Coverage for adaro_rl / zoo / Enduro-v5 / environment.py: 83%
18 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 07:50 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 07:50 +0000
1import numpy as np
2import gymnasium as gym
3import ale_py
4from stable_baselines3.common.atari_wrappers import AtariWrapper
5from ..eval import custom_evaluate_policy
6from ..environment import make_env
9gym.register_envs(ale_py)
10ENV_ID = "ALE/Enduro-v5"
11N_FRAME_STACK = 4
14class EnvWrapper(AtariWrapper):
15 def __init__(self, env):
16 super(EnvWrapper, self).__init__(env)
18 self.observation_perturbation_space = gym.spaces.Box(
19 low=np.repeat(
20 -np.ones(self.observation_space.shape).astype(np.int16) * 255,
21 N_FRAME_STACK,
22 axis=-1,
23 ),
24 high=np.repeat(
25 np.ones(self.observation_space.shape).astype(np.int16) * 255,
26 N_FRAME_STACK,
27 axis=-1,
28 ),
29 dtype=np.int16,
30 )
31 self.proportional_obs_perturbation_mask = np.zeros(
32 self.observation_perturbation_space.shape
33 )
36train_env_config = {
37 "env_id": ENV_ID,
38 "wrapper_class": EnvWrapper,
39 "n_envs": 1,
40 "n_frame_stack": N_FRAME_STACK,
41 "env_kwargs": {},
42}
44eval_env_config = {
45 "env_id": ENV_ID,
46 "wrapper_class": EnvWrapper,
47 "n_envs": 1,
48 "n_frame_stack": N_FRAME_STACK,
49 "env_kwargs": {},
50}
52render_env_config = {
53 "env_id": ENV_ID,
54 "wrapper_class": EnvWrapper,
55 "n_envs": 1,
56 "n_frame_stack": N_FRAME_STACK,
57 "env_kwargs": {},
58}
60eval_config = {
61 "evaluate_policy_fct": custom_evaluate_policy,
62 "deterministic_eval": False,
63}