Skip to content

Conversation

@Siritao
Copy link

@Siritao Siritao commented Jun 9, 2025

Motivation

This PR is part of our upcoming work FlowMLA @antgroup. Targeting MLA in Data Parallel scenarios, we enable serving DeepSeek-like models on fewer devices through seamless memory optimization, while providing larger batching capacity for enhanced throughput.
Specifically, we partition the o_proj weights—the largest memory consumer in DP attention—across multiple DP ranks. We non-blockingly prefetch these weights via all_gather only prior to each layer's computation. These gathered weights physically reuse the same GPU memory buffer across operations. This reduces per-rank parameter redundancy to 1/dp_size, freeing up more memory for kv cache, which in turn enables significant throughput gains. Furthermore, our optimization now makes serving FP8 DeepSeek-R1 models on 8×H20 GPU instances feasible, where previously impossible.

display_hidden

The figure above shows the serving profile for DeepSeek-R1 FP8 on 8×H20 GPUs at a single request, demonstrating that prefetch communication can be effectively hidden.

Throughput experiment: 32×H800 (4 nodes, NVLink), DeepSeek-R1 BF16, tp32 moe + dp32 Attention, input length 1024, output length 2048, maximize batch size

Metric w/o FlowMLA w/ FlowMLA
Throughput (tokens/sec) 280 400

Note that the current version only supports DP Attention, not DP+TP.

Modifications

To achieve this, we implemented the vTensor library for weight management. We added two server arguments (--vtensor-enable and --vtensor-port) to enable our optimization, partially modified the implementations of deepseek_v2.py, RowParallelLinear class, cuda_graph_runner.py, and added support for fp8_e4m3.

To run our Opt:​​

  • Install the vTensor module: cd sglang/vtensor && python setup.py install
  • When serving, add --vtensor-enable to your startup command, optionally specifying a port with --vtensor-port [PORT]

Checklist

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @Siritao, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a significant memory optimization technique, dubbed FlowMLA, for serving large language models like DeepSeek-V2 in Data Parallel configurations. It utilizes a new vTensor library based on CUDA VMM to efficiently manage and access partitioned o_proj weights across devices. By prefetching weights for future layers asynchronously, the PR aims to hide communication latency and reduce the memory footprint per GPU, thereby boosting throughput and enabling deployment on less hardware.

Highlights

  • Virtual Memory Management (VMM) with vTensor: Introduces a new C++ library with Python bindings (vTensor) that leverages CUDA Virtual Memory Management (VMM) to manage large tensors, specifically targeting the distribution and access of model weights across devices.
  • Distributed Weight Partitioning and Prefetching: Implements a strategy to partition the large o_proj weights of DeepSeek-V2 models across Data Parallel (DP) ranks. It uses non-blocking all_gather operations on dedicated CUDA streams to prefetch the necessary weight slices for upcoming layers, aiming to overlap communication with computation.
  • Memory Optimization for Data Parallelism: By partitioning weights and using VMM, the per-device memory footprint for model parameters is significantly reduced, freeing up memory for KV cache and enabling larger batch sizes or deployment on fewer GPUs in DP setups.
  • FP8 Support: Adds support for FP8 weights within the vTensor framework and updates the NCCL wrapper to handle FP8 data types for communication.
  • Configuration Options: Adds new server arguments (--vtensor-enable, --vtensor-port) to enable and configure the vTensor optimization.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configureGemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request introduces FlowMLA, a memory optimization technique for DeepSeek-like models in data parallel scenarios. It partitions o_proj weights and uses virtual memory management (vTensor) with NCCL all_gather for non-blocking prefetching. The changes involve adding a new C++ extension (vTensor), integrating it into the Python model and infrastructure code, and adding new server arguments. The approach seems promising for reducing memory footprint and boosting throughput, but there are several areas for improvement regarding correctness, maintainability, and clarity of the distributed logic.

Comment on lines +523 to +529
if self.vtensor_enable:
with torch.cuda.graph(graph, pool=global_graph_memory_pool, stream=stream):
self._stream_odd.wait_stream(stream)
self._stream_even.wait_stream(stream)
out = run_once()
stream.wait_stream(self._stream_odd)
stream.wait_stream(self._stream_even)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Adding stream synchronization (wait_stream) within the CUDA graph capture ensures that the prefetch operations on the odd/even streams are correctly ordered with respect to the main computation stream. This is crucial for correctness when using asynchronous operations.

