Skip to content

Conversation

@teja-rao
Copy link
Collaborator

@teja-rao teja-rao commented Oct 30, 2025

There are three commits in the PR. each commit can be reviewed individually.

Commit 1: Ensures we use flash attention 3 for VLLMCompatibleFlashAttention. Without this fix, we run in to the following error which shows we are using torch.ops._vllm_fa2_C.varlen_fwd.

Traceback (most recent call last):
  File "/data/users/teja/spirl/simple_rl.py", line 880, in <module>
    main()
  File "/data/users/teja/spirl/simple_rl.py", line 833, in main
    metrics = rl_update_step(
              ^^^^^^^^^^^^^^^
  File "/data/users/teja/spirl/simple_rl.py", line 677, in rl_update_step
    loss, loss_metrics = compute_policy_gradient_loss_vllm(
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/teja/spirl/simple_rl.py", line 496, in compute_policy_gradient_loss_vllm
    logits = model(full_tensor)
             ^^^^^^^^^^^^^^^^^^
  File "/home/teja/.conda/envs/pytorch/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/teja/.conda/envs/pytorch/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/teja/spirl/torchtitan/torchtitan/models/qwen3/model/model_vllm_compat.py", line 357, in forward
    h = layer(h, self.rope_cache, attention_masks)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/teja/.conda/envs/pytorch/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/teja/.conda/envs/pytorch/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/teja/spirl/torchtitan/torchtitan/models/qwen3/model/model_vllm_compat.py", line 265, in forward
    x = x + self.attention(attn_norm_out, rope_cache, attention_masks)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/teja/.conda/envs/pytorch/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/teja/.conda/envs/pytorch/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/teja/spirl/torchtitan/torchtitan/models/qwen3/model/model_vllm_compat.py", line 223, in forward
    output = self.inner_attention(xq, xk, xv, scale=self.scaling)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/teja/.conda/envs/pytorch/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/teja/.conda/envs/pytorch/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/teja/spirl/torchtitan/torchtitan/models/attention.py", line 149, in forward
    output_varlen = self.flash_attn_varlen_func(
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/teja/.conda/envs/pytorch/lib/python3.12/site-packages/vllm/vllm_flash_attn/flash_attn_interface.py", line 236, in flash_attn_varlen_func
    **out, softmax_lse = torch.ops._vllm_fa2_C.varlen_fwd(**
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/teja/.conda/envs/pytorch/lib/python3.12/site-packages/torch/_ops.py", line 1255, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
torch.AcceleratorError: CUDA error: the provided PTX was compiled with an unsupported toolchain.
Search for `cudaErrorUnsupportedPtxVersion' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Commit 2: output layer weights are updated on torchtian but arent getting synced to vllm. this is happening because Qwen3 has tie_word_embeddings=True and is using token embeddings weight for output layer in VLLM. The fix ensure these are also tied in titan side and whenever output layers is updated, the token embedding layer is also updated.

before the fix:


INFO 10-30 12:03:48 [llm.py:345] Supported tasks: ['generate']
Using provided vLLM state dict with 311 weights
Converted to 311 TorchTitan weights
Adding requests: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 2336.92it/s]
Processed prompts: 100%|███████████████████████████████████████████████████████████████████| 20/20 [00:05<00:00,  3.85it/s, est. speed input: 26.93 toks/s, output: 76.93 toks/s]
  ⚠ vLLM-TorchTitan logprobs differ: 15/20 tokens
    Max delta: 3.700256e-04, Avg delta: 3.901466e-05
    vLLM logprobs:     ['-0.6266400814', '-0.9714866281', '-2.9318428040', '-4.8173923492', '-2.0535356998']
    TorchTitan logprobs: ['-0.6266930103', '-0.9714921713', '-2.9318408966', '-4.8170223236', '-2.0535359383']

Step   1 | Loss: -0.0300 | Reward: +1.156±0.731 | Advantage: -0.000±0.802
  Sample:  Paris, the United Kingdom's capital is London, and the capi...

After the fix:

INFO 10-30 12:11:00 [llm.py:345] Supported tasks: ['generate']
Using provided vLLM state dict with 311 weights
Converted to 311 TorchTitan weights
Adding requests: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 2753.61it/s]
Processed prompts: 100%|███████████████████████████████████████████████████████████████████| 20/20 [00:05<00:00,  3.77it/s, est. speed input: 26.41 toks/s, output: 75.46 toks/s]
  ✓ vLLM-TorchTitan bitwise determinism verified: 20 tokens match exactly

Step   1 | Loss: -0.0130 | Reward: +1.150±0.727 | Advantage: -0.000±0.802
  Sample:  Paris, the United Kingdom's capital is London, and the capi...

================================================================================
Training complete!
  1. Gradient flow is broken. Only output layer has weights updated. This is because we are using vllm_rms_norm which does not have backward pass support. This is verified by adding a log to see how many params have grads after the backward pass.
   # Check gradients
    params_with_grad = sum(1 for p in model.parameters() if p.requires_grad and p.grad is not None)
    params_requiring_grad = sum(1 for p in model.parameters() if p.requires_grad)
    print(f"  Gradients computed: {params_with_grad}/{params_requiring_grad} params")

Before the fix:
Gradients computed: 1/282 params

After the fix:
Gradients computed: 198/282 params

@teja-rao teja-rao merged commit 91c7ee1 into bwasti:shim Oct 30, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant