Coverage for tests / test_pipelines.py: 98%

86 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-14 07:50 +0000

1import os 

2import sys 

3import shutil 

4from pathlib import Path 

5 

6import adaro_rl 

7import adaro_rl.viz 

8import adaro_rl.zoo as zoo 

9 

10def test(monkeypatch): 

11 

12 assert adaro_rl.__name__ == "adaro_rl" 

13 

14 import adaro_rl.pipelines.main as cli 

15 

16 # 1) stub the heavy training loop so the test is instant 

17 monkeypatch.setattr("adaro_rl.pipelines.main.train", lambda *_, **__: None) 

18 

19 # 2) give argparse a *clean* argv (first element is just the program name) 

20 monkeypatch.setattr( 

21 sys, 

22 "argv", 

23 [ 

24 "adaro_rl", # <- argv[0] 

25 "train", # <- sub-command 

26 "--zoo", "adaro_rl.zoo", 

27 "--config", "Enduro-v5", 

28 "--checkpoint", "dummy.zip", 

29 "--output-dir", "tmp_out", 

30 ], 

31 ) 

32 

33 # 3) run: the parser should succeed and the stubbed train() is called 

34 cli.main() 

35 

36 

37 

38 training_steps=100 

39 n_eval_episodes=1 

40 

41 # device="cuda" 

42 device="cpu" 

43 seed=0 

44 

45 ############### Testing : Download - Test - Train Pipelines ############### 

46 

47 

48 config_name="Enduro-v5" 

49 

50 output_dir = os.path.join("outputs", f"{config_name}") 

51 agent_dir = os.path.join(output_dir,"agent") 

52 agent_bis_dir = os.path.join(output_dir,"agent_bis") 

53 adv_trained_dir = os.path.join(output_dir,"adv_trained_agent") 

54 

55 output_path = Path(output_dir) 

56 agent_path = Path(agent_dir) 

57 

58 # Create the directory using the Path object 

59 output_path.mkdir(parents=True, exist_ok=True) 

60 

61 config = zoo.configs[config_name] 

62 

63 print("📥 Download the agent") 

64 agent_path.mkdir(parents=True, exist_ok=True) 

65 zoo.download_model(config_name, local_dir=str(agent_dir)) 

66 

67 print("\n🏋️ Train the agent") 

68 adaro_rl.pipelines.train( 

69 config=config, 

70 checkpoint=os.path.join(agent_dir,"model.zip"), 

71 output_dir=agent_bis_dir, 

72 device=device, 

73 seed=seed, 

74 total_timesteps=training_steps, 

75 prepopulate_timesteps=100, 

76 verbose=False 

77 ) 

78 

79 print("\n🧪 Test the agent") 

80 adaro_rl.pipelines.test( 

81 config=config, 

82 checkpoint=[os.path.join(agent_dir,"model.zip"),os.path.join(agent_bis_dir,"model.zip")], 

83 output_dir=os.path.join(agent_dir,"results"), 

84 device=device, 

85 seed=seed, 

86 n_eval_episodes=n_eval_episodes 

87 ) 

88 

89 ############### Testing : Train Adversary - Online Attack - Adversarial Training Pipelines ############### 

90 

91 

92 attack_name="FGM_D" 

93 target="untargeted" 

94 adversary_checkpoint=None 

95 eps=100.0 

96 norm=2.0 

97 

98 print(f"\n🧨 Online attack: {attack_name}_{target}_{eps}_{norm}") 

99 adaro_rl.pipelines.online_attack( 

100 config=config, 

101 attack_name=attack_name, 

102 target=target, 

103 eps=eps, 

104 norm=norm, 

105 output_dir=os.path.join(agent_dir,"results"), 

106 agent_checkpoint=os.path.join(agent_dir,"model.zip"), 

107 adversary_checkpoint=adversary_checkpoint, 

108 self_reference=True, 

109 device=device, 

110 seed=seed, 

111 n_eval_episodes=n_eval_episodes 

112 ) 

113 

114 print(f"\n🛡️ Adversarial training: {attack_name}_{target}_{eps}_{norm}") 

115 adaro_rl.pipelines.adversarial_train( 

116 config=config, 

117 attack_name=attack_name, 

118 target=target, 

119 eps=eps, 

120 norm=norm, 

121 output_dir=str(adv_trained_dir), 

122 agent_checkpoint=os.path.join(agent_dir,"model.zip"), 

123 adversary_checkpoint=adversary_checkpoint, 

124 self_reference=True, 

125 device=device, 

126 seed=seed, 

127 total_timesteps=training_steps, 

128 prepopulate_timesteps=0, 

129 verbose=False 

130 ) 

131 

132 

133 print("\n🧪 Test the adv trained agent") 

134 adaro_rl.pipelines.test( 

135 config=config, 

136 checkpoint=os.path.join(adv_trained_dir,"model.zip"), 

137 output_dir=os.path.join(adv_trained_dir,"results"), 

138 device=device, 

139 seed=seed, 

140 n_eval_episodes=n_eval_episodes 

141 ) 

142 

143 print(f"\n🧨 Online attack the adv trained agent: {attack_name}_{target}_{eps}_{norm}") 

144 adaro_rl.pipelines.online_attack( 

145 config=config, 

146 attack_name=attack_name, 

147 target=target, 

148 eps=eps, 

149 norm=norm, 

150 output_dir=os.path.join(adv_trained_dir,"results"), 

151 agent_checkpoint=os.path.join(adv_trained_dir,"model.zip"), 

152 adversary_checkpoint=adversary_checkpoint, 

153 self_reference=True, 

154 device=device, 

155 seed=seed, 

156 n_eval_episodes=n_eval_episodes 

157 ) 

