Coverage for robustML/advertrain/dependencies/trades.py: 59%

56 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-09-10 08:11 +0000

1""" 

2Taken from https://github.com/yaodongyu/TRADES 

3 

4MIT License 

5""" 

6 

7import torch 

8import torch.nn as nn 

9import torch.nn.functional as F 

10import torch.optim as optim 

11from torch.autograd import Variable 

12 

13 

14def squared_l2_norm(x: torch.Tensor) -> torch.Tensor: 

15 """ 

16 Compute the squared L2 norm of a tensor. 

17 

18 Args: 

19 x (torch.Tensor): The input tensor. 

20 

21 Returns: 

22 torch.Tensor: The squared L2 norm of the flattened input tensor. 

23 """ 

24 flattened = x.view(x.unsqueeze(0).shape[0], -1) 

25 return (flattened ** 2).sum(1) 

26 

27 

28def l2_norm(x: torch.Tensor) -> torch.Tensor: 

29 """ 

30 Compute the L2 norm of a tensor. 

31 

32 Args: 

33 x (torch.Tensor): The input tensor. 

34 

35 Returns: 

36 torch.Tensor: The L2 norm of the input tensor. 

37 """ 

38 return squared_l2_norm(x).sqrt() 

39 

40 

41def trades_loss( 

42 model: nn.Module, 

43 x_natural: torch.Tensor, 

44 y: torch.Tensor, 

45 optimizer: torch.optim.Optimizer, 

46 step_size: float = 0.003, 

47 epsilon: float = 0.031, 

48 perturb_steps: int = 10, 

49 beta: float = 1.0, 

50 distance: str = "l_inf", 

51 device: torch.device = None 

52) -> torch.Tensor: 

53 """ 

54 Calculate the TRADES loss for training robust models. 

55 

56 Args: 

57 model (nn.Module): The neural network model. 

58 x_natural (torch.Tensor): Natural (clean) inputs. 

59 y (torch.Tensor): Target outputs. 

60 optimizer (torch.optim.Optimizer): Optimizer for the model. 

61 step_size (float, optional): Step size for perturbation. Defaults to 0.003. 

62 epsilon (float, optional): Perturbation limit. Defaults to 0.031. 

63 perturb_steps (int, optional): Number of perturbation steps. Defaults to 10. 

64 beta (float, optional): Regularization parameter for TRADES. Defaults to 1.0. 

65 distance (str, optional): Norm for perturbation ('l_inf' or 'l_2'). Defaults to 'l_inf'. 

66 device (torch.device, optional): The device to use (e.g., 'cuda' or 'cpu'). 

67 

68 Returns: 

69 torch.Tensor: The TRADES loss. 

70 """ 

71 # define KL-loss 

72 criterion_kl = nn.KLDivLoss(size_average=False) 

73 model.eval() 

74 batch_size = len(x_natural) 

75 # generate adversarial example 

76 if "cuda" in str(device): 

77 x_adv = x_natural.detach() + 0.001 * torch.randn(x_natural.shape).cuda(device).detach() 

78 else: 

79 x_adv = x_natural.detach() + 0.001 * torch.randn(x_natural.shape).detach() 

80 if distance == "l_inf": 

81 for _ in range(perturb_steps): 

82 x_adv.requires_grad_() 

83 with torch.enable_grad(): 

84 loss_kl = criterion_kl( 

85 F.log_softmax(model(x_adv), dim=1), 

86 F.softmax(model(x_natural), dim=1), 

87 ) 

88 grad = torch.autograd.grad(loss_kl, [x_adv])[0] 

89 x_adv = x_adv.detach() + step_size * torch.sign(grad.detach()) 

90 x_adv = torch.min( 

91 torch.max(x_adv, x_natural - epsilon), x_natural + epsilon 

92 ) 

93 x_adv = torch.clamp(x_adv, 0.0, 1.0) 

94 elif distance == "l_2": 

95 if "cuda" in str(device): 

96 delta = 0.001 * torch.randn(x_natural.shape).cuda(device).detach() 

97 else: 

98 delta = 0.001 * torch.randn(x_natural.shape).detach() 

99 delta = Variable(delta.data, requires_grad=True) 

100 

101 # Setup optimizers 

102 optimizer_delta = optim.SGD([delta], lr=epsilon / perturb_steps * 2) 

103 

104 for _ in range(perturb_steps): 

105 adv = x_natural + delta 

106 

107 # optimize 

108 optimizer_delta.zero_grad() 

109 with torch.enable_grad(): 

110 loss = (-1) * criterion_kl( 

111 F.log_softmax(model(adv), dim=1), F.softmax(model(x_natural), dim=1) 

112 ) 

113 loss.backward() 

114 # renorming gradient 

115 grad_norms = delta.grad.view(batch_size, -1).norm(p=2, dim=1) 

116 delta.grad.div_(grad_norms.view(-1, 1, 1, 1)) 

117 # avoid nan or inf if gradient is 0 

118 if (grad_norms == 0).any(): 

119 delta.grad[grad_norms == 0] = torch.randn_like( 

120 delta.grad[grad_norms == 0] 

121 ) 

122 optimizer_delta.step() 

123 

124 # projection 

125 delta.data.add_(x_natural) 

126 delta.data.clamp_(0, 1).sub_(x_natural) 

127 delta.data.renorm_(p=2, dim=0, maxnorm=epsilon) 

128 x_adv = Variable(x_natural + delta, requires_grad=False) 

129 else: 

130 x_adv = torch.clamp(x_adv, 0.0, 1.0) 

131 model.train() 

132 

133 x_adv = Variable(torch.clamp(x_adv, 0.0, 1.0), requires_grad=False) 

134 # zero gradient 

135 optimizer.zero_grad() 

136 # calculate robust loss 

137 logits = model(x_natural) 

138 loss_natural = F.cross_entropy(logits, y) 

139 loss_robust = (1.0 / batch_size) * criterion_kl( 

140 F.log_softmax(model(x_adv), dim=1), F.softmax(model(x_natural), dim=1) 

141 ) 

142 loss = loss_natural + beta * loss_robust 

143 return loss