Coverage for adaro_rl / zoo / HalfCheetah-v5 / agent.py: 96%

25 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-14 07:50 +0000

1from torch import nn 

2from stable_baselines3 import SAC 

3from ..agent import make_agent 

4 

5 

6def linear_schedule( 

7 initial_value: float, 

8 end_value: float = 0, 

9): 

10 """ 

11 Linear learning rate schedule. 

12 """ 

13 

14 def func(progress_remaining: float) -> float: 

15 """ 

16 Progress will decrease from 1 (beginning) to 0. 

17 """ 

18 return progress_remaining * (initial_value - end_value) + end_value 

19 

20 return func 

21 

22 

23ALGO = SAC 

24 

25TRAIN_STEPS = 20000000 

26EVAL_FREQ = None 

27 

28ADV_TRAIN_STEPS = 2000000 

29ADV_EVAL_FREQ = None 

30ADV_PREPOPULATE_STEPS = 100000 

31 

32ADVERSARY_TRAIN_STEPS = 20000000 

33ADVERSARY_EVAL_FREQ = None 

34 

35N_EVALS = 50 

36 

37 

38repo_id = "lucasschott/HalfCheetah-v5-SAC" 

39filename = "model.zip" 

40 

41 

42agent_config = { 

43 "algo": ALGO, 

44 "algo_kwargs": { 

45 "policy": "MlpPolicy", 

46 "policy_kwargs": { 

47 "log_std_init": -3, 

48 "activation_fn": nn.ReLU, 

49 "net_arch": [256, 256], 

50 }, 

51 "learning_rate": linear_schedule(1e-3, 5e-4), 

52 "batch_size": 256, 

53 "gamma": 0.99, 

54 "learning_starts": 10000, 

55 "buffer_size": 1000000, 

56 "tau": 0.005, 

57 "ent_coef": "auto", 

58 "train_freq": 1, 

59 "gradient_steps": 1, 

60 "use_sde": True, 

61 }, 

62} 

63training_kwargs = { 

64 "total_timesteps": TRAIN_STEPS, 

65 "n_eval_episodes": N_EVALS, 

66 "eval_freq": EVAL_FREQ, 

67} 

68 

69 

70adversary_config = { 

71 "algo": ALGO, 

72 "algo_kwargs": { 

73 "policy": "MlpPolicy", 

74 "policy_kwargs": { 

75 "log_std_init": -3, 

76 "activation_fn": nn.ReLU, 

77 "net_arch": [256, 256], 

78 }, 

79 "learning_rate": linear_schedule(1e-3, 5e-4), 

80 "batch_size": 256, 

81 "gamma": 0.99, 

82 "learning_starts": 10000, 

83 "buffer_size": 1000000, 

84 "tau": 0.005, 

85 "ent_coef": "auto", 

86 "train_freq": 1, 

87 "gradient_steps": 1, 

88 "use_sde": True, 

89 }, 

90} 

91adversary_training_kwargs = { 

92 "total_timesteps": ADVERSARY_TRAIN_STEPS, 

93 "n_eval_episodes": N_EVALS, 

94 "eval_freq": ADVERSARY_EVAL_FREQ, 

95} 

96 

97 

98finetuned_agent_config = { 

99 "algo": ALGO, 

100 "algo_kwargs": { 

101 "policy": "MlpPolicy", 

102 "policy_kwargs": { 

103 "log_std_init": -3, 

104 "activation_fn": nn.ReLU, 

105 "net_arch": [256, 256], 

106 }, 

107 "learning_rate": 1e-4, 

108 "batch_size": 256, 

109 "gamma": 0.99, 

110 "learning_starts": 10000, 

111 "buffer_size": 1000000, 

112 "tau": 0.005, 

113 "ent_coef": "auto", 

114 "train_freq": 1, 

115 "gradient_steps": 1, 

116 "use_sde": True, 

117 }, 

118} 

119finetuned_training_kwargs = { 

120 "total_timesteps": ADV_TRAIN_STEPS, 

121 "n_eval_episodes": N_EVALS, 

122 "eval_freq": ADV_EVAL_FREQ, 

123 "prepopulate_timesteps": ADV_PREPOPULATE_STEPS, 

124} 

125 

126 

127n_eval_episodes = N_EVALS