Coverage for adaro_rl / zoo / agent.py: 74%

241 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-14 07:50 +0000

1import os 

2import copy 

3import inspect 

4 

5 

6import numpy as np 

7 

8import torch 

9from torch import nn 

10 

11import gymnasium as gym 

12from gymnasium import spaces 

13 

14from stable_baselines3 import DQN, PPO, A2C, DDPG, TD3, SAC 

15from stable_baselines3.common.distributions import CategoricalDistribution 

16from stable_baselines3.common.callbacks import BaseCallback 

17from stable_baselines3.common.preprocessing import is_image_space 

18from stable_baselines3.common.utils import is_vectorized_observation 

19from collections import Counter 

20 

21 

22PyTorchObs = torch.Tensor | dict[str, torch.Tensor] 

23 

24 

25def make_agent(**kwargs): 

26 """ 

27 Create an agent instance based on the specified type and keyword arguments. 

28 

29 This function dynamically selects the correct agent class and instantiates it using 

30 the provided keyword arguments. It also checks that all required 

31 positional arguments for the agent's constructor are provided. 

32 

33 Parameters 

34 ---------- 

35 ``**kwargs`` 

36 Keyword arguments to be passed to the agent's constructor. 

37 

38 Returns 

39 ------- 

40 object 

41 An instance of the selected agent class. 

42 

43 Raises 

44 ------ 

45 ValueError 

46 If required constructor arguments are missing. 

47 """ 

48 

49 checkpoint = kwargs.get("checkpoint") 

50 if isinstance(checkpoint, str): 

51 agent_class = Agent 

52 elif isinstance(checkpoint, list): 

53 agent_class = EnsembleAgent 

54 else: 

55 raise ValueError( 

56 "Invalid checkpoint : it should contain a path to a model checkpoint or a list of paths" 

57 ) 

58 

59 # Get the parameter names and default values of the __init__ method 

60 sig = inspect.signature(agent_class.__init__) 

61 

62 # Collect parameters without default values (required positional args) 

63 required_args = [ 

64 name 

65 for name, param in sig.parameters.items() 

66 if param.default == inspect.Parameter.empty and name != "self" 

67 ] 

68 

69 # Collect the arguments required for the specific agent that are in kwargs 

70 init_kwargs = {arg: kwargs[arg] for arg in sig.parameters.keys() if arg in kwargs} 

71 

72 # Check for any missing required positional arguments 

73 missing_args = [arg for arg in required_args if arg not in init_kwargs] 

74 if missing_args: 

75 raise ValueError(f"Missing required arguments : {', '.join(missing_args)}") 

76 

77 return agent_class(**init_kwargs) 

78 

79 

80class Agent(nn.Module): 

81 """ 

82 Reinforcement learning agent wrapper that integrates a Stable Baselines3 algorithm. 

83 

84 This class is a wrapper around Stable Baselines3 agents, providing additional utilities 

85 for model saving/loading, environment preprocessing, prediction, and training controls. 

86 

87 Parameters 

88 ---------- 

89 algo : callable 

90 The Stable Baselines3 algorithm class (e.g., PPO, A2C) to be used. 

91 env : gym.Env or vecenv 

92 The environment used for training the agent. 

93 checkpoint : str, optional 

94 Path to a checkpoint file to load an existing model. Default is None. 

95 output_dir : str, optional 

96 Directory where model logs and checkpoints will be saved. Default is None. 

97 device : str, optional 

98 Computation device (e.g., "cpu", "cuda:0"). Default is "cpu". 

99 seed : int, optional 

100 Random seed for reproducibility. Default is 0. 

101 verbose : bool, optional 

102 Verbosity level. Default is True. 

103 algo_kwargs : dict, optional 

104 Additional keyword arguments for the algorithm. Default is an empty dict. 

105 """ 

106 

