Coverage for tests / test_pipelines.py: 98%
86 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 07:50 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 07:50 +0000
1import os
2import sys
3import shutil
4from pathlib import Path
6import adaro_rl
7import adaro_rl.viz
8import adaro_rl.zoo as zoo
10def test(monkeypatch):
12 assert adaro_rl.__name__ == "adaro_rl"
14 import adaro_rl.pipelines.main as cli
16 # 1) stub the heavy training loop so the test is instant
17 monkeypatch.setattr("adaro_rl.pipelines.main.train", lambda *_, **__: None)
19 # 2) give argparse a *clean* argv (first element is just the program name)
20 monkeypatch.setattr(
21 sys,
22 "argv",
23 [
24 "adaro_rl", # <- argv[0]
25 "train", # <- sub-command
26 "--zoo", "adaro_rl.zoo",
27 "--config", "Enduro-v5",
28 "--checkpoint", "dummy.zip",
29 "--output-dir", "tmp_out",
30 ],
31 )
33 # 3) run: the parser should succeed and the stubbed train() is called
34 cli.main()
38 training_steps=100
39 n_eval_episodes=1
41 # device="cuda"
42 device="cpu"
43 seed=0
45 ############### Testing : Download - Test - Train Pipelines ###############
48 config_name="Enduro-v5"
50 output_dir = os.path.join("outputs", f"{config_name}")
51 agent_dir = os.path.join(output_dir,"agent")
52 agent_bis_dir = os.path.join(output_dir,"agent_bis")
53 adv_trained_dir = os.path.join(output_dir,"adv_trained_agent")
55 output_path = Path(output_dir)
56 agent_path = Path(agent_dir)
58 # Create the directory using the Path object
59 output_path.mkdir(parents=True, exist_ok=True)
61 config = zoo.configs[config_name]
63 print("📥 Download the agent")
64 agent_path.mkdir(parents=True, exist_ok=True)
65 zoo.download_model(config_name, local_dir=str(agent_dir))
67 print("\n🏋️ Train the agent")
68 adaro_rl.pipelines.train(
69 config=config,
70 checkpoint=os.path.join(agent_dir,"model.zip"),
71 output_dir=agent_bis_dir,
72 device=device,
73 seed=seed,
74 total_timesteps=training_steps,
75 prepopulate_timesteps=100,
76 verbose=False
77 )
79 print("\n🧪 Test the agent")
80 adaro_rl.pipelines.test(
81 config=config,
82 checkpoint=[os.path.join(agent_dir,"model.zip"),os.path.join(agent_bis_dir,"model.zip")],
83 output_dir=os.path.join(agent_dir,"results"),
84 device=device,
85 seed=seed,
86 n_eval_episodes=n_eval_episodes
87 )
89 ############### Testing : Train Adversary - Online Attack - Adversarial Training Pipelines ###############
92 attack_name="FGM_D"
93 target="untargeted"
94 adversary_checkpoint=None
95 eps=100.0
96 norm=2.0
98 print(f"\n🧨 Online attack: {attack_name}_{target}_{eps}_{norm}")
99 adaro_rl.pipelines.online_attack(
100 config=config,
101 attack_name=attack_name,
102 target=target,
103 eps=eps,
104 norm=norm,
105 output_dir=os.path.join(agent_dir,"results"),
106 agent_checkpoint=os.path.join(agent_dir,"model.zip"),
107 adversary_checkpoint=adversary_checkpoint,
108 self_reference=True,
109 device=device,
110 seed=seed,
111 n_eval_episodes=n_eval_episodes
112 )
114 print(f"\n🛡️ Adversarial training: {attack_name}_{target}_{eps}_{norm}")
115 adaro_rl.pipelines.adversarial_train(
116 config=config,
117 attack_name=attack_name,
118 target=target,
119 eps=eps,
120 norm=norm,
121 output_dir=str(adv_trained_dir),
122 agent_checkpoint=os.path.join(agent_dir,"model.zip"),
123 adversary_checkpoint=adversary_checkpoint,
124 self_reference=True,
125 device=device,
126 seed=seed,
127 total_timesteps=training_steps,
128 prepopulate_timesteps=0,
129 verbose=False
130 )
133 print("\n🧪 Test the adv trained agent")
134 adaro_rl.pipelines.test(
135 config=config,
136 checkpoint=os.path.join(adv_trained_dir,"model.zip"),
137 output_dir=os.path.join(adv_trained_dir,"results"),
138 device=device,
139 seed=seed,
140 n_eval_episodes=n_eval_episodes
141 )
143 print(f"\n🧨 Online attack the adv trained agent: {attack_name}_{target}_{eps}_{norm}")
144 adaro_rl.pipelines.online_attack(
145 config=config,
146 attack_name=attack_name,
147 target=target,
148 eps=eps,
149 norm=norm,
150 output_dir=os.path.join(adv_trained_dir,"results"),
151 agent_checkpoint=os.path.join(adv_trained_dir,"model.zip"),
152 adversary_checkpoint=adversary_checkpoint,
153 self_reference=True,
154 device=device,
155 seed=seed,
156 n_eval_episodes=n_eval_episodes
157 )
160 print(f"\n🧨 Plot robustness matrix : {attack_name}_{target}_{eps}_{norm}")
161 adaro_rl.viz.robustness_matrix(
162 env_name=config_name,
163 agent_dirs=[agent_dir, adv_trained_dir],
164 agent_names=["agent", "adv_trained_agent"],
165 norm=norm,
166 attack_list=[attack_name],
167 eps_list=[eps],
168 output_dir=output_dir,
169 fontsize=15
170 )
172 ############### Testing Ensemble Attack : Online Attack Pipeline ###############
175 attack_name="FGM_D"
176 target="untargeted"
177 adversary_checkpoint=None
178 eps=10
179 norm=0.0
181 print(f"\n🧨 Online attack: {attack_name}_{target}_{eps}_{norm}")
182 adaro_rl.pipelines.online_attack(
183 config=config,
184 attack_name=attack_name,
185 target=target,
186 eps=eps,
187 norm=norm,
188 output_dir=os.path.join(agent_dir,"results"),
189 agent_checkpoint=os.path.join(agent_dir,"model.zip"),
190 adversary_checkpoint=adversary_checkpoint,
191 self_reference=True,
192 device=device,
193 seed=seed,
194 n_eval_episodes=n_eval_episodes
195 )
198 ############### Testing Ensemble Attack : Online Attack - Adversarial Training Pipelines ###############
201 attack_name=["FGM_D", "FGM_D", "FGM_V", "FGM_V", "RNA", "RUA", "RSA", "FGSM_D", "FGSM_D", "FGSM_V", "FGSM_V",]
202 target=["targeted", "untargeted", "min", "max", None, None, None, "targeted", "untargeted", "min", "max"]
203 adversary_checkpoint=[None, None, None, None, None, None, None, None, None, None, None]
204 eps=[100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0]
205 norm=2.0
207 print(f"\n🧨 Online attack: {attack_name}_{target}_{eps}_{norm}")
208 adaro_rl.pipelines.online_attack(
209 config=config,
210 attack_name=attack_name,
211 target=target,
212 eps=eps,
213 norm=norm,
214 output_dir=os.path.join(agent_dir,"results"),
215 agent_checkpoint=os.path.join(agent_dir,"model.zip"),
216 adversary_checkpoint=adversary_checkpoint,
217 self_reference=True,
218 device=device,
219 seed=seed,
220 n_eval_episodes=n_eval_episodes
221 )
223 print(f"\n🛡️ Adversarial training: {attack_name}_{target}_{eps}_{norm}")
224 adaro_rl.pipelines.adversarial_train(
225 config=config,
226 attack_name=attack_name,
227 target=target,
228 eps=eps,
229 norm=norm,
230 output_dir=str(adv_trained_dir),
231 agent_checkpoint=os.path.join(agent_dir,"model.zip"),
232 adversary_checkpoint=adversary_checkpoint,
233 self_reference=True,
234 device=device,
235 seed=seed,
236 total_timesteps=training_steps,
237 prepopulate_timesteps=0,
238 verbose=False
239 )
242 ############### Testing Ensemble Attack : Online Attack Pipeline ###############
244 config_name="HalfCheetah-v5"
246 attack_name=["FGM_C", "FGM_QC", "FGM_QC", "FGM_QAC", "FGM_QAC", "FGSM_C", "FGSM_QC", "FGSM_QC", "FGSM_QAC", "FGSM_QAC"]
247 target=["untargeted", "min", "max", "min", "max", "untargeted", "min", "max", "min", "max"]
248 adversary_checkpoint=[None, None, None, None, None, None, None, None, None, None]
249 eps=[0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02]
250 norm=2.0
252 output_dir = os.path.join("outputs", f"{config_name}")
253 agent_dir = os.path.join(output_dir,"agent")
254 adv_trained_dir = os.path.join(output_dir,"adv_trained_agent")
256 output_path = Path(output_dir)
257 agent_path = Path(agent_dir)
259 # Create the directory using the Path object
260 output_path.mkdir(parents=True, exist_ok=True)
262 config = zoo.configs[config_name]
264 print("📥 Download the agent")
265 agent_path.mkdir(parents=True, exist_ok=True)
266 if not os.path.isfile(os.path.join(agent_dir,"model.zip")):
267 zoo.download_model(config_name, local_dir=str(agent_dir))
268 else:
269 print(f"{os.path.join(agent_dir,'model.zip')} already exists")
273 print(f"\n🧨 Online attack: {attack_name}_{target}_{eps}_{norm}")
274 adaro_rl.pipelines.online_attack(
275 config=config,
276 attack_name=attack_name,
277 target=target,
278 eps=eps,
279 norm=norm,
280 output_dir=os.path.join(agent_dir,"results"),
281 agent_checkpoint=os.path.join(agent_dir,"model.zip"),
282 adversary_checkpoint=adversary_checkpoint,
283 self_reference=True,
284 device=device,
285 seed=seed,
286 n_eval_episodes=n_eval_episodes
287 )
295if __name__ == "__main__":
296 test()