Skip to content

[RFC]: All Ops should be determined during init and wrapped in a Layer Module to avoid envs.ENVIRON overhead #17067

@tjtanaa

Description

@tjtanaa

Motivation.

Accessing envs.ENVIRON has non-negligible overhead. Given that LLM models have many ops and layers. The overhead from accessing envs.ENVIRON could spike to 0.1 ~ 1ms overhead per token. I have observed a huge overhead in MLA prefill forward pass when using envs.ENVIRON in kernel selection logic (where if-else statement is involved).

Proposed action:

  1. Layer Module is suggested to store the selected kernel ops as a property of the layer.
    @cache is discourage due to the increasing complexity that it is causing to clear the cache as there are many properties depending on envs.
    @cache is discouraged in several PRs review as there is a usecase as such:
    Users instantiate multiple LLMs in a single python program. Each LLM instance uses different sets of ENV variables.

  2. Document the overhead issue down in vLLM documentation page under Contribution section to remind developers of the abstract and the overhead caused by envs.ENVIRON invocation.

Overhead experiments

Average time per accessing envs.ENVIRON : 1.0514259338378907e-06 seconds
Average time per accessing class method access : 3.0994415283203126e-08 seconds

Script:

import time
import torch
import vllm.envs as envs


class Layer:
    @staticmethod
    def forward():
        return torch.randn(1024, 1024, dtype=torch.float16, device='cuda')

def test_envs_timing():
    
    # Time the operation
    num_runs = 100
    start_time = time.time()
    for _ in range(num_runs):
        envs.VLLM_ROCM_USE_AITER_LINEAR
    end_time = time.time()
    
    # Calculate average time per run
    avg_time = (end_time - start_time) / num_runs
    print(f"Average time per accessing envs.ENVIRON : {avg_time} seconds")


def test_class_method():
    
    # Time the operation
    num_runs = 100
    start_time = time.time()
    for _ in range(num_runs):
        Layer.forward
    end_time = time.time()
    
    # Calculate average time per run
    avg_time = (end_time - start_time) / num_runs
    print(f"Average time per accessing class method access : {avg_time} seconds")

if __name__ == "__main__":
    test_envs_timing()
    test_class_method()

Proposed Change.

As an example, in model layers, such as linear layers, they should use the following abstraction:
vllm/model_executor/layers/quantization/fp8.py

from vllm.model_executor.layers.quantization.fp8_utils import dispatch_w8a8_blockscale_func
class Fp8LinearMethod(LinearMethodBase):
    def __init__(self, quant_config: Fp8Config):
        ...
        if self.block_quant:
            self.linear_func = dispatch_w8a8_blockscale_func(self.cutlass_block_fp8_supported)
        else:
            # logic to select the per-tensor/ per-channel scaled gemm kernel
   
    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
        # no more checks for condition
        return self.linear_func(input=x,
                                     weight=layer.weight,
                                     weight_scale=layer.weight_scale,
                                     out_dtype=self.out_dtype,
                                     input_scale=layer.input_scale,
                                     bias=bias)

aiter-block-gemm-integration/vllm/model_executor/layers/quantization/utils/fp8_utils.py


def dispatch_w8a8_blockscale_func(
        use_cutlass: bool) -> Callable[..., torch.Tensor]:
    if use_cutlass:
        return cutlass_scaled_mm
    if (current_platform.is_rocm() and 
        envs.VLLM_ROCM_USE_AITER and 
        envs.VLLM_ROCM_USE_AITER_LINEAR):
        return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale
    return w8a8_block_fp8_matmul

Feedback Period.

No response

CC List.

No response

Any Other Things.

No response

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    RFCstaleOver 90 days of inactivity

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions