Coverage for robustAI/advertrain/dependencies/fire.py: 70%

69 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-10-01 08:42 +0000

1""" 

2Taken from https://github.com/MarinePICOT/Adversarial-Robustness-via-Fisher-Rao-Regularization 

3 

4Robust training losses. Based on code from 

5https://github.com/MarinePICOT/Adversarial-Robustness-via-Fisher-Rao-Regularization/blob/main/src/losses.py 

6""" 

7 

8import numpy as np 

9import torch 

10import torch.nn.functional as F 

11from torch.autograd import Variable 

12 

13 

14def entropy_loss(unlabeled_logits: torch.Tensor) -> torch.Tensor: 

15 """ 

16 Calculate the entropy loss for a batch of unlabeled data. 

17 

18 Args: 

19 unlabeled_logits (torch.Tensor): A tensor of logits from a model's output. 

20 It should have a shape of (batch_size, num_classes). 

21 

22 Returns: 

23 torch.Tensor: The mean entropy loss across the batch. 

24 """ 

25 unlabeled_probs = torch.nn.functional.softmax(unlabeled_logits, dim=1) 

26 return ( 

27 -(unlabeled_probs * torch.nn.functional.log_softmax(unlabeled_logits, dim=1)) 

28 .sum(dim=1) 

29 .mean(dim=0) 

30 ) 

31 

32 

33def fire_loss( 

34 model: torch.nn.Module, 

35 x_natural: torch.Tensor, 

36 y: torch.Tensor, 

37 optimizer: torch.optim.Optimizer, 

38 epoch: int, 

39 device: torch.device, 

40 step_size: float = 0.003, 

41 epsilon: float = 0.001, 

42 perturb_steps: int = 10, 

43 beta: float = 1.0, 

44 adversarial: bool = True, 

45 distance: str = "Linf", 

46 entropy_weight: float = 0, 

47 pretrain: int = 0, 

48) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 

49 """ 

50 This function calculates the FIRE (Fast and Improved Robustness Estimation) loss, 

51 which is a combination of natural loss, robust loss, and entropy loss for unlabeled data. 

52 It is used for adversarial training and stability training of neural networks. 

53 

54 Args: 

55 model (torch.nn.Module): The neural network model to be trained. 

56 x_natural (torch.Tensor): Input tensor of natural (non-adversarial) images. 

57 y (torch.Tensor): Tensor of labels. Unlabeled data should have label -1. 

58 optimizer (torch.optim.Optimizer): The optimizer used for training. 

59 epoch (int): Current training epoch. 

60 device (torch.device): The device on which to perform calculations. 

61 step_size (float): Step size for adversarial example generation. 

62 epsilon (float): Perturbation size for adversarial example generation. 

63 perturb_steps (int): Number of steps for adversarial example generation. 

64 beta (float): Weight for the robust loss in the overall loss calculation. 

65 adversarial (bool): Flag to enable/disable adversarial training. 

66 distance (str): Type of distance metric for adversarial example generation ("Linf" or "L2"). 

67 entropy_weight (float): Weight for the entropy loss in the overall loss calculation. 

68 pretrain (int): Number of pretraining epochs. 

69 

70 Returns: 

71 tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing 

72 the total loss, natural loss, robust loss, and entropy loss for unlabeled data. 

73 """ 

74 

75 if beta == 0: 

76 logits = model(x_natural) 

77 loss = F.cross_entropy(logits, y) 

78 inf = torch.Tensor([np.inf]) 

79 zero = torch.Tensor([0.0]) 

80 return loss, loss, inf, zero 

81 

82 is_unlabeled = y == -1 

83 if epoch < pretrain: 

84 logits = model(x_natural) 

85 loss = F.cross_entropy(logits, y) 

86 loss_natural = loss 

87 loss_robust = torch.Tensor([0.0]) 

88 if torch.sum(is_unlabeled) > 0: 

89 logits_unlabeled = logits[is_unlabeled] 

90 loss_entropy_unlabeled = entropy_loss(logits_unlabeled) 

91 loss = loss + entropy_weight * loss_entropy_unlabeled 

92 else: 

93 loss_entropy_unlabeled = torch.tensor(0) 

94 

95 else: 

96 model.eval() # moving to eval mode to freeze batchnorm stats 

