We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 761f54a commit d4643f3Copy full SHA for d4643f3
main.py
@@ -158,10 +158,13 @@ def main(args):
158
_ = model.eval()
159
device = next(model.parameters()).device
160
for x, x_path in tqdm(ds, desc='Save predictions'):
161
+ H, W = x.shape[-2:]
162
+ x = transforms.Resize((256, 256))(x)
163
x = x.unsqueeze(0).to(device)
164
logits = model(x).detach().cpu()
165
preds = F.softmax(logits, 1).argmax(1)[0] * 255 # [h, w]
- preds = Image.fromarray(preds.numpy().astype(np.uint8), 'P')
166
+ preds = Image.fromarray(preds.numpy().astype(np.uint8), 'L')
167
+ preds = preds.resize((W, H))
168
preds.save(f'{x_path}.png')
169
170
else:
0 commit comments