Coverage for robustAI/advertrain/dependencies/autoattack.py: 13%

356 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-10-01 08:42 +0000

1""" 

2Taken from https://github.com/fra31/auto-attack 

3 

4MIT License 

5""" 

6import math 

7import time 

8from typing import Callable, Optional, Tuple 

9 

10import torch 

11import torch.nn as nn 

12import torch.nn.functional as F 

13 

14 

15def L0_norm(x: torch.Tensor) -> torch.Tensor: 

16 """ 

17 Calculate the L0 norm of a tensor. 

18 

19 Args: 

20 x (torch.Tensor): Input tensor. 

21 

22 Returns: 

23 torch.Tensor: The L0 norm of the input tensor. 

24 """ 

25 return (x != 0.).view(x.shape[0], -1).sum(-1) 

26 

27 

28def L1_norm(x: torch.Tensor, keepdim: bool = False) -> torch.Tensor: 

29 """ 

30 Calculate the L1 norm of a tensor. 

31 

32 Args: 

33 x (torch.Tensor): Input tensor. 

34 keepdim (bool, optional): Whether to keep the dimensions or not. Defaults to False. 

35 

36 Returns: 

37 torch.Tensor: The L1 norm of the input tensor. 

38 """ 

39 z = x.abs().view(x.shape[0], -1).sum(-1) 

40 if keepdim: 

41 z = z.view(-1, *[1] * (len(x.shape) - 1)) 

42 return z 

43 

44 

45def L2_norm(x: torch.Tensor, keepdim: bool = False) -> torch.Tensor: 

46 """ 

47 Calculate the L2 norm of a tensor. 

48 

49 Args: 

50 x (torch.Tensor): Input tensor. 

51 keepdim (bool, optional): Whether to keep the dimensions or not. Defaults to False. 

52 

53 Returns: 

54 torch.Tensor: The L2 norm of the input tensor. 

55 """ 

56 z = (x ** 2).view(x.shape[0], -1).sum(-1).sqrt() 

57 if keepdim: 

58 z = z.view(-1, *[1] * (len(x.shape) - 1)) 

59 return z 

60 

61 

62def L1_projection(x2: torch.Tensor, y2: torch.Tensor, eps1: float) -> torch.Tensor: 

63 """ 

64 Project a point onto an L1 ball. 

65 

66 Args: 

67 x2 (torch.Tensor): Center of the L1 ball (bs x input_dim). 

68 y2 (torch.Tensor): Current perturbation (x2 + y2 is the point to be projected). 

69 eps1 (float): Radius of the L1 ball. 

70 

71 Returns: 

72 torch.Tensor: Delta such that ||y2 + delta||_1 <= eps1 and 0 <= x2 + y2 + delta <= 1. 

73 """ 

74 x = x2.clone().float().view(x2.shape[0], -1) 

75 y = y2.clone().float().view(y2.shape[0], -1) 

76 sigma = y.clone().sign() 

77 u = torch.min(1 - x - y, x + y) 

78 u = torch.min(torch.zeros_like(y), u) 

79 lvar = -torch.clone(y).abs() 

80 d = u.clone() 

81 bs, indbs = torch.sort(-torch.cat((u, lvar), 1), dim=1) 

82 bs2 = torch.cat((bs[:, 1:], torch.zeros(bs.shape[0], 1).to(bs.device)), 1) 

83 inu = 2 * (indbs < u.shape[1]).float() - 1 

84 size1 = inu.cumsum(dim=1) 

85 s1 = -u.sum(dim=1) 

86 c = eps1 - y.clone().abs().sum(dim=1) 

87 c5 = s1 + c < 0 

88 c2 = c5.nonzero().squeeze(1) 

89 s = s1.unsqueeze(-1) + torch.cumsum((bs2 - bs) * size1, dim=1) 

90 

91 if c2.nelement != 0: 

92 lb = torch.zeros_like(c2).float() 

93 ub = torch.ones_like(lb) * (bs.shape[1] - 1) 

