Skip to content

Commit 4fdc15e

Browse files
committed
[tnx] neuron rolling batch test suite
1 parent 1038c63 commit 4fdc15e

File tree

5 files changed

+582
-30
lines changed

5 files changed

+582
-30
lines changed
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
#!/usr/bin/env python
2+
#
3+
# Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
6+
# except in compliance with the License. A copy of the License is located at
7+
#
8+
# http://aws.amazon.com/apache2.0/
9+
#
10+
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
11+
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
12+
# the specific language governing permissions and limitations under the License.
13+
14+
from collections import defaultdict
15+
from typing import List, Dict
16+
from dataclasses import dataclass
17+
from djl_python import test_model, Input
18+
from djl_python.request import Request
19+
from djl_python.input_parser import format_input
20+
21+
@dataclass
22+
class SimulationSchedule:
23+
prompts: List[str]
24+
params: List[Dict]
25+
reqs_to_prefill: List[int]
26+
wait_steps: List[int]
27+
28+
29+
class NeuronRollingBatchGenerator:
30+
31+
def __init__(self):
32+
self.rolling_batch = None
33+
self._req_id = 0
34+
# Store the results
35+
self.output_all = defaultdict(list)
36+
self.input_all = {}
37+
self.data_collector = []
38+
self.responses = []
39+
40+
# Status variables
41+
self.input_str = []
42+
self.params = []
43+
self.req_ids = []
44+
45+
# Spec_dec
46+
self.token_numbers = defaultdict(list)
47+
48+
def init_neuron_service(self, properties: dict):
49+
from djl_python.transformers_neuronx import TransformersNeuronXService
50+
_service = TransformersNeuronXService()
51+
_service.initialize(properties)
52+
self.rolling_batch = _service.rolling_batch
53+
54+
def get_req_id(self):
55+
req_id = self._req_id
56+
self._req_id = self._req_id + 1
57+
return req_id
58+
59+
def collect_data(self, result):
60+
done_requests_indices = []
61+
for idx, item in enumerate(result):
62+
if len(self.data_collector) <= idx:
63+
self.data_collector.append(item["data"])
64+
else:
65+
self.data_collector[idx] += item["data"]
66+
if item['last']:
67+
done_requests_indices.append(idx)
68+
for idx in sorted(done_requests_indices, reverse=True):
69+
value = self.data_collector.pop(idx)
70+
self.responses.append(value)
71+
print(f"\nFinished request: {value}\n")
72+
return done_requests_indices
73+
74+
def build_request(self, raw_input):
75+
inputs = test_model.create_json_request(raw_input)
76+
parsed_inputs = format_input(inputs)
77+
request = Request(parsed_inputs)
78+
request.id = self.get_req_id()
79+
return request
80+
81+
def simulator(self, schedule: SimulationSchedule):
82+
assert len(schedule.prompts) == len(schedule.params)
83+
assert len(schedule.reqs_to_prefill) == len(schedule.wait_steps)
84+
zipped_requests = zip(schedule.prompts, schedule.params)
85+
all_requests = [{
86+
"inputs": prompt,
87+
"parameters": params
88+
} for prompt, params in zipped_requests]
89+
current_requests = []
90+
new_requests = []
91+
for batch_size, step in zip(schedule.reqs_to_prefill,
92+
schedule.wait_steps):
93+
for _ in range(batch_size):
94+
request = self.build_request(all_requests.pop(0))
95+
new_requests = [request] + new_requests
96+
current_requests.append(request)
97+
98+
for i in range(step):
99+
if len(current_requests) == 0:
100+
break
101+
generated_tokens = self.rolling_batch.inference(new_requests)
102+
new_requests.clear()
103+
finished_indices = self.collect_data(generated_tokens)
104+
for idx in sorted(finished_indices, reverse=True):
105+
current_requests.pop(idx)
106+
while len(current_requests) > 0:
107+
generated_tokens = self.rolling_batch.inference(new_requests)
108+
finished_indices = self.collect_data(generated_tokens)
109+
for idx in sorted(finished_indices, reverse=True):
110+
current_requests.pop(idx)
111+
112+
def step(self, step=20, input_str_delta=None, params_delta=None):
113+
if input_str_delta:
114+
begin_id = max(self.input_all.keys(), default=0) + 1
115+
req_ids_delta = list(
116+
range(begin_id, begin_id + len(input_str_delta)))
117+
118+
self.input_str += input_str_delta
119+
self.params += params_delta
120+
self.req_ids += req_ids_delta
121+
for req_id, input_s, param in zip(req_ids_delta, input_str_delta,
122+
params_delta):
123+
self.input_all[req_id] = (input_s, param)
124+
125+
iterator = range(step)
126+
for i in iterator:
127+
result = self.rolling_batch.inference(self.input_str, self.params)
128+
for res, req_id in zip(result, self.req_ids):
129+
self.output_all[req_id].append(res['data'])
130+
self.token_numbers[req_id].append(res.get('step_token_num', 1))
131+
self.req_ids = [
132+
req_id for req_id, res in zip(self.req_ids, result)
133+
if not res['last']
134+
]
135+
self.input_str = [
136+
s for s, res in zip(self.input_str, result) if not res['last']
137+
]
138+
self.params = [
139+
p for p, res in zip(self.params, result) if not res['last']
140+
]
141+
if not self.req_ids:
142+
break
143+
144+
def is_empty(self):
145+
return not self.req_ids
146+
147+
def reset(self):
148+
self.data_collector = []
149+
self.rolling_batch = None
150+
# Store the results
151+
self.output_all = defaultdict(list)
152+
self.input_all = {}
153+
154+
# Status variables, the remaining
155+
self.input_str = []
156+
self.params = []
157+
self.req_ids = []
158+
159+
self.token_numbers = defaultdict(list)
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
#!/usr/bin/env python
2+
#
3+
# Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
6+
# except in compliance with the License. A copy of the License is located at
7+
#
8+
# http://aws.amazon.com/apache2.0/
9+
#
10+
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
11+
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
12+
# the specific language governing permissions and limitations under the License.
13+
14+
import unittest
15+
import json
16+
import os
17+
18+
try:
19+
from djl_python.transformers_neuronx import TransformersNeuronXService
20+
SKIP_TEST = False
21+
except ImportError:
22+
SKIP_TEST = True
23+
24+
expected_text_30 = {
25+
"TinyLlama/TinyLlama-1.1B-Chat-v0.6": {
26+
0:
27+
"Hello, my name is [Your Name] and I am a [Your Job Title] at [Your Company Name]. I am interested in learning more about your company'",
28+
1:
29+
'The president of the United States is a man named Donald Trump.\n\n2. The president of the United States is a man named Donald Trump.\n\n3. The president',
30+
2:
31+
'The capital of France is Paris.\n\n2. The capital of the United States is Washington, D.C.\n\n3. The capital of Canada is Ott',
32+
3:
33+
"The future of AI is bright, and it's already here. With the help of AI, we can create more personalized experiences, automate repetitive tasks, and even predict the future.",
34+
}
35+
}
36+
37+
38+
@unittest.skipIf(SKIP_TEST, "Neuron dependencies are not available")
39+
class TestNeuronRollingBatch(unittest.TestCase):
40+
41+
def test_models(self):
42+
# === Preparation ===
43+
from djl_python.tests.neuron_test_scripts.neuron_rb_generator import NeuronRollingBatchGenerator, SimulationSchedule
44+
45+
# --- Models ---
46+
model_names = [
47+
"TinyLlama/TinyLlama-1.1B-Chat-v0.6",
48+
]
49+
50+
# === Test ===
51+
for model_id in model_names:
52+
properties = {
53+
"tensor_parallel_degree": 2,
54+
"n_positions": "128",
55+
"rolling_batch": "tnx",
56+
"max_rolling_batch_size": 4,
57+
"model_id": model_id
58+
}
59+
60+
# ===================== neuron-tnx ============================
61+
gen = NeuronRollingBatchGenerator()
62+
gen.init_neuron_service(properties)
63+
64+
print('========== init inference ===========')
65+
input_str = [
66+
"Hello, my name is",
67+
"The president of the United States is",
68+
"The capital of France is",
69+
"The future of AI is",
70+
]
71+
72+
params = [{
73+
"max_new_tokens": 100,
74+
"do_sample": False,
75+
}.copy() for _ in range(len(input_str))]
76+
77+
test_input = SimulationSchedule(prompts=input_str,
78+
params=params,
79+
reqs_to_prefill=[1, 2, 1],
80+
wait_steps=[1, 4, 5])
81+
82+
gen.simulator(test_input)
83+
84+
for i, out in enumerate(gen.responses):
85+
out_dict = json.loads(''.join(out))
86+
out_str = out_dict["generated_text"]
87+
test_generation = input_str[i] + " " + out_str
88+
print(f"\n====req_id: {i}=====\n{test_generation}\n")
89+
if model_id in expected_text_30 and i in expected_text_30[
90+
model_id]:
91+
expected_prefix_30_req_id = expected_text_30[model_id][i]
92+
assert expected_prefix_30_req_id == test_generation[:len(
93+
expected_prefix_30_req_id)]
94+
95+
gen.reset()
96+
del gen
97+
import gc
98+
gc.collect()
99+
100+
def test_tiny_models(self):
101+
# === Preparation ===
102+
from djl_python.tests.neuron_test_scripts.neuron_rb_generator import NeuronRollingBatchGenerator, SimulationSchedule
103+
from djl_python.tests.neuron_test_scripts.tiny_models import artifacts
104+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
105+
106+
# --- Models ---
107+
model_names = [
108+
"llama",
109+
"gpt2",
110+
"gptneox",
111+
"bloom",
112+
]
113+
114+
# === Test ===
115+
for model_id in model_names:
116+
properties = {
117+
"tensor_parallel_degree": 2,
118+
"n_positions": "128",
119+
"rolling_batch": "tnx",
120+
"max_rolling_batch_size": 4,
121+
"model_loading_timeout": 3600,
122+
"model_id": artifacts(model_id)
123+
}
124+
125+
# ===================== neuron-tnx ============================
126+
gen = NeuronRollingBatchGenerator()
127+
gen.init_neuron_service(properties)
128+
129+
print('========== init inference ===========')
130+
input_str = [
131+
"Hello, my name is",
132+
"The president of the United States is",
133+
"The capital of France is",
134+
"The future of AI is",
135+
]
136+
137+
params = [{
138+
"max_new_tokens": 100,
139+
"do_sample": False,
140+
"ignore_eos": True,
141+
}.copy() for _ in range(len(input_str))]
142+
143+
test_input = SimulationSchedule(prompts=input_str,
144+
params=params,
145+
reqs_to_prefill=[1, 2, 1],
146+
wait_steps=[1, 4, 5])
147+
148+
gen.simulator(test_input)
149+
gen.reset()
150+
del gen
151+
import gc
152+
gc.collect()
153+
154+
155+
if __name__ == '__main__':
156+
unittest.main()

0 commit comments

Comments
 (0)