Comment on lines 168 to 171
VmmTensor::~VmmTensor() {
DRV_CALL(cuMemUnmap(v_ptr, padded_size));
auto tmp = std::move(u_p_block);
DRV_CALL(cuMemAddressFree(v_ptr, padded_size));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The VmmTensor destructor correctly unmaps and frees the main virtual address space (v_ptr). It also moves the unique physical block (u_p_block), which should handle its deallocation via std::unique_ptr. However, it does not explicitly manage the shared physical blocks, which are held in global vectors (shared_phy_blocks_pre, shared_phy_blocks_post). These global vectors and the blocks they hold might not be properly released, potentially leading to resource leaks, especially if release_shared_phy_blocks is not called or if VmmTensor objects are created and destroyed without a corresponding global cleanup.

Comment on lines 134 to 140
DRV_CALL(cuMemAddressReserve(&offset_v_ptr, offset_size, 0ULL, 0ULL, 0ULL));
DRV_CALL(cuMemMap(offset_v_ptr, offset_size, 0ULL, this->u_p_block->alloc_handle, 0ULL));
CUmemAccessDesc accessDesc = {};
accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
accessDesc.location.id = this->device_id;
accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
DRV_CALL(cuMemSetAccess(offset_v_ptr, offset_size, &accessDesc, 1));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The SplitTensor method reserves a new virtual address space (cuMemAddressReserve) and maps the unique physical block to it. It's not clear if this offset_v_ptr address space is ever freed. The VmmTensor destructor only frees v_ptr. This could lead to a resource leak if SplitTensor is called.

Comment on lines 113 to 117
assert(shared_phy_index < shared_phy_blocks_pre.size());
phy_block = shared_phy_blocks_pre[shared_phy_index];
} else {
assert(shared_phy_index < shared_phy_blocks_post.size());
phy_block = shared_phy_blocks_post[shared_phy_index];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic here assumes that shared_phy_blocks_pre and shared_phy_blocks_post are indexed sequentially (shared_phy_index++) to assign shared blocks to the other ranks' slices. This also relies on the order of initialization and allocation matching across ranks. Ensure this is robust in a distributed setting.

Comment on lines 1204 to 1205
if vtensor_enable and "o_proj" in prefix:
assert tp_size == 1, "FlowMLA currently only supported tp size == 1"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The assertion assert tp_size == 1 indicates that FlowMLA is currently limited to tensor parallel size 1. The PR description mentions partitioning weights across DP ranks. This assertion seems to imply that TP > 1 is not supported when FlowMLA is enabled. Please clarify this limitation in the documentation or PR description if it's intended. If FlowMLA is meant for DP, the assertion might be misplaced or the condition for enabling vtensor should be tied to DP size.

world_size = get_tensor_model_parallel_world_size()
if world_size > 8:
world_size = 8
# TODO: only support float8_e4m3 weight
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The TODO comment indicates that only float8_e4m3 weight is supported for vtensor. Please update this comment or the code if other dtypes are supported. Also, the condition quant_config is not None and quant_config.is_checkpoint_fp8_serialized checks if the checkpoint is FP8 serialized, not necessarily if the weight being loaded is FP8. This might be okay if vtensor is only enabled for FP8 models, but the logic could be clearer.

Comment on lines +162 to +175
if vtensor_weight is None:
weight = Parameter(
torch.empty(
sum(output_partition_sizes),
input_size_per_partition,
dtype=params_dtype,
),
requires_grad=False,
)
else:
weight = Parameter(
vtensor_weight,
requires_grad=False,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The create_weights method now supports initializing the weight parameter with a pre-created vtensor_weight. This allows injecting the VMM-backed tensor.

"IPEXAWQLinearMethod",
]

vtensor_pynccl = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using a global variable for the NCCL communicator (vtensor_pynccl) can make the code harder to manage and test, especially in multi-process or multi-threaded environments. Consider passing the communicator instance explicitly where needed or managing its lifecycle within a dedicated class or context.

Comment on lines +119 to +120
if dtype == torch.float8_e4m3fn:
return cls.ncclFloat8e4m3
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Mapping torch.float8_e4m3fn to the new ncclFloat8e4m3 enum value is correctly implemented.

Comment on lines +62 to +64
std::vector<std::shared_ptr<PhyBlock>> shared_phy_blocks_pre;
std::vector<std::shared_ptr<PhyBlock>> shared_phy_blocks_post;
std::vector<std::unique_ptr<PhyBlock>> unique_phy_blocks;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The global vectors shared_phy_blocks_pre, shared_phy_blocks_post, and unique_phy_blocks are used to manage physical memory blocks. Their lifecycle and thread safety need careful consideration, especially in a multi-process or multi-threaded environment. Ensure these resources are properly initialized and released.

@ch-wan
Copy link
Collaborator

ch-wan commented Jun 10, 2025

Is it applied to prefill or decode?

@Siritao
Copy link
Author

Siritao commented Jun 10, 2025

Is it applied to prefill or decode?

The optimization applies to both prefill and decoding stages. For decoding, we add cudagraph support.

@ch-wan ch-wan self-assigned this Jun 10, 2025
@ch-wan
Copy link
Collaborator

ch-wan commented Jun 24, 2025

@Siritao Thank you for your excellent contribution. I'd like to get more details about the settings. What is the batch size of your evaluation? I assume that when bs is small, communication cannot be fully hidden. What is the threshold?

@Siritao
Copy link
Author

Siritao commented Jun 25, 2025

@Siritao Thank you for your excellent contribution. I'd like to get more details about the settings. What is the batch size of your evaluation? I assume that when bs is small, communication cannot be fully hidden. What is the threshold?

Thanks for your attention. The figure was traced on ​8xH20 (DeepSeek-R1 FP8) with batch_size=1, demonstrating that communication overhead can be effectively hidden even under minimal batch conditions. However, FlowMLA’s performance does vary based on GPUs and model scales. To address this, we’ve added ​large-scale throughput experiments​ on ​32xH800 (DeepSeek-R1 BF16)​ in the PR submission. For this test, we used tp32 moe && dp32 Attention​ (o_proj sharded intra-node) && fixed input/output length​ (1k/2k tokens) && maximize batch size.
In this test scenario, enabling FlowMLA increased KV cache size from 6.22 GB to 16.80 GB per device (with --mem-fraction-static=0.9) while significantly boosting throughput, demonstrating the effectiveness of FlowMLA for memory-bound workloads.

} \
}

