Skip to content

Commit 4ad8416

Browse files
Add paddle.is_autocast_enabled and get_autocast_gpu_dtype API (#74441)
* Add get_autocast_dtype. get_autocast_gpu_dtype and is_autocast_enable APIs * refine docs * set default dtype to fp16
1 parent ffef065 commit 4ad8416

File tree

5 files changed

+239
-0
lines changed

5 files changed

+239
-0
lines changed

python/paddle/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,12 @@
129129
tensor as tensor,
130130
utils as utils,
131131
)
132+
from .amp import (
133+
get_autocast_cpu_dtype,
134+
get_autocast_dtype,
135+
get_autocast_gpu_dtype,
136+
is_autocast_enabled,
137+
)
132138
from .autograd import (
133139
enable_grad,
134140
grad,
@@ -1233,6 +1239,10 @@
12331239
'nan',
12341240
'pi',
12351241
'e',
1242+
'is_autocast_enabled',
1243+
'get_autocast_dtype',
1244+
'get_autocast_cpu_dtype',
1245+
'get_autocast_gpu_dtype',
12361246
]
12371247

12381248
import os

python/paddle/amp/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
amp_guard,
3434
auto_cast,
3535
decorate,
36+
get_autocast_dtype,
37+
is_autocast_enabled,
3638
)
3739
from .grad_scaler import ( # noqa: F401
3840
AmpScaler,
@@ -46,8 +48,15 @@
4648
'decorate',
4749
'is_float16_supported',
4850
'is_bfloat16_supported',
51+
'is_autocast_enabled',
52+
'get_autocast_dtype',
53+
'get_autocast_cpu_dtype',
54+
'get_autocast_gpu_dtype',
4955
]
5056

57+
get_autocast_cpu_dtype = get_autocast_dtype
58+
get_autocast_gpu_dtype = get_autocast_dtype
59+
5160

