Coverage for adaro_rl / attacks / base_attack.py: 90%
89 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 07:50 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 07:50 +0000
1import warnings
2import numpy as np
5class BaseAttack:
6 """
7 Base class for performing adversarial attacks.
9 Parameters
10 ----------
11 obs_space : gym.Space
12 Observation space.
13 perturb_space : gym.Space
14 Space of the perturbation.
15 eps : int | float | np.ndarray]
16 Magnitude of the perturbation.
17 - If eps is an int and norm is 0 (i.e., an L0 "norm"), then eps specifies the maximum number
18 of features that may be perturbed.
19 In this case, each perturbed feature is allowed to change to the extreme values authorized
20 by the observation space.
21 - If eps is an int or a float and norm is 1 or greater (including float('inf')), then eps
22 specifies the global bound on the perturbation,
23 i.e., the maximum distance between the original observation and the perturbed one,
24 measured in the given norm.
25 - If eps is a np.ndarray, it must have the same shape as the input observation and should
26 be used when no global norm is applied (i.e., norm is None).
27 In this case, eps specifies an elementwise bound: each element of the observation can be
28 perturbed by at most the corresponding value in eps.
29 norm : Optional[int | float]
30 Norm of the perturbation. Allowed values are 0, 1, 2, ..., float('inf'), or None.
31 - If norm is 0, eps must be an int.
32 - If norm is 1 or greater (including float('inf')), eps must be an int or a float.
33 - If norm is None, eps must be a np.ndarray.
34 is_proportional_mask : int | np.ndarray, optional, only used if eps is a np.ndarray
35 Mask indicating perturbation mode for each element in the observation.
36 If the mask value is 1, the element is perturbed proportionally, else absolutely.
37 device : str, optional
38 Device to run the attack on ('cpu', 'cuda:0', etc.). Default is 'cpu'.
39 """
41 def __init__(
42 self,
43 obs_space,
44 perturb_space,
45 eps: int | float | np.ndarray,
46 norm: int | float = None,
47 is_proportional_mask: int | np.ndarray = None,
48 device="cpu",
49 seed=None,
50 ):
51 self.obs_space = obs_space
52 self.obs_shape = self.obs_space.shape
53 self.perturb_space = perturb_space
54 self.perturb_shape = perturb_space.shape
55 self.norm = norm
56 self.eps = eps
57 self.is_proportional_mask = is_proportional_mask
58 self.device = device
59 self.seed = seed
60 self.rng = np.random.default_rng(seed)
62 # Precompute and cache constant arrays on CPU
63 self.flat_obs_shape = (int(np.prod(self.obs_shape)),)
64 self.flat_perturb_shape = (int(np.prod(self.perturb_shape)),)
66 if self.norm is None:
67 # No global norm: use elementwise bounds
68 assert isinstance(self.eps, np.ndarray), (
69 "When norm is None, eps must be a numpy array for elementwise bounds."
70 )
71 assert self.eps.shape == self.perturb_shape, (
72 "eps array shape must match the perturbation shape."
73 )
74 # Define elementwise bounds
75 self.max_eps = self.eps
77 elif self.norm == 0:
78 # L0 norm: eps must be an int (the maximum number of features to change)
79 assert isinstance(self.eps, int), "For L0 norm, eps must be an int."
80 # Define elementwise bounds
81 self.max_eps = np.full(self.perturb_shape, float("inf"))
83 elif self.norm >= 1:
84 # Global norm: eps must be a scalar (int or float)
85 assert isinstance(self.eps, (int, float)), (
86 "For global norm bounds (norm >= 1), eps must be an int or a float."
87 )
88 # Define elementwise bounds
89 self.max_eps = np.full(self.perturb_shape, self.eps)
91 else:
92 raise ValueError(
93 "Invalid norm value: norm must be None, 0, an int >= 1, or a float('inf'))."
94 )
96 self.flat_max_eps = self.max_eps.flatten()
97 self.flat_search_space = np.zeros(self.flat_perturb_shape)
98 self.flat_search_space[self.flat_max_eps != 0] = 1
100 if self.is_proportional_mask is None:
101 self.is_proportional_mask = np.zeros(self.perturb_shape)
102 elif not isinstance(self.is_proportional_mask, np.ndarray):
103 self.is_proportional_mask = np.full(
104 self.perturb_shape, self.is_proportional_mask
105 )
106 else:
107 assert self.is_proportional_mask.shape == self.obs_shape
109 # Precompute low and high arrays for observations and perturbations
110 self.obs_low = self.obs_space.low
111 self.obs_high = self.obs_space.high
112 self.perturb_low = self.perturb_space.low.flatten()
113 self.perturb_high = self.perturb_space.high.flatten()
115 # Ensure shapes match
116 assert self.obs_low.shape == self.obs_shape
117 assert self.obs_high.shape == self.obs_shape
119 def _scale_perturbation(self, perturbation_map_batch: np.ndarray, signed=False):
120 """
121 Scale and normalize a batch of perturbation maps according to the specified norm constraint.
123 Parameters
124 ----------
125 perturbation_map_batch : np.ndarray
126 A batch of perturbation maps with shape (batch_size, ...) representing the raw
127 perturbations
128 to be applied to each input in the batch.
130 signed : bool, optional
131 If True, the perturbation direction is taken as the sign of the gradient (FGSM-style).
132 If False, uses the raw gradient direction (FGM-style). Default is False.
134 Returns
135 -------
136 np.ndarray
137 A batch of scaled and clipped perturbation maps with the same shape as the input,
138 normalized to satisfy the specified norm constraint (`self.norm`) and bounded within
139 the allowed perturbation range [`self.perturb_low`, `self.perturb_high`].
141 Notes
142 -----
143 - If `self.norm` is 0, the perturbation is applied to only the top-k features per sample,
144 where k = `self.eps`.
145 - If `self.norm` is ∞ and `signed=True`, this behaves like FGSM.
146 - If `signed=False`, the perturbation is normalized to lie on the surface of the Lₚ ball of
147 radius `self.eps`,
148 where p = `self.norm`.
149 - The perturbation is masked using `self.flat_search_space` and clamped between
150 `self.perturb_low` and `self.perturb_high` after scaling.
151 """
152 batch_size = perturbation_map_batch.shape[0]
153 original_shape = perturbation_map_batch.shape
155 # Flatten the perturbation maps
156 perturbation_map_batch = perturbation_map_batch.reshape(batch_size, -1)
158 # Apply the search space mask
159 perturbation_map_batch = perturbation_map_batch * self.flat_search_space
161 sign_map_batch = np.sign(perturbation_map_batch)
163 if self.norm is None:
164 perturbation_map_batch = sign_map_batch * self.flat_max_eps
166 elif self.norm == 0:
167 # For each example in the batch
168 abs_perturbation_map_batch = np.abs(perturbation_map_batch)
169 k = self.eps
171 # Get the indices of the top-k elements per example
172 inds = np.argpartition(-abs_perturbation_map_batch, k, axis=1)[:, :k]
174 # Zero array for the perturbation map batch
175 perturbation_map_batch.fill(0)
177 # Set the top-k elements
178 batch_indices = np.arange(batch_size)[:, None]
179 perturbation_map_batch[batch_indices, inds] = 1
180 perturbation_map_batch *= sign_map_batch
181 perturbation_map_batch *= self.flat_max_eps
183 elif self.norm >= 1:
184 if signed:
185 perturbation_map_batch = sign_map_batch * self.flat_max_eps
187 perturbation_eps_s = np.linalg.norm(
188 perturbation_map_batch, ord=self.norm, axis=1, keepdims=True
189 )
190 perturbation_map_batch = perturbation_map_batch * (
191 self.eps / perturbation_eps_s
192 )
194 # Clamp perturbation to allowed ranges
195 perturbation_map_batch = np.clip(
196 perturbation_map_batch, self.perturb_low, self.perturb_high
197 )
199 # Reshape back to original shape
200 perturbation_map_batch = perturbation_map_batch.reshape(original_shape)
202 return perturbation_map_batch
204 def generate_perturbation(self, observation_batch: np.ndarray):
205 """
206 Abstract method to generate perturbations for a batch of observations.
208 Parameters
209 ----------
210 observation_batch : np.ndarray
211 Batch of input observations.
213 Raises
214 ------
215 NotImplementedError
216 This method must be implemented in subclasses.
217 """
218 raise NotImplementedError
220 def _apply_proportional_perturbation_mask(
221 self, perturbation_map: np.ndarray, observation_batch: np.ndarray
222 ):
223 """
224 Applies the proportional perturbation mask to a batch of perturbation maps.
226 Each element in the perturbation is either applied absolutely or scaled
227 proportionally to the corresponding observation_batch value, based on the mask.
229 Parameters
230 ----------
231 perturbation_map : np.ndarray
232 Batch of perturbation maps of shape (batch_size, *obs_shape).
233 observation_batch : np.ndarray
234 Batch of observations of shape (batch_size, *obs_shape).
236 Returns
237 -------
238 np.ndarray
239 Batch of masked perturbation maps.
240 """
241 original_type = observation_batch.dtype
242 observation_batch_float = observation_batch.astype(np.float32)
244 # Ensure proportional mask is broadcastable over batch dimension
245 if self.is_proportional_mask.ndim < observation_batch_float.ndim:
246 mask = np.expand_dims(self.is_proportional_mask, axis=0)
247 else:
248 mask = self.is_proportional_mask
250 perturbed_observation_batch_float = np.where(
251 mask == 1,
252 perturbation_map * np.abs(observation_batch_float),
253 perturbation_map,
254 )
256 if original_type == np.uint8:
257 perturbed_observation_batch_float = np.clip(
258 perturbed_observation_batch_float, 0, 255
259 )
261 return perturbed_observation_batch_float.astype(original_type)
263 def apply_perturbation_on_obs(
264 self, perturbation_batch: np.ndarray, observation_batch: np.ndarray
265 ):
266 """
267 Applies perturbations to a batch of observations and clamps to observation bounds.
269 Parameters
270 ----------
271 perturbation_batch : np.ndarray
272 Batch of perturbation vectors.
273 observation_batch : np.ndarray
274 Batch of original observations.
276 Returns
277 -------
278 np.ndarray
279 Batch of perturbed and clamped observations.
280 """
281 # Apply proportional perturbation mask to the entire batch
282 perturbation_batch = self._apply_proportional_perturbation_mask(
283 perturbation_batch, observation_batch
284 )
286 # Apply perturbation and clamp
287 perturbed_observation_batch = np.clip(
288 observation_batch + perturbation_batch, self.obs_low, self.obs_high
289 )
291 return perturbed_observation_batch
293 def generate_adv_obs(self, observation_batch: np.ndarray):
294 """
295 Generates adversarial observations by computing and applying perturbations.
297 Parameters
298 ----------
299 observation_batch : np.ndarray
300 Batch of input observations.
302 Returns
303 -------
304 np.ndarray
305 Batch of adversarially perturbed observations.
307 Warnings
308 --------
309 Will raise a warning and return original observations if NaN values are present in the
310 result.
311 """
313 perturbation_batch = self.generate_perturbation(observation_batch)
315 # Apply perturbation
316 perturbed_observation_batch = self.apply_perturbation_on_obs(
317 perturbation_batch, observation_batch
318 )
320 if np.isnan(perturbed_observation_batch).any():
321 warnings.warn("Perturbation cancelled because it contains NaN values")
322 return observation_batch
323 else:
324 return perturbed_observation_batch