Skip to content

Commit d4643f3

Browse files
author
Guglielmo Camporese
committed
fix predict
1 parent 761f54a commit d4643f3

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

main.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,10 +158,13 @@ def main(args):
158158
_ = model.eval()
159159
device = next(model.parameters()).device
160160
for x, x_path in tqdm(ds, desc='Save predictions'):
161+
H, W = x.shape[-2:]
162+
x = transforms.Resize((256, 256))(x)
161163
x = x.unsqueeze(0).to(device)
162164
logits = model(x).detach().cpu()
163165
preds = F.softmax(logits, 1).argmax(1)[0] * 255 # [h, w]
164-
preds = Image.fromarray(preds.numpy().astype(np.uint8), 'P')
166+
preds = Image.fromarray(preds.numpy().astype(np.uint8), 'L')
167+
preds = preds.resize((W, H))
165168
preds.save(f'{x_path}.png')
166169

167170
else:

0 commit comments

Comments
 (0)