107 def __init__( 

108 self, 

109 algo, 

110 env, 

111 checkpoint=None, 

112 output_dir=None, 

113 device="cpu", 

114 seed=None, 

115 verbose=True, 

116 algo_kwargs=None, 

117 ): 

118 super().__init__() 

119 

120 self.algo = algo 

121 

122 self.device = device 

123 

124 self.observation_space = env.observation_space 

125 self.obs_shape = env.observation_space.shape 

126 

127 self.action_space = env.action_space 

128 self.action_shape = env.action_space.shape 

129 

130 if algo_kwargs is None: 

131 algo_kwargs = {} 

132 self.algo_kwargs = algo_kwargs 

133 

134 self.output_dir = output_dir 

135 self.algo_kwargs["tensorboard_log"] = output_dir 

136 

137 self.n_envs = env.num_envs 

138 

139 self.verbose = verbose 

140 

141 if checkpoint is not None: 

142 self.model = self.load(checkpoint=checkpoint, env=env) 

143 else: 

144 self.model = self.algo( 

145 env=env, 

146 seed=seed, 

147 device=self.device, 

148 verbose=self.verbose, 

149 **self.algo_kwargs, 

150 ) 

151 

152 self.model.tensorboard_log = output_dir 

153 

154 self.model_observation_space = self.model.policy.observation_space 

155 self.model_obs_shape = self.model.policy.observation_space.shape 

156 

157 def load(self, checkpoint, env=None): 

158 """ 

159 Load a model from the given checkpoint. 

160 

161 Parameters 

162 ---------- 

163 checkpoint : str 

164 Path to the model checkpoint to load. 

165 env : gym.Env, optional 

166 Environment to use with the loaded model. Default is None. 

167 """ 

168 self.model = self.algo.load( 

169 checkpoint, env=env, device=self.device, verbose=self.verbose 

170 ) 

171 lr = self.algo_kwargs.get("learning_rate") 

172 if lr is not None: 

173 self.model.learning_rate = lr 

174 self.model._setup_lr_schedule() 

175 

176 return self.model 

177 

178 def change_output_dir(self, output_dir): 

179 """ 

180 Change the output directory for the agent, updating the TensorBoard log path. 

181 

182 Parameters 

183 ---------- 

184 output_dir : str 

185 The new output directory path. 

186 """ 

187 self.output_dir = output_dir 

188 self.model.tensorboard_log = output_dir 

189 

190 def save(self, checkpoint="model"): 

191 """ 

192 Save the current model to the specified checkpoint. 

193 

194 Parameters 

195 ---------- 

196 checkpoint : str, optional 

197 The name or path for the saved checkpoint. Default is "model". 

198 """ 

199 self.model.save(checkpoint) 

200 

201 def to(self, device): 

202 """ 

203 Move the agent's model to the specified computation device. 

204 

205 Parameters 

206 ---------- 

207 device : str 

208 The target device (e.g., "cpu", "cuda:0"). 

209 """ 

210 self.device = device 

211 self.model.device = self.device 

212 self.model.policy.to(self.device) 

213 if hasattr(self.model, "critic"): 

214 self.model.critic.to(self.device) 

215 

216 def learn( 

217 self, 

218 eval_env, 

219 total_timesteps=100000, 

220 eval_freq=5000, 

221 n_eval_episodes=10, 

222 prepopulate_timesteps=100000, 

223 verbose=True, 

224 ): 

225 """ 

226 Train the agent using the given evaluation environment and training parameters. 

227 

228 Parameters 

229 ---------- 

230 eval_env : gym.Env or vecenv 

231 The evaluation environment for periodically assessing the agent's performance. 

232 total_timesteps : int, optional 

233 Total number of timesteps for the training process. Default is 100000. 

234 eval_freq : int, optional 

235 Frequency (in timesteps) to perform evaluation. Default is 5000. 

236 n_eval_episodes : int, optional 

237 Number of episodes to run for each evaluation. Default is 10. 

238 deterministic_eval : bool, optional 

239 If True, evaluation is conducted deterministically. Default is True. 

240 prepopulate_timesteps : int, optional 

241 Number of timesteps used for pre-population before normal training. Default is 100000. 

242 """ 

