Coverage for adaro_rl / attacks / fgm.py: 84%

159 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-14 07:50 +0000

1from typing import Tuple 

2import torch 

3import numpy as np 

4from .base_attack import BaseAttack 

5 

6 

7NO_TARGET_ERROR = "No target given. Please provide a callable or a string target." 

8 

9 

10# =========================================== 

11# FastGradientMethod_DiscreteAction 

12# =========================================== 

13 

14 

15class FastGradientMethodBaseClass(BaseAttack): 

16 """ 

17 Base class for Fast Gradient Method (FGM) attacks. 

18 

19 This class provides a common interface for implementing FGM-based adversarial 

20 attacks across different agent types (discrete, continuous, critic, etc.). 

21 

22 Parameters 

23 ---------- 

24 make_agent_fct : callable 

25 Function that returns an agent instance. 

26 target : callable | str 

27 Target specification for the attack. Can be: 

28 - A callable that returns target values and a `targeted` boolean. 

29 - A string: "targeted", "untargeted", "min", or "max". 

30 obs_space : int | Tuple 

31 Shape or space definition of the observation. 

32 perturb_space : int | Tuple 

33 Shape or space definition of the perturbation. 

34 eps : int | float | np.ndarray 

35 Perturbation budget. Interpretation depends on the norm. 

36 norm : Optional[int | float], optional 

37 Norm of the perturbation. One of 0, 1, 2, np.inf, or None. 

38 is_proportional_mask : int | np.ndarray, optional 

39 Mask to specify proportional perturbation behavior. Default is None. 

40 device : str, optional 

41 Device for computation. Default is 'cpu'. 

42 """ 

43 

44 def __init__( 

45 self, 

46 target, 

47 obs_space: int | Tuple, 

48 perturb_space: int | Tuple, 

49 eps: int | float | np.ndarray, 

50 norm: int | float = None, 

51 is_proportional_mask: int | np.ndarray = None, 

52 device="cpu", 

53 ): 

54 super().__init__( 

55 obs_space, perturb_space, eps, norm, is_proportional_mask, device 

56 ) 

57 self.target = target 

58 self.signed = False 

59 

60 def _get_perturbation_map(self, observation_batch, batch_size): 

61 raise NotImplementedError 

62 

63 def generate_perturbation(self, observation_batch: torch.Tensor): 

64 """ 

65 Generates perturbations for a batch of observations. 

66 

67 Parameters 

68 ---------- 

69 observation_batch : torch.Tensor 

70 Batch of input observations to perturb. 

71 

72 Returns 

73 ------- 

74 torch.Tensor 

75 Batch of perturbations, reshaped to match input. 

76 """ 

77 original_type = observation_batch.dtype 

78 int_conversion = np.issubdtype(original_type, np.integer) 

79 

80 observation_batch = ( 

81 torch.as_tensor(observation_batch, device=self.device) 

82 .to(torch.float32) 

83 .requires_grad_(True) 

84 ) 

85 

86 batch_size = observation_batch.shape[0] 

87 

88 perturbation_map_batch = ( 

89 self._get_perturbation_map(observation_batch, batch_size) 

90 .cpu() 

91 .numpy() 

92 .reshape(batch_size, -1) 

93 ) 

94 

95 perturbation_batch = self._scale_perturbation( 

96 perturbation_map_batch, self.signed 

97 ) 

98 

99 if int_conversion: 

100 perturbation_batch = perturbation_batch.round() 

101 

102 return perturbation_batch.reshape((batch_size,) + self.obs_shape) 

103 

104 

105# ================================================================ 

106# FastGradientMethod_DiscreteAction 

107# ================================================================ 

108 

109 

110class FastGradientMethodDiscreteAction(FastGradientMethodBaseClass): 

