Coverage for adaro_rl / wrappers / obs_attack_adversary_env.py: 28%

36 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-14 07:50 +0000

1import numpy as np 

2import torch 

3import gymnasium as gym 

4from stable_baselines3.common.vec_env import VecEnvWrapper 

5 

6 

7class ObsAttackAdversaryEnv(VecEnvWrapper): 

8 """ 

9 Vectorized adversarial-environment wrapper that sits “above” an agent s VecEnv, 

10 allowing an adversary to craft observation perturbations and see their effect on 

11 the agent s behavior. 

12 

13 At each step_async call, the adversary provides a batch of flat perturbation 

14 vectors. Those are scaled and applied to the last true observations, passed 

15 into the agent for action selection, and then forwarded to the inner VecEnv. 

16 Rewards returned in step_wait are negated, as the adversary s objective is to 

17 minimize the agent s return. 

18 

19 Parameters 

20 ---------- 

21 venv : VecEnv 

22 The underlying vectorized environment. 

23 agent_predict_fct : callable 

24 Function with signature 

25 ``(obs_batch: np.ndarray, deterministic: bool) -> (action_batch, state)``. 

26 scale_perturbation_fct : callable 

27 Function that takes an array of raw adversary actions of shape 

28 ``(n_envs, *obs_shape)`` and returns scaled perturbations of same shape. 

29 apply_perturbation_to_obs_fct : callable 

30 Function with signature 

31 ``(perturbations: np.ndarray, obs_batch: np.ndarray) -> np.ndarray`` 

32 which returns the adversarial observations of shape ``(n_envs, *obs_shape)``. 

33 

34 Attributes 

35 ---------- 

36 last_obs : np.ndarray 

37 The most recent true observations from the underlying environment, 

38 used as the base for applying new perturbations. 

39 """ 

40 

41 def __init__( 

42 self, 

43 venv, 

44 make_agent_fct, 

45 agent_predict_fct, 

46 scale_perturbation_fct, 

47 apply_perturbation_to_obs_fct, 

48 ): 

49 super().__init__(venv) 

50 

51 self.agent = make_agent_fct() 

52 self.agent_predict_fct = self.agent.predict 

53 

54 self.scale_perturbation_fct = scale_perturbation_fct 

55 self.apply_perturbation_to_obs_fct = apply_perturbation_to_obs_fct 

56 

57 # Build adversary action_space: flattened obs perturbations 

58 obs_shape = self.observation_space.shape 

59 flat_dim = int(np.prod(obs_shape)) 

60 self.action_space = gym.spaces.Box( 

61 low=-np.ones(flat_dim, dtype=np.float32), 

62 high=np.ones(flat_dim, dtype=np.float32), 

63 dtype=np.float32, 

64 ) 

65 

66 # Placeholder for the last true observations 

67 self.last_obs = None 

68 

69 def reset(self, **kwargs): 

70 """ 

71 Reset the underlying VecEnv and stash the initial true observations. 

72 

73 Parameters 

74 ---------- 

75 kwargs 

76 Any keyword arguments to forward to `self.venv.reset`. 

77 

78 Returns 

79 ------- 

80 obs : np.ndarray 

81 The initial observations from the environment, shape 

82 ``(n_envs, *obs_shape)``. 

83 """ 

84 batch_obs = self.venv.reset(**kwargs) 

85 self.last_obs = batch_obs.copy() 

86 return batch_obs 

87 

88 def step_async(self, adversary_actions: np.ndarray): 

89 """ 

90 Receive a batch of adversarial action vectors, craft perturbed observations, 

91 query the agent for its actions, and forward those to the inner VecEnv. 

92 

93 Parameters 

94 ---------- 

95 adversary_actions : np.ndarray 

96 Array of shape (n_envs, flat_obs_dim) containing raw perturbation vectors. 

97 """ 

98 n_envs = adversary_actions.shape[0] 

99 obs_shape = self.observation_space.shape 

100 

101 # reshape to (n_envs, *obs_shape) 

102 adv_flat = adversary_actions.reshape((n_envs,) + obs_shape) 

103 

104 # scale and apply perturbation 

105 perturbations = self.scale_perturbation_fct(adv_flat) 

106 adv_obs = self.apply_perturbation_to_obs_fct(perturbations, self.last_obs) 

107 

108 # get agent’s actions on perturbed obs 

109 agent_actions, _ = self.agent_predict_fct(adv_obs, deterministic=True) 

110 if isinstance(agent_actions, torch.Tensor): 

111 agent_actions = agent_actions.detach().cpu().numpy() 

112 

113 # forward to the real env 

114 self.venv.step_async(agent_actions) 

115 

116 def step_wait(self): 

117 """ 

118 Wait for the inner VecEnv to step, then retrieve next true obs, original reward, 

119 done flags, and infos. Stash the new observations, negate rewards, and return. 

120 

121 Returns 

122 ------- 

123 obs : np.ndarray 

124 Next true observations, shape (n_envs, *obs_shape). 

125 neg_reward : np.ndarray 

126 Negated rewards from the inner env, shape (n_envs,). 

127 dones : np.ndarray 

128 Boolean array indicating episode termination, shape (n_envs,). 

129 infos : List[dict] 

130 Info dictionaries returned by the underlying VecEnv. 

131 """ 

132 batch_obs, rewards, dones, infos = self.venv.step_wait() 

133 self.last_obs = batch_obs.copy() 

134 adv_rewards = -rewards 

135 return batch_obs, adv_rewards, dones, infos 

136 

137 def render(self, **kwargs): 

138 """ 

139 Render the underlying VecEnv. 

140 

141 Parameters 

142 ---------- 

143 kwargs 

144 Forwarded to `self.venv.render`. 

145 

146 Returns 

147 ------- 

148 Render output (e.g., array or None), as defined by the inner VecEnv. 

149 """ 

150 return self.venv.render(**kwargs)