Skip to content

Commit 64f5141

Browse files
garyhlaielusenji
authored andcommitted
LayoutLMv2Processor: ensure 1-to-1 mapping between images and samples in case of overflowing tokens (huggingface#17092)
* add get_overflowing_images function to ensure 1-to-1 mapping between samples and images in LayoutLMv2Processor * make style * add test for overflowing_tokens, change assert to ValueError, avoiding unrelated formatting changes * change line length by passing --preview into black
1 parent fb58c9b commit 64f5141

File tree

2 files changed

+55
-3
lines changed

2 files changed

+55
-3
lines changed

src/transformers/models/layoutlmv2/processing_layoutlmv2.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,12 @@ def __call__(
8686

8787
if self.feature_extractor.apply_ocr and (word_labels is not None):
8888
raise ValueError(
89-
"You cannot provide word labels "
90-
"if you initialized the feature extractor with apply_ocr set to True."
89+
"You cannot provide word labels if you initialized the feature extractor with apply_ocr set to True."
9190
)
9291

92+
if return_overflowing_tokens is True and return_offsets_mapping is False:
93+
raise ValueError("You cannot return overflowing tokens without returning the offsets mapping.")
94+
9395
# first, apply the feature extractor
9496
features = self.feature_extractor(images=images, return_tensors=return_tensors)
9597

@@ -122,6 +124,23 @@ def __call__(
122124
)
123125

124126
# add pixel values
125-
encoded_inputs["image"] = features.pop("pixel_values")
127+
images = features.pop("pixel_values")
128+
if return_overflowing_tokens is True:
129+
images = self.get_overflowing_images(images, encoded_inputs["overflow_to_sample_mapping"])
130+
encoded_inputs["image"] = images
126131

127132
return encoded_inputs
133+
134+
def get_overflowing_images(self, images, overflow_to_sample_mapping):
135+
# in case there's an overflow, ensure each `input_ids` sample is mapped to its corresponding image
136+
images_with_overflow = []
137+
for sample_idx in overflow_to_sample_mapping:
138+
images_with_overflow.append(images[sample_idx])
139+
140+
if len(images_with_overflow) != len(overflow_to_sample_mapping):
141+
raise ValueError(
142+
"Expected length of images to be the same as the length of `overflow_to_sample_mapping`, but got"
143+
f" {len(images_with_overflow)} and {len(overflow_to_sample_mapping)}"
144+
)
145+
146+
return images_with_overflow

tests/models/layoutlmv2/test_processor_layoutlmv2.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,39 @@ def test_save_load_pretrained_additional_features(self):
133133
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
134134
self.assertIsInstance(processor.feature_extractor, LayoutLMv2FeatureExtractor)
135135

136+
@slow
137+
def test_overflowing_tokens(self):
138+
# In the case of overflowing tokens, test that we still have 1-to-1 mapping between the images and input_ids (sequences that are too long are broken down into multiple sequences).
139+
140+
from datasets import load_dataset
141+
142+
# set up
143+
datasets = load_dataset("nielsr/funsd")
144+
processor = LayoutLMv2Processor.from_pretrained("microsoft/layoutlmv2-base-uncased", revision="no_ocr")
145+
146+
def preprocess_data(examples):
147+
images = [Image.open(path).convert("RGB") for path in examples["image_path"]]
148+
words = examples["words"]
149+
boxes = examples["bboxes"]
150+
word_labels = examples["ner_tags"]
151+
encoded_inputs = processor(
152+
images,
153+
words,
154+
boxes=boxes,
155+
word_labels=word_labels,
156+
padding="max_length",
157+
truncation=True,
158+
return_overflowing_tokens=True,
159+
stride=50,
160+
return_offsets_mapping=True,
161+
return_tensors="pt",
162+
)
163+
return encoded_inputs
164+
165+
train_data = preprocess_data(datasets["train"])
166+
167+
self.assertEqual(len(train_data["image"]), len(train_data["input_ids"]))
168+
136169

137170
# different use cases tests
138171
@require_torch

0 commit comments

Comments
 (0)