Skip to content

Commit 59a6c0e

Browse files
authored
Merge pull request #62 from ema-frasca/master
Added CaSpeR, CSCCT and ZSCL methods
2 parents f9f443a + f4aa002 commit 59a6c0e

32 files changed

+3109
-6
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,12 @@ Mammoth currently supports **more than 50** models, with new releases covering t
5353
- Efficient Lifelong Learning with A-GEM (A-GEM, A-GEM-R - A-GEM with reservoir buffer): `agem`, `agem_r`.
5454
- AttriCLIP: A Non-Incremental Learner for Incremental Knowledge Learning (AttriCLIP): `attriclip`.
5555
- Bias Correction (BiC): `bic`.
56+
- CaSpeR-IL (on DER++, X-DER with RPC, iCaRL, and ER-ACE): `derpp_casper`, `xder_rpc_casper`, `icarl_casper`, `er_ace_casper`.
5657
- Continual Contrastive Interpolation Consistency (CCIC) - _Requires_ `pip install kornia`: `ccic`.
5758
- Continual Generative training for Incremental prompt-Learning (CGIL): `cgil`
5859
- Contrastive Language-Image Pre-Training (CLIP): `clip` (*static* method with no learning).
5960
- CODA-Prompt: COntinual Decomposed Attention-based Prompting for Rehearsal-Free Continual Learning (CODA-Prompt) - _Requires_ `pip install timm==0.9.8`: `coda-prompt`.
61+
- CSCCT (on DER++, X-DER with RPC, iCaRL, and ER-ACE): `derpp_cscct`, `xder_rpc_cscct`, `icarl_cscct`, `er_ace_cscct`.
6062
- Generating Instance-level Prompts for Rehearsal-free Continual Learning (DAP): `dap`.
6163
- Dark Experience for General Continual Learning: a Strong, Simple Baseline (DER & DER++): `der` and `derpp`.
6264
- DualPrompt: Complementary Prompting for Rehearsal-free Continual Learning (DualPrompt) - _Requires_ `pip install timm==0.9.8`: `dualprompt`.
@@ -92,6 +94,7 @@ Mammoth currently supports **more than 50** models, with new releases covering t
9294
- Semantic Two-level Additive Residual Prompt (STAR-Prompt): `starprompt`. Also includes the first-stage only (`first_stage_starprompt`) and second-stage only (`second_stage_starprompt`) versions.
9395
- Transfer without Forgetting (TwF): `twf`.
9496
- eXtended-DER (X-DER): `xder` (full version), `xder_ce` (X-DER with CE), `xder_rpc` (X-DER with RPC).
97+
- ZSCL: Zero-Shot Continual Learning: `zscl`.
9598

9699
## Datasets
97100