111 """ 

112 Fast Gradient Method for agents with discrete action spaces. 

113 

114 Parameters 

115 ---------- 

116 make_agent_fct : callable 

117 Function that returns an agent instance. 

118 target : callable | str 

119 Attack target specification ("targeted", "untargeted", or callable). 

120 obs_space : int | Tuple 

121 Observation space specification. 

122 perturb_space : int | Tuple 

123 Perturbation space specification. 

124 eps : int | float | np.ndarray 

125 Perturbation budget. 

126 norm : Optional[int | float], optional 

127 Norm of the perturbation. Default is None. 

128 is_proportional_mask : int | torch.Tensor, optional 

129 Mask specifying proportional perturbation behavior. Default is None. 

130 device : str, optional 

131 Computation device. Default is 'cpu'. 

132 """ 

133 

134 def __init__( 

135 self, 

136 make_agent_fct, 

137 target, 

138 obs_space: int | Tuple, 

139 perturb_space: int | Tuple, 

140 eps: int | float | np.ndarray, 

141 norm: int | float = None, 

142 is_proportional_mask: int | torch.Tensor = None, 

143 device="cpu", 

144 ): 

145 super().__init__( 

146 target, obs_space, perturb_space, eps, norm, is_proportional_mask, device 

147 ) 

148 self.agent = make_agent_fct() 

149 self.agent_probs_fct = self.agent.probs 

150 

151 def _generate_perturbation_map( 

152 self, action_probs_batch, target_batch, targeted, observation_batch 

153 ): 

154 """ 

155 Computes the gradient-based perturbation map using the loss between predictions and targets. 

156 

157 Parameters 

158 ---------- 

159 predictions : torch.Tensor 

160 Output of the agent (e.g., action logits, value estimates). 

161 targets : torch.Tensor 

162 Target values for the attack objective. 

163 targeted : bool 

164 Whether the attack is targeted. 

165 observation_batch : torch.Tensor 

166 Input batch with `requires_grad=True`. 

167 

168 Returns 

169 ------- 

170 torch.Tensor 

171 Gradient-based perturbation map. 

172 """ 

173 if targeted: 

174 loss = torch.nn.functional.cross_entropy(action_probs_batch, target_batch) 

175 else: 

176 loss = -torch.nn.functional.cross_entropy(action_probs_batch, target_batch) 

177 loss.backward() 

178 saliency_map_batch = observation_batch.grad.data 

179 return saliency_map_batch 

180 

181 def _get_perturbation_map(self, observation_batch, batch_size): 

182 """ 

183 Computes the perturbation map given a batch of observations. 

184 

185 Parameters 

186 ---------- 

187 observation_batch : torch.Tensor 

188 Batch of observations (requires_grad=True). 

189 batch_size : int 

190 Number of samples in the batch. 

191 

192 Returns 

193 ------- 

194 torch.Tensor 

195 Perturbation map shaped like the input. 

196 """ 

197 action_probs_batch = self.agent_probs_fct(observation_batch) 

198 action_probs_batch = action_probs_batch.reshape(batch_size, -1) 

199 

200 if callable(self.target): 

201 with torch.no_grad(): 

202 target_batch, targeted = self.target(observation_batch) 

203 target_batch = target_batch.reshape(batch_size, -1) 

204 

205 elif self.target == "targeted": 

206 target_batch = torch.argmin(action_probs_batch, dim=1) 

207 target_batch = target_batch.detach() 

208 targeted = True 

209 

210 else: # untargeted 

211 target_batch = torch.argmax(action_probs_batch, dim=1) 

212 target_batch = target_batch.detach() 

213 targeted = False 

214 

215 perturbation_map_batch = self._generate_perturbation_map( 

216 action_probs_batch, target_batch, targeted, observation_batch 

217 ) 

218 

219 return perturbation_map_batch 

220 

221 

222# ================================================================ 

223# FastGradientMethod_ContinuousAction 

224# ================================================================ 

225 

226 

227class FastGradientMethodContinuousAction(FastGradientMethodBaseClass): 

