Coverage for robustAI/advertrain/training/autoattack_training.py: 85%

13 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-10-01 08:42 +0000

1""" 

2This module contains the class for ClassicalTraining including aversarial training using AutoGPD attack 

3""" 

4from typing import Callable, Tuple 

5 

6import torch 

7from robustAI.advertrain.dependencies.autoattack import APGDAttack 

8from robustAI.advertrain.training.classical_training import ClassicalTraining 

9 

10 

11class AutoAttackTraining(ClassicalTraining): 

12 """ 

13 Extends ClassicalTraining to include adversarial training using AutoPGD attacks. 

14 

15 Attributes: 

16 epsilon (float): The maximum perturbation amount allowed for the APGD attack. 

17 apgd_loss (str): The loss function to be used in the APGD attack. 

18 apgd (APGDAttack): Instance of APGDAttack for generating adversarial examples. 

19 

20 Methods: 

21 preprocess_batch(x, y, epoch): Processes each batch by generating adversarial examples. 

22 """ 

23 def __init__( 

24 self, 

25 model: torch.nn.Module, 

26 optimizer: torch.optim.Optimizer, 

27 loss_func: Callable, 

28 device: torch.device, 

29 loss: str, 

30 epsilon: float 

31 ): 

32 """ 

33 Initializes the AutoAttackTraining object with the given parameters. 

34 

35 Args: 

36 model (torch.nn.Module): The neural network model to be trained. 

37 optimizer (torch.optim.Optimizer): The optimizer used for training. 

38 loss_func (Callable): The loss function used for training. 

39 device (torch.device): The device on which to perform computations. 

40 loss (str): The type of loss function to use in the APGD attack. 

41 epsilon (float): The maximum perturbation amount allowed for the APGD attack. 

42 """ 

43 super().__init__(model, optimizer, loss_func, device) 

44 

45 self.epsilon = epsilon 

46 self.apgd_loss = loss 

47 self.apgd = APGDAttack( 

48 self.model, eps=self.epsilon, loss=self.apgd_loss, device=self.device 

49 ) 

50 

51 def preprocess_batch(self, x: torch.Tensor, y: torch.Tensor, epoch: int) -> Tuple[torch.Tensor, torch.Tensor]: 

52 """ 

53 Processes each batch by generating adversarial examples. 

54 

55 Args: 

56 x (torch.Tensor): Input data (images). 

57 y (torch.Tensor): Corresponding labels. 

58 epoch (int): The current epoch number. 

59 

60 Returns: 

61 Tuple[torch.Tensor, torch.Tensor]: A tuple containing the adversarial examples and their corresponding 

62 labels. 

63 """ 

64 adv_x = self.apgd.perturb(x, y) 

65 return adv_x, y