@@ -302,8 +302,8 @@ impl<B: Backend> Muon<B> {
302302 /// - PyTorch: https://github.com/pytorch/pytorch/blob/main/torch/optim/muon.py
303303 fn zeropower_via_newtonschulz < const D : usize > ( & self , g : Tensor < B , D > ) -> Tensor < B , D > {
304304 assert ! (
305- D > = 2 ,
306- "Newton-Schulz iteration requires at least 2D tensors, got {}D" ,
305+ D ! = 2 ,
306+ "Newton-Schulz iteration requires 2D tensors, got {}D" ,
307307 D
308308 ) ;
309309
@@ -389,14 +389,6 @@ impl<B: Backend> SimpleOptimizer<B> for Muon<B> {
389389 grad : Tensor < B , D > ,
390390 state : Option < Self :: State < D > > ,
391391 ) -> ( Tensor < B , D > , Option < Self :: State < D > > ) {
392- assert ! (
393- D >= 2 ,
394- "Muon optimizer is designed for 2D+ parameters (matrices). \
395- For 1D parameters (biases, layer norms), use AdamW or SGD instead. \
396- Got {}D tensor.",
397- D
398- ) ;
399-
400392 // Step 1: Apply momentum
401393 let state_momentum = state. map ( |s| s. momentum ) ;
402394 let ( grad, new_momentum_state) = self . momentum . transform ( grad, state_momentum) ;
@@ -500,7 +492,7 @@ mod tests {
500492 }
501493
502494 #[ test]
503- #[ should_panic( expected = "2D+ parameters" ) ]
495+ #[ should_panic( expected = "2D parameters" ) ]
504496 fn test_1d_tensor_panics ( ) {
505497 let device = Default :: default ( ) ;
506498 let config = MuonConfig :: new ( ) ;
@@ -615,82 +607,6 @@ mod tests {
615607 ) ;
616608 }
617609
618- #[ test]
619- fn test_muon_with_3d_tensor ( ) {
620- // Test that Muon works with 3D tensors (e.g., batched weight matrices)
621- // Shape: [batch_size, height, width]
622- let device = Default :: default ( ) ;
623-
624- // Create a 3D tensor: [2, 4, 3] - 2 batches of 4x3 matrices
625- let tensor_3d = Tensor :: < TestBackend , 3 > :: from_floats (
626- [
627- // Batch 1
628- [
629- [ 1.0 , 0.5 , 0.2 ] ,
630- [ 0.5 , 1.0 , 0.3 ] ,
631- [ 0.2 , 0.3 , 1.0 ] ,
632- [ 0.1 , 0.2 , 0.3 ] ,
633- ] ,
634- // Batch 2
635- [
636- [ 1.0 , 0.4 , 0.1 ] ,
637- [ 0.4 , 1.0 , 0.2 ] ,
638- [ 0.1 , 0.2 , 1.0 ] ,
639- [ 0.3 , 0.1 , 0.2 ] ,
640- ] ,
641- ] ,
642- & device,
643- ) ;
644-
645- let grad_3d = Tensor :: < TestBackend , 3 > :: from_floats (
646- [
647- // Batch 1 gradients
648- [
649- [ 0.1 , 0.2 , 0.3 ] ,
650- [ 0.2 , 0.1 , 0.2 ] ,
651- [ 0.3 , 0.2 , 0.1 ] ,
652- [ 0.1 , 0.1 , 0.1 ] ,
653- ] ,
654- // Batch 2 gradients
655- [
656- [ 0.2 , 0.1 , 0.1 ] ,
657- [ 0.1 , 0.2 , 0.1 ] ,
658- [ 0.1 , 0.1 , 0.2 ] ,
659- [ 0.2 , 0.2 , 0.2 ] ,
660- ] ,
661- ] ,
662- & device,
663- ) ;
664-
665- let config = MuonConfig :: new ( ) ;
666- let muon: Muon < TestBackend > = Muon {
667- momentum : Momentum :: new ( & config. momentum ) ,
668- ns_params : NewtonSchulzParams :: new ( config. ns_coefficients , config. ns_steps ) ,
669- weight_decay_penalty : None ,
670- epsilon : config. epsilon ,
671- adjust_lr_fn : config. adjust_lr_fn ,
672- } ;
673-
674- // Should not panic - Muon supports D >= 2
675- let ( updated_tensor, state) = muon. step ( 0.01 , tensor_3d. clone ( ) , grad_3d, None ) ;
676-
677- // Verify state was created
678- assert ! ( state. is_some( ) ) ;
679-
680- // Verify tensor was updated (should be different from original)
681- let original_data = tensor_3d. into_data ( ) ;
682- let updated_data = updated_tensor. into_data ( ) ;
683-
684- assert_ne ! (
685- original_data. as_slice:: <f32 >( ) . unwrap( ) ,
686- updated_data. as_slice:: <f32 >( ) . unwrap( ) ,
687- "Tensor should be updated after optimization step"
688- ) ;
689-
690- // Verify shape is preserved
691- assert_eq ! ( updated_data. shape, original_data. shape) ;
692- }
693-
694610 #[ test]
695611 fn test_tall_matrix_transpose ( ) {
696612 // Test that tall matrices (A > B) are transposed during Newton-Schulz iteration
0 commit comments