Coverage for adaro_rl / attacks / fgsm.py: 100%
24 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 07:50 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 07:50 +0000
1from typing import Tuple
2import torch
3import numpy as np
5from .fgm import (
6 FastGradientMethodDiscreteAction,
7 FastGradientMethodContinuousAction,
8 FastGradientMethodVCritic,
9 FastGradientMethodQCritic,
10 FastGradientMethodQActorCritic,
11)
14# ================================================================
15# FastGradientSignMethod_DiscreteAction
16# # ================================================================
19class FastGradientSignMethodDiscreteAction(FastGradientMethodDiscreteAction):
20 """
21 Fast Gradient Sign Method (FGSM) for discrete-action agents.
23 Inherits from FastGradientMethod_DiscreteAction and sets `signed=True`
24 to enable sign-based perturbations (i.e., using the sign of the gradient).
26 Parameters
27 ----------
28 make_agent_fct : callable
29 Function that returns an agent instance.
30 target : callable | str
31 Target for the attack. Can be a callable or a string ("targeted", "untargeted").
32 obs_space : int | Tuple
33 Shape or space of the observation.
34 perturb_space : int | Tuple
35 Shape or space of the perturbation.
36 eps : int | float | np.ndarray
37 Perturbation budget.
38 norm : Optional[int | float]
39 Norm to use for the perturbation (e.g., 1, 2, np.inf, or None).
40 is_proportional_mask : int | torch.Tensor, optional
41 Mask to control proportional perturbation. Default is None.
42 device : str, optional
43 Device to run the attack on (e.g., 'cpu', 'cuda:0'). Default is 'cpu'.
44 """
46 def __init__(
47 self,
48 make_agent_fct,
49 target,
50 obs_space: int | Tuple,
51 perturb_space: int | Tuple,
52 eps: int | float | np.ndarray,
53 norm: int | float = None,
54 is_proportional_mask: int | torch.Tensor = None,
55 device="cpu",
56 ):
57 super().__init__(
58 make_agent_fct,
59 target,
60 obs_space,
61 perturb_space,
62 eps,
63 norm,
64 is_proportional_mask,
65 device,
66 )
67 self.signed = True
70# ================================================================
71# FastGradientSignMethod_ContinuousAction ######################
72# ================================================================
75class FastGradientSignMethodContinuousAction(FastGradientMethodContinuousAction):
76 """
77 Fast Gradient Sign Method (FGSM) for continuous-action agents.
79 Enables sign-based perturbation for agents that return continuous actions.
81 Parameters
82 ----------
83 make_agent_fct : callable
84 Function that returns an agent instance.
85 target : callable | str
86 Target for the attack. Can be a callable or a string (e.g., "self").
87 obs_space : int | Tuple
88 Shape or space of the observation.
89 perturb_space : int | Tuple
90 Shape or space of the perturbation.
91 eps : int | float | np.ndarray
92 Perturbation budget.
93 norm : Optional[int | float]
94 Norm to use for the perturbation (e.g., 1, 2, np.inf, or None).
95 is_proportional_mask : int | torch.Tensor, optional
96 Mask to control proportional perturbation. Default is None.
97 device : str, optional
98 Device to run the attack on (e.g., 'cpu', 'cuda:0'). Default is 'cpu'.
99 """
101 def __init__(
102 self,
103 make_agent_fct,
104 target,
105 obs_space: int | Tuple,
106 perturb_space: int | Tuple,
107 eps: int | float | np.ndarray,
108 norm: int | float = None,
109 is_proportional_mask: int | torch.Tensor = None,
110 device="cpu",
111 ):
112 super().__init__(
113 make_agent_fct,
114 target,
115 obs_space,
116 perturb_space,
117 eps,
118 norm,
119 is_proportional_mask,
120 device,
121 )
122 self.signed = True
125# ================================================================
126# FastGradientSignMethod_V_Critic
127# ================================================================
130class FastGradientSignMethodVCritic(FastGradientMethodVCritic):
131 """
132 Fast Gradient Sign Method (FGSM) for value-based critics (V-functions).
134 Enables sign-based adversarial perturbations based on the gradient of the value estimate.
136 Parameters
137 ----------
138 make_agent_fct : callable
139 Function that returns an agent instance.
140 target : callable | str
141 Target specification ("min", "max", or a callable).
142 obs_space : int | Tuple
143 Shape or space of the observation.
144 perturb_space : int | Tuple
145 Shape or space of the perturbation.
146 eps : int | float | np.ndarray
147 Perturbation budget.
148 norm : Optional[int | float]
149 Norm to use for the perturbation.
150 is_proportional_mask : int | torch.Tensor, optional
151 Mask to control proportional perturbation. Default is None.
152 device : str, optional
153 Device to run the attack on. Default is 'cpu'.
154 """
156 def __init__(
157 self,
158 make_agent_fct,
159 target,
160 obs_space: int | Tuple,
161 perturb_space: int | Tuple,
162 eps: int | float | np.ndarray,
163 norm: int | float = None,
164 is_proportional_mask: int | torch.Tensor = None,
165 device="cpu",
166 ):
167 super().__init__(
168 make_agent_fct,
169 target,
170 obs_space,
171 perturb_space,
172 eps,
173 norm,
174 is_proportional_mask,
175 device,
176 )
177 self.signed = True
180# ================================================================
181# FastGradientSignMethod_Q_Critic
182# ================================================================
185class FastGradientSignMethodQCritic(FastGradientMethodQCritic):
186 """
187 Fast Gradient Sign Method (FGSM) for Q-value critics with continuous actions.
189 Applies sign-based perturbations to maximize or minimize Q-values,
190 based on the critic model.
192 Parameters
193 ----------
194 make_agent_fct : callable
195 Function that returns an agent instance.
196 target : callable | str
197 Target specification ("min", "max", or a callable).
198 obs_space : int | Tuple
199 Shape or space of the observation.
200 perturb_space : int | Tuple
201 Shape or space of the perturbation.
202 eps : int | float | np.ndarray
203 Perturbation budget.
204 norm : Optional[int | float]
205 Norm to use for the perturbation.
206 is_proportional_mask : int | torch.Tensor, optional
207 Mask to control proportional perturbation. Default is None.
208 device : str, optional
209 Device to run the attack on. Default is 'cpu'.
210 """
212 def __init__(
213 self,
214 make_agent_fct,
215 target,
216 obs_space: int | Tuple,
217 perturb_space: int | Tuple,
218 eps: int | float | np.ndarray,
219 norm: int | float = None,
220 is_proportional_mask: int | torch.Tensor = None,
221 device="cpu",
222 ):
223 super().__init__(
224 make_agent_fct,
225 target,
226 obs_space,
227 perturb_space,
228 eps,
229 norm,
230 is_proportional_mask,
231 device,
232 )
233 self.signed = True
236# ================================================================
237# FastGradientSignMethod_Q_ActorCritic ################
238# ================================================================
241class FastGradientSignMethodQActorCritic(FastGradientMethodQActorCritic):
242 """
243 Fast Gradient Sign Method (FGSM) for actor-critic agents using Q-values.
245 Enables sign-based perturbations for actor-critic models where Q-values depend on
246 both observation and action.
248 Parameters
249 ----------
250 make_agent_fct : callable
251 Function that returns an agent instance.
252 target : callable | str
253 Target specification ("min", "max", or a callable).
254 obs_space : int | Tuple
255 Shape or space of the observation.
256 perturb_space : int | Tuple
257 Shape or space of the perturbation.
258 eps : int | float | np.ndarray
259 Perturbation budget.
260 norm : Optional[int | float]
261 Norm to use for the perturbation.
262 is_proportional_mask : int | torch.Tensor, optional
263 Mask to control proportional perturbation. Default is None.
264 device : str, optional
265 Device to run the attack on. Default is 'cpu'.
266 """
268 def __init__(
269 self,
270 make_agent_fct,
271 target,
272 obs_space: int | Tuple,
273 perturb_space: int | Tuple,
274 eps: int | float | np.ndarray,
275 norm: int | float = None,
276 is_proportional_mask: int | torch.Tensor = None,
277 device="cpu",
278 ):
279 super().__init__(
280 make_agent_fct,
281 target,
282 obs_space,
283 perturb_space,
284 eps,
285 norm,
286 is_proportional_mask,
287 device,
288 )
289 self.signed = True