Skip to content

Commit 4c4fd65

Browse files
to align with pytorch implementaion, allow 2d tensors only
1 parent 2edeeaf commit 4c4fd65

File tree

1 file changed

+3
-87
lines changed

1 file changed

+3
-87
lines changed

crates/burn-optim/src/optim/muon.rs

Lines changed: 3 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -302,8 +302,8 @@ impl<B: Backend> Muon<B> {
302302
/// - PyTorch: https://github.com/pytorch/pytorch/blob/main/torch/optim/muon.py
303303
fn zeropower_via_newtonschulz<const D: usize>(&self, g: Tensor<B, D>) -> Tensor<B, D> {
304304
assert!(
305-
D >= 2,
306-
"Newton-Schulz iteration requires at least 2D tensors, got {}D",
305+
D != 2,
306+
"Newton-Schulz iteration requires 2D tensors, got {}D",
307307
D
308308
);
309309

@@ -389,14 +389,6 @@ impl<B: Backend> SimpleOptimizer<B> for Muon<B> {
389389
grad: Tensor<B, D>,
390390
state: Option<Self::State<D>>,
391391
) -> (Tensor<B, D>, Option<Self::State<D>>) {
392-
assert!(
393-
D >= 2,
394-
"Muon optimizer is designed for 2D+ parameters (matrices). \
395-
For 1D parameters (biases, layer norms), use AdamW or SGD instead. \
396-
Got {}D tensor.",
397-
D
398-
);
399-
400392
// Step 1: Apply momentum
401393
let state_momentum = state.map(|s| s.momentum);
402394
let (grad, new_momentum_state) = self.momentum.transform(grad, state_momentum);
@@ -500,7 +492,7 @@ mod tests {
500492
}
501493

502494
#[test]
503-
#[should_panic(expected = "2D+ parameters")]
495+
#[should_panic(expected = "2D parameters")]
504496
fn test_1d_tensor_panics() {
505497
let device = Default::default();
506498
let config = MuonConfig::new();
@@ -615,82 +607,6 @@ mod tests {
615607
);
616608
}
617609

618-
#[test]
619-
fn test_muon_with_3d_tensor() {
620-
// Test that Muon works with 3D tensors (e.g., batched weight matrices)
621-
// Shape: [batch_size, height, width]
622-
let device = Default::default();
623-
624-
// Create a 3D tensor: [2, 4, 3] - 2 batches of 4x3 matrices
625-
let tensor_3d = Tensor::<TestBackend, 3>::from_floats(
626-
[
627-
// Batch 1
628-
[
629-
[1.0, 0.5, 0.2],
630-
[0.5, 1.0, 0.3],
631-
[0.2, 0.3, 1.0],
632-
[0.1, 0.2, 0.3],
633-
],
634-
// Batch 2
635-
[
636-
[1.0, 0.4, 0.1],
637-
[0.4, 1.0, 0.2],
638-
[0.1, 0.2, 1.0],
639-
[0.3, 0.1, 0.2],
640-
],
641-
],
642-
&device,
643-
);
644-
645-
let grad_3d = Tensor::<TestBackend, 3>::from_floats(
646-
[
647-
// Batch 1 gradients
648-
[
649-
[0.1, 0.2, 0.3],
650-
[0.2, 0.1, 0.2],
651-
[0.3, 0.2, 0.1],
652-
[0.1, 0.1, 0.1],
653-
],
654-
// Batch 2 gradients
655-
[
656-
[0.2, 0.1, 0.1],
657-
[0.1, 0.2, 0.1],
658-
[0.1, 0.1, 0.2],
659-
[0.2, 0.2, 0.2],
660-
],
661-
],
662-
&device,
663-
);
664-
665-
let config = MuonConfig::new();
666-
let muon: Muon<TestBackend> = Muon {
667-
momentum: Momentum::new(&config.momentum),
668-
ns_params: NewtonSchulzParams::new(config.ns_coefficients, config.ns_steps),
669-
weight_decay_penalty: None,
670-
epsilon: config.epsilon,
671-
adjust_lr_fn: config.adjust_lr_fn,
672-
};
673-
674-
// Should not panic - Muon supports D >= 2
675-
let (updated_tensor, state) = muon.step(0.01, tensor_3d.clone(), grad_3d, None);
676-
677-
// Verify state was created
678-
assert!(state.is_some());
679-
680-
// Verify tensor was updated (should be different from original)
681-
let original_data = tensor_3d.into_data();
682-
let updated_data = updated_tensor.into_data();
683-
684-
assert_ne!(
685-
original_data.as_slice::<f32>().unwrap(),
686-
updated_data.as_slice::<f32>().unwrap(),
687-
"Tensor should be updated after optimization step"
688-
);
689-
690-
// Verify shape is preserved
691-
assert_eq!(updated_data.shape, original_data.shape);
692-
}
693-
694610
#[test]
695611
fn test_tall_matrix_transpose() {
696612
// Test that tall matrices (A > B) are transposed during Newton-Schulz iteration

0 commit comments

Comments
 (0)