Skip to content

Commit 8b02b9f

Browse files
Merge pull request #25 from andreped/reinhard
Add reinhard color normalization support
2 parents b257184 + 946db08 commit 8b02b9f

File tree

21 files changed

+364
-11
lines changed

21 files changed

+364
-11
lines changed

README.md

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,11 @@
55
[![Pip Downloads](https://img.shields.io/pypi/dm/torchstain?label=pip%20downloads&logo=python)](https://pypi.org/project/torchstain/)
66
[![DOI](https://zenodo.org/badge/323590093.svg)](https://zenodo.org/badge/latestdoi/323590093)
77

8-
9-
108
GPU-accelerated stain normalization tools for histopathological images. Compatible with PyTorch, TensorFlow, and Numpy.
119
Normalization algorithms currently implemented:
1210

1311
- Macenko et al. [\[1\]](#reference) (ported from [numpy implementation](https://github.com/schaugf/HEnorm_python))
12+
- Reinhard et al. [\[2\]](#reference) (only numpy & TensorFlow backend support)
1413

1514
## Installation
1615

@@ -43,20 +42,19 @@ t_to_transform = T(to_transform)
4342
norm, H, E = normalizer.normalize(I=t_to_transform, stains=True)
4443
```
4544

46-
![alt text](result.png)
45+
![alt text](data/result.png)
4746

4847
## Implemented algorithms
4948

5049
| Algorithm | numpy | torch | tensorflow |
5150
|-|-|-|-|
5251
| Macenko | ✓ | ✓ | ✓ |
53-
52+
| Reinhard | ✓ | ✗ | ✓ |
5453

5554
## Backend comparison
5655

5756
Results with 10 runs per size on a Intel(R) Core(TM) i5-8365U CPU @ 1.60GHz
5857

59-
6058
| size | numpy avg. time | torch avg. time | tf avg. time |
6159
|--------|-------------------|-------------------|------------------|
6260
| 224 | 0.0182s ± 0.0016 | 0.0180s ± 0.0390 | 0.0048s ± 0.0002 |
@@ -68,17 +66,15 @@ Results with 10 runs per size on a Intel(R) Core(TM) i5-8365U CPU @ 1.60GHz
6866
| 1568 | 1.1935s ± 0.0739 | 0.2590s ± 0.0088 | 0.2531s ± 0.0031 |
6967
| 1792 | 1.4523s ± 0.0207 | 0.3402s ± 0.0114 | 0.3080s ± 0.0188 |
7068

71-
7269
## Reference
7370

7471
- [1] Macenko, Marc, et al. "A method for normalizing histology slides for quantitative analysis." 2009 IEEE International Symposium on Biomedical Imaging: From Nano to Macro. IEEE, 2009.
75-
72+
- [2] Reinhard, Erik, et al. "Color transfer between images." IEEE Computer Graphics and Applications. IEEE, 2001.
7673

7774
## Citing
7875

7976
If you find this software useful for your research, please cite it as:
8077

81-
8278
```bibtex
8379
@software{barbano2022torchstain,
8480
author = {Carlo Alberto Barbano and
File renamed without changes.

tests/test_color_conv.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from torchstain.numpy.utils.rgb2lab import rgb2lab
2+
from torchstain.numpy.utils.lab2rgb import lab2rgb
3+
import numpy as np
4+
import cv2
5+
import os
6+
7+
def test_rgb_to_lab():
8+
size = 1024
9+
curr_file_path = os.path.dirname(os.path.realpath(__file__))
10+
img = cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(curr_file_path, "../data/source.png")), cv2.COLOR_BGR2RGB), (size, size))
11+
12+
reconstructed_img = lab2rgb(rgb2lab(img))
13+
val = np.mean(np.abs(reconstructed_img - img))
14+
print("MAE:", val)
15+
assert val < 0.1

tests/test_torch.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from torchvision import transforms
1010
from skimage.metrics import structural_similarity as ssim
1111

12-
1312
def setup_function(fn):
1413
print("torch version:", torch.__version__, "torchvision version:", torchvision.__version__)
1514

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from .he_normalizer import HENormalizer
22
from .macenko import MacenkoNormalizer
3+
from .reinhard import ReinhardNormalizer
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
def ReinhardNormalizer(backend='numpy'):
2+
if backend == 'numpy':
3+
from torchstain.numpy.normalizers import NumpyReinhardNormalizer
4+
return NumpyReinhardNormalizer()
5+
elif backend == "torch":
6+
raise NotImplementedError
7+
elif backend == "tensorflow":
8+
from torchstain.tf.normalizers import TensorFlowReinhardNormalizer
9+
return TensorFlowReinhardNormalizer()
10+
else:
11+
raise Exception(f'Unknown backend {backend}')

torchstain/numpy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from torchstain.numpy import normalizers, utils
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
from .macenko import NumpyMacenkoNormalizer
1+
from .macenko import NumpyMacenkoNormalizer
2+
from .reinhard import NumpyReinhardNormalizer
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import numpy as np
2+
from torchstain.base.normalizers import HENormalizer
3+
from torchstain.numpy.utils.rgb2lab import rgb2lab
4+
from torchstain.numpy.utils.lab2rgb import lab2rgb
5+
from torchstain.numpy.utils.split import csplit, cmerge, lab_split, lab_merge
6+
from torchstain.numpy.utils.stats import get_mean_std, standardize
7+
8+
"""
9+
Source code adapted from:
10+
https://github.com/DigitalSlideArchive/HistomicsTK/blob/master/histomicstk/preprocessing/color_normalization/reinhard.py
11+
https://github.com/Peter554/StainTools/blob/master/staintools/reinhard_color_normalizer.py
12+
"""
13+
class NumpyReinhardNormalizer(HENormalizer):
14+
def __init__(self):
15+
super().__init__()
16+
self.target_mus = None
17+
self.target_stds = None
18+
19+
def fit(self, target):
20+
# normalize
21+
target = target.astype("float32") / 255
22+
23+
# convert to LAB
24+
lab = rgb2lab(target)
25+
26+
# get summary statistics
27+
stack_ = np.array([get_mean_std(x) for x in lab_split(lab)])
28+
self.target_means = stack_[:, 0]
29+
self.target_stds = stack_[:, 1]
30+
31+
def normalize(self, I):
32+
# normalize
33+
I = I.astype("float32") / 255
34+
35+
# convert to LAB
36+
lab = rgb2lab(I)
37+
labs = lab_split(lab)
38+
39+
# get summary statistics from LAB
40+
stack_ = np.array([get_mean_std(x) for x in labs])
41+
mus = stack_[:, 0]
42+
stds = stack_[:, 1]
43+
44+
# standardize intensities channel-wise and normalize using target mus and stds
45+
result = [standardize(x, mu_, std_) * std_T + mu_T for x, mu_, std_, mu_T, std_T \
46+
in zip(labs, mus, stds, self.target_means, self.target_stds)]
47+
48+
# rebuild LAB
49+
lab = lab_merge(*result)
50+
51+
# convert back to RGB from LAB
52+
lab = lab2rgb(lab)
53+
54+
# rescale to [0, 255] uint8
55+
return (lab * 255).astype("uint8")

torchstain/numpy/utils/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from torchstain.numpy.utils.rgb2lab import *
2+
from torchstain.numpy.utils.lab2rgb import *
3+
from torchstain.numpy.utils.split import *
4+
from torchstain.numpy.utils.stats import *

0 commit comments

Comments
 (0)