11import torch
22import torch .nn .functional as F
33import torch .optim as optim
4- from torch .autograd import Variable
54
65from envs import create_atari_env
76from model import ActorCritic
@@ -37,11 +36,11 @@ def train(rank, args, shared_model, counter, lock, optimizer=None):
3736 # Sync with the shared model
3837 model .load_state_dict (shared_model .state_dict ())
3938 if done :
40- cx = Variable ( torch .zeros (1 , 256 ) )
41- hx = Variable ( torch .zeros (1 , 256 ) )
39+ cx = torch .zeros (1 , 256 )
40+ hx = torch .zeros (1 , 256 )
4241 else :
43- cx = Variable ( cx .data )
44- hx = Variable ( hx .data )
42+ cx = cx .detach ( )
43+ hx = hx .detach ( )
4544
4645 values = []
4746 log_probs = []
@@ -50,15 +49,15 @@ def train(rank, args, shared_model, counter, lock, optimizer=None):
5049
5150 for step in range (args .num_steps ):
5251 episode_length += 1
53- value , logit , (hx , cx ) = model ((Variable ( state .unsqueeze (0 ) ),
52+ value , logit , (hx , cx ) = model ((state .unsqueeze (0 ),
5453 (hx , cx )))
55- prob = F .softmax (logit )
56- log_prob = F .log_softmax (logit )
54+ prob = F .softmax (logit , dim = - 1 )
55+ log_prob = F .log_softmax (logit , dim = - 1 )
5756 entropy = - (log_prob * prob ).sum (1 , keepdim = True )
5857 entropies .append (entropy )
5958
60- action = prob .multinomial (num_samples = 1 ).data
61- log_prob = log_prob .gather (1 , Variable ( action ) )
59+ action = prob .multinomial (num_samples = 1 ).detach ()
60+ log_prob = log_prob .gather (1 , action )
6261
6362 state , reward , done , _ = env .step (action .numpy ())
6463 done = done or episode_length >= args .max_episode_length
@@ -81,13 +80,12 @@ def train(rank, args, shared_model, counter, lock, optimizer=None):
8180
8281 R = torch .zeros (1 , 1 )
8382 if not done :
84- value , _ , _ = model ((Variable ( state .unsqueeze (0 ) ), (hx , cx )))
85- R = value .data
83+ value , _ , _ = model ((state .unsqueeze (0 ), (hx , cx )))
84+ R = value .detach ()
8685
87- values .append (Variable ( R ) )
86+ values .append (R )
8887 policy_loss = 0
8988 value_loss = 0
90- R = Variable (R )
9189 gae = torch .zeros (1 , 1 )
9290 for i in reversed (range (len (rewards ))):
9391 R = args .gamma * R + rewards [i ]
@@ -96,16 +94,16 @@ def train(rank, args, shared_model, counter, lock, optimizer=None):
9694
9795 # Generalized Advantage Estimataion
9896 delta_t = rewards [i ] + args .gamma * \
99- values [i + 1 ]. data - values [i ]. data
97+ values [i + 1 ] - values [i ]
10098 gae = gae * args .gamma * args .tau + delta_t
10199
102100 policy_loss = policy_loss - \
103- log_probs [i ] * Variable ( gae ) - args .entropy_coef * entropies [i ]
101+ log_probs [i ] * gae . detach ( ) - args .entropy_coef * entropies [i ]
104102
105103 optimizer .zero_grad ()
106104
107105 (policy_loss + args .value_loss_coef * value_loss ).backward ()
108- torch .nn .utils .clip_grad_norm (model .parameters (), args .max_grad_norm )
106+ torch .nn .utils .clip_grad_norm_ (model .parameters (), args .max_grad_norm )
109107
110108 ensure_shared_grads (model , shared_model )
111109 optimizer .step ()
0 commit comments