Coverage for adaro_rl / zoo / environment.py: 48%

44 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-14 07:50 +0000

1import os 

2 

3import gymnasium as gym 

4 

5from typing import Any, Callable, Dict, Optional 

6 

7from stable_baselines3.common.monitor import Monitor 

8from stable_baselines3.common.vec_env import SubprocVecEnv, VecEnv, VecFrameStack 

9from stable_baselines3.common.vec_env.patch_gym import _patch_env 

10 

11 

12def make_env( 

13 env_id: str | Callable[..., gym.Env], 

14 n_envs: int = 1, 

15 n_frame_stack: int = 1, 

16 seed: Optional[int] = None, 

17 start_index: int = 0, 

18 monitor_dir: Optional[str] = None, 

19 wrapper_class: Optional[Callable[[gym.Env], gym.Env]] = None, 

20 adv_wrapper_class: Optional[Callable[[gym.Env], gym.Env]] = None, 

21 env_kwargs: Optional[Dict[str, Any]] = None, 

22 monitor_kwargs: Optional[Dict[str, Any]] = None, 

23 wrapper_kwargs: Optional[Dict[str, Any]] = None, 

24 adv_wrapper_kwargs: Optional[Dict[str, Any]] = None, 

25 vec_env_kwargs: Optional[Dict[str, Any]] = None, 

26 vec_frame_stack_kwargs: Optional[Dict[str, Any]] = None, 

27 render_mode: Optional[str] = "rgb_array", 

28) -> VecEnv: 

29 """ 

30 Create a wrapped, monitored vectorized environment with optional frame stacking and adversarial wrappers. 

31 

32 This function instantiates a vectorized environment comprising ``n_envs`` individual gym 

33 environments. Each environment is constructed either via the gym ID (if ``env_id`` is a str) 

34 or by calling the provided environment constructor. Every environment is wrapped with a 

35 ``Monitor`` to record training information, and optionally with a user-specified wrapper 

36 (``wrapper_class``). The environments are then grouped into a vectorized environment using 

37 ``SubprocVecEnv``, wrapped with a frame stacking mechanism via ``VecFrameStack``, and lastly, if 

38 specified, an adversarial wrapper is applied to the entire vectorized environment. 

39 

40 Parameters 

41 ---------- 

42 env_id : str | Callable[..., gym.Env] 

43 The environment ID (if a string) or a callable returning an environment instance. 

44 n_envs : int, optional 

45 The number of parallel environments to create. Default is 1. 

46 n_frame_stack : int, optional 

47 The number of frames to stack in the final environment. Default is 1. 

48 seed : Optional[int], optional 

49 The random seed for environment reproducibility. Default is None. 

50 start_index : int, optional 

51 The starting index used for seeding and identifying individual environments. Default is 0. 

52 monitor_dir : Optional[str], optional 

53 Directory path where Monitor log files will be saved. If None, logging to file is disabled. 

54 wrapper_class : Optional[Callable[[gym.Env], gym.Env]], optional 

55 An additional wrapper to apply to each environment after the Monitor. Default is None. 

56 adv_wrapper_class : Optional[Callable[[gym.Env], gym.Env]], optional 

57 An adversarial wrapper to apply to the entire vectorized environment. Default is None. 

58 env_kwargs : Optional[Dict[str, Any]], optional 

59 Keyword arguments passed to the environment constructor. Default is None. 

60 monitor_kwargs : Optional[Dict[str, Any]], optional 

61 Keyword arguments passed to the Monitor wrapper. Default is None. 

62 wrapper_kwargs : Optional[Dict[str, Any]], optional 

63 Keyword arguments passed to the additional wrapper specified by ``wrapper_class``. Default is None. 

64 adv_wrapper_kwargs : Optional[Dict[str, Any]], optional 

65 Keyword arguments passed to the adversarial wrapper specified by ``adv_wrapper_class``. Default is None. 

66 vec_env_kwargs : Optional[Dict[str, Any]], optional 

67 Additional keyword arguments passed to the vectorized environment constructor (e.g. ``SubprocVecEnv``). 

68 Default is None. 

69 vec_frame_stack_kwargs : Optional[Dict[str, Any]], optional 

70 Keyword arguments passed to the frame stacking wrapper (``VecFrameStack``) constructor. 

71 Default is None. 

72 

73 Returns 

74 ------- 

75 VecEnv 

76 A vectorized environment that has been wrapped in a Monitor for logging, stacked with frames 

77 as specified, and optionally wrapped with an adversarial wrapper if ``adv_wrapper_class`` is given. 

78 """ 

79 env_kwargs = env_kwargs or {} 

80 monitor_kwargs = monitor_kwargs or {} 

81 wrapper_kwargs = wrapper_kwargs or {} 

82 vec_env_kwargs = vec_env_kwargs or {} 

83 vec_frame_stack_kwargs = vec_frame_stack_kwargs or {} 

84 adv_wrapper_kwargs = adv_wrapper_kwargs or {} 

85 

86 def make_env(rank: int) -> Callable[[], gym.Env]: 

87 def _init() -> gym.Env: 

88 # For type checker: 

89 assert env_kwargs is not None 

90 assert monitor_kwargs is not None 

91 assert wrapper_kwargs is not None 

92 assert adv_wrapper_kwargs is not None 

93 

94 if isinstance(env_id, str): 

95 # if the render mode was not specified, we set it to `rgb_array` as default. 

96 kwargs = {"render_mode": render_mode} 

97 kwargs.update(env_kwargs) 

98 try: 

99 env = gym.make(env_id, **kwargs) # type: ignore[arg-type] 

100 except TypeError: 

101 env = gym.make(env_id, **env_kwargs) 

102 else: 

103 env = env_id(**env_kwargs) 

104 # Patch to support gym 0.21/0.26 and gymnasium 

105 env = _patch_env(env) 

106 

107 if seed is not None: 

108 env.reset(seed=seed) 

109 # compat_gym_seed(env, seed=seed + rank) 

110 env.action_space.seed(seed + rank) 

111 # Wrap the env in a Monitor wrapper 

112 # to have additional training information 

113 monitor_path = ( 

114 os.path.join(monitor_dir, str(rank)) 

115 if monitor_dir is not None 

116 else None 

117 ) 

118 # Create the monitor folder if needed 

119 if monitor_path is not None and monitor_dir is not None: 

120 os.makedirs(monitor_dir, exist_ok=True) 

121 env = Monitor(env, filename=monitor_path, **monitor_kwargs) 

122 # Optionally, wrap the environment with the provided wrapper 

123 if wrapper_class is not None: 

124 env = wrapper_class(env, **wrapper_kwargs) 

125 return env 

126 

127 return _init 

128 

129 suproc_vec_env = SubprocVecEnv( 

130 [make_env(i + start_index) for i in range(n_envs)], **vec_env_kwargs 

131 ) 

132 vec_framestack_env = VecFrameStack( 

133 suproc_vec_env, n_stack=n_frame_stack, **vec_frame_stack_kwargs 

134 ) 

135 if adv_wrapper_class is not None: 

136 return adv_wrapper_class(vec_framestack_env, **adv_wrapper_kwargs) 

137 else: 

138 return vec_framestack_env