Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions docs/source/nn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,16 @@ Utilities

.. autofunction:: torch.nn.utils.clip_grad_norm

:hidden:`weight_norm`
~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: torch.nn.utils.weight_norm

:hidden:`remove_weight_norm`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: torch.nn.utils.remove_weight_norm


.. currentmodule:: torch.nn.utils.rnn

Expand Down
23 changes: 23 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,29 @@ def compare_scaling(grads):
scale = compare_scaling(grads)
self.assertEqual(scale, 1)

def test_weight_norm(self):
input = Variable(torch.randn(3, 5))
m = nn.Linear(5, 7)
expected_output = m(input)

# add weight normalization
m = torch.nn.utils.weight_norm(m)
self.assertEqual(m.weight_v.size(), m.weight.size())
self.assertEqual(m.weight_g.size(), (7, 1))
self.assertEqual(m(input), expected_output)

# remove weight norm
m = torch.nn.utils.remove_weight_norm(m)
self.assertFalse(hasattr(m, 'weight_g'))
self.assertFalse(hasattr(m, 'weight_v'))
self.assertEqual(m(input), expected_output)

# test with dim=1
m = torch.nn.utils.weight_norm(m, dim=1)
self.assertEqual(m.weight_v.size(), m.weight.size())
self.assertEqual(m.weight_g.size(), (1, 5))
self.assertEqual(m(input), expected_output)

def test_embedding_padding_idx(self):
embedding = nn.Embedding(10, 20, padding_idx=0)
input = Variable(torch.LongTensor([[0, 2, 4, 5], [4, 3, 0, 9]]))
Expand Down
1 change: 1 addition & 0 deletions torch/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from .parameter import Parameter
from .parallel import DataParallel
from . import init
from . import utils
21 changes: 21 additions & 0 deletions torch/nn/modules/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__(self):
self._buffers = OrderedDict()
self._backward_hooks = OrderedDict()
self._forward_hooks = OrderedDict()
self._forward_pre_hooks = OrderedDict()
self._modules = OrderedDict()
self.training = True

Expand Down Expand Up @@ -186,6 +187,22 @@ def register_backward_hook(self, hook):
self._backward_hooks[handle.id] = hook
return handle

def register_forward_pre_hook(self, hook):
"""Registers a forward pre-hook on the module.

The hook will be called before :func:`forward` is invoked.
It should have the following signature::

hook(module, input) -> None

The hook should not modify the input.
This function returns a handle with a method ``handle.remove()``
that removes the hook from the module.
"""
handle = hooks.RemovableHandle(self._forward_pre_hooks)
self._forward_pre_hooks[handle.id] = hook
return handle

def register_forward_hook(self, hook):
"""Registers a forward hook on the module.

Expand All @@ -203,6 +220,8 @@ def register_forward_hook(self, hook):
return handle

def __call__(self, *input, **kwargs):
for hook in self._forward_pre_hooks.values():
hook(self, input)
result = self.forward(*input, **kwargs)
for hook in self._forward_hooks.values():
hook_result = hook(self, input, result)
Expand Down Expand Up @@ -449,6 +468,8 @@ def named_modules(self, memo=None, prefix=''):
memo.add(self)
yield prefix, self
for name, module in self._modules.items():
if module is None:
continue
submodule_prefix = prefix + ('.' if prefix else '') + name
for m in module.named_modules(memo, submodule_prefix):
yield m
Expand Down
1 change: 1 addition & 0 deletions torch/nn/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from . import rnn
from .clip_grad import clip_grad_norm
from .weight_norm import weight_norm, remove_weight_norm
122 changes: 122 additions & 0 deletions torch/nn/utils/weight_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
"""
Weight Normalization from https://arxiv.org/abs/1602.07868
"""
import torch.utils.hooks as hooks
from torch.nn.parameter import Parameter


class WeightNorm(object):
def __init__(self, name, dim):
self.name = name
self.dim = dim

def compute_weight(self, module):
g = getattr(module, self.name + '_g')
v = getattr(module, self.name + '_v')
return v * (g / self.norm(v))

def norm(self, p):
"""Computes the norm over all dimensions except dim"""
if self.dim is None:
return p.norm()
if self.dim != 0:
p = p.transpose(0, self.dim)
output_size = (p.size(0),) + (1,) * (p.dim() - 1)
p = p.contiguous().view(p.size(0), -1).norm(dim=1).view(*output_size)
if self.dim != 0:
p = p.transpose(0, self.dim)
return p

@staticmethod
def apply(module, name, dim):
fn = WeightNorm(name, dim)

weight = getattr(module, name)

# remove w from parameter list
del module._parameters[name]

# add g and v as new parameters and express w as g/||v|| * v
module.register_parameter(name + '_g', Parameter(fn.norm(weight).data))
module.register_parameter(name + '_v', Parameter(weight.data))
setattr(module, name, fn.compute_weight(module))

handle = hooks.RemovableHandle(module._forward_pre_hooks)
module._forward_pre_hooks[handle.id] = fn
fn.handle = handle

return fn

def remove(self, module):
weight = self.compute_weight(module)

self.handle.remove()
delattr(module, self.name)
del module._parameters[self.name + '_g']
del module._parameters[self.name + '_v']
module.register_parameter(self.name, Parameter(weight.data))

def __call__(self, module, inputs):
setattr(module, self.name, self.compute_weight(module))


def weight_norm(module, name='weight', dim=0):
"""Applies weight normalization to a parameter in the given module.

.. math::
\mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|}

Weight normalization is a reparameterization that decouples the magnitude
of a weight tensor from its direction. This replaces the parameter specified
by `name` (e.g. "weight") with two parameters: one specifying the magnitude
(e.g. "weight_g") and one specifying the direction (e.g. "weight_v").
Weight normalization is implemented via a hook that recomputes the weight
tensor from the magnitude and direction before every :meth:`~Module.forward`
call.

By default, with `dim=0`, the norm is computed independently per output
channel/plane. To compute a norm over the entire weight tensor, use
`dim=None`.

See https://arxiv.org/abs/1602.07868

Args:
module (nn.Module): containing module
name (str, optional): name of weight parameter
dim (int, optional): dimension over which to compute the norm

Returns:
The original module with the weight norm hook

Example::

>>> m = weight_norm(nn.Linear(20, 40), name='weight')
Linear (20 -> 40)
>>> m.weight_g.size()
torch.Size([40, 1])
>>> m.weight_v.size()
torch.Size([40, 20])

"""
WeightNorm.apply(module, name, dim)
return module


def remove_weight_norm(module, name='weight'):
"""Removes the weight normalization reparameterization from a module.

Args:
module (nn.Module): containing module
name (str, optional): name of weight parameter

Example:
>>> m = weight_norm(nn.Linear(20, 40))
>>> remove_weight_norm(m)
"""
for hook in module._forward_pre_hooks.values():
if isinstance(hook, WeightNorm) and hook.name == name:
hook.remove(module)
return module

raise ValueError("weight_norm of '{}' not found in {}"
.format(name, module))