Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 6 additions & 13 deletions cpp/tensorrt_llm/thop/attentionOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -335,8 +335,8 @@ using RunnerPtr = std::shared_ptr<torch_ext::trtllm::attention::RunnerBase>;
using torch_ext::trtllm::attention::Runner;
using torch_ext::trtllm::attention::AttentionInputType;

torch::Tensor attention(torch::Tensor q, torch::optional<torch::Tensor> k, torch::optional<torch::Tensor> v,
std::optional<torch::ScalarType> out_dtype, torch::optional<torch::Tensor> workspace_,
void attention_inplace(torch::Tensor q, torch::optional<torch::Tensor> k, torch::optional<torch::Tensor> v,
torch::Tensor& output, std::optional<torch::ScalarType> out_dtype, torch::optional<torch::Tensor> workspace_,
torch::Tensor sequence_length, torch::Tensor host_past_key_value_lengths, torch::Tensor context_lengths,
torch::Tensor host_context_lengths, torch::Tensor host_request_types,
torch::optional<torch::Tensor> kv_cache_block_offsets, torch::optional<torch::Tensor> host_kv_cache_block_offsets,
Expand Down Expand Up @@ -549,12 +549,6 @@ torch::Tensor attention(torch::Tensor q, torch::optional<torch::Tensor> k, torch
workspace = torch::empty({workspace_size}, torch::dtype(torch::kByte).device(qkv.device()));
}

int64_t v_head_size = !op->mIsMLAEnabled ? head_size
: is_gen_only ? op->mMLAParams.kv_lora_rank
: v_head_dim.value();
auto output = torch::empty(
{num_tokens, num_heads * v_head_size}, qkv.options().dtype(out_dtype.value_or(qkv.scalar_type())));

if ((num_contexts > 0) && (attn_input_type != AttentionInputType::GenerationOnly))
{
auto seq_offset = 0;
Expand Down Expand Up @@ -585,19 +579,18 @@ torch::Tensor attention(torch::Tensor q, torch::optional<torch::Tensor> k, torch
}

TLLM_LOG_TRACE("Attention op stops at layer %d", layer_idx);

return output;
}

} // namespace torch_ext

TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def(
"attention("
"attention_inplace("
"Tensor q"
", Tensor? k"
", Tensor? v"
", Tensor(a!) output"
", ScalarType? out_dtype"
", Tensor? workspace"
", Tensor sequence_length"
Expand Down Expand Up @@ -653,10 +646,10 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
", int? v_head_dim"
", Tensor? mrope_rotary_cos_sin"
", Tensor? mrope_position_deltas"
") -> Tensor");
") -> ()");
}

TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
{
m.impl("attention", &torch_ext::attention);
m.impl("attention_inplace", &torch_ext::attention_inplace);
}
10 changes: 10 additions & 0 deletions examples/pytorch/quickstart_advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,14 @@ def add_llm_args(parser):
default=False,
action='store_true',
help='Print iteration logs during execution')
parser.add_argument('--use_torch_compile',
default=False,
action='store_true',
help='Use torch.compile to optimize the model')
parser.add_argument('--use_piecewise_cuda_graph',
default=False,
action='store_true',
help='Use piecewise CUDA graph to optimize the model')

# Sampling
parser.add_argument("--max_tokens", type=int, default=64)
Expand Down Expand Up @@ -122,6 +130,8 @@ def setup_llm(args):
use_cuda_graph=args.use_cuda_graph,
load_format=args.load_format,
print_iter_log=args.print_iter_log,
torch_compile_enabled=args.use_torch_compile,
torch_compile_piecewise_cuda_graph=args.use_piecewise_cuda_graph,
moe_backend=args.moe_backend,
enable_trtllm_decoder=args.enable_trtllm_decoder)

Expand Down
25 changes: 19 additions & 6 deletions tensorrt_llm/_torch/attention_backend/trtllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,13 @@ def prepare(self) -> None:
assert self.kv_lens[:self.num_seqs].max(
) <= self.kv_cache_manager.max_seq_len, f"Please set max_seq_len to at least {self.kv_lens[:self.num_seqs].max()} for kv cache manager."

self.kv_lens_cuda_runtime = self.kv_lens_cuda[:self.num_seqs]
self.kv_lens_runtime = self.kv_lens[:self.num_seqs]
self.prompt_lens_cuda_runtime = self.prompt_lens_cuda[:self.num_seqs]
self.prompt_lens_cpu_runtime = self.prompt_lens_cpu[:self.num_seqs]
self.host_request_types_runtime = self.host_request_types[:self.
num_seqs]

def prepare_flash_mla(self) -> None:
block_ids_per_seq = self.kv_cache_manager.get_block_ids_per_seq(
self.request_ids).pin_memory()
Expand All @@ -554,6 +561,13 @@ def prepare_flash_mla(self) -> None:
self.block_ids_per_seq[:self.num_generations, :num_blocks].copy_(
block_ids_per_seq[self.num_contexts:], non_blocking=True)

self.kv_lens_cuda_runtime = self.kv_lens_cuda[:self.num_seqs]
self.kv_lens_runtime = self.kv_lens[:self.num_seqs]
self.prompt_lens_cuda_runtime = self.prompt_lens_cuda[:self.num_seqs]
self.prompt_lens_cpu_runtime = self.prompt_lens_cpu[:self.num_seqs]
self.host_request_types_runtime = self.host_request_types[:self.
num_seqs]


class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):

Expand Down Expand Up @@ -662,7 +676,6 @@ def forward(
or metadata.runtime_features.has_speculative_draft_tokens
) if metadata.runtime_features else False

