Skip to content

Commit 3da4519

Browse files
kylesayrsdsikka
andauthored
Vision Datasets (#943)
* clean up CustomDataset Signed-off-by: Kyle Sayers <[email protected]> * chchchchanges Signed-off-by: Kyle Sayers <[email protected]> * wip: use rename to processor, going through tests Signed-off-by: Kyle Sayers <[email protected]> * remove labels from calibration dataset rather than assuming that all tokenized datasets should not be given labels Signed-off-by: Kyle Sayers <[email protected]> * cleanup Signed-off-by: Kyle Sayers <[email protected]> * cleanup, etc Signed-off-by: Kyle Sayers <[email protected]> * fix typehinting Signed-off-by: Kyle Sayers <[email protected]> * add typechecking imports Signed-off-by: Kyle Sayers <[email protected]> * remove sparseml utilities Signed-off-by: Kyle Sayers <[email protected]> * use in model_load Signed-off-by: Kyle Sayers <[email protected]> * remove use of RECIPE FILE NAME Signed-off-by: Kyle Sayers <[email protected]> * rename to RECIPE_FILE_NAME, avoid circular import Signed-off-by: Kyle Sayers <[email protected]> * image dataset collation * cleanup, do not handle case where processor is None Signed-off-by: Kyle Sayers <[email protected]> * remove qa ignore Signed-off-by: Kyle Sayers <[email protected]> * add documentation Signed-off-by: Kyle Sayers <[email protected]> * add data collator arg Signed-off-by: Kyle Sayers <[email protected]> * use default factor Signed-off-by: Kyle Sayers <[email protected]> * validate flickr Signed-off-by: Kyle Sayers <[email protected]> * discover bug, tests and multimodal working * dataset split fallbacks Signed-off-by: Kyle Sayers <[email protected]> * cleanup, depreciate remove_columns argument * silently assign tokenizer to processor * replace tokenizer with processor Signed-off-by: Kyle Sayers <[email protected]> * typehinting, add not-implemented error Signed-off-by: Kyle Sayers <[email protected]> * remove todos Signed-off-by: Kyle Sayers <[email protected]> * update dataset manager api in tests Signed-off-by: Kyle Sayers <[email protected]> * Delete examples/multimodal_vision/qwen_vl2.py * Delete examples/multimodal_vision/mllama.py * handle columns better Signed-off-by: Kyle Sayers <[email protected]> * filter_tokenizer_args Signed-off-by: Kyle Sayers <[email protected]> * more tests Signed-off-by: Kyle Sayers <[email protected]> * remove duplicate file Signed-off-by: Kyle Sayers <[email protected]> * better help texts Signed-off-by: Kyle Sayers <[email protected]> * rvert data split fallbacks * handle non-fast tokenizers * address nits, add logging Signed-off-by: Kyle Sayers <[email protected]> * add back copyrights * correctly update helptext Signed-off-by: Kyle Sayers <[email protected]> * do not remove prompt key Signed-off-by: Kyle Sayers <[email protected]> * remove prompt key Signed-off-by: Kyle Sayers <[email protected]> * do not process tokenized datasets, including adding labels Signed-off-by: Kyle Sayers <[email protected]> * remove default chat template Signed-off-by: Kyle Sayers <[email protected]> * add back default templates with warning Signed-off-by: Kyle Sayers <[email protected]> --------- Signed-off-by: Kyle Sayers <[email protected]> Co-authored-by: Dipika Sikka <[email protected]>
1 parent 384059b commit 3da4519

File tree

21 files changed

+524
-481
lines changed

21 files changed

+524
-481
lines changed

src/llmcompressor/transformers/finetune/data/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .custom import CustomDataset
77
from .data_args import DataTrainingArguments
88
from .evolcodealpaca import EvolCodeAlpacaDataset
9+
from .flickr_30k import Flickr30K
910
from .gsm8k import GSM8KDataset
1011
from .open_platypus import OpenPlatypusDataset
1112
from .ptb import PtbDataset

src/llmcompressor/transformers/finetune/data/base.py

Lines changed: 220 additions & 139 deletions
Large diffs are not rendered by default.

src/llmcompressor/transformers/finetune/data/c4.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
from copy import deepcopy
2+
from typing import TYPE_CHECKING
23

34
from llmcompressor.transformers.finetune.data import TextGenerationDataset
5+
from llmcompressor.typing import Processor
6+
7+
if TYPE_CHECKING:
8+
from llmcompressor.transformers import DataTrainingArguments as DataArgs
49

510

611
@TextGenerationDataset.register(name="c4")
@@ -13,9 +18,9 @@ class C4Dataset(TextGenerationDataset):
1318
:param processor: processor or tokenizer to use on dataset
1419
"""
1520

16-
def __init__(self, data_args, split, processor):
21+
def __init__(self, data_args: "DataArgs", split: str, processor: Processor):
1722
data_args = deepcopy(data_args)
1823
data_args.dataset = "allenai/c4"
19-
super().__init__(
20-
text_column="text", data_args=data_args, split=split, processor=processor
21-
)
24+
data_args.text_column = "text"
25+
26+
super().__init__(data_args=data_args, split=split, processor=processor)

src/llmcompressor/transformers/finetune/data/cnn_dailymail.py

Lines changed: 11 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from copy import deepcopy
15-
from typing import Optional
15+
from typing import TYPE_CHECKING
1616

1717
from llmcompressor.transformers.finetune.data import TextGenerationDataset
18+
from llmcompressor.typing import Processor
19+
20+
if TYPE_CHECKING:
21+
from llmcompressor.transformers import DataTrainingArguments as DataArgs
1822

1923

2024
@TextGenerationDataset.register(name="cnn_dailymail")
@@ -29,39 +33,16 @@ class CNNDailyMailDataset(TextGenerationDataset):
2933

3034
SAMPLE_TEMPLATE = "Article:\n{article}\n\n### Summarization:\n{highlights}\n"
3135

32-
def __init__(self, data_args, split, processor):
36+
def __init__(self, data_args: "DataArgs", split: str, processor: Processor):
3337
data_args = deepcopy(data_args)
3438
data_args.dataset = "cnn_dailymail"
3539
data_args.dataset_config_name = "3.0.0"
3640

37-
super().__init__(
38-
text_column="text", data_args=data_args, split=split, processor=processor
39-
)
40-
41-
def get_raw_dataset(self, cache_dir: Optional[str] = None):
42-
"""
43-
Load the raw dataset from Hugging Face, using cached copy if available.
44-
Additionally reformats the entries to fit the template.
45-
46-
:param cache_dir: disk location to search for cached dataset
47-
:return: the requested dataset
48-
"""
49-
raw_dataset = super().get_raw_dataset(cache_dir=cache_dir)
41+
super().__init__(data_args=data_args, split=split, processor=processor)
5042

51-
def restructure_fn(sample):
52-
sample["text"] = self.SAMPLE_TEMPLATE.format(
43+
def dataset_template(self, sample):
44+
return {
45+
"text": self.SAMPLE_TEMPLATE.format(
5346
article=sample["article"], highlights=sample["highlights"]
5447
)
55-
56-
return sample
57-
58-
raw_dataset = self.map(
59-
raw_dataset,
60-
function=restructure_fn,
61-
batched=False,
62-
remove_columns=["article", "highlights", "id"],
63-
num_proc=self.data_args.preprocessing_num_workers,
64-
load_from_cache_file=not self.data_args.overwrite_cache,
65-
desc="Restructuring CNN/DailyMail Dataset",
66-
)
67-
return raw_dataset
48+
}

src/llmcompressor/transformers/finetune/data/custom.py

Lines changed: 1 addition & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from copy import deepcopy
15-
from typing import Dict, List, Union
16-
17-
from datasets.dataset_dict import Dataset, DatasetDict
18-
1914
from llmcompressor.transformers.finetune.data import TextGenerationDataset
20-
from llmcompressor.transformers.utils.preprocessing_functions import (
21-
PreprocessingFunctionRegistry,
22-
)
23-
from llmcompressor.utils import import_from_path
2415

2516

2617
@TextGenerationDataset.register(name="custom", alias=["json", "csv"])
@@ -36,76 +27,4 @@ class CustomDataset(TextGenerationDataset):
3627
3728
"""
3829

39-
def __init__(self, data_args, split, processor):
40-
data_args = deepcopy(data_args)
41-
super().__init__(
42-
text_column=data_args.text_column,
43-
data_args=data_args,
44-
split=split,
45-
processor=processor,
46-
)
47-
self.preprocessing_func = data_args.preprocessing_func
48-
self.remove_columns = data_args.remove_columns
49-
50-
def get_raw_dataset(self, *_ignore, **__ignore) -> Union[DatasetDict, Dataset]:
51-
"""Get the raw dataset and apply preprocessing func if provided"""
52-
53-
dataset = self.data_args.dataset
54-
if isinstance(dataset, DatasetDict) or isinstance(dataset, Dataset):
55-
# user passed in an already instantiated dataset, just use it directly
56-
raw_dataset = dataset
57-
else:
58-
# dataset must be loaded from file or HF Hub
59-
raw_dataset = super().get_raw_dataset()
60-
61-
if self.preprocessing_func is not None:
62-
if callable(self.preprocessing_func):
63-
func = self.preprocessing_func
64-
elif ":" in self.preprocessing_func:
65-
# load func_name from "/path/to/file.py:func_name"
66-
func = import_from_path(self.preprocessing_func)
67-
else:
68-
# load from the registry
69-
func = PreprocessingFunctionRegistry.get_value_from_registry(
70-
name=self.preprocessing_func
71-
)
72-
73-
raw_dataset = self.map(
74-
raw_dataset,
75-
function=func,
76-
batched=False,
77-
num_proc=self.data_args.preprocessing_num_workers,
78-
desc="Applying custom func to the custom dataset",
79-
)
80-
81-
self.remove_columns = (
82-
self.remove_columns or self.get_remove_columns_from_dataset(raw_dataset)
83-
)
84-
85-
if self.remove_columns is not None:
86-
raw_dataset = self.map(
87-
raw_dataset,
88-
batched=True,
89-
remove_columns=self.remove_columns,
90-
num_proc=self.data_args.preprocessing_num_workers,
91-
desc="Removing unneeded columns",
92-
)
93-
94-
return raw_dataset
95-
96-
def get_remove_columns_from_dataset(
97-
self, raw_dataset: Union[DatasetDict, Dataset]
98-
) -> List[str]:
99-
"""Remove redandant columns from the dataset for processing"""
100-
101-
remove_columns = raw_dataset.column_names
102-
if isinstance(remove_columns, Dict):
103-
remove_columns = raw_dataset[list(raw_dataset.keys())[0]].column_names
104-
105-
remove_columns = set(remove_columns)
106-
if self.text_column in remove_columns:
107-
remove_columns.remove(self.text_column)
108-
if self.PROMPT_KEY in remove_columns:
109-
remove_columns.remove(self.PROMPT_KEY)
110-
111-
return list(remove_columns)
30+
pass

src/llmcompressor/transformers/finetune/data/data_args.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from dataclasses import dataclass, field
2-
from typing import Callable, Dict, List, Optional, Union
2+
from typing import Any, Callable, Dict, List, Optional, Union
3+
4+
from transformers import DefaultDataCollator
35

46

57
@dataclass
@@ -31,26 +33,38 @@ class CustomDataTrainingArguments(DVCDatasetTrainingArguments):
3133
},
3234
)
3335

