9
9
import transformers
10
10
from torchtyping import TensorType
11
11
from transformers .modeling_outputs import ModelOutput
12
- from transformers .models .opt .modeling_opt import _make_causal_mask , _expand_mask
12
+ from transformers .models .bloom import modeling_bloom
13
+ from transformers .models .opt import modeling_opt
13
14
14
15
from trlx .data .method_configs import MethodConfig , register_method
15
16
from trlx .utils .modeling import (
@@ -698,15 +699,15 @@ def forward(
698
699
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
699
700
combined_attention_mask = None
700
701
if input_shape [- 1 ] > 1 :
701
- combined_attention_mask = _make_causal_mask (
702
+ combined_attention_mask = modeling_opt . _make_causal_mask (
702
703
input_shape ,
703
704
hidden_states .dtype ,
704
705
past_key_values_length = past_key_values_length ,
705
706
).to (hidden_states .device )
706
707
707
708
if attention_mask is not None :
708
709
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
709
- expanded_attn_mask = _expand_mask (
710
+ expanded_attn_mask = modeling_opt . _expand_mask (
710
711
attention_mask , hidden_states .dtype , tgt_len = input_shape [- 1 ]
711
712
).to (hidden_states .device )
712
713
combined_attention_mask = (
@@ -798,6 +799,193 @@ def forward(
798
799
)
799
800
800
801
802
+ class BloomModelBranch (transformers .PreTrainedModel ):
803
+ """
804
+ BloomModelBranch implements the frozen upper trunk of the reference model
805
+ used when computing the PPO KL-divergence penalty. Expects a list of
806
+ frozen transformer blocks and an lm_head from the base model.
807
+ """
808
+
809
+ def __init__ (
810
+ self ,
811
+ config : transformers .PretrainedConfig ,
812
+ transformer_blocks : nn .ModuleList ,
813
+ final_norm : nn .Module ,
814
+ lm_head : nn .Module ,
815
+ ):
816
+ super ().__init__ (config )
817
+
818
+ # Defined by the main trunk
819
+ self .hidden_size = hf_get_hidden_size (config )
820
+ self .transformer_blocks = deepcopy (nn .ModuleList (transformer_blocks ))
821
+ self .final_norm = deepcopy (final_norm )
822
+ self .lm_head = deepcopy (lm_head )
823
+
824
+ # Model parallel
825
+ self .model_parallel = False
826
+ self .device_map = None
827
+ self .gradient_checkpointing = False
828
+
829
+ # Turning off grad saves memory
830
+ for block in self .transformer_blocks :
831
+ for parameter in block .parameters ():
832
+ parameter .requires_grad = False
833
+ for parameter in lm_head .parameters ():
834
+ parameter .requires_grad = False
835
+
836
+ def forward (
837
+ self ,
838
+ hidden_states : torch .Tensor , # Takes as input hidden_states instead of input_ids
839
+ output_shape : torch .Tensor , # output_size given by main trunk
840
+ past_key_values : Optional [Tuple [Tuple [torch .Tensor ]]] = None ,
841
+ attention_mask : Optional [torch .FloatTensor ] = None ,
842
+ head_mask : Optional [torch .FloatTensor ] = None ,
843
+ encoder_hidden_states : Optional [torch .Tensor ] = None ,
844
+ encoder_attention_mask : Optional [torch .FloatTensor ] = None ,
845
+ use_cache : Optional [bool ] = None ,
846
+ output_attentions : Optional [bool ] = None ,
847
+ output_hidden_states : Optional [bool ] = None ,
848
+ return_dict : Optional [bool ] = False ,
849
+ position_ids : Optional [torch .LongTensor ] = None ,
850
+ ) -> Union [Tuple , CausalLMOutputWithCrossAttentions ]:
851
+ output_attentions = (
852
+ output_attentions
853
+ if output_attentions is not None
854
+ else self .config .output_attentions
855
+ )
856
+ output_hidden_states = (
857
+ output_hidden_states
858
+ if output_hidden_states is not None
859
+ else self .config .output_hidden_states
860
+ )
861
+ use_cache = use_cache if use_cache is not None else self .config .use_cache
862
+ return_dict = (
863
+ return_dict if return_dict is not None else self .config .use_return_dict
864
+ )
865
+
866
+ #######################################################################
867
+ # Modififed BloomModel.forward
868
+ #######################################################################
869
+
870
+ batch_size , seq_length = hidden_states .shape [:2 ]
871
+
872
+ if past_key_values is None :
873
+ past_key_values = tuple ([None ] * len (self .transformer_blocks ))
874
+
875
+ # Prepare head mask if needed
876
+ # 1.0 in head_mask indicate we keep the head
877
+ # attention_probs has shape batch_size x num_heads x N x N
878
+ # head_mask has shape n_layer x batch x num_heads x N x N
879
+ head_mask = self .get_head_mask (head_mask , hf_get_num_hidden_layers (self .config ))
880
+
881
+ presents = () if use_cache else None
882
+ all_self_attentions = () if output_attentions else None
883
+ all_hidden_states = () if output_hidden_states else None
884
+
885
+ # Compute alibi tensor: check modeling_bloom.build_alibi_tensor documentation
886
+ seq_length_with_past = seq_length
887
+ past_key_values_length = 0
888
+ if past_key_values [0 ] is not None :
889
+ past_key_values_length = past_key_values [0 ][0 ].shape [2 ]
890
+ seq_length_with_past = seq_length_with_past + past_key_values_length
891
+ if attention_mask is None :
892
+ attention_mask = torch .ones (
893
+ (batch_size , seq_length_with_past ), device = hidden_states .device
894
+ )
895
+ else :
896
+ attention_mask = attention_mask .to (hidden_states .device )
897
+
898
+ alibi = modeling_bloom .build_alibi_tensor (
899
+ attention_mask , self .config .n_head , dtype = hidden_states .dtype
900
+ )
901
+
902
+ # create causal mask
903
+ # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
904
+ combined_attention_mask = None
905
+ device = attention_mask .device
906
+ input_shape = (batch_size , seq_length )
907
+ _ , src_length = input_shape
908
+
909
+ if src_length > 1 :
910
+ combined_attention_mask = modeling_bloom ._make_causal_mask (
911
+ input_shape ,
912
+ device = device ,
913
+ past_key_values_length = past_key_values_length ,
914
+ )
915
+
916
+ # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
917
+ expanded_attn_mask = modeling_bloom ._expand_mask (
918
+ attention_mask , tgt_length = src_length
919
+ )
920
+ combined_attention_mask = (
921
+ expanded_attn_mask
922
+ if combined_attention_mask is None
923
+ else expanded_attn_mask | combined_attention_mask
924
+ )
925
+ causal_mask = combined_attention_mask
926
+
927
+ for i , (block , layer_past ) in enumerate (
928
+ zip (self .transformer_blocks , past_key_values )
929
+ ):
930
+
931
+ if output_hidden_states :
932
+ all_hidden_states = all_hidden_states + (hidden_states ,)
933
+
934
+ outputs = block (
935
+ hidden_states ,
936
+ layer_past = layer_past ,
937
+ attention_mask = causal_mask ,
938
+ head_mask = head_mask [i ],
939
+ use_cache = use_cache ,
940
+ output_attentions = output_attentions ,
941
+ alibi = alibi ,
942
+ )
943
+
944
+ hidden_states = outputs [0 ]
945
+ if use_cache is True :
946
+ presents = presents + (outputs [1 ],)
947
+
948
+ if output_attentions :
949
+ all_self_attentions = all_self_attentions + (
950
+ outputs [2 if use_cache else 1 ],
951
+ )
952
+
953
+ # Add last hidden state
954
+ hidden_states = self .final_norm (hidden_states )
955
+
956
+ if output_hidden_states :
957
+ all_hidden_states = all_hidden_states + (hidden_states ,)
958
+
959
+ #######################################################################
960
+ # End of modified BloomModel.forward
961
+ #######################################################################
962
+
963
+ lm_logits = self .lm_head (hidden_states )
964
+
965
+ if not return_dict :
966
+ return tuple (
967
+ v
968
+ for v in [
969
+ lm_logits ,
970
+ hidden_states ,
971
+ presents ,
972
+ all_hidden_states ,
973
+ all_self_attentions ,
974
+ ]
975
+ if v is not None
976
+ )
977
+
978
+ return CausalLMOutputWithCrossAttentions (
979
+ loss = None ,
980
+ logits = lm_logits ,
981
+ past_key_values = presents ,
982
+ hidden_states = all_hidden_states ,
983
+ attentions = all_self_attentions ,
984
+ cross_attentions = None ,
985
+ value = None ,
986
+ )
987
+
988
+
801
989
def hf_get_causal_lm_branch_class (
802
990
config : transformers .PretrainedConfig ,
803
991
) -> "ModelBranch" :
@@ -809,14 +997,24 @@ def hf_get_causal_lm_branch_class(
809
997
"GPTNeoXForCausalLM" ,
810
998
]
811
999
opt_branch_supported_archs = ["OPTForCausalLM" ]
1000
+ bloom_branch_supported_archs = ["BloomModel" , "BloomForCausalLM" ]
812
1001
arch = config .architectures [0 ]
813
1002
if arch in gpt_branch_supported_archs :
814
1003
return GPTModelBranch
815
1004
elif arch in opt_branch_supported_archs :
816
1005
return OPTModelBranch
1006
+ elif arch in bloom_branch_supported_archs :
1007
+ return BloomModelBranch
817
1008
else :
1009
+ all_supported_archs = sum (
1010
+ [
1011
+ gpt_branch_supported_archs ,
1012
+ opt_branch_supported_archs ,
1013
+ bloom_branch_supported_archs ,
1014
+ ],
1015
+ [],
1016
+ )
818
1017
raise ValueError (
819
1018
f"Unsupported architecture: `{ arch } `. The following architectures are "
820
- "available for model branching:\n "
821
- f"{ gpt_branch_supported_archs + opt_branch_supported_archs } "
1019
+ "available for model branching:\n {all_supported_archs}"
822
1020
)
0 commit comments