Coverage for adaro_rl / pipelines / main.py: 88%
97 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 07:50 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 07:50 +0000
1from . import train, test, online_attack, adversarial_train
2import argparse
3import importlib
5DEFAULT_ZOO = "adaro_rl.zoo"
6ZOO_HELP = "Zoo module where configs are stored."
7CONFIG_HELP = "Name of the config to use."
8CHECKPOINT_HELP = "Checkpoint to load (if any)."
9AGENT_CHECKPOINT_HELP = "Agent checkpoint to load (if any)."
10LIST_AGENT_CHECKPOINT_HELP = "List of agent checkpoint to load (if any)."
11ADVERSARY_CHECKPOINT_HELP = "List of Adversary checkpoint to load (if any)."
12OUTPUT_DIR_HELP_MODEL = "Directory to store agent and logs."
13OUTPUT_DIR_HELP_RESULTS = "Directory to store the results."
14DEVICE_HELP = "Device to run the experiment on."
15SEED_HELP = "Random seed."
16NORM_HELP = "Norm of the perturbation."
17TRAIN_STEP_HELP = "Number or training steps."
18PREPOPULATE_HELP = "Number of step before starting training"
21def main() -> None:
22 """
23 Main entry point for the CLI utility to run various reinforcement learning tasks.
25 This function parses command-line arguments to select and run one of several tasks:
26 - Train a reinforcement learning agent.
27 - Test a trained agent in the environment.
28 - Train an adversarial agent to attack a victim agent.
29 - Perform an online adversarial attack using specified attack methods.
30 - Perform adversarial training using single or ensemble attacks.
32 The function dynamically loads a configuration from a specified zoo module, verifies
33 the configuration, and dispatches the execution to the appropriate handler based on
34 the command selected by the user.
36 CLI Commands Supported
37 ----------------------
38 - `train`: Train a new RL agent.
39 - `test`: Evaluate a trained agent.
40 - `online-attack`: Run online adversarial attacks on an agent.
41 - `adversarial-training`: Perform adversarial training with attack strategies.
43 Command-Line Arguments
44 ----------------------
45 Each subcommand has its own set of arguments, including but not limited to:
46 zoo : str
47 The import path to the module that contains available experiment configurations.
48 config : str
49 The key name of the configuration to use from the zoo module.
50 checkpoint : str, optional
51 Path to the checkpoint of a trained model (for agent or adversary).
52 output-dir : str, optional
53 Directory where model artifacts or results should be saved.
54 device : str, optional
55 The device to run the experiment on (e.g., "cpu" or "cuda").
56 seed : int, optional
57 Random seed to ensure reproducibility.
58 render : bool, optional
59 Whether to enable environment rendering during evaluation.
61 Raises
62 ------
63 ImportError
64 If the specified zoo module cannot be imported.
65 KeyError
66 If the specified configuration name is not found in the zoo module.
68 Returns
69 -------
70 None
71 The function executes the specified command and exits without returning any value.
72 """
74 parser = argparse.ArgumentParser(description="Script for multiple functionalities")
75 subparsers = parser.add_subparsers(
76 dest="command", required=True, help="Available commands"
77 )
79 # Train command
80 parser_train = subparsers.add_parser("train", help="Train an agent")
81 parser_train.add_argument("--zoo", type=str, default=DEFAULT_ZOO, help=ZOO_HELP)
82 parser_train.add_argument("--config", type=str, required=True, help=CONFIG_HELP)
83 parser_train.add_argument(
84 "--checkpoint", type=str, default=None, help=CHECKPOINT_HELP
85 )
86 parser_train.add_argument(
87 "--output-dir", type=str, default="agent", help=OUTPUT_DIR_HELP_MODEL
88 )
89 parser_train.add_argument("--device", type=str, default="cpu", help=DEVICE_HELP)
90 parser_train.add_argument("--seed", type=int, default=0, help=SEED_HELP)
91 parser_train.add_argument(
92 "--total-timesteps", type=int, default=None, help=TRAIN_STEP_HELP
93 )
94 parser_train.add_argument(
95 "--prepopulate-timesteps", type=int, default=None, help=PREPOPULATE_HELP
96 )
97 parser_train.add_argument("--verbose", dest="verbose", action="store_true")
98 parser_train.add_argument("--no-verbose", dest="verbose", action="store_false")
99 parser_train.set_defaults(verbose=False)
101 # Test command
102 parser_test = subparsers.add_parser("test", help="Test an agent")
103 parser_test.add_argument("--zoo", type=str, default=DEFAULT_ZOO, help=ZOO_HELP)
104 parser_test.add_argument("--config", type=str, required=True, help=CONFIG_HELP)
105 parser_test.add_argument(
106 "--checkpoint", type=str, required=True, help=LIST_AGENT_CHECKPOINT_HELP
107 )
108 parser_test.add_argument(
109 "--output-dir", type=str, default="agent", help=OUTPUT_DIR_HELP_RESULTS
110 )
111 parser_test.add_argument("--render", action="store_true", help="Enable rendering.")
112 parser_test.add_argument("--device", type=str, default="cpu", help=DEVICE_HELP)
113 parser_test.add_argument("--seed", type=int, default=0, help=SEED_HELP)
114 parser_test.add_argument(
115 "--n-eval-episodes", type=int, default=None, help="Number of evaluation steps."
116 )
118 # Online attack command
119 parser_online_attack = subparsers.add_parser(
120 "online-attack", help="Perform an online attack on an agent"
121 )
122 parser_online_attack.add_argument(
123 "--zoo", type=str, default=DEFAULT_ZOO, help=ZOO_HELP
124 )
125 parser_online_attack.add_argument(
126 "--config", type=str, required=True, help=CONFIG_HELP
127 )
128 parser_online_attack.add_argument(
129 "--attack-name",
130 type=str,
131 required=True,
132 nargs="+",
133 help="List of attack names.",
134 )
135 parser_online_attack.add_argument(
136 "--target",
137 type=str,
138 required=True,
139 nargs="+",
140 help="List of target specifications for each attack.",
141 )
142 parser_online_attack.add_argument(
143 "--eps",
144 type=float,
145 required=True,
146 nargs="+",
147 help="List of perturbation amounts (epsilon).",
148 )
149 parser_online_attack.add_argument(
150 "--norm", type=float, required=True, help=NORM_HELP
151 )
152 parser_online_attack.add_argument(
153 "--adversary-checkpoint",
154 type=str,
155 default=None,
156 nargs="+",
157 help=ADVERSARY_CHECKPOINT_HELP,
158 )
159 parser_online_attack.add_argument(
160 "--output-dir", type=str, default="agent", help=OUTPUT_DIR_HELP_RESULTS
161 )
162 parser_online_attack.add_argument(
163 "--agent-checkpoint", type=str, required=True, help=LIST_AGENT_CHECKPOINT_HELP
164 )
165 parser_online_attack.add_argument(
166 "--reference-checkpoint",
167 type=str,
168 default=None,
169 help="Reference agent checkpoint to load (if any).",
170 )
171 parser_online_attack.add_argument(
172 "--self-reference",
173 action="store_true",
174 help="Flag for using the primary agent as the reference agent; otherwise it uses `reference_checkpoint`.",
175 )
176 parser_online_attack.add_argument(
177 "--render", action="store_true", help="Enable rendering."
178 )
179 parser_online_attack.add_argument(
180 "--device", type=str, default="cpu", help=DEVICE_HELP
181 )
182 parser_online_attack.add_argument("--seed", type=int, default=0, help=SEED_HELP)
183 parser_online_attack.add_argument(
184 "--n-eval-episodes", type=int, default=None, help="Number of evaluation steps."
185 )
187 # Adversarial training command
188 parser_adv_training = subparsers.add_parser(
189 "adversarial-training", help="Perform adversarial training"
190 )
191 parser_adv_training.add_argument(
192 "--zoo", type=str, default=DEFAULT_ZOO, help=ZOO_HELP
193 )
194 parser_adv_training.add_argument(
195 "--config", type=str, required=True, help=CONFIG_HELP
196 )
197 parser_adv_training.add_argument(
198 "--attack-name",
199 type=str,
200 required=True,
201 nargs="+",
202 help="List of attack names.",
203 )
204 parser_adv_training.add_argument(
205 "--target",
206 type=str,
207 required=True,
208 nargs="+",
209 help="List of target specifications for each attack.",
210 )
211 parser_adv_training.add_argument(
212 "--eps",
213 type=float,
214 required=True,
215 nargs="+",
216 help="List of perturbation amounts (epsilon).",
217 )
218 parser_adv_training.add_argument(
219 "--norm", type=float, required=True, help=NORM_HELP
220 )
221 parser_adv_training.add_argument(
222 "--adversary-checkpoint",
223 type=str,
224 default=None,
225 nargs="+",
226 help=ADVERSARY_CHECKPOINT_HELP,
227 )
228 parser_adv_training.add_argument(
229 "--output-dir", type=str, default="adv_agent", help=OUTPUT_DIR_HELP_MODEL
230 )
231 parser_adv_training.add_argument(
232 "--agent-checkpoint", type=str, default=None, help=AGENT_CHECKPOINT_HELP
233 )
234 parser_adv_training.add_argument(
235 "--reference-checkpoint",
236 type=str,
237 default=None,
238 help="Reference agent checkpoint to load (if any).",
239 )
240 parser_adv_training.add_argument(
241 "--self-reference",
242 action="store_true",
243 help="Flag for using the primary agent as the reference agent; otherwise it uses `reference_checkpoint`.",
244 )
245 parser_adv_training.add_argument(
246 "--device", type=str, default="cpu", help=DEVICE_HELP
247 )
248 parser_adv_training.add_argument("--seed", type=int, default=0, help=SEED_HELP)
249 parser_adv_training.add_argument(
250 "--total-timesteps", type=int, default=None, help=TRAIN_STEP_HELP
251 )
252 parser_adv_training.add_argument(
253 "--prepopulate-timesteps", type=int, default=None, help=PREPOPULATE_HELP
254 )
255 parser_adv_training.add_argument("--verbose", dest="verbose", action="store_true")
256 parser_adv_training.add_argument(
257 "--no-verbose", dest="verbose", action="store_false"
258 )
259 parser_adv_training.set_defaults(verbose=False)
261 args = parser.parse_args()
263 # Common setup
264 device = args.device # Or use a helper like setup_device(args.device)
265 # Example: setup_seed(args.seed) if you have such a function
267 # Import the zoo module
268 try:
269 zoo = importlib.import_module(args.zoo)
270 except ImportError:
271 raise ImportError(
272 f"Could not import the specified zoo module: {args.zoo}. Please ensure it is"
273 "installed and available."
274 )
276 # Ensure the config exists
277 if args.config not in zoo.configs:
278 available_configs = list(zoo.configs.keys())
279 raise KeyError(
280 f"The configuration '{args.config}' does not exist in the zoo module '{args.zoo}'. "
281 f"Available configurations: {available_configs}"
282 )
284 config = zoo.configs[args.config]
286 # Dispatch based on the selected command
287 if args.command == "train":
288 # For training, new signature: train(config, output_dir, checkpoint, device, seed)
289 train(
290 config,
291 output_dir=args.output_dir,
292 checkpoint=args.checkpoint,
293 device=device,
294 seed=args.seed,
295 total_timesteps=args.total_timesteps,
296 prepopulate_timesteps=args.prepopulate_timesteps,
297 verbose=args.verbose,
298 )
300 elif args.command == "test":
301 # For testing, new signature: test(config, output_dir, checkpoint, render, device, seed)
302 test(
303 config,
304 output_dir=args.output_dir,
305 checkpoint=args.checkpoint,
306 render=args.render,
307 device=device,
308 seed=args.seed,
309 n_eval_episodes=args.n_eval_episodes,
310 )
312 elif args.command == "online-attack":
313 online_attack(
314 config,
315 attack_name=args.attack_name,
316 target=args.target,
317 eps=args.eps,
318 norm=args.norm,
319 output_dir=args.output_dir,
320 agent_checkpoint=args.agent_checkpoint,
321 reference_checkpoint=args.reference_checkpoint,
322 adversary_checkpoint=args.adversary_checkpoint,
323 self_reference=args.self_reference,
324 render=args.render,
325 device=device,
326 seed=args.seed,
327 n_eval_episodes=args.n_eval_episodes,
328 )
330 elif args.command == "adversarial-training":
331 adversarial_train(
332 config,
333 attack_name=args.attack_name,
334 target=args.target,
335 eps=args.eps,
336 norm=args.norm,
337 output_dir=args.output_dir,
338 agent_checkpoint=args.agent_checkpoint,
339 reference_checkpoint=args.reference_checkpoint,
340 adversary_checkpoint=args.adversary_checkpoint,
341 self_reference=args.self_reference,
342 device=device,
343 seed=args.seed,
344 total_timesteps=args.total_timesteps,
345 prepopulate_timesteps=args.prepopulate_timesteps,
346 verbose=args.verbose,
347 )
349 else:
350 parser.print_help()
353if __name__ == "__main__":
354 main()