⬅ robustML/advertrain/dependencies/fire.py source

1 """
2 Taken from https://github.com/MarinePICOT/Adversarial-Robustness-via-Fisher-Rao-Regularization
3  
4 Robust training losses. Based on code from
5 https://github.com/MarinePICOT/Adversarial-Robustness-via-Fisher-Rao-Regularization/blob/main/src/losses.py
6 """
7  
8 import numpy as np
9 import torch
10 import torch.nn.functional as F
11 from torch.autograd import Variable
12  
13  
14 def entropy_loss(unlabeled_logits: torch.Tensor) -> torch.Tensor:
15 """
16 Calculate the entropy loss for a batch of unlabeled data.
17  
18 Args:
19 unlabeled_logits (torch.Tensor): A tensor of logits from a model's output.
20 It should have a shape of (batch_size, num_classes).
21  
22 Returns:
23 torch.Tensor: The mean entropy loss across the batch.
24 """
25 unlabeled_probs = torch.nn.functional.softmax(unlabeled_logits, dim=1)
26 return (
27 -(unlabeled_probs * torch.nn.functional.log_softmax(unlabeled_logits, dim=1))
28 .sum(dim=1)
29 .mean(dim=0)
30 )
31  
32  
33 def fire_loss(
34 model: torch.nn.Module,
35 x_natural: torch.Tensor,
36 y: torch.Tensor,
37 optimizer: torch.optim.Optimizer,
38 epoch: int,
39 device: torch.device,
40 step_size: float = 0.003,
41 epsilon: float = 0.001,
42 perturb_steps: int = 10,
43 beta: float = 1.0,
44 adversarial: bool = True,
45 distance: str = "Linf",
46 entropy_weight: float = 0,
47 pretrain: int = 0,
48 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
49 """
50 This function calculates the FIRE (Fast and Improved Robustness Estimation) loss,
51 which is a combination of natural loss, robust loss, and entropy loss for unlabeled data.
52 It is used for adversarial training and stability training of neural networks.
53  
54 Args:
55 model (torch.nn.Module): The neural network model to be trained.
56 x_natural (torch.Tensor): Input tensor of natural (non-adversarial) images.
57 y (torch.Tensor): Tensor of labels. Unlabeled data should have label -1.
58 optimizer (torch.optim.Optimizer): The optimizer used for training.
59 epoch (int): Current training epoch.
60 device (torch.device): The device on which to perform calculations.
61 step_size (float): Step size for adversarial example generation.
62 epsilon (float): Perturbation size for adversarial example generation.
63 perturb_steps (int): Number of steps for adversarial example generation.
64 beta (float): Weight for the robust loss in the overall loss calculation.
65 adversarial (bool): Flag to enable/disable adversarial training.
66 distance (str): Type of distance metric for adversarial example generation ("Linf" or "L2").
67 entropy_weight (float): Weight for the entropy loss in the overall loss calculation.
68 pretrain (int): Number of pretraining epochs.
69  
70 Returns:
71 tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing
72 the total loss, natural loss, robust loss, and entropy loss for unlabeled data.
73 """
74  
75 if beta == 0:
76 logits = model(x_natural)
77 loss = F.cross_entropy(logits, y)
78 inf = torch.Tensor([np.inf])
79 zero = torch.Tensor([0.0])
80 return loss, loss, inf, zero
81  
82 is_unlabeled = y == -1
83 if epoch < pretrain:
84 logits = model(x_natural)
85 loss = F.cross_entropy(logits, y)
86 loss_natural = loss
87 loss_robust = torch.Tensor([0.0])
88 if torch.sum(is_unlabeled) > 0:
89 logits_unlabeled = logits[is_unlabeled]
90 loss_entropy_unlabeled = entropy_loss(logits_unlabeled)
91 loss = loss + entropy_weight * loss_entropy_unlabeled
92 else:
93 loss_entropy_unlabeled = torch.tensor(0)
94  
95 else:
96 model.eval() # moving to eval mode to freeze batchnorm stats
  • F841 Local variable 'batch_size' is assigned to but never used
97 batch_size = len(x_natural)
98 # generate adversarial example
99 x_adv = x_natural.detach() + 0.0 # the + 0. is for copying the tensor
100 s_nat = model(x_natural).softmax(1).detach()
101 if adversarial:
102 if distance == "Linf":
103 x_adv += 0.001 * torch.randn(x_natural.shape).to(device)
104  
105 for _ in range(perturb_steps):
106 x_adv.requires_grad_()
107 with torch.enable_grad():
108 s_adv = model(x_adv).softmax(1)
109 sqroot_prod = ((s_nat * s_adv) ** 0.5).sum(1)
  • E261 At least two spaces before inline comment
  • E501 Line too long (177 > 120 characters)
110 loss_kl = (torch.acos(sqroot_prod - 1e-7) ** 2).mean(0) # Minus eps to prevent gradient explosion near 1 (https://github.com/pytorch/pytorch/issues/8069)
111 grad = torch.autograd.grad(loss_kl, [x_adv])[0]
112 x_adv = x_adv.detach() + step_size * torch.sign(grad.detach())
113 x_adv = torch.min(
114 torch.max(x_adv, x_natural - epsilon), x_natural + epsilon
115 )
116 x_adv = torch.clamp(x_adv, 0.0, 1.0)
117 else:
118 raise ValueError(
119 "No support for distance %s in adversarial " "training" % distance
120 )
121 else:
122 if distance == "L2":
123 x_adv = x_adv + epsilon * torch.randn_like(x_adv)
124 else:
125 raise ValueError(
126 "No support for distance %s in stability " "training" % distance
127 )
128  
129 model.train() # moving to train mode to update batchnorm stats
130  
131 # zero gradient
132 optimizer.zero_grad()
133  
134 x_adv = Variable(torch.clamp(x_adv, 0.0, 1.0), requires_grad=False)
135 logits_nat = model(x_natural)
136 logits_adv = model(x_adv)
137  
138 s_adv = logits_adv.softmax(1)
139 s_nat = logits_nat.softmax(1)
140  
141 loss_natural = F.cross_entropy(logits_nat, y, ignore_index=-1)
142  
143 sqroot_prod = ((s_nat * s_adv) ** 0.5).sum(1)
144  
  • E261 At least two spaces before inline comment
  • E501 Line too long (165 > 120 characters)
145 loss_robust = (torch.acos(sqroot_prod - 1e-7) ** 2).mean(0) # Minus eps to prevent gradient explosion near 1 (https://github.com/pytorch/pytorch/issues/8069)
146  
147 loss = loss_natural + beta * loss_robust
148  
149 if torch.sum(is_unlabeled) > 0:
150 logits_unlabeled = logits[is_unlabeled]
151 loss_entropy_unlabeled = entropy_loss(logits_unlabeled)
152 loss = loss + entropy_weight * loss_entropy_unlabeled
153 else:
154 loss_entropy_unlabeled = torch.tensor(0)
155  
156 return loss, loss_natural, loss_robust, loss_entropy_unlabeled
157  
158  
159 def noise_loss(
160 model: torch.nn.Module,
161 x_natural: torch.Tensor,
162 y: torch.Tensor,
163 epsilon: float = 0.25,
164 clamp_x: bool = True
165 ) -> torch.Tensor:
166 """
167 This function augments the input data with random noise and computes the loss
  • W291 Trailing whitespace
168 based on the model's predictions for the noisy data.
169  
170 Args:
171 model (torch.nn.Module): The neural network model.
172 x_natural (torch.Tensor): The original (clean) input data.
173 y (torch.Tensor): The labels corresponding to the input data.
174 epsilon (float, optional): The magnitude of the noise to be added to the input data.
175 Defaults to 0.25.
176 clamp_x (bool, optional): If True, the noisy data is clamped to the range [0.0, 1.0].
177 Defaults to True.
178  
179 Returns:
180 torch.Tensor: The computed loss based on the model's predictions for the noisy data.
181 """
182 x_noise = x_natural + epsilon * torch.randn_like(x_natural)
183 if clamp_x:
184 x_noise = x_noise.clamp(0.0, 1.0)
185 logits_noise = model(x_noise)
186 loss = F.cross_entropy(logits_noise, y, ignore_index=-1)
187 return loss