Coverage for robustML/advertrain/training/adversarial_training.py: 81%

16 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-09-10 08:11 +0000

1from typing import Tuple 

2 

3import numpy as np 

4import torch 

5from torch import Tensor 

6from torch.nn import Module 

7from torch.optim import Optimizer 

8 

9from robustML.advertrain.dependencies.cleverhans.projected_gradient_descent import \ 

10 projected_gradient_descent 

11from robustML.advertrain.training.classical_training import ClassicalTraining 

12 

13 

14class AdversarialTraining(ClassicalTraining): 

15 """ 

16 A training class that incorporates adversarial training using Projected Gradient Descent (PGD). 

17 

18 This class extends ClassicalTraining by modifying the preprocessing of batches to include 

19 the generation of adversarial examples using PGD. 

20 

21 Attributes: 

22 epsilon (float): The maximum perturbation allowed for adversarial examples. 

23 """ 

24 

25 def __init__( 

26 self, 

27 model: Module, 

28 optimizer: Optimizer, 

29 loss_func, 

30 device: torch.device, 

31 epsilon: float, 

32 ) -> None: 

33 """ 

34 Initializes the AdversarialTraining class. 

35 

36 Args: 

37 model (Module): The neural network model to be trained. 

38 optimizer (Optimizer): The optimizer for training the model. 

39 loss_func: The loss function to be used for training. 

40 device (torch.device): The device for training. 

41 epsilon (float): The maximum perturbation allowed for adversarial examples. 

42 """ 

43 self.epsilon = epsilon 

44 super().__init__(model, optimizer, loss_func, device) 

45 

46 def preprocess_batch( 

47 self, x: Tensor, y: Tensor, epoch: int 

48 ) -> Tuple[Tensor, Tensor]: 

49 """ 

50 Preprocesses a batch of data by generating adversarial examples using PGD. 

51 

52 Args: 

53 x (Tensor): The input data batch. 

54 y (Tensor): The ground truth labels batch. 

55 epoch (int): The current training epoch. 

56 

57 Returns: 

58 Tuple[Tensor, Tensor]: A tuple of adversarial examples and their corresponding labels. 

59 """ 

60 n_steps = 20 

61 

62 adv_x = projected_gradient_descent( 

63 model_fn=self.model, 

64 x=x, 

65 eps=self.epsilon, 

66 eps_iter=self.epsilon / n_steps, 

67 nb_iter=n_steps, 

68 norm=np.inf, 

69 clip_min=0, 

70 clip_max=1, 

71 sanity_checks=False, 

72 ) 

73 

74 return adv_x, y