94 

95 nitermax = torch.ceil(torch.log2(torch.tensor(bs.shape[1]).float())) 

96 counter2 = torch.zeros_like(lb).long() 

97 counter = 0 

98 while counter < nitermax: 

99 counter4 = torch.floor((lb + ub) / 2.) 

100 counter2 = counter4.type(torch.LongTensor) 

101 c8 = s[c2, counter2] + c[c2] < 0 

102 ind3 = c8.nonzero().squeeze(1) 

103 ind32 = (~c8).nonzero().squeeze(1) 

104 if ind3.nelement != 0: 

105 lb[ind3] = counter4[ind3] 

106 if ind32.nelement != 0: 

107 ub[ind32] = counter4[ind32] 

108 counter += 1 

109 lb2 = lb.long() 

110 alpha = (-s[c2, lb2] - c[c2]) / size1[c2, lb2 + 1] + bs2[c2, lb2] 

111 d[c2] = -torch.min(torch.max(-u[c2], alpha.unsqueeze(-1)), -lvar[c2]) 

112 return (sigma * d).view(x2.shape) 

113 

114 

115class APGDAttack: 

116 """ 

117 Implements the Auto-PGD (Auto Projected Gradient Descent) attack method. 

118 

119 Attributes: 

120 model (Callable): A function representing the forward pass of the model to be attacked. 

121 n_iter (int): Number of iterations for the attack. 

122 norm (str): The type of norm for the attack ('Linf', 'L2', 'L1'). 

123 n_restarts (int): Number of random restarts for the attack. 

124 eps (float): The maximum perturbation amount allowed. 

125 seed (int): Random seed for reproducibility. 

126 loss (str): Type of loss function to use ('ce' for cross-entropy, 'dlr'). 

127 eot_iter (int): Number of iterations for Expectation over Transformation. 

128 rho (float): Parameter for adjusting step size. 

129 topk (Optional[float]): Parameter for controlling the sparsity of the attack. 

130 verbose (bool): If True, prints verbose output during the attack. 

131 device (Optional[torch.device]): The device on which to perform computations. 

132 use_largereps (bool): If True, uses larger epsilon values in initial iterations. 

133 is_tf_model (bool): If True, indicates the model is a TensorFlow model. 

134 

135 Methods: 

136 init_hyperparam(x): Initializes hyperparameters based on the input data. 

137 check_oscillation(...): Checks for oscillation in the optimization process. 

138 check_shape(x): Ensures the input has the expected shape. 

139 normalize(x): Normalizes the input tensor. 

140 lp_norm(x): Computes the Lp norm of the input. 

141 dlr_loss(x, y): Computes the Deep Learning Robustness (DLR) loss. 

142 attack_single_run(x, y, x_init=None): Performs a single run of the attack. 

143 perturb(x, y=None, best_loss=False, x_init=None): Generates adversarial examples for the given inputs. 

144 decr_eps_pgd(x, y, epss, iters, use_rs=True): Performs PGD with decreasing epsilon values. 

145 """ 

146 def __init__( 

147 self, 

148 predict: Callable, 

149 n_iter: int = 100, 

150 norm: str = 'Linf', 

151 n_restarts: int = 1, 

152 eps: Optional[float] = None, 

153 seed: int = 0, 

154 loss: str = 'ce', 

155 eot_iter: int = 1, 

156 rho: float = .75, 

157 topk: Optional[float] = None, 

158 verbose: bool = False, 

159 device: Optional[torch.device] = None, 

160 use_largereps: bool = False, 

161 is_tf_model: bool = False): 

