1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14+ from __future__ import annotations
1415
1516import re
17+ from typing import TYPE_CHECKING
1618
1719from paddle import _C_ops , _legacy_C_ops
1820
2729 in_dynamic_or_pir_mode ,
2830)
2931
32+ if TYPE_CHECKING :
33+ from paddle import Tensor
34+
3035__all__ = []
3136
3237
@@ -46,7 +51,7 @@ def _convert_(name):
4651 return re .sub ('([a-z0-9])([A-Z])' , r'\1_\2' , s1 ).lower ()
4752
4853
49- def generate_layer_fn (op_type ):
54+ def generate_layer_fn (op_type : str ):
5055 """Register the Python layer for an Operator.
5156
5257 Args:
@@ -124,7 +129,7 @@ def infer_and_check_dtype(op_proto, *args, **kwargs):
124129 dtype = core .VarDesc .VarType .FP32
125130 return dtype
126131
127- def func (* args , ** kwargs ):
132+ def func (* args , ** kwargs ) -> Tensor :
128133 helper = LayerHelper (op_type , ** kwargs )
129134
130135 dtype = infer_and_check_dtype (op_proto , * args , ** kwargs )
@@ -160,7 +165,7 @@ def func(*args, **kwargs):
160165 return func
161166
162167
163- def generate_activation_fn (op_type ):
168+ def generate_activation_fn (op_type : str ):
164169 """Register the Python layer for an Operator without Attribute.
165170
166171 Args:
@@ -171,7 +176,7 @@ def generate_activation_fn(op_type):
171176
172177 """
173178
174- def func (x , name = None ):
179+ def func (x , name : str | None = None ) -> Tensor :
175180 if in_dynamic_or_pir_mode ():
176181 if hasattr (_C_ops , op_type ):
177182 op = getattr (_C_ops , op_type )
0 commit comments