datasets/utils/continual_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ def get_data_loaders(self) -> Tuple[DataLoader, DataLoader]:
320320
"""
321321
raise NotImplementedError
322322

323-
def get_backbone() -> str:
323+
def get_backbone(self) -> str:
324324
"""Returns the name of the backbone to be used for the current dataset. This can be changes using the `--backbone` argument or by setting it in the `dataset_config`."""
325325
raise NotImplementedError
326326

docs/utils/index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ Other arguments such as the size of the training batch and the number of epochs
3535

3636
.. code-block:: bash
3737
38-
python utils/main.py --dataset seq-cifar10 --model der --buffer_size 500 --lr 0.03 --batch_size 128 --epochs 10
38+
python utils/main.py --dataset seq-cifar10 --model der --buffer_size 500 --lr 0.03 --batch_size 128 --n_epochs 10
3939
4040
.. note::
4141
To ease hyper-parameter tuning, all boolean arguments follow the convention: ``--<argument>=1`` for ``True`` and ``--<argument>=0`` for ``False``.

models/casper_utils/casper_model.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import torch
2+
from models.utils.continual_model import ContinualModel
3+
4+
from utils.buffer import Buffer
5+
6+
from .spectral_analysis import calc_ADL_knn, calc_euclid_dist, find_eigs, normalize_A
7+
8+
9+
class CasperModel(ContinualModel):
10+
11+
@staticmethod
12+
def add_casper_args(parser):
13+
parser.add_argument('--casper_batch', type=int, default=None,
14+
help='Size of minibatch for casper. Equal to batch_size by default, if negative equal to buffer_size.')
15+
16+
parser.add_argument('--rho', type=float, default=0.01, help='Weight for casper loss.')
17+
parser.add_argument('--knn_laplace', type=int, default=10, help='K of knn to build the graph for laplacian.')
18+
parser.add_argument('--p', default=None, type=int, help='Number of classes to be drawn from the buffer. Default is N_CLASSES_PER_TASK.')
19+
return parser
20+
21+
def __init__(self, backbone, loss, args, transform, dataset=None):
22+
assert 'buffer_size' in args, 'The model requires a buffer'
23+
if args.casper_batch is None:
24+
args.casper_batch = args.batch_size
25+
if args.casper_batch < 0:
26+
args.casper_batch = args.buffer_size
27+
super().__init__(backbone, loss, args, transform, dataset)
28+
29+
self.buffer = Buffer(self.args.buffer_size, device=self.device, sample_selection_strategy='balancoir')
30+
31+
self.nc = self.args.p if self.args.p is not None else self.cpt
32+
33+
def get_casper_loss(self):
34+
if self.args.rho == 0:
35+
return torch.tensor(0., dtype=torch.float, device=self.device)
36+
if self.args.casper_batch == self.args.buffer_size:
37+
buffer_data = self.buffer.get_all_data(transform=self.transform)
38+
else:
39+
buffer_data = self.buffer.get_balanced_data(self.args.casper_batch, transform=self.transform, n_classes=self.nc)
40+
inputs, labels = buffer_data[0], buffer_data[1]
41+
features = self.net.features(inputs.to(self.device))
42+
43+
dists = calc_euclid_dist(features)
44+
A, D, L = calc_ADL_knn(dists, k=self.args.knn_laplace, symmetric=True)
45+
46+
L = torch.eye(A.shape[0], device=A.device) - normalize_A(A, D)
47+
48+
n = self.nc
49+
# evals = torch.linalg.eigvalsh(L)
50+
evals, _ = find_eigs(L, n_pairs=min(2 * n, len(L)))
51+
52+
# gaps = evals[1:] - evals[:-1]
53+
return evals[:n + 1].sum() - evals[n + 1]

models/casper_utils/knn.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
'''
2+
Author: Tobias Plötz, TU Darmstadt ([email protected])
3+
This file is part of the implementation as described in the NIPS 2018 paper:
4+
Tobias Plötz and Stefan Roth, Neural Nearest Neighbors Networks.
5+
Please see the file LICENSE.txt for the license governing this code.
6+
'''
7+
import math
8+
from math import log
9+
10+
import torch
11+
import torch.nn as nn
12+
import torch.nn.functional as F
13+
from torch.autograd import Variable
14+
15+
# import ops
16+
17+
18+
def log1mexp(x, expm1_guard=1e-7):
19+
# See https://cran.r-project.org/package=Rmpfr/.../log1mexp-note.pdf
20+
t = x < math.log(0.5)
21+
y = torch.zeros_like(x)
22+
y[t] = torch.log1p(-x[t].exp())
23+
24+
# for x close to 0 we need expm1 for numerically stable computation
25+
# we furtmermore modify the backward pass to avoid instable gradients,
26+
# ie situations where the incoming output gradient is close to 0 and the gradient of expm1 is very large
27+
expxm1 = torch.expm1(x[~t])
28+
log1mexp_fw = (-expxm1).log()
29+
log1mexp_bw = (-expxm1 + expm1_guard).log() # limits magnitude of gradient
30+
31+
y[~t] = log1mexp_fw.detach() + (log1mexp_bw - log1mexp_bw.detach())
32+
return y
33+
34+
35+
class NeuralNearestNeighbors(nn.Module):
36+
r"""
37+
Computes neural nearest neighbor volumes based on pairwise distances
38+
"""
39+
40+
def __init__(self, k, temp_opt={}):
41+
r"""
42+
:param k: Number of neighbor volumes to compute
43+
:param temp_opt: temperature options:
44+
external_temp: Whether temperature is given as external input
45+
rather than fixed parameter
46+
temp_bias: A fixed bias to add to the log temperature
47+
distance_bn: Whether to put distances through a batchnorm layer
48+
"""
49+
super(NeuralNearestNeighbors, self).__init__()
50+
self.external_temp = temp_opt.get("external_temp")
51+
self.log_temp_bias = log(temp_opt.get("temp_bias", 1))
52+
distance_bn = temp_opt.get("distance_bn")
53+
54+
if not self.external_temp:
55+
self.log_temp = nn.Parameter(torch.FloatTensor(1).fill_(0.0))
56+
if distance_bn:
57+
self.bn = nn.BatchNorm1d(1)
58+
else:
59+
self.bn = None
60+
61+
self.k = k
62+
63+
def forward(self, D, log_temp=None):
64+
b, m, o = D.shape
65+
if self.bn is not None:
66+
D = self.bn(D.view(b, 1, m * o)).view(D.shape)
67+
68+
if self.external_temp:
69+
log_temp = log_temp.view(D.shape[0], D.shape[1], -1)
70+
else:
71+
log_temp = self.log_temp.view(1, 1, 1)
72+
73+
log_temp = log_temp + self.log_temp_bias
74+
75+
temperature = log_temp.exp()
76+
if self.training:
77+
M = D.data > -float("Inf")
78+
if len(temperature) > 1:
79+
D[M] /= temperature.expand_as(D)[M]
80+
else:
81+
D[M] = D[M] / temperature[0, 0, 0]
82+
else:
83+
D /= temperature
84+
85+
logits = D.view(D.shape[0] * D.shape[1], -1)
86+
87+
samples_arr = []
88+
89+
for r in range(self.k):
90+
# Eqs. 8 and 10
91+
weights = F.log_softmax(logits, dim=1)
92+
# weights_exp = ops.clamp_probs(weights.exp())
93+
weights_exp = weights.exp()
94+
95+
samples_arr.append(weights_exp.view(b, m, o))
96+
97+
# Eq. 9
98+
logits = logits + log1mexp(weights.view(*logits.shape))
99+
# logits = logits + (1-weights_exp.view(*logits.shape)).log()
100+
101+
W = torch.stack(samples_arr, dim=3)
102+
103+
return W
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
import math
2+
3+
import torch
4+
from xitorch import LinearOperator
5+
from xitorch.linalg import symeig
6+
7+
from .knn import NeuralNearestNeighbors
8+
9+
10+
def calc_ADL_from_dist(dist_matrix: torch.Tensor, sigma=1.):
11+
# compute affinity matrix, heat_kernel
12+
A = torch.exp(-dist_matrix / (sigma ** 2))
13+
# compute degree matrix
14+
D = torch.diag(A.sum(1))
15+
# compute laplacian
16+
L = D - A
17+
return A, D, L
18+
19+
20+
def calc_euclid_dist(data: torch.Tensor):
21+
return ((data.unsqueeze(0) - data.unsqueeze(1)) ** 2).sum(-1)
22+
23+
24+
def calc_cos_dist(data):
25+
return -torch.cosine_similarity(data.unsqueeze(0), data.unsqueeze(1), dim=-1)
26+
27+
28+
def calc_dist_weiss(nu: torch.Tensor, logvar: torch.Tensor):
29+
var = logvar.exp()
30+
edist = calc_euclid_dist(nu)
31+
wdiff = (var.unsqueeze(0) + var.unsqueeze(1) - 2 * (torch.sqrt(var.unsqueeze(0) * var.unsqueeze(1)))).sum(-1)
32+
return edist + wdiff
33+
34+
35+
def calc_ADL_heat(dist_matrix: torch.Tensor, sigma=1.):
36+
# compute affinity matrix, heat_kernel
37+
A = torch.exp(-dist_matrix / (dist_matrix.mean().detach()))
38+
# compute degree matrix
39+
d_values = A.sum(1)
40+
assert not (d_values == 0).any(), f'D contains zeros in diag: \n{d_values}' # \n{A.tolist()}\n{distances.tolist()}'
41+
D = torch.diag(d_values)
42+
# compute laplacian
43+
L = D - A
44+
return A, D, L
45+
46+
47+
def calc_ADL_knn(distances: torch.Tensor, k: int, symmetric: bool = True):
48+
new_A = torch.clone(distances)
49+
new_A[torch.eye(len(new_A)).bool()] = +math.inf
50+
51+
knn = NeuralNearestNeighbors(k)
52+
53+
# knn.log_temp = torch.nn.Parameter(torch.tensor(-10.))
54+
# final_A = knn(-new_A.unsqueeze(0)).squeeze().sum(-1)
55+
# final_A += final_A.clone().T
56+
# final_A[final_A != 0] /= final_A.clone()[final_A != 0]
57+
58+
final_A = torch.zeros_like(new_A)
59+
idxes = new_A.topk(k, largest=False)[1]
60+
final_A[torch.arange(len(idxes)).unsqueeze(1), idxes] = 1
61+
# backpropagation trick
62+
w = knn(-new_A.unsqueeze(0)).squeeze().sum(-1)
63+
if symmetric:
64+
# final_A += final_A.T
65+
final_A = ((final_A + final_A.T) > 0).float()
66+
w = w + w.T
67+
68+
# Ahk, _, _ = calc_ADL_from_dist(distances, sigma=1)
69+
A = final_A.detach() + (w - w.detach())
70+
# A = final_A
71+
72+
# compute degree matrix
73+
d_values = A.sum(1)
74+
assert not (d_values == 0).any(), f'D contains zeros in diag: \n{d_values}' # \n{A.tolist()}\n{distances.tolist()}'
75+
D = torch.diag(d_values)
76+
# compute laplacian
77+
L = D - A
78+
return A, D, L
79+
80+
81+
def calc_ADL(data: torch.Tensor, sigma=1.):
82+
return calc_ADL_from_dist(calc_euclid_dist(data), sigma)
83+
84+
85+
def find_eigs(laplacian: torch.Tensor, n_pairs: int = 0, largest=False):
86+
# n_pairs = 0
87+
if n_pairs > 0:
88+
# eigenvalues, eigenvectors = torch.lobpcg(laplacian, n_pairs, largest=torch.tensor([largest]))
89+
# eigenvalues, eigenvectors = LOBPCG2.apply(laplacian, n_pairs)
90+
eigenvalues, eigenvectors = symeig(LinearOperator.m(laplacian, True), n_pairs)
91+
else:
92+
# eigenvalues = eigenvalues.to(float)
93+
eigenvalues, eigenvectors = torch.linalg.eigh(laplacian)
94+
# eigenvectors = eigenvectors.to(float)
95+
sorted_indices = torch.argsort(eigenvalues, descending=largest)
96+
eigenvalues, eigenvectors = eigenvalues[sorted_indices], eigenvectors[:, sorted_indices]
97+
98+
return eigenvalues, eigenvectors
99+
100+
101+
def calc_energy_from_values(values: torch.Tensor, norm=False):
102+
nsamples = len(values)
103+
max_value = nsamples - 1 if norm else nsamples * (nsamples - 1)
104+
dir_energy = values.sum()
105+
energy_p = dir_energy / max_value
106+
return energy_p.cpu().item()
107+
108+
109+
def normalize_A(A, D):
110+
inv_d = torch.diag(D[torch.eye(len(D)).bool()].pow(-0.5))
111+
assert not torch.isinf(inv_d).any(), 'D^-0.5 contains inf'
112+
# inv_d[torch.isinf(inv_d)] = 0
113+
# return torch.sqrt(torch.linalg.inv(D)) @ A @ torch.sqrt(torch.linalg.inv(D))
114+
return inv_d @ A @ inv_d
115+
116+
117+
def dir_energy_normal(data: torch.Tensor, sigma=1.):
118+
A, D, L = calc_ADL(data, sigma)
119+
L_norm = torch.eye(A.shape[0]).to(data.device) - normalize_A(A, D)
120+
eigenvalues, eigenvectors = find_eigs(L_norm)
121+
energy = calc_energy_from_values(eigenvalues, norm=True)
122+
return energy, eigenvalues, eigenvectors
123+
124+
125+
def dir_energy(data: torch.Tensor, sigma=1):
126+
A, D, L = calc_ADL(data, sigma=sigma)
127+
eigenvalues, eigenvectors = find_eigs(L)
128+
energy = calc_energy_from_values(eigenvalues)
129+
return energy
130+
131+
132+
def laplacian_analysis(data: torch.Tensor, sigma=1., knn=0, logvars: torch.Tensor = None,
133+
norm_lap=False, norm_eigs=False, n_pairs=0):
134+
if logvars is None:
135+
distances = calc_euclid_dist(data)
136+
else:
137+
distances = calc_dist_weiss(data, logvars)
138+
if knn > 0:
139+
A, D, L = calc_ADL_knn(distances, knn, symmetric=True)
140+
else:
141+
A, D, L = calc_ADL_from_dist(distances, sigma)
142+
if norm_lap:
143+
L = torch.eye(A.shape[0]).to(data.device) - normalize_A(A, D)
144+
eigenvalues, eigenvectors = find_eigs(L, n_pairs=n_pairs)
145+
energy = calc_energy_from_values(eigenvalues, norm=norm_lap)
146+
if norm_eigs and not norm_lap:
147+
eigenvalues = eigenvalues / (len(eigenvalues))
148+
return energy, eigenvalues, eigenvectors, L, (A, D, distances)
149+
150+
151+
class LOBPCG2(torch.autograd.Function):
152+
@staticmethod
153+
def forward(ctx, A: torch.Tensor, k: int):
154+
e, v = torch.lobpcg(A, k=k, largest=False)
155+
res = (A @ v) - (v @ torch.diag(e))
156+
assert (res.abs() < 1e-3).all(), 'A v != e v => incorrect eigenpairs'
157+
ctx.save_for_backward(e, v, A)
158+
return e, v
159+
160+
@staticmethod
161+
def backward(ctx, de, dv):
162+
"""
163+
solve `dA v + A dv = dv diag(e) + v diag(de)` for `dA`
164+
"""
165+
e, v, A = ctx.saved_tensors
166+
167+
vt = v.transpose(-2, -1)
168+
rhs = ((dv @ torch.diag(e)) + (v @ torch.diag(de)) - (A @ dv)).transpose(-2, -1)
169+
170+
n, k = v.shape
171+
K = vt[:, :vt.shape[0]]
172+
# print('K.det=', K.det()) # should be > 0
173+
iK = K.inverse()
174+
175+
dAt = torch.zeros((n, n), device=rhs.device)
176+
dAt[:k] = (iK @ rhs)[:k]
177+
dA = dAt.transpose(-2, -1)
178+
179+
# res = T.mm(dA, v) + T.mm(A, dv) - T.mm(dv, T.diag(e)) - T.mm(v, T.diag(de))
180+
# print('res=', res)
181+
return dA, None

0 commit comments

Comments
 (0)