162 """ 

163 Initializes the APGDAttack object with the given parameters. 

164 

165 Args: 

166 predict: A callable representing the forward pass of the model. 

167 n_iter: Number of iterations for the attack. 

168 norm: The norm type for the attack ('Linf', 'L2', 'L1'). 

169 n_restarts: Number of random restarts for the attack. 

170 eps: The maximum perturbation amount allowed. 

171 seed: Random seed for reproducibility. 

172 loss: Type of loss function to use. 

173 eot_iter: Number of iterations for Expectation over Transformation. 

174 rho: Parameter for adjusting step size. 

175 topk: Parameter for controlling sparsity in 'L1' norm. 

176 verbose: If True, enables verbose output. 

177 device: The device on which to perform computations. 

178 use_largereps: If True, uses larger epsilon values initially. 

179 is_tf_model: If True, indicates a TensorFlow model. 

180 """ 

181 self.model = predict 

182 self.n_iter = n_iter 

183 self.eps = eps 

184 self.norm = norm 

185 self.n_restarts = n_restarts 

186 self.seed = seed 

187 self.loss = loss 

188 self.eot_iter = eot_iter 

189 self.thr_decr = rho 

190 self.topk = topk 

191 self.verbose = verbose 

192 self.device = device 

193 self.use_rs = True 

194 self.use_largereps = use_largereps 

195 self.n_iter_orig = n_iter + 0 

196 self.eps_orig = eps + 0. 

197 self.is_tf_model = is_tf_model 

198 self.y_target = None 

199 

200 def init_hyperparam(self, x: torch.Tensor) -> None: 

201 """ 

202 Initializes various hyperparameters based on the input data. 

203 

204 Args: 

205 x (torch.Tensor): The input data. 

206 """ 

207 assert self.norm in ['Linf', 'L2', 'L1'] 

208 assert self.eps is not None 

209 if self.device is None: 

210 self.device = x.device 

211 self.orig_dim = list(x.shape[1:]) 

212 self.ndims = len(self.orig_dim) 

213 if self.seed is None: 

214 self.seed = time.time() 

215 

216 # set parameters for checkpoints 

217 self.n_iter_2 = max(int(0.22 * self.n_iter), 1) 

218 self.n_iter_min = max(int(0.06 * self.n_iter), 1) 

219 self.size_decr = max(int(0.03 * self.n_iter), 1) 

220 

221 def check_oscillation(self, x: torch.Tensor, j: int, k: int, y5: torch.Tensor, k3: float = 0.75) -> torch.Tensor: 

222 """ 

223 Checks for oscillation in the optimization process to adjust step sizes. 

224 

225 Args: 

226 x (torch.Tensor): The input tensor. 

227 j (int): Current iteration index. 

228 k (int): The number of steps to look back for oscillation. 

229 y5 (torch.Tensor): The tensor of losses. 

230 k3 (float, optional): Threshold parameter for oscillation. Defaults to 0.75. 

231 

232 Returns: 

233 torch.Tensor: Tensor indicating if oscillation is detected. 

234 """ 

235 t = torch.zeros(x.shape[1]).to(self.device) 

236 for counter5 in range(k): 

237 t += (x[j - counter5] > x[j - counter5 - 1]).float() 

238 return (t <= k * k3 * torch.ones_like(t)).float() 

239 

240 def check_shape(self, x: torch.Tensor) -> torch.Tensor: 

241 """ 

242 Ensures the input tensor has the correct shape. 

243 

244 Args: 

245 x (torch.Tensor): The input tensor. 

246 

247 Returns: 

248 torch.Tensor: The reshaped tensor. 

249 """ 

250 return x if len(x.shape) > 0 else x.unsqueeze(0) 

251 

252 def normalize(self, x: torch.Tensor) -> torch.Tensor: 

253 """ 

254 Normalizes the input tensor based on the specified norm type. 

255 

256 Args: 

257 x (torch.Tensor): The input tensor to be normalized. 

258 

259 Returns: 

260 torch.Tensor: The normalized tensor. 

261 """ 

262 if self.norm == 'Linf': 

263 t = x.abs().view(x.shape[0], -1).max(1)[0] 

264 return x / (t.view(-1, *([1] * self.ndims)) + 1e-12) 

265 elif self.norm == 'L2': 

