Skip to content

Commit de0c183

Browse files
committed
switch freqs to linspace for rotary embedding
1 parent bf652d3 commit de0c183

File tree

2 files changed

+2
-3
lines changed

2 files changed

+2
-3
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'uformer-pytorch',
55
packages = find_packages(),
6-
version = '0.0.7',
6+
version = '0.0.8',
77
license='MIT',
88
description = 'Uformer - Pytorch',
99
author = 'Phil Wang',

uformer_pytorch/uformer_pytorch.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
# constants
1212

13-
LayerNorm = partial(nn.InstanceNorm2d, affine = True)
1413
List = nn.ModuleList
1514

1615
# helpers
@@ -44,7 +43,7 @@ class AxialRotaryEmbedding(nn.Module):
4443
def __init__(self, dim, max_freq = 10):
4544
super().__init__()
4645
self.dim = dim
47-
scales = torch.logspace(0., log(max_freq / 2) / log(2), self.dim // 4, base = 2)
46+
scales = torch.linspace(1., max_freq / 2, self.dim // 4)
4847
self.register_buffer('scales', scales)
4948

5049
def forward(self, x):

0 commit comments

Comments
 (0)