Skip to content

Commit 4654c1b

Browse files
authored
fix template (modelscope#1321)
1 parent 0e4c5a7 commit 4654c1b

File tree

1 file changed

+25
-24
lines changed

1 file changed

+25
-24
lines changed

swift/llm/utils/template.py

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -929,7 +929,7 @@ def _load_image(img_path: Union[str, 'PIL.Image.Image']) -> 'PIL.Image.Image':
929929
return image
930930

931931

932-
def _load_video(video_path: str) -> np.ndarray:
932+
def _load_video_llava(video_path: str) -> np.ndarray:
933933
import av
934934
container = av.open(video_path)
935935
total_frames = container.streams.video[0].frames
@@ -1036,7 +1036,7 @@ def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any
10361036
if idx_list:
10371037
idx = idx_list[0]
10381038
images_path = example.get('images') or []
1039-
image = _load_image(images_path[0])
1039+
image = _read_batch(images_path)[0]
10401040
placeholder = '<|begin_of_image|><|endoftext|><|end_of_image|>'
10411041
placeholder_id = self.tokenizer.encode(placeholder, add_special_tokens=False)
10421042
input_ids = (input_ids[:idx] + placeholder_id + input_ids[idx + 1:])
@@ -1627,7 +1627,7 @@ def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any
16271627
else:
16281628
videos_path.append(media_file)
16291629
if len(videos_path) > 0:
1630-
videos = _read_batch(videos_path, _load_video)
1630+
videos = _read_batch(videos_path, _load_video_llava)
16311631
video_processor = self.tokenizer.processor.video_processor
16321632
video_inputs = video_processor(videos, return_tensors='pt').to(self.model.dtype)
16331633
inputs['pixel_values_videos'] = video_inputs['pixel_values_videos']
@@ -1766,9 +1766,9 @@ def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any
17661766
if len(inputs) == 0:
17671767
return inputs, {}
17681768
image_path = example.get('images') or []
1769-
if image_path:
1770-
raw_image = _load_image(image_path[0])
1771-
pixel_values = self.tokenizer.processor.image_processor(raw_image, return_tensors='pt')['pixel_values']
1769+
raw_image = _read_batch(image_path)
1770+
if raw_image:
1771+
pixel_values = self.tokenizer.processor.image_processor(raw_image[0], return_tensors='pt')['pixel_values']
17721772
inputs['pixel_values'] = pixel_values.to(self.model.dtype)
17731773
return inputs, {}
17741774

@@ -1806,9 +1806,9 @@ def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any
18061806
inputs['token_type_ids'] = [0] * n + [1] * n2
18071807
else:
18081808
inputs['token_type_ids'] = [0] * len(inputs['input_ids'])
1809-
if image_path:
1810-
raw_image = _load_image(image_path[0])
1811-
model_inputs = processor(text=example['query'], images=raw_image, return_tensors='pt')
1809+
raw_image = _read_batch(image_path)
1810+
if raw_image:
1811+
model_inputs = processor(text=example['query'], images=raw_image[0], return_tensors='pt')
18121812
inputs['pixel_values'] = model_inputs['pixel_values']
18131813
return inputs, {}
18141814

@@ -1949,9 +1949,9 @@ def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any
19491949
example['query'], example['history'], images_path = replace_img_tag(example['query'], history,
19501950
'<image_placeholder>')
19511951
inputs, _ = super().encode(example)
1952-
images_path.extend(example.get('images') or [])
19531952
if len(inputs) == 0:
19541953
return inputs, {}
1954+
images_path.extend(example.get('images') or [])
19551955
images = _read_batch(images_path)
19561956
processor = self.tokenizer.processor
19571957
input_ids, labels = inputs['input_ids'], inputs['labels']
@@ -2024,34 +2024,35 @@ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index, exa
20242024

20252025
def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
20262026
inputs, _ = super().encode(example)
2027-
images_path = example.get('images') or []
2028-
image = _load_image(images_path[0]) if len(images_path) >= 1 else []
20292027
if len(inputs) == 0:
20302028
return inputs, {}
2029+
images_path = example.get('images') or []
2030+
image = _read_batch(images_path)
20312031
inputs.pop('loss_scale', None)
20322032
model = self.model
20332033
inputs2 = model.build_conversation_input_ids(
2034-
self.tokenizer, query=example['query'], history=example.get('history'), images=[image])
2035-
image_token_len = inputs2['token_type_ids'].sum()
2034+
self.tokenizer, query=example['query'], history=example.get('history'), images=image)
2035+
image_token_len = inputs2['token_type_ids'].sum().item()
20362036
input_ids = inputs['input_ids']
20372037
labels = inputs['labels']
20382038
inputs['token_type_ids'] = [0] + [1] * image_token_len + [0] * len(input_ids[1:])
20392039
inputs['input_ids'] = input_ids[:1] + [self.tokenizer.pad_token_id] * image_token_len + input_ids[1:]
20402040
if labels is not None:
20412041
inputs['labels'] = labels[:1] + [-100] * image_token_len + labels[1:]
2042-
dtype = model.dtype
2043-
inputs['images'] = [[img.to(dtype=dtype)] for img in inputs2['images']]
2044-
if 'cross_images' in inputs2:
2045-
# is cogagent
2046-
inputs['cross_images'] = [[cross_img.to(dtype=dtype)] for cross_img in inputs2['cross_images']]
2042+
if len(image) > 0:
2043+
dtype = model.dtype
2044+
inputs['images'] = [[img.to(dtype=dtype)] for img in inputs2['images']]
2045+
if 'cross_images' in inputs2:
2046+
# is cogagent
2047+
inputs['cross_images'] = [[cross_img.to(dtype=dtype)] for cross_img in inputs2['cross_images']]
20472048
return inputs, {}
20482049

20492050
def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] = None) -> Dict[str, Any]:
20502051
res = super().data_collator(batch, padding_to)
2051-
is_cogagent = 'cross_images' in batch[0]
2052-
keys = ['images', 'cross_images'] if is_cogagent else ['images']
2052+
keys = ['images', 'cross_images']
20532053
for key in keys:
2054-
res[key] = [b[key][0] for b in batch]
2054+
if key in batch[0]:
2055+
res[key] = [b[key][0] for b in batch]
20552056
token_type_ids = [torch.tensor(b['token_type_ids']) for b in batch]
20562057
token_type_ids = pad_sequence(token_type_ids, batch_first=True, padding_value=0)
20572058
res['token_type_ids'] = token_type_ids
@@ -2107,10 +2108,10 @@ def check_example(self, example):
21072108

21082109
def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
21092110
inputs, _ = super().encode(example)
2110-
images_path = example['images']
2111-
image = _load_image(images_path[0])
21122111
if len(inputs) == 0:
21132112
return inputs, {}
2113+
images_path = example['images']
2114+
image = _load_image(images_path[0])
21142115
input_ids = inputs['input_ids']
21152116
labels = inputs['labels']
21162117
idx_list = _findall(input_ids, -1)

0 commit comments

Comments
 (0)