266 t = (x ** 2).view(x.shape[0], -1).sum(-1).sqrt() 

267 return x / (t.view(-1, *([1] * self.ndims)) + 1e-12) 

268 elif self.norm == 'L1': 

269 try: 

270 t = x.abs().view(x.shape[0], -1).sum(dim=-1) 

271 except RuntimeError: 

272 t = x.abs().reshape([x.shape[0], -1]).sum(dim=-1) 

273 return x / (t.view(-1, *([1] * self.ndims)) + 1e-12) 

274 

275 def lp_norm(self, x: torch.Tensor) -> torch.Tensor: 

276 """ 

277 Computes the Lp norm of the input tensor. 

278 

279 Args: 

280 x (torch.Tensor): The input tensor. 

281 

282 Returns: 

283 torch.Tensor: The computed Lp norm of the input tensor. 

284 """ 

285 if self.norm == 'L2': 

286 t = (x ** 2).view(x.shape[0], -1).sum(-1).sqrt() 

287 return t.view(-1, *([1] * self.ndims)) 

288 

289 def dlr_loss(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 

290 """ 

291 Computes the Deep Learning Robustness (DLR) loss. 

292 

293 Args: 

294 x (torch.Tensor): The logits from the model. 

295 y (torch.Tensor): The target labels. 

296 

297 Returns: 

298 torch.Tensor: The computed DLR loss. 

299 """ 

300 x_sorted, ind_sorted = x.sort(dim=1) 

301 ind = (ind_sorted[:, -1] == y).float() 

302 u = torch.arange(x.shape[0]) 

303 return -(x[u, y] - x_sorted[:, -2] * ind - x_sorted[:, -1] * ( 

304 1. - ind)) / (x_sorted[:, -1] - x_sorted[:, -3] + 1e-12) 

305 

306 def attack_single_run(self, x: torch.Tensor, 

307 y: torch.Tensor, 

308 x_init: Optional[torch.Tensor] = None 

309 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 

310 """ 

311 Performs a single run of the attack. 

312 

313 Args: 

314 x (torch.Tensor): The input data (clean images). 

315 y (torch.Tensor): The target labels. 

316 x_init (Optional[torch.Tensor]): Initial starting point for the attack. 

317 

318 Returns: 

319 Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing the best perturbed inputs, 

320 the accuracy tensor, the loss tensor, and the best adversarial examples found. 

321 """ 

322 if len(x.shape) < self.ndims: 

323 x = x.unsqueeze(0) 

324 y = y.unsqueeze(0) 

325 if self.norm == 'Linf': 

326 t = 2 * torch.rand(x.shape).to(self.device).detach() - 1 

327 x_adv = x + self.eps * torch.ones_like(x).detach() * self.normalize(t) 

328 elif self.norm == 'L2': 

329 t = torch.randn(x.shape).to(self.device).detach() 

330 x_adv = x + self.eps * torch.ones_like(x).detach() * self.normalize(t) 

331 elif self.norm == 'L1': 

332 t = torch.randn(x.shape).to(self.device).detach() 

333 delta = L1_projection(x, t, self.eps) 

334 x_adv = x + t + delta 

335 if x_init is not None: 

336 x_adv = x_init.clone() 

337 if self.norm == 'L1' and self.verbose: 

338 print('[custom init] L1 perturbation {:.5f}'.format( 

339 (x_adv - x).abs().view(x.shape[0], -1).sum(1).max())) 

340 x_adv = x_adv.clamp(0., 1.) 

341 x_best = x_adv.clone() 

342 x_best_adv = x_adv.clone() 

343 loss_steps = torch.zeros([self.n_iter, x.shape[0]]).to(self.device) 

344 loss_best_steps = torch.zeros([self.n_iter + 1, x.shape[0]]).to(self.device) 

345 acc_steps = torch.zeros_like(loss_best_steps) 

346 if not self.is_tf_model: 

347 if self.loss == 'ce': 

348 criterion_indiv = nn.CrossEntropyLoss(reduction='none') 

349 elif self.loss == 'ce-targeted-cfts': 

350 def criterion_indiv(x, y): 

351 return -1. * F.cross_entropy(x, y, reduction='none') 

352 elif self.loss == 'dlr': 

353 criterion_indiv = self.dlr_loss 

354 elif self.loss == 'dlr-targeted': 

355 criterion_indiv = self.dlr_loss_targeted 

356 elif self.loss == 'ce-targeted': 

357 criterion_indiv = self.ce_loss_targeted 

358 else: 

359 raise ValueError('unknowkn loss') 

360 else: 

361 if self.loss == 'ce': 

362 criterion_indiv = self.model.get_logits_loss_grad_xent 

363 elif self.loss == 'dlr': 

364 criterion_indiv = self.model.get_logits_loss_grad_dlr 

365 elif self.loss == 'dlr-targeted': 

366 criterion_indiv = self.model.get_logits_loss_grad_target 

367 else: 

368 raise ValueError('unknowkn loss') 

369 

370 x_adv.requires_grad_() 

371 grad = torch.zeros_like(x) 

372 for _ in range(self.eot_iter): 

373 if not self.is_tf_model: 

374 with torch.enable_grad(): 

375 logits = self.model(x_adv) 

376 loss_indiv = criterion_indiv(logits, y) 

377 loss = loss_indiv.sum() 

378 grad += torch.autograd.grad(loss, [x_adv])[0].detach() 

379 else: 

380 if self.y_target is None: 

381 logits, loss_indiv, grad_curr = criterion_indiv(x_adv, y) 

382 else: 

383 logits, loss_indiv, grad_curr = criterion_indiv(x_adv, y, self.y_target) 

384 grad += grad_curr 

385 

386 grad /= float(self.eot_iter) 

387 grad_best = grad.clone() 

388 acc = logits.detach().max(1)[1] == y 

389 acc_steps[0] = acc + 0 

390 loss_best = loss_indiv.detach().clone() 

391 alpha = 2. if self.norm in ['Linf', 'L2'] else 1. if self.norm in ['L1'] else 2e-2 

392 step_size = alpha * self.eps * torch.ones([x.shape[0], *( 

393 [1] * self.ndims)]).to(self.device).detach() 

394 x_adv_old = x_adv.clone() 

395 k = self.n_iter_2 + 0 

396 if self.norm == 'L1': 

397 k = max(int(.04 * self.n_iter), 1) 

398 n_fts = math.prod(self.orig_dim) 

399 if x_init is None: 

400 topk = .2 * torch.ones([x.shape[0]], device=self.device) 

401 sp_old = n_fts * torch.ones_like(topk) 

402 else: 

403 topk = L0_norm(x_adv - x) / n_fts / 1.5 

404 sp_old = L0_norm(x_adv - x) 

405 

406 adasp_redstep = 1.5 

407 adasp_minstep = 10. 

408 

409 counter3 = 0 

410 loss_best_last_check = loss_best.clone() 

411 reduced_last_check = torch.ones_like(loss_best) 

412 # n_reduced = 0 

413 n_fts = x.shape[-3] * x.shape[-2] * x.shape[-1] 

414 u = torch.arange(x.shape[0], device=self.device) 

415 for i in range(self.n_iter): 

416 # gradient step 

417 with torch.no_grad(): 

418 x_adv = x_adv.detach() 

419 grad2 = x_adv - x_adv_old 

420 x_adv_old = x_adv.clone() 

421 a = 0.75 if i > 0 else 1.0 

422 if self.norm == 'Linf': 

423 x_adv_1 = x_adv + step_size * torch.sign(grad) 

424 x_adv_1 = torch.clamp(torch.min(torch.max(x_adv_1, x - self.eps), x + self.eps), 0.0, 1.0) 

425 x_adv_1 = torch.clamp(torch.min(torch.max( 

426 x_adv + (x_adv_1 - x_adv) * a + grad2 * (1 - a), 

427 x - self.eps), x + self.eps), 0.0, 1.0) 

428 elif self.norm == 'L2': 

429 x_adv_1 = x_adv + step_size * self.normalize(grad) 

430 x_adv_1 = torch.clamp(x + self.normalize(x_adv_1 - x) * 

431 torch.min(self.eps * torch.ones_like(x).detach(), self.lp_norm(x_adv_1 - x)), 

432 0.0, 1.0) 

433 x_adv_1 = x_adv + (x_adv_1 - x_adv) * a + grad2 * (1 - a) 

434 x_adv_1 = torch.clamp( 

435 x + self.normalize(x_adv_1 - x) 

436 * torch.min( 

437 self.eps * torch.ones_like(x).detach(), 

438 self.lp_norm(x_adv_1 - x) 

439 ), 

440 0.0, 

441 1.0 

442 ) 

443 

444 elif self.norm == 'L1': 

445 grad_topk = grad.abs().view(x.shape[0], -1).sort(-1)[0] 

446 topk_curr = torch.clamp((1. - topk) * n_fts, min=0, max=n_fts - 1).long() 

447 grad_topk = grad_topk[u, topk_curr].view(-1, *[1] * (len(x.shape) - 1)) 

448 sparsegrad = grad * (grad.abs() >= grad_topk).float() 

449 x_adv_1 = x_adv + step_size * sparsegrad.sign() / (sparsegrad.sign().abs().view(x.shape[0], -1) 

450 .sum(dim=-1).view(-1, *[1] * (len(x.shape) - 1)) 

451 + 1e-10) 

452 delta_u = x_adv_1 - x 

453 delta_p = L1_projection(x, delta_u, self.eps) 

454 x_adv_1 = x + delta_u + delta_p 

455 x_adv = x_adv_1 + 0. 

456 # get gradient 

457 x_adv.requires_grad_() 

458 grad = torch.zeros_like(x) 

459 for _ in range(self.eot_iter): 

460 if not self.is_tf_model: 

461 with torch.enable_grad(): 

462 logits = self.model(x_adv) 

463 loss_indiv = criterion_indiv(logits, y) 

464 loss = loss_indiv.sum() 

465 

466 grad += torch.autograd.grad(loss, [x_adv])[0].detach() 

467 else: 

468 if self.y_target is None: 

469 logits, loss_indiv, grad_curr = criterion_indiv(x_adv, y) 

470 else: 

471 logits, loss_indiv, grad_curr = criterion_indiv(x_adv, y, self.y_target) 

472 grad += grad_curr 

473 

474 grad /= float(self.eot_iter) 

475 pred = logits.detach().max(1)[1] == y 

476 acc = torch.min(acc, pred) 

477 acc_steps[i + 1] = acc + 0 

478 ind_pred = (pred == 0).nonzero().squeeze() 

479 x_best_adv[ind_pred] = x_adv[ind_pred] + 0. 

480 if self.verbose: 

481 str_stats = ' - step size: {:.5f} - topk: {:.2f}'.format( 

482 step_size.mean(), topk.mean() * n_fts) if self.norm in ['L1'] else '' 

483 print('[m] iteration: {} - best loss: {:.6f} - robust accuracy: {:.2%}{}'.format( 

484 i, loss_best.sum(), acc.float().mean(), str_stats)) 

485 

486 # check step size 

487 with torch.no_grad(): 

488 y1 = loss_indiv.detach().clone() 

489 loss_steps[i] = y1 + 0 

490 ind = (y1 > loss_best).nonzero().squeeze() 

491 x_best[ind] = x_adv[ind].clone() 

492 grad_best[ind] = grad[ind].clone() 

493 loss_best[ind] = y1[ind] + 0 

494 loss_best_steps[i + 1] = loss_best + 0 

495 counter3 += 1 

496 if counter3 == k: 

497 if self.norm in ['Linf', 'L2']: 

498 fl_oscillation = self.check_oscillation(loss_steps, i, k, loss_best, k3=self.thr_decr) 

499 fl_reduce_no_impr = (1. - reduced_last_check) * ( 

500 loss_best_last_check >= loss_best).float() 

501 fl_oscillation = torch.max(fl_oscillation, fl_reduce_no_impr) 

502 reduced_last_check = fl_oscillation.clone() 

503 loss_best_last_check = loss_best.clone() 

504 

505 if fl_oscillation.sum() > 0: 

506 ind_fl_osc = (fl_oscillation > 0).nonzero().squeeze() 

507 step_size[ind_fl_osc] /= 2.0 

508 # n_reduced = fl_oscillation.sum() 

509 

510 x_adv[ind_fl_osc] = x_best[ind_fl_osc].clone() 

511 grad[ind_fl_osc] = grad_best[ind_fl_osc].clone() 

512 k = max(k - self.size_decr, self.n_iter_min) 

513 

514 elif self.norm == 'L1': 

515 sp_curr = L0_norm(x_best - x) 

516 fl_redtopk = (sp_curr / sp_old) < .95 

517 topk = sp_curr / n_fts / 1.5 

518 step_size[fl_redtopk] = alpha * self.eps 

519 step_size[~fl_redtopk] /= adasp_redstep 

520 step_size.clamp_(alpha * self.eps / adasp_minstep, alpha * self.eps) 

521 sp_old = sp_curr.clone() 

522 

523 x_adv[fl_redtopk] = x_best[fl_redtopk].clone() 

524 grad[fl_redtopk] = grad_best[fl_redtopk].clone() 

525 

526 counter3 = 0 

527 

528 return (x_best, acc, loss_best, x_best_adv) 

529 

530 def perturb(self, x: torch.Tensor, y: Optional[torch.Tensor] = None, 

531 best_loss: bool = False, x_init: Optional[torch.Tensor] = None) -> torch.Tensor: 

532 """ 

533 Generates adversarial examples for the given inputs. 

534 

535 Args: 

536 x (torch.Tensor): Clean images. 

537 y (Optional[torch.Tensor]): Clean labels. If None, predicted labels are used. 

538 best_loss (bool, optional): If True, returns points with highest loss. Defaults to False. 

539 x_init (Optional[torch.Tensor]): Initial starting point for the attack. 

540 

541 Returns: 

542 torch.Tensor: Adversarial examples. 

543 """ 

544 assert self.loss in ['ce', 'dlr'] 

545 if y is not None and len(y.shape) == 0: 

546 x.unsqueeze_(0) 

547 y.unsqueeze_(0) 

548 self.init_hyperparam(x) 

549 x = x.detach().clone().float().to(self.device) 

550 if not self.is_tf_model: 

551 y_pred = self.model(x).max(1)[1] 

552 else: 

553 y_pred = self.model.predict(x).max(1)[1] 

554 if y is None: 

555 y = y_pred.detach().clone().long().to(self.device) 

556 else: 

557 y = y.detach().clone().long().to(self.device) 

558 adv = x.clone() 

559 if self.loss != 'ce-targeted': 

560 acc = y_pred == y 

561 else: 

562 acc = y_pred != y 

563 # loss = -1e10 * torch.ones_like(acc).float() 

564 if self.verbose: 

565 print('-------------------------- ', 'running {}-attack with epsilon {:.5f}'.format(self.norm, self.eps), 

566 '--------------------------') 

567 print('initial accuracy: {:.2%}'.format(acc.float().mean())) 

568 

569 if self.use_largereps: 

570 epss = [3. * self.eps_orig, 2. * self.eps_orig, 1. * self.eps_orig] 

571 iters = [.3 * self.n_iter_orig, .3 * self.n_iter_orig, .4 * self.n_iter_orig] 

572 iters = [math.ceil(c) for c in iters] 

573 iters[-1] = self.n_iter_orig - sum(iters[:-1]) # make sure to use the given iterations 

574 if self.verbose: 

575 print('using schedule [{}x{}]'.format('+'.join([str(c) for c in epss]), 

576 '+'.join([str(c) for c in iters]))) 

577 

578 startt = time.time() 

579 if not best_loss: 

580 torch.random.manual_seed(self.seed) 

581 torch.cuda.random.manual_seed(self.seed) 

582 for counter in range(self.n_restarts): 

583 ind_to_fool = acc.nonzero().squeeze() 

584 if len(ind_to_fool.shape) == 0: 

585 ind_to_fool = ind_to_fool.unsqueeze(0) 

586 if ind_to_fool.numel() != 0: 

587 x_to_fool = x[ind_to_fool].clone() 

588 y_to_fool = y[ind_to_fool].clone() 

589 

590 if not self.use_largereps: 

591 res_curr = self.attack_single_run(x_to_fool, y_to_fool) 

592 else: 

593 res_curr = self.decr_eps_pgd(x_to_fool, y_to_fool, epss, iters) 

594 best_curr, acc_curr, loss_curr, adv_curr = res_curr 

595 ind_curr = (acc_curr == 0).nonzero().squeeze() 

596 acc[ind_to_fool[ind_curr]] = 0 

597 adv[ind_to_fool[ind_curr]] = adv_curr[ind_curr].clone() 

598 if self.verbose: 

599 print('restart {} - robust accuracy: {:.2%}'.format( 

600 counter, acc.float().mean()), 

601 '- cum. time: {:.1f} s'.format( 

602 time.time() - startt)) 

603 return adv.detach().clone() 

604 else: 

605 adv_best = x.detach().clone() 

606 loss_best = torch.ones([x.shape[0]]).to( 

607 self.device) * (-float('inf')) 

608 for counter in range(self.n_restarts): 

609 best_curr, _, loss_curr, _ = self.attack_single_run(x, y) 

610 ind_curr = (loss_curr > loss_best).nonzero().squeeze() 

611 adv_best[ind_curr] = best_curr[ind_curr] + 0. 

612 loss_best[ind_curr] = loss_curr[ind_curr] + 0. 

613 if self.verbose: 

614 print('restart {} - loss: {:.5f}'.format(counter, loss_best.sum())) 

615 return adv_best 

616 

617 def decr_eps_pgd(self, x: torch.Tensor, y: torch.Tensor, epss: list, iters: list, use_rs: bool = True 

618 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 

619 """ 

620 Performs PGD with decreasing epsilon values. 

621 

622 Args: 

623 x (torch.Tensor): The input data. 

624 y (torch.Tensor): The target labels. 

625 epss (list): List of epsilon values to use in the attack. 

626 iters (list): List of iteration counts corresponding to each epsilon value. 

627 use_rs (bool, optional): If True, uses random start. Defaults to True. 

628 

629 Returns: 

630 Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing the final perturbed 

631 inputs, the accuracy tensor, the loss tensor, and the best adversarial examples found. 

632 """ 

633 assert len(epss) == len(iters) 

634 assert self.norm in ['L1'] 

635 self.use_rs = False 

636 if not use_rs: 

637 x_init = None 

638 else: 

639 x_init = x + torch.randn_like(x) 

640 x_init += L1_projection(x, x_init - x, 1. * float(epss[0])) 

641 # eps_target = float(epss[-1]) 

642 if self.verbose: 

643 print('total iter: {}'.format(sum(iters))) 

644 for eps, niter in zip(epss, iters): 

645 if self.verbose: 

646 print('using eps: {:.2f}'.format(eps)) 

647 self.n_iter = niter + 0 

648 self.eps = eps + 0. 

649 # 

650 if x_init is not None: 

651 x_init += L1_projection(x, x_init - x, 1. * eps) 

652 x_init, acc, loss, x_adv = self.attack_single_run(x, y, x_init=x_init) 

653 return (x_init, acc, loss, x_adv)