Skip to content

Commit 1016a23

Browse files
authored
[multimodal][ci] adding multimodal tests (#2234)
1 parent 87b0c19 commit 1016a23

File tree

4 files changed

+84
-1
lines changed

4 files changed

+84
-1
lines changed

.github/workflows/integration.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@ jobs:
122122
instance: inf2
123123
- test: TestNeuronxRollingBatch
124124
instance: inf2
125+
- test: TestMultiModal
126+
instance: g6
125127
steps:
126128
- uses: actions/checkout@v4
127129
- name: Clean env

tests/integration/llm/client.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -760,6 +760,18 @@ def get_model_name():
760760
}
761761
}
762762

763+
multi_modal_spec = {
764+
"llava_v1.6-mistral": {
765+
"batch_size": [1, 4]
766+
},
767+
"paligemma-3b-mix-448": {
768+
"batch_size": [1, 4],
769+
},
770+
"phi-3-vision-128k-instruct": {
771+
"batch_size": [1, 4],
772+
}
773+
}
774+
763775

764776
def add_file_handler_to_logger(file_path: str):
765777
handler = logging.FileHandler(file_path, mode='w')
@@ -1430,6 +1442,42 @@ def test_correctness(model, model_spec):
14301442
validate_correctness(dataset, data, score)
14311443

14321444

1445+
def get_multimodal_prompt():
1446+
messages = [{
1447+
"role":
1448+
"user",
1449+
"content": [{
1450+
"type": "text",
1451+
"text": "What is this an image of?",
1452+
}, {
1453+
"type": "image_url",
1454+
"image_url": {
1455+
"url": "https://resources.djl.ai/images/dog_bike_car.jpg",
1456+
}
1457+
}]
1458+
}]
1459+
return {
1460+
"messages": messages,
1461+
"temperature": 0.9,
1462+
"top_p": 0.6,
1463+
"max_new_tokens": 512,
1464+
}
1465+
1466+
1467+
def test_multimodal(model, model_spec):
1468+
if model not in model_spec:
1469+
raise ValueError(
1470+
f"{model} is not currently supported {list(model_spec.keys())}")
1471+
spec = model_spec[model]
1472+
messages = get_multimodal_prompt()
1473+
for i, batch_size in enumerate(spec["batch_size"]):
1474+
awscurl_run(messages,
1475+
spec.get("tokenizer", None),
1476+
batch_size,
1477+
num_run=5,
1478+
output=True)
1479+
1480+
14331481
def run(raw_args):
14341482
parser = argparse.ArgumentParser(description="Build the LLM configs")
14351483
parser.add_argument("handler", help="the handler used in the model")
@@ -1507,6 +1555,8 @@ def run(raw_args):
15071555
test_handler_rolling_batch(args.model, no_code_rolling_batch_spec)
15081556
elif args.handler == "correctness":
15091557
test_correctness(args.model, correctness_model_spec)
1558+
elif args.handler == "multimodal":
1559+
test_multimodal(args.model, multi_modal_spec)
15101560

15111561
else:
15121562
raise ValueError(

tests/integration/llm/prepare.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -574,6 +574,16 @@
574574
"option.model_id": "s3://djl-llm/llama-2-tiny/",
575575
"option.quantize": "awq",
576576
"option.tensor_parallel_degree": 4
577+
},
578+
"llava_v1.6-mistral": {
579+
"option.model_id": "s3://djl-llm/llava-v1.6-mistral-7b-hf/",
580+
},
581+
"paligemma-3b-mix-448": {
582+
"option.model_id": "s3://djl-llm/paligemma-3b-mix-448/"
583+
},
584+
"phi-3-vision-128k-instruct": {
585+
"option.model_id": "s3://djl-llm/phi-3-vision-128k-instruct/",
586+
"option.trust_remote_code": True,
577587
}
578588
}
579589

@@ -784,7 +794,7 @@
784794
"option.dtype": "fp16",
785795
"option.tensor_parallel_degree": 4,
786796
"option.max_rolling_batch_size": 4,
787-
}
797+
},
788798
}
789799

790800
lmi_dist_aiccl_model_list = {

tests/integration/tests.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -883,3 +883,24 @@ def test_llama3_1_8b(self):
883883
prepare.build_correctness_model("neuronx-llama3-1-8b")
884884
r.launch(container='pytorch-inf2-2')
885885
client.run("correctness neuronx-llama3-1-8b".split())
886+
887+
888+
class TestMultiModalLmiDist:
889+
890+
def test_llava_next(self):
891+
with Runner('lmi', 'llava_v1.6-mistral') as r:
892+
prepare.build_lmi_dist_model('llava_v1.6-mistral')
893+
r.launch()
894+
client.run("multimodal llava_v1.6-mistral".split())
895+
896+
def test_paligemma(self):
897+
with Runner('lmi', 'paligemma-3b-mix-448') as r:
898+
prepare.build_lmi_dist_model('paligemma-3b-mix-448')
899+
r.launch()
900+
client.run("multimodal paligemma-3b-mix-448".split())
901+
902+
def test_phi3_v(self):
903+
with Runner('lmi', 'phi-3-vision-128k-instruct') as r:
904+
prepare.build_lmi_dist_model('phi-3-vision-128k-instruct')
905+
r.launch()
906+
client.run("multimodal phi-3-vision-128k-instruct".split())

0 commit comments

Comments
 (0)