Coverage for robustAI/advertrain/dependencies/cleverhans/utils.py: 67%

43 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-10-01 08:42 +0000

1""" 

2Taken from https://github.com/rwightman/pytorch-image-models 

3 

4MIT License 

5""" 

6import numpy as np 

7import torch 

8 

9 

10def clip_eta(eta: torch.Tensor, norm: int, eps: float) -> torch.Tensor: 

11 """ 

12 Clips the perturbation eta to be within the specified norm ball. 

13 

14 Args: 

15 eta (torch.Tensor): The perturbation tensor. 

16 norm (int): The norm to use. 

17 eps (float): Epsilon, the maximum allowed norm of the perturbation. 

18 

19 Returns: 

20 torch.Tensor: The clipped perturbation. 

21 """ 

22 if norm not in [np.inf, 1, 2]: 

23 raise ValueError("Norm must be np.inf, 1, or 2.") 

24 

25 elif norm == np.inf: 

26 eta = torch.clamp(eta, -eps, eps) 

27 else: 

28 avoid_zero_div = torch.tensor(1e-12, dtype=eta.dtype, device=eta.device) 

29 reduc_ind = list(range(1, len(eta.size()))) 

30 norm_val = ( 

31 torch.sqrt(torch.sum(eta**2, dim=reduc_ind, keepdim=True)) 

32 if norm == 2 

33 else torch.sum(torch.abs(eta), dim=reduc_ind, keepdim=True) 

34 ) 

35 norm_val = torch.max(norm_val, avoid_zero_div) 

36 factor = torch.min(torch.tensor(1.0, dtype=eta.dtype, device=eta.device), eps / norm_val) 

37 eta *= factor 

38 

39 return eta 

40 

41 

42def optimize_linear(grad: torch.Tensor, eps: float, norm: int = np.inf) -> torch.Tensor: 

43 """ 

44 Solves for the optimal input to a linear function under a norm constraint. 

45 

46 Args: 

47 grad (torch.Tensor): Tensor of gradients. 

48 eps (float): Epsilon, the maximum allowed norm of the perturbation. 

49 norm (int): The norm to use. 

50 

51 Returns: 

52 torch.Tensor: The optimized perturbation. 

53 """ 

54 red_ind = list(range(1, len(grad.size()))) 

55 avoid_zero_div = torch.tensor(1e-12, dtype=grad.dtype, device=grad.device) 

56 

57 if norm == np.inf: 

58 optimal_perturbation = torch.sign(grad) 

59 elif norm == 1: 

60 abs_grad = torch.abs(grad) 

61 sign = torch.sign(grad) 

62 red_ind = list(range(1, len(grad.size()))) 

63 abs_grad = torch.abs(grad) 

64 ori_shape = [1] * len(grad.size()) 

65 ori_shape[0] = grad.size(0) 

66 

67 max_abs_grad, _ = torch.max(abs_grad.view(grad.size(0), -1), 1) 

68 max_mask = abs_grad.eq(max_abs_grad.view(ori_shape)).to(torch.float) 

69 num_ties = max_mask 

70 for red_scalar in red_ind: 

71 num_ties = torch.sum(num_ties, red_scalar, keepdim=True) 

72 optimal_perturbation = sign * max_mask / num_ties 

73 opt_pert_norm = optimal_perturbation.abs().sum(dim=red_ind) 

74 assert torch.all(opt_pert_norm == torch.ones_like(opt_pert_norm)) 

75 elif norm == 2: 

76 square = torch.max(avoid_zero_div, torch.sum(grad ** 2, red_ind, keepdim=True)) 

77 optimal_perturbation = grad / torch.sqrt(square) 

78 

79 opt_pert_norm = ( 

80 optimal_perturbation.pow(2).sum(dim=red_ind, keepdim=True).sqrt() 

81 ) 

82 one_mask = (square <= avoid_zero_div).to(torch.float) * opt_pert_norm + ( 

83 square > avoid_zero_div 

84 ).to(torch.float) 

85 assert torch.allclose(opt_pert_norm, one_mask, rtol=1e-05, atol=1e-08) 

86 else: 

87 raise ValueError("Only L-inf, L1 and L2 norms are currently implemented.") 

88 

89 scaled_perturbation = eps * optimal_perturbation 

90 return scaled_perturbation