Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 40 additions & 31 deletions include/cppoptlib/solver/lbfgs.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

#include "../linesearch/more_thuente.h"
#include "Eigen/Core"
#include "solver.h" // NOLINT
#include "solver.h" // NOLINT

namespace cppoptlib::solver {

Expand All @@ -49,7 +49,7 @@ class Lbfgs
cppoptlib::function::DifferentiabilityMode::Second,
"L-BFGS only supports first- or second-order differentiable functions");

private:
private:
using StateType = typename cppoptlib::function::FunctionState<
typename FunctionType::ScalarType, FunctionType::Dimension>;
using Superclass = Solver<FunctionType, StateType>;
Expand All @@ -63,7 +63,7 @@ class Lbfgs
using memory_MatrixType = Eigen::Matrix<ScalarType, Eigen::Dynamic, m>;
using memory_VectorType = Eigen::Matrix<ScalarType, 1, m>;

public:
public:
EIGEN_MAKE_ALIGNED_OPERATOR_NEW
using Superclass::Superclass;

Expand Down Expand Up @@ -107,43 +107,52 @@ class Lbfgs
// Start with the preconditioned gradient as the initial search direction.
VectorType search_direction = grad_precond;

// Determine the number of corrections available for the two-loop recursion.
// We exclude the most recent correction (which was just computed) from use.
int k = (mem_count_ > 0 ? static_cast<int>(mem_count_) - 1 : 0);
// Determine the actual number of stored corrections to use
const int k = static_cast<int>(mem_count_);

// --- First Loop (Backward Pass) ---
// Iterate over stored corrections in reverse chronological order.
// First loop: computes q = q - alpha_i * y_i
// Iterates from the newest correction (k-1) to the oldest (k-m_actual)
// conceptual_idx refers to the chronological order: 0=oldest,
// num_valid_corrections-1=newest
for (int i = k - 1; i >= 0; i--) {
// Compute the index in chronological order.
// When mem_count_ < m, corrections are stored in order [0 ...
// mem_count_-1]. When full, they are stored cyclically starting at
// mem_pos_ (oldest) up to (mem_pos_ + m - 1) mod m.
int idx = (mem_count_ < m ? i : ((mem_pos_ + i) % m));
const ScalarType denom =
x_diff_memory_.col(idx).dot(grad_diff_memory_.col(idx));
if (std::abs(denom) < eps) {
const int idx = (mem_count_ < m) ? i : (mem_pos_ + i) % m;

const VectorType &s_col = x_diff_memory_.col(idx);
const VectorType &y_col = grad_diff_memory_.col(idx);

const ScalarType s_dot_y = s_col.dot(y_col);
if (std::abs(s_dot_y) < eps) { // Avoid division by zero or near-zero
continue;
}
const ScalarType rho = 1.0 / denom;
alpha(i) = rho * x_diff_memory_.col(idx).dot(search_direction);
search_direction -= alpha(i) * grad_diff_memory_.col(idx);
const ScalarType rho_val = static_cast<ScalarType>(1.0) / s_dot_y;
alpha(i) = rho_val * s_col.dot(search_direction);
search_direction -= alpha(i) * y_col;
}

// Apply the initial Hessian approximation.
// Apply the initial Hessian approximation H_k^0 = gamma_k * I
// gamma_k = s_{k-1}^T y_{k-1} / (y_{k-1}^T y_{k-1})
// Here, scaling_factor_ is this gamma_k from the *previous* iteration.
search_direction *= scaling_factor_;

// --- Second Loop (Forward Pass) ---
// Second loop: computes r = r + s_i * (alpha_i - beta_i)
// Iterates from the oldest correction (k-m_actual) to the newest (k-1)
for (int i = 0; i < k; i++) {
int idx = (mem_count_ < m ? i : ((mem_pos_ + i) % m));
const ScalarType denom =
x_diff_memory_.col(idx).dot(grad_diff_memory_.col(idx));
if (std::abs(denom) < eps) {
const int idx = (mem_count_ < m) ? i : (mem_pos_ + i) % m;

const VectorType &s_col = x_diff_memory_.col(idx);
const VectorType &y_col = grad_diff_memory_.col(idx);

const ScalarType s_dot_y = s_col.dot(y_col);
if (std::abs(s_dot_y) < eps) {
continue;
}
const ScalarType rho = 1.0 / denom;
const ScalarType beta =
rho * grad_diff_memory_.col(idx).dot(search_direction);
search_direction += x_diff_memory_.col(idx) * (alpha(i) - beta);
const ScalarType rho_val = static_cast<ScalarType>(1.0) / s_dot_y;
const ScalarType beta = rho_val * y_col.dot(search_direction);
search_direction += s_col * (alpha(i) - beta);
}

// Check descent direction validity.
Expand Down Expand Up @@ -210,21 +219,21 @@ class Lbfgs
return next;
}

private:
private:
memory_MatrixType x_diff_memory_;
memory_MatrixType grad_diff_memory_;
// Circular buffer state:
size_t mem_count_ = 0; // Number of corrections stored so far (max m).
size_t mem_pos_ = 0; // Index of the oldest correction in the buffer.
size_t mem_count_ = 0; // Number of corrections stored so far (max m).
size_t mem_pos_ = 0; // Index of the oldest correction in the buffer.

memory_VectorType
alpha; // Storage for the coefficients in the two-loop recursion.
alpha; // Storage for the coefficients in the two-loop recursion.
ScalarType scaling_factor_ = 1;
// Cautious factor to determine whether to accept a new correction pair.
// You may want to expose this parameter or adjust its default value.
ScalarType cautious_factor_ = 1e-6;
};

} // namespace cppoptlib::solver
} // namespace cppoptlib::solver

#endif // INCLUDE_CPPOPTLIB_SOLVER_LBFGS_H_
#endif // INCLUDE_CPPOPTLIB_SOLVER_LBFGS_H_
Loading