@@ -1785,24 +1785,32 @@ def execute_model(
1785
1785
1786
1786
if model_input .inputs_embeds is not None :
1787
1787
if self .is_driver_worker :
1788
- sampled = broadcast_tensor_dict (
1789
- {"token_ids" : output .sampled_token_ids })
1788
+ sampled_token_ids = []
1789
+ valid_outputs = []
1790
+ for sequence_group_output in output .outputs :
1791
+ if len (sequence_group_output .samples ) == 0 :
1792
+ continue
1793
+ assert len (sequence_group_output .samples ) == 1
1794
+ valid_outputs .append (sequence_group_output )
1795
+ sampled_token_ids .append (
1796
+ sequence_group_output .samples [0 ].output_token )
1797
+ sampled_token_ids = torch .tensor (sampled_token_ids ).to (
1798
+ self .device )
1799
+ sampled_token_ids = broadcast_tensor_dict (
1800
+ {"sampled_token_ids" :
1801
+ sampled_token_ids })["sampled_token_ids" ]
1790
1802
else :
1791
- sampled = broadcast_tensor_dict ()
1792
- if sampled ["token_ids" ] is not None :
1793
- sampled_token_embeds = self .model .get_input_embeddings (
1794
- sampled ["token_ids" ].squeeze (1 ))
1803
+ sampled_token_ids = broadcast_tensor_dict (
1804
+ )["sampled_token_ids" ]
1805
+ if len (sampled_token_ids ) > 0 :
1806
+ sampled_token_embeds = \
1807
+ self .model .get_input_embeddings (sampled_token_ids )
1795
1808
if self .is_driver_worker :
1796
1809
self .sampler .include_gpu_probs_tensor = \
1797
1810
orig_include_gpu_probs
1798
-
1799
- output .sampled_token_embeds = sampled_token_embeds
1800
-
1801
- for token_embed , sequence_group_output in zip (
1802
- output .sampled_token_embeds , output .outputs ):
1803
- assert len (sequence_group_output .samples ) == 1
1804
- sequence_group_output .samples [
1805
- 0 ].output_embed = token_embed
1811
+ for i , sequence_group_output in enumerate (valid_outputs ):
1812
+ sequence_group_output .samples [0 ].output_embed = \
1813
+ sampled_token_embeds [i ]
1806
1814
1807
1815
if not self .is_driver_worker :
1808
1816
return []
0 commit comments