-
Notifications
You must be signed in to change notification settings - Fork 3.6k
feat(relax/frontend/torch): Add basic range constraint support #17898
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
demoncoder-crypto
wants to merge
16
commits into
apache:main
Choose a base branch
from
demoncoder-crypto:fix/relax-pytorch-constraints-v2
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 5 commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
a4ab52f
feat(relax/frontend/torch): Add basic range constraint support from E…
demoncoder-crypto 71f8a98
Fix: Insert test_dynamic_shape_with_constraints
demoncoder-crypto f8ad0d7
refactor(test): Refactor constraint test to use verify_model and add …
demoncoder-crypto e5710b7
fix(test): Define tir.Var for TVMScript parsing in constraint test
demoncoder-crypto 5c7758c
style: Apply black formatting
demoncoder-crypto 9ca05a6
fix(relax/torch): Handle ExportedProgram range constraints and add tests
demoncoder-crypto 8ab98aa
Merge branch 'main' into fix/relax-pytorch-constraints-v2
demoncoder-crypto 7201b72
style: Apply formatting fixes to test_frontend_from_exported_program.py
demoncoder-crypto f7e23f4
style: Fix trailing whitespace in test file
demoncoder-crypto bcab702
feat(relax): Enhance PyTorch ExportedProgram range constraints support
demoncoder-crypto 70bff93
feat: Enhance PyTorch range constraints support
demoncoder-crypto 54885dd
style: Fix lint errors reported by CI
demoncoder-crypto 4717288
style: Apply final lint fixes for translator and test files
demoncoder-crypto 073ec93
Apply Black code formatting to exported_program_translator.py
demoncoder-crypto 6162a27
Add logging module for PyTorch frontend
demoncoder-crypto 249c808
fix: coerce bounds to int and update R.relu to R.nn.relu
demoncoder-crypto File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,12 +20,12 @@ | |
"""PyTorch ExportedProgram of Relax.""" | ||
from collections import ChainMap, OrderedDict | ||
from functools import partial | ||
from typing import Callable, Dict, List, Tuple | ||
from typing import Callable, Dict, List, Tuple, Optional | ||
|
||
import torch | ||
import tvm | ||
from tvm import relax | ||
|
||
import tvm.tir as tir # pylint: disable=unused-import, consider-using-from-import | ||
from .base_fx_graph_translator import BaseFXGraphImporter | ||
|
||
|
||
|
@@ -497,11 +497,16 @@ def create_convert_map( | |
|
||
def create_input_vars( | ||
self, exported_program: torch.export.ExportedProgram | ||
) -> Tuple[Dict[str, relax.Var], Dict[str, relax.Var]]: | ||
) -> Tuple[ | ||
Dict[str, relax.Var], | ||
Dict[str, relax.Var], | ||
Dict[tvm.tir.Var, Tuple[Optional[int], Optional[int]]], | ||
]: | ||
"""Create relax input vars.""" | ||
parameters_buffers_constants = OrderedDict() | ||
user_inputs = OrderedDict() | ||
torch_symbol_to_relax_var: Dict[str, tvm.tir.Var] = {} | ||
relax_range_constraints: Dict[tvm.tir.Var, Tuple[Optional[int], Optional[int]]] = {} | ||
|
||
for spec in exported_program.graph_signature.input_specs: | ||
name_hint = spec.arg.name | ||
|
@@ -519,13 +524,17 @@ def create_input_vars( | |
torch_shape = exported_program.state_dict[spec.target].shape | ||
torch_dtype = exported_program.state_dict[spec.target].dtype | ||
|
||
# TODO(mshr-h): Support range constraints | ||
relax_shape = [ | ||
torch_symbol_to_relax_var.setdefault(str(s), tvm.tir.SizeVar(str(s), "int64")) | ||
if isinstance(s, torch.SymInt) | ||
else s | ||
for s in torch_shape | ||
] | ||
# Create SizeVars and map SymInts | ||
relax_shape = [] | ||
for s in torch_shape: | ||
if isinstance(s, torch.SymInt): | ||
s_str = str(s) | ||
if s_str not in torch_symbol_to_relax_var: | ||
torch_symbol_to_relax_var[s_str] = tvm.tir.SizeVar(s_str, "int64") | ||
relax_shape.append(torch_symbol_to_relax_var[s_str]) | ||
else: | ||
relax_shape.append(s) | ||
|
||
dtype = self._convert_data_type(torch_dtype) | ||
|
||
relax_var = relax.Var(name_hint, relax.TensorStructInfo(relax_shape, dtype)) | ||
|
@@ -534,7 +543,55 @@ def create_input_vars( | |
else: | ||
parameters_buffers_constants[name_hint] = relax_var | ||
|
||
return parameters_buffers_constants, user_inputs | ||
# Extract range constraints for TIR vars | ||
if hasattr(exported_program, "range_constraints") and exported_program.range_constraints: | ||
for torch_sym_expr, constraint in exported_program.range_constraints.items(): | ||
# Convert sympy expression to string for mapping | ||
torch_sym_expr_str = str(torch_sym_expr) | ||
|
||
if torch_sym_expr_str in torch_symbol_to_relax_var: | ||
relax_tir_var = torch_symbol_to_relax_var[torch_sym_expr_str] | ||
# TODO(sjt): Handle SymFloat, SymBool cases as well. | ||
# Note: min / max could be int or SymInt objects. | ||
# Need to handle symbolic shapes as well. | ||
min_val = constraint.min | ||
max_val = constraint.max | ||
# Call helper to add/refine constraint | ||
self._add_range_constraint( | ||
relax_range_constraints, relax_tir_var, min_val, max_val | ||
) | ||
# else: | ||
# FIXED Indentation for Black: | ||
# TODO: Handle complex expressions (e.g., s0 + 1) for advanced support | ||
# print(f"Skipping complex constraint expression: {torch_sym_expr}") | ||
|
||
return parameters_buffers_constants, user_inputs, relax_range_constraints | ||
|
||
# NEW HELPER METHOD | ||
def _add_range_constraint(self, constraints_dict, relax_tir_var, min_val, max_val): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
"""Adds or refines a range constraint for a TIR variable.""" | ||
if relax_tir_var not in constraints_dict: | ||
constraints_dict[relax_tir_var] = (min_val, max_val) | ||
else: | ||
# Refine existing constraints if the new one is tighter | ||
existing_min, existing_max = constraints_dict[relax_tir_var] | ||
# Merge lower bounds (take the max) | ||
if existing_min is None: | ||
new_min = min_val | ||
elif min_val is None: | ||
new_min = existing_min | ||
else: | ||
new_min = max(existing_min, min_val) | ||
|
||
# Merge upper bounds (take the min) | ||
if existing_max is None: | ||
new_max = max_val | ||
elif max_val is None: | ||
new_max = existing_max | ||
else: | ||
new_max = min(existing_max, max_val) | ||
|
||
constraints_dict[relax_tir_var] = (new_min, new_max) | ||
|
||
def from_exported_program( | ||
self, | ||
|
@@ -546,23 +603,47 @@ def from_exported_program( | |
"""Convert a PyTorch ExportedProgram to a Relax program.""" | ||
from torch import fx # type: ignore | ||
|
||
# Create input variables. | ||
parameter_buffer_constant_vars, user_input_vars = self.create_input_vars(exported_program) | ||
# Create input variables and get range constraints. | ||
( | ||
parameter_buffer_constant_vars, | ||
user_input_vars, | ||
relax_range_constraints, | ||
) = self.create_input_vars(exported_program) | ||
inputs_vars = user_input_vars.copy() | ||
inputs_vars.update(parameter_buffer_constant_vars) | ||
|
||
# Initialize the block builder with a function and a dataflow block. | ||
self.block_builder = relax.BlockBuilder() | ||
func_name = "main" | ||
func_attrs = {"num_input": len(user_input_vars)} if keep_params_as_input else None | ||
|
||
# Prepare function attributes | ||
func_attrs = {"num_input": len(user_input_vars)} if keep_params_as_input else {} | ||
|
||
# NEW: Add range constraints to function attributes if they exist | ||
if relax_range_constraints: | ||
lower_bounds = {} | ||
upper_bounds = {} | ||
for tir_var, (min_val, max_val) in relax_range_constraints.items(): | ||
if min_val is not None: | ||
lower_bounds[tir_var] = tvm.tir.IntImm("int64", min_val) | ||
if max_val is not None: | ||
upper_bounds[tir_var] = tvm.tir.IntImm("int64", max_val) | ||
|
||
if lower_bounds: | ||
func_attrs["tir_var_lower_bound"] = lower_bounds | ||
if upper_bounds: | ||
func_attrs["tir_var_upper_bound"] = upper_bounds | ||
|
||
# Use None if func_attrs is empty, otherwise use the dictionary | ||
final_func_attrs = func_attrs if func_attrs else None | ||
|
||
nodes: List[fx.Node] = exported_program.graph.nodes | ||
|
||
# Find all the missing function types | ||
self._check_unsupported_func_type(nodes) | ||
|
||
with self.block_builder.function( | ||
name=func_name, params=list(inputs_vars.values()).copy(), attrs=func_attrs | ||
name=func_name, params=list(inputs_vars.values()).copy(), attrs=final_func_attrs | ||
): | ||
output = None | ||
with self.block_builder.dataflow(): | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove the unnecessary comment.