97 # batch_size = len(x_natural) 

98 # generate adversarial example 

99 x_adv = x_natural.detach() + 0.0 # the + 0. is for copying the tensor 

100 s_nat = model(x_natural).softmax(1).detach() 

101 if adversarial: 

102 if distance == "Linf": 

103 x_adv += 0.001 * torch.randn(x_natural.shape).to(device) 

104 

105 for _ in range(perturb_steps): 

106 x_adv.requires_grad_() 

107 with torch.enable_grad(): 

108 s_adv = model(x_adv).softmax(1) 

109 sqroot_prod = ((s_nat * s_adv) ** 0.5).sum(1) 

110 # In line below, Minus eps to prevent gradient 

111 # explosion near 1 (https://github.com/pytorch/pytorch/issues/8069) 

112 loss_kl = (torch.acos(sqroot_prod - 1e-7) ** 2).mean(0) 

113 

114 grad = torch.autograd.grad(loss_kl, [x_adv])[0] 

115 x_adv = x_adv.detach() + step_size * torch.sign(grad.detach()) 

116 x_adv = torch.min( 

117 torch.max(x_adv, x_natural - epsilon), x_natural + epsilon 

118 ) 

119 x_adv = torch.clamp(x_adv, 0.0, 1.0) 

120 else: 

121 raise ValueError(f"No support for distance {distance} in adversarial training") 

122 else: 

123 if distance == "L2": 

124 x_adv = x_adv + epsilon * torch.randn_like(x_adv) 

125 else: 

126 raise ValueError(f"No support for distance {distance} in stability training") 

127 

128 model.train() # moving to train mode to update batchnorm stats 

129 

130 # zero gradient 

131 optimizer.zero_grad() 

132 

133 x_adv = Variable(torch.clamp(x_adv, 0.0, 1.0), requires_grad=False) 

134 logits_nat = model(x_natural) 

135 logits_adv = model(x_adv) 

136 

137 s_adv = logits_adv.softmax(1) 

138 s_nat = logits_nat.softmax(1) 

139 

140 loss_natural = F.cross_entropy(logits_nat, y, ignore_index=-1) 

141 

142 sqroot_prod = ((s_nat * s_adv) ** 0.5).sum(1) 

143 # In line below, Minus eps to prevent gradient explosion near 1 (https://github.com/pytorch/pytorch/issues/8069) 

144 loss_robust = (torch.acos(sqroot_prod - 1e-7) ** 2).mean(0) 

145 

146 loss = loss_natural + beta * loss_robust 

147 

148 if torch.sum(is_unlabeled) > 0: 

149 logits_unlabeled = logits[is_unlabeled] 

150 loss_entropy_unlabeled = entropy_loss(logits_unlabeled) 

151 loss = loss + entropy_weight * loss_entropy_unlabeled 

152 else: 

153 loss_entropy_unlabeled = torch.tensor(0) 

154 

155 return loss, loss_natural, loss_robust, loss_entropy_unlabeled 

156 

157 

158def noise_loss( 

159 model: torch.nn.Module, 

160 x_natural: torch.Tensor, 

161 y: torch.Tensor, 

162 epsilon: float = 0.25, 

163 clamp_x: bool = True 

164) -> torch.Tensor: 

165 """ 

166 This function augments the input data with random noise and computes the loss 

167 based on the model's predictions for the noisy data. 

168 Args: 

169 model (torch.nn.Module): The neural network model. 

170 x_natural (torch.Tensor): The original (clean) input data. 

171 y (torch.Tensor): The labels corresponding to the input data. 

172 epsilon (float, optional): The magnitude of the noise to be added to the input data. 

173 Defaults to 0.25. 

174 clamp_x (bool, optional): If True, the noisy data is clamped to the range [0.0, 1.0]. 

175 Defaults to True. 

176 

177 Returns: 

178 torch.Tensor: The computed loss based on the model's predictions for the noisy data. 

179 """ 

180 x_noise = x_natural + epsilon * torch.randn_like(x_natural) 

181 if clamp_x: 

182 x_noise = x_noise.clamp(0.0, 1.0) 

183 logits_noise = model(x_noise) 

184 loss = F.cross_entropy(logits_noise, y, ignore_index=-1) 

185 return loss