243 

244 if self.model.num_timesteps == 0: 

245 self.model._setup_learn( 

246 total_timesteps=total_timesteps, 

247 reset_num_timesteps=False, 

248 tb_log_name="training_logs", 

249 progress_bar=True, 

250 ) 

251 

252 saved_learning_rate = self.model.learning_rate 

253 saved_tensorboard_log = self.model.tensorboard_log 

254 

255 if prepopulate_timesteps > 0: 

256 self.model.learning_rate = 0 

257 self.model._setup_lr_schedule() 

258 self.model.tensorboard_log = None 

259 self.model.learn( 

260 total_timesteps=prepopulate_timesteps, 

261 reset_num_timesteps=False, 

262 progress_bar=True, 

263 ) 

264 

265 self.model.learning_rate = saved_learning_rate 

266 self.model._setup_lr_schedule() 

267 self.model.tensorboard_log = saved_tensorboard_log 

268 best_reward_callback = _SaveOnBestRewardCallback( 

269 check_freq=1000, 

270 best_model_save_path=os.path.join(self.output_dir, "best_train_model"), 

271 ) 

272 

273 self.model.learn( 

274 total_timesteps=total_timesteps, 

275 reset_num_timesteps=False, 

276 callback=best_reward_callback, 

277 tb_log_name="training_logs", 

278 progress_bar=True, 

279 ) 

280 

281 def preprocess_obs( 

282 self, observation: torch.Tensor | dict[str, torch.Tensor] 

283 ) -> tuple[PyTorchObs, bool]: 

284 """ 

285 Preprocess input observations for the model, including normalization and channel reordering. 

286 

287 Parameters 

288 ---------- 

289 observation : torch.Tensor or dict[str, torch.Tensor] 

290 The input observation or batch of observations. 

291 

292 Returns 

293 ------- 

294 tuple 

295 A tuple containing: 

296 - The processed observation as a PyTorch tensor or dictionary of tensors. 

297 - A boolean indicating whether the observation is vectorized. 

298 """ 

299 vectorized_env = False 

300 if isinstance(observation, dict): 

301 assert isinstance(self.observation_space, spaces.Dict), ( 

302 f"The observation provided is a dict but the obs space is {self.observation_space}" 

303 ) 

304 # need to copy the dict as the dict in VecFrameStack will become a torch tensor 

305 observation = copy.deepcopy(observation) 

306 for key, obs in observation.items(): 

307 obs_space = self.model_observation_space.spaces[key] 

308 if is_image_space(obs_space): 

309 obs_ = _maybe_transpose(obs, obs_space) 

310 else: 

311 obs_ = obs 

312 vectorized_env = vectorized_env or is_vectorized_observation( 

313 obs_, obs_space 

314 ) 

315 # Add batch dimension if needed 

316 observation[key] = obs_.reshape( 

317 (-1, *self.observation_space[key].shape) 

318 ) # type: ignore[misc] 

319 

320 elif is_image_space(self.observation_space): 

321 # Handle the different cases for images 

322 # as PyTorch use channel first format 

323 observation = _maybe_transpose(observation, self.model_observation_space) 

324 

325 if not isinstance(observation, dict): 

326 # Dict obs need to be handled separately 

327 vectorized_env = is_vectorized_observation( 

328 observation, self.model_observation_space 

329 ) 

330 # Add batch dimension if needed 

331 observation = observation.reshape((-1, *self.model_observation_space.shape)) # type: ignore[misc] 

332 

333 return observation, vectorized_env 

334 

335 def predict(self, observations, state=None, episode_start=None, deterministic=True): 

