@@ -163,6 +163,57 @@ pub struct MuonConfig {
163163 adjust_lr_fn : AdjustLrFn ,
164164}
165165
166+ impl MuonConfig {
167+ /// Initialize Muon optimizer.
168+ ///
169+ /// # Returns
170+ ///
171+ /// Returns an optimizer adaptor that can be used to optimize a module.
172+ ///
173+ /// # Example
174+ ///
175+ /// ```ignore
176+ /// use burn_optim::{MuonConfig, AdjustLrFn, decay::WeightDecayConfig};
177+ ///
178+ /// // Basic configuration with default (Original) LR adjustment
179+ /// let optimizer = MuonConfig::new()
180+ /// .with_weight_decay(Some(WeightDecayConfig::new(0.01)))
181+ /// .init();
182+ ///
183+ /// // With AdamW-compatible settings using MatchRmsAdamW
184+ /// let optimizer = MuonConfig::new()
185+ /// .with_adjust_lr_fn(AdjustLrFn::MatchRmsAdamW)
186+ /// .with_weight_decay(Some(WeightDecayConfig::new(0.1)))
187+ /// .init();
188+ ///
189+ /// // Custom momentum and NS settings
190+ /// let optimizer = MuonConfig::new()
191+ /// .with_momentum(MomentumConfig {
192+ /// momentum: 0.9,
193+ /// dampening: 0.1,
194+ /// nesterov: false,
195+ /// })
196+ /// .with_ns_steps(7)
197+ /// .init();
198+ /// ```
199+ pub fn init < B : AutodiffBackend , M : AutodiffModule < B > > (
200+ & self ,
201+ ) -> OptimizerAdaptor < Muon < B :: InnerBackend > , M , B > {
202+ let momentum = Momentum :: new ( & self . momentum ) ;
203+ let weight_decay_penalty = self . weight_decay . as_ref ( ) . map ( |wd| wd. penalty ) ;
204+
205+ let optim = Muon {
206+ momentum,
207+ ns_params : NewtonSchulzParams :: new ( self . ns_coefficients , self . ns_steps ) ,
208+ weight_decay_penalty,
209+ epsilon : self . epsilon ,
210+ adjust_lr_fn : self . adjust_lr_fn ,
211+ } ;
212+
213+ OptimizerAdaptor :: from ( optim)
214+ }
215+ }
216+
166217/// Parameters for Newton-Schulz orthogonalization.
167218#[ derive( Clone , Copy ) ]
168219struct NewtonSchulzParams {
@@ -326,7 +377,7 @@ impl<B: Backend> SimpleOptimizer<B> for Muon<B> {
326377 /// 4. Apply weight decay (using original lr)
327378 /// 5. Update parameter (using adjusted lr)
328379 ///
329- /// # Important
380+ /// # Notes
330381 ///
331382 /// Unlike typical optimizers, the weight decay and parameter update use
332383 /// different learning rates:
@@ -380,57 +431,6 @@ impl<B: Backend> SimpleOptimizer<B> for Muon<B> {
380431 }
381432}
382433
383- impl MuonConfig {
384- /// Initialize Muon optimizer.
385- ///
386- /// # Returns
387- ///
388- /// Returns an optimizer adaptor that can be used to optimize a module.
389- ///
390- /// # Example
391- ///
392- /// ```ignore
393- /// use burn_optim::{MuonConfig, AdjustLrFn, decay::WeightDecayConfig};
394- ///
395- /// // Basic configuration with default (Original) LR adjustment
396- /// let optimizer = MuonConfig::new()
397- /// .with_weight_decay(Some(WeightDecayConfig::new(0.01)))
398- /// .init();
399- ///
400- /// // With AdamW-compatible settings using MatchRmsAdamW
401- /// let optimizer = MuonConfig::new()
402- /// .with_adjust_lr_fn(AdjustLrFn::MatchRmsAdamW)
403- /// .with_weight_decay(Some(WeightDecayConfig::new(0.1)))
404- /// .init();
405- ///
406- /// // Custom momentum and NS settings
407- /// let optimizer = MuonConfig::new()
408- /// .with_momentum(MomentumConfig {
409- /// momentum: 0.9,
410- /// dampening: 0.1,
411- /// nesterov: false,
412- /// })
413- /// .with_ns_steps(7)
414- /// .init();
415- /// ```
416- pub fn init < B : AutodiffBackend , M : AutodiffModule < B > > (
417- & self ,
418- ) -> OptimizerAdaptor < Muon < B :: InnerBackend > , M , B > {
419- let momentum = Momentum :: new ( & self . momentum ) ;
420- let weight_decay_penalty = self . weight_decay . as_ref ( ) . map ( |wd| wd. penalty ) ;
421-
422- let optim = Muon {
423- momentum,
424- ns_params : NewtonSchulzParams :: new ( self . ns_coefficients , self . ns_steps ) ,
425- weight_decay_penalty,
426- epsilon : self . epsilon ,
427- adjust_lr_fn : self . adjust_lr_fn ,
428- } ;
429-
430- OptimizerAdaptor :: from ( optim)
431- }
432- }
433-
434434#[ cfg( test) ]
435435mod tests {
436436 use super :: * ;
0 commit comments