Coverage for robustML/advertrain/training/adversarial_training.py: 81%
16 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-09-10 08:11 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-09-10 08:11 +0000
1from typing import Tuple
3import numpy as np
4import torch
5from torch import Tensor
6from torch.nn import Module
7from torch.optim import Optimizer
9from robustML.advertrain.dependencies.cleverhans.projected_gradient_descent import \
10 projected_gradient_descent
11from robustML.advertrain.training.classical_training import ClassicalTraining
14class AdversarialTraining(ClassicalTraining):
15 """
16 A training class that incorporates adversarial training using Projected Gradient Descent (PGD).
18 This class extends ClassicalTraining by modifying the preprocessing of batches to include
19 the generation of adversarial examples using PGD.
21 Attributes:
22 epsilon (float): The maximum perturbation allowed for adversarial examples.
23 """
25 def __init__(
26 self,
27 model: Module,
28 optimizer: Optimizer,
29 loss_func,
30 device: torch.device,
31 epsilon: float,
32 ) -> None:
33 """
34 Initializes the AdversarialTraining class.
36 Args:
37 model (Module): The neural network model to be trained.
38 optimizer (Optimizer): The optimizer for training the model.
39 loss_func: The loss function to be used for training.
40 device (torch.device): The device for training.
41 epsilon (float): The maximum perturbation allowed for adversarial examples.
42 """
43 self.epsilon = epsilon
44 super().__init__(model, optimizer, loss_func, device)
46 def preprocess_batch(
47 self, x: Tensor, y: Tensor, epoch: int
48 ) -> Tuple[Tensor, Tensor]:
49 """
50 Preprocesses a batch of data by generating adversarial examples using PGD.
52 Args:
53 x (Tensor): The input data batch.
54 y (Tensor): The ground truth labels batch.
55 epoch (int): The current training epoch.
57 Returns:
58 Tuple[Tensor, Tensor]: A tuple of adversarial examples and their corresponding labels.
59 """
60 n_steps = 20
62 adv_x = projected_gradient_descent(
63 model_fn=self.model,
64 x=x,
65 eps=self.epsilon,
66 eps_iter=self.epsilon / n_steps,
67 nb_iter=n_steps,
68 norm=np.inf,
69 clip_min=0,
70 clip_max=1,
71 sanity_checks=False,
72 )
74 return adv_x, y