Skip to content

Gradient Vanishing in Stacked Pure-Mamba Image Model (Seeking Advice) #811

@LiulianC

Description

@LiulianC

Hello,
I’m currently encountering a gradient vanishing problem when training a pure Mamba-based image processing model. I would really appreciate your advice or any insights into possible causes and fixes.

🧩 Environment

  • PyTorch: 2.6
  • Python: 3.10
  • System: Ubuntu
  • Mamba-ssm: 2.2.2
  • Causal-conv1d: 1.4.0 post 1
  • Note: I cannot use the latest Mamba version because it reports an illegal character error: _Nz3.

🏗️ Model Architecture

I’m directly using the official classes:

class MixerModel(nn.Module):
def __init__(
self,
d_model: int,
n_layer: int,
d_intermediate: int,
vocab_size: int,
ssm_cfg=None,
attn_layer_idx=None,
attn_cfg=None,
norm_epsilon: float = 1e-5,
rms_norm: bool = False,
initializer_cfg=None,
fused_add_norm=False,
residual_in_fp32=False,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.residual_in_fp32 = residual_in_fp32
self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)
# We change the order of residual and layer norm:
# Instead of LN -> Attn / MLP -> Add, we do:
# Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
# the main branch (output of MLP / Mixer). The model definition is unchanged.
# This is for performance reason: we can fuse add + layer_norm.
self.fused_add_norm = fused_add_norm
if self.fused_add_norm:
if layer_norm_fn is None or rms_norm_fn is None:
raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
self.layers = nn.ModuleList(
[
create_block(
d_model,
d_intermediate=d_intermediate,
ssm_cfg=ssm_cfg,
attn_layer_idx=attn_layer_idx,
attn_cfg=attn_cfg,
norm_epsilon=norm_epsilon,
rms_norm=rms_norm,
residual_in_fp32=residual_in_fp32,
fused_add_norm=fused_add_norm,
layer_idx=i,
**factory_kwargs,
)
for i in range(n_layer)
]
)
self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
d_model, eps=norm_epsilon, **factory_kwargs
)
self.apply(
partial(
_init_weights,
n_layer=n_layer,
**(initializer_cfg if initializer_cfg is not None else {}),
n_residuals_per_layer=1 if d_intermediate == 0 else 2, # 2 if we have MLP
)
)

My network follows a typical transformer-style backbone:

image → patch embed → (Mamba1 ×10 + downsample) → (Mamba1 ×10 + downsample)
       → (Mamba1 ×10 + downsample) → (Mamba1 ×10) → decoder → B3HW

Each downsampling step doubles the number of channels and halves the spatial resolution.

I used the official initialization, residual connections, and feed-forward structure.
The initialization layers were not overridden by any global initialization.

⚠️ Problem: Gradient Vanishing

After training, the model output quality was poor, so I checked the gradients layer by layer.

I found that the gradients of the following parameters gradually vanished and propagated to other Mamba layers:

  • dt_proj.weight and dt_proj.bias
  • mixer.fc1.weight and mixer.fc2.weight
  • A_log

Upon investigation:

  • dt_proj is a linear projection layer for preparing parameters — it maps a low-rank temporal delta from d_model/16 to 2*d_model, a 32× expansion, which might be quite large for a standard linear layer.

self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)

  • fc1 and fc2 are the two linear layers inside the MLP, which also show gradient vanishing.

hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias, **factory_kwargs)
self.activation = activation
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, **factory_kwargs)

🙏 Request for Help

Resolving this gradient vanishing issue is very important for my current work.
I would be deeply grateful for any insights, debugging suggestions, or theoretical explanations the team could provide.

Thank you very much for your time and for the excellent open-source work on Mamba.

📜 Gradient Log (Partial)

grad.log

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions