|
1 | 1 | import os |
2 | 2 | import cv2 |
3 | | -import matplotlib.pyplot as plt |
4 | 3 | import torchstain |
5 | 4 | import torch |
6 | 5 | from torchvision import transforms |
7 | 6 | import time |
8 | 7 | from skimage.metrics import structural_similarity as ssim |
9 | 8 | import numpy as np |
10 | 9 |
|
11 | | - |
12 | | -size = 1024 |
13 | | -curr_file_path = os.path.dirname(os.path.realpath(__file__)) |
14 | | -target = cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(curr_file_path, "../data/target.png")), cv2.COLOR_BGR2RGB), (size, size)) |
15 | | -to_transform = cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(curr_file_path, "../data/source.png")), cv2.COLOR_BGR2RGB), (size, size)) |
16 | | - |
17 | | -# setup preprocessing and preprocess image to be normalized |
18 | | -T = transforms.Compose([ |
19 | | - transforms.ToTensor(), |
20 | | - transforms.Lambda(lambda x: x*255) |
21 | | -]) |
22 | | -t_to_transform = T(to_transform) |
23 | | - |
24 | | -# initialize normalizers for each backend and fit to target image |
25 | | -normalizer = torchstain.MacenkoNormalizer(backend='numpy') |
26 | | -normalizer.fit(target) |
27 | | - |
28 | | -torch_normalizer = torchstain.MacenkoNormalizer(backend='torch') |
29 | | -torch_normalizer.fit(T(target)) |
30 | | - |
31 | | -tf_normalizer = torchstain.MacenkoNormalizer(backend='tensorflow') |
32 | | -tf_normalizer.fit(T(target)) |
33 | | - |
34 | | -# transform |
35 | | -result_numpy, _, _ = normalizer.normalize(I=to_transform, stains=True) |
36 | | -result_torch, _, _ = torch_normalizer.normalize(I=t_to_transform, stains=True) |
37 | | -result_tf, _, _ = tf_normalizer.normalize(I=t_to_transform, stains=True) |
38 | | - |
39 | | -# convert to numpy and set dtype |
40 | | -result_numpy = result_numpy.astype("float32") |
41 | | -result_torch = result_torch.numpy().astype("float32") |
42 | | -result_tf = result_tf.numpy().astype("float32") |
43 | | - |
44 | | -# assess whether the normalized images are identical across backends |
45 | | -np.testing.assert_almost_equal(ssim(result_numpy.flatten(), result_torch.flatten()), 1.0, decimal=4, verbose=True) |
46 | | -np.testing.assert_almost_equal(ssim(result_numpy.flatten(), result_tf.flatten()), 1.0, decimal=4, verbose=True) |
| 10 | +def test_normalize_all(): |
| 11 | + size = 1024 |
| 12 | + curr_file_path = os.path.dirname(os.path.realpath(__file__)) |
| 13 | + target = cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(curr_file_path, "../data/target.png")), cv2.COLOR_BGR2RGB), (size, size)) |
| 14 | + to_transform = cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(curr_file_path, "../data/source.png")), cv2.COLOR_BGR2RGB), (size, size)) |
| 15 | + |
| 16 | + # setup preprocessing and preprocess image to be normalized |
| 17 | + T = transforms.Compose([ |
| 18 | + transforms.ToTensor(), |
| 19 | + transforms.Lambda(lambda x: x*255) |
| 20 | + ]) |
| 21 | + t_to_transform = T(to_transform) |
| 22 | + |
| 23 | + # initialize normalizers for each backend and fit to target image |
| 24 | + normalizer = torchstain.normalizers.MacenkoNormalizer(backend='numpy') |
| 25 | + normalizer.fit(target) |
| 26 | + |
| 27 | + torch_normalizer = torchstain.normalizers.MacenkoNormalizer(backend='torch') |
| 28 | + torch_normalizer.fit(T(target)) |
| 29 | + |
| 30 | + tf_normalizer = torchstain.normalizers.MacenkoNormalizer(backend='tensorflow') |
| 31 | + tf_normalizer.fit(T(target)) |
| 32 | + |
| 33 | + # transform |
| 34 | + result_numpy, _, _ = normalizer.normalize(I=to_transform, stains=True) |
| 35 | + result_torch, _, _ = torch_normalizer.normalize(I=t_to_transform, stains=True) |
| 36 | + result_tf, _, _ = tf_normalizer.normalize(I=t_to_transform, stains=True) |
| 37 | + |
| 38 | + # convert to numpy and set dtype |
| 39 | + result_numpy = result_numpy.astype("float32") |
| 40 | + result_torch = result_torch.numpy().astype("float32") |
| 41 | + result_tf = result_tf.numpy().astype("float32") |
| 42 | + |
| 43 | + # assess whether the normalized images are identical across backends |
| 44 | + np.testing.assert_almost_equal(ssim(result_numpy.flatten(), result_torch.flatten()), 1.0, decimal=4, verbose=True) |
| 45 | + np.testing.assert_almost_equal(ssim(result_numpy.flatten(), result_tf.flatten()), 1.0, decimal=4, verbose=True) |
0 commit comments