228 """ 

229 Fast Gradient Method for agents with continuous action spaces. 

230 

231 Parameters 

232 ---------- 

233 make_agent_fct : callable 

234 Function that returns an agent instance. 

235 target : callable | str 

236 Attack target specification ("self", callable, etc.). 

237 obs_space : int | Tuple 

238 Observation space specification. 

239 perturb_space : int | Tuple 

240 Perturbation space specification. 

241 eps : int | float | np.ndarray 

242 Perturbation budget. 

243 norm : Optional[int | float], optional 

244 Norm of the perturbation. Default is None. 

245 is_proportional_mask : int | torch.Tensor, optional 

246 Proportional perturbation mask. Default is None. 

247 device : str, optional 

248 Computation device. Default is 'cpu'. 

249 """ 

250 

251 def __init__( 

252 self, 

253 make_agent_fct, 

254 target, 

255 obs_space: int | Tuple, 

256 perturb_space: int | Tuple, 

257 eps: int | float | np.ndarray, 

258 norm: int | float = None, 

259 is_proportional_mask: int | torch.Tensor = None, 

260 device="cpu", 

261 ): 

262 super().__init__( 

263 target, obs_space, perturb_space, eps, norm, is_proportional_mask, device 

264 ) 

265 self.agent = make_agent_fct() 

266 self.agent_action_fct = self.agent.act 

267 

268 def _generate_perturbation_map( 

269 self, action_batch, target_batch, targeted, observation_batch 

270 ): 

271 """ 

272 Computes the gradient-based perturbation map using the loss between predictions and targets. 

273 

274 Parameters 

275 ---------- 

276 predictions : torch.Tensor 

277 Output of the agent (e.g., action logits, value estimates). 

278 targets : torch.Tensor 

279 Target values for the attack objective. 

280 targeted : bool 

281 Whether the attack is targeted. 

282 observation_batch : torch.Tensor 

283 Input batch with `requires_grad=True`. 

284 

285 Returns 

286 ------- 

287 torch.Tensor 

288 Gradient-based perturbation map. 

289 """ 

290 if targeted: 

291 loss = torch.nn.functional.mse_loss(action_batch, target_batch) 

292 else: 

293 loss = -torch.nn.functional.mse_loss(action_batch, target_batch) 

294 loss.backward() 

295 perturbation_map_batch = observation_batch.grad.data 

296 return perturbation_map_batch 

297 

298 def _get_perturbation_map(self, observation_batch, batch_size): 

299 """ 

300 Computes the perturbation map given a batch of observations. 

301 

302 Parameters 

303 ---------- 

304 observation_batch : torch.Tensor 

305 Batch of observations (requires_grad=True). 

306 batch_size : int 

307 Number of samples in the batch. 

308 

309 Returns 

310 ------- 

311 torch.Tensor 

312 Perturbation map shaped like the input. 

313 """ 

314 action_batch = self.agent_action_fct(observation_batch) 

315 action_batch = action_batch.reshape((batch_size, -1)) 

316 

317 if callable(self.target): 

318 with torch.no_grad(): 

319 target_batch, targeted = self.target(observation_batch) 

320 target_batch = target_batch.reshape((batch_size, -1)) 

321 

322 else: # untargeted 

323 target_batch = action_batch.clone().detach() 

324 targeted = False 

325 

326 noise = (torch.rand_like(target_batch) - 0.5) * 1e-4 

327 target_batch = target_batch + noise 

328 

329 perturbation_map_batch = self._generate_perturbation_map( 

330 action_batch, target_batch, targeted, observation_batch 

331 ) 

332 

333 return perturbation_map_batch 

334 

335 

336# ================================================================ 

337# FastGradientMethod_V_Critic 

338# ================================================================ 

339 

340 

341class FastGradientMethodVCritic(FastGradientMethodBaseClass): 

