Coverage for adaro_rl / pipelines / adversarial_train.py: 78%
59 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 07:50 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 07:50 +0000
1import os
2import gymnasium as gym
3from stable_baselines3.common.vec_env import DummyVecEnv
5from ..attacks import make_attack, training_method_names, EnsembleAttackWrapper
6from ..wrappers import ObsAttackAgentEnv
7from .utils import normalize_lists, make_attack_list
10def adversarial_train(
11 config,
12 attack_name,
13 target,
14 eps,
15 norm,
16 adversary_checkpoint=None,
17 output_dir="agent",
18 agent_checkpoint=None,
19 reference_checkpoint=None,
20 self_reference=False,
21 device="cpu",
22 seed=None,
23 total_timesteps=None,
24 prepopulate_timesteps=None,
25 verbose=True,
26):
27 """
28 Perform adversarial training with specified attack configurations and environments.
30 This function sets up the training and evaluation environments, constructs
31 the main agent and adversarial policy agents (or a reference agent), builds the
32 adversarial attack (either single or ensemble), and finally trains the adversarial
33 agent using the provided learning routine. After training, the agent is saved to
34 the specified output directory.
36 Parameters
37 ----------
38 config : object
39 A configuration object containing environment and agent settings. It must include
40 keys 'env_config' and 'agent_config' with appropriate sub-keys such as:
41 - env_config: with keys "env_id", "n_envs", "n_frame_stack", "wrapper_class",
42 "training_env_kwargs", "eval_env_kwargs", and "deterministic_eval".
43 - agent_config: with keys "algo", "adv_trained_algo_kwargs", "algo_kwargs", "adversary_algo_kwargs",
44 and "adv_training_kwargs" (which further includes "total_timesteps" and "prepopulate_timesteps").
45 attack_name : str or list of str
46 The name(s) of the attack method(s) to be used for adversarial training.
47 target : str or list of str
48 The target specification(s) for the attack. If an entry is "target_fct", then
49 `config.target_fct` will be used.
50 eps : float or list of float
51 The perturbation budget(s) (epsilon) for the adversarial attack.
52 norm : {0, 1, 2, float('inf')}
53 The norm to be used for the perturbation (e.g., 0 for sparse, 1/L1, 2/L2, or float('inf') for L∞).
54 adversary_checkpoint : str or list of str, optional
55 The file path(s) to the checkpoint(s) for the adversarial agent. If not provided,
56 the reference agent is used where applicable.
57 output_dir : str, optional
58 Directory where the agent and training artifacts will be saved. Default is "agent".
59 agent_checkpoint : str, optional
60 A checkpoint for the main agent (if applicable).
61 reference_checkpoint : str, optional
62 A checkpoint for the reference agent used when `self_reference` is False.
63 self_reference : bool, optional
64 If True, the adversarially trained agent is also used as the reference agent.
65 Otherwise, the reference agent is constructed using the provided `reference_checkpoint`.
66 device : str, optional
67 The device string for computation (e.g., "cpu" or "cuda"). Default is "cpu".
68 seed : int, optional
69 Random seed for reproducibility.
71 Returns
72 -------
73 int
74 Returns 1 upon successful completion of adversarial training.
75 """
77 attack_name_list, target_list, eps_list, adversary_checkpoint_list = (
78 normalize_lists(
79 attack_name=attack_name,
80 target=target,
81 eps=eps,
82 adversary_checkpoint=adversary_checkpoint,
83 ).values()
84 )
86 os.makedirs(output_dir, exist_ok=True)
88 if total_timesteps is None:
89 total_timesteps = config.finetuned_training_kwargs["total_timesteps"]
90 if prepopulate_timesteps is None:
91 prepopulate_timesteps = config.finetuned_training_kwargs[
92 "prepopulate_timesteps"
93 ]
95 # ENV ####################
97 os.environ["SDL_VIDEODRIVER"] = "dummy"
99 base_env = config.make_env(
100 env_id=config.train_env_config["env_id"],
101 n_envs=config.train_env_config["n_envs"],
102 n_frame_stack=config.train_env_config["n_frame_stack"],
103 wrapper_class=config.train_env_config["wrapper_class"],
104 adv_wrapper_class=None,
105 env_kwargs=config.train_env_config["env_kwargs"],
106 seed=seed,
107 )
109 # AGENT ####################
111 class Placeholder(gym.Env):
112 def __init__(self, env):
113 self.observation_space = env.observation_space
114 self.action_space = env.action_space
116 def placeholder_fct(cls=Placeholder, env=base_env):
117 return cls(env)
119 # placeholder_fct = lambda cls=Placeholder, env=base_env: cls(env)
120 placeholder_env = DummyVecEnv(
121 [placeholder_fct for _ in range(config.train_env_config["n_envs"])]
122 )
124 def make_adv_trained_agent_fct():
125 return config.make_agent(
126 algo=config.finetuned_agent_config["algo"],
127 env=placeholder_env,
128 checkpoint=agent_checkpoint,
129 output_dir=output_dir,
130 device=device,
131 seed=seed,
132 verbose=verbose,
133 algo_kwargs=config.finetuned_agent_config["algo_kwargs"],
134 )
136 # make_adv_trained_agent_fct = lambda: config.make_agent(
137 # algo=config.finetuned_agent_config["algo"],
138 # env=placeholder_env,
139 # checkpoint=agent_checkpoint,
140 # output_dir=output_dir,
141 # device=device,
142 # seed=seed,
143 # verbose=verbose,
144 # algo_kwargs=config.finetuned_agent_config["algo_kwargs"],
145 # )
147 adv_trained_agent = make_adv_trained_agent_fct()
149 adv_trained_agent.train()
151 # REFERENCE AGENT ####################
153 if self_reference:
154 make_reference_agent_fct = make_adv_trained_agent_fct
155 else:
157 def make_reference_agent_fct():
158 return config.make_agent(
159 algo=config.agent_config["algo"],
160 env=placeholder_env,
161 checkpoint=reference_checkpoint,
162 device=device,
163 seed=seed,
164 verbose=verbose,
165 algo_kwargs=config.agent_config["algo_kwargs"],
166 )
167 # make_reference_agent_fct = lambda: config.make_agent(
168 # algo=config.agent_config["algo"],
169 # env=placeholder_env,
170 # checkpoint=reference_checkpoint,
171 # device=device,
172 # seed=seed,
173 # verbose=verbose,
174 # algo_kwargs=config.agent_config["algo_kwargs"],
175 # )
177 # ADVERSARIAL POLICY ####################
179 make_agent_fct_list_for_attack = []
181 for attack_name, adversary_checkpoint in zip(
182 attack_name_list, adversary_checkpoint_list
183 ):
184 if attack_name in training_method_names:
185 adv_action_space = base_env.get_attr("observation_perturbation_space")[0]
187 class AdversaryPlaceholder(gym.Env):
188 def __init__(self, env):
189 self.observation_space = env.observation_space
190 self.action_space = adv_action_space
192 adversary_placeholder_fct = (
193 lambda cls=AdversaryPlaceholder, env=base_env: cls(env)
194 )
195 adversary_placeholder_env = DummyVecEnv(
196 [
197 adversary_placeholder_fct
198 for _ in range(config.train_env_config["n_envs"])
199 ]
200 )
202 make_adversary_agent_fct = (
203 lambda adversary_placeholder_env=adversary_placeholder_env,
204 adversary_checkpoint=adversary_checkpoint: config.make_agent(
205 algo=config.agent_config["algo"],
206 env=adversary_placeholder_env,
207 checkpoint=adversary_checkpoint,
208 device=device,
209 seed=seed,
210 verbose=verbose,
211 algo_kwargs=config.adversary_config["algo_kwargs"],
212 )
213 )
215 make_agent_fct_list_for_attack.append(make_adversary_agent_fct)
217 else:
218 make_agent_fct_list_for_attack.append(make_reference_agent_fct)
220 make_attack_fct_list = make_attack_list(
221 base_env,
222 attack_name_list,
223 make_agent_fct_list_for_attack,
224 target_list,
225 eps_list,
226 config,
227 make_attack,
228 norm,
229 device,
230 )
232 if len(make_attack_fct_list) > 1:
234 def make_attack_fct():
235 return EnsembleAttackWrapper(
236 make_attack_fct_list=make_attack_fct_list
237 )
239 # make_attack_fct = lambda: EnsembleAttackWrapper(
240 # make_attack_fct_list=make_attack_fct_list
241 # )
242 else:
243 make_attack_fct = make_attack_fct_list[0]
245 # ADV WRAPPER ####################
247 adv_wrapper_class = ObsAttackAgentEnv
248 adv_wrapper_kwargs = {"make_attack_fct": make_attack_fct, "freq": 1}
250 # MAKE ENV ####################
252 training_env = config.make_env(
253 env_id=config.train_env_config["env_id"],
254 n_envs=config.train_env_config["n_envs"],
255 n_frame_stack=config.train_env_config["n_frame_stack"],
256 wrapper_class=config.train_env_config["wrapper_class"],
257 adv_wrapper_class=adv_wrapper_class,
258 adv_wrapper_kwargs=adv_wrapper_kwargs,
259 env_kwargs=config.train_env_config["env_kwargs"],
260 seed=seed,
261 )
263 eval_env = config.make_env(
264 env_id=config.eval_env_config["env_id"],
265 n_envs=config.eval_env_config["n_envs"],
266 n_frame_stack=config.eval_env_config["n_frame_stack"],
267 wrapper_class=config.eval_env_config["wrapper_class"],
268 adv_wrapper_class=adv_wrapper_class,
269 adv_wrapper_kwargs=adv_wrapper_kwargs,
270 env_kwargs=config.eval_env_config["env_kwargs"],
271 seed=seed,
272 )
274 adv_trained_agent.model.set_env(training_env)
275 placeholder_env.close()
276 base_env.close()
278 # TRAINING ####################
280 adv_trained_agent.learn(
281 eval_env,
282 total_timesteps=total_timesteps,
283 prepopulate_timesteps=prepopulate_timesteps,
284 )
286 # SAVE ####################
288 adv_trained_agent.save(checkpoint=os.path.join(output_dir, "model.zip"))
290 # CLOSE ###################
291 training_env.close()
292 eval_env.close()