Coverage for robustML/advertrain/training/fire_training.py: 63%

30 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-09-10 08:11 +0000

1import torch 

2 

3from robustML.advertrain.dependencies.fire import fire_loss 

4from robustML.advertrain.training.classical_training import ClassicalTraining 

5 

6 

7class FIRETraining(ClassicalTraining): 

8 def __init__( 

9 self, 

10 model: torch.nn.Module, 

11 optimizer: torch.optim.Optimizer, 

12 device: torch.device, 

13 epsilon: float, 

14 beta: float, 

15 perturb_steps: int = 20 

16 ): 

17 """ 

18 Initialize the FIRETraining class for adversarial and robust training of neural network models. 

19 

20 This class extends ClassicalTraining by incorporating the FIRE (Fast and Improved Robustness Estimation) loss 

21 in the training process, which is designed for adversarial training scenarios. 

22 

23 Args: 

24 model (torch.nn.Module): The neural network model to be trained. 

25 optimizer (torch.optim.Optimizer): The optimizer used for training. 

26 device (torch.device): The device on which to perform calculations. 

27 epsilon (float): Perturbation size for adversarial example generation. 

28 beta (float): Weight for the robust loss in the overall loss calculation. 

29 perturb_steps (int, optional): Number of steps for adversarial example generation (Defaults to 20). 

30 """ 

31 super().__init__(model, optimizer, None, device) 

32 

33 self.epsilon = epsilon 

34 self.beta = beta 

35 self.perturb_steps = perturb_steps 

36 self.step_size = epsilon / perturb_steps 

37 

38 def train_batch(self, x: torch.Tensor, y: torch.Tensor, epoch: int) -> tuple[float, int]: 

39 """ 

40 Train the model for one batch of data. 

41 

42 Args: 

43 x (torch.Tensor): The input data. 

44 y (torch.Tensor): The labels corresponding to the input data. 

45 epoch (int): The current epoch number. 

46 

47 Returns: 

48 tuple[float, int]: A tuple containing the loss value and the batch size. 

49 """ 

50 x, y = x.to(self.device), y.to(self.device) 

51 x, y = self.preprocess_batch(x, y, epoch) 

52 x = x.clamp(0, 1) 

53 

54 self.optimizer.zero_grad() 

55 loss, a, b, c = fire_loss( 

56 self.model, 

57 x, 

58 y, 

59 self.optimizer, 

60 epoch, 

61 self.device, 

62 step_size=self.step_size, 

63 epsilon=self.epsilon, 

64 perturb_steps=self.perturb_steps, 

65 beta=self.beta, 

66 ) 

67 

68 loss.backward() 

69 self.optimizer.step() 

70 

71 output = self.model(x) 

72 pred = torch.argmax(output, dim=1) 

73 

74 self.metrics.update(x, y, pred, loss) 

75 

76 return ( 

77 loss.item(), 

78 len(x), 

79 ) 

80 

81 def val_batch(self, x: torch.Tensor, y: torch.Tensor, epoch: int) -> tuple[float, int]: 

82 """ 

83 Validate the model for one batch of data. 

84 

85 Args: 

86 x (torch.Tensor): The input data. 

87 y (torch.Tensor): The labels corresponding to the input data. 

88 epoch (int): The current epoch number. 

89 

90 Returns: 

91 tuple[float, int]: A tuple containing the loss value and the batch size. 

92 """ 

93 x, y = x.to(self.device), y.to(self.device) 

94 

95 with torch.no_grad(): 

96 loss, _, _, _ = fire_loss( 

97 self.model, 

98 x, 

99 y, 

100 self.optimizer, 

101 epoch, 

102 self.device, 

103 step_size=self.step_size, 

104 epsilon=self.epsilon, 

105 perturb_steps=self.perturb_steps, 

106 beta=self.beta, 

107 ) 

108 

109 output = self.model(x) 

110 pred = torch.argmax(output, dim=1) 

111 

112 self.metrics.update(x, y, pred, loss) 

113 

114 return ( 

115 loss.item(), 

116 len(x), 

117 )