Coverage for adaro_rl / pipelines / adversarial_train.py: 78%

59 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-14 07:50 +0000

1import os 

2import gymnasium as gym 

3from stable_baselines3.common.vec_env import DummyVecEnv 

4 

5from ..attacks import make_attack, training_method_names, EnsembleAttackWrapper 

6from ..wrappers import ObsAttackAgentEnv 

7from .utils import normalize_lists, make_attack_list 

8 

9 

10def adversarial_train( 

11 config, 

12 attack_name, 

13 target, 

14 eps, 

15 norm, 

16 adversary_checkpoint=None, 

17 output_dir="agent", 

18 agent_checkpoint=None, 

19 reference_checkpoint=None, 

20 self_reference=False, 

21 device="cpu", 

22 seed=None, 

23 total_timesteps=None, 

24 prepopulate_timesteps=None, 

25 verbose=True, 

26): 

27 """ 

28 Perform adversarial training with specified attack configurations and environments. 

29 

30 This function sets up the training and evaluation environments, constructs 

31 the main agent and adversarial policy agents (or a reference agent), builds the 

32 adversarial attack (either single or ensemble), and finally trains the adversarial 

33 agent using the provided learning routine. After training, the agent is saved to 

34 the specified output directory. 

35 

36 Parameters 

37 ---------- 

38 config : object 

39 A configuration object containing environment and agent settings. It must include 

40 keys 'env_config' and 'agent_config' with appropriate sub-keys such as: 

41 - env_config: with keys "env_id", "n_envs", "n_frame_stack", "wrapper_class", 

42 "training_env_kwargs", "eval_env_kwargs", and "deterministic_eval". 

43 - agent_config: with keys "algo", "adv_trained_algo_kwargs", "algo_kwargs", "adversary_algo_kwargs", 

44 and "adv_training_kwargs" (which further includes "total_timesteps" and "prepopulate_timesteps"). 

45 attack_name : str or list of str 

46 The name(s) of the attack method(s) to be used for adversarial training. 

47 target : str or list of str 

48 The target specification(s) for the attack. If an entry is "target_fct", then 

49 `config.target_fct` will be used. 

50 eps : float or list of float 

51 The perturbation budget(s) (epsilon) for the adversarial attack. 

52 norm : {0, 1, 2, float('inf')} 

53 The norm to be used for the perturbation (e.g., 0 for sparse, 1/L1, 2/L2, or float('inf') for L∞). 

54 adversary_checkpoint : str or list of str, optional 

55 The file path(s) to the checkpoint(s) for the adversarial agent. If not provided, 

56 the reference agent is used where applicable. 

57 output_dir : str, optional 

58 Directory where the agent and training artifacts will be saved. Default is "agent". 

59 agent_checkpoint : str, optional 

60 A checkpoint for the main agent (if applicable). 

61 reference_checkpoint : str, optional 

62 A checkpoint for the reference agent used when `self_reference` is False. 

63 self_reference : bool, optional 

64 If True, the adversarially trained agent is also used as the reference agent. 

65 Otherwise, the reference agent is constructed using the provided `reference_checkpoint`. 

66 device : str, optional 

67 The device string for computation (e.g., "cpu" or "cuda"). Default is "cpu". 

68 seed : int, optional 

69 Random seed for reproducibility. 

70 

71 Returns 

72 ------- 

73 int 

74 Returns 1 upon successful completion of adversarial training. 

75 """ 

76 

77 attack_name_list, target_list, eps_list, adversary_checkpoint_list = ( 

78 normalize_lists( 

79 attack_name=attack_name, 

80 target=target, 

81 eps=eps, 

82 adversary_checkpoint=adversary_checkpoint, 

83 ).values() 

84 ) 

85 

86 os.makedirs(output_dir, exist_ok=True) 

87 

88 if total_timesteps is None: 

89 total_timesteps = config.finetuned_training_kwargs["total_timesteps"] 

90 if prepopulate_timesteps is None: 

91 prepopulate_timesteps = config.finetuned_training_kwargs[ 

92 "prepopulate_timesteps" 

93 ] 

94 

95 # ENV #################### 

96 

97 os.environ["SDL_VIDEODRIVER"] = "dummy" 

98 

99 base_env = config.make_env( 

100 env_id=config.train_env_config["env_id"], 

101 n_envs=config.train_env_config["n_envs"], 

102 n_frame_stack=config.train_env_config["n_frame_stack"], 

103 wrapper_class=config.train_env_config["wrapper_class"], 

104 adv_wrapper_class=None, 

105 env_kwargs=config.train_env_config["env_kwargs"], 

106 seed=seed, 

107 ) 

108 

109 # AGENT #################### 

110 

111 class Placeholder(gym.Env): 

112 def __init__(self, env): 

113 self.observation_space = env.observation_space 

114 self.action_space = env.action_space 

115 

116 def placeholder_fct(cls=Placeholder, env=base_env): 

117 return cls(env) 

118 

119 # placeholder_fct = lambda cls=Placeholder, env=base_env: cls(env) 

120 placeholder_env = DummyVecEnv( 

121 [placeholder_fct for _ in range(config.train_env_config["n_envs"])] 

122 ) 

123 

124 def make_adv_trained_agent_fct(): 

125 return config.make_agent( 

126 algo=config.finetuned_agent_config["algo"], 

127 env=placeholder_env, 

128 checkpoint=agent_checkpoint, 

129 output_dir=output_dir, 

130 device=device, 

131 seed=seed, 

132 verbose=verbose, 

133 algo_kwargs=config.finetuned_agent_config["algo_kwargs"], 

134 ) 

135 

136 # make_adv_trained_agent_fct = lambda: config.make_agent( 