336 """ 

337 Predict the next action for given observations. 

338 

339 Parameters 

340 ---------- 

341 observations : array-like 

342 The observations based on which the agent should compute an action. 

343 state : any, optional 

344 The internal state of the agent (if applicable). Default is None. 

345 episode_start : any, optional 

346 Indicator for the start of an episode (if applicable). Default is None. 

347 deterministic : bool, optional 

348 Whether to use a deterministic policy. Default is True. 

349 

350 Returns 

351 ------- 

352 tuple 

353 A tuple (action, state) where action is the computed action and state is the updated state. 

354 """ 

355 action, state = self.model.predict( 

356 observations, state, episode_start, deterministic 

357 ) 

358 return action, state 

359 

360 def act(self, obs, deterministic=True): 

361 """ 

362 Compute the action for the given observation. 

363 

364 Parameters 

365 ---------- 

366 obs : array-like 

367 The input observation. 

368 deterministic : bool, optional 

369 Whether to use a deterministic policy. Default is True. 

370 

371 Returns 

372 ------- 

373 array-like 

374 The computed action, reshaped to match the action space. 

375 """ 

376 obs, _ = self.preprocess_obs(obs) 

377 if isinstance(self.model, DQN): 

378 q_values = self.model.policy(obs) 

379 action = torch.argmax(q_values) 

380 else: 

381 action = self.model.policy._predict(obs, deterministic=deterministic) 

382 

383 action = action.reshape((-1, *self.model.action_space.shape)) 

384 return action 

385 

386 def probs(self, obs): 

387 """ 

388 Compute the action probabilities for a given observation. 

389 

390 Parameters 

391 ---------- 

392 obs : array-like 

393 The input observation. 

394 

395 Returns 

396 ------- 

397 torch.Tensor or np.ndarray 

398 The action probability distribution. For continuous actions, an error is raised. 

399 """ 

400 obs, _ = self.preprocess_obs(obs) 

401 action_probs = None 

402 if isinstance(self.model.action_space, spaces.Box): 

403 raise TypeError( 

404 "Continuous action spaces have no finite distribution probs" 

405 ) 

406 elif isinstance(self.model, DQN): 

407 q_values = self.model.policy(obs) 

408 action_distribution = ( 

409 CategoricalDistribution(self.model.action_space.dim) 

410 .proba_distribution(q_values) 

411 .distribution 

412 ) 

413 action_probs = action_distribution.probs 

414 elif isinstance(self.model, A2C) or isinstance(self.model, PPO): 

415 action_distribution = self.model.policy.get_distribution(obs).distribution 

416 action_probs = action_distribution.probs 

417 elif ( 

418 isinstance(self.model, SAC) 

419 or isinstance(self.model, DDPG) 

420 or isinstance(self.model, TD3) 

421 ): 

422 raise TypeError(f"{type(self.model)} has no finite distribution probs") 

423 return action_probs 

424 

425 def log_prob(self, obs, action): 

426 """ 

427 Compute the log probability of an action given an observation. 

428 

429 Parameters 

430 ---------- 

431 obs : array-like 

432 The input observation. 

433 action : torch.Tensor 

434 The action for which to compute the log probability. 

435 

436 Returns 

437 ------- 

438 torch.Tensor or np.ndarray 

439 The log probability of the action. 

440 """ 

441 obs, _ = self.preprocess_obs(obs) 

442 log_prob = None 

443 if isinstance(self.model, DQN): 

444 q_values = self.model.policy(obs) 

445 log_prob = ( 

446 CategoricalDistribution(self.model.action_space.dim) 

447 .proba_distribution(q_values) 

448 .distribution.log_prob(action) 

449 .squeeze(0) 

450 ) 

451 elif isinstance(self.model, A2C) or isinstance(self.model, PPO): 

452 log_prob = self.model.policy.get_distribution(obs).distribution.log_prob( 

453 action 

454 ) 

455 else: 

456 raise TypeError(f"{type(self.model)} has no action log prob") 

457 return log_prob 

458 

459 def v_value(self, obs): 

