Skip to content

Commit ffa570c

Browse files
authored
[Tracing][Testing] Add tracing tests (#1335)
## Purpose ## * Add regression testing to model tracing beyond example tests * These tests complete in ~1 min and can be run at a quicker cadence than example tests * These can also be used to test tracing capabilities beyond those in the examples, for example tracing into linear layers ## Prerequisites ## * #1334 * #1402 ## Changes ## * Fix function signature of peoples speech dataset * Add `trust_remote_code` argument to debugger * Add `tests/llmcompressor/transformers/tracing/models.py` * I did not include phi3 because it's a very difficult model to work with programmatically. I will revisit once the major tracing improvements have landed --------- Signed-off-by: Kyle Sayers <[email protected]>
1 parent 0effb4e commit ffa570c

File tree

5 files changed

+136
-13
lines changed

5 files changed

+136
-13
lines changed

.github/workflows/test-check-transformers.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ on:
88

99
env:
1010
CADENCE: "commit"
11+
HF_TOKEN: ${{ secrets.HF_TOKEN_READ }}
1112

1213
jobs:
1314
detect-changes:
@@ -95,6 +96,10 @@ jobs:
9596
if: (success() || failure()) && steps.install.outcome == 'success'
9697
run: |
9798
pytest -v tests/llmcompressor/transformers/obcq
99+
- name: Running Tracing Tests
100+
if: (success() || failure()) && steps.install.outcome == 'success'
101+
run: |
102+
pytest -v tests/llmcompressor/transformers/tracing
98103
- name: Running KV Cache Tests
99104
if: (success() || failure()) && steps.install.outcome == 'success'
100105
run: |

src/llmcompressor/transformers/finetune/data/peoples_speech.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,20 +26,20 @@ class PeoplesSpeech(TextGenerationDataset):
2626
:param processor: processor or tokenizer to use on dataset
2727
"""
2828

29-
def __init__(self, data_args: "DataArgs", split: str, processor: Processor):
30-
data_args = deepcopy(data_args)
31-
data_args.dataset = "MLCommons/peoples_speech"
32-
data_args.dataset_config_name = "test"
33-
if not data_args.overwrite_cache:
29+
def __init__(self, dataset_args: "DataArgs", split: str, processor: Processor):
30+
dataset_args = deepcopy(dataset_args)
31+
dataset_args.dataset = "MLCommons/peoples_speech"
32+
dataset_args.dataset_config_name = "test"
33+
if not dataset_args.overwrite_cache:
3434
logger.warning(
3535
"Because audio processors are more complex, dataset mapping functions "
3636
"vary with model architecture and their results cannot be cached. "
3737
"Setting overwrite_cache=True"
3838
)
39-
data_args.overwrite_cache = True
39+
dataset_args.overwrite_cache = True
4040
self.processor_type = processor.__class__.__name__
4141

42-
super().__init__(data_args=data_args, split=split, processor=processor)
42+
super().__init__(dataset_args=dataset_args, split=split, processor=processor)
4343

4444
def dataset_template(self, example):
4545
audio = example["audio"]["array"]

src/llmcompressor/transformers/tracing/debug.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from llmcompressor.transformers import TextGenerationDataset
1313
from llmcompressor.args import DatasetArguments
1414

15+
from llmcompressor.utils.dev import skip_weights_download
16+
1517
__all__ = [
1618
"get_model_class"
1719
]
@@ -24,6 +26,7 @@ def parse_args():
2426
parser.add_argument("--sequential_targets", type=str, nargs="*", default=None, metavar="TARGET", help="List of targets for sequential tracing") # noqa: E501
2527
parser.add_argument("--ignore", type=str, nargs="*", default=[], metavar="PATTERN", help="List of patterns to ignore during tracing") # noqa: E501
2628
parser.add_argument("--modality", type=str, default="text", help="Modality of calibration dataset, defaults to text") # noqa: E501
29+
parser.add_argument("--trust_remote_code", type=bool, default=False, help="Whether to trust model remote code") # noqa: E501
2730
return parser.parse_args()
2831

2932

@@ -33,6 +36,7 @@ def trace(
3336
sequential_targets: Optional[Union[List[str], str]] = None,
3437
ignore: Union[List[str], str] = [],
3538
modality: str = "text",
39+
trust_remote_code: bool = True
3640
):
3741
"""
3842
Debug traceability by tracing a pre-trained model into subgraphs
@@ -44,6 +48,7 @@ def trace(
4448
inference
4549
:param ignore: patterns to ignore during tracing
4650
:param modality: data modality for dummy tracing data, defaults to 'text'
51+
:param trust_remote_code: trust remote model code
4752
4853
Example usage from CLI
4954
llmcompressor.trace \
@@ -54,12 +59,16 @@ def trace(
5459
--modality text
5560
"""
5661
# Load model
57-
model = model_class.from_pretrained(
58-
model_id,
59-
device_map="auto",
60-
torch_dtype="auto",
62+
with skip_weights_download(model_class):
63+
model = model_class.from_pretrained(
64+
model_id,
65+
device_map="cpu",
66+
torch_dtype="auto",
67+
trust_remote_code=trust_remote_code,
68+
)
69+
processor = AutoProcessor.from_pretrained(
70+
model_id, trust_remote_code=trust_remote_code
6171
)
62-
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
6372
print("Loaded model")
6473

6574
# Prepare sample data
@@ -138,6 +147,7 @@ def main():
138147
sequential_targets=args.sequential_targets,
139148
ignore=args.ignore,
140149
modality=args.modality,
150+
trust_remote_code=args.trust_remote_code
141151
)
142152

143153

src/llmcompressor/transformers/tracing/idefics3.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ def __init__(self, config: Idefics3Config):
285285

286286
def forward(
287287
self,
288-
input_ids: torch.LongTensor = None,
288+
input_ids: Optional[torch.LongTensor] = None,
289289
attention_mask: Optional[torch.Tensor] = None,
290290
position_ids: Optional[torch.LongTensor] = None,
291291
past_key_values: Optional[List[torch.FloatTensor]] = None,
@@ -296,6 +296,7 @@ def forward(
296296
use_cache: Optional[bool] = None,
297297
output_attentions: Optional[bool] = None,
298298
output_hidden_states: Optional[bool] = None,
299+
cache_position: Optional[torch.LongTensor] = None,
299300
return_dict: Optional[bool] = None,
300301
) -> Union[Tuple, Idefics3BaseModelOutputWithPast]:
301302
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -394,6 +395,7 @@ def forward(
394395
use_cache=use_cache,
395396
output_attentions=output_attentions,
396397
output_hidden_states=output_hidden_states,
398+
cache_position=cache_position,
397399
return_dict=return_dict,
398400
)
399401

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import pytest
2+
from transformers import AutoModelForCausalLM, WhisperForConditionalGeneration
3+
4+
from llmcompressor.transformers.tracing import (
5+
TraceableIdefics3ForConditionalGeneration,
6+
TraceableLlavaForConditionalGeneration,
7+
TraceableMllamaForConditionalGeneration,
8+
TraceableQwen2_5_VLForConditionalGeneration,
9+
TraceableQwen2VLForConditionalGeneration,
10+
)
11+
from llmcompressor.transformers.tracing.debug import trace
12+
13+
14+
@pytest.mark.parametrize(
15+
"model_id,model_class,targets",
16+
[
17+
("meta-llama/Meta-Llama-3-8B-Instruct", AutoModelForCausalLM, None),
18+
],
19+
)
20+
def test_text_trace(model_id, model_class, targets):
21+
trace(
22+
model_id,
23+
model_class,
24+
targets,
25+
ignore=[],
26+
modality="text",
27+
trust_remote_code=True,
28+
)
29+
30+
31+
@pytest.mark.parametrize(
32+
"model_id,model_class,targets,ignore",
33+
[
34+
(
35+
"HuggingFaceM4/Idefics3-8B-Llama3",
36+
TraceableIdefics3ForConditionalGeneration,
37+
["LlamaDecoderLayer"],
38+
[],
39+
),
40+
(
41+
"llava-hf/llava-1.5-7b-hf",
42+
TraceableLlavaForConditionalGeneration,
43+
["LlamaDecoderLayer"],
44+
[],
45+
),
46+
(
47+
"meta-llama/Llama-3.2-11B-Vision-Instruct",
48+
TraceableMllamaForConditionalGeneration,
49+
["MllamaSelfAttentionDecoderLayer"],
50+
[],
51+
),
52+
# skip phi3_v because of its processor is annoying and requires special code
53+
(
54+
"mgoin/pixtral-12b",
55+
TraceableLlavaForConditionalGeneration,
56+
["MistralDecoderLayer"],
57+
[],
58+
),
59+
(
60+
"Qwen/Qwen2.5-VL-7B-Instruct",
61+
TraceableQwen2_5_VLForConditionalGeneration,
62+
["Qwen2_5_VLDecoderLayer"],
63+
[],
64+
),
65+
(
66+
"Qwen/Qwen2-VL-2B-Instruct",
67+
TraceableQwen2VLForConditionalGeneration,
68+
["Qwen2VLDecoderLayer"],
69+
[],
70+
),
71+
],
72+
)
73+
def test_vision_trace(model_id, model_class, targets, ignore):
74+
trace(
75+
model_id,
76+
model_class,
77+
targets,
78+
ignore=ignore,
79+
modality="vision",
80+
trust_remote_code=True,
81+
)
82+
83+
84+
@pytest.mark.parametrize(
85+
"model_id,model_class,targets,ignore",
86+
[
87+
(
88+
"openai/whisper-large-v3",
89+
WhisperForConditionalGeneration,
90+
None,
91+
[],
92+
),
93+
],
94+
)
95+
def test_audio_trace(model_id, model_class, targets, ignore):
96+
pytest.importorskip("librosa")
97+
pytest.importorskip("soundfile")
98+
99+
trace(
100+
model_id,
101+
model_class,
102+
targets,
103+
ignore=ignore,
104+
modality="audio",
105+
trust_remote_code=True,
106+
)

0 commit comments

Comments
 (0)