Skip to content

Commit 6da1ec3

Browse files
megeminico63oc
authored andcommitted
[Typing] 识别标题中的 debug 关键字进行类型检查 (PaddlePaddle#65319)
1 parent 11e579d commit 6da1ec3

File tree

4 files changed

+39
-12
lines changed

4 files changed

+39
-12
lines changed

paddle/scripts/paddle_build.sh

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3562,11 +3562,20 @@ function exec_type_checking() {
35623562

35633563
# check all sample code
35643564
TITLE_CHECK_ALL=`curl -s https://github.com/PaddlePaddle/Paddle/pull/${GIT_PR_ID} | grep "<title>" | grep -i "typing all" || true`
3565+
DEBUG_MODE=`curl -s https://github.com/PaddlePaddle/Paddle/pull/${GIT_PR_ID} | grep "<title>" | grep -i "[debug]" || true`
35653566

35663567
if [[ ${TITLE_CHECK_ALL} ]]; then
3567-
python type_checking.py --full-test; type_checking_error=$?
3568+
if [[ ${DEBUG_MODE} ]]; then
3569+
python type_checking.py --debug --full-test; type_checking_error=$?
3570+
else
3571+
python type_checking.py --full-test; type_checking_error=$?
3572+
fi
35683573
else
3569-
python type_checking.py; type_checking_error=$?
3574+
if [[ ${DEBUG_MODE} ]]; then
3575+
python type_checking.py --debug; type_checking_error=$?
3576+
else
3577+
python type_checking.py; type_checking_error=$?
3578+
fi
35703579
fi
35713580

35723581
if [ "$type_checking_error" != "0" ];then

python/paddle/_typing/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
NestedStructure as NestedStructure,
2424
Numberic as Numberic,
2525
NumbericSequence as NumbericSequence,
26+
PaddingMode as PaddingMode,
2627
TensorIndex as TensorIndex,
2728
TensorLike as TensorLike,
2829
TensorOrTensors as TensorOrTensors,

python/paddle/nn/functional/pooling.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
16+
17+
from typing import TYPE_CHECKING
18+
1519
import numpy as np
1620

1721
from paddle import _C_ops, in_dynamic_mode
@@ -33,6 +37,11 @@
3337
convert_to_list,
3438
)
3539

40+
if TYPE_CHECKING:
41+
from paddle import Tensor
42+
43+
from ..._typing import PaddingMode, Size1, Size2
44+
3645
__all__ = []
3746

3847

@@ -173,14 +182,14 @@ def _expand_low_nd_padding(padding):
173182

174183

175184
def avg_pool1d(
176-
x,
177-
kernel_size,
178-
stride=None,
179-
padding=0,
180-
exclusive=True,
181-
ceil_mode=False,
182-
name=None,
183-
):
185+
x: Tensor,
186+
kernel_size: Size1,
187+
stride: Size1 | None = None,
188+
padding: PaddingMode | Size1 | Size2 = 0,
189+
exclusive: bool = True,
190+
ceil_mode: bool = False,
191+
name: str | None = None,
192+
) -> Tensor:
184193
"""
185194
This API implements average pooling 1d operation,
186195
See more details in :ref:`api_paddle_nn_AvgPool1d` .

tools/type_checking.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,16 @@ class TestResult:
7070

7171
class MypyChecker(TypeChecker):
7272
def __init__(
73-
self, config_file: str, cache_dir: str, *args: Any, **kwargs: Any
73+
self,
74+
config_file: str,
75+
cache_dir: str,
76+
debug: bool = False,
77+
*args: Any,
78+
**kwargs: Any,
7479
) -> None:
7580
self.config_file = config_file
7681
self.cache_dir = cache_dir
82+
self.debug = debug
7783
super().__init__(*args, **kwargs)
7884

7985
def run(self, api_name: str, codeblock: str) -> TestResult:
@@ -99,7 +105,8 @@ def run(self, api_name: str, codeblock: str) -> TestResult:
99105
)
100106

101107
normal_report, error_report, exit_status = mypy_api.run(
102-
[
108+
(["--show-traceback"] if self.debug else [])
109+
+ [
103110
f'--config-file={self.config_file}',
104111
f'--cache-dir={self.cache_dir}',
105112
'-c',
@@ -313,5 +320,6 @@ def run_type_checker(
313320
cache_dir=(
314321
args.cache_dir if args.cache_dir else (base_path / '.mypy_cache')
315322
),
323+
debug=args.debug,
316324
)
317325
run_type_checker(args, mypy_checker)

0 commit comments

Comments
 (0)