5261
def is_float16_supported(device: str | None = None) -> bool:
5362
"""

python/paddle/amp/auto_cast.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from typing_extensions import TypeAlias, TypeGuard
4949

5050
from paddle import Tensor
51+
from paddle._typing import PlaceLike
5152
from paddle._typing.dtype_like import _DTypeLiteral
5253
from paddle.nn import Layer
5354
from paddle.nn.layer.layers import _StateDict
@@ -1322,3 +1323,73 @@ def decorate(
13221323
master_grad,
13231324
excluded_layers,
13241325
)
1326+
1327+
1328+
def is_autocast_enabled(device_type: PlaceLike | None = None) -> bool:
1329+
"""
1330+
Check whether auto-mixed-precision is enabled in the current context.
1331+
1332+
Args:
1333+
device_type (PlaceLike, optional): The device type to check. This argument is ignored for all devices sharing the same AMP state in paddlepaddle.
1334+
1335+
Returns:
1336+
bool: True if auto-mixed-precision is enabled, False otherwise.
1337+
1338+
Examples:
1339+
.. code-block:: python
1340+
1341+
>>> # doctest: +REQUIRES(env:GPU)
1342+
>>> # Demo1: Check if auto-mixed-precision is enabled by default
1343+
>>> import paddle
1344+
>>> paddle.device.set_device('gpu')
1345+
>>> print(paddle.is_autocast_enabled())
1346+
False
1347+
1348+
>>> # Demo2: Enable auto-mixed-precision and check again
1349+
>>> with paddle.amp.auto_cast():
1350+
... print(paddle.is_autocast_enabled())
1351+
True
1352+
"""
1353+
if in_pir_mode():
1354+
amp_attrs = core._get_amp_attrs()
1355+
return amp_attrs._amp_level != AMP_LEVEL.O0
1356+
else:
1357+
tracer = _dygraph_tracer()
1358+
if tracer:
1359+
return tracer._amp_level != core.AmpLevel.O0
1360+
return False
1361+
1362+
1363+
def get_autocast_dtype(device_type: PlaceLike | None = None) -> _DTypeLiteral:
1364+
"""
1365+
Get the auto-mixed-precision dtype in the current context if autocast is enabled else default AMP dtype(float16).
1366+
1367+
Args:
1368+
device_type (PlaceLike, optional): The device type to check. This argument is ignored for all devices sharing the same AMP state in paddlepaddle.
1369+
1370+
Returns:
1371+
_DTypeLiteral: The current AMP dtype.
1372+
1373+
Examples:
1374+
.. code-block:: python
1375+
1376+
>>> # doctest: +REQUIRES(env:GPU)
1377+
>>> # Demo1: Get default auto-mixed-precision dtype
1378+
>>> import paddle
1379+
>>> paddle.device.set_device('gpu')
1380+
>>> print(paddle.get_autocast_dtype())
1381+
float16
1382+
1383+
>>> # Demo2: Enable auto-mixed-precision and get the dtype
1384+
>>> with paddle.amp.auto_cast():
1385+
... print(paddle.get_autocast_dtype())
1386+
float16
1387+
"""
1388+
if not is_autocast_enabled():
1389+
return "float16"
1390+
if in_pir_mode():
1391+
amp_attrs = core._get_amp_attrs()
1392+
return amp_attrs._amp_dtype
1393+
else:
1394+
tracer = _dygraph_tracer()
1395+
return tracer._amp_dtype

test/amp/test_get_autocast_dtype.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import paddle
18+
19+
20+
@unittest.skipIf(paddle.device.get_device() == "cpu", "Skip AMP test on CPU")
21+
class TestAutocast(unittest.TestCase):
22+
def setUp(self) -> None:
23+
paddle.disable_static()
24+
self.device_list = [None, paddle.device.get_device()]
25+
self.default_dtype = "float16"
26+
27+
def do_test(self, device, expected_type):
28+
self.assertTrue(paddle.get_autocast_dtype(device) == expected_type)
29+
self.assertTrue(paddle.get_autocast_gpu_dtype() == expected_type)
30+
self.assertTrue(paddle.amp.get_autocast_dtype(device) == expected_type)
31+
self.assertTrue(paddle.amp.get_autocast_gpu_dtype() == expected_type)
32+
self.assertTrue(paddle.amp.get_autocast_cpu_dtype() == expected_type)
33+
self.assertTrue(
34+
paddle.amp.get_autocast_cpu_dtype(device) == expected_type
35+
)
36+
37+
def test_amp_default(self):
38+
for device in self.device_list:
39+
self.do_test(device, self.default_dtype)
40+
41+
def test_amp_autocast_fp16(self):
42+
for device in self.device_list:
43+
with paddle.amp.auto_cast(True, dtype="float16"):
44+
self.do_test(device, "float16")
45+
self.do_test(device, self.default_dtype)
46+
47+
def test_amp_autocast_bf16(self):
48+
for device in self.device_list:
49+
with paddle.amp.auto_cast(True, dtype="bfloat16"):
50+
self.do_test(device, "bfloat16")
51+
self.do_test(device, self.default_dtype)
52+
53+
def test_amp_autocast_false_bf16(self):
54+
for device in self.device_list:
55+
with paddle.amp.auto_cast(True, dtype="bfloat16"):
56+
self.do_test(device, "bfloat16")
57+
self.do_test(device, self.default_dtype)
58+
59+
def test_amp_nested_context(self):
60+
for device in self.device_list:
61+
with paddle.amp.auto_cast(True, dtype="bfloat16"):
62+
self.do_test(device, "bfloat16")
63+
with paddle.amp.auto_cast(True, dtype="float16"):
64+
self.do_test(device, "float16")
65+
self.do_test(device, "bfloat16")
66+
self.do_test(device, self.default_dtype)
67+
68+
69+
class TestAutocastStatic(TestAutocast):
70+
def setUp(self) -> None:
71+
paddle.enable_static()
72+
self.device_list = [None, paddle.device.get_device()]
73+
self.default_dtype = "float16"
74+
75+
76+
if __name__ == "__main__":
77+
unittest.main()

test/amp/test_is_autocast_enabled.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import paddle
18+
19+
20+
@unittest.skipIf(paddle.device.get_device() == "cpu", "Skip AMP test on CPU")
21+
class TestAutocast(unittest.TestCase):
22+
def setUp(self) -> None:
23+
paddle.disable_static()
24+
self.device_list = [None, paddle.device.get_device()]
25+
26+
def test_amp_default(self):
27+
for device in self.device_list:
28+
self.assertFalse(paddle.is_autocast_enabled(device))
29+
self.assertFalse(paddle.amp.is_autocast_enabled(device))
30+
31+
def test_amp_autocast_true(self):
32+
for device in self.device_list:
33+
with paddle.amp.auto_cast(True):
34+
self.assertTrue(paddle.is_autocast_enabled(device))
35+
self.assertTrue(paddle.amp.is_autocast_enabled(device))
36+
37+
self.assertFalse(paddle.is_autocast_enabled(device))
38+
self.assertFalse(paddle.amp.is_autocast_enabled(device))
39+
40+
def test_amp_autocast_false(self):
41+
for device in self.device_list:
42+
with paddle.amp.auto_cast(False):
43+
self.assertFalse(paddle.is_autocast_enabled(device))
44+
self.assertFalse(paddle.amp.is_autocast_enabled(device))
45+
46+
self.assertFalse(paddle.is_autocast_enabled(device))
47+
self.assertFalse(paddle.amp.is_autocast_enabled(device))
48+
49+
def test_amp_nested_context(self):
50+
for device in self.device_list:
51+
with paddle.amp.auto_cast(True):
52+
self.assertTrue(paddle.is_autocast_enabled(device))
53+
self.assertTrue(paddle.amp.is_autocast_enabled(device))
54+
55+
with paddle.amp.auto_cast(False):
56+
self.assertFalse(paddle.is_autocast_enabled(device))
57+
self.assertFalse(paddle.amp.is_autocast_enabled(device))
58+
59+
self.assertTrue(paddle.is_autocast_enabled(device))
60+
self.assertTrue(paddle.amp.is_autocast_enabled(device))
61+
self.assertFalse(paddle.is_autocast_enabled(device))
62+
self.assertFalse(paddle.amp.is_autocast_enabled(device))
63+
64+
65+
class TestAutocastStatic(TestAutocast):
66+
def setUp(self) -> None:
67+
paddle.enable_static()
68+
self.device_list = [None, paddle.device.get_device()]
69+
70+
71+
if __name__ == "__main__":
72+
unittest.main()

0 commit comments

Comments
 (0)