Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/paddle/fluid/tests/unittests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,7 +942,7 @@ def test_spectral_norm(self):
lod_level=1,
append_batch_size=False,
)
spectralNorm = paddle.nn.SpectralNorm(shape, axis=1, power_iters=2)
spectralNorm = paddle.nn.SpectralNorm(shape, dim=1, power_iters=2)
ret = spectralNorm(Weight)
static_ret2 = self.get_static_graph_result(
feed={
Expand All @@ -955,7 +955,7 @@ def test_spectral_norm(self):
)[0]

with self.dynamic_graph():
spectralNorm = paddle.nn.SpectralNorm(shape, axis=1, power_iters=2)
spectralNorm = paddle.nn.SpectralNorm(shape, dim=1, power_iters=2)
dy_ret = spectralNorm(base.to_variable(input))
dy_rlt_value = dy_ret.numpy()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ class TestDygraphSpectralNormOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
shape = (2, 4, 3, 3)
spectralNorm = paddle.nn.SpectralNorm(shape, axis=1, power_iters=2)
spectralNorm = paddle.nn.SpectralNorm(shape, dim=1, power_iters=2)

def test_Variable():
weight_1 = np.random.random((2, 4)).astype("float32")
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/nn/layer/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1737,8 +1737,8 @@ def __init__(self, start_axis=1, stop_axis=-1):
self.start_axis = start_axis
self.stop_axis = stop_axis

def forward(self, x):
def forward(self, input):
out = paddle.flatten(
x, start_axis=self.start_axis, stop_axis=self.stop_axis
input, start_axis=self.start_axis, stop_axis=self.stop_axis
)
return out
24 changes: 12 additions & 12 deletions python/paddle/nn/layer/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1812,7 +1812,7 @@ class SpectralNorm(Layer):

Step 1:
Generate vector U in shape of [H], and V in shape of [W].
While H is the :attr:`axis` th dimension of the input weights,
While H is the :attr:`dim` th dimension of the input weights,
and W is the product result of remaining dimensions.

Step 2:
Expand All @@ -1839,9 +1839,9 @@ class SpectralNorm(Layer):

Parameters:
weight_shape(list or tuple): The shape of weight parameter.
axis(int, optional): The index of dimension which should be permuted to the first before reshaping Input(Weight) to matrix, it should be set as 0 if Input(Weight) is the weight of fc layer, and should be set as 1 if Input(Weight) is the weight of conv layer. Default: 0.
dim(int, optional): The index of dimension which should be permuted to the first before reshaping Input(Weight) to matrix, it should be set as 0 if Input(Weight) is the weight of fc layer, and should be set as 1 if Input(Weight) is the weight of conv layer. Default: 0.
power_iters(int, optional): The number of power iterations to calculate spectral norm. Default: 1.
epsilon(float, optional): The epsilon for numerical stability in calculating norms. Default: 1e-12.
eps(float, optional): The epsilon for numerical stability in calculating norms. Default: 1e-12.
name (str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` .
dtype (str, optional): Data type, it can be "float32" or "float64". Default: "float32".

Expand All @@ -1854,7 +1854,7 @@ class SpectralNorm(Layer):
import paddle
x = paddle.rand((2,8,32,32))

spectral_norm = paddle.nn.SpectralNorm(x.shape, axis=1, power_iters=2)
spectral_norm = paddle.nn.SpectralNorm(x.shape, dim=1, power_iters=2)
spectral_norm_out = spectral_norm(x)

print(spectral_norm_out.shape) # [2, 8, 32, 32]
Expand All @@ -1864,25 +1864,25 @@ class SpectralNorm(Layer):
def __init__(
self,
weight_shape,
axis=0,
dim=0,
power_iters=1,
epsilon=1e-12,
eps=1e-12,
dtype='float32',
):
super().__init__()
self._power_iters = power_iters
self._epsilon = epsilon
self._dim = axis
self._epsilon = eps
self._dim = dim
self._dtype = dtype

self._weight_shape = list(weight_shape)
assert (
np.prod(self._weight_shape) > 0
), "Any dimension of `weight_shape` cannot be equal to 0."
assert axis < len(self._weight_shape), (
"The input `axis` should be less than the "
"length of `weight_shape`, but received axis="
"{}".format(axis)
assert dim < len(self._weight_shape), (
"The input `dim` should be less than the "
"length of `weight_shape`, but received dim="
"{}".format(dim)
)
h = self._weight_shape[self._dim]
w = np.prod(self._weight_shape) // h
Expand Down