Coverage for robustAI/advertrain/dependencies/cleverhans/fast_gradient_method.py: 70%

33 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-10-01 08:42 +0000

1""" 

2Taken from https://github.com/rwightman/pytorch-image-models 

3 

4The Fast Gradient Method attack. 

5 

6MIT License 

7""" 

8from typing import Optional 

9 

10import numpy as np 

11import torch 

12 

13from robustAI.advertrain.dependencies.cleverhans.utils import optimize_linear 

14 

15 

16def fast_gradient_method( 

17 model_fn, 

18 x: torch.Tensor, 

19 eps: float, 

20 norm: int, 

21 clip_min: Optional[float] = None, 

22 clip_max: Optional[float] = None, 

23 y: Optional[torch.Tensor] = None, 

24 targeted: bool = False, 

25 sanity_checks: bool = False, 

26) -> torch.Tensor: 

27 """ 

28 PyTorch implementation of the Fast Gradient Method (FGM). 

29 

30 Args: 

31 model_fn: A callable that takes an input tensor and returns the model logits. 

32 x (torch.Tensor): Input tensor. 

33 eps (float): Epsilon, the input variation parameter. 

34 norm (int): Order of the norm (np.inf, 1, or 2). 

35 clip_min (float, optional): Mininum value per input dimension. 

36 clip_max (float, optional): Maximum value per input dimension. 

37 y (torch.Tensor, optional): Labels or target labels for targeted attack. 

38 targeted (bool): Whether to perform a targeted attack or not. 

39 sanity_checks (bool): If True, include sanity checks. 

40 

41 Returns: 

42 torch.Tensor: A tensor containing the adversarial examples. 

43 """ 

44 # Clipping perturbations 

45 if eps < 0: 

46 raise ValueError(f"eps must be greater than or equal to 0, got {eps} instead") 

47 if eps == 0: 

48 return x 

49 if clip_min is not None and clip_max is not None and clip_min > clip_max: 

50 raise ValueError( 

51 f"clip_min must be less than or equal to clip_max,got clip_min={clip_min},clip_max={clip_max}.") 

52 

53 asserts = [] 

54 

55 # If a data range was specified, 

56 if clip_min is not None: 

57 assert_ge = torch.all( 

58 torch.ge(x, torch.tensor(clip_min, device=x.device, dtype=x.dtype)) 

59 ) 

60 asserts.append(assert_ge) 

61 

62 if clip_max is not None: 

63 assert_le = torch.all( 

64 torch.le(x, torch.tensor(clip_max, device=x.device, dtype=x.dtype)) 

65 ) 

66 asserts.append(assert_le) 

67 

68 if sanity_checks: 

69 assert np.all(asserts) 

70 # Prepare input tensor 

71 x = x.clone().detach().float().requires_grad_(True) 

72 y = torch.argmax(model_fn(x), dim=1) if y is None else y 

73 

74 # Compute loss 

75 loss_fn = torch.nn.CrossEntropyLoss() 

76 loss = loss_fn(model_fn(x), y) * (-1 if targeted else 1) 

77 loss.backward() 

78 optimal_perturbation = optimize_linear(x.grad, eps, norm) 

79 

80 # Optimize linear 

81 optimal_perturbation = optimize_linear(x.grad, eps, norm) 

82 adv_x = x + optimal_perturbation 

83 

84 # Clipping perturbations 

85 if (clip_min is not None) or (clip_max is not None): 

86 if clip_min is None or clip_max is None: 

87 raise ValueError( 

88 "One of clip_min and clip_max is None but we don't currently support one-sided clipping" 

89 ) 

90 adv_x = torch.clamp(adv_x, clip_min, clip_max) 

91 return adv_x