Coverage for adaro_rl / zoo / Enduro-v5 / agent.py: 100%

17 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-14 07:50 +0000

1from torch import nn 

2from stable_baselines3 import PPO 

3from stable_baselines3.common.torch_layers import NatureCNN 

4from ..agent import make_agent 

5 

6 

7def linear_schedule( 

8 initial_value: float, 

9 end_value: float = 0, 

10): 

11 """ 

12 Linear learning rate schedule. 

13 """ 

14 

15 def func(progress_remaining: float) -> float: 

16 """ 

17 Progress will decrease from 1 (beginning) to 0. 

18 """ 

19 return progress_remaining * (initial_value - end_value) + end_value 

20 

21 return func 

22 

23 

24repo_id = "lucasschott/Enduro-v5-PPO" 

25filename = "model.zip" 

26 

27 

28agent_config = { 

29 "algo": PPO, 

30 "algo_kwargs": { 

31 "policy": "MlpPolicy", 

32 "policy_kwargs": { 

33 "features_extractor_class": NatureCNN, 

34 "activation_fn": nn.ReLU, 

35 "net_arch": dict(vf=[512, 256], pi=[256, 128]), 

36 }, 

37 "learning_rate": linear_schedule(3e-4, 1e-6), 

38 "n_steps": 256, 

39 "batch_size": 1024, 

40 "n_epochs": 4, 

41 "gamma": 0.99, 

42 "gae_lambda": 0.95, 

43 "clip_range": 0.1, 

44 "vf_coef": 0.5, 

45 "ent_coef": 0.01, 

46 }, 

47} 

48training_kwargs = { 

49 "total_timesteps": 40000000, 

50 "n_eval_episodes": 50, 

51 "eval_freq": 0, 

52 "prepopulate_timesteps": 0, 

53} 

54 

55 

56adversary_config = { 

57 "algo": PPO, 

58 "algo_kwargs": { 

59 "policy": "MlpPolicy", 

60 "policy_kwargs": { 

61 "features_extractor_class": NatureCNN, 

62 "activation_fn": nn.ReLU, 

63 "net_arch": dict(vf=[512, 256], pi=[256, 128]), 

64 }, 

65 "learning_rate": 2.5e-4, 

66 "n_steps": 256, 

67 "batch_size": 1024, 

68 "n_epochs": 4, 

69 "gamma": 0.99, 

70 "gae_lambda": 0.95, 

71 "clip_range": 0.1, 

72 "vf_coef": 0.5, 

73 "ent_coef": 0.01, 

74 }, 

75} 

76adversary_training_kwargs = { 

77 "total_timesteps": 40000000, 

78 "n_eval_episodes": 0, 

79 "eval_freq": 0, 

80 "prepopulate_timesteps": 0, 

81} 

82 

83 

84finetuned_agent_config = { 

85 "algo": PPO, 

86 "algo_kwargs": { 

87 "policy": "MlpPolicy", 

88 "policy_kwargs": { 

89 "features_extractor_class": NatureCNN, 

90 "activation_fn": nn.ReLU, 

91 "net_arch": dict(vf=[512, 256], pi=[256, 128]), 

92 }, 

93 "learning_rate": 5e-5, 

94 "n_steps": 256, 

95 "batch_size": 1024, 

96 "n_epochs": 4, 

97 "gamma": 0.99, 

98 "gae_lambda": 0.95, 

99 "clip_range": 0.1, 

100 "vf_coef": 0.5, 

101 "ent_coef": 0.01, 

102 }, 

103} 

104finetuned_training_kwargs = { 

105 "total_timesteps": 2000000, 

106 "n_eval_episodes": 0, 

107 "eval_freq": 0, 

108 "prepopulate_timesteps": 100000, 

109} 

110 

111 

112n_eval_episodes = 50