460 """ 

461 Estimate the state value for a given observation. 

462 

463 Parameters 

464 ---------- 

465 obs : array-like 

466 The input observation. 

467 

468 Returns 

469 ------- 

470 torch.Tensor or np.ndarray 

471 The estimated state value. 

472 """ 

473 obs, _ = self.preprocess_obs(obs) 

474 value = None 

475 if isinstance(self.model, DQN): 

476 q_values = self.model.policy(obs) 

477 action_distribution = ( 

478 CategoricalDistribution(self.model.action_space.dim) 

479 .proba_distribution(q_values) 

480 .distribution 

481 ) 

482 value = torch.sum(action_distribution.probs * q_values, dim=1) 

483 elif isinstance(self.model, PPO) or isinstance(self.model, A2C): 

484 values = self.model.policy.predict_values(obs) 

485 value = torch.min(values, dim=1)[0] 

486 else: 

487 raise TypeError(f"{type(self.model)} has no state value") 

488 return value 

489 

490 def q_value(self, obs, action): 

491 """ 

492 Estimate the Q-value of a given observation-action pair. 

493 

494 Parameters 

495 ---------- 

496 obs : array-like 

497 The input observation. 

498 action : torch.Tensor 

499 The action for which to evaluate the Q-value. 

500 

501 Returns 

502 ------- 

503 torch.Tensor or np.ndarray 

504 The estimated Q-value. 

505 """ 

506 obs, _ = self.preprocess_obs(obs) 

507 q_value = None 

508 if isinstance(self.model, DQN): 

509 q_values = self.model.policy(obs) 

510 q_value = q_values[torch.arange(q_values.size(0)), action] 

511 elif ( 

512 isinstance(self.model, SAC) 

513 or isinstance(self.model, DDPG) 

514 or isinstance(self.model, TD3) 

515 ): 

516 q_values = self.model.critic(obs, action) 

517 # Take the minimum across the critic outputs for each batch element 

518 q_values = torch.stack(q_values, dim=0) 

519 q_value = torch.min(q_values, dim=0)[0].squeeze() 

520 else: 

521 raise TypeError(f"{type(self.model)} has no Q value") 

522 return q_value 

523 

524 

525class EnsembleAgent: 

526 """ 

527 Ensemble agent that aggregates multiple agents for decision making. 

528 

529 The ensemble collects several agents, loads or creates them from specified directories, 

530 and allows for combined predictions via either voting (discrete actions) or averaging (continuous actions). 

531 

532 Parameters 

533 ---------- 

534 algo : callable 

535 The algorithm class to be used for each individual agent. 

536 env : gym.Env or vecenv 

537 The environment used for agent initialization. 

538 checkpoint : list of str, optional 

539 List of checkpoint to load for each agent. 

540 Default is an empty list. 

541 device : str, optional 

542 Computation device (e.g., "cpu", "cuda:0"). Default is "cpu". 

543 seed : int, optional 

544 Random seed for initialization. Default is 0. 

545 verbose : bool, optional 

546 Verbosity level. Default is True. 

547 """ 

548 

549 def __init__( 

550 self, 

551 algo, 

552 env, 

553 checkpoint, 

554 device="cpu", 

555 seed=None, 

556 verbose=True, 

557 ): 

558 self.algo = algo 

559 

560 self.device = device 

561 

562 self.verbose = verbose 

563 

564 self.checkpoints = checkpoint 

565 

566 self.observation_space = env.observation_space 

567 self.obs_shape = env.observation_space.shape 

568 

569 self.action_space = env.action_space 

570 self.action_shape = env.action_space.shape 

571 

572 self.n_envs = env.num_envs 

573 

574 self.seed = seed 

575 

576 self.load(checkpoints=self.checkpoints, env=env) 

577 

578 self.model_observation_space = self.agents[0].model.policy.observation_space 

579 self.model_obs_shape = self.agents[0].model.policy.observation_space.shape 

580 

