Coverage for adaro_rl / attacks / fgm.py: 84%
159 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 07:50 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 07:50 +0000
1from typing import Tuple
2import torch
3import numpy as np
4from .base_attack import BaseAttack
7NO_TARGET_ERROR = "No target given. Please provide a callable or a string target."
10# ===========================================
11# FastGradientMethod_DiscreteAction
12# ===========================================
15class FastGradientMethodBaseClass(BaseAttack):
16 """
17 Base class for Fast Gradient Method (FGM) attacks.
19 This class provides a common interface for implementing FGM-based adversarial
20 attacks across different agent types (discrete, continuous, critic, etc.).
22 Parameters
23 ----------
24 make_agent_fct : callable
25 Function that returns an agent instance.
26 target : callable | str
27 Target specification for the attack. Can be:
28 - A callable that returns target values and a `targeted` boolean.
29 - A string: "targeted", "untargeted", "min", or "max".
30 obs_space : int | Tuple
31 Shape or space definition of the observation.
32 perturb_space : int | Tuple
33 Shape or space definition of the perturbation.
34 eps : int | float | np.ndarray
35 Perturbation budget. Interpretation depends on the norm.
36 norm : Optional[int | float], optional
37 Norm of the perturbation. One of 0, 1, 2, np.inf, or None.
38 is_proportional_mask : int | np.ndarray, optional
39 Mask to specify proportional perturbation behavior. Default is None.
40 device : str, optional
41 Device for computation. Default is 'cpu'.
42 """
44 def __init__(
45 self,
46 target,
47 obs_space: int | Tuple,
48 perturb_space: int | Tuple,
49 eps: int | float | np.ndarray,
50 norm: int | float = None,
51 is_proportional_mask: int | np.ndarray = None,
52 device="cpu",
53 ):
54 super().__init__(
55 obs_space, perturb_space, eps, norm, is_proportional_mask, device
56 )
57 self.target = target
58 self.signed = False
60 def _get_perturbation_map(self, observation_batch, batch_size):
61 raise NotImplementedError
63 def generate_perturbation(self, observation_batch: torch.Tensor):
64 """
65 Generates perturbations for a batch of observations.
67 Parameters
68 ----------
69 observation_batch : torch.Tensor
70 Batch of input observations to perturb.
72 Returns
73 -------
74 torch.Tensor
75 Batch of perturbations, reshaped to match input.
76 """
77 original_type = observation_batch.dtype
78 int_conversion = np.issubdtype(original_type, np.integer)
80 observation_batch = (
81 torch.as_tensor(observation_batch, device=self.device)
82 .to(torch.float32)
83 .requires_grad_(True)
84 )
86 batch_size = observation_batch.shape[0]
88 perturbation_map_batch = (
89 self._get_perturbation_map(observation_batch, batch_size)
90 .cpu()
91 .numpy()
92 .reshape(batch_size, -1)
93 )
95 perturbation_batch = self._scale_perturbation(
96 perturbation_map_batch, self.signed
97 )
99 if int_conversion:
100 perturbation_batch = perturbation_batch.round()
102 return perturbation_batch.reshape((batch_size,) + self.obs_shape)
105# ================================================================
106# FastGradientMethod_DiscreteAction
107# ================================================================
110class FastGradientMethodDiscreteAction(FastGradientMethodBaseClass):
111 """
112 Fast Gradient Method for agents with discrete action spaces.
114 Parameters
115 ----------
116 make_agent_fct : callable
117 Function that returns an agent instance.
118 target : callable | str
119 Attack target specification ("targeted", "untargeted", or callable).
120 obs_space : int | Tuple
121 Observation space specification.
122 perturb_space : int | Tuple
123 Perturbation space specification.
124 eps : int | float | np.ndarray
125 Perturbation budget.
126 norm : Optional[int | float], optional
127 Norm of the perturbation. Default is None.
128 is_proportional_mask : int | torch.Tensor, optional
129 Mask specifying proportional perturbation behavior. Default is None.
130 device : str, optional
131 Computation device. Default is 'cpu'.
132 """
134 def __init__(
135 self,
136 make_agent_fct,
137 target,
138 obs_space: int | Tuple,
139 perturb_space: int | Tuple,
140 eps: int | float | np.ndarray,
141 norm: int | float = None,
142 is_proportional_mask: int | torch.Tensor = None,
143 device="cpu",
144 ):
145 super().__init__(
146 target, obs_space, perturb_space, eps, norm, is_proportional_mask, device
147 )
148 self.agent = make_agent_fct()
149 self.agent_probs_fct = self.agent.probs
151 def _generate_perturbation_map(
152 self, action_probs_batch, target_batch, targeted, observation_batch
153 ):
154 """
155 Computes the gradient-based perturbation map using the loss between predictions and targets.
157 Parameters
158 ----------
159 predictions : torch.Tensor
160 Output of the agent (e.g., action logits, value estimates).
161 targets : torch.Tensor
162 Target values for the attack objective.
163 targeted : bool
164 Whether the attack is targeted.
165 observation_batch : torch.Tensor
166 Input batch with `requires_grad=True`.
168 Returns
169 -------
170 torch.Tensor
171 Gradient-based perturbation map.
172 """
173 if targeted:
174 loss = torch.nn.functional.cross_entropy(action_probs_batch, target_batch)
175 else:
176 loss = -torch.nn.functional.cross_entropy(action_probs_batch, target_batch)
177 loss.backward()
178 saliency_map_batch = observation_batch.grad.data
179 return saliency_map_batch
181 def _get_perturbation_map(self, observation_batch, batch_size):
182 """
183 Computes the perturbation map given a batch of observations.
185 Parameters
186 ----------
187 observation_batch : torch.Tensor
188 Batch of observations (requires_grad=True).
189 batch_size : int
190 Number of samples in the batch.
192 Returns
193 -------
194 torch.Tensor
195 Perturbation map shaped like the input.
196 """
197 action_probs_batch = self.agent_probs_fct(observation_batch)
198 action_probs_batch = action_probs_batch.reshape(batch_size, -1)
200 if callable(self.target):
201 with torch.no_grad():
202 target_batch, targeted = self.target(observation_batch)
203 target_batch = target_batch.reshape(batch_size, -1)
205 elif self.target == "targeted":
206 target_batch = torch.argmin(action_probs_batch, dim=1)
207 target_batch = target_batch.detach()
208 targeted = True
210 else: # untargeted
211 target_batch = torch.argmax(action_probs_batch, dim=1)
212 target_batch = target_batch.detach()
213 targeted = False
215 perturbation_map_batch = self._generate_perturbation_map(
216 action_probs_batch, target_batch, targeted, observation_batch
217 )
219 return perturbation_map_batch
222# ================================================================
223# FastGradientMethod_ContinuousAction
224# ================================================================
227class FastGradientMethodContinuousAction(FastGradientMethodBaseClass):
228 """
229 Fast Gradient Method for agents with continuous action spaces.
231 Parameters
232 ----------
233 make_agent_fct : callable
234 Function that returns an agent instance.
235 target : callable | str
236 Attack target specification ("self", callable, etc.).
237 obs_space : int | Tuple
238 Observation space specification.
239 perturb_space : int | Tuple
240 Perturbation space specification.
241 eps : int | float | np.ndarray
242 Perturbation budget.
243 norm : Optional[int | float], optional
244 Norm of the perturbation. Default is None.
245 is_proportional_mask : int | torch.Tensor, optional
246 Proportional perturbation mask. Default is None.
247 device : str, optional
248 Computation device. Default is 'cpu'.
249 """
251 def __init__(
252 self,
253 make_agent_fct,
254 target,
255 obs_space: int | Tuple,
256 perturb_space: int | Tuple,
257 eps: int | float | np.ndarray,
258 norm: int | float = None,
259 is_proportional_mask: int | torch.Tensor = None,
260 device="cpu",
261 ):
262 super().__init__(
263 target, obs_space, perturb_space, eps, norm, is_proportional_mask, device
264 )
265 self.agent = make_agent_fct()
266 self.agent_action_fct = self.agent.act
268 def _generate_perturbation_map(
269 self, action_batch, target_batch, targeted, observation_batch
270 ):
271 """
272 Computes the gradient-based perturbation map using the loss between predictions and targets.
274 Parameters
275 ----------
276 predictions : torch.Tensor
277 Output of the agent (e.g., action logits, value estimates).
278 targets : torch.Tensor
279 Target values for the attack objective.
280 targeted : bool
281 Whether the attack is targeted.
282 observation_batch : torch.Tensor
283 Input batch with `requires_grad=True`.
285 Returns
286 -------
287 torch.Tensor
288 Gradient-based perturbation map.
289 """
290 if targeted:
291 loss = torch.nn.functional.mse_loss(action_batch, target_batch)
292 else:
293 loss = -torch.nn.functional.mse_loss(action_batch, target_batch)
294 loss.backward()
295 perturbation_map_batch = observation_batch.grad.data
296 return perturbation_map_batch
298 def _get_perturbation_map(self, observation_batch, batch_size):
299 """
300 Computes the perturbation map given a batch of observations.
302 Parameters
303 ----------
304 observation_batch : torch.Tensor
305 Batch of observations (requires_grad=True).
306 batch_size : int
307 Number of samples in the batch.
309 Returns
310 -------
311 torch.Tensor
312 Perturbation map shaped like the input.
313 """
314 action_batch = self.agent_action_fct(observation_batch)
315 action_batch = action_batch.reshape((batch_size, -1))
317 if callable(self.target):
318 with torch.no_grad():
319 target_batch, targeted = self.target(observation_batch)
320 target_batch = target_batch.reshape((batch_size, -1))
322 else: # untargeted
323 target_batch = action_batch.clone().detach()
324 targeted = False
326 noise = (torch.rand_like(target_batch) - 0.5) * 1e-4
327 target_batch = target_batch + noise
329 perturbation_map_batch = self._generate_perturbation_map(
330 action_batch, target_batch, targeted, observation_batch
331 )
333 return perturbation_map_batch
336# ================================================================
337# FastGradientMethod_V_Critic
338# ================================================================
341class FastGradientMethodVCritic(FastGradientMethodBaseClass):
342 """
343 Fast Gradient Method attack for value-based critics.
345 Parameters
346 ----------
347 make_agent_fct : callable
348 Function that returns an agent instance.
349 target : callable | str
350 Target for the attack ("min", "max", or callable).
351 obs_space : int | Tuple
352 Observation space specification.
353 perturb_space : int | Tuple
354 Perturbation space specification.
355 eps : int | float | np.ndarray
356 Perturbation budget.
357 norm : Optional[int | float], optional
358 Norm of the perturbation. Default is None.
359 is_proportional_mask : int | torch.Tensor, optional
360 Proportional perturbation mask. Default is None.
361 device : str, optional
362 Computation device. Default is 'cpu'.
363 """
365 def __init__(
366 self,
367 make_agent_fct,
368 target,
369 obs_space: int | Tuple,
370 perturb_space: int | Tuple,
371 eps: int | float | np.ndarray,
372 norm: int | float = None,
373 is_proportional_mask: int | torch.Tensor = None,
374 device="cpu",
375 ):
376 super().__init__(
377 target, obs_space, perturb_space, eps, norm, is_proportional_mask, device
378 )
379 self.agent = make_agent_fct()
380 self.agent_v_value_fct = self.agent.v_value
382 def _generate_perturbation_map(
383 self, v_value_batch, target_batch, observation_batch
384 ):
385 """
386 Computes the gradient-based perturbation map using the loss between predictions and targets.
388 Parameters
389 ----------
390 predictions : torch.Tensor
391 Output of the agent (e.g., action logits, value estimates).
392 targets : torch.Tensor
393 Target values for the attack objective.
394 targeted : bool
395 Whether the attack is targeted.
396 observation_batch : torch.Tensor
397 Input batch with `requires_grad=True`.
399 Returns
400 -------
401 torch.Tensor
402 Gradient-based perturbation map.
403 """
404 loss = torch.nn.functional.mse_loss(v_value_batch, target_batch)
405 loss.backward()
406 perturbation_map_batch = observation_batch.grad.data
407 return perturbation_map_batch
409 def _get_perturbation_map(self, observation_batch, batch_size):
410 """
411 Computes the perturbation map given a batch of observations.
413 Parameters
414 ----------
415 observation_batch : torch.Tensor
416 Batch of observations (requires_grad=True).
417 batch_size : int
418 Number of samples in the batch.
420 Returns
421 -------
422 torch.Tensor
423 Perturbation map shaped like the input.
424 """
425 v_value_batch = self.agent_v_value_fct(observation_batch)
426 v_value_batch = v_value_batch.reshape((batch_size, -1))
428 if callable(self.target):
429 with torch.no_grad():
430 target_batch = self.target(observation_batch)
431 target_batch = target_batch.reshape((batch_size, -1))
432 noise = (torch.rand_like(target_batch) - 0.5) * 1e-4
433 target_batch = target_batch + noise
435 elif self.target == "max":
436 target_batch = v_value_batch.clone().detach() + 1
438 elif self.target == "min":
439 target_batch = v_value_batch.clone().detach() - 1
441 else:
442 raise ValueError(NO_TARGET_ERROR)
444 perturbation_map_batch = self._generate_perturbation_map(
445 v_value_batch, target_batch, observation_batch
446 )
448 return perturbation_map_batch
451# ================================================================
452# FastGradientMethod_Q_Critic
453# ================================================================
456class FastGradientMethodQCritic(FastGradientMethodBaseClass):
457 """
458 Fast Gradient Method for Q-value-based critics (continuous actions).
460 Parameters
461 ----------
462 make_agent_fct : callable
463 Function that returns an agent instance.
464 target : callable | str
465 Attack target specification ("min", "max", or callable).
466 obs_space : int | Tuple
467 Observation space specification.
468 perturb_space : int | Tuple
469 Perturbation space specification.
470 eps : int | float | np.ndarray
471 Perturbation budget.
472 norm : Optional[int | float], optional
473 Norm of the perturbation. Default is None.
474 is_proportional_mask : int | torch.Tensor, optional
475 Proportional perturbation mask. Default is None.
476 device : str, optional
477 Computation device. Default is 'cpu'.
478 """
480 def __init__(
481 self,
482 make_agent_fct,
483 target,
484 obs_space: int | Tuple,
485 perturb_space: int | Tuple,
486 eps: int | float | np.ndarray,
487 norm: int | float = None,
488 is_proportional_mask: int | torch.Tensor = None,
489 device="cpu",
490 ):
491 super().__init__(
492 target, obs_space, perturb_space, eps, norm, is_proportional_mask, device
493 )
494 self.agent = make_agent_fct()
495 self.agent_action_fct = self.agent.act
496 self.agent_q_value_fct = self.agent.q_value
498 def _generate_perturbation_map(
499 self, q_value_batch, target_batch, observation_batch
500 ):
501 """
502 Computes the gradient-based perturbation map using the loss between predictions and targets.
504 Parameters
505 ----------
506 predictions : torch.Tensor
507 Output of the agent (e.g., action logits, value estimates).
508 targets : torch.Tensor
509 Target values for the attack objective.
510 targeted : bool
511 Whether the attack is targeted.
512 observation_batch : torch.Tensor
513 Input batch with `requires_grad=True`.
515 Returns
516 -------
517 torch.Tensor
518 Gradient-based perturbation map.
519 """
520 loss = torch.nn.functional.mse_loss(q_value_batch, target_batch)
521 loss.backward()
522 perturbation_map_batch = observation_batch.grad.data
523 return perturbation_map_batch
525 def _get_perturbation_map(self, observation_batch, batch_size):
526 """
527 Computes the perturbation map given a batch of observations.
529 Parameters
530 ----------
531 observation_batch : torch.Tensor
532 Batch of observations (requires_grad=True).
533 batch_size : int
534 Number of samples in the batch.
536 Returns
537 -------
538 torch.Tensor
539 Perturbation map shaped like the input.
540 """
541 with torch.no_grad():
542 action_batch = self.agent_action_fct(observation_batch)
544 q_value_batch = self.agent_q_value_fct(observation_batch, action_batch)
546 q_value_batch = q_value_batch.reshape((batch_size, -1))
548 if callable(self.target):
549 with torch.no_grad():
550 target_batch = self.target(observation_batch)
551 target_batch = target_batch.reshape((batch_size, -1))
552 noise = (torch.rand_like(target_batch) - 0.5) * 1e-4
553 target_batch = target_batch + noise
555 elif self.target == "max":
556 target_batch = q_value_batch.clone().detach() + 1
558 elif self.target == "min":
559 target_batch = q_value_batch.clone().detach() - 1
561 else:
562 raise ValueError(NO_TARGET_ERROR)
564 perturbation_map_batch = self._generate_perturbation_map(
565 q_value_batch, target_batch, observation_batch
566 )
568 return perturbation_map_batch
571# ================================================================
572# FastGradientMethod_Q_ActorCritic
573# ================================================================
576class FastGradientMethodQActorCritic(FastGradientMethodBaseClass):
577 """
578 Fast Gradient Method for Q-value actor-critic agents with continuous actions.
580 Parameters
581 ----------
582 make_agent_fct : callable
583 Function that returns an agent instance.
584 target : callable | str
585 Attack target specification ("min", "max", or callable).
586 obs_space : int | Tuple
587 Observation space specification.
588 perturb_space : int | Tuple
589 Perturbation space specification.
590 eps : int | float | np.ndarray
591 Perturbation budget.
592 norm : Optional[int | float], optional
593 Norm of the perturbation. Default is None.
594 is_proportional_mask : int | torch.Tensor, optional
595 Proportional perturbation mask. Default is None.
596 device : str, optional
597 Computation device. Default is 'cpu'.
598 """
600 def __init__(
601 self,
602 make_agent_fct,
603 target,
604 obs_space: int | Tuple,
605 perturb_space: int | Tuple,
606 eps: int | float | np.ndarray,
607 norm: int | float = None,
608 is_proportional_mask: int | np.ndarray = None,
609 device="cpu",
610 ):
611 super().__init__(
612 target, obs_space, perturb_space, eps, norm, is_proportional_mask, device
613 )
614 self.agent = make_agent_fct()
615 self.agent_action_fct = self.agent.act
616 self.agent_q_value_fct = self.agent.q_value
618 def _generate_perturbation_map(
619 self, q_value_batch, target_batch, observation_batch
620 ):
621 """
622 Computes the gradient-based perturbation map using the loss between predictions and targets.
624 Parameters
625 ----------
626 predictions : torch.Tensor
627 Output of the agent (e.g., action logits, value estimates).
628 targets : torch.Tensor
629 Target values for the attack objective.
630 targeted : bool
631 Whether the attack is targeted.
632 observation_batch : torch.Tensor
633 Input batch with `requires_grad=True`.
635 Returns
636 -------
637 torch.Tensor
638 Gradient-based perturbation map.
639 """
640 loss = torch.nn.functional.mse_loss(q_value_batch, target_batch)
641 loss.backward()
642 perturbation_map_batch = observation_batch.grad.data
643 return perturbation_map_batch
645 def _get_perturbation_map(self, observation_batch, batch_size):
646 """
647 Computes the perturbation map given a batch of observations.
649 Parameters
650 ----------
651 observation_batch : torch.Tensor
652 Batch of observations (requires_grad=True).
653 batch_size : int
654 Number of samples in the batch.
656 Returns
657 -------
658 torch.Tensor
659 Perturbation map shaped like the input.
660 """
661 action_batch = self.agent_action_fct(observation_batch)
663 observation_batch_clone = observation_batch.clone()
664 q_value_batch = self.agent_q_value_fct(observation_batch_clone, action_batch)
666 q_value_batch = q_value_batch.reshape((batch_size, -1))
668 if callable(self.target):
669 with torch.no_grad():
670 target_batch = self.target(observation_batch)
671 target_batch = target_batch.reshape((batch_size, -1))
672 noise = (torch.rand_like(target_batch) - 0.5) * 1e-4
673 target_batch = target_batch + noise
675 elif self.target == "max":
676 target_batch = q_value_batch.clone().detach() + 1
678 elif self.target == "min":
679 target_batch = q_value_batch.clone().detach() - 1
681 else:
682 raise ValueError(NO_TARGET_ERROR)
684 perturbation_map_batch = self._generate_perturbation_map(
685 q_value_batch, target_batch, observation_batch
686 )
688 return perturbation_map_batch