Skip to content

Commit 6b29724

Browse files
move impl MuonConfig right befind its struct
1 parent a637eb8 commit 6b29724

File tree

1 file changed

+52
-52
lines changed

1 file changed

+52
-52
lines changed

crates/burn-optim/src/optim/muon.rs

Lines changed: 52 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,57 @@ pub struct MuonConfig {
163163
adjust_lr_fn: AdjustLrFn,
164164
}
165165

166+
impl MuonConfig {
167+
/// Initialize Muon optimizer.
168+
///
169+
/// # Returns
170+
///
171+
/// Returns an optimizer adaptor that can be used to optimize a module.
172+
///
173+
/// # Example
174+
///
175+
/// ```ignore
176+
/// use burn_optim::{MuonConfig, AdjustLrFn, decay::WeightDecayConfig};
177+
///
178+
/// // Basic configuration with default (Original) LR adjustment
179+
/// let optimizer = MuonConfig::new()
180+
/// .with_weight_decay(Some(WeightDecayConfig::new(0.01)))
181+
/// .init();
182+
///
183+
/// // With AdamW-compatible settings using MatchRmsAdamW
184+
/// let optimizer = MuonConfig::new()
185+
/// .with_adjust_lr_fn(AdjustLrFn::MatchRmsAdamW)
186+
/// .with_weight_decay(Some(WeightDecayConfig::new(0.1)))
187+
/// .init();
188+
///
189+
/// // Custom momentum and NS settings
190+
/// let optimizer = MuonConfig::new()
191+
/// .with_momentum(MomentumConfig {
192+
/// momentum: 0.9,
193+
/// dampening: 0.1,
194+
/// nesterov: false,
195+
/// })
196+
/// .with_ns_steps(7)
197+
/// .init();
198+
/// ```
199+
pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(
200+
&self,
201+
) -> OptimizerAdaptor<Muon<B::InnerBackend>, M, B> {
202+
let momentum = Momentum::new(&self.momentum);
203+
let weight_decay_penalty = self.weight_decay.as_ref().map(|wd| wd.penalty);
204+
205+
let optim = Muon {
206+
momentum,
207+
ns_params: NewtonSchulzParams::new(self.ns_coefficients, self.ns_steps),
208+
weight_decay_penalty,
209+
epsilon: self.epsilon,
210+
adjust_lr_fn: self.adjust_lr_fn,
211+
};
212+
213+
OptimizerAdaptor::from(optim)
214+
}
215+
}
216+
166217
/// Parameters for Newton-Schulz orthogonalization.
167218
#[derive(Clone, Copy)]
168219
struct NewtonSchulzParams {
@@ -326,7 +377,7 @@ impl<B: Backend> SimpleOptimizer<B> for Muon<B> {
326377
/// 4. Apply weight decay (using original lr)
327378
/// 5. Update parameter (using adjusted lr)
328379
///
329-
/// # Important
380+
/// # Notes
330381
///
331382
/// Unlike typical optimizers, the weight decay and parameter update use
332383
/// different learning rates:
@@ -380,57 +431,6 @@ impl<B: Backend> SimpleOptimizer<B> for Muon<B> {
380431
}
381432
}
382433

383-
impl MuonConfig {
384-
/// Initialize Muon optimizer.
385-
///
386-
/// # Returns
387-
///
388-
/// Returns an optimizer adaptor that can be used to optimize a module.
389-
///
390-
/// # Example
391-
///
392-
/// ```ignore
393-
/// use burn_optim::{MuonConfig, AdjustLrFn, decay::WeightDecayConfig};
394-
///
395-
/// // Basic configuration with default (Original) LR adjustment
396-
/// let optimizer = MuonConfig::new()
397-
/// .with_weight_decay(Some(WeightDecayConfig::new(0.01)))
398-
/// .init();
399-
///
400-
/// // With AdamW-compatible settings using MatchRmsAdamW
401-
/// let optimizer = MuonConfig::new()
402-
/// .with_adjust_lr_fn(AdjustLrFn::MatchRmsAdamW)
403-
/// .with_weight_decay(Some(WeightDecayConfig::new(0.1)))
404-
/// .init();
405-
///
406-
/// // Custom momentum and NS settings
407-
/// let optimizer = MuonConfig::new()
408-
/// .with_momentum(MomentumConfig {
409-
/// momentum: 0.9,
410-
/// dampening: 0.1,
411-
/// nesterov: false,
412-
/// })
413-
/// .with_ns_steps(7)
414-
/// .init();
415-
/// ```
416-
pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(
417-
&self,
418-
) -> OptimizerAdaptor<Muon<B::InnerBackend>, M, B> {
419-
let momentum = Momentum::new(&self.momentum);
420-
let weight_decay_penalty = self.weight_decay.as_ref().map(|wd| wd.penalty);
421-
422-
let optim = Muon {
423-
momentum,
424-
ns_params: NewtonSchulzParams::new(self.ns_coefficients, self.ns_steps),
425-
weight_decay_penalty,
426-
epsilon: self.epsilon,
427-
adjust_lr_fn: self.adjust_lr_fn,
428-
};
429-
430-
OptimizerAdaptor::from(optim)
431-
}
432-
}
433-
434434
#[cfg(test)]
435435
mod tests {
436436
use super::*;

0 commit comments

Comments
 (0)