Coverage for adaro_rl / pipelines / train.py: 88%

16 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-14 07:50 +0000

1import os 

2 

3 

4def train( 

5 config, 

6 output_dir="agent", 

7 checkpoint=None, 

8 device="cpu", 

9 seed=None, 

10 total_timesteps=None, 

11 prepopulate_timesteps=None, 

12 verbose=True, 

13): 

14 """ 

15 Train a reinforcement learning agent using the provided configuration. 

16 

17 This function creates training and evaluation environments, instantiates (or loads) 

18 an agent based on the provided configuration, and runs the training loop via the 

19 `learn` utility. After training, the final model is saved to disk along with a JSON 

20 file containing minimal agent parameters. 

21 

22 Parameters 

23 ---------- 

24 config : object 

25 Configuration object containing: 

26 - env_config (dict): Environment configuration including keys such as 

27 "env_id", "n_envs", "n_frame_stack", "wrapper_class", "training_env_kwargs", 

28 "eval_env_kwargs". It may also include additional 

29 fields required by the environment creation functions. 

30 - agent_config (dict): Agent configuration including keys such as "algo", 

31 "algo_kwargs", and "training_kwargs". The "training_kwargs" dictionary should 

32 include "total_timesteps". 

33 output_dir : str, optional 

34 Base directory for storing the agent's logs, checkpoints, and parameter files. 

35 Default is "agent". 

36 checkpoint : str, optional 

37 Path to a checkpoint for a pre-trained agent. If provided, the agent is loaded 

38 from this checkpoint. Default is None. 

39 device : str, optional 

40 The computation device for the agent (e.g., "cpu", "cuda:0"). Default is "cpu". 

41 seed : int, optional 

42 Random seed for setting up the environment and agent. Default is 0. 

43 

44 Returns 

45 ------- 

46 None 

47 

48 Side Effects 

49 ------------ 

50 - Creates training and evaluation environments using the specified wrappers and settings. 

51 - Instantiates an agent (or loads it from a checkpoint) and runs the training loop via `learn`. 

52 - Saves the final trained model as "last_model" in the specified output directory. 

53 - Writes a JSON file ("parameters.json") in the output directory containing minimal agent parameters. 

54 - Closes the training and evaluation environments upon completion. 

55 """ 

56 

57 os.makedirs(output_dir, exist_ok=True) 

58 

59 if total_timesteps is None: 

60 total_timesteps = config.training_kwargs["total_timesteps"] 

61 if prepopulate_timesteps is None: 

62 prepopulate_timesteps = config.training_kwargs["prepopulate_timesteps"] 

63 

64 # ENV #################### 

65 

66 os.environ["SDL_VIDEODRIVER"] = "dummy" 

67 

68 training_env = config.make_env( 

69 env_id=config.train_env_config["env_id"], 

70 n_envs=config.train_env_config["n_envs"], 

71 n_frame_stack=config.train_env_config["n_frame_stack"], 

72 wrapper_class=config.train_env_config["wrapper_class"], 

73 adv_wrapper_class=None, 

74 env_kwargs=config.train_env_config["env_kwargs"], 

75 seed=seed, 

76 ) 

77 

78 eval_env = config.make_env( 

79 env_id=config.eval_env_config["env_id"], 

80 n_envs=config.eval_env_config["n_envs"], 

81 n_frame_stack=config.eval_env_config["n_frame_stack"], 

82 wrapper_class=config.eval_env_config["wrapper_class"], 

83 adv_wrapper_class=None, 

84 env_kwargs=config.eval_env_config["env_kwargs"], 

85 seed=seed, 

86 ) 

87 

88 # Agent #################### 

89 

90 agent = config.make_agent( 

91 algo=config.agent_config["algo"], 

92 env=training_env, 

93 checkpoint=checkpoint, 

94 output_dir=output_dir, 

95 device=device, 

96 seed=seed, 

97 verbose=verbose, 

98 algo_kwargs=config.agent_config["algo_kwargs"], 

99 ) 

100 

101 agent.train() 

102 

103 # TRAINING #################### 

104 

105 agent.learn( 

106 eval_env, 

107 total_timesteps=total_timesteps, 

108 prepopulate_timesteps=prepopulate_timesteps, 

109 ) 

110 

111 # SAVE #################### 

112 

113 agent.save(checkpoint=os.path.join(output_dir, "model.zip")) 

114 

115 # CLOSE #################### 

116 

117 training_env.close() 

118 eval_env.close()