Coverage for adaro_rl / pipelines / online_attack.py: 84%
92 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 07:50 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 07:50 +0000
1import os
2import numpy as np
3import pandas as pd
4import gymnasium as gym
5from stable_baselines3.common.vec_env import DummyVecEnv
7from ..attacks import make_attack, training_method_names, EnsembleAttackWrapper
8from ..wrappers import ObsAttackAgentEnv
9from .utils import normalize_lists, make_attack_list
12def online_attack(
13 config,
14 attack_name,
15 target,
16 eps,
17 norm,
18 adversary_checkpoint=None,
19 output_dir="agent",
20 agent_checkpoint=None,
21 reference_checkpoint=None,
22 self_reference=False,
23 render=False,
24 device="cpu",
25 seed=None,
26 n_eval_episodes=None,
27):
28 """
29 Run an online adversarial attack evaluation with a configured agent and environment.
31 This function creates an evaluation environment and loads a primary agent along with
32 a (possibly ensemble) adversarial attack. It then evaluates the agent in the environment
33 under adversarial conditions, computes and displays evaluation metrics (reward and episode
34 length statistics), and finally saves these metrics to CSV files.
36 Parameters
37 ----------
38 config : object
39 A configuration object containing environment and agent settings.
40 attack_name : str or list of str
41 The name(s) of the adversarial attack method(s) to be used.
42 target : str or list of str
43 The target specification(s) for the attack. If an entry equals "target_fct", then
44 `config.target_fct` is used.
45 eps : float or list of float
46 The perturbation budget(s) (epsilon) for the adversarial attack.
47 norm : {0, 1, 2, float('inf')}
48 The norm constraint under which the perturbation is applied (e.g., 0 for sparse,
49 1 for L1, 2 for L2, or float('inf') for L∞).
50 adversary_checkpoint : str or list of str, optional
51 Path(s) to checkpoint(s) for adversarial agents. If not provided, the reference agent is used.
52 output_dir : str, optional
53 Directory where result files (CSV) and model artifacts are saved. Default is "agent".
54 agent_checkpoint : str, optional
55 Path to a checkpoint for the primary agent.
56 reference_checkpoint : str, optional
57 Path to a checkpoint for the reference agent (used when self_reference is False).
58 self_reference : bool, optional
59 If True, the primary agent is used also as the reference agent; otherwise, a separate
60 reference agent is created using `reference_checkpoint`.
61 render : bool, optional
62 Flag indicating whether the environment should render during evaluation. Default is False.
63 device : str, optional
64 The computation device (e.g., "cpu" or "cuda"). Default is "cpu".
65 seed : int, optional
66 Random seed used for reproducibility. Default is 0.
68 Returns
69 -------
70 int
71 Returns 1 upon successful evaluation of the adversarial attack.
72 """
74 attack_name_list, target_list, eps_list, adversary_checkpoint_list = (
75 normalize_lists(
76 attack_name=attack_name,
77 target=target,
78 eps=eps,
79 adversary_checkpoint=adversary_checkpoint,
80 ).values()
81 )
83 os.makedirs(output_dir, exist_ok=True)
84 adv_attack_mean_filename = "online_adv_attacks_norm_{}_mean_reward.csv".format(norm)
85 adv_attack_mean_path = os.path.join(output_dir, adv_attack_mean_filename)
86 adv_attack_std_filename = "online_adv_attacks_norm_{}_std_reward.csv".format(norm)
87 adv_attack_std_path = os.path.join(output_dir, adv_attack_std_filename)
89 if n_eval_episodes is None:
90 n_eval_episodes = config.n_eval_episodes
92 # ENV ####################
94 if not render:
95 os.environ["SDL_VIDEODRIVER"] = "dummy"
97 base_env = config.make_env(
98 env_id=config.eval_env_config["env_id"],
99 n_envs=config.eval_env_config["n_envs"],
100 n_frame_stack=config.eval_env_config["n_frame_stack"],
101 wrapper_class=config.eval_env_config["wrapper_class"],
102 adv_wrapper_class=None,
103 env_kwargs=config.eval_env_config["env_kwargs"],
104 seed=seed,
105 )
107 # AGENT ####################
109 class Placeholder(gym.Env):
110 def __init__(self, env):
111 self.observation_space = env.observation_space
112 self.action_space = env.action_space
114 def placeholder_fct(cls=Placeholder, env=base_env):
115 return cls(env)
117 placeholder_env = DummyVecEnv(
118 [placeholder_fct for _ in range(config.eval_env_config["n_envs"])]
119 )
121 def make_agent_fct():
122 return config.make_agent(
123 algo=config.agent_config["algo"],
124 env=placeholder_env,
125 checkpoint=agent_checkpoint,
126 device=device,
127 seed=seed,
128 algo_kwargs=config.agent_config["algo_kwargs"],
129 )
131 agent = make_agent_fct()
133 # REFERENCE AGENT ####################
135 if self_reference:
136 make_reference_agent_fct = make_agent_fct
137 else:
139 def make_reference_agent_fct():
140 return config.make_agent(
141 algo=config.agent_config["algo"],
142 env=placeholder_env,
143 checkpoint=reference_checkpoint,
144 device=device,
145 seed=seed,
146 algo_kwargs=config.agent_config["algo_kwargs"],
147 )
149 # ADVERSARIAL POLICY ####################
151 make_agent_fct_list_for_attack = []
153 for attack_name, adversary_checkpoint in zip(
154 attack_name_list, adversary_checkpoint_list
155 ):
156 if attack_name in training_method_names:
157 adv_action_space = base_env.get_attr("observation_perturbation_space")[0]
159 class AdversaryPlaceholder(gym.Env):
160 def __init__(self, env):
161 self.observation_space = env.observation_space
162 self.action_space = adv_action_space
164 adversary_placeholder_fct = (
165 lambda cls=AdversaryPlaceholder, env=base_env: cls(env)
166 )
167 adversary_placeholder_env = DummyVecEnv(
168 [
169 adversary_placeholder_fct
170 for _ in range(config.eval_env_config["n_envs"])
171 ]
172 )
174 make_adversary_agent_fct = (
175 lambda adversary_checkpoint=adversary_checkpoint,
176 adversary_placeholder_env=adversary_placeholder_env: config.make_agent(
177 algo=config.agent_config["algo"],
178 env=adversary_placeholder_env,
179 checkpoint=adversary_checkpoint,
180 device=device,
181 seed=seed,
182 algo_kwargs=config.adversary_config["algo_kwargs"],
183 )
184 )
186 make_agent_fct_list_for_attack.append(make_adversary_agent_fct)
188 else:
189 make_agent_fct_list_for_attack.append(make_reference_agent_fct)
191 make_attack_fct_list = make_attack_list(
192 base_env,
193 attack_name_list,
194 make_agent_fct_list_for_attack,
195 target_list,
196 eps_list,
197 config,
198 make_attack,
199 norm,
200 device,
201 )
203 if len(make_attack_fct_list) > 1:
204 def make_attack_fct():
205 return EnsembleAttackWrapper(
206 make_attack_fct_list=make_attack_fct_list
207 )
208 attack_label = "Ensemble"
209 eps_label = -1
210 else:
211 make_attack_fct = make_attack_fct_list[0]
212 attack_label = "{}_{}".format(attack_name_list[0], target_list[0])
213 eps_label = eps_list[0]
215 # ADV WRAPPER ####################
217 adv_wrapper_class = ObsAttackAgentEnv
218 adv_wrapper_kwargs = {"make_attack_fct": make_attack_fct, "freq": 1}
220 # MAKE ENV ####################
222 if render:
223 env_config = config.render_env_config
224 else:
225 env_config = config.eval_env_config
226 os.environ["SDL_VIDEODRIVER"] = "dummy"
228 eval_env = config.make_env(
229 env_id=env_config["env_id"],
230 n_envs=env_config["n_envs"],
231 n_frame_stack=env_config["n_frame_stack"],
232 wrapper_class=env_config["wrapper_class"],
233 adv_wrapper_class=adv_wrapper_class,
234 adv_wrapper_kwargs=adv_wrapper_kwargs,
235 env_kwargs=env_config["env_kwargs"],
236 seed=seed,
237 )
239 # RUN ####################
241 rewards, lengths = config.eval_config["evaluate_policy_fct"](
242 agent,
243 eval_env,
244 n_eval_episodes=n_eval_episodes,
245 render=render,
246 deterministic=config.eval_config["deterministic_eval"],
247 return_episode_rewards=True,
248 )
250 # DISPLAY ####################
252 mean_lengths = np.mean(lengths)
253 std_lengths = np.std(lengths)
254 print("lengths : {}".format(lengths))
255 print("mean : {}".format(mean_lengths))
256 print("std : {}".format(std_lengths))
258 mean_reward = np.mean(rewards)
259 std_reward = np.std(rewards)
260 print("rewards : {}".format(rewards))
261 print("mean : {}".format(mean_reward))
262 print("std : {}".format(std_reward))
264 # SAVE ####################
266 def save(path, result):
267 if os.path.isfile(path):
268 df = pd.read_csv(path)
269 else:
270 df = pd.DataFrame(columns=["eps", attack_label])
272 # Ensure all required columns exist
273 if "eps" not in df.columns:
274 df["eps"] = np.nan
275 if attack_label not in df.columns:
276 df[attack_label] = np.nan
278 # Convert to correct type for lookup
279 df["eps"] = df["eps"].astype(float)
281 # Update or insert the result
282 if eps_label in df["eps"].values:
283 df.loc[df["eps"] == eps_label, attack_label] = result
284 else:
285 new_row = pd.DataFrame([{"eps": eps_label, attack_label: result}])
286 df = pd.concat([df, new_row], ignore_index=True)
288 df.to_csv(path, index=False)
290 save(adv_attack_mean_path, mean_reward)
291 save(adv_attack_std_path, std_reward)
293 print("save in {} and {}".format(adv_attack_mean_path, adv_attack_std_path))
295 # CLOSE ####################
297 eval_env.close()