158 

159 

160 print(f"\n🧨 Plot robustness matrix : {attack_name}_{target}_{eps}_{norm}") 

161 adaro_rl.viz.robustness_matrix( 

162 env_name=config_name, 

163 agent_dirs=[agent_dir, adv_trained_dir], 

164 agent_names=["agent", "adv_trained_agent"], 

165 norm=norm, 

166 attack_list=[attack_name], 

167 eps_list=[eps], 

168 output_dir=output_dir, 

169 fontsize=15 

170 ) 

171 

172 ############### Testing Ensemble Attack : Online Attack Pipeline ############### 

173 

174 

175 attack_name="FGM_D" 

176 target="untargeted" 

177 adversary_checkpoint=None 

178 eps=10 

179 norm=0.0 

180 

181 print(f"\n🧨 Online attack: {attack_name}_{target}_{eps}_{norm}") 

182 adaro_rl.pipelines.online_attack( 

183 config=config, 

184 attack_name=attack_name, 

185 target=target, 

186 eps=eps, 

187 norm=norm, 

188 output_dir=os.path.join(agent_dir,"results"), 

189 agent_checkpoint=os.path.join(agent_dir,"model.zip"), 

190 adversary_checkpoint=adversary_checkpoint, 

191 self_reference=True, 

192 device=device, 

193 seed=seed, 

194 n_eval_episodes=n_eval_episodes 

195 ) 

196 

197 

198 ############### Testing Ensemble Attack : Online Attack - Adversarial Training Pipelines ############### 

199 

200 

201 attack_name=["FGM_D", "FGM_D", "FGM_V", "FGM_V", "RNA", "RUA", "RSA", "FGSM_D", "FGSM_D", "FGSM_V", "FGSM_V",] 

202 target=["targeted", "untargeted", "min", "max", None, None, None, "targeted", "untargeted", "min", "max"] 

203 adversary_checkpoint=[None, None, None, None, None, None, None, None, None, None, None] 

204 eps=[100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0] 

205 norm=2.0 

206 

207 print(f"\n🧨 Online attack: {attack_name}_{target}_{eps}_{norm}") 

208 adaro_rl.pipelines.online_attack( 

209 config=config, 

210 attack_name=attack_name, 

211 target=target, 

212 eps=eps, 

213 norm=norm, 

214 output_dir=os.path.join(agent_dir,"results"), 

215 agent_checkpoint=os.path.join(agent_dir,"model.zip"), 

216 adversary_checkpoint=adversary_checkpoint, 

217 self_reference=True, 

218 device=device, 

219 seed=seed, 

220 n_eval_episodes=n_eval_episodes 

221 ) 

222 

223 print(f"\n🛡️ Adversarial training: {attack_name}_{target}_{eps}_{norm}") 

224 adaro_rl.pipelines.adversarial_train( 

225 config=config, 

226 attack_name=attack_name, 

227 target=target, 

228 eps=eps, 

229 norm=norm, 

230 output_dir=str(adv_trained_dir), 

231 agent_checkpoint=os.path.join(agent_dir,"model.zip"), 

232 adversary_checkpoint=adversary_checkpoint, 

233 self_reference=True, 

234 device=device, 

235 seed=seed, 

236 total_timesteps=training_steps, 

237 prepopulate_timesteps=0, 

238 verbose=False 

239 ) 

240 

241 

242 ############### Testing Ensemble Attack : Online Attack Pipeline ############### 

243 

244 config_name="HalfCheetah-v5" 

245 

246 attack_name=["FGM_C", "FGM_QC", "FGM_QC", "FGM_QAC", "FGM_QAC", "FGSM_C", "FGSM_QC", "FGSM_QC", "FGSM_QAC", "FGSM_QAC"] 

247 target=["untargeted", "min", "max", "min", "max", "untargeted", "min", "max", "min", "max"] 

248 adversary_checkpoint=[None, None, None, None, None, None, None, None, None, None] 

249 eps=[0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02] 

250 norm=2.0 

251 

252 output_dir = os.path.join("outputs", f"{config_name}") 

253 agent_dir = os.path.join(output_dir,"agent") 

254 adv_trained_dir = os.path.join(output_dir,"adv_trained_agent") 

255 

256 output_path = Path(output_dir) 

257 agent_path = Path(agent_dir) 

258 

259 # Create the directory using the Path object 

260 output_path.mkdir(parents=True, exist_ok=True) 

261 

262 config = zoo.configs[config_name] 

263 

264 print("📥 Download the agent") 

265 agent_path.mkdir(parents=True, exist_ok=True) 

266 if not os.path.isfile(os.path.join(agent_dir,"model.zip")): 

267 zoo.download_model(config_name, local_dir=str(agent_dir)) 

268 else: 

269 print(f"{os.path.join(agent_dir,'model.zip')} already exists") 

270 

271 

272 

273 print(f"\n🧨 Online attack: {attack_name}_{target}_{eps}_{norm}") 

274 adaro_rl.pipelines.online_attack( 

275 config=config, 

276 attack_name=attack_name, 

277 target=target, 

278 eps=eps, 

279 norm=norm, 

280 output_dir=os.path.join(agent_dir,"results"), 

281 agent_checkpoint=os.path.join(agent_dir,"model.zip"), 

282 adversary_checkpoint=adversary_checkpoint, 

283 self_reference=True, 

284 device=device, 

285 seed=seed, 

286 n_eval_episodes=n_eval_episodes 

287 ) 

288 

289 

290 

291 

292 

293 

294 

295if __name__ == "__main__": 

296 test()