Coverage for adaro_rl / wrappers / obs_attack_agent_env.py: 90%

30 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-14 07:50 +0000

1import torch 

2from stable_baselines3.common.vec_env import VecEnvWrapper 

3 

4 

5class ObsAttackAgentEnv(VecEnvWrapper): 

6 """ 

7 Vectorized environment wrapper applying adversarial attacks on observations from the agent's perspective. 

8 

9 At each reset and step, observations are optionally perturbed by a provided attack function at a given frequency, 

10 simulating adversarial manipulations of the agent’s input. 

11 

12 Parameters 

13 ---------- 

14 venv : VecEnv 

15 The underlying vectorized environment to wrap. 

16 attack_fct : callable 

17 Function that takes a batch of observations (np.ndarray of shape (n_envs, ...)) 

18 and returns a batch of perturbed observations of the same shape. 

19 freq : int, optional 

20 Frequency (in steps) at which to apply `attack_fct`. Every `freq` steps an attack 

21 is applied; defaults to 1 (attack every step). 

22 

23 Attributes 

24 ---------- 

25 freq : int 

26 Attack frequency. 

27 count : int 

28 Step counter used to determine when to apply perturbations. 

29 render_mode : Any 

30 Render mode forwarded from the underlying environment. 

31 action_space : gym.Space 

32 Action space of the underlying environment. 

33 observation_space : gym.Space 

34 Observation space of the underlying environment. 

35 """ 

36 

37 def __init__(self, venv, make_attack_fct, freq=1): 

38 super().__init__(venv) 

39 

40 self.attack = make_attack_fct() 

41 self.attack_fct = self.attack.generate_adv_obs 

42 

43 self.action_space = self.venv.action_space 

44 self.observation_space = self.venv.observation_space 

45 self.render_mode = self.venv.render_mode 

46 

47 self.freq = freq 

48 self.count = 0 

49 

50 def reset(self, seed=None): 

51 """ 

52 Reset the environment and apply an initial adversarial attack. 

53 

54 Parameters 

55 ---------- 

56 seed : int, optional 

57 Random seed for environment reset. 

58 

59 Returns 

60 ------- 

61 batch_adv_obs : np.ndarray 

62 Batch of perturbed observations after reset. 

63 """ 

64 batch_obs = self.venv.reset() 

65 batch_adv_obs = self.attack_fct(batch_obs) 

66 self.count = 1 

67 return batch_adv_obs 

68 

69 def step_async(self, actions): 

70 """ 

71 Asynchronously send agent actions to the environment. 

72 

73 Parameters 

74 ---------- 

75 actions : np.ndarray or torch.Tensor 

76 Batch of actions from the agent. 

77 """ 

78 if isinstance(actions, torch.Tensor): 

79 actions = actions.detach().cpu().numpy() 

80 self.venv.step_async(actions) 

81 

82 def step_wait(self): 

83 """ 

84 Wait for the environment step to complete, then optionally apply an adversarial attack. 

85 

86 Returns 

87 ------- 

88 batch_adv_obs : np.ndarray 

89 Batch of (possibly perturbed) observations. 

90 reward : np.ndarray 

91 Batch of rewards from the environment. 

92 terminal : np.ndarray 

93 Batch of done flags indicating episode termination. 

94 info : list of dict 

95 Batch of info dictionaries. 

96 """ 

97 batch_obs, reward, terminal, info = self.venv.step_wait() 

98 if self.count % self.freq == 0: 

99 batch_adv_obs = self.attack_fct(batch_obs) 

100 else: 

101 batch_adv_obs = batch_obs 

102 self.count += 1 

103 return batch_adv_obs, reward, terminal, info 

104 

105 def render(self): 

106 """ 

107 Render the underlying environment. 

108 

109 Returns 

110 ------- 

111 Any 

112 Render output from the underlying environment (e.g., image array or None). 

113 """ 

114 return self.venv.render()