Coverage for adaro_rl / zoo / agent.py: 74%
241 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 07:50 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 07:50 +0000
1import os
2import copy
3import inspect
6import numpy as np
8import torch
9from torch import nn
11import gymnasium as gym
12from gymnasium import spaces
14from stable_baselines3 import DQN, PPO, A2C, DDPG, TD3, SAC
15from stable_baselines3.common.distributions import CategoricalDistribution
16from stable_baselines3.common.callbacks import BaseCallback
17from stable_baselines3.common.preprocessing import is_image_space
18from stable_baselines3.common.utils import is_vectorized_observation
19from collections import Counter
22PyTorchObs = torch.Tensor | dict[str, torch.Tensor]
25def make_agent(**kwargs):
26 """
27 Create an agent instance based on the specified type and keyword arguments.
29 This function dynamically selects the correct agent class and instantiates it using
30 the provided keyword arguments. It also checks that all required
31 positional arguments for the agent's constructor are provided.
33 Parameters
34 ----------
35 ``**kwargs``
36 Keyword arguments to be passed to the agent's constructor.
38 Returns
39 -------
40 object
41 An instance of the selected agent class.
43 Raises
44 ------
45 ValueError
46 If required constructor arguments are missing.
47 """
49 checkpoint = kwargs.get("checkpoint")
50 if isinstance(checkpoint, str):
51 agent_class = Agent
52 elif isinstance(checkpoint, list):
53 agent_class = EnsembleAgent
54 else:
55 raise ValueError(
56 "Invalid checkpoint : it should contain a path to a model checkpoint or a list of paths"
57 )
59 # Get the parameter names and default values of the __init__ method
60 sig = inspect.signature(agent_class.__init__)
62 # Collect parameters without default values (required positional args)
63 required_args = [
64 name
65 for name, param in sig.parameters.items()
66 if param.default == inspect.Parameter.empty and name != "self"
67 ]
69 # Collect the arguments required for the specific agent that are in kwargs
70 init_kwargs = {arg: kwargs[arg] for arg in sig.parameters.keys() if arg in kwargs}
72 # Check for any missing required positional arguments
73 missing_args = [arg for arg in required_args if arg not in init_kwargs]
74 if missing_args:
75 raise ValueError(f"Missing required arguments : {', '.join(missing_args)}")
77 return agent_class(**init_kwargs)
80class Agent(nn.Module):
81 """
82 Reinforcement learning agent wrapper that integrates a Stable Baselines3 algorithm.
84 This class is a wrapper around Stable Baselines3 agents, providing additional utilities
85 for model saving/loading, environment preprocessing, prediction, and training controls.
87 Parameters
88 ----------
89 algo : callable
90 The Stable Baselines3 algorithm class (e.g., PPO, A2C) to be used.
91 env : gym.Env or vecenv
92 The environment used for training the agent.
93 checkpoint : str, optional
94 Path to a checkpoint file to load an existing model. Default is None.
95 output_dir : str, optional
96 Directory where model logs and checkpoints will be saved. Default is None.
97 device : str, optional
98 Computation device (e.g., "cpu", "cuda:0"). Default is "cpu".
99 seed : int, optional
100 Random seed for reproducibility. Default is 0.
101 verbose : bool, optional
102 Verbosity level. Default is True.
103 algo_kwargs : dict, optional
104 Additional keyword arguments for the algorithm. Default is an empty dict.
105 """
107 def __init__(
108 self,
109 algo,
110 env,
111 checkpoint=None,
112 output_dir=None,
113 device="cpu",
114 seed=None,
115 verbose=True,
116 algo_kwargs=None,
117 ):
118 super().__init__()
120 self.algo = algo
122 self.device = device
124 self.observation_space = env.observation_space
125 self.obs_shape = env.observation_space.shape
127 self.action_space = env.action_space
128 self.action_shape = env.action_space.shape
130 if algo_kwargs is None:
131 algo_kwargs = {}
132 self.algo_kwargs = algo_kwargs
134 self.output_dir = output_dir
135 self.algo_kwargs["tensorboard_log"] = output_dir
137 self.n_envs = env.num_envs
139 self.verbose = verbose
141 if checkpoint is not None:
142 self.model = self.load(checkpoint=checkpoint, env=env)
143 else:
144 self.model = self.algo(
145 env=env,
146 seed=seed,
147 device=self.device,
148 verbose=self.verbose,
149 **self.algo_kwargs,
150 )
152 self.model.tensorboard_log = output_dir
154 self.model_observation_space = self.model.policy.observation_space
155 self.model_obs_shape = self.model.policy.observation_space.shape
157 def load(self, checkpoint, env=None):
158 """
159 Load a model from the given checkpoint.
161 Parameters
162 ----------
163 checkpoint : str
164 Path to the model checkpoint to load.
165 env : gym.Env, optional
166 Environment to use with the loaded model. Default is None.
167 """
168 self.model = self.algo.load(
169 checkpoint, env=env, device=self.device, verbose=self.verbose
170 )
171 lr = self.algo_kwargs.get("learning_rate")
172 if lr is not None:
173 self.model.learning_rate = lr
174 self.model._setup_lr_schedule()
176 return self.model
178 def change_output_dir(self, output_dir):
179 """
180 Change the output directory for the agent, updating the TensorBoard log path.
182 Parameters
183 ----------
184 output_dir : str
185 The new output directory path.
186 """
187 self.output_dir = output_dir
188 self.model.tensorboard_log = output_dir
190 def save(self, checkpoint="model"):
191 """
192 Save the current model to the specified checkpoint.
194 Parameters
195 ----------
196 checkpoint : str, optional
197 The name or path for the saved checkpoint. Default is "model".
198 """
199 self.model.save(checkpoint)
201 def to(self, device):
202 """
203 Move the agent's model to the specified computation device.
205 Parameters
206 ----------
207 device : str
208 The target device (e.g., "cpu", "cuda:0").
209 """
210 self.device = device
211 self.model.device = self.device
212 self.model.policy.to(self.device)
213 if hasattr(self.model, "critic"):
214 self.model.critic.to(self.device)
216 def learn(
217 self,
218 eval_env,
219 total_timesteps=100000,
220 eval_freq=5000,
221 n_eval_episodes=10,
222 prepopulate_timesteps=100000,
223 verbose=True,
224 ):
225 """
226 Train the agent using the given evaluation environment and training parameters.
228 Parameters
229 ----------
230 eval_env : gym.Env or vecenv
231 The evaluation environment for periodically assessing the agent's performance.
232 total_timesteps : int, optional
233 Total number of timesteps for the training process. Default is 100000.
234 eval_freq : int, optional
235 Frequency (in timesteps) to perform evaluation. Default is 5000.
236 n_eval_episodes : int, optional
237 Number of episodes to run for each evaluation. Default is 10.
238 deterministic_eval : bool, optional
239 If True, evaluation is conducted deterministically. Default is True.
240 prepopulate_timesteps : int, optional
241 Number of timesteps used for pre-population before normal training. Default is 100000.
242 """
244 if self.model.num_timesteps == 0:
245 self.model._setup_learn(
246 total_timesteps=total_timesteps,
247 reset_num_timesteps=False,
248 tb_log_name="training_logs",
249 progress_bar=True,
250 )
252 saved_learning_rate = self.model.learning_rate
253 saved_tensorboard_log = self.model.tensorboard_log
255 if prepopulate_timesteps > 0:
256 self.model.learning_rate = 0
257 self.model._setup_lr_schedule()
258 self.model.tensorboard_log = None
259 self.model.learn(
260 total_timesteps=prepopulate_timesteps,
261 reset_num_timesteps=False,
262 progress_bar=True,
263 )
265 self.model.learning_rate = saved_learning_rate
266 self.model._setup_lr_schedule()
267 self.model.tensorboard_log = saved_tensorboard_log
268 best_reward_callback = _SaveOnBestRewardCallback(
269 check_freq=1000,
270 best_model_save_path=os.path.join(self.output_dir, "best_train_model"),
271 )
273 self.model.learn(
274 total_timesteps=total_timesteps,
275 reset_num_timesteps=False,
276 callback=best_reward_callback,
277 tb_log_name="training_logs",
278 progress_bar=True,
279 )
281 def preprocess_obs(
282 self, observation: torch.Tensor | dict[str, torch.Tensor]
283 ) -> tuple[PyTorchObs, bool]:
284 """
285 Preprocess input observations for the model, including normalization and channel reordering.
287 Parameters
288 ----------
289 observation : torch.Tensor or dict[str, torch.Tensor]
290 The input observation or batch of observations.
292 Returns
293 -------
294 tuple
295 A tuple containing:
296 - The processed observation as a PyTorch tensor or dictionary of tensors.
297 - A boolean indicating whether the observation is vectorized.
298 """
299 vectorized_env = False
300 if isinstance(observation, dict):
301 assert isinstance(self.observation_space, spaces.Dict), (
302 f"The observation provided is a dict but the obs space is {self.observation_space}"
303 )
304 # need to copy the dict as the dict in VecFrameStack will become a torch tensor
305 observation = copy.deepcopy(observation)
306 for key, obs in observation.items():
307 obs_space = self.model_observation_space.spaces[key]
308 if is_image_space(obs_space):
309 obs_ = _maybe_transpose(obs, obs_space)
310 else:
311 obs_ = obs
312 vectorized_env = vectorized_env or is_vectorized_observation(
313 obs_, obs_space
314 )
315 # Add batch dimension if needed
316 observation[key] = obs_.reshape(
317 (-1, *self.observation_space[key].shape)
318 ) # type: ignore[misc]
320 elif is_image_space(self.observation_space):
321 # Handle the different cases for images
322 # as PyTorch use channel first format
323 observation = _maybe_transpose(observation, self.model_observation_space)
325 if not isinstance(observation, dict):
326 # Dict obs need to be handled separately
327 vectorized_env = is_vectorized_observation(
328 observation, self.model_observation_space
329 )
330 # Add batch dimension if needed
331 observation = observation.reshape((-1, *self.model_observation_space.shape)) # type: ignore[misc]
333 return observation, vectorized_env
335 def predict(self, observations, state=None, episode_start=None, deterministic=True):
336 """
337 Predict the next action for given observations.
339 Parameters
340 ----------
341 observations : array-like
342 The observations based on which the agent should compute an action.
343 state : any, optional
344 The internal state of the agent (if applicable). Default is None.
345 episode_start : any, optional
346 Indicator for the start of an episode (if applicable). Default is None.
347 deterministic : bool, optional
348 Whether to use a deterministic policy. Default is True.
350 Returns
351 -------
352 tuple
353 A tuple (action, state) where action is the computed action and state is the updated state.
354 """
355 action, state = self.model.predict(
356 observations, state, episode_start, deterministic
357 )
358 return action, state
360 def act(self, obs, deterministic=True):
361 """
362 Compute the action for the given observation.
364 Parameters
365 ----------
366 obs : array-like
367 The input observation.
368 deterministic : bool, optional
369 Whether to use a deterministic policy. Default is True.
371 Returns
372 -------
373 array-like
374 The computed action, reshaped to match the action space.
375 """
376 obs, _ = self.preprocess_obs(obs)
377 if isinstance(self.model, DQN):
378 q_values = self.model.policy(obs)
379 action = torch.argmax(q_values)
380 else:
381 action = self.model.policy._predict(obs, deterministic=deterministic)
383 action = action.reshape((-1, *self.model.action_space.shape))
384 return action
386 def probs(self, obs):
387 """
388 Compute the action probabilities for a given observation.
390 Parameters
391 ----------
392 obs : array-like
393 The input observation.
395 Returns
396 -------
397 torch.Tensor or np.ndarray
398 The action probability distribution. For continuous actions, an error is raised.
399 """
400 obs, _ = self.preprocess_obs(obs)
401 action_probs = None
402 if isinstance(self.model.action_space, spaces.Box):
403 raise TypeError(
404 "Continuous action spaces have no finite distribution probs"
405 )
406 elif isinstance(self.model, DQN):
407 q_values = self.model.policy(obs)
408 action_distribution = (
409 CategoricalDistribution(self.model.action_space.dim)
410 .proba_distribution(q_values)
411 .distribution
412 )
413 action_probs = action_distribution.probs
414 elif isinstance(self.model, A2C) or isinstance(self.model, PPO):
415 action_distribution = self.model.policy.get_distribution(obs).distribution
416 action_probs = action_distribution.probs
417 elif (
418 isinstance(self.model, SAC)
419 or isinstance(self.model, DDPG)
420 or isinstance(self.model, TD3)
421 ):
422 raise TypeError(f"{type(self.model)} has no finite distribution probs")
423 return action_probs
425 def log_prob(self, obs, action):
426 """
427 Compute the log probability of an action given an observation.
429 Parameters
430 ----------
431 obs : array-like
432 The input observation.
433 action : torch.Tensor
434 The action for which to compute the log probability.
436 Returns
437 -------
438 torch.Tensor or np.ndarray
439 The log probability of the action.
440 """
441 obs, _ = self.preprocess_obs(obs)
442 log_prob = None
443 if isinstance(self.model, DQN):
444 q_values = self.model.policy(obs)
445 log_prob = (
446 CategoricalDistribution(self.model.action_space.dim)
447 .proba_distribution(q_values)
448 .distribution.log_prob(action)
449 .squeeze(0)
450 )
451 elif isinstance(self.model, A2C) or isinstance(self.model, PPO):
452 log_prob = self.model.policy.get_distribution(obs).distribution.log_prob(
453 action
454 )
455 else:
456 raise TypeError(f"{type(self.model)} has no action log prob")
457 return log_prob
459 def v_value(self, obs):
460 """
461 Estimate the state value for a given observation.
463 Parameters
464 ----------
465 obs : array-like
466 The input observation.
468 Returns
469 -------
470 torch.Tensor or np.ndarray
471 The estimated state value.
472 """
473 obs, _ = self.preprocess_obs(obs)
474 value = None
475 if isinstance(self.model, DQN):
476 q_values = self.model.policy(obs)
477 action_distribution = (
478 CategoricalDistribution(self.model.action_space.dim)
479 .proba_distribution(q_values)
480 .distribution
481 )
482 value = torch.sum(action_distribution.probs * q_values, dim=1)
483 elif isinstance(self.model, PPO) or isinstance(self.model, A2C):
484 values = self.model.policy.predict_values(obs)
485 value = torch.min(values, dim=1)[0]
486 else:
487 raise TypeError(f"{type(self.model)} has no state value")
488 return value
490 def q_value(self, obs, action):
491 """
492 Estimate the Q-value of a given observation-action pair.
494 Parameters
495 ----------
496 obs : array-like
497 The input observation.
498 action : torch.Tensor
499 The action for which to evaluate the Q-value.
501 Returns
502 -------
503 torch.Tensor or np.ndarray
504 The estimated Q-value.
505 """
506 obs, _ = self.preprocess_obs(obs)
507 q_value = None
508 if isinstance(self.model, DQN):
509 q_values = self.model.policy(obs)
510 q_value = q_values[torch.arange(q_values.size(0)), action]
511 elif (
512 isinstance(self.model, SAC)
513 or isinstance(self.model, DDPG)
514 or isinstance(self.model, TD3)
515 ):
516 q_values = self.model.critic(obs, action)
517 # Take the minimum across the critic outputs for each batch element
518 q_values = torch.stack(q_values, dim=0)
519 q_value = torch.min(q_values, dim=0)[0].squeeze()
520 else:
521 raise TypeError(f"{type(self.model)} has no Q value")
522 return q_value
525class EnsembleAgent:
526 """
527 Ensemble agent that aggregates multiple agents for decision making.
529 The ensemble collects several agents, loads or creates them from specified directories,
530 and allows for combined predictions via either voting (discrete actions) or averaging (continuous actions).
532 Parameters
533 ----------
534 algo : callable
535 The algorithm class to be used for each individual agent.
536 env : gym.Env or vecenv
537 The environment used for agent initialization.
538 checkpoint : list of str, optional
539 List of checkpoint to load for each agent.
540 Default is an empty list.
541 device : str, optional
542 Computation device (e.g., "cpu", "cuda:0"). Default is "cpu".
543 seed : int, optional
544 Random seed for initialization. Default is 0.
545 verbose : bool, optional
546 Verbosity level. Default is True.
547 """
549 def __init__(
550 self,
551 algo,
552 env,
553 checkpoint,
554 device="cpu",
555 seed=None,
556 verbose=True,
557 ):
558 self.algo = algo
560 self.device = device
562 self.verbose = verbose
564 self.checkpoints = checkpoint
566 self.observation_space = env.observation_space
567 self.obs_shape = env.observation_space.shape
569 self.action_space = env.action_space
570 self.action_shape = env.action_space.shape
572 self.n_envs = env.num_envs
574 self.seed = seed
576 self.load(checkpoints=self.checkpoints, env=env)
578 self.model_observation_space = self.agents[0].model.policy.observation_space
579 self.model_obs_shape = self.agents[0].model.policy.observation_space.shape
581 def load(self, checkpoints=[], env=None):
582 """
583 Load checkpoints for the ensemble agents.
585 Parameters
586 ----------
587 path : str, optional
588 If provided, the root directory to override the current agents directory.
589 checkpoint_names : list of str, optional
590 List of checkpoint filenames to attempt loading. Default is an empty list.
591 env : gym.Env, optional
592 Environment to be passed to the agent during loading. Default is None.
593 """
594 self.agents = []
595 for checkpoint in checkpoints:
596 self.agents.append(
597 Agent(
598 self.algo,
599 env,
600 checkpoint=checkpoint,
601 device=self.device,
602 seed=self.seed,
603 verbose=self.verbose,
604 )
605 )
607 def to(self, device):
608 """
609 Move all ensemble agents to the specified device.
611 Parameters
612 ----------
613 device : str
614 The target device (e.g., "cpu", "cuda:0").
615 """
616 for agent in self.agents:
617 agent.to(device)
619 def eval(self):
620 """
621 Set all ensemble agents to evaluation mode.
622 """
623 for agent in self.agents:
624 agent.eval()
626 def train(self):
627 """
628 Set all ensemble agents to training mode.
629 """
630 for agent in self.agents:
631 agent.train()
633 def predict(self, observations, state=None, episode_start=None, deterministic=True):
634 """
635 Aggregate predictions from all ensemble agents and return a combined action.
637 For discrete action spaces, uses majority voting. For continuous spaces, averages the actions.
639 Parameters
640 ----------
641 observations : array-like
642 The input observations.
643 state : list, optional
644 List of agent states. If None, defaults to [None] * num_agents.
645 episode_start : list, optional
646 List of episode start flags. If None, defaults to [None] * num_agents.
647 deterministic : bool, optional
648 Whether to use a deterministic policy. Default is True.
650 Returns
651 -------
652 tuple
653 A tuple (final_actions, state) where final_actions is the aggregated action.
654 """
656 num_agents = len(self.agents)
658 # Validate state
659 if state is None:
660 state = [None] * num_agents
661 elif isinstance(state, list):
662 if len(state) != num_agents:
663 raise ValueError(
664 f"Expected states list of length {num_agents}, got {len(state)}"
665 )
666 else:
667 raise TypeError(f"state must be a list or None, got {type(state).__name__}")
669 episodes_start = [episode_start] * num_agents
671 # Collect predictions from each agent
672 batch_actions_states = [
673 agent.model.predict(observations, state_i, episode_start, deterministic)
674 for agent, state_i, episode_start in zip(self.agents, state, episodes_start)
675 ]
676 batch_actions, new_states = zip(*batch_actions_states)
678 # Aggregate actions
679 if isinstance(self.agents[0].model.action_space, gym.spaces.Discrete):
680 # Transpose to group actions per observation
681 actions_transposed = list(zip(*batch_actions))
682 final_actions = []
683 for actions in actions_transposed:
684 counter = Counter(actions)
685 action, _ = counter.most_common(1)[0]
686 final_actions.append(action)
687 elif isinstance(self.agents[0].model.action_space, gym.spaces.Box):
688 # Average over agents
689 batch_actions = np.stack(batch_actions, axis=0)
690 final_actions = np.mean(batch_actions, axis=0)
691 else:
692 raise NotImplementedError("Unsupported action space type")
694 return np.array(final_actions), list(new_states)
697class _SaveOnBestRewardCallback(BaseCallback):
698 """
699 Callback for saving a model based on the mean reward from the episode information buffer.
701 The callback checks the reward every ``check_freq`` steps and, if the mean reward improves
702 over the best recorded value, saves the model in the given directory.
704 Parameters
705 ----------
706 check_freq : int
707 How often (in steps) to check for improvements in mean reward.
708 best_model_save_path : str
709 Directory path where the model will be saved when a new best reward is achieved.
710 verbose : int, optional
711 Verbosity level. If greater than 0, prints information about improvements. Default is 0.
712 """
714 def __init__(self, check_freq: int, best_model_save_path: str, verbose=0):
715 super(_SaveOnBestRewardCallback, self).__init__(verbose)
716 self.check_freq = check_freq
717 self.best_model_save_path = best_model_save_path
718 self.best_mean_reward = -float("inf")
720 def _on_step(self) -> bool:
721 """
722 Called at every step during training. Checks if the mean reward has improved and saves the model.
724 Returns
725 -------
726 bool
727 Always returns True to continue training.
728 """
729 # Only check every `check_freq` callback calls
731 if self.n_calls % self.check_freq == 0:
732 # Gather rewards from the ep_info_buffer
733 ep_rewards = [
734 info["r"] for info in self.model.ep_info_buffer if "r" in info
735 ]
736 if len(ep_rewards) > 0:
737 mean_reward = sum(ep_rewards) / len(ep_rewards)
739 # Save the model if the mean reward improves
740 if mean_reward > self.best_mean_reward:
741 self.best_mean_reward = mean_reward
742 if self.verbose > 0:
743 print(
744 f"[_SaveOnBestRewardCallback] New best mean reward: "
745 f"{mean_reward:.2f} over {len(ep_rewards)} episodes - saving model."
746 )
747 self.model.save(os.path.join(self.best_model_save_path))
748 return True
751def _transpose_image(image: torch.Tensor) -> torch.Tensor:
752 """
753 Transpose an image or a batch of images to change the channel dimension order.
755 For a single image (3D tensor), reorders dimensions from (height, width, channels) to
756 (channels, height, width). For a batch of images (4D tensor), reorders dimensions from
757 (batch, height, width, channels) to (batch, channels, height, width).
759 Parameters
760 ----------
761 image : torch.Tensor
762 The input image tensor.
764 Returns
765 -------
766 torch.Tensor
767 The transposed image tensor.
768 """
769 if len(image.shape) == 3:
770 return image.permute(2, 0, 1)
771 return image.permute(0, 3, 1, 2)
774def _maybe_transpose(
775 observation: torch.Tensor, observation_space: spaces.Space
776) -> np.ndarray:
777 """
778 Transpose the observation tensor if the observation space is of image type.
780 If the observation tensor does not match the expected shape of the image space, this function
781 attempts to re-order the channels using :func:`transpose_image`.
783 Parameters
784 ----------
785 observation : torch.Tensor
786 The observation tensor.
787 observation_space : gym.spaces.Space
788 The observation space definition.
790 Returns
791 -------
792 np.ndarray
793 The observation array, transposed to channel-first format if it is an image.
794 """
795 if is_image_space(observation_space):
796 if not (
797 observation.shape == observation_space.shape
798 or observation.shape[1:] == observation_space.shape
799 ):
800 # Try to re-order the channels
801 transpose_obs = _transpose_image(observation)
802 if (
803 transpose_obs.shape == observation_space.shape
804 or transpose_obs.shape[1:] == observation_space.shape
805 ):
806 observation = transpose_obs
807 return observation