34-
text_column: Optional[str] = field(
36+
text_column: str = field(
3537
default="text",
36-
metadata={"help": "For custom datasets only. The text field key"},
38+
metadata={
39+
"help": (
40+
"Optional key to be used as the `text` input to tokenizer/processor "
41+
"after dataset preprocesssing"
42+
)
43+
},
3744
)
3845

3946
remove_columns: Union[None, str, List] = field(
4047
default=None,
41-
metadata={"help": "Column names to remove after preprocessing custom datasets"},
48+
metadata={"help": "Column names to remove after preprocessing (deprecated)"},
4249
)
4350

4451
preprocessing_func: Union[None, str, Callable] = field(
4552
default=None,
4653
metadata={
4754
"help": (
48-
"The preprocessing function to apply or the preprocessing func name in "
49-
"src/llmcompressor/transformers/utils/preprocessing_functions.py"
55+
"Typically a function which applies a chat template. Can take the form "
56+
"of either a function to apply to the dataset, a name defined in "
57+
"src/llmcompressor/transformers/utils/preprocessing_functions.py, or "
58+
"a path to a function definition of the form /path/to/file.py:func"
5059
)
5160
},
5261
)
5362

63+
data_collator: Callable[[Any], Any] = field(
64+
default_factory=lambda: DefaultDataCollator(),
65+
metadata={"help": "The function to used to form a batch from the dataset"},
66+
)
67+
5468

5569
@dataclass
5670
class DataTrainingArguments(CustomDataTrainingArguments):
@@ -91,8 +105,8 @@ class DataTrainingArguments(CustomDataTrainingArguments):
91105
"help": "Whether or not to concatenate datapoints to fill max_seq_length"
92106
},
93107
)
94-
raw_kwargs: Optional[Dict] = field(
95-
default=None,
108+
raw_kwargs: Dict = field(
109+
default_factory=dict,
96110
metadata={"help": "Additional keyboard args to pass to datasets load_data"},
97111
)
98112
splits: Union[None, str, List, Dict] = field(

src/llmcompressor/transformers/finetune/data/evolcodealpaca.py

Lines changed: 17 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from copy import deepcopy
15-
from typing import Optional
15+
from typing import TYPE_CHECKING
1616

1717
from llmcompressor.transformers.finetune.data import TextGenerationDataset
18+
from llmcompressor.typing import Processor
19+
20+
if TYPE_CHECKING:
21+
from llmcompressor.transformers import DataTrainingArguments as DataArgs
1822

1923

2024
@TextGenerationDataset.register(name="evolcodealpaca")
@@ -34,40 +38,20 @@ class EvolCodeAlpacaDataset(TextGenerationDataset):
3438
"\n\n### Response:\n"
3539
)
3640

37-
def __init__(self, data_args, split, processor):
41+
def __init__(self, data_args: "DataArgs", split: str, processor: Processor):
3842
data_args = deepcopy(data_args)
3943
data_args.dataset = "theblackcat102/evol-codealpaca-v1"
40-
super().__init__(
41-
text_column="text", data_args=data_args, split=split, processor=processor
42-
)
43-
44-
def get_raw_dataset(self, cache_dir: Optional[str] = None):
45-
"""
46-
Load the raw dataset from Hugging Face, using cached copy if available.
47-
Additionally reformats the entries to fit the alpaca template.
44+
data_args.text_column = "text"
4845

49-
:param cache_dir: disk location to search for cached dataset
50-
:return: the requested dataset
51-
"""
52-
raw_dataset = super().get_raw_dataset(cache_dir=cache_dir)
46+
super().__init__(data_args, split=split, processor=processor)
5347

54-
# helper fn for restructuring each dataset entry using the alpaca template
55-
def restructure_fn(sample):
56-
sample["text"] = self.EVOL_ALPACA_TEMPLATE.format(
57-
instruction=sample["instruction"]
58-
)
59-
sample[self.PROMPT_KEY] = sample["text"]
60-
if "output" in sample:
61-
sample["text"] += sample["output"]
62-
return sample
48+
def dataset_template(self, sample):
49+
prompt = self.EVOL_ALPACA_TEMPLATE.format(instruction=sample["instruction"])
50+
text = prompt
51+
if "output" in text:
52+
text += sample["output"]
6353

64-
raw_dataset = self.map(
65-
raw_dataset,
66-
function=restructure_fn,
67-
batched=False,
68-
remove_columns=["output", "instruction"],
69-
num_proc=self.data_args.preprocessing_num_workers,
70-
load_from_cache_file=not self.data_args.overwrite_cache,
71-
desc="Restructuring Evol Code Alpaca Dataset",
72-
)
73-
return raw_dataset
54+
return {
55+
"text": text,
56+
self.PROMPT_KEY: prompt,
57+
}
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from copy import deepcopy
2+
from typing import TYPE_CHECKING
3+
4+
from loguru import logger
5+
6+
from llmcompressor.transformers.finetune.data import TextGenerationDataset
7+
from llmcompressor.typing import Processor
8+
9+
if TYPE_CHECKING:
10+
from llmcompressor.transformers import DataTrainingArguments as DataArgs
11+
12+
13+
@TextGenerationDataset.register(name="flickr", alias="flickr30k")
14+
class Flickr30K(TextGenerationDataset):
15+
"""
16+
:param data_args: configuration settings for dataset loading
17+
:param split: split from dataset to load, for instance `test` or `train[:5%]`
18+
:param processor: processor or tokenizer to use on dataset
19+
"""
20+
21+
DEFAULT_CHAT_TEMPLATE = (
22+
"{% for message in messages %}\n"
23+
"{% if message['role'] == 'user' %}\n"
24+
"{{ '<|user|>\n' + message['content'] + eos_token }}\n"
25+
"{% elif message['role'] == 'system' %}\n"
26+
"{{ '<|system|>\n' + message['content'] + eos_token }}\n"
27+
"{% elif message['role'] == 'assistant' %}\n"
28+
"{{ '<|assistant|>\n' + message['content'] + eos_token }}\n"
29+
"{% endif %}\n"
30+
"{% if loop.last and add_generation_prompt %}\n"
31+
"{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
32+
)
33+
34+
def __init__(self, data_args: "DataArgs", split: str, processor: Processor):
35+
data_args = deepcopy(data_args)
36+
data_args.dataset = "lmms-lab/flickr30k"
37+
38+
super().__init__(data_args=data_args, split=split, processor=processor)
39+
40+
if (
41+
self.tokenizer is not None
42+
and getattr(self.tokenizer, "chat_template", None) is None
43+
):
44+
# note that since tokenizer is a member of processor,
45+
# this change affects processor.apply_chat_template
46+
self.tokenizer.chat_template = self.DEFAULT_CHAT_TEMPLATE
47+
logger.warning(
48+
"tokenizer.chat_template is not set, using default chat template for "
49+
f"{self.__class__.__name__}"
50+
)
51+
52+
def dataset_template(self, sample):
53+
messages = [
54+
{
55+
"role": "user",
56+
"content": [
57+
{"type": "image"},
58+
{"type": "text", "text": "What does the image show?"},
59+
],
60+
}
61+
]
62+
return {
63+
"text": self.processor.apply_chat_template(
64+
messages,
65+
add_generation_prompt=True,
66+
),
67+
"images": sample["image"],
68+
}

0 commit comments

Comments
 (0)