Skip to content

Commit b466869

Browse files
Merge pull request #5 from andreped/tensorflow-backend
Tensorflow backend support
2 parents efa86d9 + b555b7c commit b466869

File tree

10 files changed

+213
-15
lines changed

10 files changed

+213
-15
lines changed

example.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import torchstain
44
import torch
55
from torchvision import transforms
6-
import numpy as np
6+
import time
7+
78

89
size = 1024
910
target = cv2.resize(cv2.cvtColor(cv2.imread("./data/target.png"), cv2.COLOR_BGR2RGB), (size, size))
@@ -12,7 +13,6 @@
1213
normalizer = torchstain.MacenkoNormalizer(backend='numpy')
1314
normalizer.fit(target)
1415

15-
1616
T = transforms.Compose([
1717
transforms.ToTensor(),
1818
transforms.Lambda(lambda x: x*255)
@@ -21,9 +21,15 @@
2121
torch_normalizer = torchstain.MacenkoNormalizer(backend='torch')
2222
torch_normalizer.fit(T(target))
2323

24+
tf_normalizer = torchstain.MacenkoNormalizer(backend='tensorflow')
25+
tf_normalizer.fit(T(target))
26+
2427
t_to_transform = T(to_transform)
2528

29+
t_ = time.time()
2630
norm, H, E = normalizer.normalize(I=to_transform, stains=True)
31+
print("numpy runtime:", time.time() - t_)
32+
2733
plt.figure()
2834
plt.suptitle('numpy normalizer')
2935
plt.subplot(2, 2, 1)
@@ -47,7 +53,10 @@
4753
plt.imshow(E)
4854
plt.show()
4955

56+
t_ = time.time()
5057
norm, H, E = torch_normalizer.normalize(I=t_to_transform, stains=True)
58+
print("torch runtime:", time.time() - t_)
59+
5160
plt.figure()
5261
plt.suptitle('torch normalizer')
5362
plt.subplot(2, 2, 1)
@@ -70,3 +79,30 @@
7079
plt.axis('off')
7180
plt.imshow(E)
7281
plt.show()
82+
83+
t_ = time.time()
84+
norm, H, E = tf_normalizer.normalize(I=t_to_transform, stains=True)
85+
print("tf runtime:", time.time() - t_)
86+
87+
plt.figure()
88+
plt.suptitle('tensorflow normalizer')
89+
plt.subplot(2, 2, 1)
90+
plt.title('Original')
91+
plt.axis('off')
92+
plt.imshow(to_transform)
93+
94+
plt.subplot(2, 2, 2)
95+
plt.title('Normalized')
96+
plt.axis('off')
97+
plt.imshow(norm)
98+
99+
plt.subplot(2, 2, 3)
100+
plt.title('H')
101+
plt.axis('off')
102+
plt.imshow(H)
103+
104+
plt.subplot(2, 2, 4)
105+
plt.title('E')
106+
plt.axis('off')
107+
plt.imshow(E)
108+
plt.show()

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
zip_safe=False,
1919
install_requires=[
2020
'torch',
21-
'numpy'
21+
'numpy',
22+
'tensorflow'
2223
],
2324
python_requires='>=3.6'
2425
)
Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
from .numpy_macenko_normalizer import NumpyMacenkoNormalizer
22
from .torch_macenko_normalizer import TorchMacenkoNormalizer
3+
from .tensorflow_macenko_normalizer import TensorFlowMacenkoNormalizer
34

45
def MacenkoNormalizer(backend='torch'):
5-
if backend not in ['torch', 'numpy']:
6-
raise Exception(f'Unkown backend {backend}')
7-
86
if backend == 'numpy':
97
return NumpyMacenkoNormalizer()
10-
return TorchMacenkoNormalizer()
8+
elif backend == "tensorflow":
9+
return TensorFlowMacenkoNormalizer()
10+
elif backend == "torch":
11+
return TorchMacenkoNormalizer()
12+
else:
13+
raise Exception(f'Unknown backend {backend}')

torchstain/normalizers/numpy_macenko_normalizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def __find_concentration(self, OD, HE):
5050
Y = np.reshape(OD, (-1, 3)).T
5151

5252
# determine concentrations of the individual stains
53-
C = np.linalg.lstsq(HE,Y, rcond=None)[0]
53+
C = np.linalg.lstsq(HE, Y, rcond=None)[0]
5454

5555
return C
5656

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
import tensorflow as tf
2+
from torchstain.normalizers.he_normalizer import HENormalizer
3+
from torchstain.utils import cov, percentile, cov_tf, percentile_tf, solveLS
4+
import numpy as np
5+
import tensorflow.keras.backend as K
6+
7+
8+
"""
9+
Source code ported from: https://github.com/schaugf/HEnorm_python
10+
Original implementation: https://github.com/mitkovetta/staining-normalization
11+
"""
12+
class TensorFlowMacenkoNormalizer(HENormalizer):
13+
def __init__(self):
14+
super().__init__()
15+
16+
self.HERef = tf.constant([[0.5626, 0.2159],
17+
[0.7201, 0.8012],
18+
[0.4062, 0.5581]])
19+
self.maxCRef = tf.constant([1.9705, 1.0308])
20+
21+
def __convert_rgb2od(self, I, Io, beta):
22+
I = tf.transpose(I, [1, 2, 0])
23+
24+
# calculate optical density
25+
OD = -tf.math.log((tf.cast(tf.reshape(I, [tf.math.reduce_prod(I.shape[:-1]), I.shape[-1]]), tf.float32) + 1) / Io)
26+
27+
# remove transparent pixels
28+
ODhat = OD[~tf.math.reduce_any(OD < beta, axis=1)]
29+
30+
return OD, ODhat
31+
32+
def __find_HE(self, ODhat, eigvecs, alpha):
33+
# project on the plane spanned by the eigenvectors corresponding to the two
34+
# largest eigenvalues
35+
That = tf.linalg.matmul(ODhat, eigvecs)
36+
phi = tf.math.atan2(That[:, 1], That[:, 0])
37+
38+
minPhi = percentile_tf(phi, alpha)
39+
maxPhi = percentile_tf(phi, 100 - alpha)
40+
41+
vMin = tf.matmul(eigvecs, tf.expand_dims(tf.stack((tf.math.cos(minPhi), tf.math.sin(minPhi))), axis=-1))
42+
vMax = tf.matmul(eigvecs, tf.expand_dims(tf.stack((tf.math.cos(maxPhi), tf.math.sin(maxPhi))), axis=-1))
43+
44+
# a heuristic to make the vector corresponding to hematoxylin first and the
45+
# one corresponding to eosin second
46+
HE = tf.where(vMin[0] > vMax[0], tf.concat((vMin, vMax), axis=1), tf.concat((vMax, vMin), axis=1))
47+
48+
return HE
49+
50+
def __find_concentration(self, OD, HE):
51+
# rows correspond to channels (RGB), columns to OD values
52+
Y = tf.transpose(OD)
53+
54+
# determine concentrations of the individual stains
55+
return solveLS(HE, Y)[:2]
56+
57+
def __compute_matrices(self, I, Io, alpha, beta):
58+
OD, ODhat = self.__convert_rgb2od(I, Io=Io, beta=beta)
59+
60+
# compute eigenvectors
61+
_, eigvecs = tf.linalg.eigh(cov_tf(tf.transpose(ODhat)))
62+
eigvecs = eigvecs[:, 1:3]
63+
64+
HE = self.__find_HE(ODhat, eigvecs, alpha)
65+
66+
C = self.__find_concentration(OD, HE)
67+
maxC = tf.stack([percentile_tf(C[0, :], 99), percentile_tf(C[1, :], 99)])
68+
69+
return HE, C, maxC
70+
71+
def fit(self, I, Io=240, alpha=1, beta=0.15):
72+
HE, _, maxC = self.__compute_matrices(I, Io, alpha, beta)
73+
74+
self.HERef = HE
75+
self.maxCRef = maxC
76+
77+
def normalize(self, I, Io=240, alpha=1, beta=0.15, stains=True):
78+
''' Normalize staining appearence of H&E stained images
79+
80+
Example use:
81+
see test.py
82+
83+
Input:
84+
I: RGB input image: tensor of shape [C, H, W] and type uint8
85+
Io: (optional) transmitted light intensity
86+
alpha: percentile
87+
beta: transparency threshold
88+
stains: if true, return also H & E components
89+
90+
Output:
91+
Inorm: normalized image
92+
H: hematoxylin image
93+
E: eosin image
94+
95+
Reference:
96+
A method for normalizing histology slides for quantitative analysis. M.
97+
Macenko et al., ISBI 2009
98+
'''
99+
c, h, w = I.shape
100+
101+
HE, C, maxC = self.__compute_matrices(I, Io, alpha, beta)
102+
103+
# normalize stain concentrations
104+
C *= tf.expand_dims((self.maxCRef / maxC), axis=-1)
105+
106+
# recreate the image using reference mixing matrix
107+
Inorm = Io * tf.math.exp(-tf.linalg.matmul(self.HERef, C))
108+
Inorm = tf.clip_by_value(Inorm, 0, 255)
109+
Inorm = tf.cast(tf.reshape(tf.transpose(Inorm), shape=(h, w, c)), tf.int32)
110+
111+
H, E = None, None
112+
113+
if stains:
114+
H = tf.math.multiply(Io, tf.math.exp(tf.linalg.matmul(-tf.expand_dims(self.HERef[:, 0], axis=-1), tf.expand_dims(C[0, :], axis=0))))
115+
H = tf.clip_by_value(H, 0, 255)
116+
H = tf.cast(tf.reshape(tf.transpose(H), shape=(h, w, c)), tf.int32)
117+
118+
E = tf.math.multiply(Io, tf.math.exp(tf.linalg.matmul(-tf.expand_dims(self.HERef[:, 1], axis=-1), tf.expand_dims(C[1, :], axis=0))))
119+
E = tf.clip_by_value(E, 0, 255)
120+
E = tf.cast(tf.reshape(tf.transpose(E), shape=(h, w, c)), tf.int32)
121+
122+
return Inorm, H, E

torchstain/normalizers/torch_macenko_normalizer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def __convert_rgb2od(self, I, Io, beta):
2020
I = I.permute(1, 2, 0)
2121

2222
# calculate optical density
23-
OD = -torch.log((I.reshape((-1, I.shape[-1])).float()+1)/Io)
23+
OD = -torch.log((I.reshape((-1, I.shape[-1])).float() + 1)/Io)
2424

2525
# remove transparent pixels
2626
ODhat = OD[~torch.any(OD < beta, dim=1)]
@@ -36,8 +36,8 @@ def __find_HE(self, ODhat, eigvecs, alpha):
3636
minPhi = percentile(phi, alpha)
3737
maxPhi = percentile(phi, 100 - alpha)
3838

39-
vMin = torch.matmul(eigvecs, torch.stack((torch.cos(minPhi), torch.sin(minPhi))).T).unsqueeze(1)
40-
vMax = torch.matmul(eigvecs, torch.stack((torch.cos(maxPhi), torch.sin(maxPhi))).T).unsqueeze(1)
39+
vMin = torch.matmul(eigvecs, torch.stack((torch.cos(minPhi), torch.sin(minPhi)))).unsqueeze(1)
40+
vMax = torch.matmul(eigvecs, torch.stack((torch.cos(maxPhi), torch.sin(maxPhi)))).unsqueeze(1)
4141

4242
# a heuristic to make the vector corresponding to hematoxylin first and the
4343
# one corresponding to eosin second

torchstain/utils/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
1-
from torchstain.utils.cov import cov
2-
from torchstain.utils.percentile import percentile
1+
from torchstain.utils.cov import cov, cov_tf
2+
from torchstain.utils.percentile import percentile, percentile_tf
3+
from torchstain.utils.solveLS import solveLS

torchstain/utils/cov.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
import tensorflow as tf
23

34
def cov(x):
45
"""
@@ -7,3 +8,12 @@ def cov(x):
78
E_x = x.mean(dim=1)
89
x = x - E_x[:, None]
910
return torch.mm(x, x.T) / (x.size(1) - 1)
11+
12+
13+
def cov_tf(x):
14+
"""
15+
https://en.wikipedia.org/wiki/Covariance_matrix
16+
"""
17+
E_x = tf.math.reduce_mean(x, axis=1)
18+
x = x - E_x[:, None]
19+
return tf.linalg.matmul(x, tf.transpose(x)) / (x.shape[1] - 1)

torchstain/utils/percentile.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch
22
from typing import Union
3+
import tensorflow as tf
34

45
"""
56
Author: https://gist.github.com/spezold/42a451682422beb42bc43ad0c0967a30
@@ -21,5 +22,21 @@ def percentile(t: torch.tensor, q: float) -> Union[int, float]:
2122
# indeed corresponds to k=1, not k=0! Use float(q) instead of q directly,
2223
# so that ``round()`` returns an integer, even if q is a np.float32.
2324
k = 1 + round(.01 * float(q) * (t.numel() - 1))
24-
result = t.view(-1).kthvalue(k).values
25-
return result
25+
return t.view(-1).kthvalue(k).values
26+
27+
28+
def percentile_tf(t: tf.Tensor, q: float) -> Union[int, float]:
29+
"""
30+
Return the ``q``-th percentile of the flattened input tensor's data.
31+
32+
CAUTION:
33+
* Needs PyTorch >= 1.1.0, as ``torch.kthvalue()`` is used.
34+
* Values are not interpolated, which corresponds to
35+
``numpy.percentile(..., interpolation="nearest")``.
36+
37+
:param t: Input tensor.
38+
:param q: Percentile to compute, which must be between 0 and 100 inclusive.
39+
:return: Resulting value (scalar).
40+
"""
41+
k = 1 + tf.math.round(.01 * tf.cast(q, tf.float32) * (tf.cast(tf.size(t), tf.float32) - 1))
42+
return tf.sort(tf.reshape(t, [-1]))[tf.cast(k, tf.int32)]

torchstain/utils/solveLS.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import tensorflow as tf
2+
3+
4+
def solveLS(A, B):
5+
q_full, _ = tf.linalg.qr(A, full_matrices=True)
6+
ret1 = tf.linalg.lstsq(A, B)
7+
ret2 = tf.linalg.lstsq(q_full, B)
8+
return tf.concat([ret1, ret2[ret1.shape[0]:]], axis=0)

0 commit comments

Comments
 (0)