Skip to content

Commit d4db289

Browse files
committed
Add an FPS counter
1 parent 17ac569 commit d4db289

File tree

3 files changed

+12
-6
lines changed

3 files changed

+12
-6
lines changed

main.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@
4949

5050
# uncomment when it's fixed in pytorch
5151
# torch.manual_seed(args.seed)
52-
5352
env = create_atari_env(args.env_name)
5453
shared_model = ActorCritic(
5554
env.observation_space.shape[0], env.action_space)
@@ -63,12 +62,15 @@
6362

6463
processes = []
6564

66-
p = mp.Process(target=test, args=(args.num_processes, args, shared_model))
65+
counter = mp.Value('i', 0)
66+
lock = mp.Lock()
67+
68+
p = mp.Process(target=test, args=(args.num_processes, args, shared_model, counter))
6769
p.start()
6870
processes.append(p)
6971

7072
for rank in range(0, args.num_processes):
71-
p = mp.Process(target=train, args=(rank, args, shared_model, optimizer))
73+
p = mp.Process(target=train, args=(rank, args, shared_model, counter, lock, optimizer))
7274
p.start()
7375
processes.append(p)
7476
for p in processes:

test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from model import ActorCritic
1010

1111

12-
def test(rank, args, shared_model):
12+
def test(rank, args, shared_model, counter):
1313
torch.manual_seed(args.seed + rank)
1414

1515
env = create_atari_env(args.env_name)
@@ -55,9 +55,10 @@ def test(rank, args, shared_model):
5555
done = True
5656

5757
if done:
58-
print("Time {}, episode reward {}, episode length {}".format(
58+
print("Time {}, num steps {}, FPS {:.0f}, episode reward {}, episode length {}".format(
5959
time.strftime("%Hh %Mm %Ss",
6060
time.gmtime(time.time() - start_time)),
61+
counter.value, counter.value / (time.time() - start_time),
6162
reward_sum, episode_length))
6263
reward_sum = 0
6364
episode_length = 0

train.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def ensure_shared_grads(model, shared_model):
1515
shared_param._grad = param.grad
1616

1717

18-
def train(rank, args, shared_model, optimizer=None):
18+
def train(rank, args, shared_model, counter, lock, optimizer=None):
1919
torch.manual_seed(args.seed + rank)
2020

2121
env = create_atari_env(args.env_name)
@@ -64,6 +64,9 @@ def train(rank, args, shared_model, optimizer=None):
6464
done = done or episode_length >= args.max_episode_length
6565
reward = max(min(reward, 1), -1)
6666

67+
with lock:
68+
counter.value += 1
69+
6770
if done:
6871
episode_length = 0
6972
state = env.reset()

0 commit comments

Comments
 (0)