Skip to content

Commit 9736031

Browse files
committed
get basic interface working
1 parent 1043fc3 commit 9736031

File tree

2 files changed

+20
-13
lines changed

2 files changed

+20
-13
lines changed

environments/environment.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ end
196196
abstract type Space{T,N} end
197197

198198
"""
199-
BoxSpace{T,N} <: Environment{T,N}
199+
BoxSpace{T,N} <: Space{T,N}
200200
201201
domain with lower and upper limits
202202
@@ -226,6 +226,10 @@ function contains(s::BoxSpace{T,N}, v::AbstractVector{T}) where {T,N}
226226
all(v .>= s.low) && all(v .<= s.high)
227227
end
228228

229+
# For compat with RLBase
230+
Base.in(v::AbstractVector{T}, s::BoxSpace{T,N}) where {T,N} = all(v .>= s.low) && all(v .<= s.high)
231+
Random.rand(rng::Random.AbstractRNG, s::BoxSpace{T,N}) where {T,N} = return rand(rng, T,N) .* (s.high .- s.low) .+ s.low
232+
229233
function clip(s::BoxSpace, u)
230234
clamp.(u, s.low, s.high)
231235
end

environments/rlenv.jl

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,24 @@
11
using ReinforcementLearningBase: RLBase
22

3-
mutable struct DojoRLEnv <: RLBase.AbstractEnv
4-
dojoenv
5-
action_space
6-
observation_space
7-
state
8-
reward
3+
mutable struct DojoRLEnv{T} <: RLBase.AbstractEnv
4+
dojoenv::Environment
5+
state::Vector{T}
6+
reward::T
97
done::Bool
108
info::Dict
119
end
1210

13-
function DojoRLEnv(dojoenv::Environment)
14-
action_space = convert(RLBase.Space, dojoenv.input_space)
15-
observation_space = convert(RLBase.Space, dojoenv.observation_space)
11+
function DojoRLEnv(dojoenv::Environment{X,T}) where {X,T}
1612
state = reset(dojoenv)
17-
return DojoRLEnv(dojoenv, action_space, observation_space, state, 0.0, false, Dict())
13+
return DojoRLEnv{T}(dojoenv, state, convert(T, 0.0), false, Dict())
1814
end
1915

20-
RLBase.action_space(env::DojoRLEnv) = env.action_space
21-
RLBase.state_space(env::DojoRLEnv) = env.observation_space
16+
function DojoRLEnv(name::String; kwargs...)
17+
DojoRLEnv(Dojo.get_environment(name; kwargs...))
18+
end
19+
20+
RLBase.action_space(env::DojoRLEnv) = env.dojoenv.input_space
21+
RLBase.state_space(env::DojoRLEnv) = env.dojoenv.observation_space
2222
RLBase.is_terminated(env::DojoRLEnv) = env.done
2323

2424
RLBase.reset!(env::DojoRLEnv) = reset(env.dojoenv)
@@ -28,6 +28,9 @@ RLBase.state(env::DojoRLEnv) = env.state
2828

2929
Random.seed!(env::DojoRLEnv, seed) = Dojo.seed(env.dojoenv, seed)
3030

31+
# TODO:
32+
# RLBase.ChanceStyle(env::DojoRLEnv) = RLBase.DETERMINISTIC
33+
3134
function (env::DojoRLEnv)(a)
3235
s, r, d, i = step(env.dojoenv, a)
3336
env.state = s

0 commit comments

Comments
 (0)