Coverage for adaro_rl / pipelines / main.py: 88%

97 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-14 07:50 +0000

1from . import train, test, online_attack, adversarial_train 

2import argparse 

3import importlib 

4 

5DEFAULT_ZOO = "adaro_rl.zoo" 

6ZOO_HELP = "Zoo module where configs are stored." 

7CONFIG_HELP = "Name of the config to use." 

8CHECKPOINT_HELP = "Checkpoint to load (if any)." 

9AGENT_CHECKPOINT_HELP = "Agent checkpoint to load (if any)." 

10LIST_AGENT_CHECKPOINT_HELP = "List of agent checkpoint to load (if any)." 

11ADVERSARY_CHECKPOINT_HELP = "List of Adversary checkpoint to load (if any)." 

12OUTPUT_DIR_HELP_MODEL = "Directory to store agent and logs." 

13OUTPUT_DIR_HELP_RESULTS = "Directory to store the results." 

14DEVICE_HELP = "Device to run the experiment on." 

15SEED_HELP = "Random seed." 

16NORM_HELP = "Norm of the perturbation." 

17TRAIN_STEP_HELP = "Number or training steps." 

18PREPOPULATE_HELP = "Number of step before starting training" 

19 

20 

21def main() -> None: 

22 """ 

23 Main entry point for the CLI utility to run various reinforcement learning tasks. 

24 

25 This function parses command-line arguments to select and run one of several tasks: 

26 - Train a reinforcement learning agent. 

27 - Test a trained agent in the environment. 

28 - Train an adversarial agent to attack a victim agent. 

29 - Perform an online adversarial attack using specified attack methods. 

30 - Perform adversarial training using single or ensemble attacks. 

31 

32 The function dynamically loads a configuration from a specified zoo module, verifies 

33 the configuration, and dispatches the execution to the appropriate handler based on 

34 the command selected by the user. 

35 

36 CLI Commands Supported 

37 ---------------------- 

38 - `train`: Train a new RL agent. 

39 - `test`: Evaluate a trained agent. 

40 - `online-attack`: Run online adversarial attacks on an agent. 

41 - `adversarial-training`: Perform adversarial training with attack strategies. 

42 

43 Command-Line Arguments 

44 ---------------------- 

45 Each subcommand has its own set of arguments, including but not limited to: 

46 zoo : str 

47 The import path to the module that contains available experiment configurations. 

48 config : str 

49 The key name of the configuration to use from the zoo module. 

50 checkpoint : str, optional 

51 Path to the checkpoint of a trained model (for agent or adversary). 

52 output-dir : str, optional 

53 Directory where model artifacts or results should be saved. 

54 device : str, optional 

55 The device to run the experiment on (e.g., "cpu" or "cuda"). 

56 seed : int, optional 

57 Random seed to ensure reproducibility. 

58 render : bool, optional 

59 Whether to enable environment rendering during evaluation. 

60 

61 Raises 

62 ------ 

63 ImportError 

64 If the specified zoo module cannot be imported. 

65 KeyError 

66 If the specified configuration name is not found in the zoo module. 

67 

68 Returns 

69 ------- 

70 None 

71 The function executes the specified command and exits without returning any value. 

72 """ 

73 

74 parser = argparse.ArgumentParser(description="Script for multiple functionalities") 

75 subparsers = parser.add_subparsers( 

76 dest="command", required=True, help="Available commands" 

77 ) 

78 

79 # Train command 

80 parser_train = subparsers.add_parser("train", help="Train an agent") 

81 parser_train.add_argument("--zoo", type=str, default=DEFAULT_ZOO, help=ZOO_HELP) 

82 parser_train.add_argument("--config", type=str, required=True, help=CONFIG_HELP) 

83 parser_train.add_argument( 

84 "--checkpoint", type=str, default=None, help=CHECKPOINT_HELP 

85 ) 

86 parser_train.add_argument( 

87 "--output-dir", type=str, default="agent", help=OUTPUT_DIR_HELP_MODEL 

88 ) 

89 parser_train.add_argument("--device", type=str, default="cpu", help=DEVICE_HELP) 

90 parser_train.add_argument("--seed", type=int, default=0, help=SEED_HELP) 

91 parser_train.add_argument( 

92 "--total-timesteps", type=int, default=None, help=TRAIN_STEP_HELP 

93 ) 

94 parser_train.add_argument( 

95 "--prepopulate-timesteps", type=int, default=None, help=PREPOPULATE_HELP 

96 ) 

97 parser_train.add_argument("--verbose", dest="verbose", action="store_true") 

98 parser_train.add_argument("--no-verbose", dest="verbose", action="store_false") 

99 parser_train.set_defaults(verbose=False) 

100 

101 # Test command 

102 parser_test = subparsers.add_parser("test", help="Test an agent") 

