@@ -1043,24 +1043,19 @@ def predict(self, input_texts: list[str], return_tokens=False):
10431043
10441044 # whether speculative decoding
10451045 if self .proposer is None :
1046- read_res_process = mp .Process (
1047- target = llm_utils .read_res , args = [self .model_name_or_path , tensor_queue , result_queue , done_event ]
1048- )
1049- if self .tensor_parallel_rank == 0 :
1050- read_res_process .start ()
1051-
1052- output_tensor = paddle .full (shape = [MAX_BSZ + 2 , 1 ], fill_value = 2 , dtype = "int64" ).cpu ()
1046+ read_res_func = llm_utils .read_res
1047+ output_tensor_shape = [MAX_BSZ + 2 , 1 ]
10531048 else :
1054- read_res_process = mp .Process (
1055- target = llm_utils .speculate_read_res ,
1056- args = [self .model_name_or_path , tensor_queue , result_queue , done_event ],
1057- )
1058- if self .tensor_parallel_rank == 0 :
1059- read_res_process .start ()
1049+ read_res_func = llm_utils .speculate_read_res
1050+ output_tensor_shape = [SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2 , 1 ]
1051+
1052+ read_res_process = mp .Process (
1053+ target = read_res_func , args = [self .model_name_or_path , tensor_queue , result_queue , done_event ]
1054+ )
1055+ if self .tensor_parallel_rank == 0 :
1056+ read_res_process .start ()
10601057
1061- output_tensor = paddle .full (
1062- shape = [SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2 , 1 ], fill_value = 2 , dtype = "int64"
1063- ).cpu ()
1058+ output_tensor = paddle .full (shape = output_tensor_shape , fill_value = 2 , dtype = "int64" ).cpu ()
10641059
10651060 tensor_queue .put (output_tensor )
10661061 if self .tensor_parallel_rank == 0 :
@@ -1205,24 +1200,19 @@ def predict(self, input_texts: list[str], return_tokens=False):
12051200
12061201 # whether speculative decoding
12071202 if self .proposer is None :
1208- read_res_process = mp .Process (
1209- target = llm_utils .read_res , args = [self .model_name_or_path , tensor_queue , result_queue , done_event ]
1210- )
1211-
1212- if self .tensor_parallel_rank == 0 :
1213- read_res_process .start ()
1214- output_tensor = paddle .full (shape = [MAX_BSZ + 2 , 1 ], fill_value = 2 , dtype = "int64" ).cpu ()
1203+ read_res_func = llm_utils .read_res
1204+ output_tensor_shape = [MAX_BSZ + 2 , 1 ]
12151205 else :
1216- read_res_process = mp .Process (
1217- target = llm_utils .speculate_read_res ,
1218- args = [self .model_name_or_path , tensor_queue , result_queue , done_event ],
1219- )
1206+ read_res_func = llm_utils .speculate_read_res
1207+ output_tensor_shape = [SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2 , 1 ]
1208+
1209+ read_res_process = mp .Process (
1210+ target = read_res_func , args = [self .model_name_or_path , tensor_queue , result_queue , done_event ]
1211+ )
1212+ if self .tensor_parallel_rank == 0 :
1213+ read_res_process .start ()
12201214
1221- if self .tensor_parallel_rank == 0 :
1222- read_res_process .start ()
1223- output_tensor = paddle .full (
1224- shape = [SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2 , 1 ], fill_value = 2 , dtype = "int64"
1225- ).cpu ()
1215+ output_tensor = paddle .full (shape = output_tensor_shape , fill_value = 2 , dtype = "int64" ).cpu ()
12261216
12271217 tensor_queue .put (output_tensor )
12281218 if self .tensor_parallel_rank == 0 :
0 commit comments