Skip to content

Commit 3c22222

Browse files
committed
fix: refactored kernel_init calls so that scale can be overriden from outside
1 parent c076c31 commit 3c22222

File tree

3 files changed

+29
-42
lines changed

3 files changed

+29
-42
lines changed

evaluate.ipynb

Lines changed: 16 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2672,7 +2672,7 @@
26722672
},
26732673
{
26742674
"cell_type": "code",
2675-
"execution_count": 36,
2675+
"execution_count": 40,
26762676
"metadata": {},
26772677
"outputs": [],
26782678
"source": [
@@ -2695,7 +2695,7 @@
26952695
" embedding_dim: int\n",
26962696
" dtype: Any = jnp.float32\n",
26972697
" precision: Any = jax.lax.Precision.HIGH\n",
2698-
" kernel_init: Callable = kernel_init(1.0)\n",
2698+
" kernel_init: Callable = partial(kernel_init, 1.0)\n",
26992699
"\n",
27002700
" @nn.compact\n",
27012701
" def __call__(self, x):\n",
@@ -2706,7 +2706,7 @@
27062706
" kernel_size=(self.patch_size, self.patch_size), \n",
27072707
" strides=(self.patch_size, self.patch_size),\n",
27082708
" dtype=self.dtype,\n",
2709-
" kernel_init=self.kernel_init,\n",
2709+
" kernel_init=self.kernel_init(),\n",
27102710
" precision=self.precision)(x)\n",
27112711
" x = jnp.reshape(x, (batch, -1, self.embedding_dim))\n",
27122712
" return x\n",
@@ -2739,7 +2739,7 @@
27392739
" norm_groups:int=8\n",
27402740
" dtype: Optional[Dtype] = None\n",
27412741
" precision: PrecisionLike = None\n",
2742-
" kernel_init: Callable = partial(kernel_init)\n",
2742+
" kernel_init: Callable = partial(kernel_init, scale=1.0)\n",
27432743
" add_residualblock_output: bool = False\n",
27442744
"\n",
27452745
" def setup(self):\n",
@@ -2758,10 +2758,10 @@
27582758
"\n",
27592759
" # Patch embedding\n",
27602760
" x = PatchEmbedding(patch_size=self.patch_size, embedding_dim=self.emb_features, \n",
2761-
" dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init(1.0))(x)\n",
2761+
" dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init)(x)\n",
27622762
" num_patches = x.shape[1]\n",
27632763
" \n",
2764-
" context_emb = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(1.0), \n",
2764+
" context_emb = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(), \n",
27652765
" dtype=self.dtype, precision=self.precision)(textcontext)\n",
27662766
" num_text_tokens = textcontext.shape[1]\n",
27672767
" \n",
@@ -2784,32 +2784,32 @@
27842784
" dtype=self.dtype, precision=self.precision, use_projection=self.use_projection, \n",
27852785
" use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.force_fp32_for_softmax, \n",
27862786
" only_pure_attention=False,\n",
2787-
" kernel_init=self.kernel_init(1.0))(x)\n",
2787+
" kernel_init=self.kernel_init())(x)\n",
27882788
" skips.append(x)\n",
27892789
" \n",
27902790
" # Middle block\n",
27912791
" x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads, \n",
27922792
" dtype=self.dtype, precision=self.precision, use_projection=self.use_projection, \n",
27932793
" use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.force_fp32_for_softmax, \n",
27942794
" only_pure_attention=False,\n",
2795-
" kernel_init=self.kernel_init(1.0))(x)\n",
2795+
" kernel_init=self.kernel_init())(x)\n",
27962796
" \n",
27972797
" # # Out blocks\n",
27982798
" for i in range(self.num_layers // 2):\n",
27992799
" x = jnp.concatenate([x, skips.pop()], axis=-1)\n",
2800-
" x = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(1.0), \n",
2800+
" x = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(), \n",
28012801
" dtype=self.dtype, precision=self.precision)(x)\n",
28022802
" x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads, \n",
28032803
" dtype=self.dtype, precision=self.precision, use_projection=self.use_projection, \n",
28042804
" use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.force_fp32_for_softmax, \n",
28052805
" only_pure_attention=False,\n",
2806-
" kernel_init=self.kernel_init(1.0))(x)\n",
2806+
" kernel_init=self.kernel_init())(x)\n",
28072807
" \n",
28082808
" # print(f'Shape of x after transformer blocks: {x.shape}')\n",
28092809
" x = self.norm()(x)\n",
28102810
" \n",
28112811
" patch_dim = self.patch_size ** 2 * self.output_channels\n",
2812-
" x = nn.Dense(features=patch_dim, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init(1.0))(x)\n",
2812+
" x = nn.Dense(features=patch_dim, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init())(x)\n",
28132813
" x = x[:, 1 + num_text_tokens:, :]\n",
28142814
" x = unpatchify(x, channels=self.output_channels)\n",
28152815
" \n",
@@ -2823,7 +2823,7 @@
28232823
" kernel_size=(3, 3),\n",
28242824
" strides=(1, 1),\n",
28252825
" # activation=jax.nn.mish\n",
2826-
" kernel_init=self.kernel_init(0.0),\n",
2826+
" kernel_init=self.kernel_init(scale=0.0),\n",
28272827
" dtype=self.dtype,\n",
28282828
" precision=self.precision\n",
28292829
" )(x)\n",
@@ -2837,7 +2837,7 @@
28372837
" kernel_size=(3, 3),\n",
28382838
" strides=(1, 1),\n",
28392839
" # activation=jax.nn.mish\n",
2840-
" kernel_init=self.kernel_init(0.0),\n",
2840+
" kernel_init=self.kernel_init(scale=0.0),\n",
28412841
" dtype=self.dtype,\n",
28422842
" precision=self.precision\n",
28432843
" )(x)\n",
@@ -2846,23 +2846,9 @@
28462846
},
28472847
{
28482848
"cell_type": "code",
2849-
"execution_count": 37,
2849+
"execution_count": 42,
28502850
"metadata": {},
2851-
"outputs": [
2852-
{
2853-
"ename": "TypeError",
2854-
"evalue": "kernel_init() missing 1 required positional argument: 'scale'",
2855-
"output_type": "error",
2856-
"traceback": [
2857-
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
2858-
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
2859-
"Cell \u001b[0;32mIn[37], line 12\u001b[0m\n\u001b[1;32m 3\u001b[0m textcontext \u001b[38;5;241m=\u001b[39m jnp\u001b[38;5;241m.\u001b[39mones((\u001b[38;5;241m8\u001b[39m, \u001b[38;5;241m77\u001b[39m, \u001b[38;5;241m768\u001b[39m), dtype\u001b[38;5;241m=\u001b[39mjnp\u001b[38;5;241m.\u001b[39mbfloat16) \n\u001b[1;32m 4\u001b[0m vit \u001b[38;5;241m=\u001b[39m UViT(patch_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m16\u001b[39m, \n\u001b[1;32m 5\u001b[0m emb_features\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m768\u001b[39m, \n\u001b[1;32m 6\u001b[0m num_layers\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m12\u001b[39m, \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 10\u001b[0m norm_groups\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0\u001b[39m,\n\u001b[1;32m 11\u001b[0m dtype\u001b[38;5;241m=\u001b[39mjnp\u001b[38;5;241m.\u001b[39mbfloat16)\n\u001b[0;32m---> 12\u001b[0m params \u001b[38;5;241m=\u001b[39m \u001b[43mvit\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mjax\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrandom\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mPRNGKey\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtemb\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtextcontext\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 14\u001b[0m \u001b[38;5;129m@jax\u001b[39m\u001b[38;5;241m.\u001b[39mjit\n\u001b[1;32m 15\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mapply\u001b[39m(params, x, temb, textcontext):\n\u001b[1;32m 16\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m vit\u001b[38;5;241m.\u001b[39mapply(params, x, temb, textcontext)\n",
2860-
" \u001b[0;31m[... skipping hidden 9 frame]\u001b[0m\n",
2861-
"Cell \u001b[0;32mIn[36], line 122\u001b[0m, in \u001b[0;36mUViT.__call__\u001b[0;34m(self, x, temb, textcontext)\u001b[0m\n\u001b[1;32m 120\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_layers \u001b[38;5;241m/\u001b[39m\u001b[38;5;241m/\u001b[39m \u001b[38;5;241m2\u001b[39m):\n\u001b[1;32m 121\u001b[0m x \u001b[38;5;241m=\u001b[39m jnp\u001b[38;5;241m.\u001b[39mconcatenate([x, skips\u001b[38;5;241m.\u001b[39mpop()], axis\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[0;32m--> 122\u001b[0m x \u001b[38;5;241m=\u001b[39m nn\u001b[38;5;241m.\u001b[39mDenseGeneral(features\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39memb_features, kernel_init\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mkernel_init\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m, \n\u001b[1;32m 123\u001b[0m dtype\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdtype, precision\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprecision)(x)\n\u001b[1;32m 124\u001b[0m x \u001b[38;5;241m=\u001b[39m TransformerBlock(heads\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_heads, dim_head\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39memb_features \u001b[38;5;241m/\u001b[39m\u001b[38;5;241m/\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_heads, \n\u001b[1;32m 125\u001b[0m dtype\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdtype, precision\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprecision, use_projection\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39muse_projection, \n\u001b[1;32m 126\u001b[0m use_flash_attention\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39muse_flash_attention, use_self_and_cross\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39muse_self_and_cross, force_fp32_for_softmax\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mforce_fp32_for_softmax, \n\u001b[1;32m 127\u001b[0m only_pure_attention\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[1;32m 128\u001b[0m kernel_init\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mkernel_init(\u001b[38;5;241m1.0\u001b[39m))(x)\n\u001b[1;32m 130\u001b[0m \u001b[38;5;66;03m# print(f'Shape of x after transformer blocks: {x.shape}')\u001b[39;00m\n",
2862-
"\u001b[0;31mTypeError\u001b[0m: kernel_init() missing 1 required positional argument: 'scale'"
2863-
]
2864-
}
2865-
],
2851+
"outputs": [],
28662852
"source": [
28672853
"x = jnp.ones((8, 128, 128, 3), dtype=jnp.bfloat16)\n",
28682854
"temb = jnp.ones((8,), dtype=jnp.bfloat16)\n",
@@ -2874,6 +2860,7 @@
28742860
" dropout_rate=0.1, \n",
28752861
" add_residualblock_output=True,\n",
28762862
" norm_groups=0,\n",
2863+
" kernel_init=partial(kernel_init, scale=1.0),\n",
28772864
" dtype=jnp.bfloat16)\n",
28782865
"params = vit.init(jax.random.PRNGKey(0), x, temb, textcontext)\n",
28792866
"\n",

flaxdiff/models/simple_vit.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class PatchEmbedding(nn.Module):
2323
embedding_dim: int
2424
dtype: Any = jnp.float32
2525
precision: Any = jax.lax.Precision.HIGH
26-
kernel_init: Callable = kernel_init(1.0)
26+
kernel_init: Callable = partial(kernel_init, 1.0)
2727

2828
@nn.compact
2929
def __call__(self, x):
@@ -34,7 +34,7 @@ def __call__(self, x):
3434
kernel_size=(self.patch_size, self.patch_size),
3535
strides=(self.patch_size, self.patch_size),
3636
dtype=self.dtype,
37-
kernel_init=self.kernel_init,
37+
kernel_init=self.kernel_init(),
3838
precision=self.precision)(x)
3939
x = jnp.reshape(x, (batch, -1, self.embedding_dim))
4040
return x
@@ -67,7 +67,7 @@ class UViT(nn.Module):
6767
norm_groups:int=8
6868
dtype: Optional[Dtype] = None
6969
precision: PrecisionLike = None
70-
kernel_init: Callable = partial(kernel_init)
70+
kernel_init: Callable = partial(kernel_init, scale=1.0)
7171
add_residualblock_output: bool = False
7272

7373
def setup(self):
@@ -86,10 +86,10 @@ def __call__(self, x, temb, textcontext=None):
8686

8787
# Patch embedding
8888
x = PatchEmbedding(patch_size=self.patch_size, embedding_dim=self.emb_features,
89-
dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init(1.0))(x)
89+
dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init)(x)
9090
num_patches = x.shape[1]
9191

92-
context_emb = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(1.0),
92+
context_emb = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(),
9393
dtype=self.dtype, precision=self.precision)(textcontext)
9494
num_text_tokens = textcontext.shape[1]
9595

