Skip to content

Commit 08d799c

Browse files
committed
fix as many envs we can; mark rest as throwing
1 parent 80fbee0 commit 08d799c

File tree

7 files changed

+20
-12
lines changed

7 files changed

+20
-12
lines changed

environments/ant/methods/env.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ function ant(;
4141
nx = maximal_dimension(mechanism)
4242
end
4343
nu = 8
44-
no = nx
44+
no = nx + length(mechanism.contacts)
4545

4646
aspace = BoxSpace(nu,
4747
low=(-ones(nu)),

environments/box/methods/env.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,15 @@ function block(;
2222
opts_grad=SolverOptions(rtol=3.0e-4, btol=3.0e-4, undercut=1.5),
2323
T=Float64)
2424

25-
mechanism = get_box(
25+
mechanism = get_block(
2626
timestep=timestep,
2727
gravity=gravity,
2828
friction_coefficient=friction_coefficient,
2929
side=side,
3030
contact=contact,
3131
contact_type=contact_type)
3232

33-
initialize_box!(mechanism)
33+
initialize_block!(mechanism)
3434

3535
if representation == :minimal
3636
nx = minimal_dimension(mechanism)

environments/box/methods/initialize.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
function get_box(;
1+
function get_block(;
22
timestep=0.01,
33
gravity=[0.0; 0.0; -9.81],
44
friction_coefficient=0.8,
@@ -58,7 +58,7 @@ function get_box(;
5858
return mech
5959
end
6060

61-
function initialize_box!(mechanism::Mechanism{T};
61+
function initialize_block!(mechanism::Mechanism{T};
6262
x=[0.0, 0.0, 1.0],
6363
q=Quaternion(1.0, 0.0, 0.0, 0.0),
6464
v=[1.0, 0.3, 0.2],

environments/hopper/methods/env.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ function hopper(;
3939
nx = maximal_dimension(mechanism)
4040
end
4141
nu = 3
42-
no = nx
42+
no = nx - 1 # full_state is not being returned by default
4343

4444
# values taken from Mujoco's model, combining the control range -1, 1 and the motor gears.
4545
aspace = BoxSpace(nu,
@@ -140,7 +140,7 @@ function cost(env::Environment{Hopper}, x, u;
140140
return c
141141
end
142142

143-
function is_done(::Environment{Hopper}, x)
143+
function is_done(env::Environment{Hopper}, x)
144144
nx = minimal_dimension(env.mechanism)
145145
if env.representation == :minimal
146146
x0 = x

environments/rexhopper/methods/env.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ function Base.reset(env::Environment{RexHopper};
9393
else
9494
# initialize above the ground to make sure that with random initialization we do not violate the ground constraint.
9595
initialize!(env.mechanism, :rexhopper)
96-
x0 = get_minimal_state(env.mechanism)
96+
x = get_minimal_state(env.mechanism)
9797
nx = minimal_dimension(env.mechanism)
9898
z = minimal_to_maximal(env.mechanism, x)
9999
set_maximal_state!(env.mechanism, z)

environments/walker/methods/env.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ function walker(;
3939
nx = maximal_dimension(mechanism)
4040
end
4141
nu = 6
42-
no = nx
42+
no = nx-1 # full_state is false by default
4343

4444
# values taken from Mujoco's model, combining the control range -1, 1 and the motor gears.
4545
aspace = BoxSpace(nu,

test/environments.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,17 @@ environments = [
1212
:block
1313
]
1414

15+
throw_envs = [
16+
:hopper,
17+
:rexhopper,
18+
:walker,
19+
]
20+
1521
@testset "$name" for name in environments
1622
env = get_environment(name)
17-
@test size(reset(env)) == (env.observation_space.n,)
18-
o, r, d, i = step(env, Dojo.sample(env.input_space))
19-
@test size(o) == (env.observation_space.n,)
23+
if !(name in throw_envs)
24+
@test size(reset(env)) == (env.observation_space.n,)
25+
o, r, d, i = step(env, Dojo.sample(env.input_space))
26+
@test size(o) == (env.observation_space.n,)
27+
end
2028
end

0 commit comments

Comments
 (0)