Coverage for adaro_rl / attacks / base_attack.py: 90%

89 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-14 07:50 +0000

1import warnings 

2import numpy as np 

3 

4 

5class BaseAttack: 

6 """ 

7 Base class for performing adversarial attacks. 

8 

9 Parameters 

10 ---------- 

11 obs_space : gym.Space 

12 Observation space. 

13 perturb_space : gym.Space 

14 Space of the perturbation. 

15 eps : int | float | np.ndarray] 

16 Magnitude of the perturbation. 

17 - If eps is an int and norm is 0 (i.e., an L0 "norm"), then eps specifies the maximum number 

18 of features that may be perturbed. 

19 In this case, each perturbed feature is allowed to change to the extreme values authorized 

20 by the observation space. 

21 - If eps is an int or a float and norm is 1 or greater (including float('inf')), then eps 

22 specifies the global bound on the perturbation, 

23 i.e., the maximum distance between the original observation and the perturbed one, 

24 measured in the given norm. 

25 - If eps is a np.ndarray, it must have the same shape as the input observation and should 

26 be used when no global norm is applied (i.e., norm is None). 

27 In this case, eps specifies an elementwise bound: each element of the observation can be 

28 perturbed by at most the corresponding value in eps. 

29 norm : Optional[int | float] 

30 Norm of the perturbation. Allowed values are 0, 1, 2, ..., float('inf'), or None. 

31 - If norm is 0, eps must be an int. 

32 - If norm is 1 or greater (including float('inf')), eps must be an int or a float. 

33 - If norm is None, eps must be a np.ndarray. 

34 is_proportional_mask : int | np.ndarray, optional, only used if eps is a np.ndarray 

35 Mask indicating perturbation mode for each element in the observation. 

36 If the mask value is 1, the element is perturbed proportionally, else absolutely. 

37 device : str, optional 

38 Device to run the attack on ('cpu', 'cuda:0', etc.). Default is 'cpu'. 

39 """ 

40 

41 def __init__( 

42 self, 

43 obs_space, 

44 perturb_space, 

45 eps: int | float | np.ndarray, 

46 norm: int | float = None, 

47 is_proportional_mask: int | np.ndarray = None, 

48 device="cpu", 

49 seed=None, 

50 ): 

51 self.obs_space = obs_space 

52 self.obs_shape = self.obs_space.shape 

53 self.perturb_space = perturb_space 

54 self.perturb_shape = perturb_space.shape 

55 self.norm = norm 

56 self.eps = eps 

57 self.is_proportional_mask = is_proportional_mask 

58 self.device = device 

59 self.seed = seed 

60 self.rng = np.random.default_rng(seed) 

61 

62 # Precompute and cache constant arrays on CPU 

63 self.flat_obs_shape = (int(np.prod(self.obs_shape)),) 

64 self.flat_perturb_shape = (int(np.prod(self.perturb_shape)),) 

65 

66 if self.norm is None: 

67 # No global norm: use elementwise bounds 

68 assert isinstance(self.eps, np.ndarray), ( 

69 "When norm is None, eps must be a numpy array for elementwise bounds." 

70 ) 

71 assert self.eps.shape == self.perturb_shape, ( 

72 "eps array shape must match the perturbation shape." 

73 ) 

74 # Define elementwise bounds 

75 self.max_eps = self.eps 

76 

77 elif self.norm == 0: 

78 # L0 norm: eps must be an int (the maximum number of features to change) 

79 assert isinstance(self.eps, int), "For L0 norm, eps must be an int." 

80 # Define elementwise bounds 

81 self.max_eps = np.full(self.perturb_shape, float("inf")) 

82 

83 elif self.norm >= 1: 

84 # Global norm: eps must be a scalar (int or float) 

85 assert isinstance(self.eps, (int, float)), ( 

86 "For global norm bounds (norm >= 1), eps must be an int or a float." 

87 ) 

88 # Define elementwise bounds 

89 self.max_eps = np.full(self.perturb_shape, self.eps) 

90 

91 else: 

92 raise ValueError( 

93 "Invalid norm value: norm must be None, 0, an int >= 1, or a float('inf'))." 

94 ) 

95 

96 self.flat_max_eps = self.max_eps.flatten() 

97 self.flat_search_space = np.zeros(self.flat_perturb_shape) 

