Skip to content

Commit bf99cf6

Browse files
authored
Merge pull request #7 from rejuvyesh/jkg/folds
Allow using Folds to parallelize `ARS`
2 parents d926fc2 + 5cc6c87 commit bf99cf6

File tree

3 files changed

+26
-14
lines changed

3 files changed

+26
-14
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,11 @@ JLD2 = "0.4.21"
3535
LaTeXStrings = "1.3.0"
3636
LightGraphs = "1.3.5"
3737
LightXML = "<0.9.0, 0.9.0"
38-
MeshCat = "0.13.0 - 0.13.0"
38+
MeshCat = "0.14"
3939
Meshing = "0.5.7"
4040
Parameters = "0.12"
4141
Polyhedra = "0.6.18"
4242
Rotations = "1.0.2 - 1.0.2"
4343
Scratch = "1.1"
44-
StaticArrays = "0.12, 1.0"
44+
StaticArrays = "1.4"
4545
julia = "1.6"

examples/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
[deps]
22
Dojo = "ac60b53e-8d92-4c83-b960-e78698fa1916"
3+
Folds = "41a02a25-b8f0-4f67-bc48-60067656b558"
34
IterativeLQR = "605048dd-e178-462b-beb9-98a09398ef27"
45
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
56
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
67
PGFPlots = "3b7a836e-365b-5785-a47d-02c71176b4aa"
78
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
89

910
[compat]
11+
Folds = "0.2"
1012
IterativeLQR = "0.1.1"
1113
JLD2 = "0.4.21"
1214
Literate = "2.13.0"

examples/reinforcement_learning/algorithms/ars.jl

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33
################################################################################
44
using LinearAlgebra
55
using Statistics
6+
using Folds
67

78
import LinearAlgebra.normalize
8-
import GeometryBasics.update
9+
import Dojo.GeometryBasics.update
910

1011
# ARS options: hyper parameters
11-
@with_kw struct HyperParameters{T}
12+
Base.@kwdef struct HyperParameters{T}
1213
main_loop_size::Int = 100
1314
horizon::Int = 200
1415
step_size::T = 0.02
@@ -114,16 +115,19 @@ function rollout_policy(θ::Matrix, env::Environment, normalizer::Normalizer, hp
114115
end
115116

116117
function train(env::Environment, policy::Policy{T}, normalizer::Normalizer{T},
117-
hp::HyperParameters{T}; distributed=false) where T
118+
hp::HyperParameters{T}; distributed=false, usefolds=false, foldsexec=Folds.ThreadedEx(;basesize=1)) where T
118119
println("Training linear policy with Augmented Random Search (ARS)\n ")
119120
if distributed
120121
envs = [deepcopy(env) for i = 1:(2 * hp.n_directions)]
121122
normalizers = [deepcopy(normalizer) for i = 1:(2 * hp.n_directions)]
122123
hps = [deepcopy(hp) for i = 1:(2 * hp.n_directions)]
123124
print(" $(nprocs()) processors")
125+
elseif usefolds
126+
envs = [deepcopy(env) for i = 1:(2*hp.n_directions)]
127+
print(" $(Threads.nthreads()) threads with Folds")
124128
else
125-
envs = [deepcopy(env) for i = 1:Threads.nthreads()]
126-
print(" $(Threads.nthreads()) threads")
129+
envs = [deepcopy(env) for i = 1:(Threads.nthreads())]
130+
print(" $(Threads.nthreads()) ")
127131
end
128132

129133
# pre-allocate for rewards
@@ -134,14 +138,20 @@ function train(env::Environment, policy::Policy{T}, normalizer::Normalizer{T},
134138
θs, δs = sample_policy(policy)
135139

136140
# evaluate policies
137-
if distributed
138-
rewards .= pmap(rollout_policy, θs, envs, normalizers, hps)
139-
else
140-
Threads.@threads for k = 1:(2 * hp.n_directions)
141-
rewards[k] = rollout_policy(θs[k], envs[Threads.threadid()], normalizer, hp)
141+
roll_time = @elapsed begin
142+
if distributed
143+
rewards .= pmap(rollout_policy, θs, envs, normalizers, hps)
144+
elseif usefolds
145+
@assert length(envs) == size(θs, 1) "$(length(envs))"
146+
Folds.map!(rewards, θs, envs, foldsexec) do θ, env
147+
rollout_policy(θ, env, normalizer, hp)
148+
end
149+
else
150+
Threads.@threads for k = 1:(2 * hp.n_directions)
151+
rewards[k] = rollout_policy(θs[k], envs[Threads.threadid()], normalizer, hp)
152+
end
142153
end
143154
end
144-
145155
# reward evaluation
146156
r_max = [max(rewards[k], rewards[hp.n_directions + k]) for k = 1:hp.n_directions]
147157
σ_r = std(rewards)
@@ -152,7 +162,7 @@ function train(env::Environment, policy::Policy{T}, normalizer::Normalizer{T},
152162
update(policy, rollouts, σ_r)
153163

154164
# finish, print:
155-
println("episode $episode reward_evaluation $(mean(rewards))")
165+
println("episode $episode reward_evaluation $(mean(rewards)). Took $(roll_time) seconds")
156166
end
157167

158168
return nothing

0 commit comments

Comments
 (0)