Skip to content

Commit ef59f58

Browse files
committed
added GPU support
1 parent f64f589 commit ef59f58

File tree

2 files changed

+5
-9
lines changed

2 files changed

+5
-9
lines changed

torchstain/normalizers/torch_macenko_normalizer.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,19 +33,15 @@ def __find_HE(self, ODhat, eigvecs, alpha):
3333
That = torch.matmul(ODhat, eigvecs)
3434
phi = torch.atan2(That[:, 1], That[:, 0])
3535

36-
minPhi = torch.tensor(percentile(phi, alpha))
37-
maxPhi = torch.tensor(percentile(phi, 100 - alpha))
36+
minPhi = percentile(phi, alpha)
37+
maxPhi = percentile(phi, 100 - alpha)
3838

3939
vMin = torch.matmul(eigvecs, torch.stack((torch.cos(minPhi), torch.sin(minPhi))).T).unsqueeze(1)
4040
vMax = torch.matmul(eigvecs, torch.stack((torch.cos(maxPhi), torch.sin(maxPhi))).T).unsqueeze(1)
4141

4242
# a heuristic to make the vector corresponding to hematoxylin first and the
4343
# one corresponding to eosin second
44-
if vMin[0] > vMax[0]:
45-
HE = torch.cat((vMin, vMax), dim=1)
46-
47-
else:
48-
HE = torch.cat((vMax, vMin), dim=1)
44+
HE = torch.where(vMin[0] > vMax[0], torch.cat((vMin, vMax), dim=1), torch.cat((vMax, vMin), dim=1))
4945

5046
return HE
5147

@@ -66,7 +62,7 @@ def __compute_matrices(self, I, Io, alpha, beta):
6662
HE = self.__find_HE(ODhat, eigvecs, alpha)
6763

6864
C = self.__find_concentration(OD, HE)
69-
maxC = torch.tensor([percentile(C[0, :], 99), percentile(C[1, :], 99)])
65+
maxC = torch.stack([percentile(C[0, :], 99), percentile(C[1, :], 99)])
7066

7167
return HE, C, maxC
7268

torchstain/utils/percentile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,5 +21,5 @@ def percentile(t: torch.tensor, q: float) -> Union[int, float]:
2121
# indeed corresponds to k=1, not k=0! Use float(q) instead of q directly,
2222
# so that ``round()`` returns an integer, even if q is a np.float32.
2323
k = 1 + round(.01 * float(q) * (t.numel() - 1))
24-
result = t.view(-1).kthvalue(k).values.item()
24+
result = t.view(-1).kthvalue(k).values
2525
return result

0 commit comments

Comments
 (0)