342 """ 

343 Fast Gradient Method attack for value-based critics. 

344 

345 Parameters 

346 ---------- 

347 make_agent_fct : callable 

348 Function that returns an agent instance. 

349 target : callable | str 

350 Target for the attack ("min", "max", or callable). 

351 obs_space : int | Tuple 

352 Observation space specification. 

353 perturb_space : int | Tuple 

354 Perturbation space specification. 

355 eps : int | float | np.ndarray 

356 Perturbation budget. 

357 norm : Optional[int | float], optional 

358 Norm of the perturbation. Default is None. 

359 is_proportional_mask : int | torch.Tensor, optional 

360 Proportional perturbation mask. Default is None. 

361 device : str, optional 

362 Computation device. Default is 'cpu'. 

363 """ 

364 

365 def __init__( 

366 self, 

367 make_agent_fct, 

368 target, 

369 obs_space: int | Tuple, 

370 perturb_space: int | Tuple, 

371 eps: int | float | np.ndarray, 

372 norm: int | float = None, 

373 is_proportional_mask: int | torch.Tensor = None, 

374 device="cpu", 

375 ): 

376 super().__init__( 

377 target, obs_space, perturb_space, eps, norm, is_proportional_mask, device 

378 ) 

379 self.agent = make_agent_fct() 

380 self.agent_v_value_fct = self.agent.v_value 

381 

382 def _generate_perturbation_map( 

383 self, v_value_batch, target_batch, observation_batch 

384 ): 

385 """ 

386 Computes the gradient-based perturbation map using the loss between predictions and targets. 

387 

388 Parameters 

389 ---------- 

390 predictions : torch.Tensor 

391 Output of the agent (e.g., action logits, value estimates). 

392 targets : torch.Tensor 

393 Target values for the attack objective. 

394 targeted : bool 

395 Whether the attack is targeted. 

396 observation_batch : torch.Tensor 

397 Input batch with `requires_grad=True`. 

398 

399 Returns 

400 ------- 

401 torch.Tensor 

402 Gradient-based perturbation map. 

403 """ 

404 loss = torch.nn.functional.mse_loss(v_value_batch, target_batch) 

405 loss.backward() 

406 perturbation_map_batch = observation_batch.grad.data 

407 return perturbation_map_batch 

408 

409 def _get_perturbation_map(self, observation_batch, batch_size): 

410 """ 

411 Computes the perturbation map given a batch of observations. 

412 

413 Parameters 

414 ---------- 

415 observation_batch : torch.Tensor 

416 Batch of observations (requires_grad=True). 

417 batch_size : int 

418 Number of samples in the batch. 

419 

420 Returns 

421 ------- 

422 torch.Tensor 

423 Perturbation map shaped like the input. 

424 """ 

425 v_value_batch = self.agent_v_value_fct(observation_batch) 

426 v_value_batch = v_value_batch.reshape((batch_size, -1)) 

427 

428 if callable(self.target): 

429 with torch.no_grad(): 

430 target_batch = self.target(observation_batch) 

431 target_batch = target_batch.reshape((batch_size, -1)) 

432 noise = (torch.rand_like(target_batch) - 0.5) * 1e-4 

433 target_batch = target_batch + noise 

434 

435 elif self.target == "max": 

436 target_batch = v_value_batch.clone().detach() + 1 

437 

438 elif self.target == "min": 

439 target_batch = v_value_batch.clone().detach() - 1 

440 

441 else: 

442 raise ValueError(NO_TARGET_ERROR) 

443 

444 perturbation_map_batch = self._generate_perturbation_map( 

445 v_value_batch, target_batch, observation_batch 

446 ) 

447 

448 return perturbation_map_batch 

449 

450 

451# ================================================================ 

452# FastGradientMethod_Q_Critic 

453# ================================================================ 

454 

455 

456class FastGradientMethodQCritic(FastGradientMethodBaseClass): 

