Source code for advertrain.training.autoattack_training

from typing import Callable, Tuple

import torch
from robustML.advertrain.dependencies.autoattack import APGDAttack
from robustML.advertrain.training.classical_training import ClassicalTraining


[docs] class AutoAttackTraining(ClassicalTraining): """ Extends ClassicalTraining to include adversarial training using AutoPGD attacks. Attributes: epsilon (float): The maximum perturbation amount allowed for the APGD attack. apgd_loss (str): The loss function to be used in the APGD attack. apgd (APGDAttack): Instance of APGDAttack for generating adversarial examples. Methods: preprocess_batch(x, y, epoch): Processes each batch by generating adversarial examples. """ def __init__( self, model: torch.nn.Module, optimizer: torch.optim.Optimizer, loss_func: Callable, device: torch.device, loss: str, epsilon: float ): """ Initializes the AutoAttackTraining object with the given parameters. Args: model (torch.nn.Module): The neural network model to be trained. optimizer (torch.optim.Optimizer): The optimizer used for training. loss_func (Callable): The loss function used for training. device (torch.device): The device on which to perform computations. loss (str): The type of loss function to use in the APGD attack. epsilon (float): The maximum perturbation amount allowed for the APGD attack. """ super().__init__(model, optimizer, loss_func, device) self.epsilon = epsilon self.apgd_loss = loss self.apgd = APGDAttack( self.model, eps=self.epsilon, loss=self.apgd_loss, device=self.device )
[docs] def preprocess_batch(self, x: torch.Tensor, y: torch.Tensor, epoch: int) -> Tuple[torch.Tensor, torch.Tensor]: """ Processes each batch by generating adversarial examples. Args: x (torch.Tensor): Input data (images). y (torch.Tensor): Corresponding labels. epoch (int): The current epoch number. Returns: Tuple[torch.Tensor, torch.Tensor]: A tuple containing the adversarial examples and their corresponding labels. """ adv_x = self.apgd.perturb(x, y) return adv_x, y