Skip to content

Commit 08188a6

Browse files
committed
fix ppo policy
1 parent b7b11e0 commit 08188a6

File tree

2 files changed

+17
-9
lines changed

2 files changed

+17
-9
lines changed

examples/deeprl/ant_ppo.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,17 @@ function RL.Experiment(
2828
agent = Agent(
2929
policy = PPOPolicy(
3030
approximator = ActorCritic(
31-
actor = Chain(
32-
Dense(ns, 256, relu; init = glorot_uniform(rng)),
33-
Dense(256, na; init = glorot_uniform(rng)),
34-
),
31+
actor = GaussianNetwork(
32+
pre = Chain(
33+
Dense(ns, 64, relu; init = glorot_uniform(rng)),
34+
Dense(64, 64, relu; init = glorot_uniform(rng)),
35+
),
36+
μ = Chain(Dense(64, na, tanh; init = glorot_uniform(rng)), vec),
37+
logσ = Chain(Dense(64, na; init = glorot_uniform(rng)), vec),
38+
),
3539
critic = Chain(
3640
Dense(ns, 256, relu; init = glorot_uniform(rng)),
37-
Dense(256, 1; init = glorot_uniform(rng)),
41+
Dense(256, na; init = glorot_uniform(rng)),
3842
),
3943
optimizer = ADAM(1e-3),
4044
),

examples/deeprl/cartpole_ppo.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,14 @@ function RL.Experiment(
2828
agent = Agent(
2929
policy = PPOPolicy(
3030
approximator = ActorCritic(
31-
actor = Chain(
32-
Dense(ns, 256, relu; init = glorot_uniform(rng)),
33-
Dense(256, na; init = glorot_uniform(rng)),
34-
),
31+
actor = GaussianNetwork(
32+
pre = Chain(
33+
Dense(ns, 64, relu; init = glorot_uniform(rng)),
34+
Dense(64, 64, relu; init = glorot_uniform(rng)),
35+
),
36+
μ = Chain(Dense(64, na, tanh; init = glorot_uniform(rng)), vec),
37+
logσ = Chain(Dense(64, na; init = glorot_uniform(rng)), vec),
38+
),
3539
critic = Chain(
3640
Dense(ns, 256, relu; init = glorot_uniform(rng)),
3741
Dense(256, 1; init = glorot_uniform(rng)),

0 commit comments

Comments
 (0)