Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 43 additions & 32 deletions tensorrt_llm/bench/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@ def prepare_multimodal_inputs(model_dir: str,
model_type: str,
modality: str,
prompts: List[str],
media: List[str],
media: List[List[str]],
image_data_format: str = "pt",
num_frames: int = 8) -> List[Dict[str, Any]]:
assert model_type in INPUT_FORMATTER_MAP, f"Model type {model_type} not in supported list of models: {INPUT_FORMATTER_MAP.keys()}"
formatter = INPUT_FORMATTER_MAP[model_type]

inputs = []
if modality == "image":
Expand All @@ -28,7 +30,7 @@ def prepare_multimodal_inputs(model_dir: str,
else:
raise ValueError(f"Unsupported modality: {modality}")

inputs = INPUT_FORMATTER_MAP[model_type](model_dir, inputs)
inputs = formatter(model_dir, inputs)

return inputs

Expand Down Expand Up @@ -100,56 +102,65 @@ def create_dataset_from_stream(
# For each line in the standard input, parse out the JSON string we expect
# to see.
# Note the := walrus -- we're assigning and checking the condition.
all_isl = []
all_osl = []
all_seq_len = []
while (line := stream.readline()) and len(dataset) < max_requests:
prompts = []
media_paths = []
all_logits = []
task_ids = []
while (line := stream.readline()) and len(task_ids) < max_requests:
# We expect the data to come in as a JSON string.
# For example:
# {"prompt": "Generate an infinite response to the following:
# There once was a man who.", "output_tokens": 1000}
#
# For multimodal data, the data should be of the form:
# {"prompt": "Generate an infinite response to the following:
# There once was a man who.", "output_tokens": 1000,
# "media_paths": ["/path/to/image1.jpg", "/path/to/image2.jpg"]}
#
# Each line should be a complete JSON dictionary with no indentation
# or newline characters.
data = json.loads(line)
prompts.append(data.get("prompt"))
media_paths.append(data.get("media_paths", None))
all_logits.append(data.get("input_ids", data.get("logits", None)))
all_osl.append(data.get("output_tokens"))
task_ids.append(data.get("task_id"))

if modality is not None:
# Multimodal data need extra preprocessing
assert modality in [
"image", "video"
], f"Modality must be one of ['image', 'video'] but got {modality}."
prompts = prepare_multimodal_inputs(model_dir,
model_type,
modality,
prompts=prompts,
media=media_paths) # list of dicts

all_isl = []
all_seq_len = []
for prompt, logits, osl, task_id in zip(prompts, all_logits, all_osl,
task_ids):
if modality is not None:
# Multimodal data
assert modality in [
"image", "video"
], f"Modality must be one of ['image', 'video'] but got {modality}."

prompt = data.get("prompt") # cannot be None
media_paths = data.get("media_paths", None)
inputs = prepare_multimodal_inputs(
model_dir,
model_type,
modality,
prompts=[prompt],
media=media_paths) # list of dicts
logits = None # cannot tokenize multi-modal data, handled by preprocessor
prompt = inputs[0]
# NOTE: we cannot tokenize multi-modal data, handled by preprocessor
# so the actual sequence length is unknown until the model is run
logits = None
cur_isl = max_input_seq_len_for_multimodal
else:
logits = data.get("input_ids", data.get("logits", None))
prompt = data.get("prompt", None)
# If the request comes in with logits, just use the provided.
# Otherwise we need to tokenize it.
logits = tokenize(prompt)["input_ids"] if logits is None else logits
task_id = data["task_id"]
osl = data["output_tokens"]
cur_isl = len(logits)
all_isl.append(cur_isl)
all_seq_len.append(cur_isl + osl)

request = InferenceRequest(
task_id=task_id,
prompt=prompt,
output_tokens=output_limiter(osl),
input_ids=logits,
)
all_osl.append(osl)
if modality is not None:
cur_isl = max_input_seq_len_for_multimodal # NOTE: actual sequence length is unknown until the model is run
all_isl.append(cur_isl)
all_seq_len.append(cur_isl + osl)
else:
all_isl.append(len(logits))
all_seq_len.append(len(logits) + osl)
dataset.append(request)

isl_stats = PercentileStats.from_iterable(all_isl)
Expand Down
11 changes: 9 additions & 2 deletions tensorrt_llm/inputs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,10 +182,13 @@ def apply_template(prompt, multimodal_data):
return inputs


def default_image_loader(prompts, images, image_data_format="pt"):
def default_image_loader(prompts: List[str],
images: Union[List[List[str]], List[str]],
image_data_format: str = "pt"):
if len(images) > len(prompts) and len(prompts) == 1:
# 1 prompt + N media
images = [images]
assert len(images) == len(prompts)
inputs = [{
"prompt": prompt,
"multi_modal_data": {
Expand All @@ -199,10 +202,14 @@ def default_image_loader(prompts, images, image_data_format="pt"):
return inputs


def default_video_loader(prompts, videos, image_data_format="pt", num_frames=8):
def default_video_loader(prompts: List[str],
videos: Union[List[List[str]], List[str]],
image_data_format: str = "pt",
num_frames: int = 8):
if len(videos) > len(prompts) and len(prompts) == 1:
# 1 prompt + N media
videos = [videos]
assert len(videos) == len(prompts)
inputs = [{
"prompt": prompt,
"multi_modal_data": {
Expand Down