Skip to content

Commit 2daba08

Browse files
pauldoucetRicharizarddguillaumejaume
authored
minor bug fixes + support for slide semantic segmentation + minor improvements (#132)
* bug fix, default argument for segment_tissue throws an exception * add num_workers arg to segment_tissue * bug fix, display mpp for patcher vis if magnification is None * don t install heavy libraries by default * add ability to create WSIPatcher from CLAM type coordinates * fix from_legacy_coords * add coords_only arg to from_legacy * add pil argument to from_legacy_coords * pass slide_path as positional argument in CuCIMWSI, ImageWSI + fix src_pixel_size exception WSIPatcher * add semantic segmentation inference * add semantic segmentation inference * test * add forward_kwargs and create save_coords_h5 * cleanup before PR * fix parameter propagation from ImageWSI to WSI * cleanup before PR * add README newline * remove forward_kwargs * remove get_channels * fix: set correct wsi path in processor unit test --------- Co-authored-by: Richard Chen <[email protected]> Co-authored-by: guillaumejaume <[email protected]>
1 parent 0bba7a7 commit 2daba08

File tree

9 files changed

+500
-155
lines changed

9 files changed

+500
-155
lines changed

pyproject.toml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,22 @@ repository = "https://github.com/mahmoodlab/TRIDENT"
1414
python = "^3.10" # Specify the Python version compatibility.
1515
ipywidgets = "*"
1616
torch = "*"
17-
transformers = "*"
1817
tqdm = "*"
1918
h5py = "*"
2019
matplotlib = "*"
21-
segmentation-models-pytorch = "*"
2220
opencv-python = "*"
2321
openslide-python = "*"
2422
Pillow = "*"
25-
timm = "0.9.16"
26-
einops_exts = "*"
2723
geopandas = "*"
2824
huggingface_hub = "*"
2925
openslide-bin = "*"
26+
27+
# Optional dependencies (can be marked as optional in later versions)
28+
transformers = "*"
29+
timm = "0.9.16"
30+
einops_exts = "*"
3031
scipy = "*"
32+
segmentation-models-pytorch = "*"
3133

3234
[tool.poetry.dev-dependencies]
3335
# Optional development dependencies

tests/test_processor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def test_processor_with_wsis(self):
7070

7171
self.processor = Processor(
7272
job_dir=self.TEST_OUTPUT_DIR,
73-
wsi_source=os.path.join(TestProcessor.TEST_OUTPUT_DIR),
73+
wsi_source=os.path.join(TestProcessor.TEST_OUTPUT_DIR, 'wsis'),
7474
wsi_ext=self.TEST_WSI_EXT,
7575
custom_list_of_wsis=os.path.join(self.custom_list_of_wsis, 'valid_list_of_wsis.csv')
7676
)

trident/IO.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,40 @@ def read_coords_legacy(coords_path):
517517
return patch_size, patch_level, custom_downsample, coords
518518

519519

520+
def coords_to_h5(
521+
coords: List[List[int]],
522+
save_path,
523+
patch_size,
524+
src_mag,
525+
target_mag,
526+
save_coords,
527+
width,
528+
height,
529+
name,
530+
overlap
531+
):
532+
""" Save tissue coordinates to .h5 """
533+
# Prepare assets for saving
534+
assets = {'coords' : np.array(coords)}
535+
attributes = {
536+
'patch_size': patch_size, # Reference frame: patch_level
537+
'patch_size_level0': patch_size * src_mag // target_mag, # Reference frame: level0
538+
'level0_magnification': src_mag,
539+
'target_magnification': target_mag,
540+
'overlap': overlap,
541+
'name': name,
542+
'savetodir': save_coords,
543+
'level0_width': width,
544+
'level0_height': height
545+
}
546+
547+
# Save the assets and attributes to an hdf5 file
548+
save_h5(save_path,
549+
assets = assets,
550+
attributes = {'coords': attributes},
551+
mode='w')
552+
553+
520554
def mask_to_gdf(
521555
mask: np.ndarray,
522556
keep_ids: List[int] = [],

trident/Visualization.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import numpy as np
22
import cv2
33
import matplotlib.pyplot as plt
4-
from scipy.stats import rankdata
54
from PIL import Image
6-
from typing import Optional, Tuple
5+
from typing import Optional, Tuple, Union
76
import os
7+
from shapely import Polygon, MultiPolygon
88

99

1010
def create_overlay(
@@ -73,6 +73,9 @@ def visualize_heatmap(
7373
normalize: bool = True,
7474
num_top_patches_to_save: int = -1,
7575
output_dir: Optional[str] = "output",
76+
vis_mag: Optional[int] = None,
77+
overlay_only = False,
78+
filename = 'heatmap.png'
7679
) -> str:
7780
"""
7881
Generate a heatmap visualization overlayed on a whole slide image (WSI).
@@ -87,30 +90,45 @@ def visualize_heatmap(
8790
normalize (bool): Whether to normalize the scores.
8891
num_top_patches_to_save (int): Number of high-score patches to save. If set to -1, do not save any. Defaults to -1.
8992
output_dir (Optional[str]): Directory to save heatmap and top-k patches.
90-
93+
vis_mag (Optional[int]): Visualization Magnification. This will overwrite vis_level
94+
overlay_only bool: Whenever to save the overlay only. If set to True, save the overlay on top of downscaled version of the WSI. Defaults to False.
95+
filename (str): file will be saved in `output_dir`/`filename`
96+
9197
Returns:
9298
str: Path to the saved heatmap image.
9399
"""
94100

95101
if normalize:
102+
from scipy.stats import rankdata
96103
scores = rankdata(scores, 'average') / len(scores) * 100 / 100
97104

98-
downsample = wsi.level_downsamples[vis_level]
105+
if vis_mag is None:
106+
downsample = wsi.level_downsamples[vis_level]
107+
else:
108+
src_mag = wsi.mag
109+
downsample = src_mag / vis_mag
110+
if not overlay_only:
111+
vis_level, _ = wsi.get_best_level_and_custom_downsample(downsample)
112+
99113
scale = np.array([1 / downsample, 1 / downsample])
100114
region_size = tuple((np.array(wsi.level_dimensions[0]) * scale).astype(int))
101-
102115
overlay = create_overlay(scores, coords, patch_size_level0, scale, region_size)
116+
117+
overlay_colored = apply_colormap(overlay, cmap)
103118

104-
img = wsi.read_region((0, 0), vis_level, wsi.level_dimensions[vis_level]).convert("RGB")
105-
img = img.resize(region_size, resample=Image.Resampling.BICUBIC)
106-
img = np.array(img)
119+
if overlay_only:
120+
blended_img = overlay_colored
121+
else:
122+
img = wsi.read_region((0, 0), vis_level, wsi.level_dimensions[vis_level]).convert("RGB")
123+
img = img.resize(region_size, resample=Image.Resampling.BICUBIC)
124+
img = np.array(img)
125+
126+
blended_img = cv2.addWeighted(img, 0.6, overlay_colored, 0.4, 0)
107127

108-
overlay_colored = apply_colormap(overlay, cmap)
109-
blended_img = cv2.addWeighted(img, 0.6, overlay_colored, 0.4, 0)
110128
blended_img = Image.fromarray(blended_img)
111129

112130
os.makedirs(output_dir, exist_ok=True)
113-
heatmap_path = os.path.join(output_dir, "heatmap.png")
131+
heatmap_path = os.path.join(output_dir, filename)
114132
blended_img.save(heatmap_path)
115133

116134
if num_top_patches_to_save > 0:

trident/wsi_objects/CuCIMWSI.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,28 @@
88

99
class CuCIMWSI(WSI):
1010

11-
def __init__(self, **kwargs) -> None:
11+
def __init__(self, slide_path, **kwargs) -> None:
12+
"""
13+
Initialize a WSI instance using CuCIM as a backend.
14+
15+
Parameters
16+
----------
17+
slide_path : str
18+
Path to the WSI file.
19+
**kwargs : dict
20+
Keyword arguments forwarded to the base `WSI` class. Most important key is:
21+
- lazy_init (bool, default=True): Whether to defer loading WSI and metadata.
22+
23+
Please refer to WSI constructor for all parameters.
24+
25+
Example
26+
-------
27+
>>> wsi = CuCIMWSI(slide_path="path/to/wsi.svs", lazy_init=False)
28+
>>> print(wsi)
29+
<width=100000, height=80000, backend=CuCIMWSI, mpp=0.25, mag=40>
30+
"""
1231
self.img = None
13-
super().__init__(**kwargs)
32+
super().__init__(slide_path, **kwargs)
1433

1534
def _lazy_initialize(self) -> None:
1635
"""

trident/wsi_objects/ImageWSI.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
from __future__ import annotations
22
import numpy as np
33
from PIL import Image
4-
from typing import Tuple, Union
4+
from typing import List, Tuple, Union
55

66
from trident.wsi_objects.WSI import WSI, ReadMode
77

88

99
class ImageWSI(WSI):
1010

11-
def __init__(self, **kwargs) -> None:
11+
def __init__(self, slide_path, **kwargs) -> None:
1212
"""
1313
Initialize a WSI object from a standard image file (e.g., PNG, JPEG, etc.).
1414
@@ -30,7 +30,7 @@ def __init__(self, **kwargs) -> None:
3030
3131
Example
3232
-------
33-
>>> wsi = ImageWSI(slide_path="path/to/image.png", lazy_init=False, mpp=0.51)
33+
>>> wsi = ImageWSI("path/to/image.png", lazy_init=False, mpp=0.51)
3434
>>> print(wsi)
3535
<width=5120, height=3840, backend=ImageWSI, mpp=0.51, mag=20>
3636
"""
@@ -48,7 +48,7 @@ def __init__(self, **kwargs) -> None:
4848
PngImagePlugin.MAX_IMAGE_PIXELS = None # Optional: disables large image warning
4949

5050
self.img = None
51-
super().__init__(**kwargs)
51+
super().__init__(slide_path, **kwargs)
5252

5353
def _lazy_initialize(self) -> None:
5454
"""
@@ -177,28 +177,28 @@ def read_region(
177177
else:
178178
raise ValueError(f"Invalid `read_as` value: {read_as}. Must be 'pil' or 'numpy'.")
179179

180-
def segment_tissue(self, **kwargs):
181-
out = super().segment_tissue(**kwargs)
180+
def segment_tissue(self, *args, **kwargs):
181+
out = super().segment_tissue(*args, **kwargs)
182182
self.close()
183183
return out
184184

185-
def extract_tissue_coords(self, **kwargs):
186-
out = super().extract_tissue_coords(**kwargs)
185+
def extract_tissue_coords(self, *args, **kwargs):
186+
out = super().extract_tissue_coords(*args, **kwargs)
187187
self.close()
188188
return out
189189

190-
def visualize_coords(self, **kwargs):
191-
out = super().visualize_coords(**kwargs)
190+
def visualize_coords(self, *args, **kwargs):
191+
out = super().visualize_coords(*args, **kwargs)
192192
self.close()
193193
return out
194194

195-
def extract_patch_features(self, **kwargs):
196-
out = super().extract_patch_features(**kwargs)
195+
def extract_patch_features(self, *args, **kwargs):
196+
out = super().extract_patch_features(*args, **kwargs)
197197
self.close()
198198
return out
199199

200-
def extract_slide_features(self, **kwargs):
201-
out = super().extract_slide_features(**kwargs)
200+
def extract_slide_features(self, *args, **kwargs):
201+
out = super().extract_slide_features(*args, **kwargs)
202202
self.close()
203203
return out
204204

trident/wsi_objects/OpenSlideWSI.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,16 @@
99

1010
class OpenSlideWSI(WSI):
1111

12-
def __init__(self, **kwargs) -> None:
12+
def __init__(self, slide_path, **kwargs) -> None:
1313
"""
1414
Initialize an OpenSlideWSI instance.
1515
1616
Parameters
1717
----------
18+
slide_path : str
19+
Path to the WSI file.
1820
**kwargs : dict
1921
Keyword arguments forwarded to the base `WSI` class. Most important key is:
20-
- slide_path (str): Path to the WSI.
2122
- lazy_init (bool, default=True): Whether to defer loading WSI and metadata.
2223
2324
Please refer to WSI constructor for all parameters.
@@ -28,7 +29,7 @@ def __init__(self, **kwargs) -> None:
2829
>>> print(wsi)
2930
<width=100000, height=80000, backend=OpenSlideWSI, mpp=0.25, mag=40>
3031
"""
31-
super().__init__(**kwargs)
32+
super().__init__(slide_path, **kwargs)
3233

3334
def _lazy_initialize(self) -> None:
3435
"""
@@ -244,4 +245,4 @@ def get_thumbnail(self, size: tuple[int, int]) -> Image.Image:
244245
PIL.Image.Image
245246
RGB thumbnail as a PIL Image.
246247
"""
247-
return self.img.get_thumbnail(size).convert('RGB')
248+
return self.img.get_thumbnail(size).convert('RGB')

0 commit comments

Comments
 (0)