num_seqs = metadata.num_seqs
self.wrapper.plan(
tokens_per_block=metadata.tokens_per_block,
max_num_requests=metadata.max_num_requests,
Expand All @@ -672,11 +685,11 @@ def forward(
attention_window_size=None,
sink_token_length=0,
beam_width=1,
sequence_length=metadata.kv_lens_cuda[:num_seqs],
host_past_key_value_lengths=metadata.kv_lens[:num_seqs],
context_lengths=metadata.prompt_lens_cuda[:num_seqs],
host_context_lengths=metadata.prompt_lens_cpu[:num_seqs],
host_request_types=metadata.host_request_types[:num_seqs],
sequence_length=metadata.kv_lens_cuda_runtime,
host_past_key_value_lengths=metadata.kv_lens_runtime,
context_lengths=metadata.prompt_lens_cuda_runtime,
host_context_lengths=metadata.prompt_lens_cpu_runtime,
host_request_types=metadata.host_request_types_runtime,
kv_cache_block_offsets=metadata.kv_cache_block_offsets,
host_kv_cache_block_offsets=metadata.host_kv_cache_block_offsets,
host_kv_cache_pool_pointers=metadata.host_kv_cache_pool_pointers,
Expand Down
99 changes: 74 additions & 25 deletions tensorrt_llm/_torch/compilation/backend.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,36 @@
import os
from typing import List, Optional, Union
from typing import List, Optional

import torch
from torch._functorch.aot_autograd import aot_module_simplified
from torch._inductor.compile_fx import compile_fx
from torch._inductor.compile_fx import compile_fx, select_decomp_table
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch.fx import Graph, GraphModule
from torch._subclasses import FakeTensor
from torch.fx import GraphModule

import tensorrt_llm
from tensorrt_llm import logger

from .patterns.ar_residual_norm import register_ar_residual_norm
from .patterns.residual_add_norm import register_add_norm
from .patterns.ub_allreduce import register_ub_patterns
from .piecewise_optimizer import piecewise_optimizer
from .recover_pass import recover_pass
from .remove_copy_pass import remove_copy_for_mutates_args


class Backend:

_custom_pass_instances: List[PatternMatcherPass] = None
_graph_pool_handle: tuple[int, int] = None

def __init__(self, enable_inductor=True, enable_userbuffers=False) -> None:
def __init__(
self,
enable_inductor=True,
enable_userbuffers=False,
enable_piecewise_cuda_graph: bool = False,
cuda_graph_batch_sizes: Optional[List[int]] = None,
) -> None:
super().__init__()
self.elapsed_time = 0
self.module_inference_event = []
Expand All @@ -28,14 +39,16 @@ def __init__(self, enable_inductor=True, enable_userbuffers=False) -> None:
self.custom_passes = Backend.get_custom_pass(enable_userbuffers)
self.rank = tensorrt_llm.mpi_rank()
self.enable_inductor = enable_inductor
self.cuda_graph_batch_sizes = (cuda_graph_batch_sizes
if cuda_graph_batch_sizes is not None
else [])
self.piecewise_cuda_graph = enable_piecewise_cuda_graph
self.no_optimization = False

self.match_count = []

if enable_inductor:
from torch._inductor import config
if Backend._graph_pool_handle is None:
Backend._graph_pool_handle = torch.cuda.graph_pool_handle()

self.inductor_config = config.get_config_copy()
self.inductor_config["joint_custom_post_pass"] = self.optimize
self.match_count = []

@classmethod
def get_custom_pass(cls, enable_userbuffers):
Expand All @@ -56,32 +69,68 @@ def get_custom_pass(cls, enable_userbuffers):
register_add_norm(cls._custom_pass_instances[0])
return cls._custom_pass_instances

def bypass_optimization(self):
self.no_optimization = True

def enable_optimization(self):
self.no_optimization = False

def optimize(
self,
gm: Union[GraphModule | Graph],
example_inputs: Optional[List[torch.Tensor]] = None,
gm: GraphModule,
example_inputs: List[torch.Tensor],
):
graph = gm.graph if isinstance(gm, GraphModule) else gm
graph = gm.graph
for custom_pass in self.custom_passes:
self.match_count.append(custom_pass.apply(graph))
while self.match_count[-1]:
self.match_count.append(custom_pass.apply(graph))
graph.eliminate_dead_code()
if isinstance(gm, GraphModule):
gm.recompile()

return gm
# After this pass, cannot run any dce!!!
remove_copy_for_mutates_args(graph)
gm.recompile()

if self.piecewise_cuda_graph:
return piecewise_optimizer(
gm,
example_inputs,
self.enable_inductor,
self.input_num_tokens,
self.cuda_graph_batch_sizes,
self._graph_pool_handle,
)
elif self.enable_inductor:
return compile_fx(gm, example_inputs)
else:
return gm

def __call__(self, gm: GraphModule,
example_inputs: List[torch.Tensor]) -> callable:

if self.no_optimization:
logger.warning(
"Bypassing torch.compile optimization and fallback to eager execution!"
)
return gm

for node in gm.graph.nodes:
if node.op == "placeholder":
if node.name == "l_input_ids_":
example_value = node.meta["example_value"]
assert isinstance(example_value, FakeTensor)
self.input_num_tokens = example_value.shape[0]
break

if self.piecewise_cuda_graph:
assert (
self.input_num_tokens is not None
), "Cannot detect input_num_tokens. Cannot use piecewise CUDA graph. What is the name of `input_ids`?"

gm = recover_pass(gm)

if self.enable_inductor:
return compile_fx(gm,
example_inputs,
config_patches=self.inductor_config)
else:
return aot_module_simplified(gm,
example_inputs,
fw_compiler=self.optimize)
return aot_module_simplified(
gm,
example_inputs,
fw_compiler=self.optimize,
decompositions=select_decomp_table(),
)
Loading