-
def shuffle_forward(x, residual, layer: nn.Module, inference_params=None, prob: float = 0.0, training: bool = False): """ Forward pass with optional shuffling of the sequence dimension. Args: - x (torch.Tensor): Input tensor with shape (B, L, d). - residual: Input tensor of the same size of x, required by mamba model - layer (nn.Module): A PyTorch module through which x should be passed. - prob (float): Probability of shuffling the sequence dimension L. - training (bool): Indicates whether the model is in training mode. Returns: - torch.Tensor: Output tensor from layer, with the sequence dimension potentially shuffled and then restored. """ B, L, _ = x.shape if training and torch.rand(1).item() < prob: # Generate a random permutation of indices shuffled_indices = torch.randperm(L, device=x.device).repeat(B, 1) # Get inverse indices by sorting the shuffled indices inverse_indices = torch.argsort(shuffled_indices, dim=1) # Apply the permutation to shuffle the sequence x_permuted = x.gather(1, shuffled_indices.unsqueeze(-1).expand(-1, -1, x.size(2))) if residual is not None: residual_permuted = residual.gather(1, shuffled_indices.unsqueeze(-1).expand(-1, -1, x.size(2))) else: residual_permuted = residual # Forward pass through the layer output_permuted, residual_permuted = layer(x_permuted, residual_permuted, inference_params=inference_params) # Restore the original order output = output_permuted.gather(1, inverse_indices.unsqueeze(-1).expand(-1, -1, output_permuted.size(2))) residual = residual_permuted.gather(1, inverse_indices.unsqueeze(-1).expand(-1, -1, residual_permuted.size(2))) else: # Forward pass without shuffling output, residual = layer(x, residual, inference_params=inference_params) return output, residual
# torch>=2.0, cuda>=11.8
pip install timm==0.4.12 mlflow==2.9.1
pip install causal-conv1d==1.1.0
pip install mamba-ssm==1.1.1pip install mmengine==0.10.1 mmcv==2.1.0 opencv-python-headless ftfy regex
pip install mmdet==3.3.0 mmsegmentation==1.2.2 mmpretrain==1.2.0- Example for training the Vim-B for 300 epochs: run the script in
ShuffleMamba/Supervised_trainingrun.sh. - Example for pre-training for MambaMLP-L: run the script in
ShuffleMamba/Masked_distillationrun_pt_large.sh. - Example for fine-tuning pre-trained MambaMLP-L: run the script in
ShuffleMamba/Masked_distillation/Finetuningrun_ft_large.sh.
Models are available at [huggingface🤗]
@inproceedings{shufflemamba,
title={Stochastic Layer-Wise Shuffle for Improving Vision Mamba Training},
author={Zizheng Huang and Haoxing Chen and Jiaqi Li and Jun Lan and Huijia Zhu and Weiqiang Wang and Limin Wang},
booktitle={International Conference on Machine Learning},
year={2025},
}This repo is built based on Mamba-Reg, VideoMamba, VMamba, ARM, and Vit-Adapter, thanks!
