Skip to content

Commit d824f83

Browse files
gouzilco63oc
authored andcommitted
[Typing][A-9] Add type annotations for paddle/tensor/ops.py (PaddlePaddle#65249)
1 parent 98e024a commit d824f83

File tree

2 files changed

+60
-50
lines changed

2 files changed

+60
-50
lines changed

python/paddle/tensor/layer_function_generator.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from __future__ import annotations
1415

1516
import re
17+
from typing import TYPE_CHECKING
1618

1719
from paddle import _C_ops, _legacy_C_ops
1820

@@ -27,6 +29,9 @@
2729
in_dynamic_or_pir_mode,
2830
)
2931

32+
if TYPE_CHECKING:
33+
from paddle import Tensor
34+
3035
__all__ = []
3136

3237

@@ -46,7 +51,7 @@ def _convert_(name):
4651
return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()
4752

4853

49-
def generate_layer_fn(op_type):
54+
def generate_layer_fn(op_type: str):
5055
"""Register the Python layer for an Operator.
5156
5257
Args:
@@ -124,7 +129,7 @@ def infer_and_check_dtype(op_proto, *args, **kwargs):
124129
dtype = core.VarDesc.VarType.FP32
125130
return dtype
126131

127-
def func(*args, **kwargs):
132+
def func(*args, **kwargs) -> Tensor:
128133
helper = LayerHelper(op_type, **kwargs)
129134

130135
dtype = infer_and_check_dtype(op_proto, *args, **kwargs)
@@ -160,7 +165,7 @@ def func(*args, **kwargs):
160165
return func
161166

162167

163-
def generate_activation_fn(op_type):
168+
def generate_activation_fn(op_type: str):
164169
"""Register the Python layer for an Operator without Attribute.
165170
166171
Args:
@@ -171,7 +176,7 @@ def generate_activation_fn(op_type):
171176
172177
"""
173178

174-
def func(x, name=None):
179+
def func(x, name: str | None = None) -> Tensor:
175180
if in_dynamic_or_pir_mode():
176181
if hasattr(_C_ops, op_type):
177182
op = getattr(_C_ops, op_type)

0 commit comments

Comments
 (0)