@@ -248,16 +248,19 @@ def _preprocess(self, source):
248248 def _infer (self , inputs ):
249249 raise NotImplementedError
250250
251- def _postprocess (self , predictions ):
251+ def _postprocess (self , predictions , return_tokens = False ):
252252 decoded_predictions = self .tokenizer .batch_decode (
253253 predictions , skip_special_tokens = True , clean_up_tokenization_spaces = False
254254 )
255- return decoded_predictions
255+ if return_tokens :
256+ return decoded_predictions , predictions
257+ else :
258+ return decoded_predictions
256259
257- def predict (self , input_texts : str | list [str ]):
260+ def predict (self , input_texts : str | list [str ], return_tokens = False ):
258261 tokenized_source = self ._preprocess (input_texts )
259262 predictions = self ._infer (tokenized_source )
260- decoded_predictions = self ._postprocess (predictions )
263+ decoded_predictions = self ._postprocess (predictions , return_tokens = return_tokens )
261264 return decoded_predictions
262265
263266
@@ -470,13 +473,16 @@ def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer):
470473 )
471474 self .generation_config = None
472475
473- def _postprocess (self , predictions ):
476+ def _postprocess (self , predictions , return_tokens = False ):
474477 if paddle .distributed .get_rank () == 0 :
475478 tokens : np .ndarray = load_real_time_tokens ()
476479 decoded_predictions = self .tokenizer .batch_decode (
477480 tokens .tolist (), skip_special_tokens = True , clean_up_tokenization_spaces = False
478481 )
479- return decoded_predictions
482+ if return_tokens :
483+ return decoded_predictions , tokens .tolist ()
484+ else :
485+ return decoded_predictions
480486 else :
481487 return None
482488
@@ -1034,7 +1040,7 @@ def _infer(self, inputs: dict[str, paddle.Tensor]):
10341040 )
10351041
10361042 @paddle .no_grad ()
1037- def predict (self , input_texts : str | list [str ]):
1043+ def predict (self , input_texts : str | list [str ], return_tokens = False ):
10381044 self ._preprocess (input_texts )
10391045
10401046 result_queue = mp .Queue ()
@@ -1055,9 +1061,15 @@ def predict(self, input_texts: str | list[str]):
10551061 self .used_list [i ] = []
10561062
10571063 outputs = []
1064+ output_tokens = []
10581065 while len (outputs ) < self .batch_size :
1059- outputs .append (result_queue .get (timeout = 1 )[- 1 ])
1060- return outputs
1066+ result = result_queue .get (timeout = 1 )
1067+ outputs .append (result [- 1 ])
1068+ output_tokens .append (result [- 2 ])
1069+ if return_tokens :
1070+ return outputs , output_tokens
1071+ else :
1072+ return outputs
10611073
10621074
10631075class StaticBlockInferencePredictor (BlockInferencePredictorMixin , BasePredictor ):
@@ -1180,7 +1192,7 @@ def _share_data(self):
11801192 def _infer (self ):
11811193 self .predictor .run ()
11821194
1183- def predict (self , input_texts : str | list [str ]):
1195+ def predict (self , input_texts : str | list [str ], return_tokens = False ):
11841196
11851197 s_time = time .time ()
11861198 self ._preprocess (input_texts )
@@ -1213,9 +1225,15 @@ def predict(self, input_texts: str | list[str]):
12131225 self .used_list [i ] = []
12141226
12151227 outputs = []
1228+ output_tokens = []
12161229 while len (outputs ) < self .batch_size :
1217- outputs .append (result_queue .get (timeout = 1 )[- 1 ])
1218- return outputs
1230+ result = result_queue .get (timeout = 1 )
1231+ outputs .append (result [- 1 ])
1232+ output_tokens .append (result [- 2 ])
1233+ if return_tokens :
1234+ return outputs , output_tokens
1235+ else :
1236+ return outputs
12191237
12201238 def _preprocess (self , source ):
12211239 BlockInferencePredictorMixin ._preprocess (self , source )
@@ -1681,8 +1699,8 @@ def benchmark(predictor, predictor_args, model_args):
16811699 output_tokens = 0
16821700 for _ in range (test_time ):
16831701 for bs , batch_source_text in enumerate (batch_benchmark_texts ):
1684- outputs = predictor .predict (batch_source_text )
1685- output_tokens += sum ([len (output ) for output in outputs ])
1702+ outputs , batch_tokens = predictor .predict (batch_source_text , return_tokens = True )
1703+ output_tokens += sum ([len (tokens ) for tokens in batch_tokens ])
16861704 end = time .perf_counter ()
16871705 print ("Avg Elapse time is: " , (end - start ) / test_time )
16881706 print ("Output tokens is: " , output_tokens )
0 commit comments