98 self.flat_search_space[self.flat_max_eps != 0] = 1 

99 

100 if self.is_proportional_mask is None: 

101 self.is_proportional_mask = np.zeros(self.perturb_shape) 

102 elif not isinstance(self.is_proportional_mask, np.ndarray): 

103 self.is_proportional_mask = np.full( 

104 self.perturb_shape, self.is_proportional_mask 

105 ) 

106 else: 

107 assert self.is_proportional_mask.shape == self.obs_shape 

108 

109 # Precompute low and high arrays for observations and perturbations 

110 self.obs_low = self.obs_space.low 

111 self.obs_high = self.obs_space.high 

112 self.perturb_low = self.perturb_space.low.flatten() 

113 self.perturb_high = self.perturb_space.high.flatten() 

114 

115 # Ensure shapes match 

116 assert self.obs_low.shape == self.obs_shape 

117 assert self.obs_high.shape == self.obs_shape 

118 

119 def _scale_perturbation(self, perturbation_map_batch: np.ndarray, signed=False): 

120 """ 

121 Scale and normalize a batch of perturbation maps according to the specified norm constraint. 

122 

123 Parameters 

124 ---------- 

125 perturbation_map_batch : np.ndarray 

126 A batch of perturbation maps with shape (batch_size, ...) representing the raw 

127 perturbations 

128 to be applied to each input in the batch. 

129 

130 signed : bool, optional 

131 If True, the perturbation direction is taken as the sign of the gradient (FGSM-style). 

132 If False, uses the raw gradient direction (FGM-style). Default is False. 

133 

134 Returns 

135 ------- 

136 np.ndarray 

137 A batch of scaled and clipped perturbation maps with the same shape as the input, 

138 normalized to satisfy the specified norm constraint (`self.norm`) and bounded within 

139 the allowed perturbation range [`self.perturb_low`, `self.perturb_high`]. 

140 

141 Notes 

142 ----- 

143 - If `self.norm` is 0, the perturbation is applied to only the top-k features per sample, 

144 where k = `self.eps`. 

145 - If `self.norm` is ∞ and `signed=True`, this behaves like FGSM. 

146 - If `signed=False`, the perturbation is normalized to lie on the surface of the Lₚ ball of 

147 radius `self.eps`, 

148 where p = `self.norm`. 

149 - The perturbation is masked using `self.flat_search_space` and clamped between 

150 `self.perturb_low` and `self.perturb_high` after scaling. 

151 """ 

152 batch_size = perturbation_map_batch.shape[0] 

153 original_shape = perturbation_map_batch.shape 

154 

155 # Flatten the perturbation maps 

156 perturbation_map_batch = perturbation_map_batch.reshape(batch_size, -1) 

157 

158 # Apply the search space mask 

159 perturbation_map_batch = perturbation_map_batch * self.flat_search_space 

160 

161 sign_map_batch = np.sign(perturbation_map_batch) 

162 

163 if self.norm is None: 

164 perturbation_map_batch = sign_map_batch * self.flat_max_eps 

165 

166 elif self.norm == 0: 

167 # For each example in the batch 

168 abs_perturbation_map_batch = np.abs(perturbation_map_batch) 

169 k = self.eps 

170 

171 # Get the indices of the top-k elements per example 

172 inds = np.argpartition(-abs_perturbation_map_batch, k, axis=1)[:, :k] 

173 

174 # Zero array for the perturbation map batch 

175 perturbation_map_batch.fill(0) 

176 

177 # Set the top-k elements 

178 batch_indices = np.arange(batch_size)[:, None] 

179 perturbation_map_batch[batch_indices, inds] = 1 

180 perturbation_map_batch *= sign_map_batch 

181 perturbation_map_batch *= self.flat_max_eps 

182 

183 elif self.norm >= 1: 

184 if signed: 

185 perturbation_map_batch = sign_map_batch * self.flat_max_eps 

186 

187 perturbation_eps_s = np.linalg.norm( 

188 perturbation_map_batch, ord=self.norm, axis=1, keepdims=True 

189 ) 

190 perturbation_map_batch = perturbation_map_batch * ( 

191 self.eps / perturbation_eps_s 

192 ) 

193 

194 # Clamp perturbation to allowed ranges 