class PhyBlock {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, this memory block is not classic cached based design (implemented memroy block, segment logics).

Instead of extending pytroch cached memory allocator, I am concerned that these memory blocks may not be correctly recorded by Pytorch event profiler.

Last year we were discussing in Pytorch community how to extend pytorch cached allocator for blocked based GPU memroy allocator (in application of training/inference) :

pytorch/pytorch#115336

At that time, direclty inheritance is not possible, so I rewrited a cached pytorch memory allocator, but I believe now we can directly inherited pytorch cached allocator interface.

This will bring good features : allocation can be correctly recorded by pytorch event profiler and mangement of block based GPU memory can be minimized in allocator logics.

@yiakwy-xpu-ml-framework-team
Copy link
Contributor

yiakwy-xpu-ml-framework-team commented Jul 2, 2025

As offline discussed, mapping/unmapping these kind memory is essential to applicaiton such as RL (cc @zhaochenyang20 ), I recall the ref vllm PR#11743 support release of KV cache memory in sleep mode last year.

My suggestion is to make a CachedCuMemAllocator which inherit pytorch CachedMemoryPlugin by maintaining a private memory pool.

SInce this function, is not exclusive to this PR (this PR does not offload memory to CPU), we can do it in SGLang-Kernel with SGlang version cuMemory Allocator , and I can add cached version in the proposal.

To use pytorch cached memory allocate machenism, simply export alloc, dealloc external C API interface, and bind the function in python side :

def get_pluggable_allocator(
    python_malloc_fn: Callable[[int],
                               int], python_free_func: Callable[[int, int],
                                                                None]
) -> torch.cuda.memory.CUDAPluggableAllocator:
    init_module(python_malloc_fn, python_free_func)
    new_alloc = torch.cuda.memory.CUDAPluggableAllocator(
        lib_name, 'my_malloc', 'my_free')
    return new_alloc

cc @zhyncs

@miter6
Copy link
Contributor

miter6 commented Jul 15, 2025

Dose flowmla support for deepep?

@miter6
Copy link
Contributor

miter6 commented Jul 15, 2025

another question.
Dose flowmla support for PD serving?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants