Source code for robustML.advertrain.dependencies.cleverhans.projected_gradient_descent

"""
Taken from https://github.com/rwightman/pytorch-image-models

The Projected Gradient Descent attack.

MIT License
"""
from typing import Optional

import numpy as np
import torch

from robustML.advertrain.dependencies.cleverhans.fast_gradient_method import fast_gradient_method
from robustML.advertrain.dependencies.cleverhans.utils import clip_eta


[docs] def projected_gradient_descent( model_fn, x: torch.Tensor, eps: float, eps_iter: float, nb_iter: int, norm: int, clip_min: Optional[float] = None, clip_max: Optional[float] = None, y: Optional[torch.Tensor] = None, targeted: bool = False, rand_init: bool = True, rand_minmax: Optional[float] = None, sanity_checks: bool = True, ) -> torch.Tensor: """ Performs the Projected Gradient Descent attack. Args: model_fn: A callable that takes an input tensor and returns the model logits. x (torch.Tensor): Input tensor. eps (float): Epsilon, the input variation parameter. eps_iter (float): Step size for each attack iteration. nb_iter (int): Number of attack iterations. norm (int): Order of the norm (np.inf, 1, or 2). clip_min (float, optional): Mininum value per input dimension. clip_max (float, optional): Maximum value per input dimension. y (torch.Tensor, optional): Labels or target labels for targeted attack. targeted (bool): Whether to perform a targeted attack or not. rand_init (bool): Whether to start from a randomly perturbed input. rand_minmax (float, optional): Range of the uniform distribution for initial random perturbation. sanity_checks (bool): If True, include sanity checks. Returns: torch.Tensor: A tensor containing the adversarial examples. """ if norm == 1: raise NotImplementedError( "It's not clear that FGM is a good inner loop" " step for PGD when norm=1, because norm=1 FGM " " changes only one pixel at a time. We need " " to rigorously test a strong norm=1 PGD " "before enabling this feature." ) if norm not in [np.inf, 2]: raise ValueError("Norm order must be either np.inf or 2.") if eps < 0: raise ValueError(f"eps must be non-negative, got {eps}") if eps_iter < 0 or eps_iter > eps: raise ValueError(f"eps_iter must be in the range [0, {eps}], got {eps_iter}") if clip_min is not None and clip_max is not None and clip_min > clip_max: raise ValueError(f"clip_min must be less than or equal to clip_max, got clip_min={clip_min}, clip_max={clip_max}") if sanity_checks: assert x.min() >= clip_min if clip_min is not None else True assert x.max() <= clip_max if clip_max is not None else True eta = torch.zeros_like(x).uniform_(-rand_minmax if rand_minmax else eps, rand_minmax if rand_minmax else eps) if rand_init else torch.zeros_like(x) # Clip eta and prepare adv_x eta = clip_eta(eta, norm, eps) adv_x = torch.clamp(x + eta, clip_min, clip_max) if clip_min is not None or clip_max is not None else x + eta y = torch.argmax(model_fn(x), dim=1) if y is None else y for _ in range(nb_iter): adv_x = fast_gradient_method(model_fn, adv_x, eps_iter, norm, clip_min, clip_max, y, targeted) eta = clip_eta(adv_x - x, norm, eps) adv_x = x + eta adv_x = torch.clamp(adv_x, clip_min, clip_max) if clip_min is not None or clip_max is not None else adv_x return adv_x