Coverage for robustAI/advertrain/training/autoattack_training.py: 85%
13 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-10-01 08:42 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-10-01 08:42 +0000
1"""
2This module contains the class for ClassicalTraining including aversarial training using AutoGPD attack
3"""
4from typing import Callable, Tuple
6import torch
7from robustAI.advertrain.dependencies.autoattack import APGDAttack
8from robustAI.advertrain.training.classical_training import ClassicalTraining
11class AutoAttackTraining(ClassicalTraining):
12 """
13 Extends ClassicalTraining to include adversarial training using AutoPGD attacks.
15 Attributes:
16 epsilon (float): The maximum perturbation amount allowed for the APGD attack.
17 apgd_loss (str): The loss function to be used in the APGD attack.
18 apgd (APGDAttack): Instance of APGDAttack for generating adversarial examples.
20 Methods:
21 preprocess_batch(x, y, epoch): Processes each batch by generating adversarial examples.
22 """
23 def __init__(
24 self,
25 model: torch.nn.Module,
26 optimizer: torch.optim.Optimizer,
27 loss_func: Callable,
28 device: torch.device,
29 loss: str,
30 epsilon: float
31 ):
32 """
33 Initializes the AutoAttackTraining object with the given parameters.
35 Args:
36 model (torch.nn.Module): The neural network model to be trained.
37 optimizer (torch.optim.Optimizer): The optimizer used for training.
38 loss_func (Callable): The loss function used for training.
39 device (torch.device): The device on which to perform computations.
40 loss (str): The type of loss function to use in the APGD attack.
41 epsilon (float): The maximum perturbation amount allowed for the APGD attack.
42 """
43 super().__init__(model, optimizer, loss_func, device)
45 self.epsilon = epsilon
46 self.apgd_loss = loss
47 self.apgd = APGDAttack(
48 self.model, eps=self.epsilon, loss=self.apgd_loss, device=self.device
49 )
51 def preprocess_batch(self, x: torch.Tensor, y: torch.Tensor, epoch: int) -> Tuple[torch.Tensor, torch.Tensor]:
52 """
53 Processes each batch by generating adversarial examples.
55 Args:
56 x (torch.Tensor): Input data (images).
57 y (torch.Tensor): Corresponding labels.
58 epoch (int): The current epoch number.
60 Returns:
61 Tuple[torch.Tensor, torch.Tensor]: A tuple containing the adversarial examples and their corresponding
62 labels.
63 """
64 adv_x = self.apgd.perturb(x, y)
65 return adv_x, y