-
-
Notifications
You must be signed in to change notification settings - Fork 9.3k
Description
🚀 The feature, motivation and pitch
On cudagraph capture stage, the MLACommonMetadataBuilder
build metadata with max_query_len =1
, with this PR
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata) -> M:
"""
This method builds the metadata for full cudagraph capture.
Currently, only decode is supported for full cudagraphs with MLA.
"""
m = common_attn_metadata
assert m.num_reqs == m.num_actual_tokens, \
"MLA only supports decode-only full CUDAGraph capture. " \
"Make sure all cudagraph capture sizes <= max_num_seq."
m.max_query_len = 1 # decode-only
And when run DeepSeek-R1 with DeepSeek MTP with simple-cuda-graph and run eagle with full cudagraph support patches, the max_query_len
may be 2 (one actual decode token and one Speculative token), which will give an error output like ok, ok, ok, ok....
And we find that, when run DeepSeek with MTP, the MLACommonImpl.forward
will always call self._forward_prefill
, and self._forward_decode
is NOT called anymore, and the cudagraph capture is decode-only
, requir max_query_len=1
, which may cause conflict? I'm NOT sure.
class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
def forward(
...
) -> torch.Tensor:
...
if has_prefill:
output[num_decode_tokens:] = self._forward_prefill(
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
attn_metadata)
if has_decode: # Run DeepSeek R1 with MTP, has_decode is always False
...
output[:num_decode_tokens] = self._forward_decode(
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata)
So, Is there any way to support DeepSeek MTP with full cudagraph?
@ProExpertProg @LucasWilkinson @zixi-qi @YaoJiayi Looking forward to your reply.
Alternatives
No response
Additional context
No response
Before submitting a new issue...
- Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.