Coverage for robustAI/advertrain/dependencies/cleverhans/projected_gradient_descent.py: 93%

29 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-10-01 08:42 +0000

1""" 

2Taken from https://github.com/rwightman/pytorch-image-models 

3 

4The Projected Gradient Descent attack. 

5 

6MIT License 

7""" 

8from typing import Optional 

9 

10import numpy as np 

11import torch 

12 

13from robustAI.advertrain.dependencies.cleverhans.fast_gradient_method import fast_gradient_method 

14from robustAI.advertrain.dependencies.cleverhans.utils import clip_eta 

15 

16 

17def projected_gradient_descent( 

18 model_fn, 

19 x: torch.Tensor, 

20 eps: float, 

21 eps_iter: float, 

22 nb_iter: int, 

23 norm: int, 

24 clip_min: Optional[float] = None, 

25 clip_max: Optional[float] = None, 

26 y: Optional[torch.Tensor] = None, 

27 targeted: bool = False, 

28 rand_init: bool = True, 

29 rand_minmax: Optional[float] = None, 

30 sanity_checks: bool = True, 

31) -> torch.Tensor: 

32 """ 

33 Performs the Projected Gradient Descent attack. 

34 

35 Args: 

36 model_fn: A callable that takes an input tensor and returns the model logits. 

37 x (torch.Tensor): Input tensor. 

38 eps (float): Epsilon, the input variation parameter. 

39 eps_iter (float): Step size for each attack iteration. 

40 nb_iter (int): Number of attack iterations. 

41 norm (int): Order of the norm (np.inf, 1, or 2). 

42 clip_min (float, optional): Mininum value per input dimension. 

43 clip_max (float, optional): Maximum value per input dimension. 

44 y (torch.Tensor, optional): Labels or target labels for targeted attack. 

45 targeted (bool): Whether to perform a targeted attack or not. 

46 rand_init (bool): Whether to start from a randomly perturbed input. 

47 rand_minmax (float, optional): Range of the uniform distribution for initial random perturbation. 

48 sanity_checks (bool): If True, include sanity checks. 

49 

50 Returns: 

51 torch.Tensor: A tensor containing the adversarial examples. 

52 """ 

53 if norm == 1: 

54 raise NotImplementedError( 

55 "It's not clear that FGM is a good inner loop" 

56 " step for PGD when norm=1, because norm=1 FGM " 

57 " changes only one pixel at a time. We need " 

58 " to rigorously test a strong norm=1 PGD " 

59 "before enabling this feature." 

60 ) 

61 if norm not in [np.inf, 2]: 

62 raise ValueError("Norm order must be either np.inf or 2.") 

63 if eps < 0: 

64 raise ValueError(f"eps must be non-negative, got {eps}") 

65 if eps_iter < 0 or eps_iter > eps: 

66 raise ValueError(f"eps_iter must be in the range [0, {eps}], got {eps_iter}") 

67 

68 if clip_min is not None and clip_max is not None and clip_min > clip_max: 

69 raise ValueError(f"clip_min must be less or equal to clip_max, got clip_min={clip_min}, clip_max={clip_max}") 

70 

71 if sanity_checks: 

72 assert x.min() >= clip_min if clip_min is not None else True 

73 assert x.max() <= clip_max if clip_max is not None else True 

74 

75 eta = ( 

76 torch.zeros_like(x).uniform_(-rand_minmax if rand_minmax else eps, rand_minmax if rand_minmax else eps) 

77 if rand_init 

78 else torch.zeros_like(x) 

79 ) 

80 

81 # Clip eta and prepare adv_x 

82 eta = clip_eta(eta, norm, eps) 

83 adv_x = torch.clamp(x + eta, clip_min, clip_max) if clip_min is not None or clip_max is not None else x + eta 

84 

85 y = torch.argmax(model_fn(x), dim=1) if y is None else y 

86 

87 for _ in range(nb_iter): 

88 adv_x = fast_gradient_method(model_fn, adv_x, eps_iter, norm, clip_min, clip_max, y, targeted) 

89 eta = clip_eta(adv_x - x, norm, eps) 

90 adv_x = x + eta 

91 adv_x = torch.clamp(adv_x, clip_min, clip_max) if clip_min is not None or clip_max is not None else adv_x 

92 

93 return adv_x