1414
1515from __future__ import annotations
1616
17+ from typing import TYPE_CHECKING
18+
1719import paddle
1820from paddle .base import core , framework
1921from paddle .base .backward import gradients_with_optimizer # noqa: F401
2022
23+ if TYPE_CHECKING :
24+ from .. import Tensor
25+
26+
2127__all__ = []
2228
2329
2430@framework .dygraph_only
2531def backward (
26- tensors : list [paddle . Tensor ],
27- grad_tensors : list [paddle . Tensor ] | None = None ,
28- retain_graph : bool | None = False ,
32+ tensors : list [Tensor ],
33+ grad_tensors : list [Tensor , None ] | None = None ,
34+ retain_graph : bool = False ,
2935) -> None :
3036 """
3137 Compute the backward gradients of given tensors.
@@ -38,7 +44,7 @@ def backward(
3844 If None, all the gradients of the ``tensors`` is the default value which is filled with 1.0.
3945 Defaults to None.
4046
41- retain_graph(bool|False , optional): If False, the graph used to compute grads will be freed. If you would
47+ retain_graph(bool, optional): If False, the graph used to compute grads will be freed. If you would
4248 like to add more ops to the built graph after calling this method( :code:`backward` ), set the parameter
4349 :code:`retain_graph` to True, then the grads will be retained. Thus, setting it to False is much more memory-efficient.
4450 Defaults to False.
@@ -87,8 +93,8 @@ def backward(
8793 """
8894
8995 def check_tensors (
90- in_out_list : list [paddle . Tensor ], name : str
91- ) -> list [paddle . Tensor ]:
96+ in_out_list : list [Tensor ] | tuple [ Tensor , ...] | Tensor , name : str
97+ ) -> list [Tensor ] | tuple [ Tensor , ... ]:
9298 assert in_out_list is not None , f"{ name } should not be None"
9399
94100 if isinstance (in_out_list , (list , tuple )):
0 commit comments