Skip to content

Commit 367237e

Browse files
refactor read_res
1 parent 7155ae1 commit 367237e

File tree

1 file changed

+22
-32
lines changed

1 file changed

+22
-32
lines changed

llm/predict/predictor.py

Lines changed: 22 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1043,24 +1043,19 @@ def predict(self, input_texts: list[str], return_tokens=False):
10431043

10441044
# whether speculative decoding
10451045
if self.proposer is None:
1046-
read_res_process = mp.Process(
1047-
target=llm_utils.read_res, args=[self.model_name_or_path, tensor_queue, result_queue, done_event]
1048-
)
1049-
if self.tensor_parallel_rank == 0:
1050-
read_res_process.start()
1051-
1052-
output_tensor = paddle.full(shape=[MAX_BSZ + 2, 1], fill_value=2, dtype="int64").cpu()
1046+
read_res_func = llm_utils.read_res
1047+
output_tensor_shape = [MAX_BSZ + 2, 1]
10531048
else:
1054-
read_res_process = mp.Process(
1055-
target=llm_utils.speculate_read_res,
1056-
args=[self.model_name_or_path, tensor_queue, result_queue, done_event],
1057-
)
1058-
if self.tensor_parallel_rank == 0:
1059-
read_res_process.start()
1049+
read_res_func = llm_utils.speculate_read_res
1050+
output_tensor_shape = [SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2, 1]
1051+
1052+
read_res_process = mp.Process(
1053+
target=read_res_func, args=[self.model_name_or_path, tensor_queue, result_queue, done_event]
1054+
)
1055+
if self.tensor_parallel_rank == 0:
1056+
read_res_process.start()
10601057

1061-
output_tensor = paddle.full(
1062-
shape=[SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2, 1], fill_value=2, dtype="int64"
1063-
).cpu()
1058+
output_tensor = paddle.full(shape=output_tensor_shape, fill_value=2, dtype="int64").cpu()
10641059

10651060
tensor_queue.put(output_tensor)
10661061
if self.tensor_parallel_rank == 0:
@@ -1205,24 +1200,19 @@ def predict(self, input_texts: list[str], return_tokens=False):
12051200

12061201
# whether speculative decoding
12071202
if self.proposer is None:
1208-
read_res_process = mp.Process(
1209-
target=llm_utils.read_res, args=[self.model_name_or_path, tensor_queue, result_queue, done_event]
1210-
)
1211-
1212-
if self.tensor_parallel_rank == 0:
1213-
read_res_process.start()
1214-
output_tensor = paddle.full(shape=[MAX_BSZ + 2, 1], fill_value=2, dtype="int64").cpu()
1203+
read_res_func = llm_utils.read_res
1204+
output_tensor_shape = [MAX_BSZ + 2, 1]
12151205
else:
1216-
read_res_process = mp.Process(
1217-
target=llm_utils.speculate_read_res,
1218-
args=[self.model_name_or_path, tensor_queue, result_queue, done_event],
1219-
)
1206+
read_res_func = llm_utils.speculate_read_res
1207+
output_tensor_shape = [SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2, 1]
1208+
1209+
read_res_process = mp.Process(
1210+
target=read_res_func, args=[self.model_name_or_path, tensor_queue, result_queue, done_event]
1211+
)
1212+
if self.tensor_parallel_rank == 0:
1213+
read_res_process.start()
12201214

1221-
if self.tensor_parallel_rank == 0:
1222-
read_res_process.start()
1223-
output_tensor = paddle.full(
1224-
shape=[SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2, 1], fill_value=2, dtype="int64"
1225-
).cpu()
1215+
output_tensor = paddle.full(shape=output_tensor_shape, fill_value=2, dtype="int64").cpu()
12261216

12271217
tensor_queue.put(output_tensor)
12281218
if self.tensor_parallel_rank == 0:

0 commit comments

Comments
 (0)