Skip to content

Commit a0c2345

Browse files
authored
fix xcomposer2.5 device_map (modelscope#1343)
1 parent 93b468e commit a0c2345

File tree

3 files changed

+54
-17
lines changed

3 files changed

+54
-17
lines changed

docs/source/Multi-Modal/cogvlm2-video最佳实践.md

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,18 @@ response: The video shows a person lighting a fire in a backyard setting. The pe
110110
# 40GB GPU memory
111111
CUDA_VISIBLE_DEVICES=0 swift sft \
112112
--model_type cogvlm2-video-13b-chat \
113-
--dataset video-chatgpt
113+
--dataset video-chatgpt \
114+
--num_train_epochs 3 \
115+
116+
# ZeRO2
117+
# Experimental environment: 4 * A100
118+
# 4 * 40GB GPU memory
119+
NPROC_PER_NODE=4 \
120+
CUDA_VISIBLE_DEVICES=0,1,2,3 swift sft \
121+
--model_type cogvlm2-video-13b-chat \
122+
--dataset video-chatgpt \
123+
--num_train_epochs 3 \
124+
--deepspeed default-zero2
114125
```
115126

116127
[自定义数据集](../LLM/自定义与拓展.md#-推荐命令行参数的形式)支持json, jsonl样式, 以下是自定义数据集的例子:

docs/source_en/Multi-Modal/cogvlm2-video-best-practice.md

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,18 @@ Fine-tuning multimodal large models usually uses **custom datasets**. Here is a
109109
# 40GB GPU memory
110110
CUDA_VISIBLE_DEVICES=0 swift sft \
111111
--model_type cogvlm2-video-13b-chat \
112-
--dataset video-chatgpt
112+
--dataset video-chatgpt \
113+
--num_train_epochs 3 \
114+
115+
# ZeRO2
116+
# Experimental environment: 4 * A100
117+
# 4 * 40GB GPU memory
118+
NPROC_PER_NODE=4 \
119+
CUDA_VISIBLE_DEVICES=0,1,2,3 swift sft \
120+
--model_type cogvlm2-video-13b-chat \
121+
--dataset video-chatgpt \
122+
--num_train_epochs 3 \
123+
--deepspeed default-zero2
113124
```
114125

115126
[Custom datasets](../LLM/Customization.md#-Recommended-Command-line-arguments) support json, jsonl formats. Here is an example of a custom dataset:

swift/llm/utils/model.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3818,6 +3818,7 @@ def new_get_rank(group=None):
38183818
eos_token='<|im_end|>',
38193819
support_flash_attn=True,
38203820
tags=['multi-modal', 'vision'],
3821+
function_kwargs={'is_v2_5': True},
38213822
hf_model_id='internlm/internlm-xcomposer2d5-7b')
38223823
@register_model(
38233824
ModelType.internlm_xcomposer2_7b_chat,
@@ -3833,6 +3834,7 @@ def get_model_tokenizer_internlm_xcomposer2(model_dir: str,
38333834
model_kwargs: Dict[str, Any],
38343835
load_model: bool = True,
38353836
**kwargs):
3837+
is_v2_5 = kwargs.pop('is_v2_5', False)
38363838
model_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
38373839
use_flash_attn = kwargs.pop('use_flash_attn', False)
38383840
model_config._flash_attn_2_enabled = use_flash_attn
@@ -3850,23 +3852,36 @@ def get_model_tokenizer_internlm_xcomposer2(model_dir: str,
38503852
model.model.layers[0].attention.__class__.attention_dropout = 0.
38513853

38523854
model_cls = model.__class__
3853-
if not hasattr(model_cls, '__old_encode_img'): # avoid double patching
3854-
model_cls.__old_encode_img = model_cls.encode_img
3855-
3856-
def _new_encode_img(self, image):
3857-
if image is None:
3858-
return None
3859-
if isinstance(image, str):
3860-
from PIL import Image
3861-
image = Image.open(image).convert('RGB')
3862-
image = self.vis_processor(image).unsqueeze(0).to(self.device)
3863-
else:
3864-
assert isinstance(image, torch.Tensor)
38653855

3866-
img_embeds, atts_img, img_target = self.img2emb(image)
3867-
return img_embeds.to(device=self.device) # FIX device_map
3856+
if is_v2_5:
3857+
3858+
def _output_device_map_hook(module, input, output):
3859+
output = (output[0].to(input[1].device), output[1])
3860+
return output
3861+
3862+
def _output_device_map_hook2(module, input, output):
3863+
return output.to(input[0].device)
3864+
3865+
model.vit.register_forward_hook(_output_device_map_hook)
3866+
model.vision_proj.register_forward_hook(_output_device_map_hook2)
3867+
else:
3868+
if not hasattr(model_cls, '__old_encode_img'): # avoid double patching
3869+
model_cls.__old_encode_img = model_cls.encode_img
3870+
3871+
def _new_encode_img(self, image):
3872+
if image is None:
3873+
return None
3874+
if isinstance(image, str):
3875+
from PIL import Image
3876+
image = Image.open(image).convert('RGB')
3877+
image = self.vis_processor(image).unsqueeze(0).to(self.device)
3878+
else:
3879+
assert isinstance(image, torch.Tensor)
3880+
3881+
img_embeds, atts_img, img_target = self.img2emb(image)
3882+
return img_embeds.to(device=self.device) # FIX device_map
38683883

3869-
model_cls.encode_img = _new_encode_img
3884+
model_cls.encode_img = _new_encode_img
38703885

38713886
return model, tokenizer
38723887

0 commit comments

Comments
 (0)