Coverage for adaro_rl / zoo / environment.py: 48%
44 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 07:50 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 07:50 +0000
1import os
3import gymnasium as gym
5from typing import Any, Callable, Dict, Optional
7from stable_baselines3.common.monitor import Monitor
8from stable_baselines3.common.vec_env import SubprocVecEnv, VecEnv, VecFrameStack
9from stable_baselines3.common.vec_env.patch_gym import _patch_env
12def make_env(
13 env_id: str | Callable[..., gym.Env],
14 n_envs: int = 1,
15 n_frame_stack: int = 1,
16 seed: Optional[int] = None,
17 start_index: int = 0,
18 monitor_dir: Optional[str] = None,
19 wrapper_class: Optional[Callable[[gym.Env], gym.Env]] = None,
20 adv_wrapper_class: Optional[Callable[[gym.Env], gym.Env]] = None,
21 env_kwargs: Optional[Dict[str, Any]] = None,
22 monitor_kwargs: Optional[Dict[str, Any]] = None,
23 wrapper_kwargs: Optional[Dict[str, Any]] = None,
24 adv_wrapper_kwargs: Optional[Dict[str, Any]] = None,
25 vec_env_kwargs: Optional[Dict[str, Any]] = None,
26 vec_frame_stack_kwargs: Optional[Dict[str, Any]] = None,
27 render_mode: Optional[str] = "rgb_array",
28) -> VecEnv:
29 """
30 Create a wrapped, monitored vectorized environment with optional frame stacking and adversarial wrappers.
32 This function instantiates a vectorized environment comprising ``n_envs`` individual gym
33 environments. Each environment is constructed either via the gym ID (if ``env_id`` is a str)
34 or by calling the provided environment constructor. Every environment is wrapped with a
35 ``Monitor`` to record training information, and optionally with a user-specified wrapper
36 (``wrapper_class``). The environments are then grouped into a vectorized environment using
37 ``SubprocVecEnv``, wrapped with a frame stacking mechanism via ``VecFrameStack``, and lastly, if
38 specified, an adversarial wrapper is applied to the entire vectorized environment.
40 Parameters
41 ----------
42 env_id : str | Callable[..., gym.Env]
43 The environment ID (if a string) or a callable returning an environment instance.
44 n_envs : int, optional
45 The number of parallel environments to create. Default is 1.
46 n_frame_stack : int, optional
47 The number of frames to stack in the final environment. Default is 1.
48 seed : Optional[int], optional
49 The random seed for environment reproducibility. Default is None.
50 start_index : int, optional
51 The starting index used for seeding and identifying individual environments. Default is 0.
52 monitor_dir : Optional[str], optional
53 Directory path where Monitor log files will be saved. If None, logging to file is disabled.
54 wrapper_class : Optional[Callable[[gym.Env], gym.Env]], optional
55 An additional wrapper to apply to each environment after the Monitor. Default is None.
56 adv_wrapper_class : Optional[Callable[[gym.Env], gym.Env]], optional
57 An adversarial wrapper to apply to the entire vectorized environment. Default is None.
58 env_kwargs : Optional[Dict[str, Any]], optional
59 Keyword arguments passed to the environment constructor. Default is None.
60 monitor_kwargs : Optional[Dict[str, Any]], optional
61 Keyword arguments passed to the Monitor wrapper. Default is None.
62 wrapper_kwargs : Optional[Dict[str, Any]], optional
63 Keyword arguments passed to the additional wrapper specified by ``wrapper_class``. Default is None.
64 adv_wrapper_kwargs : Optional[Dict[str, Any]], optional
65 Keyword arguments passed to the adversarial wrapper specified by ``adv_wrapper_class``. Default is None.
66 vec_env_kwargs : Optional[Dict[str, Any]], optional
67 Additional keyword arguments passed to the vectorized environment constructor (e.g. ``SubprocVecEnv``).
68 Default is None.
69 vec_frame_stack_kwargs : Optional[Dict[str, Any]], optional
70 Keyword arguments passed to the frame stacking wrapper (``VecFrameStack``) constructor.
71 Default is None.
73 Returns
74 -------
75 VecEnv
76 A vectorized environment that has been wrapped in a Monitor for logging, stacked with frames
77 as specified, and optionally wrapped with an adversarial wrapper if ``adv_wrapper_class`` is given.
78 """
79 env_kwargs = env_kwargs or {}
80 monitor_kwargs = monitor_kwargs or {}
81 wrapper_kwargs = wrapper_kwargs or {}
82 vec_env_kwargs = vec_env_kwargs or {}
83 vec_frame_stack_kwargs = vec_frame_stack_kwargs or {}
84 adv_wrapper_kwargs = adv_wrapper_kwargs or {}
86 def make_env(rank: int) -> Callable[[], gym.Env]:
87 def _init() -> gym.Env:
88 # For type checker:
89 assert env_kwargs is not None
90 assert monitor_kwargs is not None
91 assert wrapper_kwargs is not None
92 assert adv_wrapper_kwargs is not None
94 if isinstance(env_id, str):
95 # if the render mode was not specified, we set it to `rgb_array` as default.
96 kwargs = {"render_mode": render_mode}
97 kwargs.update(env_kwargs)
98 try:
99 env = gym.make(env_id, **kwargs) # type: ignore[arg-type]
100 except TypeError:
101 env = gym.make(env_id, **env_kwargs)
102 else:
103 env = env_id(**env_kwargs)
104 # Patch to support gym 0.21/0.26 and gymnasium
105 env = _patch_env(env)
107 if seed is not None:
108 env.reset(seed=seed)
109 # compat_gym_seed(env, seed=seed + rank)
110 env.action_space.seed(seed + rank)
111 # Wrap the env in a Monitor wrapper
112 # to have additional training information
113 monitor_path = (
114 os.path.join(monitor_dir, str(rank))
115 if monitor_dir is not None
116 else None
117 )
118 # Create the monitor folder if needed
119 if monitor_path is not None and monitor_dir is not None:
120 os.makedirs(monitor_dir, exist_ok=True)
121 env = Monitor(env, filename=monitor_path, **monitor_kwargs)
122 # Optionally, wrap the environment with the provided wrapper
123 if wrapper_class is not None:
124 env = wrapper_class(env, **wrapper_kwargs)
125 return env
127 return _init
129 suproc_vec_env = SubprocVecEnv(
130 [make_env(i + start_index) for i in range(n_envs)], **vec_env_kwargs
131 )
132 vec_framestack_env = VecFrameStack(
133 suproc_vec_env, n_stack=n_frame_stack, **vec_frame_stack_kwargs
134 )
135 if adv_wrapper_class is not None:
136 return adv_wrapper_class(vec_framestack_env, **adv_wrapper_kwargs)
137 else:
138 return vec_framestack_env