1 from typing import Callable, Tuple
2
3 import torch
4 from robustML.advertrain.dependencies.autoattack import APGDAttack
5 from robustML.advertrain.training.classical_training import ClassicalTraining
6
7
8 class AutoAttackTraining(ClassicalTraining):
9 """
10 Extends ClassicalTraining to include adversarial training using AutoPGD attacks.
11
12 Attributes:
13 epsilon (float): The maximum perturbation amount allowed for the APGD attack.
14 apgd_loss (str): The loss function to be used in the APGD attack.
15 apgd (APGDAttack): Instance of APGDAttack for generating adversarial examples.
16
17 Methods:
18 preprocess_batch(x, y, epoch): Processes each batch by generating adversarial examples.
19 """
20 def __init__(
21 self,
22 model: torch.nn.Module,
23 optimizer: torch.optim.Optimizer,
24 loss_func: Callable,
25 device: torch.device,
26 loss: str,
27 epsilon: float
28 ):
29 """
30 Initializes the AutoAttackTraining object with the given parameters.
31
32 Args:
33 model (torch.nn.Module): The neural network model to be trained.
34 optimizer (torch.optim.Optimizer): The optimizer used for training.
35 loss_func (Callable): The loss function used for training.
36 device (torch.device): The device on which to perform computations.
37 loss (str): The type of loss function to use in the APGD attack.
38 epsilon (float): The maximum perturbation amount allowed for the APGD attack.
39 """
40 super().__init__(model, optimizer, loss_func, device)
41
42 self.epsilon = epsilon
43 self.apgd_loss = loss
44 self.apgd = APGDAttack(
45 self.model, eps=self.epsilon, loss=self.apgd_loss, device=self.device
46 )
47
48 def preprocess_batch(self, x: torch.Tensor, y: torch.Tensor, epoch: int) -> Tuple[torch.Tensor, torch.Tensor]:
49 """
50 Processes each batch by generating adversarial examples.
51
52 Args:
53 x (torch.Tensor): Input data (images).
54 y (torch.Tensor): Corresponding labels.
55 epoch (int): The current epoch number.
56
57 Returns:
-
E501
Line too long (122 > 120 characters)
58 Tuple[torch.Tensor, torch.Tensor]: A tuple containing the adversarial examples and their corresponding labels.
59 """
60 adv_x = self.apgd.perturb(x, y)
61 return adv_x, y