Skip to content

Commit da6f627

Browse files
committed
Add sparse attention integration to llm_eval
Signed-off-by: Kai Xu <[email protected]>
1 parent 1da834b commit da6f627

File tree

15 files changed

+318
-54
lines changed

15 files changed

+318
-54
lines changed

examples/llm_eval/lm_eval_hf.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,11 @@
4343
from lm_eval.api.model import T
4444
from lm_eval.models.huggingface import HFLM
4545
from quantization_utils import quantize_model
46+
from sparse_attention_utils import sparsify_model
4647

4748
import modelopt.torch.opt as mto
4849
from modelopt.torch.quantization.utils import is_quantized
50+
from modelopt.torch.sparsity.attention_sparsity.conversion import is_attn_sparsified
4951

5052

5153
def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict | None = None) -> T:
@@ -57,9 +59,20 @@ def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict |
5759
calib_size = arg_dict.pop("calib_size", 512)
5860
compress = arg_dict.pop("compress", False)
5961

62+
# Sparse attention arguments
63+
sparse_cfg = arg_dict.pop("sparse_cfg", None)
64+
6065
additional_config = {} if additional_config is None else additional_config
6166
additional_config = {k: v for k, v in additional_config.items() if v is not None}
6267

68+
# Force eager attention if sparse attention is requested
69+
if sparse_cfg:
70+
additional_config["attn_implementation"] = "eager"
71+
warnings.warn(
72+
"Sparse attention requires attn_implementation='eager'. "
73+
"Forcing eager attention implementation."
74+
)
75+
6376
# Enable automatic save/load of modelopt state huggingface checkpointing
6477
mto.enable_huggingface_checkpointing()
6578

@@ -85,6 +98,15 @@ def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict |
8598
compress=compress,
8699
)
87100

101+
if sparse_cfg:
102+
if is_attn_sparsified(model_obj.model):
103+
warnings.warn("Skipping sparse attention: model already has sparse attention applied.")
104+
else:
105+
sparsify_model(
106+
model=model_obj,
107+
sparse_cfg=sparse_cfg,
108+
)
109+
88110
return model_obj
89111

90112

@@ -120,6 +142,11 @@ def setup_parser_with_modelopt_args():
120142
action="store_true",
121143
help="Compress the model after quantization",
122144
)
145+
parser.add_argument(
146+
"--sparse_cfg",
147+
type=str,
148+
help="Sparse attention configuration (e.g., SKIP_SOFTMAX_DEFAULT, SKIP_SOFTMAX_CALIB)",
149+
)
123150
return parser
124151

125152

@@ -142,6 +169,7 @@ def setup_parser_with_modelopt_args():
142169
"calib_batch_size": args.calib_batch_size,
143170
"calib_size": args.calib_size,
144171
"compress": args.compress,
172+
"sparse_cfg": args.sparse_cfg,
145173
}
146174
)
147175

examples/llm_eval/mmlu.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from fire import Fire
4949
from modeling import EvalModel, select_model
5050
from quantization_utils import MAX_SEQ_LEN, get_tokenizer, quantize_model
51+
from sparse_attention_utils import sparsify_model
5152
from tqdm import tqdm
5253

5354
try:
@@ -56,6 +57,7 @@
5657
LLM = None # type: ignore[misc]
5758
import modelopt.torch.opt as mto
5859
from modelopt.torch.quantization.utils import is_quantized
60+
from modelopt.torch.sparsity.attention_sparsity.conversion import is_attn_sparsified
5961

6062
os.environ["TOKENIZERS_PARALLELISM"] = "false"
6163

@@ -227,6 +229,7 @@ def main(
227229
batch_size: int = 0,
228230
calib_size: int = 512,
229231
dtype: str = "bfloat16",
232+
sparse_cfg: str | None = None,
230233
**kwargs,
231234
):
232235
random.seed(RAND_SEED)
@@ -263,6 +266,14 @@ def main(
263266
max_batch_size=1,
264267
)
265268
else:
269+
# Force eager attention if sparse attention is requested
270+
if sparse_cfg:
271+
kwargs["attn_implementation"] = "eager"
272+
warnings.warn(
273+
"Sparse attention requires attn_implementation='eager'. "
274+
"Forcing eager attention implementation."
275+
)
276+
266277
model = select_model(
267278
max_input_length=MAX_SEQ_LEN, max_output_length=2, dtype=dtype, **kwargs
268279
)
@@ -283,6 +294,20 @@ def main(
283294
auto_quantize_bits=auto_quantize_bits,
284295
)
285296