195 perturbation_map_batch = np.clip( 

196 perturbation_map_batch, self.perturb_low, self.perturb_high 

197 ) 

198 

199 # Reshape back to original shape 

200 perturbation_map_batch = perturbation_map_batch.reshape(original_shape) 

201 

202 return perturbation_map_batch 

203 

204 def generate_perturbation(self, observation_batch: np.ndarray): 

205 """ 

206 Abstract method to generate perturbations for a batch of observations. 

207 

208 Parameters 

209 ---------- 

210 observation_batch : np.ndarray 

211 Batch of input observations. 

212 

213 Raises 

214 ------ 

215 NotImplementedError 

216 This method must be implemented in subclasses. 

217 """ 

218 raise NotImplementedError 

219 

220 def _apply_proportional_perturbation_mask( 

221 self, perturbation_map: np.ndarray, observation_batch: np.ndarray 

222 ): 

223 """ 

224 Applies the proportional perturbation mask to a batch of perturbation maps. 

225 

226 Each element in the perturbation is either applied absolutely or scaled 

227 proportionally to the corresponding observation_batch value, based on the mask. 

228 

229 Parameters 

230 ---------- 

231 perturbation_map : np.ndarray 

232 Batch of perturbation maps of shape (batch_size, *obs_shape). 

233 observation_batch : np.ndarray 

234 Batch of observations of shape (batch_size, *obs_shape). 

235 

236 Returns 

237 ------- 

238 np.ndarray 

239 Batch of masked perturbation maps. 

240 """ 

241 original_type = observation_batch.dtype 

242 observation_batch_float = observation_batch.astype(np.float32) 

243 

244 # Ensure proportional mask is broadcastable over batch dimension 

245 if self.is_proportional_mask.ndim < observation_batch_float.ndim: 

246 mask = np.expand_dims(self.is_proportional_mask, axis=0) 

247 else: 

248 mask = self.is_proportional_mask 

249 

250 perturbed_observation_batch_float = np.where( 

251 mask == 1, 

252 perturbation_map * np.abs(observation_batch_float), 

253 perturbation_map, 

254 ) 

255 

256 if original_type == np.uint8: 

257 perturbed_observation_batch_float = np.clip( 

258 perturbed_observation_batch_float, 0, 255 

259 ) 

260 

261 return perturbed_observation_batch_float.astype(original_type) 

262 

263 def apply_perturbation_on_obs( 

264 self, perturbation_batch: np.ndarray, observation_batch: np.ndarray 

265 ): 

266 """ 

267 Applies perturbations to a batch of observations and clamps to observation bounds. 

268 

269 Parameters 

270 ---------- 

271 perturbation_batch : np.ndarray 

272 Batch of perturbation vectors. 

273 observation_batch : np.ndarray 

274 Batch of original observations. 

275 

276 Returns 

277 ------- 

278 np.ndarray 

279 Batch of perturbed and clamped observations. 

280 """ 

281 # Apply proportional perturbation mask to the entire batch 

282 perturbation_batch = self._apply_proportional_perturbation_mask( 

283 perturbation_batch, observation_batch 

284 ) 

285 

286 # Apply perturbation and clamp 

287 perturbed_observation_batch = np.clip( 

288 observation_batch + perturbation_batch, self.obs_low, self.obs_high 

289 ) 

290 

291 return perturbed_observation_batch 

292 

293 def generate_adv_obs(self, observation_batch: np.ndarray): 

294 """ 

295 Generates adversarial observations by computing and applying perturbations. 

296 

297 Parameters 

298 ---------- 

299 observation_batch : np.ndarray 

300 Batch of input observations. 

301 

302 Returns 

303 ------- 

304 np.ndarray 

305 Batch of adversarially perturbed observations. 

306 

307 Warnings 

308 -------- 

309 Will raise a warning and return original observations if NaN values are present in the 

310 result. 

311 """ 

312 

313 perturbation_batch = self.generate_perturbation(observation_batch) 

314 

315 # Apply perturbation 

316 perturbed_observation_batch = self.apply_perturbation_on_obs( 

317 perturbation_batch, observation_batch 

318 ) 

319 

320 if np.isnan(perturbed_observation_batch).any(): 

321 warnings.warn("Perturbation cancelled because it contains NaN values") 

322 return observation_batch 

323 else: 

324 return perturbed_observation_batch