33# ###############################################################################
44using LinearAlgebra
55using Statistics
6+ using Folds
67
78import LinearAlgebra. normalize
8- import GeometryBasics. update
9+ import Dojo . GeometryBasics. update
910
1011# ARS options: hyper parameters
11- @with_kw struct HyperParameters{T}
12+ Base . @kwdef struct HyperParameters{T}
1213 main_loop_size:: Int = 100
1314 horizon:: Int = 200
1415 step_size:: T = 0.02
@@ -114,16 +115,19 @@ function rollout_policy(θ::Matrix, env::Environment, normalizer::Normalizer, hp
114115end
115116
116117function train (env:: Environment , policy:: Policy{T} , normalizer:: Normalizer{T} ,
117- hp:: HyperParameters{T} ; distributed= false ) where T
118+ hp:: HyperParameters{T} ; distributed= false , usefolds = false , foldsexec = Folds . ThreadedEx (;basesize = 1 ) ) where T
118119 println (" Training linear policy with Augmented Random Search (ARS)\n " )
119120 if distributed
120121 envs = [deepcopy (env) for i = 1 : (2 * hp. n_directions)]
121122 normalizers = [deepcopy (normalizer) for i = 1 : (2 * hp. n_directions)]
122123 hps = [deepcopy (hp) for i = 1 : (2 * hp. n_directions)]
123124 print (" $(nprocs ()) processors" )
125+ elseif usefolds
126+ envs = [deepcopy (env) for i = 1 : (2 * hp. n_directions)]
127+ print (" $(Threads. nthreads ()) threads with Folds" )
124128 else
125- envs = [deepcopy (env) for i = 1 : Threads. nthreads ()]
126- print (" $(Threads. nthreads ()) threads " )
129+ envs = [deepcopy (env) for i = 1 : ( Threads. nthreads () )]
130+ print (" $(Threads. nthreads ()) " )
127131 end
128132
129133 # pre-allocate for rewards
@@ -134,14 +138,20 @@ function train(env::Environment, policy::Policy{T}, normalizer::Normalizer{T},
134138 θs, δs = sample_policy (policy)
135139
136140 # evaluate policies
137- if distributed
138- rewards .= pmap (rollout_policy, θs, envs, normalizers, hps)
139- else
140- Threads. @threads for k = 1 : (2 * hp. n_directions)
141- rewards[k] = rollout_policy (θs[k], envs[Threads. threadid ()], normalizer, hp)
141+ roll_time = @elapsed begin
142+ if distributed
143+ rewards .= pmap (rollout_policy, θs, envs, normalizers, hps)
144+ elseif usefolds
145+ @assert length (envs) == size (θs, 1 ) " $(length (envs)) "
146+ Folds. map! (rewards, θs, envs, foldsexec) do θ, env
147+ rollout_policy (θ, env, normalizer, hp)
148+ end
149+ else
150+ Threads. @threads for k = 1 : (2 * hp. n_directions)
151+ rewards[k] = rollout_policy (θs[k], envs[Threads. threadid ()], normalizer, hp)
152+ end
142153 end
143154 end
144-
145155 # reward evaluation
146156 r_max = [max (rewards[k], rewards[hp. n_directions + k]) for k = 1 : hp. n_directions]
147157 σ_r = std (rewards)
@@ -152,7 +162,7 @@ function train(env::Environment, policy::Policy{T}, normalizer::Normalizer{T},
152162 update (policy, rollouts, σ_r)
153163
154164 # finish, print:
155- println (" episode $episode reward_evaluation $(mean (rewards)) " )
165+ println (" episode $episode reward_evaluation $(mean (rewards)) . Took $(roll_time) seconds " )
156166 end
157167
158168 return nothing
0 commit comments