103 parser_test.add_argument("--zoo", type=str, default=DEFAULT_ZOO, help=ZOO_HELP) 

104 parser_test.add_argument("--config", type=str, required=True, help=CONFIG_HELP) 

105 parser_test.add_argument( 

106 "--checkpoint", type=str, required=True, help=LIST_AGENT_CHECKPOINT_HELP 

107 ) 

108 parser_test.add_argument( 

109 "--output-dir", type=str, default="agent", help=OUTPUT_DIR_HELP_RESULTS 

110 ) 

111 parser_test.add_argument("--render", action="store_true", help="Enable rendering.") 

112 parser_test.add_argument("--device", type=str, default="cpu", help=DEVICE_HELP) 

113 parser_test.add_argument("--seed", type=int, default=0, help=SEED_HELP) 

114 parser_test.add_argument( 

115 "--n-eval-episodes", type=int, default=None, help="Number of evaluation steps." 

116 ) 

117 

118 # Online attack command 

119 parser_online_attack = subparsers.add_parser( 

120 "online-attack", help="Perform an online attack on an agent" 

121 ) 

122 parser_online_attack.add_argument( 

123 "--zoo", type=str, default=DEFAULT_ZOO, help=ZOO_HELP 

124 ) 

125 parser_online_attack.add_argument( 

126 "--config", type=str, required=True, help=CONFIG_HELP 

127 ) 

128 parser_online_attack.add_argument( 

129 "--attack-name", 

130 type=str, 

131 required=True, 

132 nargs="+", 

133 help="List of attack names.", 

134 ) 

135 parser_online_attack.add_argument( 

136 "--target", 

137 type=str, 

138 required=True, 

139 nargs="+", 

140 help="List of target specifications for each attack.", 

141 ) 

142 parser_online_attack.add_argument( 

143 "--eps", 

144 type=float, 

145 required=True, 

146 nargs="+", 

147 help="List of perturbation amounts (epsilon).", 

148 ) 

149 parser_online_attack.add_argument( 

150 "--norm", type=float, required=True, help=NORM_HELP 

151 ) 

152 parser_online_attack.add_argument( 

153 "--adversary-checkpoint", 

154 type=str, 

155 default=None, 

156 nargs="+", 

157 help=ADVERSARY_CHECKPOINT_HELP, 

158 ) 

159 parser_online_attack.add_argument( 

160 "--output-dir", type=str, default="agent", help=OUTPUT_DIR_HELP_RESULTS 

161 ) 

162 parser_online_attack.add_argument( 

163 "--agent-checkpoint", type=str, required=True, help=LIST_AGENT_CHECKPOINT_HELP 

164 ) 

165 parser_online_attack.add_argument( 

166 "--reference-checkpoint", 

167 type=str, 

168 default=None, 

169 help="Reference agent checkpoint to load (if any).", 

170 ) 

171 parser_online_attack.add_argument( 

172 "--self-reference", 

173 action="store_true", 

174 help="Flag for using the primary agent as the reference agent; otherwise it uses `reference_checkpoint`.", 

175 ) 

176 parser_online_attack.add_argument( 

177 "--render", action="store_true", help="Enable rendering." 

178 ) 

179 parser_online_attack.add_argument( 

180 "--device", type=str, default="cpu", help=DEVICE_HELP 

181 ) 

182 parser_online_attack.add_argument("--seed", type=int, default=0, help=SEED_HELP) 

183 parser_online_attack.add_argument( 

184 "--n-eval-episodes", type=int, default=None, help="Number of evaluation steps." 

185 ) 

186 

187 # Adversarial training command 

188 parser_adv_training = subparsers.add_parser( 

189 "adversarial-training", help="Perform adversarial training" 

190 ) 

191 parser_adv_training.add_argument( 

192 "--zoo", type=str, default=DEFAULT_ZOO, help=ZOO_HELP 

193 ) 

194 parser_adv_training.add_argument( 

195 "--config", type=str, required=True, help=CONFIG_HELP 

196 ) 

197 parser_adv_training.add_argument( 

198 "--attack-name", 

199 type=str, 

200 required=True, 

201 nargs="+", 

202 help="List of attack names.", 

203 ) 

204 parser_adv_training.add_argument( 

205 "--target", 

206 type=str, 

207 required=True, 

208 nargs="+", 

209 help="List of target specifications for each attack.", 

210 ) 

211 parser_adv_training.add_argument( 

212 "--eps", 

213 type=float, 

214 required=True, 

215 nargs="+", 

216 help="List of perturbation amounts (epsilon).", 

217 ) 

218 parser_adv_training.add_argument( 

219 "--norm", type=float, required=True, help=NORM_HELP 

220 ) 

221 parser_adv_training.add_argument( 

222 "--adversary-checkpoint", 

223 type=str, 

224 default=None, 

225 nargs="+", 

226 help=ADVERSARY_CHECKPOINT_HELP, 

227 ) 