137 # algo=config.finetuned_agent_config["algo"], 

138 # env=placeholder_env, 

139 # checkpoint=agent_checkpoint, 

140 # output_dir=output_dir, 

141 # device=device, 

142 # seed=seed, 

143 # verbose=verbose, 

144 # algo_kwargs=config.finetuned_agent_config["algo_kwargs"], 

145 # ) 

146 

147 adv_trained_agent = make_adv_trained_agent_fct() 

148 

149 adv_trained_agent.train() 

150 

151 # REFERENCE AGENT #################### 

152 

153 if self_reference: 

154 make_reference_agent_fct = make_adv_trained_agent_fct 

155 else: 

156 

157 def make_reference_agent_fct(): 

158 return config.make_agent( 

159 algo=config.agent_config["algo"], 

160 env=placeholder_env, 

161 checkpoint=reference_checkpoint, 

162 device=device, 

163 seed=seed, 

164 verbose=verbose, 

165 algo_kwargs=config.agent_config["algo_kwargs"], 

166 ) 

167 # make_reference_agent_fct = lambda: config.make_agent( 

168 # algo=config.agent_config["algo"], 

169 # env=placeholder_env, 

170 # checkpoint=reference_checkpoint, 

171 # device=device, 

172 # seed=seed, 

173 # verbose=verbose, 

174 # algo_kwargs=config.agent_config["algo_kwargs"], 

175 # ) 

176 

177 # ADVERSARIAL POLICY #################### 

178 

179 make_agent_fct_list_for_attack = [] 

180 

181 for attack_name, adversary_checkpoint in zip( 

182 attack_name_list, adversary_checkpoint_list 

183 ): 

184 if attack_name in training_method_names: 

185 adv_action_space = base_env.get_attr("observation_perturbation_space")[0] 

186 

187 class AdversaryPlaceholder(gym.Env): 

188 def __init__(self, env): 

189 self.observation_space = env.observation_space 

190 self.action_space = adv_action_space 

191 

192 adversary_placeholder_fct = ( 

193 lambda cls=AdversaryPlaceholder, env=base_env: cls(env) 

194 ) 

195 adversary_placeholder_env = DummyVecEnv( 

196 [ 

197 adversary_placeholder_fct 

198 for _ in range(config.train_env_config["n_envs"]) 

199 ] 

200 ) 

201 

202 make_adversary_agent_fct = ( 

203 lambda adversary_placeholder_env=adversary_placeholder_env, 

204 adversary_checkpoint=adversary_checkpoint: config.make_agent( 

205 algo=config.agent_config["algo"], 

206 env=adversary_placeholder_env, 

207 checkpoint=adversary_checkpoint, 

208 device=device, 

209 seed=seed, 

210 verbose=verbose, 

211 algo_kwargs=config.adversary_config["algo_kwargs"], 

212 ) 

213 ) 

214 

215 make_agent_fct_list_for_attack.append(make_adversary_agent_fct) 

216 

217 else: 

218 make_agent_fct_list_for_attack.append(make_reference_agent_fct) 

219 

220 make_attack_fct_list = make_attack_list( 

221 base_env, 

222 attack_name_list, 

223 make_agent_fct_list_for_attack, 

224 target_list, 

225 eps_list, 

226 config, 

227 make_attack, 

228 norm, 

229 device, 

230 ) 

231 

232 if len(make_attack_fct_list) > 1: 

233 

234 def make_attack_fct(): 

235 return EnsembleAttackWrapper( 

236 make_attack_fct_list=make_attack_fct_list 

237 ) 

238 

239 # make_attack_fct = lambda: EnsembleAttackWrapper( 

240 # make_attack_fct_list=make_attack_fct_list 

241 # ) 

242 else: 

243 make_attack_fct = make_attack_fct_list[0] 

244 

245 # ADV WRAPPER #################### 

246 

247 adv_wrapper_class = ObsAttackAgentEnv 

248 adv_wrapper_kwargs = {"make_attack_fct": make_attack_fct, "freq": 1} 

249 

250 # MAKE ENV #################### 

251 

252 training_env = config.make_env( 

253 env_id=config.train_env_config["env_id"], 

254 n_envs=config.train_env_config["n_envs"], 

255 n_frame_stack=config.train_env_config["n_frame_stack"], 

256 wrapper_class=config.train_env_config["wrapper_class"], 

257 adv_wrapper_class=adv_wrapper_class, 

258 adv_wrapper_kwargs=adv_wrapper_kwargs, 

259 env_kwargs=config.train_env_config["env_kwargs"], 

260 seed=seed, 

261 ) 

262 

263 eval_env = config.make_env( 

264 env_id=config.eval_env_config["env_id"], 

265 n_envs=config.eval_env_config["n_envs"], 

266 n_frame_stack=config.eval_env_config["n_frame_stack"], 

267 wrapper_class=config.eval_env_config["wrapper_class"], 

268 adv_wrapper_class=adv_wrapper_class, 

269 adv_wrapper_kwargs=adv_wrapper_kwargs, 

270 env_kwargs=config.eval_env_config["env_kwargs"], 

271 seed=seed, 

272 ) 

273 

274 adv_trained_agent.model.set_env(training_env) 

275 placeholder_env.close() 

276 base_env.close() 

277 

278 # TRAINING #################### 

279 

280 adv_trained_agent.learn( 

281 eval_env, 

282 total_timesteps=total_timesteps, 

283 prepopulate_timesteps=prepopulate_timesteps, 

284 ) 

285 

286 # SAVE #################### 

287 

288 adv_trained_agent.save(checkpoint=os.path.join(output_dir, "model.zip")) 

289 

290 # CLOSE ################### 

291 training_env.close() 

292 eval_env.close()