-
-
Notifications
You must be signed in to change notification settings - Fork 9.2k
Description
Motivation.
Accessing envs.ENVIRON has non-negligible overhead. Given that LLM models have many ops and layers. The overhead from accessing envs.ENVIRON could spike to 0.1 ~ 1ms overhead per token. I have observed a huge overhead in MLA prefill forward pass when using envs.ENVIRON in kernel selection logic (where if-else
statement is involved).
Proposed action:
-
Layer Module is suggested to store the selected kernel ops as a property of the layer.
@cache
is discourage due to the increasing complexity that it is causing to clear the cache as there are many properties depending on envs.
@cache
is discouraged in several PRs review as there is a usecase as such:
Users instantiate multiple LLMs in a single python program. Each LLM instance uses different sets of ENV variables. -
Document the overhead issue down in vLLM documentation page under Contribution section to remind developers of the abstract and the overhead caused by envs.ENVIRON invocation.
Overhead experiments
Average time per accessing envs.ENVIRON : 1.0514259338378907e-06 seconds
Average time per accessing class method access : 3.0994415283203126e-08 seconds
Script:
import time
import torch
import vllm.envs as envs
class Layer:
@staticmethod
def forward():
return torch.randn(1024, 1024, dtype=torch.float16, device='cuda')
def test_envs_timing():
# Time the operation
num_runs = 100
start_time = time.time()
for _ in range(num_runs):
envs.VLLM_ROCM_USE_AITER_LINEAR
end_time = time.time()
# Calculate average time per run
avg_time = (end_time - start_time) / num_runs
print(f"Average time per accessing envs.ENVIRON : {avg_time} seconds")
def test_class_method():
# Time the operation
num_runs = 100
start_time = time.time()
for _ in range(num_runs):
Layer.forward
end_time = time.time()
# Calculate average time per run
avg_time = (end_time - start_time) / num_runs
print(f"Average time per accessing class method access : {avg_time} seconds")
if __name__ == "__main__":
test_envs_timing()
test_class_method()
Proposed Change.
As an example, in model layers, such as linear layers, they should use the following abstraction:
vllm/model_executor/layers/quantization/fp8.py
from vllm.model_executor.layers.quantization.fp8_utils import dispatch_w8a8_blockscale_func
class Fp8LinearMethod(LinearMethodBase):
def __init__(self, quant_config: Fp8Config):
...
if self.block_quant:
self.linear_func = dispatch_w8a8_blockscale_func(self.cutlass_block_fp8_supported)
else:
# logic to select the per-tensor/ per-channel scaled gemm kernel
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
# no more checks for condition
return self.linear_func(input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
out_dtype=self.out_dtype,
input_scale=layer.input_scale,
bias=bias)
aiter-block-gemm-integration/vllm/model_executor/layers/quantization/utils/fp8_utils.py
def dispatch_w8a8_blockscale_func(
use_cutlass: bool) -> Callable[..., torch.Tensor]:
if use_cutlass:
return cutlass_scaled_mm
if (current_platform.is_rocm() and
envs.VLLM_ROCM_USE_AITER and
envs.VLLM_ROCM_USE_AITER_LINEAR):
return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale
return w8a8_block_fp8_matmul
Feedback Period.
No response
CC List.
No response
Any Other Things.
No response
Before submitting a new issue...
- Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.