228 parser_adv_training.add_argument( 

229 "--output-dir", type=str, default="adv_agent", help=OUTPUT_DIR_HELP_MODEL 

230 ) 

231 parser_adv_training.add_argument( 

232 "--agent-checkpoint", type=str, default=None, help=AGENT_CHECKPOINT_HELP 

233 ) 

234 parser_adv_training.add_argument( 

235 "--reference-checkpoint", 

236 type=str, 

237 default=None, 

238 help="Reference agent checkpoint to load (if any).", 

239 ) 

240 parser_adv_training.add_argument( 

241 "--self-reference", 

242 action="store_true", 

243 help="Flag for using the primary agent as the reference agent; otherwise it uses `reference_checkpoint`.", 

244 ) 

245 parser_adv_training.add_argument( 

246 "--device", type=str, default="cpu", help=DEVICE_HELP 

247 ) 

248 parser_adv_training.add_argument("--seed", type=int, default=0, help=SEED_HELP) 

249 parser_adv_training.add_argument( 

250 "--total-timesteps", type=int, default=None, help=TRAIN_STEP_HELP 

251 ) 

252 parser_adv_training.add_argument( 

253 "--prepopulate-timesteps", type=int, default=None, help=PREPOPULATE_HELP 

254 ) 

255 parser_adv_training.add_argument("--verbose", dest="verbose", action="store_true") 

256 parser_adv_training.add_argument( 

257 "--no-verbose", dest="verbose", action="store_false" 

258 ) 

259 parser_adv_training.set_defaults(verbose=False) 

260 

261 args = parser.parse_args() 

262 

263 # Common setup 

264 device = args.device # Or use a helper like setup_device(args.device) 

265 # Example: setup_seed(args.seed) if you have such a function 

266 

267 # Import the zoo module 

268 try: 

269 zoo = importlib.import_module(args.zoo) 

270 except ImportError: 

271 raise ImportError( 

272 f"Could not import the specified zoo module: {args.zoo}. Please ensure it is" 

273 "installed and available." 

274 ) 

275 

276 # Ensure the config exists 

277 if args.config not in zoo.configs: 

278 available_configs = list(zoo.configs.keys()) 

279 raise KeyError( 

280 f"The configuration '{args.config}' does not exist in the zoo module '{args.zoo}'. " 

281 f"Available configurations: {available_configs}" 

282 ) 

283 

284 config = zoo.configs[args.config] 

285 

286 # Dispatch based on the selected command 

287 if args.command == "train": 

288 # For training, new signature: train(config, output_dir, checkpoint, device, seed) 

289 train( 

290 config, 

291 output_dir=args.output_dir, 

292 checkpoint=args.checkpoint, 

293 device=device, 

294 seed=args.seed, 

295 total_timesteps=args.total_timesteps, 

296 prepopulate_timesteps=args.prepopulate_timesteps, 

297 verbose=args.verbose, 

298 ) 

299 

300 elif args.command == "test": 

301 # For testing, new signature: test(config, output_dir, checkpoint, render, device, seed) 

302 test( 

303 config, 

304 output_dir=args.output_dir, 

305 checkpoint=args.checkpoint, 

306 render=args.render, 

307 device=device, 

308 seed=args.seed, 

309 n_eval_episodes=args.n_eval_episodes, 

310 ) 

311 

312 elif args.command == "online-attack": 

313 online_attack( 

314 config, 

315 attack_name=args.attack_name, 

316 target=args.target, 

317 eps=args.eps, 

318 norm=args.norm, 

319 output_dir=args.output_dir, 

320 agent_checkpoint=args.agent_checkpoint, 

321 reference_checkpoint=args.reference_checkpoint, 

322 adversary_checkpoint=args.adversary_checkpoint, 

323 self_reference=args.self_reference, 

324 render=args.render, 

325 device=device, 

326 seed=args.seed, 

327 n_eval_episodes=args.n_eval_episodes, 

328 ) 

329 

330 elif args.command == "adversarial-training": 

331 adversarial_train( 

332 config, 

333 attack_name=args.attack_name, 

334 target=args.target, 

335 eps=args.eps, 

336 norm=args.norm, 

337 output_dir=args.output_dir, 

338 agent_checkpoint=args.agent_checkpoint, 

339 reference_checkpoint=args.reference_checkpoint, 

340 adversary_checkpoint=args.adversary_checkpoint, 

341 self_reference=args.self_reference, 

342 device=device, 

343 seed=args.seed, 

344 total_timesteps=args.total_timesteps, 

345 prepopulate_timesteps=args.prepopulate_timesteps, 

346 verbose=args.verbose, 

347 ) 

348 

349 else: 

350 parser.print_help() 

351 

352 

353if __name__ == "__main__": 

354 main()