Skip to content

Commit 4867094

Browse files
authored
[Typing][B-87] Add type annotations for python/paddle/nn/utils/spectral_norm_hook.py (#65810)
1 parent d61ca73 commit 4867094

File tree

1 file changed

+36
-9
lines changed

1 file changed

+36
-9
lines changed

python/paddle/nn/utils/spectral_norm_hook.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,23 +12,44 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
16+
17+
from typing import TYPE_CHECKING
18+
1519
import paddle
1620

1721
from .. import functional as F
1822
from ..layer.common import Linear
1923
from ..layer.conv import Conv1DTranspose, Conv2DTranspose, Conv3DTranspose
2024

25+
if TYPE_CHECKING:
26+
from typing_extensions import Never
27+
28+
from paddle import Tensor
29+
from paddle.nn import Layer
30+
2131
__all__ = []
2232

2333

24-
def normal_(x, mean=0.0, std=1.0):
34+
def normal_(x: Tensor, mean: float = 0.0, std: float = 1.0) -> Tensor:
2535
temp_value = paddle.normal(mean, std, shape=x.shape)
2636
paddle.assign(temp_value, x)
2737
return x
2838

2939

3040
class SpectralNorm:
31-
def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12):
41+
name: str
42+
dim: int
43+
n_power_iterations: int
44+
eps: float
45+
46+
def __init__(
47+
self,
48+
name: str = 'weight',
49+
n_power_iterations: int = 1,
50+
dim: int = 0,
51+
eps: float = 1e-12,
52+
) -> None:
3253
self.name = name
3354
self.dim = dim
3455
if n_power_iterations <= 0:
@@ -39,7 +60,7 @@ def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12):
3960
self.n_power_iterations = n_power_iterations
4061
self.eps = eps
4162

42-
def reshape_weight_to_matrix(self, weight):
63+
def reshape_weight_to_matrix(self, weight: Tensor) -> Tensor:
4364
weight_mat = weight
4465
if self.dim != 0:
4566
# transpose dim to front
@@ -52,7 +73,7 @@ def reshape_weight_to_matrix(self, weight):
5273

5374
return weight_mat.reshape([height, -1])
5475

55-
def compute_weight(self, layer, do_power_iteration):
76+
def compute_weight(self, layer: Layer, do_power_iteration: bool) -> Tensor:
5677
weight = getattr(layer, self.name + '_orig')
5778
u = getattr(layer, self.name + '_u')
5879
v = getattr(layer, self.name + '_v')
@@ -91,15 +112,17 @@ def compute_weight(self, layer, do_power_iteration):
91112
weight = weight / sigma
92113
return weight
93114

94-
def __call__(self, layer, inputs):
115+
def __call__(self, layer: Layer, inputs: Never) -> None:
95116
setattr(
96117
layer,
97118
self.name,
98119
self.compute_weight(layer, do_power_iteration=layer.training),
99120
)
100121

101122
@staticmethod
102-
def apply(layer, name, n_power_iterations, dim, eps):
123+
def apply(
124+
layer: Layer, name: str, n_power_iterations: int, dim: int, eps: float
125+
) -> SpectralNorm:
103126
for k, hook in layer._forward_pre_hooks.items():
104127
if isinstance(hook, SpectralNorm) and hook.name == name:
105128
raise RuntimeError(
@@ -138,8 +161,12 @@ def apply(layer, name, n_power_iterations, dim, eps):
138161

139162

140163
def spectral_norm(
141-
layer, name='weight', n_power_iterations=1, eps=1e-12, dim=None
142-
):
164+
layer: Layer,
165+
name: str = 'weight',
166+
n_power_iterations: int = 1,
167+
eps: float = 1e-12,
168+
dim: int | None = None,
169+
) -> Layer:
143170
r"""
144171
Applies spectral normalization to a parameter according to the
145172
following Calculation:
@@ -176,7 +203,7 @@ def spectral_norm(
176203
name(str, optional): Name of the weight parameter. Default: 'weight'.
177204
n_power_iterations(int, optional): The number of power iterations to calculate spectral norm. Default: 1.
178205
eps(float, optional): The epsilon for numerical stability in calculating norms. Default: 1e-12.
179-
dim(int, optional): The index of dimension which should be permuted to the first before reshaping Input(Weight) to matrix, it should be set as 0 if Input(Weight) is the weight of fc layer, and should be set as 1 if Input(Weight) is the weight of conv layer. Default: None.
206+
dim(int|None, optional): The index of dimension which should be permuted to the first before reshaping Input(Weight) to matrix, it should be set as 0 if Input(Weight) is the weight of fc layer, and should be set as 1 if Input(Weight) is the weight of conv layer. Default: None.
180207
181208
Returns:
182209
Layer, the original layer with the spectral norm hook.

0 commit comments

Comments
 (0)