581 def load(self, checkpoints=[], env=None): 

582 """ 

583 Load checkpoints for the ensemble agents. 

584 

585 Parameters 

586 ---------- 

587 path : str, optional 

588 If provided, the root directory to override the current agents directory. 

589 checkpoint_names : list of str, optional 

590 List of checkpoint filenames to attempt loading. Default is an empty list. 

591 env : gym.Env, optional 

592 Environment to be passed to the agent during loading. Default is None. 

593 """ 

594 self.agents = [] 

595 for checkpoint in checkpoints: 

596 self.agents.append( 

597 Agent( 

598 self.algo, 

599 env, 

600 checkpoint=checkpoint, 

601 device=self.device, 

602 seed=self.seed, 

603 verbose=self.verbose, 

604 ) 

605 ) 

606 

607 def to(self, device): 

608 """ 

609 Move all ensemble agents to the specified device. 

610 

611 Parameters 

612 ---------- 

613 device : str 

614 The target device (e.g., "cpu", "cuda:0"). 

615 """ 

616 for agent in self.agents: 

617 agent.to(device) 

618 

619 def eval(self): 

620 """ 

621 Set all ensemble agents to evaluation mode. 

622 """ 

623 for agent in self.agents: 

624 agent.eval() 

625 

626 def train(self): 

627 """ 

628 Set all ensemble agents to training mode. 

629 """ 

630 for agent in self.agents: 

631 agent.train() 

632 

633 def predict(self, observations, state=None, episode_start=None, deterministic=True): 

634 """ 

635 Aggregate predictions from all ensemble agents and return a combined action. 

636 

637 For discrete action spaces, uses majority voting. For continuous spaces, averages the actions. 

638 

639 Parameters 

640 ---------- 

641 observations : array-like 

642 The input observations. 

643 state : list, optional 

644 List of agent states. If None, defaults to [None] * num_agents. 

645 episode_start : list, optional 

646 List of episode start flags. If None, defaults to [None] * num_agents. 

647 deterministic : bool, optional 

648 Whether to use a deterministic policy. Default is True. 

649 

650 Returns 

651 ------- 

652 tuple 

653 A tuple (final_actions, state) where final_actions is the aggregated action. 

654 """ 

655 

656 num_agents = len(self.agents) 

657 

658 # Validate state 

659 if state is None: 

660 state = [None] * num_agents 

661 elif isinstance(state, list): 

662 if len(state) != num_agents: 

663 raise ValueError( 

664 f"Expected states list of length {num_agents}, got {len(state)}" 

665 ) 

666 else: 

667 raise TypeError(f"state must be a list or None, got {type(state).__name__}") 

668 

669 episodes_start = [episode_start] * num_agents 

670 

671 # Collect predictions from each agent 

672 batch_actions_states = [ 

673 agent.model.predict(observations, state_i, episode_start, deterministic) 

674 for agent, state_i, episode_start in zip(self.agents, state, episodes_start) 

675 ] 

676 batch_actions, new_states = zip(*batch_actions_states) 

677 

678 # Aggregate actions 

679 if isinstance(self.agents[0].model.action_space, gym.spaces.Discrete): 

680 # Transpose to group actions per observation 

681 actions_transposed = list(zip(*batch_actions)) 

682 final_actions = [] 

683 for actions in actions_transposed: 

684 counter = Counter(actions) 

685 action, _ = counter.most_common(1)[0] 

686 final_actions.append(action) 

687 elif isinstance(self.agents[0].model.action_space, gym.spaces.Box): 

688 # Average over agents 

689 batch_actions = np.stack(batch_actions, axis=0) 

690 final_actions = np.mean(batch_actions, axis=0) 

691 else: 

692 raise NotImplementedError("Unsupported action space type") 

693 

694 return np.array(final_actions), list(new_states) 

695 

696 

697class _SaveOnBestRewardCallback(BaseCallback): 

