Coverage for robustAI/advertrain/dependencies/cleverhans/fast_gradient_method.py: 70%
33 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-10-01 08:42 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-10-01 08:42 +0000
1"""
2Taken from https://github.com/rwightman/pytorch-image-models
4The Fast Gradient Method attack.
6MIT License
7"""
8from typing import Optional
10import numpy as np
11import torch
13from robustAI.advertrain.dependencies.cleverhans.utils import optimize_linear
16def fast_gradient_method(
17 model_fn,
18 x: torch.Tensor,
19 eps: float,
20 norm: int,
21 clip_min: Optional[float] = None,
22 clip_max: Optional[float] = None,
23 y: Optional[torch.Tensor] = None,
24 targeted: bool = False,
25 sanity_checks: bool = False,
26) -> torch.Tensor:
27 """
28 PyTorch implementation of the Fast Gradient Method (FGM).
30 Args:
31 model_fn: A callable that takes an input tensor and returns the model logits.
32 x (torch.Tensor): Input tensor.
33 eps (float): Epsilon, the input variation parameter.
34 norm (int): Order of the norm (np.inf, 1, or 2).
35 clip_min (float, optional): Mininum value per input dimension.
36 clip_max (float, optional): Maximum value per input dimension.
37 y (torch.Tensor, optional): Labels or target labels for targeted attack.
38 targeted (bool): Whether to perform a targeted attack or not.
39 sanity_checks (bool): If True, include sanity checks.
41 Returns:
42 torch.Tensor: A tensor containing the adversarial examples.
43 """
44 # Clipping perturbations
45 if eps < 0:
46 raise ValueError(f"eps must be greater than or equal to 0, got {eps} instead")
47 if eps == 0:
48 return x
49 if clip_min is not None and clip_max is not None and clip_min > clip_max:
50 raise ValueError(
51 f"clip_min must be less than or equal to clip_max,got clip_min={clip_min},clip_max={clip_max}.")
53 asserts = []
55 # If a data range was specified,
56 if clip_min is not None:
57 assert_ge = torch.all(
58 torch.ge(x, torch.tensor(clip_min, device=x.device, dtype=x.dtype))
59 )
60 asserts.append(assert_ge)
62 if clip_max is not None:
63 assert_le = torch.all(
64 torch.le(x, torch.tensor(clip_max, device=x.device, dtype=x.dtype))
65 )
66 asserts.append(assert_le)
68 if sanity_checks:
69 assert np.all(asserts)
70 # Prepare input tensor
71 x = x.clone().detach().float().requires_grad_(True)
72 y = torch.argmax(model_fn(x), dim=1) if y is None else y
74 # Compute loss
75 loss_fn = torch.nn.CrossEntropyLoss()
76 loss = loss_fn(model_fn(x), y) * (-1 if targeted else 1)
77 loss.backward()
78 optimal_perturbation = optimize_linear(x.grad, eps, norm)
80 # Optimize linear
81 optimal_perturbation = optimize_linear(x.grad, eps, norm)
82 adv_x = x + optimal_perturbation
84 # Clipping perturbations
85 if (clip_min is not None) or (clip_max is not None):
86 if clip_min is None or clip_max is None:
87 raise ValueError(
88 "One of clip_min and clip_max is None but we don't currently support one-sided clipping"
89 )
90 adv_x = torch.clamp(adv_x, clip_min, clip_max)
91 return adv_x