-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Description
Hello,
I’m currently encountering a gradient vanishing problem when training a pure Mamba-based image processing model. I would really appreciate your advice or any insights into possible causes and fixes.
🧩 Environment
- PyTorch: 2.6
- Python: 3.10
- System: Ubuntu
- Mamba-ssm: 2.2.2
- Causal-conv1d: 1.4.0 post 1
- Note: I cannot use the latest Mamba version because it reports an illegal character error:
_Nz3.
🏗️ Model Architecture
I’m directly using the official classes:
mamba/mamba_ssm/models/mixer_seq_simple.py
Lines 118 to 182 in 10b5d63
| class MixerModel(nn.Module): | |
| def __init__( | |
| self, | |
| d_model: int, | |
| n_layer: int, | |
| d_intermediate: int, | |
| vocab_size: int, | |
| ssm_cfg=None, | |
| attn_layer_idx=None, | |
| attn_cfg=None, | |
| norm_epsilon: float = 1e-5, | |
| rms_norm: bool = False, | |
| initializer_cfg=None, | |
| fused_add_norm=False, | |
| residual_in_fp32=False, | |
| device=None, | |
| dtype=None, | |
| ) -> None: | |
| factory_kwargs = {"device": device, "dtype": dtype} | |
| super().__init__() | |
| self.residual_in_fp32 = residual_in_fp32 | |
| self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs) | |
| # We change the order of residual and layer norm: | |
| # Instead of LN -> Attn / MLP -> Add, we do: | |
| # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and | |
| # the main branch (output of MLP / Mixer). The model definition is unchanged. | |
| # This is for performance reason: we can fuse add + layer_norm. | |
| self.fused_add_norm = fused_add_norm | |
| if self.fused_add_norm: | |
| if layer_norm_fn is None or rms_norm_fn is None: | |
| raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels") | |
| self.layers = nn.ModuleList( | |
| [ | |
| create_block( | |
| d_model, | |
| d_intermediate=d_intermediate, | |
| ssm_cfg=ssm_cfg, | |
| attn_layer_idx=attn_layer_idx, | |
| attn_cfg=attn_cfg, | |
| norm_epsilon=norm_epsilon, | |
| rms_norm=rms_norm, | |
| residual_in_fp32=residual_in_fp32, | |
| fused_add_norm=fused_add_norm, | |
| layer_idx=i, | |
| **factory_kwargs, | |
| ) | |
| for i in range(n_layer) | |
| ] | |
| ) | |
| self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)( | |
| d_model, eps=norm_epsilon, **factory_kwargs | |
| ) | |
| self.apply( | |
| partial( | |
| _init_weights, | |
| n_layer=n_layer, | |
| **(initializer_cfg if initializer_cfg is not None else {}), | |
| n_residuals_per_layer=1 if d_intermediate == 0 else 2, # 2 if we have MLP | |
| ) | |
| ) |
My network follows a typical transformer-style backbone:
image → patch embed → (Mamba1 ×10 + downsample) → (Mamba1 ×10 + downsample)
→ (Mamba1 ×10 + downsample) → (Mamba1 ×10) → decoder → B3HW
Each downsampling step doubles the number of channels and halves the spatial resolution.
I used the official initialization, residual connections, and feed-forward structure.
The initialization layers were not overridden by any global initialization.
⚠️ Problem: Gradient Vanishing
After training, the model output quality was poor, so I checked the gradients layer by layer.
I found that the gradients of the following parameters gradually vanished and propagated to other Mamba layers:
dt_proj.weightanddt_proj.biasmixer.fc1.weightandmixer.fc2.weightA_log
Upon investigation:
dt_projis a linear projection layer for preparing parameters — it maps a low-rank temporal delta fromd_model/16to2*d_model, a 32× expansion, which might be quite large for a standard linear layer.
mamba/mamba_ssm/modules/mamba_simple.py
Line 80 in 10b5d63
| self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs) |
fc1andfc2are the two linear layers inside the MLP, which also show gradient vanishing.
mamba/mamba_ssm/modules/mlp.py
Lines 24 to 27 in 10b5d63
| hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of | |
| self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias, **factory_kwargs) | |
| self.activation = activation | |
| self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, **factory_kwargs) |
🙏 Request for Help
Resolving this gradient vanishing issue is very important for my current work.
I would be deeply grateful for any insights, debugging suggestions, or theoretical explanations the team could provide.
Thank you very much for your time and for the excellent open-source work on Mamba.