698 """ 

699 Callback for saving a model based on the mean reward from the episode information buffer. 

700 

701 The callback checks the reward every ``check_freq`` steps and, if the mean reward improves 

702 over the best recorded value, saves the model in the given directory. 

703 

704 Parameters 

705 ---------- 

706 check_freq : int 

707 How often (in steps) to check for improvements in mean reward. 

708 best_model_save_path : str 

709 Directory path where the model will be saved when a new best reward is achieved. 

710 verbose : int, optional 

711 Verbosity level. If greater than 0, prints information about improvements. Default is 0. 

712 """ 

713 

714 def __init__(self, check_freq: int, best_model_save_path: str, verbose=0): 

715 super(_SaveOnBestRewardCallback, self).__init__(verbose) 

716 self.check_freq = check_freq 

717 self.best_model_save_path = best_model_save_path 

718 self.best_mean_reward = -float("inf") 

719 

720 def _on_step(self) -> bool: 

721 """ 

722 Called at every step during training. Checks if the mean reward has improved and saves the model. 

723 

724 Returns 

725 ------- 

726 bool 

727 Always returns True to continue training. 

728 """ 

729 # Only check every `check_freq` callback calls 

730 

731 if self.n_calls % self.check_freq == 0: 

732 # Gather rewards from the ep_info_buffer 

733 ep_rewards = [ 

734 info["r"] for info in self.model.ep_info_buffer if "r" in info 

735 ] 

736 if len(ep_rewards) > 0: 

737 mean_reward = sum(ep_rewards) / len(ep_rewards) 

738 

739 # Save the model if the mean reward improves 

740 if mean_reward > self.best_mean_reward: 

741 self.best_mean_reward = mean_reward 

742 if self.verbose > 0: 

743 print( 

744 f"[_SaveOnBestRewardCallback] New best mean reward: " 

745 f"{mean_reward:.2f} over {len(ep_rewards)} episodes - saving model." 

746 ) 

747 self.model.save(os.path.join(self.best_model_save_path)) 

748 return True 

749 

750 

751def _transpose_image(image: torch.Tensor) -> torch.Tensor: 

752 """ 

753 Transpose an image or a batch of images to change the channel dimension order. 

754 

755 For a single image (3D tensor), reorders dimensions from (height, width, channels) to 

756 (channels, height, width). For a batch of images (4D tensor), reorders dimensions from 

757 (batch, height, width, channels) to (batch, channels, height, width). 

758 

759 Parameters 

760 ---------- 

761 image : torch.Tensor 

762 The input image tensor. 

763 

764 Returns 

765 ------- 

766 torch.Tensor 

767 The transposed image tensor. 

768 """ 

769 if len(image.shape) == 3: 

770 return image.permute(2, 0, 1) 

771 return image.permute(0, 3, 1, 2) 

772 

773 

774def _maybe_transpose( 

775 observation: torch.Tensor, observation_space: spaces.Space 

776) -> np.ndarray: 

777 """ 

778 Transpose the observation tensor if the observation space is of image type. 

779 

780 If the observation tensor does not match the expected shape of the image space, this function 

781 attempts to re-order the channels using :func:`transpose_image`. 

782 

783 Parameters 

784 ---------- 

785 observation : torch.Tensor 

786 The observation tensor. 

787 observation_space : gym.spaces.Space 

788 The observation space definition. 

789 

790 Returns 

791 ------- 

792 np.ndarray 

793 The observation array, transposed to channel-first format if it is an image. 

794 """ 

795 if is_image_space(observation_space): 

796 if not ( 

797 observation.shape == observation_space.shape 

798 or observation.shape[1:] == observation_space.shape 

799 ): 

800 # Try to re-order the channels 

801 transpose_obs = _transpose_image(observation) 

802 if ( 

803 transpose_obs.shape == observation_space.shape 

804 or transpose_obs.shape[1:] == observation_space.shape 

805 ): 

806 observation = transpose_obs 

807 return observation