457 """ 

458 Fast Gradient Method for Q-value-based critics (continuous actions). 

459 

460 Parameters 

461 ---------- 

462 make_agent_fct : callable 

463 Function that returns an agent instance. 

464 target : callable | str 

465 Attack target specification ("min", "max", or callable). 

466 obs_space : int | Tuple 

467 Observation space specification. 

468 perturb_space : int | Tuple 

469 Perturbation space specification. 

470 eps : int | float | np.ndarray 

471 Perturbation budget. 

472 norm : Optional[int | float], optional 

473 Norm of the perturbation. Default is None. 

474 is_proportional_mask : int | torch.Tensor, optional 

475 Proportional perturbation mask. Default is None. 

476 device : str, optional 

477 Computation device. Default is 'cpu'. 

478 """ 

479 

480 def __init__( 

481 self, 

482 make_agent_fct, 

483 target, 

484 obs_space: int | Tuple, 

485 perturb_space: int | Tuple, 

486 eps: int | float | np.ndarray, 

487 norm: int | float = None, 

488 is_proportional_mask: int | torch.Tensor = None, 

489 device="cpu", 

490 ): 

491 super().__init__( 

492 target, obs_space, perturb_space, eps, norm, is_proportional_mask, device 

493 ) 

494 self.agent = make_agent_fct() 

495 self.agent_action_fct = self.agent.act 

496 self.agent_q_value_fct = self.agent.q_value 

497 

498 def _generate_perturbation_map( 

499 self, q_value_batch, target_batch, observation_batch 

500 ): 

501 """ 

502 Computes the gradient-based perturbation map using the loss between predictions and targets. 

503 

504 Parameters 

505 ---------- 

506 predictions : torch.Tensor 

507 Output of the agent (e.g., action logits, value estimates). 

508 targets : torch.Tensor 

509 Target values for the attack objective. 

510 targeted : bool 

511 Whether the attack is targeted. 

512 observation_batch : torch.Tensor 

513 Input batch with `requires_grad=True`. 

514 

515 Returns 

516 ------- 

517 torch.Tensor 

518 Gradient-based perturbation map. 

519 """ 

520 loss = torch.nn.functional.mse_loss(q_value_batch, target_batch) 

521 loss.backward() 

522 perturbation_map_batch = observation_batch.grad.data 

523 return perturbation_map_batch 

524 

525 def _get_perturbation_map(self, observation_batch, batch_size): 

526 """ 

527 Computes the perturbation map given a batch of observations. 

528 

529 Parameters 

530 ---------- 

531 observation_batch : torch.Tensor 

532 Batch of observations (requires_grad=True). 

533 batch_size : int 

534 Number of samples in the batch. 

535 

536 Returns 

537 ------- 

538 torch.Tensor 

539 Perturbation map shaped like the input. 

540 """ 

541 with torch.no_grad(): 

542 action_batch = self.agent_action_fct(observation_batch) 

543 

544 q_value_batch = self.agent_q_value_fct(observation_batch, action_batch) 

545 

546 q_value_batch = q_value_batch.reshape((batch_size, -1)) 

547 

548 if callable(self.target): 

549 with torch.no_grad(): 

550 target_batch = self.target(observation_batch) 

551 target_batch = target_batch.reshape((batch_size, -1)) 

552 noise = (torch.rand_like(target_batch) - 0.5) * 1e-4 

553 target_batch = target_batch + noise 

554 

555 elif self.target == "max": 

556 target_batch = q_value_batch.clone().detach() + 1 

557 

558 elif self.target == "min": 

559 target_batch = q_value_batch.clone().detach() - 1 

560 

561 else: 

562 raise ValueError(NO_TARGET_ERROR) 

563 

564 perturbation_map_batch = self._generate_perturbation_map( 

565 q_value_batch, target_batch, observation_batch 

566 ) 

567 

568 return perturbation_map_batch 

569 

570 

571# ================================================================ 

572# FastGradientMethod_Q_ActorCritic 

573# ================================================================ 

574 

575 

576class FastGradientMethodQActorCritic(FastGradientMethodBaseClass): 

