Skip to content

Commit 8826e21

Browse files
committed
Update to pytorch 0.4.1
1 parent e898f75 commit 8826e21

File tree

2 files changed

+23
-26
lines changed

2 files changed

+23
-26
lines changed

test.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
import torch
55
import torch.nn.functional as F
6-
from torch.autograd import Variable
76

87
from envs import create_atari_env
98
from model import ActorCritic
@@ -34,16 +33,16 @@ def test(rank, args, shared_model, counter):
3433
# Sync with the shared model
3534
if done:
3635
model.load_state_dict(shared_model.state_dict())
37-
cx = Variable(torch.zeros(1, 256), volatile=True)
38-
hx = Variable(torch.zeros(1, 256), volatile=True)
36+
cx = torch.zeros(1, 256)
37+
hx = torch.zeros(1, 256)
3938
else:
40-
cx = Variable(cx.data, volatile=True)
41-
hx = Variable(hx.data, volatile=True)
39+
cx = cx.detach()
40+
hx = hx.detach()
4241

43-
value, logit, (hx, cx) = model((Variable(
44-
state.unsqueeze(0), volatile=True), (hx, cx)))
45-
prob = F.softmax(logit)
46-
action = prob.max(1, keepdim=True)[1].data.numpy()
42+
with torch.no_grad():
43+
value, logit, (hx, cx) = model((state.unsqueeze(0), (hx, cx)))
44+
prob = F.softmax(logit, dim=-1)
45+
action = prob.max(1, keepdim=True)[1].numpy()
4746

4847
state, reward, done, _ = env.step(action[0, 0])
4948
done = done or episode_length >= args.max_episode_length

train.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import torch
22
import torch.nn.functional as F
33
import torch.optim as optim
4-
from torch.autograd import Variable
54

65
from envs import create_atari_env
76
from model import ActorCritic
@@ -37,11 +36,11 @@ def train(rank, args, shared_model, counter, lock, optimizer=None):
3736
# Sync with the shared model
3837
model.load_state_dict(shared_model.state_dict())
3938
if done:
40-
cx = Variable(torch.zeros(1, 256))
41-
hx = Variable(torch.zeros(1, 256))
39+
cx = torch.zeros(1, 256)
40+
hx = torch.zeros(1, 256)
4241
else:
43-
cx = Variable(cx.data)
44-
hx = Variable(hx.data)
42+
cx = cx.detach()
43+
hx = hx.detach()
4544

4645
values = []
4746
log_probs = []
@@ -50,15 +49,15 @@ def train(rank, args, shared_model, counter, lock, optimizer=None):
5049

5150
for step in range(args.num_steps):
5251
episode_length += 1
53-
value, logit, (hx, cx) = model((Variable(state.unsqueeze(0)),
52+
value, logit, (hx, cx) = model((state.unsqueeze(0),
5453
(hx, cx)))
55-
prob = F.softmax(logit)
56-
log_prob = F.log_softmax(logit)
54+
prob = F.softmax(logit, dim=-1)
55+
log_prob = F.log_softmax(logit, dim=-1)
5756
entropy = -(log_prob * prob).sum(1, keepdim=True)
5857
entropies.append(entropy)
5958

60-
action = prob.multinomial(num_samples=1).data
61-
log_prob = log_prob.gather(1, Variable(action))
59+
action = prob.multinomial(num_samples=1).detach()
60+
log_prob = log_prob.gather(1, action)
6261

6362
state, reward, done, _ = env.step(action.numpy())
6463
done = done or episode_length >= args.max_episode_length
@@ -81,13 +80,12 @@ def train(rank, args, shared_model, counter, lock, optimizer=None):
8180

8281
R = torch.zeros(1, 1)
8382
if not done:
84-
value, _, _ = model((Variable(state.unsqueeze(0)), (hx, cx)))
85-
R = value.data
83+
value, _, _ = model((state.unsqueeze(0), (hx, cx)))
84+
R = value.detach()
8685

87-
values.append(Variable(R))
86+
values.append(R)
8887
policy_loss = 0
8988
value_loss = 0
90-
R = Variable(R)
9189
gae = torch.zeros(1, 1)
9290
for i in reversed(range(len(rewards))):
9391
R = args.gamma * R + rewards[i]
@@ -96,16 +94,16 @@ def train(rank, args, shared_model, counter, lock, optimizer=None):
9694

9795
# Generalized Advantage Estimataion
9896
delta_t = rewards[i] + args.gamma * \
99-
values[i + 1].data - values[i].data
97+
values[i + 1] - values[i]
10098
gae = gae * args.gamma * args.tau + delta_t
10199

102100
policy_loss = policy_loss - \
103-
log_probs[i] * Variable(gae) - args.entropy_coef * entropies[i]
101+
log_probs[i] * gae.detach() - args.entropy_coef * entropies[i]
104102

105103
optimizer.zero_grad()
106104

107105
(policy_loss + args.value_loss_coef * value_loss).backward()
108-
torch.nn.utils.clip_grad_norm(model.parameters(), args.max_grad_norm)
106+
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
109107

110108
ensure_shared_grads(model, shared_model)
111109
optimizer.step()

0 commit comments

Comments
 (0)