Coverage for robustAI/advertrain/dependencies/autoattack.py: 13%
356 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-10-01 08:42 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-10-01 08:42 +0000
1"""
2Taken from https://github.com/fra31/auto-attack
4MIT License
5"""
6import math
7import time
8from typing import Callable, Optional, Tuple
10import torch
11import torch.nn as nn
12import torch.nn.functional as F
15def L0_norm(x: torch.Tensor) -> torch.Tensor:
16 """
17 Calculate the L0 norm of a tensor.
19 Args:
20 x (torch.Tensor): Input tensor.
22 Returns:
23 torch.Tensor: The L0 norm of the input tensor.
24 """
25 return (x != 0.).view(x.shape[0], -1).sum(-1)
28def L1_norm(x: torch.Tensor, keepdim: bool = False) -> torch.Tensor:
29 """
30 Calculate the L1 norm of a tensor.
32 Args:
33 x (torch.Tensor): Input tensor.
34 keepdim (bool, optional): Whether to keep the dimensions or not. Defaults to False.
36 Returns:
37 torch.Tensor: The L1 norm of the input tensor.
38 """
39 z = x.abs().view(x.shape[0], -1).sum(-1)
40 if keepdim:
41 z = z.view(-1, *[1] * (len(x.shape) - 1))
42 return z
45def L2_norm(x: torch.Tensor, keepdim: bool = False) -> torch.Tensor:
46 """
47 Calculate the L2 norm of a tensor.
49 Args:
50 x (torch.Tensor): Input tensor.
51 keepdim (bool, optional): Whether to keep the dimensions or not. Defaults to False.
53 Returns:
54 torch.Tensor: The L2 norm of the input tensor.
55 """
56 z = (x ** 2).view(x.shape[0], -1).sum(-1).sqrt()
57 if keepdim:
58 z = z.view(-1, *[1] * (len(x.shape) - 1))
59 return z
62def L1_projection(x2: torch.Tensor, y2: torch.Tensor, eps1: float) -> torch.Tensor:
63 """
64 Project a point onto an L1 ball.
66 Args:
67 x2 (torch.Tensor): Center of the L1 ball (bs x input_dim).
68 y2 (torch.Tensor): Current perturbation (x2 + y2 is the point to be projected).
69 eps1 (float): Radius of the L1 ball.
71 Returns:
72 torch.Tensor: Delta such that ||y2 + delta||_1 <= eps1 and 0 <= x2 + y2 + delta <= 1.
73 """
74 x = x2.clone().float().view(x2.shape[0], -1)
75 y = y2.clone().float().view(y2.shape[0], -1)
76 sigma = y.clone().sign()
77 u = torch.min(1 - x - y, x + y)
78 u = torch.min(torch.zeros_like(y), u)
79 lvar = -torch.clone(y).abs()
80 d = u.clone()
81 bs, indbs = torch.sort(-torch.cat((u, lvar), 1), dim=1)
82 bs2 = torch.cat((bs[:, 1:], torch.zeros(bs.shape[0], 1).to(bs.device)), 1)
83 inu = 2 * (indbs < u.shape[1]).float() - 1
84 size1 = inu.cumsum(dim=1)
85 s1 = -u.sum(dim=1)
86 c = eps1 - y.clone().abs().sum(dim=1)
87 c5 = s1 + c < 0
88 c2 = c5.nonzero().squeeze(1)
89 s = s1.unsqueeze(-1) + torch.cumsum((bs2 - bs) * size1, dim=1)
91 if c2.nelement != 0:
92 lb = torch.zeros_like(c2).float()
93 ub = torch.ones_like(lb) * (bs.shape[1] - 1)
95 nitermax = torch.ceil(torch.log2(torch.tensor(bs.shape[1]).float()))
96 counter2 = torch.zeros_like(lb).long()
97 counter = 0
98 while counter < nitermax:
99 counter4 = torch.floor((lb + ub) / 2.)
100 counter2 = counter4.type(torch.LongTensor)
101 c8 = s[c2, counter2] + c[c2] < 0
102 ind3 = c8.nonzero().squeeze(1)
103 ind32 = (~c8).nonzero().squeeze(1)
104 if ind3.nelement != 0:
105 lb[ind3] = counter4[ind3]
106 if ind32.nelement != 0:
107 ub[ind32] = counter4[ind32]
108 counter += 1
109 lb2 = lb.long()
110 alpha = (-s[c2, lb2] - c[c2]) / size1[c2, lb2 + 1] + bs2[c2, lb2]
111 d[c2] = -torch.min(torch.max(-u[c2], alpha.unsqueeze(-1)), -lvar[c2])
112 return (sigma * d).view(x2.shape)
115class APGDAttack:
116 """
117 Implements the Auto-PGD (Auto Projected Gradient Descent) attack method.
119 Attributes:
120 model (Callable): A function representing the forward pass of the model to be attacked.
121 n_iter (int): Number of iterations for the attack.
122 norm (str): The type of norm for the attack ('Linf', 'L2', 'L1').
123 n_restarts (int): Number of random restarts for the attack.
124 eps (float): The maximum perturbation amount allowed.
125 seed (int): Random seed for reproducibility.
126 loss (str): Type of loss function to use ('ce' for cross-entropy, 'dlr').
127 eot_iter (int): Number of iterations for Expectation over Transformation.
128 rho (float): Parameter for adjusting step size.
129 topk (Optional[float]): Parameter for controlling the sparsity of the attack.
130 verbose (bool): If True, prints verbose output during the attack.
131 device (Optional[torch.device]): The device on which to perform computations.
132 use_largereps (bool): If True, uses larger epsilon values in initial iterations.
133 is_tf_model (bool): If True, indicates the model is a TensorFlow model.
135 Methods:
136 init_hyperparam(x): Initializes hyperparameters based on the input data.
137 check_oscillation(...): Checks for oscillation in the optimization process.
138 check_shape(x): Ensures the input has the expected shape.
139 normalize(x): Normalizes the input tensor.
140 lp_norm(x): Computes the Lp norm of the input.
141 dlr_loss(x, y): Computes the Deep Learning Robustness (DLR) loss.
142 attack_single_run(x, y, x_init=None): Performs a single run of the attack.
143 perturb(x, y=None, best_loss=False, x_init=None): Generates adversarial examples for the given inputs.
144 decr_eps_pgd(x, y, epss, iters, use_rs=True): Performs PGD with decreasing epsilon values.
145 """
146 def __init__(
147 self,
148 predict: Callable,
149 n_iter: int = 100,
150 norm: str = 'Linf',
151 n_restarts: int = 1,
152 eps: Optional[float] = None,
153 seed: int = 0,
154 loss: str = 'ce',
155 eot_iter: int = 1,
156 rho: float = .75,
157 topk: Optional[float] = None,
158 verbose: bool = False,
159 device: Optional[torch.device] = None,
160 use_largereps: bool = False,
161 is_tf_model: bool = False):
162 """
163 Initializes the APGDAttack object with the given parameters.
165 Args:
166 predict: A callable representing the forward pass of the model.
167 n_iter: Number of iterations for the attack.
168 norm: The norm type for the attack ('Linf', 'L2', 'L1').
169 n_restarts: Number of random restarts for the attack.
170 eps: The maximum perturbation amount allowed.
171 seed: Random seed for reproducibility.
172 loss: Type of loss function to use.
173 eot_iter: Number of iterations for Expectation over Transformation.
174 rho: Parameter for adjusting step size.
175 topk: Parameter for controlling sparsity in 'L1' norm.
176 verbose: If True, enables verbose output.
177 device: The device on which to perform computations.
178 use_largereps: If True, uses larger epsilon values initially.
179 is_tf_model: If True, indicates a TensorFlow model.
180 """
181 self.model = predict
182 self.n_iter = n_iter
183 self.eps = eps
184 self.norm = norm
185 self.n_restarts = n_restarts
186 self.seed = seed
187 self.loss = loss
188 self.eot_iter = eot_iter
189 self.thr_decr = rho
190 self.topk = topk
191 self.verbose = verbose
192 self.device = device
193 self.use_rs = True
194 self.use_largereps = use_largereps
195 self.n_iter_orig = n_iter + 0
196 self.eps_orig = eps + 0.
197 self.is_tf_model = is_tf_model
198 self.y_target = None
200 def init_hyperparam(self, x: torch.Tensor) -> None:
201 """
202 Initializes various hyperparameters based on the input data.
204 Args:
205 x (torch.Tensor): The input data.
206 """
207 assert self.norm in ['Linf', 'L2', 'L1']
208 assert self.eps is not None
209 if self.device is None:
210 self.device = x.device
211 self.orig_dim = list(x.shape[1:])
212 self.ndims = len(self.orig_dim)
213 if self.seed is None:
214 self.seed = time.time()
216 # set parameters for checkpoints
217 self.n_iter_2 = max(int(0.22 * self.n_iter), 1)
218 self.n_iter_min = max(int(0.06 * self.n_iter), 1)
219 self.size_decr = max(int(0.03 * self.n_iter), 1)
221 def check_oscillation(self, x: torch.Tensor, j: int, k: int, y5: torch.Tensor, k3: float = 0.75) -> torch.Tensor:
222 """
223 Checks for oscillation in the optimization process to adjust step sizes.
225 Args:
226 x (torch.Tensor): The input tensor.
227 j (int): Current iteration index.
228 k (int): The number of steps to look back for oscillation.
229 y5 (torch.Tensor): The tensor of losses.
230 k3 (float, optional): Threshold parameter for oscillation. Defaults to 0.75.
232 Returns:
233 torch.Tensor: Tensor indicating if oscillation is detected.
234 """
235 t = torch.zeros(x.shape[1]).to(self.device)
236 for counter5 in range(k):
237 t += (x[j - counter5] > x[j - counter5 - 1]).float()
238 return (t <= k * k3 * torch.ones_like(t)).float()
240 def check_shape(self, x: torch.Tensor) -> torch.Tensor:
241 """
242 Ensures the input tensor has the correct shape.
244 Args:
245 x (torch.Tensor): The input tensor.
247 Returns:
248 torch.Tensor: The reshaped tensor.
249 """
250 return x if len(x.shape) > 0 else x.unsqueeze(0)
252 def normalize(self, x: torch.Tensor) -> torch.Tensor:
253 """
254 Normalizes the input tensor based on the specified norm type.
256 Args:
257 x (torch.Tensor): The input tensor to be normalized.
259 Returns:
260 torch.Tensor: The normalized tensor.
261 """
262 if self.norm == 'Linf':
263 t = x.abs().view(x.shape[0], -1).max(1)[0]
264 return x / (t.view(-1, *([1] * self.ndims)) + 1e-12)
265 elif self.norm == 'L2':
266 t = (x ** 2).view(x.shape[0], -1).sum(-1).sqrt()
267 return x / (t.view(-1, *([1] * self.ndims)) + 1e-12)
268 elif self.norm == 'L1':
269 try:
270 t = x.abs().view(x.shape[0], -1).sum(dim=-1)
271 except RuntimeError:
272 t = x.abs().reshape([x.shape[0], -1]).sum(dim=-1)
273 return x / (t.view(-1, *([1] * self.ndims)) + 1e-12)
275 def lp_norm(self, x: torch.Tensor) -> torch.Tensor:
276 """
277 Computes the Lp norm of the input tensor.
279 Args:
280 x (torch.Tensor): The input tensor.
282 Returns:
283 torch.Tensor: The computed Lp norm of the input tensor.
284 """
285 if self.norm == 'L2':
286 t = (x ** 2).view(x.shape[0], -1).sum(-1).sqrt()
287 return t.view(-1, *([1] * self.ndims))
289 def dlr_loss(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
290 """
291 Computes the Deep Learning Robustness (DLR) loss.
293 Args:
294 x (torch.Tensor): The logits from the model.
295 y (torch.Tensor): The target labels.
297 Returns:
298 torch.Tensor: The computed DLR loss.
299 """
300 x_sorted, ind_sorted = x.sort(dim=1)
301 ind = (ind_sorted[:, -1] == y).float()
302 u = torch.arange(x.shape[0])
303 return -(x[u, y] - x_sorted[:, -2] * ind - x_sorted[:, -1] * (
304 1. - ind)) / (x_sorted[:, -1] - x_sorted[:, -3] + 1e-12)
306 def attack_single_run(self, x: torch.Tensor,
307 y: torch.Tensor,
308 x_init: Optional[torch.Tensor] = None
309 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
310 """
311 Performs a single run of the attack.
313 Args:
314 x (torch.Tensor): The input data (clean images).
315 y (torch.Tensor): The target labels.
316 x_init (Optional[torch.Tensor]): Initial starting point for the attack.
318 Returns:
319 Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing the best perturbed inputs,
320 the accuracy tensor, the loss tensor, and the best adversarial examples found.
321 """
322 if len(x.shape) < self.ndims:
323 x = x.unsqueeze(0)
324 y = y.unsqueeze(0)
325 if self.norm == 'Linf':
326 t = 2 * torch.rand(x.shape).to(self.device).detach() - 1
327 x_adv = x + self.eps * torch.ones_like(x).detach() * self.normalize(t)
328 elif self.norm == 'L2':
329 t = torch.randn(x.shape).to(self.device).detach()
330 x_adv = x + self.eps * torch.ones_like(x).detach() * self.normalize(t)
331 elif self.norm == 'L1':
332 t = torch.randn(x.shape).to(self.device).detach()
333 delta = L1_projection(x, t, self.eps)
334 x_adv = x + t + delta
335 if x_init is not None:
336 x_adv = x_init.clone()
337 if self.norm == 'L1' and self.verbose:
338 print('[custom init] L1 perturbation {:.5f}'.format(
339 (x_adv - x).abs().view(x.shape[0], -1).sum(1).max()))
340 x_adv = x_adv.clamp(0., 1.)
341 x_best = x_adv.clone()
342 x_best_adv = x_adv.clone()
343 loss_steps = torch.zeros([self.n_iter, x.shape[0]]).to(self.device)
344 loss_best_steps = torch.zeros([self.n_iter + 1, x.shape[0]]).to(self.device)
345 acc_steps = torch.zeros_like(loss_best_steps)
346 if not self.is_tf_model:
347 if self.loss == 'ce':
348 criterion_indiv = nn.CrossEntropyLoss(reduction='none')
349 elif self.loss == 'ce-targeted-cfts':
350 def criterion_indiv(x, y):
351 return -1. * F.cross_entropy(x, y, reduction='none')
352 elif self.loss == 'dlr':
353 criterion_indiv = self.dlr_loss
354 elif self.loss == 'dlr-targeted':
355 criterion_indiv = self.dlr_loss_targeted
356 elif self.loss == 'ce-targeted':
357 criterion_indiv = self.ce_loss_targeted
358 else:
359 raise ValueError('unknowkn loss')
360 else:
361 if self.loss == 'ce':
362 criterion_indiv = self.model.get_logits_loss_grad_xent
363 elif self.loss == 'dlr':
364 criterion_indiv = self.model.get_logits_loss_grad_dlr
365 elif self.loss == 'dlr-targeted':
366 criterion_indiv = self.model.get_logits_loss_grad_target
367 else:
368 raise ValueError('unknowkn loss')
370 x_adv.requires_grad_()
371 grad = torch.zeros_like(x)
372 for _ in range(self.eot_iter):
373 if not self.is_tf_model:
374 with torch.enable_grad():
375 logits = self.model(x_adv)
376 loss_indiv = criterion_indiv(logits, y)
377 loss = loss_indiv.sum()
378 grad += torch.autograd.grad(loss, [x_adv])[0].detach()
379 else:
380 if self.y_target is None:
381 logits, loss_indiv, grad_curr = criterion_indiv(x_adv, y)
382 else:
383 logits, loss_indiv, grad_curr = criterion_indiv(x_adv, y, self.y_target)
384 grad += grad_curr
386 grad /= float(self.eot_iter)
387 grad_best = grad.clone()
388 acc = logits.detach().max(1)[1] == y
389 acc_steps[0] = acc + 0
390 loss_best = loss_indiv.detach().clone()
391 alpha = 2. if self.norm in ['Linf', 'L2'] else 1. if self.norm in ['L1'] else 2e-2
392 step_size = alpha * self.eps * torch.ones([x.shape[0], *(
393 [1] * self.ndims)]).to(self.device).detach()
394 x_adv_old = x_adv.clone()
395 k = self.n_iter_2 + 0
396 if self.norm == 'L1':
397 k = max(int(.04 * self.n_iter), 1)
398 n_fts = math.prod(self.orig_dim)
399 if x_init is None:
400 topk = .2 * torch.ones([x.shape[0]], device=self.device)
401 sp_old = n_fts * torch.ones_like(topk)
402 else:
403 topk = L0_norm(x_adv - x) / n_fts / 1.5
404 sp_old = L0_norm(x_adv - x)
406 adasp_redstep = 1.5
407 adasp_minstep = 10.
409 counter3 = 0
410 loss_best_last_check = loss_best.clone()
411 reduced_last_check = torch.ones_like(loss_best)
412 # n_reduced = 0
413 n_fts = x.shape[-3] * x.shape[-2] * x.shape[-1]
414 u = torch.arange(x.shape[0], device=self.device)
415 for i in range(self.n_iter):
416 # gradient step
417 with torch.no_grad():
418 x_adv = x_adv.detach()
419 grad2 = x_adv - x_adv_old
420 x_adv_old = x_adv.clone()
421 a = 0.75 if i > 0 else 1.0
422 if self.norm == 'Linf':
423 x_adv_1 = x_adv + step_size * torch.sign(grad)
424 x_adv_1 = torch.clamp(torch.min(torch.max(x_adv_1, x - self.eps), x + self.eps), 0.0, 1.0)
425 x_adv_1 = torch.clamp(torch.min(torch.max(
426 x_adv + (x_adv_1 - x_adv) * a + grad2 * (1 - a),
427 x - self.eps), x + self.eps), 0.0, 1.0)
428 elif self.norm == 'L2':
429 x_adv_1 = x_adv + step_size * self.normalize(grad)
430 x_adv_1 = torch.clamp(x + self.normalize(x_adv_1 - x) *
431 torch.min(self.eps * torch.ones_like(x).detach(), self.lp_norm(x_adv_1 - x)),
432 0.0, 1.0)
433 x_adv_1 = x_adv + (x_adv_1 - x_adv) * a + grad2 * (1 - a)
434 x_adv_1 = torch.clamp(
435 x + self.normalize(x_adv_1 - x)
436 * torch.min(
437 self.eps * torch.ones_like(x).detach(),
438 self.lp_norm(x_adv_1 - x)
439 ),
440 0.0,
441 1.0
442 )
444 elif self.norm == 'L1':
445 grad_topk = grad.abs().view(x.shape[0], -1).sort(-1)[0]
446 topk_curr = torch.clamp((1. - topk) * n_fts, min=0, max=n_fts - 1).long()
447 grad_topk = grad_topk[u, topk_curr].view(-1, *[1] * (len(x.shape) - 1))
448 sparsegrad = grad * (grad.abs() >= grad_topk).float()
449 x_adv_1 = x_adv + step_size * sparsegrad.sign() / (sparsegrad.sign().abs().view(x.shape[0], -1)
450 .sum(dim=-1).view(-1, *[1] * (len(x.shape) - 1))
451 + 1e-10)
452 delta_u = x_adv_1 - x
453 delta_p = L1_projection(x, delta_u, self.eps)
454 x_adv_1 = x + delta_u + delta_p
455 x_adv = x_adv_1 + 0.
456 # get gradient
457 x_adv.requires_grad_()
458 grad = torch.zeros_like(x)
459 for _ in range(self.eot_iter):
460 if not self.is_tf_model:
461 with torch.enable_grad():
462 logits = self.model(x_adv)
463 loss_indiv = criterion_indiv(logits, y)
464 loss = loss_indiv.sum()
466 grad += torch.autograd.grad(loss, [x_adv])[0].detach()
467 else:
468 if self.y_target is None:
469 logits, loss_indiv, grad_curr = criterion_indiv(x_adv, y)
470 else:
471 logits, loss_indiv, grad_curr = criterion_indiv(x_adv, y, self.y_target)
472 grad += grad_curr
474 grad /= float(self.eot_iter)
475 pred = logits.detach().max(1)[1] == y
476 acc = torch.min(acc, pred)
477 acc_steps[i + 1] = acc + 0
478 ind_pred = (pred == 0).nonzero().squeeze()
479 x_best_adv[ind_pred] = x_adv[ind_pred] + 0.
480 if self.verbose:
481 str_stats = ' - step size: {:.5f} - topk: {:.2f}'.format(
482 step_size.mean(), topk.mean() * n_fts) if self.norm in ['L1'] else ''
483 print('[m] iteration: {} - best loss: {:.6f} - robust accuracy: {:.2%}{}'.format(
484 i, loss_best.sum(), acc.float().mean(), str_stats))
486 # check step size
487 with torch.no_grad():
488 y1 = loss_indiv.detach().clone()
489 loss_steps[i] = y1 + 0
490 ind = (y1 > loss_best).nonzero().squeeze()
491 x_best[ind] = x_adv[ind].clone()
492 grad_best[ind] = grad[ind].clone()
493 loss_best[ind] = y1[ind] + 0
494 loss_best_steps[i + 1] = loss_best + 0
495 counter3 += 1
496 if counter3 == k:
497 if self.norm in ['Linf', 'L2']:
498 fl_oscillation = self.check_oscillation(loss_steps, i, k, loss_best, k3=self.thr_decr)
499 fl_reduce_no_impr = (1. - reduced_last_check) * (
500 loss_best_last_check >= loss_best).float()
501 fl_oscillation = torch.max(fl_oscillation, fl_reduce_no_impr)
502 reduced_last_check = fl_oscillation.clone()
503 loss_best_last_check = loss_best.clone()
505 if fl_oscillation.sum() > 0:
506 ind_fl_osc = (fl_oscillation > 0).nonzero().squeeze()
507 step_size[ind_fl_osc] /= 2.0
508 # n_reduced = fl_oscillation.sum()
510 x_adv[ind_fl_osc] = x_best[ind_fl_osc].clone()
511 grad[ind_fl_osc] = grad_best[ind_fl_osc].clone()
512 k = max(k - self.size_decr, self.n_iter_min)
514 elif self.norm == 'L1':
515 sp_curr = L0_norm(x_best - x)
516 fl_redtopk = (sp_curr / sp_old) < .95
517 topk = sp_curr / n_fts / 1.5
518 step_size[fl_redtopk] = alpha * self.eps
519 step_size[~fl_redtopk] /= adasp_redstep
520 step_size.clamp_(alpha * self.eps / adasp_minstep, alpha * self.eps)
521 sp_old = sp_curr.clone()
523 x_adv[fl_redtopk] = x_best[fl_redtopk].clone()
524 grad[fl_redtopk] = grad_best[fl_redtopk].clone()
526 counter3 = 0
528 return (x_best, acc, loss_best, x_best_adv)
530 def perturb(self, x: torch.Tensor, y: Optional[torch.Tensor] = None,
531 best_loss: bool = False, x_init: Optional[torch.Tensor] = None) -> torch.Tensor:
532 """
533 Generates adversarial examples for the given inputs.
535 Args:
536 x (torch.Tensor): Clean images.
537 y (Optional[torch.Tensor]): Clean labels. If None, predicted labels are used.
538 best_loss (bool, optional): If True, returns points with highest loss. Defaults to False.
539 x_init (Optional[torch.Tensor]): Initial starting point for the attack.
541 Returns:
542 torch.Tensor: Adversarial examples.
543 """
544 assert self.loss in ['ce', 'dlr']
545 if y is not None and len(y.shape) == 0:
546 x.unsqueeze_(0)
547 y.unsqueeze_(0)
548 self.init_hyperparam(x)
549 x = x.detach().clone().float().to(self.device)
550 if not self.is_tf_model:
551 y_pred = self.model(x).max(1)[1]
552 else:
553 y_pred = self.model.predict(x).max(1)[1]
554 if y is None:
555 y = y_pred.detach().clone().long().to(self.device)
556 else:
557 y = y.detach().clone().long().to(self.device)
558 adv = x.clone()
559 if self.loss != 'ce-targeted':
560 acc = y_pred == y
561 else:
562 acc = y_pred != y
563 # loss = -1e10 * torch.ones_like(acc).float()
564 if self.verbose:
565 print('-------------------------- ', 'running {}-attack with epsilon {:.5f}'.format(self.norm, self.eps),
566 '--------------------------')
567 print('initial accuracy: {:.2%}'.format(acc.float().mean()))
569 if self.use_largereps:
570 epss = [3. * self.eps_orig, 2. * self.eps_orig, 1. * self.eps_orig]
571 iters = [.3 * self.n_iter_orig, .3 * self.n_iter_orig, .4 * self.n_iter_orig]
572 iters = [math.ceil(c) for c in iters]
573 iters[-1] = self.n_iter_orig - sum(iters[:-1]) # make sure to use the given iterations
574 if self.verbose:
575 print('using schedule [{}x{}]'.format('+'.join([str(c) for c in epss]),
576 '+'.join([str(c) for c in iters])))
578 startt = time.time()
579 if not best_loss:
580 torch.random.manual_seed(self.seed)
581 torch.cuda.random.manual_seed(self.seed)
582 for counter in range(self.n_restarts):
583 ind_to_fool = acc.nonzero().squeeze()
584 if len(ind_to_fool.shape) == 0:
585 ind_to_fool = ind_to_fool.unsqueeze(0)
586 if ind_to_fool.numel() != 0:
587 x_to_fool = x[ind_to_fool].clone()
588 y_to_fool = y[ind_to_fool].clone()
590 if not self.use_largereps:
591 res_curr = self.attack_single_run(x_to_fool, y_to_fool)
592 else:
593 res_curr = self.decr_eps_pgd(x_to_fool, y_to_fool, epss, iters)
594 best_curr, acc_curr, loss_curr, adv_curr = res_curr
595 ind_curr = (acc_curr == 0).nonzero().squeeze()
596 acc[ind_to_fool[ind_curr]] = 0
597 adv[ind_to_fool[ind_curr]] = adv_curr[ind_curr].clone()
598 if self.verbose:
599 print('restart {} - robust accuracy: {:.2%}'.format(
600 counter, acc.float().mean()),
601 '- cum. time: {:.1f} s'.format(
602 time.time() - startt))
603 return adv.detach().clone()
604 else:
605 adv_best = x.detach().clone()
606 loss_best = torch.ones([x.shape[0]]).to(
607 self.device) * (-float('inf'))
608 for counter in range(self.n_restarts):
609 best_curr, _, loss_curr, _ = self.attack_single_run(x, y)
610 ind_curr = (loss_curr > loss_best).nonzero().squeeze()
611 adv_best[ind_curr] = best_curr[ind_curr] + 0.
612 loss_best[ind_curr] = loss_curr[ind_curr] + 0.
613 if self.verbose:
614 print('restart {} - loss: {:.5f}'.format(counter, loss_best.sum()))
615 return adv_best
617 def decr_eps_pgd(self, x: torch.Tensor, y: torch.Tensor, epss: list, iters: list, use_rs: bool = True
618 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
619 """
620 Performs PGD with decreasing epsilon values.
622 Args:
623 x (torch.Tensor): The input data.
624 y (torch.Tensor): The target labels.
625 epss (list): List of epsilon values to use in the attack.
626 iters (list): List of iteration counts corresponding to each epsilon value.
627 use_rs (bool, optional): If True, uses random start. Defaults to True.
629 Returns:
630 Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing the final perturbed
631 inputs, the accuracy tensor, the loss tensor, and the best adversarial examples found.
632 """
633 assert len(epss) == len(iters)
634 assert self.norm in ['L1']
635 self.use_rs = False
636 if not use_rs:
637 x_init = None
638 else:
639 x_init = x + torch.randn_like(x)
640 x_init += L1_projection(x, x_init - x, 1. * float(epss[0]))
641 # eps_target = float(epss[-1])
642 if self.verbose:
643 print('total iter: {}'.format(sum(iters)))
644 for eps, niter in zip(epss, iters):
645 if self.verbose:
646 print('using eps: {:.2f}'.format(eps))
647 self.n_iter = niter + 0
648 self.eps = eps + 0.
649 #
650 if x_init is not None:
651 x_init += L1_projection(x, x_init - x, 1. * eps)
652 x_init, acc, loss, x_adv = self.attack_single_run(x, y, x_init=x_init)
653 return (x_init, acc, loss, x_adv)