@@ -3818,6 +3818,7 @@ def new_get_rank(group=None):
3818
3818
eos_token = '<|im_end|>' ,
3819
3819
support_flash_attn = True ,
3820
3820
tags = ['multi-modal' , 'vision' ],
3821
+ function_kwargs = {'is_v2_5' : True },
3821
3822
hf_model_id = 'internlm/internlm-xcomposer2d5-7b' )
3822
3823
@register_model (
3823
3824
ModelType .internlm_xcomposer2_7b_chat ,
@@ -3833,6 +3834,7 @@ def get_model_tokenizer_internlm_xcomposer2(model_dir: str,
3833
3834
model_kwargs : Dict [str , Any ],
3834
3835
load_model : bool = True ,
3835
3836
** kwargs ):
3837
+ is_v2_5 = kwargs .pop ('is_v2_5' , False )
3836
3838
model_config = AutoConfig .from_pretrained (model_dir , trust_remote_code = True )
3837
3839
use_flash_attn = kwargs .pop ('use_flash_attn' , False )
3838
3840
model_config ._flash_attn_2_enabled = use_flash_attn
@@ -3850,23 +3852,36 @@ def get_model_tokenizer_internlm_xcomposer2(model_dir: str,
3850
3852
model .model .layers [0 ].attention .__class__ .attention_dropout = 0.
3851
3853
3852
3854
model_cls = model .__class__
3853
- if not hasattr (model_cls , '__old_encode_img' ): # avoid double patching
3854
- model_cls .__old_encode_img = model_cls .encode_img
3855
-
3856
- def _new_encode_img (self , image ):
3857
- if image is None :
3858
- return None
3859
- if isinstance (image , str ):
3860
- from PIL import Image
3861
- image = Image .open (image ).convert ('RGB' )
3862
- image = self .vis_processor (image ).unsqueeze (0 ).to (self .device )
3863
- else :
3864
- assert isinstance (image , torch .Tensor )
3865
3855
3866
- img_embeds , atts_img , img_target = self .img2emb (image )
3867
- return img_embeds .to (device = self .device ) # FIX device_map
3856
+ if is_v2_5 :
3857
+
3858
+ def _output_device_map_hook (module , input , output ):
3859
+ output = (output [0 ].to (input [1 ].device ), output [1 ])
3860
+ return output
3861
+
3862
+ def _output_device_map_hook2 (module , input , output ):
3863
+ return output .to (input [0 ].device )
3864
+
3865
+ model .vit .register_forward_hook (_output_device_map_hook )
3866
+ model .vision_proj .register_forward_hook (_output_device_map_hook2 )
3867
+ else :
3868
+ if not hasattr (model_cls , '__old_encode_img' ): # avoid double patching
3869
+ model_cls .__old_encode_img = model_cls .encode_img
3870
+
3871
+ def _new_encode_img (self , image ):
3872
+ if image is None :
3873
+ return None
3874
+ if isinstance (image , str ):
3875
+ from PIL import Image
3876
+ image = Image .open (image ).convert ('RGB' )
3877
+ image = self .vis_processor (image ).unsqueeze (0 ).to (self .device )
3878
+ else :
3879
+ assert isinstance (image , torch .Tensor )
3880
+
3881
+ img_embeds , atts_img , img_target = self .img2emb (image )
3882
+ return img_embeds .to (device = self .device ) # FIX device_map
3868
3883
3869
- model_cls .encode_img = _new_encode_img
3884
+ model_cls .encode_img = _new_encode_img
3870
3885
3871
3886
return model , tokenizer
3872
3887
0 commit comments