Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
b8fd539
init commit
Badr-MOUFAD Aug 24, 2022
c2aecba
gram solver && unit test
Badr-MOUFAD Aug 24, 2022
507fc8a
fix bug gram solver && tighten test
Badr-MOUFAD Aug 24, 2022
c9b64c2
add anderson acceleration
Badr-MOUFAD Aug 24, 2022
20c1911
bug ``stop_criter`` && refactor
Badr-MOUFAD Aug 24, 2022
f2e985d
refactoring of var names
Badr-MOUFAD Aug 25, 2022
2dbc8e4
handle ``w_init``
Badr-MOUFAD Aug 25, 2022
8ca7a41
refactor ``_gram_cd_``
Badr-MOUFAD Aug 25, 2022
3453233
gram epoch greedy and cyclic strategy
Badr-MOUFAD Aug 25, 2022
8d3dbc1
extend to sparse case && unitest
Badr-MOUFAD Aug 25, 2022
cdd7e34
one implementation of _gram_cd && unittest
Badr-MOUFAD Aug 25, 2022
f4bfeaf
greedy_cd arg instead of cd_strategy
Badr-MOUFAD Aug 25, 2022
95cf1d4
Merge branch 'main' of https://github.com/scikit-learn-contrib/skglm …
Badr-MOUFAD Aug 25, 2022
4c0acca
add docs
Badr-MOUFAD Aug 25, 2022
dcab054
script fast gram, not faster than scipy
mathurinm Aug 25, 2022
e8bc96e
fast gram timing
Badr-MOUFAD Aug 25, 2022
61a67c4
keep grads instead
Badr-MOUFAD Aug 25, 2022
1b6c169
refactor ``chosen_j``
Badr-MOUFAD Aug 25, 2022
c9c5575
script to profile
Badr-MOUFAD Aug 25, 2022
68a0458
potential improvements, docstring
mathurinm Aug 26, 2022
3788cc4
warnings.warn arguments in correct order
mathurinm Aug 26, 2022
1ce391d
cleanups: ann files
Badr-MOUFAD Aug 26, 2022
2476a34
fix ``p_obj`` computation
Badr-MOUFAD Aug 26, 2022
0f766e9
Merge branch 'main' of https://github.com/scikit-learn-contrib/skglm …
Badr-MOUFAD Aug 26, 2022
3208dfa
typos + less cases in test, smaller X in tests
mathurinm Aug 26, 2022
16f6ee4
typo: ``XtXw`` --> ``grad``
Badr-MOUFAD Aug 26, 2022
e9b7224
Merge branch 'gram-solver' of https://github.com/Badr-MOUFAD/skglm in…
Badr-MOUFAD Aug 26, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion skglm/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -1142,7 +1142,7 @@ def path(self, yXT, y, Cs, coef_init=None, return_n_iter=True, **params):
Target vector relative to X.

Cs : ndarray shape (n_Cs,)
Values of regularization strenghts for which solutions are
Values of regularization strengths for which solutions are
computed.

coef_init : array, shape (n_features,), optional
Expand Down
89 changes: 89 additions & 0 deletions skglm/solvers/gram_cd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import numpy as np
from numba import njit
from skglm.utils import AndersonAcceleration


def gram_cd_solver(X, y, penalty, max_iter=20, use_acc=True, tol=1e-4, verbose=False):
"""Run coordinate descent while keeping the gradients up-to-date with Gram updates.

Minimize::
1 / (2*n_samples) * norm(y - Xw)**2 + penalty(w)

Which can be rewritten as::
w.T @ Q @ w / (2*n_samples) - q.T @ w / n_samples + penalty(w)

where::
Q = X.T @ X (gram matrix)
q = X.T @ y
"""
n_samples, n_features = X.shape
scaled_gram = X.T @ X / n_samples
scaled_Xty = X.T @ y / n_samples
scaled_y_norm2 = np.linalg.norm(y) ** 2 / (2 * n_samples)
all_features = np.arange(n_features)
stop_crit = np.inf # prevent ref before assign
p_objs_out = []

