|
| 1 | +import math |
| 2 | + |
| 3 | +import torch |
| 4 | +from xitorch import LinearOperator |
| 5 | +from xitorch.linalg import symeig |
| 6 | + |
| 7 | +from .knn import NeuralNearestNeighbors |
| 8 | + |
| 9 | + |
| 10 | +def calc_ADL_from_dist(dist_matrix: torch.Tensor, sigma=1.): |
| 11 | + # compute affinity matrix, heat_kernel |
| 12 | + A = torch.exp(-dist_matrix / (sigma ** 2)) |
| 13 | + # compute degree matrix |
| 14 | + D = torch.diag(A.sum(1)) |
| 15 | + # compute laplacian |
| 16 | + L = D - A |
| 17 | + return A, D, L |
| 18 | + |
| 19 | + |
| 20 | +def calc_euclid_dist(data: torch.Tensor): |
| 21 | + return ((data.unsqueeze(0) - data.unsqueeze(1)) ** 2).sum(-1) |
| 22 | + |
| 23 | + |
| 24 | +def calc_cos_dist(data): |
| 25 | + return -torch.cosine_similarity(data.unsqueeze(0), data.unsqueeze(1), dim=-1) |
| 26 | + |
| 27 | + |
| 28 | +def calc_dist_weiss(nu: torch.Tensor, logvar: torch.Tensor): |
| 29 | + var = logvar.exp() |
| 30 | + edist = calc_euclid_dist(nu) |
| 31 | + wdiff = (var.unsqueeze(0) + var.unsqueeze(1) - 2 * (torch.sqrt(var.unsqueeze(0) * var.unsqueeze(1)))).sum(-1) |
| 32 | + return edist + wdiff |
| 33 | + |
| 34 | + |
| 35 | +def calc_ADL_heat(dist_matrix: torch.Tensor, sigma=1.): |
| 36 | + # compute affinity matrix, heat_kernel |
| 37 | + A = torch.exp(-dist_matrix / (dist_matrix.mean().detach())) |
| 38 | + # compute degree matrix |
| 39 | + d_values = A.sum(1) |
| 40 | + assert not (d_values == 0).any(), f'D contains zeros in diag: \n{d_values}' # \n{A.tolist()}\n{distances.tolist()}' |
| 41 | + D = torch.diag(d_values) |
| 42 | + # compute laplacian |
| 43 | + L = D - A |
| 44 | + return A, D, L |
| 45 | + |
| 46 | + |
| 47 | +def calc_ADL_knn(distances: torch.Tensor, k: int, symmetric: bool = True): |
| 48 | + new_A = torch.clone(distances) |
| 49 | + new_A[torch.eye(len(new_A)).bool()] = +math.inf |
| 50 | + |
| 51 | + knn = NeuralNearestNeighbors(k) |
| 52 | + |
| 53 | + # knn.log_temp = torch.nn.Parameter(torch.tensor(-10.)) |
| 54 | + # final_A = knn(-new_A.unsqueeze(0)).squeeze().sum(-1) |
| 55 | + # final_A += final_A.clone().T |
| 56 | + # final_A[final_A != 0] /= final_A.clone()[final_A != 0] |
| 57 | + |
| 58 | + final_A = torch.zeros_like(new_A) |
| 59 | + idxes = new_A.topk(k, largest=False)[1] |
| 60 | + final_A[torch.arange(len(idxes)).unsqueeze(1), idxes] = 1 |
| 61 | + # backpropagation trick |
| 62 | + w = knn(-new_A.unsqueeze(0)).squeeze().sum(-1) |
| 63 | + if symmetric: |
| 64 | + # final_A += final_A.T |
| 65 | + final_A = ((final_A + final_A.T) > 0).float() |
| 66 | + w = w + w.T |
| 67 | + |
| 68 | + # Ahk, _, _ = calc_ADL_from_dist(distances, sigma=1) |
| 69 | + A = final_A.detach() + (w - w.detach()) |
| 70 | + # A = final_A |
| 71 | + |
| 72 | + # compute degree matrix |
| 73 | + d_values = A.sum(1) |
| 74 | + assert not (d_values == 0).any(), f'D contains zeros in diag: \n{d_values}' # \n{A.tolist()}\n{distances.tolist()}' |
| 75 | + D = torch.diag(d_values) |
| 76 | + # compute laplacian |
| 77 | + L = D - A |
| 78 | + return A, D, L |
| 79 | + |
| 80 | + |
| 81 | +def calc_ADL(data: torch.Tensor, sigma=1.): |
| 82 | + return calc_ADL_from_dist(calc_euclid_dist(data), sigma) |
| 83 | + |
| 84 | + |
| 85 | +def find_eigs(laplacian: torch.Tensor, n_pairs: int = 0, largest=False): |
| 86 | + # n_pairs = 0 |
| 87 | + if n_pairs > 0: |
| 88 | + # eigenvalues, eigenvectors = torch.lobpcg(laplacian, n_pairs, largest=torch.tensor([largest])) |
| 89 | + # eigenvalues, eigenvectors = LOBPCG2.apply(laplacian, n_pairs) |
| 90 | + eigenvalues, eigenvectors = symeig(LinearOperator.m(laplacian, True), n_pairs) |
| 91 | + else: |
| 92 | + # eigenvalues = eigenvalues.to(float) |
| 93 | + eigenvalues, eigenvectors = torch.linalg.eigh(laplacian) |
| 94 | + # eigenvectors = eigenvectors.to(float) |
| 95 | + sorted_indices = torch.argsort(eigenvalues, descending=largest) |
| 96 | + eigenvalues, eigenvectors = eigenvalues[sorted_indices], eigenvectors[:, sorted_indices] |
| 97 | + |
| 98 | + return eigenvalues, eigenvectors |
| 99 | + |
| 100 | + |
| 101 | +def calc_energy_from_values(values: torch.Tensor, norm=False): |
| 102 | + nsamples = len(values) |
| 103 | + max_value = nsamples - 1 if norm else nsamples * (nsamples - 1) |
| 104 | + dir_energy = values.sum() |
| 105 | + energy_p = dir_energy / max_value |
| 106 | + return energy_p.cpu().item() |
| 107 | + |
| 108 | + |
| 109 | +def normalize_A(A, D): |
| 110 | + inv_d = torch.diag(D[torch.eye(len(D)).bool()].pow(-0.5)) |
| 111 | + assert not torch.isinf(inv_d).any(), 'D^-0.5 contains inf' |
| 112 | + # inv_d[torch.isinf(inv_d)] = 0 |
| 113 | + # return torch.sqrt(torch.linalg.inv(D)) @ A @ torch.sqrt(torch.linalg.inv(D)) |
| 114 | + return inv_d @ A @ inv_d |
| 115 | + |
| 116 | + |
| 117 | +def dir_energy_normal(data: torch.Tensor, sigma=1.): |
| 118 | + A, D, L = calc_ADL(data, sigma) |
| 119 | + L_norm = torch.eye(A.shape[0]).to(data.device) - normalize_A(A, D) |
| 120 | + eigenvalues, eigenvectors = find_eigs(L_norm) |
| 121 | + energy = calc_energy_from_values(eigenvalues, norm=True) |
| 122 | + return energy, eigenvalues, eigenvectors |
| 123 | + |
| 124 | + |
| 125 | +def dir_energy(data: torch.Tensor, sigma=1): |
| 126 | + A, D, L = calc_ADL(data, sigma=sigma) |
| 127 | + eigenvalues, eigenvectors = find_eigs(L) |
| 128 | + energy = calc_energy_from_values(eigenvalues) |
| 129 | + return energy |
| 130 | + |
| 131 | + |
| 132 | +def laplacian_analysis(data: torch.Tensor, sigma=1., knn=0, logvars: torch.Tensor = None, |
| 133 | + norm_lap=False, norm_eigs=False, n_pairs=0): |
| 134 | + if logvars is None: |
| 135 | + distances = calc_euclid_dist(data) |
| 136 | + else: |
| 137 | + distances = calc_dist_weiss(data, logvars) |
| 138 | + if knn > 0: |
| 139 | + A, D, L = calc_ADL_knn(distances, knn, symmetric=True) |
| 140 | + else: |
| 141 | + A, D, L = calc_ADL_from_dist(distances, sigma) |
| 142 | + if norm_lap: |
| 143 | + L = torch.eye(A.shape[0]).to(data.device) - normalize_A(A, D) |
| 144 | + eigenvalues, eigenvectors = find_eigs(L, n_pairs=n_pairs) |
| 145 | + energy = calc_energy_from_values(eigenvalues, norm=norm_lap) |
| 146 | + if norm_eigs and not norm_lap: |
| 147 | + eigenvalues = eigenvalues / (len(eigenvalues)) |
| 148 | + return energy, eigenvalues, eigenvectors, L, (A, D, distances) |
| 149 | + |
| 150 | + |
| 151 | +class LOBPCG2(torch.autograd.Function): |
| 152 | + @staticmethod |
| 153 | + def forward(ctx, A: torch.Tensor, k: int): |
| 154 | + e, v = torch.lobpcg(A, k=k, largest=False) |
| 155 | + res = (A @ v) - (v @ torch.diag(e)) |
| 156 | + assert (res.abs() < 1e-3).all(), 'A v != e v => incorrect eigenpairs' |
| 157 | + ctx.save_for_backward(e, v, A) |
| 158 | + return e, v |
| 159 | + |
| 160 | + @staticmethod |
| 161 | + def backward(ctx, de, dv): |
| 162 | + """ |
| 163 | + solve `dA v + A dv = dv diag(e) + v diag(de)` for `dA` |
| 164 | + """ |
| 165 | + e, v, A = ctx.saved_tensors |
| 166 | + |
| 167 | + vt = v.transpose(-2, -1) |
| 168 | + rhs = ((dv @ torch.diag(e)) + (v @ torch.diag(de)) - (A @ dv)).transpose(-2, -1) |
| 169 | + |
| 170 | + n, k = v.shape |
| 171 | + K = vt[:, :vt.shape[0]] |
| 172 | + # print('K.det=', K.det()) # should be > 0 |
| 173 | + iK = K.inverse() |
| 174 | + |
| 175 | + dAt = torch.zeros((n, n), device=rhs.device) |
| 176 | + dAt[:k] = (iK @ rhs)[:k] |
| 177 | + dA = dAt.transpose(-2, -1) |
| 178 | + |
| 179 | + # res = T.mm(dA, v) + T.mm(A, dv) - T.mm(dv, T.diag(e)) - T.mm(v, T.diag(de)) |
| 180 | + # print('res=', res) |
| 181 | + return dA, None |
0 commit comments