Skip to content

Commit 1b89f0d

Browse files
hengoren-exphengoren
authored andcommitted
Support s3 log_dir in Tensorboard logger.
1 parent fe7fd9d commit 1b89f0d

File tree

2 files changed

+18
-4
lines changed

2 files changed

+18
-4
lines changed

composer/loggers/tensorboard_logger.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from pathlib import Path
77
from typing import Any, Optional, Sequence, Union
8+
from urllib.parse import urlparse
89

910
import numpy as np
1011
import torch
@@ -108,9 +109,16 @@ def _initialize_summary_writer(self):
108109

109110
assert self.run_name is not None
110111
assert self.log_dir is not None
111-
# We name the child directory after the run_name to ensure the run_name shows up
112-
# in the Tensorboard GUI.
113-
summary_writer_log_dir = Path(self.log_dir) / self.run_name
112+
113+
parsed = urlparse(self.log_dir)
114+
# TODO: Handle other remote storage schemes
115+
if parsed.scheme == 's3':
116+
scheme, bucket, prefix, _, _, _ = parsed
117+
summary_writer_log_dir = f"{scheme}://{bucket}/{prefix.strip('/')}/{self.run_name}"
118+
else:
119+
# We name the child directory after the run_name to ensure the run_name shows up
120+
# in the Tensorboard GUI.
121+
summary_writer_log_dir = str(Path(self.log_dir) / self.run_name)
114122

115123
# Disable SummaryWriter's internal flushing to avoid file corruption while
116124
# file staged for upload to an ObjectStore.

tests/loggers/test_tensorboard_logger.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
from typing import Sequence
55

6+
import boto3
7+
import moto
68
import pytest
79
import torch
810

@@ -54,8 +56,12 @@ def test_tensorboard_log_image(test_tensorboard_logger, dummy_state):
5456
# Tensorboard images are stored inline, so we can't check them automatically.
5557

5658

59+
@moto.mock_aws
5760
def test_tensorboard_logger_s3_log_dir(dummy_state):
5861
bucket_name = 'test-tensorboard-bucket'
62+
s3 = boto3.client('s3')
63+
s3.create_bucket(Bucket=bucket_name)
64+
5965
test_s3_log_dir = f's3://{bucket_name}/log_prefix'
6066

6167
dummy_state.run_name = 'tensorboard-test-log-s3'
@@ -64,4 +70,4 @@ def test_tensorboard_logger_s3_log_dir(dummy_state):
6470
tensorboard_logger.init(dummy_state, logger)
6571
assert tensorboard_logger.writer is not None
6672
expected_log_dir = f'{test_s3_log_dir}/{dummy_state.run_name}'
67-
assert str(tensorboard_logger.writer.log_dir) == expected_log_dir
73+
assert tensorboard_logger.writer.log_dir == expected_log_dir

0 commit comments

Comments
 (0)