Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions cpp/tensorrt_llm/common/attentionOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,8 @@ bool AttentionOp::convertMMHAParamsToXQAParams(tensorrt_llm::kernels::XQAParams&
xqaParams.paged_kv_cache = mPagedKVCache;
xqaParams.tokens_per_block = mTokensPerBlock;
xqaParams.kv_cache_quant_mode = mKVCacheQuantMode;
xqaParams.tp_size = mTpSize;
xqaParams.tp_rank = mTpRank;
xqaParams.tp_size = mAttnTpSize;
xqaParams.tp_rank = mAttnTpRank;
xqaParams.qkv_bias_enabled = mQKVBiasEnabled;
xqaParams.cross_attention = mCrossAttention;
xqaParams.max_distance = mMaxDistance;
Expand Down Expand Up @@ -223,6 +223,15 @@ bool AttentionOp::convertMMHAParamsToXQAParams(tensorrt_llm::kernels::XQAParams&
xqaParams.host_context_lengths = generationsParams.host_context_lengths;
xqaParams.semaphores = generationsParams.semaphores;
xqaParams.workspaces = generationsParams.workspace;
if (mCpSize > 1)
{
size_t const batch_beam = generationsParams.beam_width * generationsParams.num_requests;
size_t const cpMaxPaddedSequenceLength = (batch_beam + mCpSize - 1) / mCpSize * mCpSize;
size_t const cpWorkspaceSize
= 2 * sizeof(T) * cpMaxPaddedSequenceLength * (mNumHeads + 2 * mNumKVHeads) * mHeadSize;
xqaParams.workspaces
= reinterpret_cast<void*>(reinterpret_cast<int8_t*>(xqaParams.workspaces) + cpWorkspaceSize);
}
xqaParams.batch_size = generationsParams.num_requests;
xqaParams.beam_width = generationsParams.beam_width;
// Speculative decoding mode has generation input_length > 1.
Expand Down Expand Up @@ -254,7 +263,7 @@ bool AttentionOp::convertMMHAParamsToXQAParams(tensorrt_llm::kernels::XQAParams&
xqaParams.mrope_position_deltas = generationsParams.mrope_position_deltas;

xqaParams.logn_scaling_ptr = generationsParams.logn_scaling_ptr;
xqaParams.total_num_input_tokens = generationsParams.num_tokens;
xqaParams.total_num_input_tokens = mCpSize > 1 ? generationsParams.num_requests : generationsParams.num_tokens;
xqaParams.is_fp8_output = mFP8ContextFMHA;
xqaParams.fp8_out_scale = (mFP8ContextFMHA ? generationsParams.attention_output_orig_quant : nullptr);
// Parameters required for FP4 output.
Expand Down
21 changes: 20 additions & 1 deletion tests/integration/defs/accuracy/test_llm_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from tensorrt_llm.quantization import QuantAlgo

from ..conftest import llm_models_root, skip_post_blackwell, skip_pre_ada
from .accuracy_core import MMLU, CnnDailymail, LlmapiAccuracyTestHarness
from .accuracy_core import GSM8K, MMLU, CnnDailymail, LlmapiAccuracyTestHarness


class TestLlama3_1_8B(LlmapiAccuracyTestHarness):
Expand All @@ -38,6 +38,25 @@ def test_fp8_rowwise(self):
task.evaluate(llm)


class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
MODEL_PATH = f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct"

@pytest.mark.skip_less_device(2)
def test_cp2(self):
with LLM(self.MODEL_PATH, context_parallel_size=2) as llm:
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)

@pytest.mark.skip_less_device(4)
def test_tp2cp2(self):
with LLM(self.MODEL_PATH,
tensor_parallel_size=2,
context_parallel_size=2) as llm:
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)


class TestMistral7B_0_3(LlmapiAccuracyTestHarness):
MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.3"
MODEL_PATH = f"{llm_models_root()}/Mistral-7B-Instruct-v0.3"
Expand Down
2 changes: 2 additions & 0 deletions tests/integration/test_lists/test-db/l0_dgx_h100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ l0_dgx_h100:
- accuracy/test_cli_flow.py::TestMixtral8x7B::test_fp8_tp2pp2_manage_weights
- accuracy/test_cli_flow.py::TestQwen2_57B_A14B::test_tp4
- accuracy/test_llm_api.py::TestQwen2_7BInstruct::test_tp2
- accuracy/test_llm_api.py::TestLlama3_1_8BInstruct::test_cp2
- accuracy/test_llm_api.py::TestLlama3_1_8BInstruct::test_tp2cp2
- examples/test_llama.py::test_llm_llama_long_alpaca_8gpu_summary[pg64317-tp4pp2-nb:4]
- examples/test_llama.py::test_llm_llama_v2_lora_benchmark_2gpu[chinese_lora-llama-v2-13b-hf]
- examples/test_mixtral.py::test_llm_mixtral_moe_plugin_lora_4gpus[Mixtral-8x7B-v0.1-chinese-mixtral-lora]
Expand Down