Coverage for robustML/advertrain/training/fire_training.py: 63%
30 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-09-10 08:11 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-09-10 08:11 +0000
1import torch
3from robustML.advertrain.dependencies.fire import fire_loss
4from robustML.advertrain.training.classical_training import ClassicalTraining
7class FIRETraining(ClassicalTraining):
8 def __init__(
9 self,
10 model: torch.nn.Module,
11 optimizer: torch.optim.Optimizer,
12 device: torch.device,
13 epsilon: float,
14 beta: float,
15 perturb_steps: int = 20
16 ):
17 """
18 Initialize the FIRETraining class for adversarial and robust training of neural network models.
20 This class extends ClassicalTraining by incorporating the FIRE (Fast and Improved Robustness Estimation) loss
21 in the training process, which is designed for adversarial training scenarios.
23 Args:
24 model (torch.nn.Module): The neural network model to be trained.
25 optimizer (torch.optim.Optimizer): The optimizer used for training.
26 device (torch.device): The device on which to perform calculations.
27 epsilon (float): Perturbation size for adversarial example generation.
28 beta (float): Weight for the robust loss in the overall loss calculation.
29 perturb_steps (int, optional): Number of steps for adversarial example generation (Defaults to 20).
30 """
31 super().__init__(model, optimizer, None, device)
33 self.epsilon = epsilon
34 self.beta = beta
35 self.perturb_steps = perturb_steps
36 self.step_size = epsilon / perturb_steps
38 def train_batch(self, x: torch.Tensor, y: torch.Tensor, epoch: int) -> tuple[float, int]:
39 """
40 Train the model for one batch of data.
42 Args:
43 x (torch.Tensor): The input data.
44 y (torch.Tensor): The labels corresponding to the input data.
45 epoch (int): The current epoch number.
47 Returns:
48 tuple[float, int]: A tuple containing the loss value and the batch size.
49 """
50 x, y = x.to(self.device), y.to(self.device)
51 x, y = self.preprocess_batch(x, y, epoch)
52 x = x.clamp(0, 1)
54 self.optimizer.zero_grad()
55 loss, a, b, c = fire_loss(
56 self.model,
57 x,
58 y,
59 self.optimizer,
60 epoch,
61 self.device,
62 step_size=self.step_size,
63 epsilon=self.epsilon,
64 perturb_steps=self.perturb_steps,
65 beta=self.beta,
66 )
68 loss.backward()
69 self.optimizer.step()
71 output = self.model(x)
72 pred = torch.argmax(output, dim=1)
74 self.metrics.update(x, y, pred, loss)
76 return (
77 loss.item(),
78 len(x),
79 )
81 def val_batch(self, x: torch.Tensor, y: torch.Tensor, epoch: int) -> tuple[float, int]:
82 """
83 Validate the model for one batch of data.
85 Args:
86 x (torch.Tensor): The input data.
87 y (torch.Tensor): The labels corresponding to the input data.
88 epoch (int): The current epoch number.
90 Returns:
91 tuple[float, int]: A tuple containing the loss value and the batch size.
92 """
93 x, y = x.to(self.device), y.to(self.device)
95 with torch.no_grad():
96 loss, _, _, _ = fire_loss(
97 self.model,
98 x,
99 y,
100 self.optimizer,
101 epoch,
102 self.device,
103 step_size=self.step_size,
104 epsilon=self.epsilon,
105 perturb_steps=self.perturb_steps,
106 beta=self.beta,
107 )
109 output = self.model(x)
110 pred = torch.argmax(output, dim=1)
112 self.metrics.update(x, y, pred, loss)
114 return (
115 loss.item(),
116 len(x),
117 )