Skip to content

Commit 33520d5

Browse files
Merge pull request #71 from SamirMoustafa/patch-1
Fix device mismatch errors in torchstain color space conversion functions
2 parents 7c2a95f + 970f819 commit 33520d5

File tree

6 files changed

+43
-34
lines changed

6 files changed

+43
-34
lines changed

.github/workflows/build.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ on:
66
jobs:
77
build_wheels:
88
name: Build release
9-
runs-on: ubuntu-20.04
9+
runs-on: ubuntu-24.04
1010

1111
steps:
1212
- uses: actions/checkout@v4
@@ -26,7 +26,7 @@ jobs:
2626

2727
upload_pypi:
2828
needs: build_wheels
29-
runs-on: ubuntu-20.04
29+
runs-on: ubuntu-24.04
3030

3131
if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v')
3232

.github/workflows/tests_full.yml

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,26 @@
11
name: Full Tests
22

33
on:
4-
push:
5-
branches:
6-
- main
7-
pull_request:
8-
branches:
9-
- main
4+
# push:
5+
# branches:
6+
# - main
7+
# pull_request:
8+
# branches-ignore:
9+
# - main
1010
workflow_dispatch:
1111

1212
jobs:
1313
build:
14-
runs-on: ubuntu-20.04
14+
runs-on: ubuntu-24.04
1515

1616
if: startsWith(github.ref, 'refs/tags/v') != true
1717

1818
steps:
1919
- uses: actions/checkout@v4
20-
- name: Set up Python 3.6
20+
- name: Set up Python 3.8
2121
uses: actions/setup-python@v4
2222
with:
23-
python-version: 3.6
23+
python-version: 3.8
2424

2525
- name: Install dependencies
2626
run: pip install wheel setuptools
@@ -40,8 +40,8 @@ jobs:
4040
runs-on: ${{ matrix.os }}
4141
strategy:
4242
matrix:
43-
os: [ windows-2019, ubuntu-20.04, macos-13 ]
44-
python-version: [ 3.7, 3.8, 3.9 ]
43+
os: [ windows-latest, ubuntu-latest, macos-latest ]
44+
python-version: [ 3.8, 3.9]
4545
tf-version: [2.7.0, 2.8.0, 2.9.0]
4646

4747
steps:
@@ -71,8 +71,8 @@ jobs:
7171
runs-on: ${{ matrix.os }}
7272
strategy:
7373
matrix:
74-
os: [ windows-2019, ubuntu-20.04, macos-13 ]
75-
python-version: [ 3.7, 3.8, 3.9 ]
74+
os: [ windows-latest, ubuntu-latest, macos-latest ]
75+
python-version: [ 3.8, 3.9 ]
7676
pytorch-version: [1.8.0, 1.9.0, 1.10.0, 1.11.0, 1.12.0, 1.13.0]
7777

7878
steps:

.github/workflows/tests_quick.yml

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,23 @@
11
name: Quick Tests
22

3-
on:
4-
push:
5-
branches-ignore:
6-
- main
7-
pull_request:
8-
branches-ignore:
9-
- main
10-
workflow_dispatch:
3+
on: [push, pull_request, workflow_dispatch]
4+
# push:
5+
#branches-ignore:
6+
# - main
7+
# pull_request:
8+
#branches-ignore:
9+
# - main
10+
# workflow_dispatch:
1111

1212
jobs:
1313
build:
14-
runs-on: ubuntu-20.04
14+
runs-on: ubuntu-24.04
1515
steps:
1616
- uses: actions/checkout@v4
17-
- name: Set up Python 3.6
17+
- name: Set up Python 3.8
1818
uses: actions/setup-python@v4
1919
with:
20-
python-version: 3.6
20+
python-version: 3.8
2121

2222
- name: Install dependencies
2323
run: pip install wheel setuptools
@@ -34,7 +34,7 @@ jobs:
3434

3535
test-tf:
3636
needs: build
37-
runs-on: ubuntu-20.04
37+
runs-on: ubuntu-24.04
3838

3939
steps:
4040
- uses: actions/checkout@v4
@@ -60,7 +60,7 @@ jobs:
6060

6161
test-torch:
6262
needs: build
63-
runs-on: ubuntu-20.04
63+
runs-on: ubuntu-24.04
6464

6565
steps:
6666
- uses: actions/checkout@v4

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
setup(
88
name='torchstain',
9-
version='1.4.0',
9+
version='1.4.1',
1010
description='Stain normalization tools for histological analysis and computational pathology',
1111
long_description=README,
1212
long_description_content_type='text/markdown',

torchstain/torch/utils/lab2rgb.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@
66
def lab2rgb(lab):
77
lab = lab.type(torch.float32)
88

9+
# Move constant tensors to the same device as input
10+
device = lab.device
11+
_white_device = _white.to(device)
12+
_xyz2rgb_device = _xyz2rgb.to(device)
13+
914
# rescale back from OpenCV format and extract LAB channel
1015
L, a, b = lab[0] / 2.55, lab[1] - 128, lab[2] - 128
1116

@@ -24,10 +29,10 @@ def lab2rgb(lab):
2429
out.masked_scatter_(not_mask, (torch.masked_select(out, not_mask) - 16 / 116) / 7.787)
2530

2631
# rescale to the reference white (illuminant)
27-
out = torch.mul(out, _white.type(out.dtype).unsqueeze(dim=-1).unsqueeze(dim=-1))
32+
out = torch.mul(out, _white_device.type(out.dtype).unsqueeze(dim=-1).unsqueeze(dim=-1))
2833

2934
# convert XYZ -> RGB color domain
30-
arr = torch.tensordot(out, torch.t(_xyz2rgb).type(out.dtype), dims=([0], [0]))
35+
arr = torch.tensordot(out, torch.t(_xyz2rgb_device).type(out.dtype), dims=([0], [0]))
3136
mask = arr > 0.0031308
3237
not_mask = torch.logical_not(mask)
3338
arr.masked_scatter_(mask, 1.055 * torch.pow(torch.masked_select(arr, mask), 1 / 2.4) - 0.055)

torchstain/torch/utils/rgb2lab.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,21 @@
99

1010
def rgb2lab(rgb):
1111
arr = rgb.type(torch.float32)
12-
12+
# Move constant tensors to the same device as input
13+
device = arr.device
14+
_rgb2xyz_device = _rgb2xyz.to(device)
15+
_white_device = _white.to(device)
16+
1317
# convert rgb -> xyz color domain
1418
mask = arr > 0.04045
1519
not_mask = torch.logical_not(mask)
1620
arr.masked_scatter_(mask, torch.pow((torch.masked_select(arr, mask) + 0.055) / 1.055, 2.4))
1721
arr.masked_scatter_(not_mask, torch.masked_select(arr, not_mask) / 12.92)
1822

19-
xyz = torch.tensordot(torch.t(_rgb2xyz), arr, dims=([0], [0]))
23+
xyz = torch.tensordot(torch.t(_rgb2xyz_device), arr, dims=([0], [0]))
2024

2125
# scale by CIE XYZ tristimulus values of the reference white point
22-
arr = torch.mul(xyz, 1 / _white.type(xyz.dtype).unsqueeze(dim=-1).unsqueeze(dim=-1))
26+
arr = torch.mul(xyz, 1 / _white_device.type(xyz.dtype).unsqueeze(dim=-1).unsqueeze(dim=-1))
2327

2428
# nonlinear distortion and linear transformation
2529
mask = arr > 0.008856

0 commit comments

Comments
 (0)