Skip to content

Code of paper 'Stochastic Layer-Wise Shuffle for Improving Vision Mamba Training'

huangzizheng01/ShuffleMamba

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

13 Commits
 
 
 
 
 
 
 
 

Repository files navigation

Stochastic Layer-Wise Shuffle for Improving Vision Mamba Training

Official PyTorch implementation of SLWS, a regularization for improving Vision Mamba training.

framework

  • Core Code for Stochastic Layer-Wise Shuffle [plug and play]:

    def shuffle_forward(x, residual, layer: nn.Module, inference_params=None, prob: float = 0.0, training: bool = False):
        """
        Forward pass with optional shuffling of the sequence dimension.
    
        Args:
        - x (torch.Tensor): Input tensor with shape (B, L, d).
        - residual: Input tensor of the same size of x, required by mamba model
        - layer (nn.Module): A PyTorch module through which x should be passed.
        - prob (float): Probability of shuffling the sequence dimension L.
        - training (bool): Indicates whether the model is in training mode.
    
        Returns:
        - torch.Tensor: Output tensor from layer, with the sequence dimension
                        potentially shuffled and then restored.
        """
        
        B, L, _ = x.shape
        if training and torch.rand(1).item() < prob:
            # Generate a random permutation of indices
            shuffled_indices = torch.randperm(L, device=x.device).repeat(B, 1)
            # Get inverse indices by sorting the shuffled indices
            inverse_indices = torch.argsort(shuffled_indices, dim=1)
    
            # Apply the permutation to shuffle the sequence
            x_permuted = x.gather(1, shuffled_indices.unsqueeze(-1).expand(-1, -1, x.size(2)))
            if residual is not None:
                residual_permuted = residual.gather(1, shuffled_indices.unsqueeze(-1).expand(-1, -1, x.size(2)))
            else:
                residual_permuted = residual            
            
            # Forward pass through the layer
            output_permuted, residual_permuted = layer(x_permuted, residual_permuted, inference_params=inference_params)
            # Restore the original order
            output = output_permuted.gather(1, inverse_indices.unsqueeze(-1).expand(-1, -1, output_permuted.size(2)))
            residual = residual_permuted.gather(1, inverse_indices.unsqueeze(-1).expand(-1, -1, residual_permuted.size(2)))
        else:
            # Forward pass without shuffling
            output, residual = layer(x, residual, inference_params=inference_params)
    
        return output, residual

Installation

For ImageNet1K classification training and masked feature distillation

# torch>=2.0, cuda>=11.8
pip install timm==0.4.12 mlflow==2.9.1
pip install causal-conv1d==1.1.0
pip install mamba-ssm==1.1.1

For ADE20K segmentation and COCO detection with Openmmlab tools

pip install mmengine==0.10.1 mmcv==2.1.0 opencv-python-headless ftfy regex
pip install mmdet==3.3.0 mmsegmentation==1.2.2 mmpretrain==1.2.0

Training

  • Example for training the Vim-B for 300 epochs: run the script in ShuffleMamba/Supervised_training run.sh.
  • Example for pre-training for MambaMLP-L: run the script in ShuffleMamba/Masked_distillation run_pt_large.sh.
  • Example for fine-tuning pre-trained MambaMLP-L: run the script in ShuffleMamba/Masked_distillation/Finetuning run_ft_large.sh.

Model Zoo

Models are available at [huggingface🤗]

Citation

@inproceedings{shufflemamba,
      title={Stochastic Layer-Wise Shuffle for Improving Vision Mamba Training}, 
      author={Zizheng Huang and Haoxing Chen and Jiaqi Li and Jun Lan and Huijia Zhu and Weiqiang Wang and Limin Wang},
      booktitle={International Conference on Machine Learning},
      year={2025},
}

Acknowledgement

This repo is built based on Mamba-Reg, VideoMamba, VMamba, ARM, and Vit-Adapter, thanks!

About

Code of paper 'Stochastic Layer-Wise Shuffle for Improving Vision Mamba Training'

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published