Coverage for adaro_rl / pipelines / online_attack.py: 84%

92 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-14 07:50 +0000

1import os 

2import numpy as np 

3import pandas as pd 

4import gymnasium as gym 

5from stable_baselines3.common.vec_env import DummyVecEnv 

6 

7from ..attacks import make_attack, training_method_names, EnsembleAttackWrapper 

8from ..wrappers import ObsAttackAgentEnv 

9from .utils import normalize_lists, make_attack_list 

10 

11 

12def online_attack( 

13 config, 

14 attack_name, 

15 target, 

16 eps, 

17 norm, 

18 adversary_checkpoint=None, 

19 output_dir="agent", 

20 agent_checkpoint=None, 

21 reference_checkpoint=None, 

22 self_reference=False, 

23 render=False, 

24 device="cpu", 

25 seed=None, 

26 n_eval_episodes=None, 

27): 

28 """ 

29 Run an online adversarial attack evaluation with a configured agent and environment. 

30 

31 This function creates an evaluation environment and loads a primary agent along with 

32 a (possibly ensemble) adversarial attack. It then evaluates the agent in the environment 

33 under adversarial conditions, computes and displays evaluation metrics (reward and episode 

34 length statistics), and finally saves these metrics to CSV files. 

35 

36 Parameters 

37 ---------- 

38 config : object 

39 A configuration object containing environment and agent settings. 

40 attack_name : str or list of str 

41 The name(s) of the adversarial attack method(s) to be used. 

42 target : str or list of str 

43 The target specification(s) for the attack. If an entry equals "target_fct", then 

44 `config.target_fct` is used. 

45 eps : float or list of float 

46 The perturbation budget(s) (epsilon) for the adversarial attack. 

47 norm : {0, 1, 2, float('inf')} 

48 The norm constraint under which the perturbation is applied (e.g., 0 for sparse, 

49 1 for L1, 2 for L2, or float('inf') for L∞). 

50 adversary_checkpoint : str or list of str, optional 

51 Path(s) to checkpoint(s) for adversarial agents. If not provided, the reference agent is used. 

52 output_dir : str, optional 

53 Directory where result files (CSV) and model artifacts are saved. Default is "agent". 

54 agent_checkpoint : str, optional 

55 Path to a checkpoint for the primary agent. 

56 reference_checkpoint : str, optional 

57 Path to a checkpoint for the reference agent (used when self_reference is False). 

58 self_reference : bool, optional 

59 If True, the primary agent is used also as the reference agent; otherwise, a separate 

60 reference agent is created using `reference_checkpoint`. 

61 render : bool, optional 

62 Flag indicating whether the environment should render during evaluation. Default is False. 

63 device : str, optional 

64 The computation device (e.g., "cpu" or "cuda"). Default is "cpu". 

65 seed : int, optional 

66 Random seed used for reproducibility. Default is 0. 

67 

68 Returns 

69 ------- 

70 int 

71 Returns 1 upon successful evaluation of the adversarial attack. 

72 """ 

73 

74 attack_name_list, target_list, eps_list, adversary_checkpoint_list = ( 

75 normalize_lists( 

76 attack_name=attack_name, 

77 target=target, 

78 eps=eps, 

79 adversary_checkpoint=adversary_checkpoint, 

80 ).values() 

81 ) 

82 

83 os.makedirs(output_dir, exist_ok=True) 

84 adv_attack_mean_filename = "online_adv_attacks_norm_{}_mean_reward.csv".format(norm) 

85 adv_attack_mean_path = os.path.join(output_dir, adv_attack_mean_filename) 

86 adv_attack_std_filename = "online_adv_attacks_norm_{}_std_reward.csv".format(norm) 

87 adv_attack_std_path = os.path.join(output_dir, adv_attack_std_filename) 

88 

89 if n_eval_episodes is None: 

90 n_eval_episodes = config.n_eval_episodes 

91 

92 # ENV #################### 

93 

94 if not render: 

95 os.environ["SDL_VIDEODRIVER"] = "dummy" 

96 

97 base_env = config.make_env( 

98 env_id=config.eval_env_config["env_id"], 

99 n_envs=config.eval_env_config["n_envs"], 

100 n_frame_stack=config.eval_env_config["n_frame_stack"], 

101 wrapper_class=config.eval_env_config["wrapper_class"], 

102 adv_wrapper_class=None, 

103 env_kwargs=config.eval_env_config["env_kwargs"], 

104 seed=seed, 

105 ) 

106 

107 # AGENT #################### 

108 

109 class Placeholder(gym.Env): 

110 def __init__(self, env): 

111 self.observation_space = env.observation_space 

112 self.action_space = env.action_space 

113 

114 def placeholder_fct(cls=Placeholder, env=base_env): 

115 return cls(env) 

116 

117 placeholder_env = DummyVecEnv( 

118 [placeholder_fct for _ in range(config.eval_env_config["n_envs"])] 

119 ) 

120 

121 def make_agent_fct(): 

122 return config.make_agent( 

123 algo=config.agent_config["algo"], 

124 env=placeholder_env, 

125 checkpoint=agent_checkpoint, 

126 device=device, 

127 seed=seed, 

128 algo_kwargs=config.agent_config["algo_kwargs"], 

129 ) 

130 

131 agent = make_agent_fct() 

132 

133 # REFERENCE AGENT #################### 

134 

135 if self_reference: 

136 make_reference_agent_fct = make_agent_fct 

137 else: 

138 

139 def make_reference_agent_fct(): 

