⬅ robustML/advertrain/dependencies/cleverhans/projected_gradient_descent.py source

1 """
2 Taken from https://github.com/rwightman/pytorch-image-models
3  
4 The Projected Gradient Descent attack.
5  
6 MIT License
7 """
8 from typing import Optional
9  
10 import numpy as np
11 import torch
12  
13 from robustML.advertrain.dependencies.cleverhans.fast_gradient_method import fast_gradient_method
14 from robustML.advertrain.dependencies.cleverhans.utils import clip_eta
15  
16  
17 def projected_gradient_descent(
18 model_fn,
19 x: torch.Tensor,
20 eps: float,
21 eps_iter: float,
22 nb_iter: int,
23 norm: int,
24 clip_min: Optional[float] = None,
25 clip_max: Optional[float] = None,
26 y: Optional[torch.Tensor] = None,
27 targeted: bool = False,
28 rand_init: bool = True,
29 rand_minmax: Optional[float] = None,
30 sanity_checks: bool = True,
31 ) -> torch.Tensor:
32 """
33 Performs the Projected Gradient Descent attack.
34  
35 Args:
36 model_fn: A callable that takes an input tensor and returns the model logits.
37 x (torch.Tensor): Input tensor.
38 eps (float): Epsilon, the input variation parameter.
39 eps_iter (float): Step size for each attack iteration.
40 nb_iter (int): Number of attack iterations.
41 norm (int): Order of the norm (np.inf, 1, or 2).
42 clip_min (float, optional): Mininum value per input dimension.
43 clip_max (float, optional): Maximum value per input dimension.
44 y (torch.Tensor, optional): Labels or target labels for targeted attack.
45 targeted (bool): Whether to perform a targeted attack or not.
46 rand_init (bool): Whether to start from a randomly perturbed input.
47 rand_minmax (float, optional): Range of the uniform distribution for initial random perturbation.
48 sanity_checks (bool): If True, include sanity checks.
49  
50 Returns:
51 torch.Tensor: A tensor containing the adversarial examples.
52 """
53 if norm == 1:
54 raise NotImplementedError(
55 "It's not clear that FGM is a good inner loop"
56 " step for PGD when norm=1, because norm=1 FGM "
57 " changes only one pixel at a time. We need "
58 " to rigorously test a strong norm=1 PGD "
59 "before enabling this feature."
60 )
61 if norm not in [np.inf, 2]:
62 raise ValueError("Norm order must be either np.inf or 2.")
63 if eps < 0:
64 raise ValueError(f"eps must be non-negative, got {eps}")
65 if eps_iter < 0 or eps_iter > eps:
66 raise ValueError(f"eps_iter must be in the range [0, {eps}], got {eps_iter}")
67  
68 if clip_min is not None and clip_max is not None and clip_min > clip_max:
  • E501 Line too long (122 > 120 characters)
69 raise ValueError(f"clip_min must be less than or equal to clip_max, got clip_min={clip_min}, clip_max={clip_max}")
70  
71 if sanity_checks:
72 assert x.min() >= clip_min if clip_min is not None else True
73 assert x.max() <= clip_max if clip_max is not None else True
74  
  • E501 Line too long (151 > 120 characters)
75 eta = torch.zeros_like(x).uniform_(-rand_minmax if rand_minmax else eps, rand_minmax if rand_minmax else eps) if rand_init else torch.zeros_like(x)
76  
77 # Clip eta and prepare adv_x
78 eta = clip_eta(eta, norm, eps)
79 adv_x = torch.clamp(x + eta, clip_min, clip_max) if clip_min is not None or clip_max is not None else x + eta
80  
81 y = torch.argmax(model_fn(x), dim=1) if y is None else y
82  
83 for _ in range(nb_iter):
84 adv_x = fast_gradient_method(model_fn, adv_x, eps_iter, norm, clip_min, clip_max, y, targeted)
85 eta = clip_eta(adv_x - x, norm, eps)
86 adv_x = x + eta
87 adv_x = torch.clamp(adv_x, clip_min, clip_max) if clip_min is not None or clip_max is not None else adv_x
88  
89 return adv_x