Coverage for adaro_rl / zoo / Enduro-v5 / environment.py: 83%

18 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-14 07:50 +0000

1import numpy as np 

2import gymnasium as gym 

3import ale_py 

4from stable_baselines3.common.atari_wrappers import AtariWrapper 

5from ..eval import custom_evaluate_policy 

6from ..environment import make_env 

7 

8 

9gym.register_envs(ale_py) 

10ENV_ID = "ALE/Enduro-v5" 

11N_FRAME_STACK = 4 

12 

13 

14class EnvWrapper(AtariWrapper): 

15 def __init__(self, env): 

16 super(EnvWrapper, self).__init__(env) 

17 

18 self.observation_perturbation_space = gym.spaces.Box( 

19 low=np.repeat( 

20 -np.ones(self.observation_space.shape).astype(np.int16) * 255, 

21 N_FRAME_STACK, 

22 axis=-1, 

23 ), 

24 high=np.repeat( 

25 np.ones(self.observation_space.shape).astype(np.int16) * 255, 

26 N_FRAME_STACK, 

27 axis=-1, 

28 ), 

29 dtype=np.int16, 

30 ) 

31 self.proportional_obs_perturbation_mask = np.zeros( 

32 self.observation_perturbation_space.shape 

33 ) 

34 

35 

36train_env_config = { 

37 "env_id": ENV_ID, 

38 "wrapper_class": EnvWrapper, 

39 "n_envs": 1, 

40 "n_frame_stack": N_FRAME_STACK, 

41 "env_kwargs": {}, 

42} 

43 

44eval_env_config = { 

45 "env_id": ENV_ID, 

46 "wrapper_class": EnvWrapper, 

47 "n_envs": 1, 

48 "n_frame_stack": N_FRAME_STACK, 

49 "env_kwargs": {}, 

50} 

51 

52render_env_config = { 

53 "env_id": ENV_ID, 

54 "wrapper_class": EnvWrapper, 

55 "n_envs": 1, 

56 "n_frame_stack": N_FRAME_STACK, 

57 "env_kwargs": {}, 

58} 

59 

60eval_config = { 

61 "evaluate_policy_fct": custom_evaluate_policy, 

62 "deterministic_eval": False, 

63}