@@ -112,32 +112,32 @@ def __call__(self, x, temb, textcontext=None):
112112
dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
113113
use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.force_fp32_for_softmax,
114114
only_pure_attention=False,
115-
kernel_init=self.kernel_init(1.0))(x)
115+
kernel_init=self.kernel_init())(x)
116116
skips.append(x)
117117

118118
# Middle block
119119
x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
120120
dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
121121
use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.force_fp32_for_softmax,
122122
only_pure_attention=False,
123-
kernel_init=self.kernel_init(1.0))(x)
123+
kernel_init=self.kernel_init())(x)
124124

125125
# # Out blocks
126126
for i in range(self.num_layers // 2):
127127
x = jnp.concatenate([x, skips.pop()], axis=-1)
128-
x = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(1.0),
128+
x = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(),
129129
dtype=self.dtype, precision=self.precision)(x)
130130
x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
131131
dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
132132
use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.force_fp32_for_softmax,
133133
only_pure_attention=False,
134-
kernel_init=self.kernel_init(1.0))(x)
134+
kernel_init=self.kernel_init())(x)
135135

136136
# print(f'Shape of x after transformer blocks: {x.shape}')
137137
x = self.norm()(x)
138138

139139
patch_dim = self.patch_size ** 2 * self.output_channels
140-
x = nn.Dense(features=patch_dim, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init(1.0))(x)
140+
x = nn.Dense(features=patch_dim, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init())(x)
141141
x = x[:, 1 + num_text_tokens:, :]
142142
x = unpatchify(x, channels=self.output_channels)
143143

@@ -151,7 +151,7 @@ def __call__(self, x, temb, textcontext=None):
151151
kernel_size=(3, 3),
152152
strides=(1, 1),
153153
# activation=jax.nn.mish
154-
kernel_init=self.kernel_init(0.0),
154+
kernel_init=self.kernel_init(scale=0.0),
155155
dtype=self.dtype,
156156
precision=self.precision
157157
)(x)
@@ -165,7 +165,7 @@ def __call__(self, x, temb, textcontext=None):
165165
kernel_size=(3, 3),
166166
strides=(1, 1),
167167
# activation=jax.nn.mish
168-
kernel_init=self.kernel_init(0.0),
168+
kernel_init=self.kernel_init(scale=0.0),
169169
dtype=self.dtype,
170170
precision=self.precision
171171
)(x)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
setup(
1212
name='flaxdiff',
1313
packages=find_packages(),
14-
version='0.1.30',
14+
version='0.1.31',
1515
description='A versatile and easy to understand Diffusion library',
1616
long_description=open('README.md').read(),
1717
long_description_content_type='text/markdown',

0 commit comments

Comments
 (0)