Coverage for adaro_rl / wrappers / obs_attack_adversary_env.py: 28%
36 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 07:50 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 07:50 +0000
1import numpy as np
2import torch
3import gymnasium as gym
4from stable_baselines3.common.vec_env import VecEnvWrapper
7class ObsAttackAdversaryEnv(VecEnvWrapper):
8 """
9 Vectorized adversarial-environment wrapper that sits “above” an agent s VecEnv,
10 allowing an adversary to craft observation perturbations and see their effect on
11 the agent s behavior.
13 At each step_async call, the adversary provides a batch of flat perturbation
14 vectors. Those are scaled and applied to the last true observations, passed
15 into the agent for action selection, and then forwarded to the inner VecEnv.
16 Rewards returned in step_wait are negated, as the adversary s objective is to
17 minimize the agent s return.
19 Parameters
20 ----------
21 venv : VecEnv
22 The underlying vectorized environment.
23 agent_predict_fct : callable
24 Function with signature
25 ``(obs_batch: np.ndarray, deterministic: bool) -> (action_batch, state)``.
26 scale_perturbation_fct : callable
27 Function that takes an array of raw adversary actions of shape
28 ``(n_envs, *obs_shape)`` and returns scaled perturbations of same shape.
29 apply_perturbation_to_obs_fct : callable
30 Function with signature
31 ``(perturbations: np.ndarray, obs_batch: np.ndarray) -> np.ndarray``
32 which returns the adversarial observations of shape ``(n_envs, *obs_shape)``.
34 Attributes
35 ----------
36 last_obs : np.ndarray
37 The most recent true observations from the underlying environment,
38 used as the base for applying new perturbations.
39 """
41 def __init__(
42 self,
43 venv,
44 make_agent_fct,
45 agent_predict_fct,
46 scale_perturbation_fct,
47 apply_perturbation_to_obs_fct,
48 ):
49 super().__init__(venv)
51 self.agent = make_agent_fct()
52 self.agent_predict_fct = self.agent.predict
54 self.scale_perturbation_fct = scale_perturbation_fct
55 self.apply_perturbation_to_obs_fct = apply_perturbation_to_obs_fct
57 # Build adversary action_space: flattened obs perturbations
58 obs_shape = self.observation_space.shape
59 flat_dim = int(np.prod(obs_shape))
60 self.action_space = gym.spaces.Box(
61 low=-np.ones(flat_dim, dtype=np.float32),
62 high=np.ones(flat_dim, dtype=np.float32),
63 dtype=np.float32,
64 )
66 # Placeholder for the last true observations
67 self.last_obs = None
69 def reset(self, **kwargs):
70 """
71 Reset the underlying VecEnv and stash the initial true observations.
73 Parameters
74 ----------
75 kwargs
76 Any keyword arguments to forward to `self.venv.reset`.
78 Returns
79 -------
80 obs : np.ndarray
81 The initial observations from the environment, shape
82 ``(n_envs, *obs_shape)``.
83 """
84 batch_obs = self.venv.reset(**kwargs)
85 self.last_obs = batch_obs.copy()
86 return batch_obs
88 def step_async(self, adversary_actions: np.ndarray):
89 """
90 Receive a batch of adversarial action vectors, craft perturbed observations,
91 query the agent for its actions, and forward those to the inner VecEnv.
93 Parameters
94 ----------
95 adversary_actions : np.ndarray
96 Array of shape (n_envs, flat_obs_dim) containing raw perturbation vectors.
97 """
98 n_envs = adversary_actions.shape[0]
99 obs_shape = self.observation_space.shape
101 # reshape to (n_envs, *obs_shape)
102 adv_flat = adversary_actions.reshape((n_envs,) + obs_shape)
104 # scale and apply perturbation
105 perturbations = self.scale_perturbation_fct(adv_flat)
106 adv_obs = self.apply_perturbation_to_obs_fct(perturbations, self.last_obs)
108 # get agent’s actions on perturbed obs
109 agent_actions, _ = self.agent_predict_fct(adv_obs, deterministic=True)
110 if isinstance(agent_actions, torch.Tensor):
111 agent_actions = agent_actions.detach().cpu().numpy()
113 # forward to the real env
114 self.venv.step_async(agent_actions)
116 def step_wait(self):
117 """
118 Wait for the inner VecEnv to step, then retrieve next true obs, original reward,
119 done flags, and infos. Stash the new observations, negate rewards, and return.
121 Returns
122 -------
123 obs : np.ndarray
124 Next true observations, shape (n_envs, *obs_shape).
125 neg_reward : np.ndarray
126 Negated rewards from the inner env, shape (n_envs,).
127 dones : np.ndarray
128 Boolean array indicating episode termination, shape (n_envs,).
129 infos : List[dict]
130 Info dictionaries returned by the underlying VecEnv.
131 """
132 batch_obs, rewards, dones, infos = self.venv.step_wait()
133 self.last_obs = batch_obs.copy()
134 adv_rewards = -rewards
135 return batch_obs, adv_rewards, dones, infos
137 def render(self, **kwargs):
138 """
139 Render the underlying VecEnv.
141 Parameters
142 ----------
143 kwargs
144 Forwarded to `self.venv.render`.
146 Returns
147 -------
148 Render output (e.g., array or None), as defined by the inner VecEnv.
149 """
150 return self.venv.render(**kwargs)