w = np.zeros(n_features)
scaled_gram_w = np.zeros(n_features)
opt = penalty.subdiff_distance(w, -scaled_Xty, all_features) # initial: grad = -Xty
if use_acc:
accelerator = AndersonAcceleration(K=5)
w_acc = np.zeros(n_features)
scaled_gram_w_acc = np.zeros(n_features)

for t in range(max_iter):
# check convergences
stop_crit = np.max(opt)
if verbose:
p_obj = (0.5 * w @ scaled_gram_w - scaled_Xty @ w +
scaled_y_norm2 + penalty.value(w))
print(
f"Iteration {t+1}: {p_obj:.10f}, "
f"stopping crit: {stop_crit:.2e}"
)

if stop_crit <= tol:
if verbose:
print(f"Stopping criterion max violation: {stop_crit:.2e}")
break

# inplace update of w, XtXw
opt = _gram_cd_iter(scaled_gram, scaled_Xty, w, scaled_gram_w, penalty,
all_features, n_updates=n_features)

# perform anderson extrapolation
if use_acc:
w_acc, scaled_gram_w_acc, is_extrapolated = accelerator.extrapolate(
w, scaled_gram_w)

if is_extrapolated:
p_obj_acc = (0.5 * w_acc @ scaled_gram_w_acc - scaled_Xty @ w_acc +
penalty.value(w_acc))
p_obj = 0.5 * w @ scaled_gram_w - scaled_Xty @ w + penalty.value(w)
if p_obj_acc < p_obj:
w[:] = w_acc
scaled_gram_w[:] = scaled_gram_w_acc

p_obj = 0.5 * w @ scaled_gram_w - scaled_Xty @ w + penalty.value(w)
p_objs_out.append(p_obj)
return w, np.array(p_objs_out), stop_crit


@njit
def _gram_cd_iter(scaled_gram, scaled_Xty, w, scaled_gram_w, penalty, ws, n_updates):
# inplace update of w, XtXw, opt
# perform greedy cd updates
for _ in range(n_updates):
grad = scaled_gram_w - scaled_Xty
opt = penalty.subdiff_distance(w, grad, ws)
j_max = np.argmax(opt)

old_w_j = w[j_max]
step = 1 / scaled_gram[j_max, j_max] # 1 / lipchitz_j
w[j_max] = penalty.prox_1d(old_w_j - step * grad[j_max], step, j_max)

# Gram matrix update
if w[j_max] != old_w_j:
scaled_gram_w += (w[j_max] - old_w_j) * scaled_gram[:, j_max]
return opt
42 changes: 42 additions & 0 deletions skglm/tests/test_gram_solver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import pytest
from itertools import product

import numpy as np
from numpy.linalg import norm
from sklearn.linear_model import Lasso

from skglm.penalties import L1
from skglm.solvers.gram_cd import gram_cd_solver
from skglm.utils import make_correlated_data, compiled_clone


@pytest.mark.parametrize("n_samples, n_features",
product([100, 200], [50, 90]))
def test_alpha_max(n_samples, n_features):
X, y, _ = make_correlated_data(n_samples, n_features, random_state=0)
alpha_max = norm(X.T @ y, ord=np.inf) / n_samples

l1_penalty = compiled_clone(L1(alpha_max))
w = gram_cd_solver(X, y, l1_penalty, tol=1e-9, verbose=0)[0]

np.testing.assert_equal(w, 0)


@pytest.mark.parametrize("n_samples, n_features, rho",
product([500, 100], [30, 80], [1e-1, 1e-2, 1e-3]))
def test_vs_lasso_sklearn(n_samples, n_features, rho):
X, y, _ = make_correlated_data(n_samples, n_features, random_state=0)
alpha_max = norm(X.T @ y, ord=np.inf) / n_samples
alpha = rho * alpha_max

sk_lasso = Lasso(alpha, fit_intercept=False, tol=1e-9)
sk_lasso.fit(X, y)

l1_penalty = compiled_clone(L1(alpha))
w = gram_cd_solver(X, y, l1_penalty, tol=1e-9, verbose=0, max_iter=1000)[0]

np.testing.assert_allclose(w, sk_lasso.coef_.flatten(), rtol=1e-7, atol=1e-7)


if __name__ == '__main__':
pass