Skip to content

Commit 51fa719

Browse files
authored
use base_version to check torch version in torch_less_than_1_11 (#16806)
* use base_version * make is_torch_less_than_1_8 match 1_11 Co-authored-by: Nicholas Broad <[email protected]>
1 parent 8d3f952 commit 51fa719

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/transformers/pytorch_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323

2424
logger = logging.get_logger(__name__)
2525

26-
is_torch_less_than_1_8 = version.parse(torch.__version__) < version.parse("1.8.0")
27-
is_torch_less_than_1_11 = version.parse(torch.__version__) < version.parse("1.11")
26+
is_torch_less_than_1_8 = version.parse(version.parse(torch.__version__).base_version) < version.parse("1.8.0")
27+
is_torch_less_than_1_11 = version.parse(version.parse(torch.__version__).base_version) < version.parse("1.11")
2828

2929

3030
def torch_int_div(tensor1, tensor2):

0 commit comments

Comments
 (0)