Coverage for adaro_rl / zoo / HalfCheetah-v5 / agent.py: 96%
25 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 07:50 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 07:50 +0000
1from torch import nn
2from stable_baselines3 import SAC
3from ..agent import make_agent
6def linear_schedule(
7 initial_value: float,
8 end_value: float = 0,
9):
10 """
11 Linear learning rate schedule.
12 """
14 def func(progress_remaining: float) -> float:
15 """
16 Progress will decrease from 1 (beginning) to 0.
17 """
18 return progress_remaining * (initial_value - end_value) + end_value
20 return func
23ALGO = SAC
25TRAIN_STEPS = 20000000
26EVAL_FREQ = None
28ADV_TRAIN_STEPS = 2000000
29ADV_EVAL_FREQ = None
30ADV_PREPOPULATE_STEPS = 100000
32ADVERSARY_TRAIN_STEPS = 20000000
33ADVERSARY_EVAL_FREQ = None
35N_EVALS = 50
38repo_id = "lucasschott/HalfCheetah-v5-SAC"
39filename = "model.zip"
42agent_config = {
43 "algo": ALGO,
44 "algo_kwargs": {
45 "policy": "MlpPolicy",
46 "policy_kwargs": {
47 "log_std_init": -3,
48 "activation_fn": nn.ReLU,
49 "net_arch": [256, 256],
50 },
51 "learning_rate": linear_schedule(1e-3, 5e-4),
52 "batch_size": 256,
53 "gamma": 0.99,
54 "learning_starts": 10000,
55 "buffer_size": 1000000,
56 "tau": 0.005,
57 "ent_coef": "auto",
58 "train_freq": 1,
59 "gradient_steps": 1,
60 "use_sde": True,
61 },
62}
63training_kwargs = {
64 "total_timesteps": TRAIN_STEPS,
65 "n_eval_episodes": N_EVALS,
66 "eval_freq": EVAL_FREQ,
67}
70adversary_config = {
71 "algo": ALGO,
72 "algo_kwargs": {
73 "policy": "MlpPolicy",
74 "policy_kwargs": {
75 "log_std_init": -3,
76 "activation_fn": nn.ReLU,
77 "net_arch": [256, 256],
78 },
79 "learning_rate": linear_schedule(1e-3, 5e-4),
80 "batch_size": 256,
81 "gamma": 0.99,
82 "learning_starts": 10000,
83 "buffer_size": 1000000,
84 "tau": 0.005,
85 "ent_coef": "auto",
86 "train_freq": 1,
87 "gradient_steps": 1,
88 "use_sde": True,
89 },
90}
91adversary_training_kwargs = {
92 "total_timesteps": ADVERSARY_TRAIN_STEPS,
93 "n_eval_episodes": N_EVALS,
94 "eval_freq": ADVERSARY_EVAL_FREQ,
95}
98finetuned_agent_config = {
99 "algo": ALGO,
100 "algo_kwargs": {
101 "policy": "MlpPolicy",
102 "policy_kwargs": {
103 "log_std_init": -3,
104 "activation_fn": nn.ReLU,
105 "net_arch": [256, 256],
106 },
107 "learning_rate": 1e-4,
108 "batch_size": 256,
109 "gamma": 0.99,
110 "learning_starts": 10000,
111 "buffer_size": 1000000,
112 "tau": 0.005,
113 "ent_coef": "auto",
114 "train_freq": 1,
115 "gradient_steps": 1,
116 "use_sde": True,
117 },
118}
119finetuned_training_kwargs = {
120 "total_timesteps": ADV_TRAIN_STEPS,
121 "n_eval_episodes": N_EVALS,
122 "eval_freq": ADV_EVAL_FREQ,
123 "prepopulate_timesteps": ADV_PREPOPULATE_STEPS,
124}
127n_eval_episodes = N_EVALS