Coverage for adaro_rl / pipelines / train.py: 88%
16 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 07:50 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 07:50 +0000
1import os
4def train(
5 config,
6 output_dir="agent",
7 checkpoint=None,
8 device="cpu",
9 seed=None,
10 total_timesteps=None,
11 prepopulate_timesteps=None,
12 verbose=True,
13):
14 """
15 Train a reinforcement learning agent using the provided configuration.
17 This function creates training and evaluation environments, instantiates (or loads)
18 an agent based on the provided configuration, and runs the training loop via the
19 `learn` utility. After training, the final model is saved to disk along with a JSON
20 file containing minimal agent parameters.
22 Parameters
23 ----------
24 config : object
25 Configuration object containing:
26 - env_config (dict): Environment configuration including keys such as
27 "env_id", "n_envs", "n_frame_stack", "wrapper_class", "training_env_kwargs",
28 "eval_env_kwargs". It may also include additional
29 fields required by the environment creation functions.
30 - agent_config (dict): Agent configuration including keys such as "algo",
31 "algo_kwargs", and "training_kwargs". The "training_kwargs" dictionary should
32 include "total_timesteps".
33 output_dir : str, optional
34 Base directory for storing the agent's logs, checkpoints, and parameter files.
35 Default is "agent".
36 checkpoint : str, optional
37 Path to a checkpoint for a pre-trained agent. If provided, the agent is loaded
38 from this checkpoint. Default is None.
39 device : str, optional
40 The computation device for the agent (e.g., "cpu", "cuda:0"). Default is "cpu".
41 seed : int, optional
42 Random seed for setting up the environment and agent. Default is 0.
44 Returns
45 -------
46 None
48 Side Effects
49 ------------
50 - Creates training and evaluation environments using the specified wrappers and settings.
51 - Instantiates an agent (or loads it from a checkpoint) and runs the training loop via `learn`.
52 - Saves the final trained model as "last_model" in the specified output directory.
53 - Writes a JSON file ("parameters.json") in the output directory containing minimal agent parameters.
54 - Closes the training and evaluation environments upon completion.
55 """
57 os.makedirs(output_dir, exist_ok=True)
59 if total_timesteps is None:
60 total_timesteps = config.training_kwargs["total_timesteps"]
61 if prepopulate_timesteps is None:
62 prepopulate_timesteps = config.training_kwargs["prepopulate_timesteps"]
64 # ENV ####################
66 os.environ["SDL_VIDEODRIVER"] = "dummy"
68 training_env = config.make_env(
69 env_id=config.train_env_config["env_id"],
70 n_envs=config.train_env_config["n_envs"],
71 n_frame_stack=config.train_env_config["n_frame_stack"],
72 wrapper_class=config.train_env_config["wrapper_class"],
73 adv_wrapper_class=None,
74 env_kwargs=config.train_env_config["env_kwargs"],
75 seed=seed,
76 )
78 eval_env = config.make_env(
79 env_id=config.eval_env_config["env_id"],
80 n_envs=config.eval_env_config["n_envs"],
81 n_frame_stack=config.eval_env_config["n_frame_stack"],
82 wrapper_class=config.eval_env_config["wrapper_class"],
83 adv_wrapper_class=None,
84 env_kwargs=config.eval_env_config["env_kwargs"],
85 seed=seed,
86 )
88 # Agent ####################
90 agent = config.make_agent(
91 algo=config.agent_config["algo"],
92 env=training_env,
93 checkpoint=checkpoint,
94 output_dir=output_dir,
95 device=device,
96 seed=seed,
97 verbose=verbose,
98 algo_kwargs=config.agent_config["algo_kwargs"],
99 )
101 agent.train()
103 # TRAINING ####################
105 agent.learn(
106 eval_env,
107 total_timesteps=total_timesteps,
108 prepopulate_timesteps=prepopulate_timesteps,
109 )
111 # SAVE ####################
113 agent.save(checkpoint=os.path.join(output_dir, "model.zip"))
115 # CLOSE ####################
117 training_env.close()
118 eval_env.close()