-
Notifications
You must be signed in to change notification settings - Fork 25.9k
[ONNX] Fix type annotations and enable type checking for all apis #84091
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful links
✅ No Failures (0 Pending)As of commit 54a44a0 (more details on the Dr. CI page): Expand to see more💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
[ghstack-poisoned]
| @symbolic_helper.parse_args("v", "f", "i") | ||
| @symbolic_helper.parse_args("v", "f", "b") | ||
| @_beartype.beartype | ||
| def dropout(g, input, p, train): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed many "i" to "b" because (1) they are bool types (2) when translated to onnx Booleans are casted to int w/o issues.
torch/onnx/verification.py
Outdated
| ort_outs: Sequence[np.ndarray], | ||
| pt_outs: Sequence[torch.Tensor], | ||
| ort_outs: Union[_NumericType, Sequence[_NumericType], Sequence, Dict], | ||
| pt_outs: Union[_NumericType, Sequence[_NumericType], Sequence, Dict], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Relaxed types here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If ORT can return dict, it means onnx also can. Is that really the case? AFAIK, that was a known limitation, wasn't it? ORTModule flattens dict, export to ONNX, runs the model and reassemble the dict before returning it to the user. Is that the case for allowing dict on ORT? My recolelction was that we used InferenceSession directly, without ORTModule wrapping it to flatten input/output, so it is unexpected to me
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Removed Dict
| use_external_data: bool = False, | ||
| additional_test_inputs: Optional[Sequence[Tuple[Any, ...]]] = None, | ||
| additional_test_inputs: Optional[ | ||
| Sequence[Union[torch.Tensor, Tuple[Any, ...]]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Expanded
| # NOTE: prim::Constant at this stage usually means something not compatible in ONNX, | ||
| # otherwise it'd be converted to onnx::Constant | ||
| if _is_value(value) and _is_onnx_constant(value): | ||
| if isinstance(value, _C.Value) and _is_onnx_constant(value): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed for mypy checks on _is_onnx_constant
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At least add a TODO comment here and on _is_value describing the mypy issue for a future fix. Otherwise this will be forgotten
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
| def _maybe_get_const(value: _C.Value, descriptor: _ValueDescriptor): | ||
| @_beartype.beartype | ||
| def _maybe_get_const( | ||
| value: Optional[Union[_C.Value, torch.Tensor, Number, Sequence]], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Relaxed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally speaking, hardening + refactoring is preferred over relaxing checks, especially when the goal of thechange is to enforce correctness
What is the reason for relaxing as opposed to refactor code and removal the avoid the relaxing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wanted to (1)enable typing checking before (2) fixing the code. Modifying the annotations is a way to minimal change code to enable checking without modifying current behaviors. We can then optimize stuff from there
torch/onnx/_patch_torch.py
Outdated
| g: torch._C.Graph, | ||
| opname: str, | ||
| *raw_args: torch._C.Value, | ||
| *raw_args: Union[torch.Tensor, torch._C.Value], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this right? Can raw_args be torch.Tensor? (They exist in tests)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would say no, but could be wrong.
Which test has it? let's look into it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is expected as I realized there is a const_if_tensor a few lines down
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we clean up and remove it (preferably in a separate PR)? I guess Meta internal test is likely to scream though :(
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure thing!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Letting a non Value goes true the most generic function for symbolics open the door for many issues. IMO we need to fix this as opposed to relax the type checking to allow CI to pass. Hardening the code is always best, even if it needs more refactoring
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In a follow-up PR. This is already too big
…ll apis" Enable runtime type checking for all torch.onnx public apis, symbolic functions and most helpers (minus two that does not have a checkable type: `_.JitType` does not exist) by adding the beartype decorator. Fix type annotations to makes unit tests green. TODO: profile performance [ghstack-poisoned]
…ll apis" Enable runtime type checking for all torch.onnx public apis, symbolic functions and most helpers (minus two that does not have a checkable type: `_.JitType` does not exist) by adding the beartype decorator. Fix type annotations to makes unit tests green. TODO: profile performance [ghstack-poisoned]
…ll apis" Enable runtime type checking for all torch.onnx public apis, symbolic functions and most helpers (minus two that does not have a checkable type: `_.JitType` does not exist) by adding the beartype decorator. Fix type annotations to makes unit tests green. TODO: profile performance [ghstack-poisoned]
…ll apis" Enable runtime type checking for all torch.onnx public apis, symbolic functions and most helpers (minus two that does not have a checkable type: `_.JitType` does not exist) by adding the beartype decorator. Fix type annotations to makes unit tests green. TODO: profile performance [ghstack-poisoned]
Add a comment in the places we cannot type check so that nobody tries or do it wrong Also adds a note stating it is mandatory to add type checking in all symbolics and their helpers are # Note [Edit Symbolic Files]
# EDITING THIS FILE AND SYMBOLIC_OPSET<VERSION> FILES? READ THIS FIRST! |
…g for all apis (#84091) Test Plan: revert-hammer Differential Revision: D39084854 Original commit changeset: aeb0e1fc4dd6 Original Phabricator Diff: D39084854 fbshipit-source-id: 703629a64b2c89c5a4278cbc6a2917b8810e86de
…ll apis" Enable runtime type checking for all torch.onnx public apis, symbolic functions and most helpers (minus two that does not have a checkable type: `_.JitType` does not exist) by adding the beartype decorator. Fix type annotations to makes unit tests green. Profile: export `torchvision.models.alexnet(pretrained=True)` ``` with runtime type checking: 21.314 / 10 passes without runtime type checking: 20.797 / 10 passes + 2.48% ``` [ghstack-poisoned]
…ll apis" Enable runtime type checking for all torch.onnx public apis, symbolic functions and most helpers (minus two that does not have a checkable type: `_.JitType` does not exist) by adding the beartype decorator. Fix type annotations to makes unit tests green. Profile: export `torchvision.models.alexnet(pretrained=True)` ``` with runtime type checking: 21.314 / 10 passes without runtime type checking: 20.797 / 10 passes + 2.48% ``` [ghstack-poisoned]
…ll apis" Enable runtime type checking for all torch.onnx public apis, symbolic functions and most helpers (minus two that does not have a checkable type: `_.JitType` does not exist) by adding the beartype decorator. Fix type annotations to makes unit tests green. Profile: export `torchvision.models.alexnet(pretrained=True)` ``` with runtime type checking: 21.314 / 10 passes without runtime type checking: 20.797 / 10 passes + 2.48% ``` [ghstack-poisoned]
|
@thiagocrepaldi PTAL cc @BowenBao |
|
The relaxed types are there to match the reality. We can optimize the code after checks are turned on. The follow ups are also going to be more focused and easier to review |
same concern here. I'd rather seeing a feature complete on a single PR than spreading towards multiple ones. If it needs to be reverted or cherry picked in the feature, due to a bug or limitation, it is easy to do in a single pr. It is also easier to read a single PR and learn about the introduced feature |
thiagocrepaldi
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would rather seeing PRs being code-complete before merge instead of having many follow ups PRs.
That makes reverting unrelated PRs very hard, especially using pytorchbot. (in order to disable beartype, many PRs related and unrelated would have to be reverted (the ones in between all beartype PRs)
It is also harder to learn about the feature because there are too many PRs for the same task
replacing _is_value by isinstance(tensor, _C.Value) defeats the purposa of having is_value()entirely. At least a #TODO or an issue should be created to track and properly fix it.
I like the idea of making revert easier. In this particular case, the bigger the PR becomes and the more places it touches, the harder it can be reverted. Future PRs would more likely need to be reverted because they will build on top of this PR once merged, based on the extend of files it would touch. IMO this PR stands on its own in the sense that it enables typing checking without substantial changes to the actual code. If we lump the logic changes as well, it will have a higher chance of breaking things and being reverted. Separating changes to bite sized self-contained PRs often means a smaller chance of breaking things, so we don't need to revert everything. When we start annotating more functions and refactor, the effort can be distributed and verified individually. I suggest we keep PRs small and focused. https://google.github.io/eng-practices/review/developer/small-cls.html#cant
PR stacks have been there for this. Depending on what we need, we can create some type of documentation for the feature and do a good job of linking PRs together.
Will do. thanks! |
…ll apis" Enable runtime type checking for all torch.onnx public apis, symbolic functions and most helpers (minus two that does not have a checkable type: `_.JitType` does not exist) by adding the beartype decorator. Fix type annotations to makes unit tests green. Profile: export `torchvision.models.alexnet(pretrained=True)` ``` with runtime type checking: 21.314 / 10 passes without runtime type checking: 20.797 / 10 passes + 2.48% ``` [ghstack-poisoned]
…ll apis" Enable runtime type checking for all torch.onnx public apis, symbolic functions and most helpers (minus two that does not have a checkable type: `_.JitType` does not exist) by adding the beartype decorator. Fix type annotations to makes unit tests green. Profile: export `torchvision.models.alexnet(pretrained=True)` ``` with runtime type checking: 21.314 / 10 passes without runtime type checking: 20.797 / 10 passes + 2.48% ``` [ghstack-poisoned]
…ll apis" Enable runtime type checking for all torch.onnx public apis, symbolic functions and most helpers (minus two that does not have a checkable type: `_.JitType` does not exist) by adding the beartype decorator. Fix type annotations to makes unit tests green. Profile: export `torchvision.models.alexnet(pretrained=True)` ``` with runtime type checking: 21.314 / 10 passes without runtime type checking: 20.797 / 10 passes + 2.48% ``` [ghstack-poisoned]
|
@pytorchbot merge -g |
|
@pytorchbot successfully started a merge job. Check the current status here. |
…4091) (#84091) Summary: Enable runtime type checking for all torch.onnx public apis, symbolic functions and most helpers (minus two that does not have a checkable type: `_.JitType` does not exist) by adding the beartype decorator. Fix type annotations to makes unit tests green. Profile: export `torchvision.models.alexnet(pretrained=True)` ``` with runtime type checking: 21.314 / 10 passes without runtime type checking: 20.797 / 10 passes + 2.48% ``` Pull Request resolved: #84091 Approved by: https://github.com/BowenBao, https://github.com/thiagocrepaldi Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/388368b6996479f6eca484d4e60a6250b2535dec Reviewed By: mehtanirav Differential Revision: D39277677 fbshipit-source-id: 6836efdd15c3b2479bac68807c65ea7c5609295f
Stack from ghstack (oldest at bottom):
Enable runtime type checking for all torch.onnx public apis, symbolic functions and most helpers (minus two that does not have a checkable type:
_.JitTypedoes not exist) by adding the beartype decorator. Fix type annotations to makes unit tests green.Profile:
export
torchvision.models.alexnet(pretrained=True)