Skip to content

Commit 0e970cc

Browse files
authored
fix: pass weights_only=True to torch.load (#203)
1 parent 02ed72b commit 0e970cc

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

tests/data/pytorch_mnist/mnist.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def model_fn(model_dir):
169169
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
170170
model = torch.nn.DataParallel(Net())
171171
with open(os.path.join(model_dir, 'model.pth'), 'rb') as f:
172-
model.load_state_dict(torch.load(f))
172+
model.load_state_dict(torch.load(f, weights_only=True))
173173
return model.to(device)
174174

175175

0 commit comments

Comments
 (0)