Coverage for adaro_rl / wrappers / obs_attack_agent_env.py: 90%
30 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 07:50 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 07:50 +0000
1import torch
2from stable_baselines3.common.vec_env import VecEnvWrapper
5class ObsAttackAgentEnv(VecEnvWrapper):
6 """
7 Vectorized environment wrapper applying adversarial attacks on observations from the agent's perspective.
9 At each reset and step, observations are optionally perturbed by a provided attack function at a given frequency,
10 simulating adversarial manipulations of the agent’s input.
12 Parameters
13 ----------
14 venv : VecEnv
15 The underlying vectorized environment to wrap.
16 attack_fct : callable
17 Function that takes a batch of observations (np.ndarray of shape (n_envs, ...))
18 and returns a batch of perturbed observations of the same shape.
19 freq : int, optional
20 Frequency (in steps) at which to apply `attack_fct`. Every `freq` steps an attack
21 is applied; defaults to 1 (attack every step).
23 Attributes
24 ----------
25 freq : int
26 Attack frequency.
27 count : int
28 Step counter used to determine when to apply perturbations.
29 render_mode : Any
30 Render mode forwarded from the underlying environment.
31 action_space : gym.Space
32 Action space of the underlying environment.
33 observation_space : gym.Space
34 Observation space of the underlying environment.
35 """
37 def __init__(self, venv, make_attack_fct, freq=1):
38 super().__init__(venv)
40 self.attack = make_attack_fct()
41 self.attack_fct = self.attack.generate_adv_obs
43 self.action_space = self.venv.action_space
44 self.observation_space = self.venv.observation_space
45 self.render_mode = self.venv.render_mode
47 self.freq = freq
48 self.count = 0
50 def reset(self, seed=None):
51 """
52 Reset the environment and apply an initial adversarial attack.
54 Parameters
55 ----------
56 seed : int, optional
57 Random seed for environment reset.
59 Returns
60 -------
61 batch_adv_obs : np.ndarray
62 Batch of perturbed observations after reset.
63 """
64 batch_obs = self.venv.reset()
65 batch_adv_obs = self.attack_fct(batch_obs)
66 self.count = 1
67 return batch_adv_obs
69 def step_async(self, actions):
70 """
71 Asynchronously send agent actions to the environment.
73 Parameters
74 ----------
75 actions : np.ndarray or torch.Tensor
76 Batch of actions from the agent.
77 """
78 if isinstance(actions, torch.Tensor):
79 actions = actions.detach().cpu().numpy()
80 self.venv.step_async(actions)
82 def step_wait(self):
83 """
84 Wait for the environment step to complete, then optionally apply an adversarial attack.
86 Returns
87 -------
88 batch_adv_obs : np.ndarray
89 Batch of (possibly perturbed) observations.
90 reward : np.ndarray
91 Batch of rewards from the environment.
92 terminal : np.ndarray
93 Batch of done flags indicating episode termination.
94 info : list of dict
95 Batch of info dictionaries.
96 """
97 batch_obs, reward, terminal, info = self.venv.step_wait()
98 if self.count % self.freq == 0:
99 batch_adv_obs = self.attack_fct(batch_obs)
100 else:
101 batch_adv_obs = batch_obs
102 self.count += 1
103 return batch_adv_obs, reward, terminal, info
105 def render(self):
106 """
107 Render the underlying environment.
109 Returns
110 -------
111 Any
112 Render output from the underlying environment (e.g., image array or None).
113 """
114 return self.venv.render()