Skip to content

Commit 6457911

Browse files
authored
Add QAIRT MHA2SHA transformation support (#2064)
## Describe your changes - Add MHA2SHA pass for QAIRT SDK that enables better performance on Qualcomm NPU ## Checklist before requesting a review - [x] Add unit tests for this change. - [x] Make sure all tests can pass. - [x] Update documents if necessary. - [x] Lint and apply fixes to your code by running `lintrunner -a` - [x] Is this a user-facing change? If yes, give a description of this change to be included in the release notes. - [x] Is this PR including examples changes? If yes, please remember to update [example documentation](https://github.com/microsoft/Olive/blob/main/docs/source/examples.md) in a follow-up PR.
1 parent 6cb62ae commit 6457911

File tree

4 files changed

+298
-0
lines changed

4 files changed

+298
-0
lines changed

olive/olive_config.json

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,14 @@
472472
"supported_algorithms": [ ],
473473
"supported_quantization_encodings": [ ]
474474
},
475+
"QairtMHA2SHA": {
476+
"module_path": "olive.passes.onnx.qairt.mha2sha.QairtMHA2SHA",
477+
"supported_providers": [ "QNNExecutionProvider" ],
478+
"supported_accelerators": [ "npu" ],
479+
"supported_precisions": [ "*" ],
480+
"supported_algorithms": [ ],
481+
"supported_quantization_encodings": [ ]
482+
},
475483
"QLoRA": {
476484
"module_path": "olive.passes.pytorch.lora.QLoRA",
477485
"supported_providers": [ "*" ],

olive/passes/onnx/qairt/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# -------------------------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# Licensed under the MIT License.
4+
# --------------------------------------------------------------------------

olive/passes/onnx/qairt/mha2sha.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# -------------------------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# Licensed under the MIT License.
4+
# --------------------------------------------------------------------------
5+
6+
import logging
7+
from copy import deepcopy
8+
from pathlib import Path
9+
from typing import Any
10+
11+
from olive.hardware import AcceleratorSpec
12+
from olive.model import ONNXModelHandler
13+
from olive.passes.olive_pass import Pass
14+
from olive.passes.pass_config import BasePassConfig, PassConfigParam
15+
16+
logger = logging.getLogger(__name__)
17+
18+
19+
class QairtMHA2SHA(Pass):
20+
"""Runs QAIRT MHA to SHA transformation on ONNX model splits and saves the transformed models.
21+
22+
Uses transformation API from the QAIRT SDK.
23+
"""
24+
25+
@classmethod
26+
def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassConfigParam]:
27+
return {
28+
"mha2sha_kwargs": PassConfigParam(
29+
type_=dict[str, Any],
30+
default_value=None,
31+
description="Additional parameters to be passed to the MHA2SHA transformation function.",
32+
),
33+
}
34+
35+
def _run_for_config(
36+
self,
37+
model: ONNXModelHandler,
38+
config: type[BasePassConfig],
39+
output_model_path: str,
40+
) -> ONNXModelHandler:
41+
if not isinstance(model, ONNXModelHandler):
42+
raise NotImplementedError(
43+
f"QairtMHA2SHA pass only supports ONNXModelHandler, but received type {type(model)}"
44+
)
45+
46+
try:
47+
from qti.aisw.tools.core.utilities.framework.frameworks.onnx import OnnxModel
48+
except ImportError:
49+
try:
50+
# Backwards compatibility with older locations of OnnxModel in <= QAIRT 2.36.1
51+
from qti.aisw.tools.core.utilities.framework.onnx import OnnxModel
52+
except ImportError as e:
53+
raise ImportError("Please install qti.aisw.tools and all dependencies to use QairtMHA2SHA.") from e
54+
55+
qairt_onnx_model = OnnxModel.load(model_path=model.model_path)
56+
try:
57+
qairt_onnx_model.mha2sha_v2(**(config.mha2sha_kwargs if config.mha2sha_kwargs is not None else {}))
58+
except AttributeError:
59+
# Backwards compatibility with older definitions of OnnxModel in <= QAIRT 2.37
60+
logger.warning("MHA2SHA V2 is not available for this SDK version, defaulting to MHA2SHA V1")
61+
qairt_onnx_model.mha2sha(**(config.mha2sha_kwargs if config.mha2sha_kwargs is not None else {}))
62+
63+
qairt_onnx_model.export(output_model_path, prefix=Path(model.model_path).stem)
64+
65+
return ONNXModelHandler(
66+
model_path=output_model_path,
67+
onnx_file_name=model.onnx_file_name,
68+
model_attributes=deepcopy(model.model_attributes),
69+
)
Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
# -------------------------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# Licensed under the MIT License.
4+
# --------------------------------------------------------------------------
5+
6+
import builtins
7+
import os
8+
from pathlib import Path
9+
from unittest.mock import MagicMock, patch
10+
11+
import pytest
12+
13+
from olive.model import ONNXModelHandler
14+
from olive.passes.olive_pass import create_pass_from_dict
15+
from olive.passes.onnx.qairt.mha2sha import QairtMHA2SHA
16+
from olive.passes.pass_config import PassConfigParam
17+
from test.utils import get_onnx_model
18+
19+
20+
# Mock OnnxModel for external QAIRT SDK
21+
class MockOnnxModelInstance:
22+
"""A mock instance for qti.aisw.tools.core.utilities.framework.onnx.OnnxModel."""
23+
24+
def __init__(self, model_path):
25+
self.mock_mha2sha_v2 = MagicMock()
26+
self.mock_mha2sha = MagicMock()
27+
self._export_mock_internal = MagicMock()
28+
self.has_mha2sha_v2 = True # Default to having v2 for most tests
29+
30+
def mha2sha_v2(self, **kwargs):
31+
if self.has_mha2sha_v2:
32+
self.mock_mha2sha_v2(**kwargs)
33+
else:
34+
# Simulate AttributeError if mha2sha_v2 is not supposed to exist
35+
raise AttributeError("mha2sha_v2 not available")
36+
37+
def mha2sha(self, **kwargs):
38+
self.mock_mha2sha(**kwargs)
39+
40+
def export(self, output_path, prefix):
41+
# Call the internal mock for tracking purposes
42+
self._export_mock_internal(output_path, prefix)
43+
44+
# Actual file writing logic
45+
output_dir = Path(output_path)
46+
output_dir.mkdir(parents=True, exist_ok=True) # Ensure directory exists
47+
file_name = f"{prefix}.onnx"
48+
output_file_path = output_dir / file_name
49+
50+
with open(output_file_path, "w") as f:
51+
f.write(f"This is a dummy ONNX file for {prefix}.onnx\n")
52+
f.write(f"Generated by MockOnnxModelInstance at {output_file_path}\n")
53+
54+
55+
@pytest.fixture(name="qairt_pass_instance")
56+
def qairt_pass_instance_fixture():
57+
"""Provide an instance of the QairtMHA2SHA pass."""
58+
return create_pass_from_dict(QairtMHA2SHA, {}, disable_search=True)
59+
60+
61+
@pytest.fixture(name="mock_accelerator_spec")
62+
def mock_accelerator_spec_fixture():
63+
"""Provide a mock AcceleratorSpec."""
64+
return MagicMock()
65+
66+
67+
@pytest.fixture(name="tmp_output_dir")
68+
def tmp_output_dir_fixture(tmp_path):
69+
"""Provide a temporary output directory as a string."""
70+
return str(tmp_path)
71+
72+
73+
@pytest.fixture(name="mock_qairt_sdk_classes")
74+
def mock_qairt_sdk_classes_fixture():
75+
"""Mock the qti.aisw.tools.core.utilities.framework.onnx.OnnxModel import.
76+
77+
Returns the mock class itself, allowing tests to configure its return_value for .load().
78+
"""
79+
80+
# Create the mock instance provider function
81+
def _create_mock_onnx_model_instance(model_path):
82+
return MockOnnxModelInstance(model_path)
83+
84+
# Create MagicMocks for the OnnxModel class itself
85+
mock_onnx_model_class_new = MagicMock()
86+
mock_onnx_model_class_old = MagicMock()
87+
88+
# Configure the .load() method of the mocked OnnxModel to return our custom mock instance
89+
mock_onnx_model_class_new.load.side_effect = _create_mock_onnx_model_instance
90+
mock_onnx_model_class_old.load.side_effect = _create_mock_onnx_model_instance
91+
92+
# Patch sys.modules for the new import path
93+
patcher_new = patch.dict(
94+
"sys.modules",
95+
{"qti.aisw.tools.core.utilities.framework.frameworks.onnx": MagicMock(OnnxModel=mock_onnx_model_class_new)},
96+
)
97+
_ = patcher_new.start()
98+
99+
# Patch sys.modules for the old import path
100+
patcher_old = patch.dict(
101+
"sys.modules", {"qti.aisw.tools.core.utilities.framework.onnx": MagicMock(OnnxModel=mock_onnx_model_class_old)}
102+
)
103+
_ = patcher_old.start()
104+
105+
yield mock_onnx_model_class_new # Yield the mock for the 'new' path, as that's what will be tried first
106+
107+
# Teardown: Stop the patchers
108+
patcher_new.stop()
109+
patcher_old.stop()
110+
111+
112+
def test_mha2sha_default_config(mock_accelerator_spec):
113+
"""Test that the default config is correctly generated."""
114+
config = QairtMHA2SHA._default_config(mock_accelerator_spec) # pylint: disable=protected-access
115+
assert "mha2sha_kwargs" in config
116+
assert isinstance(config["mha2sha_kwargs"], PassConfigParam)
117+
assert config["mha2sha_kwargs"].default_value is None
118+
119+
120+
def test_mha2sha_for_onnx_model_handler(qairt_pass_instance, tmp_output_dir, mock_qairt_sdk_classes):
121+
"""Test run with a single ONNXModelHandler."""
122+
input_model = get_onnx_model()
123+
transformed_model = qairt_pass_instance.run(input_model, tmp_output_dir)
124+
125+
# Assertions
126+
assert isinstance(transformed_model, ONNXModelHandler)
127+
assert os.path.dirname(transformed_model.model_path) == tmp_output_dir
128+
assert transformed_model.onnx_file_name == input_model.onnx_file_name
129+
130+
# Verify QAIRT SDK calls
131+
mock_qairt_sdk_classes.load.assert_called_once_with(model_path=input_model.model_path)
132+
133+
# Retrieve the mock OnnxModelInstance that was actually returned and used
134+
loaded_qairt_instance = mock_qairt_sdk_classes.load.call_args.return_value
135+
loaded_qairt_instance.mock_mha2sha_v2.assert_called_once_with() # No kwargs passed by default
136+
loaded_qairt_instance.mock_mha2sha.assert_not_called() # Ensure V1 is not called
137+
loaded_qairt_instance.mock_export.assert_called_once_with(tmp_output_dir, prefix=input_model.onnx_file_name)
138+
139+
140+
def test_mha2sha_v1_fallback(qairt_pass_instance, tmp_output_dir, mock_qairt_sdk_classes):
141+
"""Test that the pass falls back to mha2sha (v1) if v2 is not available."""
142+
dummy_model = get_onnx_model()
143+
144+
original_side_effect_func = mock_qairt_sdk_classes.load.side_effect
145+
146+
def custom_side_effect_func(model_path):
147+
instance = original_side_effect_func(model_path)
148+
instance.has_mha2sha_v2 = False
149+
return instance
150+
151+
mock_qairt_sdk_classes.load.side_effect = custom_side_effect_func
152+
153+
_ = qairt_pass_instance.run(dummy_model, output_model_path=tmp_output_dir)
154+
155+
# Verify V1 was called and V2 was not
156+
loaded_qairt_instance = mock_qairt_sdk_classes.load.call_args
157+
loaded_qairt_instance.mock_mha2sha.assert_called_once_with()
158+
loaded_qairt_instance.mock_mha2sha_v2.assert_not_called()
159+
160+
161+
def test_mha2sha_kwargs_passed(tmp_output_dir, mock_qairt_sdk_classes):
162+
"""Test that additional kwargs are passed to mha2sha_v2/mha2sha."""
163+
dummy_model = get_onnx_model()
164+
mha2sha_pass = create_pass_from_dict(
165+
QairtMHA2SHA, {"mha2sha_kwargs": {"param1": "value1", "param2": 123}}, disable_search=True
166+
)
167+
mha2sha_pass.run(dummy_model, output_model_path=tmp_output_dir)
168+
169+
# Test with v2 available
170+
loaded_qairt_instance_v2 = mock_qairt_sdk_classes.load.call_args.return_value # Get instance from previous call
171+
loaded_qairt_instance_v2.mock_mha2sha_v2.assert_called_once_with(param1="value1", param2=123)
172+
loaded_qairt_instance_v2.mock_mha2sha.assert_not_called()
173+
174+
# Test with v1 fallback
175+
# Reset mocks for next part of test
176+
mock_qairt_sdk_classes.load.reset_mock()
177+
loaded_qairt_instance_v2.mock_mha2sha_v2.reset_mock()
178+
loaded_qairt_instance_v2.mock_mha2sha.reset_mock()
179+
180+
# Configure the mock OnnxModel instance to NOT have mha2sha_v2 for this part
181+
# We reuse the logic from test_mha2sha_v1_fallback
182+
original_side_effect_func = mock_qairt_sdk_classes.load.side_effect
183+
184+
def custom_side_effect_func_v1(model_path):
185+
instance = original_side_effect_func(model_path)
186+
instance.has_mha2sha_v2 = False
187+
return instance
188+
189+
mock_qairt_sdk_classes.load.side_effect = custom_side_effect_func_v1
190+
191+
mha2sha_pass.run(dummy_model, output_model_path=tmp_output_dir)
192+
loaded_qairt_instance_v1 = mock_qairt_sdk_classes.load.call_args.return_value
193+
loaded_qairt_instance_v1.mock_mha2sha.assert_called_once_with(param1="value1", param2=123)
194+
loaded_qairt_instance_v1.mock_mha2sha_v2.assert_not_called()
195+
196+
197+
def test_import_error_qti_aisw(qairt_pass_instance, tmp_output_dir):
198+
"""Test that ImportError is raised if qti.aisw.tools cannot be imported.
199+
200+
This test needs to run *without* the `mock_qairt_sdk_classes` fixture's patching active.
201+
We'll manually clear sys.modules for this specific test.
202+
"""
203+
204+
def import_side_effect(name, *args, **kwargs):
205+
if name in [
206+
"qti.aisw.tools.core.utilities.framework.frameworks.onnx",
207+
"qti.aisw.tools.core.utilities.framework.onnx",
208+
]:
209+
raise ImportError("Mock import error")
210+
return original_import(name, *args, **kwargs)
211+
212+
original_import = builtins.__import__
213+
214+
with patch("builtins.__import__", side_effect=import_side_effect):
215+
dummy_model = get_onnx_model()
216+
with pytest.raises(ImportError):
217+
qairt_pass_instance.run(dummy_model, output_model_path=tmp_output_dir)

0 commit comments

Comments
 (0)