Coverage for adaro_rl / zoo / Enduro-v5 / agent.py: 100%
17 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 07:50 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 07:50 +0000
1from torch import nn
2from stable_baselines3 import PPO
3from stable_baselines3.common.torch_layers import NatureCNN
4from ..agent import make_agent
7def linear_schedule(
8 initial_value: float,
9 end_value: float = 0,
10):
11 """
12 Linear learning rate schedule.
13 """
15 def func(progress_remaining: float) -> float:
16 """
17 Progress will decrease from 1 (beginning) to 0.
18 """
19 return progress_remaining * (initial_value - end_value) + end_value
21 return func
24repo_id = "lucasschott/Enduro-v5-PPO"
25filename = "model.zip"
28agent_config = {
29 "algo": PPO,
30 "algo_kwargs": {
31 "policy": "MlpPolicy",
32 "policy_kwargs": {
33 "features_extractor_class": NatureCNN,
34 "activation_fn": nn.ReLU,
35 "net_arch": dict(vf=[512, 256], pi=[256, 128]),
36 },
37 "learning_rate": linear_schedule(3e-4, 1e-6),
38 "n_steps": 256,
39 "batch_size": 1024,
40 "n_epochs": 4,
41 "gamma": 0.99,
42 "gae_lambda": 0.95,
43 "clip_range": 0.1,
44 "vf_coef": 0.5,
45 "ent_coef": 0.01,
46 },
47}
48training_kwargs = {
49 "total_timesteps": 40000000,
50 "n_eval_episodes": 50,
51 "eval_freq": 0,
52 "prepopulate_timesteps": 0,
53}
56adversary_config = {
57 "algo": PPO,
58 "algo_kwargs": {
59 "policy": "MlpPolicy",
60 "policy_kwargs": {
61 "features_extractor_class": NatureCNN,
62 "activation_fn": nn.ReLU,
63 "net_arch": dict(vf=[512, 256], pi=[256, 128]),
64 },
65 "learning_rate": 2.5e-4,
66 "n_steps": 256,
67 "batch_size": 1024,
68 "n_epochs": 4,
69 "gamma": 0.99,
70 "gae_lambda": 0.95,
71 "clip_range": 0.1,
72 "vf_coef": 0.5,
73 "ent_coef": 0.01,
74 },
75}
76adversary_training_kwargs = {
77 "total_timesteps": 40000000,
78 "n_eval_episodes": 0,
79 "eval_freq": 0,
80 "prepopulate_timesteps": 0,
81}
84finetuned_agent_config = {
85 "algo": PPO,
86 "algo_kwargs": {
87 "policy": "MlpPolicy",
88 "policy_kwargs": {
89 "features_extractor_class": NatureCNN,
90 "activation_fn": nn.ReLU,
91 "net_arch": dict(vf=[512, 256], pi=[256, 128]),
92 },
93 "learning_rate": 5e-5,
94 "n_steps": 256,
95 "batch_size": 1024,
96 "n_epochs": 4,
97 "gamma": 0.99,
98 "gae_lambda": 0.95,
99 "clip_range": 0.1,
100 "vf_coef": 0.5,
101 "ent_coef": 0.01,
102 },
103}
104finetuned_training_kwargs = {
105 "total_timesteps": 2000000,
106 "n_eval_episodes": 0,
107 "eval_freq": 0,
108 "prepopulate_timesteps": 100000,
109}
112n_eval_episodes = 50