Coverage for robustAI/advertrain/dependencies/cleverhans/utils.py: 67%
43 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-10-01 08:42 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-10-01 08:42 +0000
1"""
2Taken from https://github.com/rwightman/pytorch-image-models
4MIT License
5"""
6import numpy as np
7import torch
10def clip_eta(eta: torch.Tensor, norm: int, eps: float) -> torch.Tensor:
11 """
12 Clips the perturbation eta to be within the specified norm ball.
14 Args:
15 eta (torch.Tensor): The perturbation tensor.
16 norm (int): The norm to use.
17 eps (float): Epsilon, the maximum allowed norm of the perturbation.
19 Returns:
20 torch.Tensor: The clipped perturbation.
21 """
22 if norm not in [np.inf, 1, 2]:
23 raise ValueError("Norm must be np.inf, 1, or 2.")
25 elif norm == np.inf:
26 eta = torch.clamp(eta, -eps, eps)
27 else:
28 avoid_zero_div = torch.tensor(1e-12, dtype=eta.dtype, device=eta.device)
29 reduc_ind = list(range(1, len(eta.size())))
30 norm_val = (
31 torch.sqrt(torch.sum(eta**2, dim=reduc_ind, keepdim=True))
32 if norm == 2
33 else torch.sum(torch.abs(eta), dim=reduc_ind, keepdim=True)
34 )
35 norm_val = torch.max(norm_val, avoid_zero_div)
36 factor = torch.min(torch.tensor(1.0, dtype=eta.dtype, device=eta.device), eps / norm_val)
37 eta *= factor
39 return eta
42def optimize_linear(grad: torch.Tensor, eps: float, norm: int = np.inf) -> torch.Tensor:
43 """
44 Solves for the optimal input to a linear function under a norm constraint.
46 Args:
47 grad (torch.Tensor): Tensor of gradients.
48 eps (float): Epsilon, the maximum allowed norm of the perturbation.
49 norm (int): The norm to use.
51 Returns:
52 torch.Tensor: The optimized perturbation.
53 """
54 red_ind = list(range(1, len(grad.size())))
55 avoid_zero_div = torch.tensor(1e-12, dtype=grad.dtype, device=grad.device)
57 if norm == np.inf:
58 optimal_perturbation = torch.sign(grad)
59 elif norm == 1:
60 abs_grad = torch.abs(grad)
61 sign = torch.sign(grad)
62 red_ind = list(range(1, len(grad.size())))
63 abs_grad = torch.abs(grad)
64 ori_shape = [1] * len(grad.size())
65 ori_shape[0] = grad.size(0)
67 max_abs_grad, _ = torch.max(abs_grad.view(grad.size(0), -1), 1)
68 max_mask = abs_grad.eq(max_abs_grad.view(ori_shape)).to(torch.float)
69 num_ties = max_mask
70 for red_scalar in red_ind:
71 num_ties = torch.sum(num_ties, red_scalar, keepdim=True)
72 optimal_perturbation = sign * max_mask / num_ties
73 opt_pert_norm = optimal_perturbation.abs().sum(dim=red_ind)
74 assert torch.all(opt_pert_norm == torch.ones_like(opt_pert_norm))
75 elif norm == 2:
76 square = torch.max(avoid_zero_div, torch.sum(grad ** 2, red_ind, keepdim=True))
77 optimal_perturbation = grad / torch.sqrt(square)
79 opt_pert_norm = (
80 optimal_perturbation.pow(2).sum(dim=red_ind, keepdim=True).sqrt()
81 )
82 one_mask = (square <= avoid_zero_div).to(torch.float) * opt_pert_norm + (
83 square > avoid_zero_div
84 ).to(torch.float)
85 assert torch.allclose(opt_pert_norm, one_mask, rtol=1e-05, atol=1e-08)
86 else:
87 raise ValueError("Only L-inf, L1 and L2 norms are currently implemented.")
89 scaled_perturbation = eps * optimal_perturbation
90 return scaled_perturbation