577 """ 

578 Fast Gradient Method for Q-value actor-critic agents with continuous actions. 

579 

580 Parameters 

581 ---------- 

582 make_agent_fct : callable 

583 Function that returns an agent instance. 

584 target : callable | str 

585 Attack target specification ("min", "max", or callable). 

586 obs_space : int | Tuple 

587 Observation space specification. 

588 perturb_space : int | Tuple 

589 Perturbation space specification. 

590 eps : int | float | np.ndarray 

591 Perturbation budget. 

592 norm : Optional[int | float], optional 

593 Norm of the perturbation. Default is None. 

594 is_proportional_mask : int | torch.Tensor, optional 

595 Proportional perturbation mask. Default is None. 

596 device : str, optional 

597 Computation device. Default is 'cpu'. 

598 """ 

599 

600 def __init__( 

601 self, 

602 make_agent_fct, 

603 target, 

604 obs_space: int | Tuple, 

605 perturb_space: int | Tuple, 

606 eps: int | float | np.ndarray, 

607 norm: int | float = None, 

608 is_proportional_mask: int | np.ndarray = None, 

609 device="cpu", 

610 ): 

611 super().__init__( 

612 target, obs_space, perturb_space, eps, norm, is_proportional_mask, device 

613 ) 

614 self.agent = make_agent_fct() 

615 self.agent_action_fct = self.agent.act 

616 self.agent_q_value_fct = self.agent.q_value 

617 

618 def _generate_perturbation_map( 

619 self, q_value_batch, target_batch, observation_batch 

620 ): 

621 """ 

622 Computes the gradient-based perturbation map using the loss between predictions and targets. 

623 

624 Parameters 

625 ---------- 

626 predictions : torch.Tensor 

627 Output of the agent (e.g., action logits, value estimates). 

628 targets : torch.Tensor 

629 Target values for the attack objective. 

630 targeted : bool 

631 Whether the attack is targeted. 

632 observation_batch : torch.Tensor 

633 Input batch with `requires_grad=True`. 

634 

635 Returns 

636 ------- 

637 torch.Tensor 

638 Gradient-based perturbation map. 

639 """ 

640 loss = torch.nn.functional.mse_loss(q_value_batch, target_batch) 

641 loss.backward() 

642 perturbation_map_batch = observation_batch.grad.data 

643 return perturbation_map_batch 

644 

645 def _get_perturbation_map(self, observation_batch, batch_size): 

646 """ 

647 Computes the perturbation map given a batch of observations. 

648 

649 Parameters 

650 ---------- 

651 observation_batch : torch.Tensor 

652 Batch of observations (requires_grad=True). 

653 batch_size : int 

654 Number of samples in the batch. 

655 

656 Returns 

657 ------- 

658 torch.Tensor 

659 Perturbation map shaped like the input. 

660 """ 

661 action_batch = self.agent_action_fct(observation_batch) 

662 

663 observation_batch_clone = observation_batch.clone() 

664 q_value_batch = self.agent_q_value_fct(observation_batch_clone, action_batch) 

665 

666 q_value_batch = q_value_batch.reshape((batch_size, -1)) 

667 

668 if callable(self.target): 

669 with torch.no_grad(): 

670 target_batch = self.target(observation_batch) 

671 target_batch = target_batch.reshape((batch_size, -1)) 

672 noise = (torch.rand_like(target_batch) - 0.5) * 1e-4 

673 target_batch = target_batch + noise 

674 

675 elif self.target == "max": 

676 target_batch = q_value_batch.clone().detach() + 1 

677 

678 elif self.target == "min": 

679 target_batch = q_value_batch.clone().detach() - 1 

680 

681 else: 

682 raise ValueError(NO_TARGET_ERROR) 

683 

684 perturbation_map_batch = self._generate_perturbation_map( 

685 q_value_batch, target_batch, observation_batch 

686 ) 

687 

688 return perturbation_map_batch