297+
# Apply sparse attention if requested
298+
if sparse_cfg:
299+
model.load()
300+
301+
if is_attn_sparsified(model.model):
302+
warnings.warn(
303+
"Skipping sparse attention: model already has sparse attention applied."
304+
)
305+
else:
306+
sparsify_model(
307+
model=model,
308+
sparse_cfg=sparse_cfg,
309+
)
310+
286311
for subject in tqdm(subjects):
287312
dev_df = pd.read_csv(os.path.join(data_dir, "dev", subject + "_dev.csv"), header=None)[
288313
:ntrain

examples/llm_eval/modeling.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ class SeqToSeqModel(EvalModel):
179179
lora_path: str = ""
180180
device: str = "cuda"
181181
load_8bit: bool = False
182+
attn_implementation: str | None = None
182183

183184
def load(self):
184185
if self.model is None:
@@ -188,6 +189,8 @@ def load(self):
188189
if self.load_8bit:
189190
args.update(device_map="auto", load_in_8bit=True)
190191
args.update(torch_dtype=getattr(torch, self.dtype) if self.dtype != "auto" else "auto")
192+
if self.attn_implementation:
193+
args["attn_implementation"] = self.attn_implementation
191194
self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_path, **args)
192195
print_gpu_utilization()
193196
if self.lora_path:
@@ -241,6 +244,8 @@ def load(self):
241244
if self.load_8bit:
242245
args.update(device_map="auto", load_in_8bit=True)
243246
args.update(torch_dtype=getattr(torch, self.dtype) if self.dtype != "auto" else "auto")
247+
if self.attn_implementation:
248+
args["attn_implementation"] = self.attn_implementation
244249
self.model = AutoModelForCausalLM.from_pretrained(
245250
self.model_path, trust_remote_code=True, **args
246251
)
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Utilities for sparse attention integration with llm_eval."""
17+
18+
import modelopt.torch.sparsity.attention_sparsity as mtsa
19+
20+
# Custom sparse attention configurations
21+
CUSTOM_SPARSE_CONFIG = {
22+
"SPARSE_CONSERVATIVE": {
23+
"sparse_cfg": {
24+
"*attn*": {
25+
"method": "flash_skip_softmax",
26+
"threshold": {"prefill": 5e-4, "decode": 1e-5},
27+
"br": 128,
28+
"bc": 128,
29+
"backend": "pytorch",
30+
"enable": True,
31+
},
32+
"default": {"enable": False},
33+
},
34+
},
35+
"SPARSE_AGGRESSIVE": {
36+
"sparse_cfg": {
37+
"*attn*": {
38+
"method": "flash_skip_softmax",
39+
"threshold": {"prefill": 5e-3, "decode": 5e-4},
40+
"br": 128,
41+
"bc": 128,
42+
"backend": "pytorch",
43+
"enable": True,
44+
},
45+
"default": {"enable": False},
46+
},
47+
},
48+
}
49+
50+
51+
def _extract_model(model_obj):
52+
"""Extract actual model from wrapper (HFLM or EvalModel)."""
53+
if hasattr(model_obj, "gpt2"):
54+
return model_obj.gpt2
55+
elif hasattr(model_obj, "model"):
56+
return model_obj.model
57+
else:
58+
return model_obj
59+
60+
61+
def sparsify_model(
62+
model,
63+
sparse_cfg: str,
64+
backend=None,
65+
):
66+
"""Apply sparse attention to model with optional RULER calibration.
67+
68+
Args:
69+
model: Model wrapper (HFLM or EvalModel) or raw model
70+
sparse_cfg: Sparse attention config name or dict
71+
backend: Backend to use (optional, overrides config backend)
72+
73+
Returns:
74+
The model with sparse attention applied
75+
76+
Note:
77+
Calibration is automatically triggered if the config contains a 'calibration' field.
78+
The calibration will auto-generate RULER dataset from the model's tokenizer.
79+
"""
80+
# Extract actual model
81+
net = _extract_model(model)
82+
83+
# Resolve config
84+
if isinstance(sparse_cfg, str):
85+
# Try custom configs first
86+
mtsa_cfg = CUSTOM_SPARSE_CONFIG.get(sparse_cfg)
87+
if mtsa_cfg is None:
88+
# Try predefined configs
89+
mtsa_cfg = getattr(mtsa, sparse_cfg, None)
90+
if mtsa_cfg is None:
91+
raise ValueError(f"Unknown sparse_cfg: {sparse_cfg}")
92+
else:
93+
mtsa_cfg = sparse_cfg
94+
95+
# Override backend if specified
96+
if backend:
97+
if isinstance(mtsa_cfg, dict) and "sparse_cfg" in mtsa_cfg:
98+
modified_sparse_cfg = {}
99+
for pattern, cfg in mtsa_cfg["sparse_cfg"].items():
100+
modified_cfg = cfg.copy() if isinstance(cfg, dict) else cfg
101+
if isinstance(modified_cfg, dict):
102+
modified_cfg["backend"] = backend
103+
modified_sparse_cfg[pattern] = modified_cfg
104+
mtsa_cfg = {"sparse_cfg": modified_sparse_cfg}
105+
106+
# Apply sparsification
107+
print(f"\nApplying sparse attention with config: {sparse_cfg}")
108+
mtsa.sparsify(net, mtsa_cfg)
109+
print("Sparse attention applied successfully!")
110+
111+
return model

examples/llm_sparsity/attention_sparsity/hf_sa.py

Lines changed: 16 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -196,29 +196,6 @@ def generate_text(model, inputs, args, tokenizer):
196196
print("\nOutputs differ")
197197

198198

199-
def sparsify_model(model, args):
200-
"""Apply sparse attention to the model with optional calibration."""
201-
print(f"\nApplying sparse attention: {args.sparse_attn} with backend: {args.backend}")
202-
base_config = SPARSE_ATTN_CFG_CHOICES[args.sparse_attn]
203-
204-
# Create modified config with selected backend
205-
modified_sparse_cfg = {}
206-
for pattern, cfg in base_config["sparse_cfg"].items():
207-
modified_cfg = cfg.copy()
208-
modified_cfg["backend"] = args.backend
209-
modified_sparse_cfg[pattern] = modified_cfg
210-
211-
# Create new config with modified settings
212-
sparse_config = SparseAttentionConfig(sparse_cfg=modified_sparse_cfg)
213-
214-
# Sparsify the model
215-
model = mtsa.sparsify(model, config=sparse_config)
216-
217-
print("Sparse attention applied successfully!")
218-
219-
return model
220-
221-
222199
def main(args):
223200
"""Main function to run the selected mode."""
224201
if not torch.cuda.is_available():
@@ -249,8 +226,22 @@ def main(args):
249226
model = model.cuda()
250227
print("Model moved to CUDA")
251228

252-
# Apply sparse attention to the model (with calibration if configured)
253-
model = sparsify_model(model, args)
229+
# Apply sparse attention with optional calibration
230+
print(f"\nApplying sparse attention: {args.sparse_attn} with backend: {args.backend}")
231+
base_config = SPARSE_ATTN_CFG_CHOICES[args.sparse_attn]
232+
233+
# Create modified config with selected backend
234+
modified_sparse_cfg = {}
235+
for pattern, cfg in base_config["sparse_cfg"].items():
236+
modified_cfg = cfg.copy()
237+
modified_cfg["backend"] = args.backend
238+
modified_sparse_cfg[pattern] = modified_cfg
239+
240+
# Create config and apply sparsification (calibration happens automatically)
241+
sparse_config = SparseAttentionConfig(sparse_cfg=modified_sparse_cfg)
242+
model = mtsa.sparsify(model, config=sparse_config)
243+
244+
print("Sparse attention applied successfully!")
254245

255246
# Verify outputs if requested (compares baseline vs calibrated sparse model)
256247
if args.verify_output:

modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from transformers import AutoTokenizer
2525

2626
from ..config import CalibrationConfig
27+
from ..conversion import print_sparse_attention_summary
2728
from ..sparse_attention import SparseAttentionModule
2829
from .calibrator import DynamicThresholdCalibrator
2930
from .dataset import RulerDatasetBuilder
@@ -162,6 +163,10 @@ def calibrate_sparse_attention(
162163
)
163164
calibration_result = calibrator.calibrate(model, forward_loop)
164165

166+
# Print calibration statistics (regardless of success/failure for debugging)
167+
print("\nCalibration complete!")
168+
print_sparse_attention_summary(model)
169+
165170
if "scale_factor" not in calibration_result:
166171
warnings.warn("Calibration did not produce valid results")
167172
return {}

modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def __init__(
185185
self,
186186
samples: int,
187187
max_seqlen: int,
188-
tokenizer_name_or_path: str,
188+
tokenizer_name_or_path: str | object,
189189
seed: int = 42,
190190
num_length_bins: int = 4,
191191
max_length_filter: int = 65536,
@@ -195,7 +195,7 @@ def __init__(
195195
Args:
196196
samples: Total number of samples to generate (distributed evenly across length bins)
197197
max_seqlen: Maximum sequence length (length bins auto-generated as powers of 2)
198-
tokenizer_name_or_path: HuggingFace tokenizer path
198+
tokenizer_name_or_path: HuggingFace tokenizer path or tokenizer object
199199
seed: Random seed for reproducibility
200200
num_length_bins: Number of length bins to generate (default: 4)
201201
max_length_filter: Maximum sequence length to keep (default: 65536)
@@ -229,8 +229,11 @@ def __init__(
229229
# Distribute samples evenly across lengths
230230
self.samples_per_length = [samples // len(self.target_lengths)] * len(self.target_lengths)
231231

232-
# Initialize tokenizer and seed
233-
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
232+
# Initialize tokenizer
233+
if isinstance(tokenizer_name_or_path, str):
234+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
235+
else:
236+
self.tokenizer = tokenizer_name_or_path
234237
random.seed(seed)
235238

236239
def build_calibration_dataset(self) -> list[dict[str, Any]]:
@@ -247,9 +250,7 @@ def build_calibration_dataset(self) -> list[dict[str, Any]]:
247250
desc="Generating RULER calibration samples",
248251
total=len(self.target_lengths),
249252
):
250-
samples_per_task = num_samples // len(self.subtasks)
251-
if samples_per_task <= 0:
252-
continue
253+
samples_per_task = max(num_samples // len(self.subtasks), 1)
253254

254255
# Generate equal samples for each task
255256
for task_name in self.subtasks:

0 commit comments

Comments
 (0)