Coverage for adaro_rl / zoo / eval.py: 84%
55 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 07:50 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 07:50 +0000
1# flake8: noqa: E731
2from typing import List, Optional, Tuple
3import warnings
4import gymnasium as gym
5import numpy as np
7from stable_baselines3.common.vec_env import (
8 VecEnv,
9 DummyVecEnv,
10 is_vecenv_wrapped,
11 VecMonitor,
12)
15def custom_evaluate_policy(
16 model: "type_aliases.PolicyPredictor",
17 env: gym.Env | VecEnv,
18 n_eval_episodes: int = 10,
19 deterministic: bool = True,
20 render: bool = False,
21 reward_threshold: Optional[float] = None,
22 return_episode_rewards: bool = False,
23 warn: bool = True,
24) -> Tuple[float | Tuple[List[float], List[int]]]:
25 """
26 Runs policy for ``n_eval_episodes`` episodes and returns average reward.
27 If a vector env is passed in, this divides the episodes to evaluate onto the
28 different elements of the vector env. This static division of work is done to
29 remove bias. See https://github.com/DLR-RM/stable-baselines3/issues/402 for more
30 details and discussion.
32 .. note::
33 If environment has not been wrapped with ``Monitor`` wrapper, reward and
34 episode lengths are counted as it appears with ``env.step`` calls. If
35 the environment contains wrappers that modify rewards or episode lengths
36 (e.g. reward scaling, early episode reset), these will affect the evaluation
37 results as well. You can avoid this by wrapping environment with ``Monitor``
38 wrapper before anything else.
40 :param model: The RL agent you want to evaluate. This can be any object
41 that implements a `predict` method, such as an RL algorithm (``BaseAlgorithm``)
42 or policy (``BasePolicy``).
43 :param env: The gym environment or ``VecEnv`` environment.
44 :param n_eval_episodes: Number of episode to evaluate the agent
45 :param deterministic: Whether to use deterministic or stochastic actions
46 :param render: Whether to render the environment or not
47 :param reward_threshold: Minimum expected reward per episode,
48 this will raise an error if the performance is not met
49 :param return_episode_rewards: If True, a list of rewards and episode lengths
50 per episode will be returned instead of the mean.
51 :param warn: If True (default), warns user about lack of a Monitor wrapper in the
52 evaluation environment.
53 :return: Mean reward per episode, std of reward per episode.
54 Returns ([float], [int]) when ``return_episode_rewards`` is True, first
55 list containing per-episode rewards and second containing per-episode lengths
56 (in number of steps).
57 """
58 is_monitor_wrapped = False
59 # Avoid circular import
60 from stable_baselines3.common.monitor import Monitor
62 if not isinstance(env, VecEnv):
63 print("ouououououououo")
64 env = DummyVecEnv([lambda: env]) # type: ignore[list-item, return-value]
66 is_monitor_wrapped = (
67 is_vecenv_wrapped(env, VecMonitor) or env.env_is_wrapped(Monitor)[0]
68 )
70 if not is_monitor_wrapped and warn:
71 warnings.warn(
72 "Evaluation environment is not wrapped with a ``Monitor`` wrapper. "
73 "This may result in reporting modified episode lengths and rewards, if other wrappers"
74 "happen to modify these. "
75 "Consider wrapping environment first with ``Monitor`` wrapper.",
76 UserWarning,
77 )
79 n_envs = env.num_envs
80 episode_rewards = []
81 episode_lengths = []
83 episode_counts = np.zeros(n_envs, dtype="int")
84 # Divides episodes among different sub environments in the vector as evenly as possible
85 episode_count_targets = np.array(
86 [(n_eval_episodes + i) // n_envs for i in range(n_envs)], dtype="int"
87 )
89 current_rewards = np.zeros(n_envs)
90 current_lengths = np.zeros(n_envs, dtype="int")
91 observations = env.reset()
92 states = None
93 episode_starts = np.ones((env.num_envs,), dtype=bool)
94 while (episode_counts < episode_count_targets).any():
95 actions, states = model.predict(
96 observations, # type: ignore[arg-type]
97 state=states,
98 episode_start=episode_starts,
99 deterministic=deterministic,
100 )
101 new_observations, rewards, dones, infos = env.step(actions)
102 current_rewards += rewards
103 current_lengths += 1
104 for i in range(n_envs):
105 if episode_counts[i] < episode_count_targets[i]:
106 # unpack values
107 done = dones[i]
108 info = infos[i]
109 episode_starts[i] = done
111 if dones[i]:
112 if is_monitor_wrapped:
113 # Atari wrapper can send a "done" signal when
114 # the agent loses a life, but it does not correspond
115 # to the true end of episode
116 if "episode" in info.keys():
117 # Do not trust "done" with episode endings.
118 # Monitor wrapper includes "episode" key in info if environment
119 # has been wrapped with it. Use those rewards instead.
120 episode_rewards.append(info["episode"]["r"])
121 episode_lengths.append(info["episode"]["l"])
122 # Only increment at the real end of an episode
123 episode_counts[i] += 1
124 else:
125 episode_rewards.append(current_rewards[i])
126 episode_lengths.append(current_lengths[i])
127 episode_counts[i] += 1
128 current_rewards[i] = 0
129 current_lengths[i] = 0
131 observations = new_observations
133 if render:
134 env.render()
136 mean_reward = np.mean(episode_rewards)
137 std_reward = np.std(episode_rewards)
138 if reward_threshold is not None:
139 assert mean_reward > reward_threshold, (
140 f"Mean reward below threshold: {mean_reward:.2f} < {reward_threshold:.2f}"
141 )
142 if return_episode_rewards:
143 return episode_rewards, episode_lengths
144 return mean_reward, std_reward