Skip to content

Commit 1a3461d

Browse files
authored
Add BloomModel hydra support (#129)
1 parent 33deeb1 commit 1a3461d

File tree

1 file changed

+203
-5
lines changed

1 file changed

+203
-5
lines changed

trlx/model/nn/ppo_models.py

Lines changed: 203 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
import transformers
1010
from torchtyping import TensorType
1111
from transformers.modeling_outputs import ModelOutput
12-
from transformers.models.opt.modeling_opt import _make_causal_mask, _expand_mask
12+
from transformers.models.bloom import modeling_bloom
13+
from transformers.models.opt import modeling_opt
1314

1415
from trlx.data.method_configs import MethodConfig, register_method
1516
from trlx.utils.modeling import (
@@ -698,15 +699,15 @@ def forward(
698699
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
699700
combined_attention_mask = None
700701
if input_shape[-1] > 1:
701-
combined_attention_mask = _make_causal_mask(
702+
combined_attention_mask = modeling_opt._make_causal_mask(
702703
input_shape,
703704
hidden_states.dtype,
704705
past_key_values_length=past_key_values_length,
705706
).to(hidden_states.device)
706707

707708
if attention_mask is not None:
708709
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
709-
expanded_attn_mask = _expand_mask(
710+
expanded_attn_mask = modeling_opt._expand_mask(
710711
attention_mask, hidden_states.dtype, tgt_len=input_shape[-1]
711712
).to(hidden_states.device)
712713
combined_attention_mask = (
@@ -798,6 +799,193 @@ def forward(
798799
)
799800

800801

802+
class BloomModelBranch(transformers.PreTrainedModel):
803+
"""
804+
BloomModelBranch implements the frozen upper trunk of the reference model
805+
used when computing the PPO KL-divergence penalty. Expects a list of
806+
frozen transformer blocks and an lm_head from the base model.
807+
"""
808+
809+
def __init__(
810+
self,
811+
config: transformers.PretrainedConfig,
812+
transformer_blocks: nn.ModuleList,
813+
final_norm: nn.Module,
814+
lm_head: nn.Module,
815+
):
816+
super().__init__(config)
817+
818+
# Defined by the main trunk
819+
self.hidden_size = hf_get_hidden_size(config)
820+
self.transformer_blocks = deepcopy(nn.ModuleList(transformer_blocks))
821+
self.final_norm = deepcopy(final_norm)
822+
self.lm_head = deepcopy(lm_head)
823+
824+
# Model parallel
825+
self.model_parallel = False
826+
self.device_map = None
827+
self.gradient_checkpointing = False
828+
829+
# Turning off grad saves memory
830+
for block in self.transformer_blocks:
831+
for parameter in block.parameters():
832+
parameter.requires_grad = False
833+
for parameter in lm_head.parameters():
834+
parameter.requires_grad = False
835+
836+
def forward(
837+
self,
838+
hidden_states: torch.Tensor, # Takes as input hidden_states instead of input_ids
839+
output_shape: torch.Tensor, # output_size given by main trunk
840+
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
841+
attention_mask: Optional[torch.FloatTensor] = None,
842+
head_mask: Optional[torch.FloatTensor] = None,
843+
encoder_hidden_states: Optional[torch.Tensor] = None,
844+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
845+
use_cache: Optional[bool] = None,
846+
output_attentions: Optional[bool] = None,
847+
output_hidden_states: Optional[bool] = None,
848+
return_dict: Optional[bool] = False,
849+
position_ids: Optional[torch.LongTensor] = None,
850+
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
851+
output_attentions = (
852+
output_attentions
853+
if output_attentions is not None
854+
else self.config.output_attentions
855+
)
856+
output_hidden_states = (
857+
output_hidden_states
858+
if output_hidden_states is not None
859+
else self.config.output_hidden_states
860+
)
861+
use_cache = use_cache if use_cache is not None else self.config.use_cache
862+
return_dict = (
863+
return_dict if return_dict is not None else self.config.use_return_dict
864+
)
865+
866+
#######################################################################
867+
# Modififed BloomModel.forward
868+
#######################################################################
869+
870+
batch_size, seq_length = hidden_states.shape[:2]
871+
872+
if past_key_values is None:
873+
past_key_values = tuple([None] * len(self.transformer_blocks))
874+
875+
# Prepare head mask if needed
876+
# 1.0 in head_mask indicate we keep the head
877+
# attention_probs has shape batch_size x num_heads x N x N
878+
# head_mask has shape n_layer x batch x num_heads x N x N
879+
head_mask = self.get_head_mask(head_mask, hf_get_num_hidden_layers(self.config))
880+
881+
presents = () if use_cache else None
882+
all_self_attentions = () if output_attentions else None
883+
all_hidden_states = () if output_hidden_states else None
884+
885+
# Compute alibi tensor: check modeling_bloom.build_alibi_tensor documentation
886+
seq_length_with_past = seq_length
887+
past_key_values_length = 0
888+
if past_key_values[0] is not None:
889+
past_key_values_length = past_key_values[0][0].shape[2]
890+
seq_length_with_past = seq_length_with_past + past_key_values_length
891+
if attention_mask is None:
892+
attention_mask = torch.ones(
893+
(batch_size, seq_length_with_past), device=hidden_states.device
894+
)
895+
else:
896+
attention_mask = attention_mask.to(hidden_states.device)
897+
898+
alibi = modeling_bloom.build_alibi_tensor(
899+
attention_mask, self.config.n_head, dtype=hidden_states.dtype
900+
)
901+
902+
# create causal mask
903+
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
904+
combined_attention_mask = None
905+
device = attention_mask.device
906+
input_shape = (batch_size, seq_length)
907+
_, src_length = input_shape
908+
909+
if src_length > 1:
910+
combined_attention_mask = modeling_bloom._make_causal_mask(
911+
input_shape,
912+
device=device,
913+
past_key_values_length=past_key_values_length,
914+
)
915+
916+
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
917+
expanded_attn_mask = modeling_bloom._expand_mask(
918+
attention_mask, tgt_length=src_length
919+
)
920+
combined_attention_mask = (
921+
expanded_attn_mask
922+
if combined_attention_mask is None
923+
else expanded_attn_mask | combined_attention_mask
924+
)
925+
causal_mask = combined_attention_mask
926+
927+
for i, (block, layer_past) in enumerate(
928+
zip(self.transformer_blocks, past_key_values)
929+
):
930+
931+
if output_hidden_states:
932+
all_hidden_states = all_hidden_states + (hidden_states,)
933+
934+
outputs = block(
935+
hidden_states,
936+
layer_past=layer_past,
937+
attention_mask=causal_mask,
938+
head_mask=head_mask[i],
939+
use_cache=use_cache,
940+
output_attentions=output_attentions,
941+
alibi=alibi,
942+
)
943+
944+
hidden_states = outputs[0]
945+
if use_cache is True:
946+
presents = presents + (outputs[1],)
947+
948+
if output_attentions:
949+
all_self_attentions = all_self_attentions + (
950+
outputs[2 if use_cache else 1],
951+
)
952+
953+
# Add last hidden state
954+
hidden_states = self.final_norm(hidden_states)
955+
956+
if output_hidden_states:
957+
all_hidden_states = all_hidden_states + (hidden_states,)
958+
959+
#######################################################################
960+
# End of modified BloomModel.forward
961+
#######################################################################
962+
963+
lm_logits = self.lm_head(hidden_states)
964+
965+
if not return_dict:
966+
return tuple(
967+
v
968+
for v in [
969+
lm_logits,
970+
hidden_states,
971+
presents,
972+
all_hidden_states,
973+
all_self_attentions,
974+
]
975+
if v is not None
976+
)
977+
978+
return CausalLMOutputWithCrossAttentions(
979+
loss=None,
980+
logits=lm_logits,
981+
past_key_values=presents,
982+
hidden_states=all_hidden_states,
983+
attentions=all_self_attentions,
984+
cross_attentions=None,
985+
value=None,
986+
)
987+
988+
801989
def hf_get_causal_lm_branch_class(
802990
config: transformers.PretrainedConfig,
803991
) -> "ModelBranch":
@@ -809,14 +997,24 @@ def hf_get_causal_lm_branch_class(
809997
"GPTNeoXForCausalLM",
810998
]
811999
opt_branch_supported_archs = ["OPTForCausalLM"]
1000+
bloom_branch_supported_archs = ["BloomModel", "BloomForCausalLM"]
8121001
arch = config.architectures[0]
8131002
if arch in gpt_branch_supported_archs:
8141003
return GPTModelBranch
8151004
elif arch in opt_branch_supported_archs:
8161005
return OPTModelBranch
1006+
elif arch in bloom_branch_supported_archs:
1007+
return BloomModelBranch
8171008
else:
1009+
all_supported_archs = sum(
1010+
[
1011+
gpt_branch_supported_archs,
1012+
opt_branch_supported_archs,
1013+
bloom_branch_supported_archs,
1014+
],
1015+
[],
1016+
)
8181017
raise ValueError(
8191018
f"Unsupported architecture: `{arch}`. The following architectures are "
820-
"available for model branching:\n"
821-
f"{gpt_branch_supported_archs + opt_branch_supported_archs}"
1019+
"available for model branching:\n{all_supported_archs}"
8221020
)

0 commit comments

Comments
 (0)