Coverage for robustML/advertrain/training/autoattack_training.py: 85%

13 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-09-10 08:11 +0000

1from typing import Callable, Tuple 

2 

3import torch 

4from robustML.advertrain.dependencies.autoattack import APGDAttack 

5from robustML.advertrain.training.classical_training import ClassicalTraining 

6 

7 

8class AutoAttackTraining(ClassicalTraining): 

9 """ 

10 Extends ClassicalTraining to include adversarial training using AutoPGD attacks. 

11 

12 Attributes: 

13 epsilon (float): The maximum perturbation amount allowed for the APGD attack. 

14 apgd_loss (str): The loss function to be used in the APGD attack. 

15 apgd (APGDAttack): Instance of APGDAttack for generating adversarial examples. 

16 

17 Methods: 

18 preprocess_batch(x, y, epoch): Processes each batch by generating adversarial examples. 

19 """ 

20 def __init__( 

21 self, 

22 model: torch.nn.Module, 

23 optimizer: torch.optim.Optimizer, 

24 loss_func: Callable, 

25 device: torch.device, 

26 loss: str, 

27 epsilon: float 

28 ): 

29 """ 

30 Initializes the AutoAttackTraining object with the given parameters. 

31 

32 Args: 

33 model (torch.nn.Module): The neural network model to be trained. 

34 optimizer (torch.optim.Optimizer): The optimizer used for training. 

35 loss_func (Callable): The loss function used for training. 

36 device (torch.device): The device on which to perform computations. 

37 loss (str): The type of loss function to use in the APGD attack. 

38 epsilon (float): The maximum perturbation amount allowed for the APGD attack. 

39 """ 

40 super().__init__(model, optimizer, loss_func, device) 

41 

42 self.epsilon = epsilon 

43 self.apgd_loss = loss 

44 self.apgd = APGDAttack( 

45 self.model, eps=self.epsilon, loss=self.apgd_loss, device=self.device 

46 ) 

47 

48 def preprocess_batch(self, x: torch.Tensor, y: torch.Tensor, epoch: int) -> Tuple[torch.Tensor, torch.Tensor]: 

49 """ 

50 Processes each batch by generating adversarial examples. 

51 

52 Args: 

53 x (torch.Tensor): Input data (images). 

54 y (torch.Tensor): Corresponding labels. 

55 epoch (int): The current epoch number. 

56 

57 Returns: 

58 Tuple[torch.Tensor, torch.Tensor]: A tuple containing the adversarial examples and their corresponding labels. 

59 """ 

60 adv_x = self.apgd.perturb(x, y) 

61 return adv_x, y