We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 02ed72b commit 0e970ccCopy full SHA for 0e970cc
tests/data/pytorch_mnist/mnist.py
@@ -169,7 +169,7 @@ def model_fn(model_dir):
169
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
170
model = torch.nn.DataParallel(Net())
171
with open(os.path.join(model_dir, 'model.pth'), 'rb') as f:
172
- model.load_state_dict(torch.load(f))
+ model.load_state_dict(torch.load(f, weights_only=True))
173
return model.to(device)
174
175
0 commit comments