140 return config.make_agent( 

141 algo=config.agent_config["algo"], 

142 env=placeholder_env, 

143 checkpoint=reference_checkpoint, 

144 device=device, 

145 seed=seed, 

146 algo_kwargs=config.agent_config["algo_kwargs"], 

147 ) 

148 

149 # ADVERSARIAL POLICY #################### 

150 

151 make_agent_fct_list_for_attack = [] 

152 

153 for attack_name, adversary_checkpoint in zip( 

154 attack_name_list, adversary_checkpoint_list 

155 ): 

156 if attack_name in training_method_names: 

157 adv_action_space = base_env.get_attr("observation_perturbation_space")[0] 

158 

159 class AdversaryPlaceholder(gym.Env): 

160 def __init__(self, env): 

161 self.observation_space = env.observation_space 

162 self.action_space = adv_action_space 

163 

164 adversary_placeholder_fct = ( 

165 lambda cls=AdversaryPlaceholder, env=base_env: cls(env) 

166 ) 

167 adversary_placeholder_env = DummyVecEnv( 

168 [ 

169 adversary_placeholder_fct 

170 for _ in range(config.eval_env_config["n_envs"]) 

171 ] 

172 ) 

173 

174 make_adversary_agent_fct = ( 

175 lambda adversary_checkpoint=adversary_checkpoint, 

176 adversary_placeholder_env=adversary_placeholder_env: config.make_agent( 

177 algo=config.agent_config["algo"], 

178 env=adversary_placeholder_env, 

179 checkpoint=adversary_checkpoint, 

180 device=device, 

181 seed=seed, 

182 algo_kwargs=config.adversary_config["algo_kwargs"], 

183 ) 

184 ) 

185 

186 make_agent_fct_list_for_attack.append(make_adversary_agent_fct) 

187 

188 else: 

189 make_agent_fct_list_for_attack.append(make_reference_agent_fct) 

190 

191 make_attack_fct_list = make_attack_list( 

192 base_env, 

193 attack_name_list, 

194 make_agent_fct_list_for_attack, 

195 target_list, 

196 eps_list, 

197 config, 

198 make_attack, 

199 norm, 

200 device, 

201 ) 

202 

203 if len(make_attack_fct_list) > 1: 

204 def make_attack_fct(): 

205 return EnsembleAttackWrapper( 

206 make_attack_fct_list=make_attack_fct_list 

207 ) 

208 attack_label = "Ensemble" 

209 eps_label = -1 

210 else: 

211 make_attack_fct = make_attack_fct_list[0] 

212 attack_label = "{}_{}".format(attack_name_list[0], target_list[0]) 

213 eps_label = eps_list[0] 

214 

215 # ADV WRAPPER #################### 

216 

217 adv_wrapper_class = ObsAttackAgentEnv 

218 adv_wrapper_kwargs = {"make_attack_fct": make_attack_fct, "freq": 1} 

219 

220 # MAKE ENV #################### 

221 

222 if render: 

223 env_config = config.render_env_config 

224 else: 

225 env_config = config.eval_env_config 

226 os.environ["SDL_VIDEODRIVER"] = "dummy" 

227 

228 eval_env = config.make_env( 

229 env_id=env_config["env_id"], 

230 n_envs=env_config["n_envs"], 

231 n_frame_stack=env_config["n_frame_stack"], 

232 wrapper_class=env_config["wrapper_class"], 

233 adv_wrapper_class=adv_wrapper_class, 

234 adv_wrapper_kwargs=adv_wrapper_kwargs, 

235 env_kwargs=env_config["env_kwargs"], 

236 seed=seed, 

237 ) 

238 

239 # RUN #################### 

240 

241 rewards, lengths = config.eval_config["evaluate_policy_fct"]( 

242 agent, 

243 eval_env, 

244 n_eval_episodes=n_eval_episodes, 

245 render=render, 

246 deterministic=config.eval_config["deterministic_eval"], 

247 return_episode_rewards=True, 

248 ) 

249 

250 # DISPLAY #################### 

251 

252 mean_lengths = np.mean(lengths) 

253 std_lengths = np.std(lengths) 

254 print("lengths : {}".format(lengths)) 

255 print("mean : {}".format(mean_lengths)) 

256 print("std : {}".format(std_lengths)) 

257 

258 mean_reward = np.mean(rewards) 

259 std_reward = np.std(rewards) 

260 print("rewards : {}".format(rewards)) 

261 print("mean : {}".format(mean_reward)) 

262 print("std : {}".format(std_reward)) 

263 

264 # SAVE #################### 

265 

266 def save(path, result): 

267 if os.path.isfile(path): 

268 df = pd.read_csv(path) 

269 else: 

270 df = pd.DataFrame(columns=["eps", attack_label]) 

271 

272 # Ensure all required columns exist 

273 if "eps" not in df.columns: 

274 df["eps"] = np.nan 

275 if attack_label not in df.columns: 

276 df[attack_label] = np.nan 

277 

278 # Convert to correct type for lookup 

279 df["eps"] = df["eps"].astype(float) 

280 

281 # Update or insert the result 

282 if eps_label in df["eps"].values: 

283 df.loc[df["eps"] == eps_label, attack_label] = result 

284 else: 

285 new_row = pd.DataFrame([{"eps": eps_label, attack_label: result}]) 

286 df = pd.concat([df, new_row], ignore_index=True) 

287 

288 df.to_csv(path, index=False) 

289 

290 save(adv_attack_mean_path, mean_reward) 

291 save(adv_attack_std_path, std_reward) 

292 

293 print("save in {} and {}".format(adv_attack_mean_path, adv_attack_std_path)) 

294 

295 # CLOSE #################### 

296 

297 eval_env.close()