|
2672 | 2672 | },
|
2673 | 2673 | {
|
2674 | 2674 | "cell_type": "code",
|
2675 |
| - "execution_count": 36, |
| 2675 | + "execution_count": 40, |
2676 | 2676 | "metadata": {},
|
2677 | 2677 | "outputs": [],
|
2678 | 2678 | "source": [
|
|
2695 | 2695 | " embedding_dim: int\n",
|
2696 | 2696 | " dtype: Any = jnp.float32\n",
|
2697 | 2697 | " precision: Any = jax.lax.Precision.HIGH\n",
|
2698 |
| - " kernel_init: Callable = kernel_init(1.0)\n", |
| 2698 | + " kernel_init: Callable = partial(kernel_init, 1.0)\n", |
2699 | 2699 | "\n",
|
2700 | 2700 | " @nn.compact\n",
|
2701 | 2701 | " def __call__(self, x):\n",
|
|
2706 | 2706 | " kernel_size=(self.patch_size, self.patch_size), \n",
|
2707 | 2707 | " strides=(self.patch_size, self.patch_size),\n",
|
2708 | 2708 | " dtype=self.dtype,\n",
|
2709 |
| - " kernel_init=self.kernel_init,\n", |
| 2709 | + " kernel_init=self.kernel_init(),\n", |
2710 | 2710 | " precision=self.precision)(x)\n",
|
2711 | 2711 | " x = jnp.reshape(x, (batch, -1, self.embedding_dim))\n",
|
2712 | 2712 | " return x\n",
|
|
2739 | 2739 | " norm_groups:int=8\n",
|
2740 | 2740 | " dtype: Optional[Dtype] = None\n",
|
2741 | 2741 | " precision: PrecisionLike = None\n",
|
2742 |
| - " kernel_init: Callable = partial(kernel_init)\n", |
| 2742 | + " kernel_init: Callable = partial(kernel_init, scale=1.0)\n", |
2743 | 2743 | " add_residualblock_output: bool = False\n",
|
2744 | 2744 | "\n",
|
2745 | 2745 | " def setup(self):\n",
|
|
2758 | 2758 | "\n",
|
2759 | 2759 | " # Patch embedding\n",
|
2760 | 2760 | " x = PatchEmbedding(patch_size=self.patch_size, embedding_dim=self.emb_features, \n",
|
2761 |
| - " dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init(1.0))(x)\n", |
| 2761 | + " dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init)(x)\n", |
2762 | 2762 | " num_patches = x.shape[1]\n",
|
2763 | 2763 | " \n",
|
2764 |
| - " context_emb = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(1.0), \n", |
| 2764 | + " context_emb = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(), \n", |
2765 | 2765 | " dtype=self.dtype, precision=self.precision)(textcontext)\n",
|
2766 | 2766 | " num_text_tokens = textcontext.shape[1]\n",
|
2767 | 2767 | " \n",
|
|
2784 | 2784 | " dtype=self.dtype, precision=self.precision, use_projection=self.use_projection, \n",
|
2785 | 2785 | " use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.force_fp32_for_softmax, \n",
|
2786 | 2786 | " only_pure_attention=False,\n",
|
2787 |
| - " kernel_init=self.kernel_init(1.0))(x)\n", |
| 2787 | + " kernel_init=self.kernel_init())(x)\n", |
2788 | 2788 | " skips.append(x)\n",
|
2789 | 2789 | " \n",
|
2790 | 2790 | " # Middle block\n",
|
2791 | 2791 | " x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads, \n",
|
2792 | 2792 | " dtype=self.dtype, precision=self.precision, use_projection=self.use_projection, \n",
|
2793 | 2793 | " use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.force_fp32_for_softmax, \n",
|
2794 | 2794 | " only_pure_attention=False,\n",
|
2795 |
| - " kernel_init=self.kernel_init(1.0))(x)\n", |
| 2795 | + " kernel_init=self.kernel_init())(x)\n", |
2796 | 2796 | " \n",
|
2797 | 2797 | " # # Out blocks\n",
|
2798 | 2798 | " for i in range(self.num_layers // 2):\n",
|
2799 | 2799 | " x = jnp.concatenate([x, skips.pop()], axis=-1)\n",
|
2800 |
| - " x = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(1.0), \n", |
| 2800 | + " x = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(), \n", |
2801 | 2801 | " dtype=self.dtype, precision=self.precision)(x)\n",
|
2802 | 2802 | " x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads, \n",
|
2803 | 2803 | " dtype=self.dtype, precision=self.precision, use_projection=self.use_projection, \n",
|
2804 | 2804 | " use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.force_fp32_for_softmax, \n",
|
2805 | 2805 | " only_pure_attention=False,\n",
|
2806 |
| - " kernel_init=self.kernel_init(1.0))(x)\n", |
| 2806 | + " kernel_init=self.kernel_init())(x)\n", |
2807 | 2807 | " \n",
|
2808 | 2808 | " # print(f'Shape of x after transformer blocks: {x.shape}')\n",
|
2809 | 2809 | " x = self.norm()(x)\n",
|
2810 | 2810 | " \n",
|
2811 | 2811 | " patch_dim = self.patch_size ** 2 * self.output_channels\n",
|
2812 |
| - " x = nn.Dense(features=patch_dim, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init(1.0))(x)\n", |
| 2812 | + " x = nn.Dense(features=patch_dim, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init())(x)\n", |
2813 | 2813 | " x = x[:, 1 + num_text_tokens:, :]\n",
|
2814 | 2814 | " x = unpatchify(x, channels=self.output_channels)\n",
|
2815 | 2815 | " \n",
|
|
2823 | 2823 | " kernel_size=(3, 3),\n",
|
2824 | 2824 | " strides=(1, 1),\n",
|
2825 | 2825 | " # activation=jax.nn.mish\n",
|
2826 |
| - " kernel_init=self.kernel_init(0.0),\n", |
| 2826 | + " kernel_init=self.kernel_init(scale=0.0),\n", |
2827 | 2827 | " dtype=self.dtype,\n",
|
2828 | 2828 | " precision=self.precision\n",
|
2829 | 2829 | " )(x)\n",
|
|
2837 | 2837 | " kernel_size=(3, 3),\n",
|
2838 | 2838 | " strides=(1, 1),\n",
|
2839 | 2839 | " # activation=jax.nn.mish\n",
|
2840 |
| - " kernel_init=self.kernel_init(0.0),\n", |
| 2840 | + " kernel_init=self.kernel_init(scale=0.0),\n", |
2841 | 2841 | " dtype=self.dtype,\n",
|
2842 | 2842 | " precision=self.precision\n",
|
2843 | 2843 | " )(x)\n",
|
|
2846 | 2846 | },
|
2847 | 2847 | {
|
2848 | 2848 | "cell_type": "code",
|
2849 |
| - "execution_count": 37, |
| 2849 | + "execution_count": 42, |
2850 | 2850 | "metadata": {},
|
2851 |
| - "outputs": [ |
2852 |
| - { |
2853 |
| - "ename": "TypeError", |
2854 |
| - "evalue": "kernel_init() missing 1 required positional argument: 'scale'", |
2855 |
| - "output_type": "error", |
2856 |
| - "traceback": [ |
2857 |
| - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", |
2858 |
| - "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", |
2859 |
| - "Cell \u001b[0;32mIn[37], line 12\u001b[0m\n\u001b[1;32m 3\u001b[0m textcontext \u001b[38;5;241m=\u001b[39m jnp\u001b[38;5;241m.\u001b[39mones((\u001b[38;5;241m8\u001b[39m, \u001b[38;5;241m77\u001b[39m, \u001b[38;5;241m768\u001b[39m), dtype\u001b[38;5;241m=\u001b[39mjnp\u001b[38;5;241m.\u001b[39mbfloat16) \n\u001b[1;32m 4\u001b[0m vit \u001b[38;5;241m=\u001b[39m UViT(patch_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m16\u001b[39m, \n\u001b[1;32m 5\u001b[0m emb_features\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m768\u001b[39m, \n\u001b[1;32m 6\u001b[0m num_layers\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m12\u001b[39m, \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 10\u001b[0m norm_groups\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0\u001b[39m,\n\u001b[1;32m 11\u001b[0m dtype\u001b[38;5;241m=\u001b[39mjnp\u001b[38;5;241m.\u001b[39mbfloat16)\n\u001b[0;32m---> 12\u001b[0m params \u001b[38;5;241m=\u001b[39m \u001b[43mvit\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mjax\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrandom\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mPRNGKey\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtemb\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtextcontext\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 14\u001b[0m \u001b[38;5;129m@jax\u001b[39m\u001b[38;5;241m.\u001b[39mjit\n\u001b[1;32m 15\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mapply\u001b[39m(params, x, temb, textcontext):\n\u001b[1;32m 16\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m vit\u001b[38;5;241m.\u001b[39mapply(params, x, temb, textcontext)\n", |
2860 |
| - " \u001b[0;31m[... skipping hidden 9 frame]\u001b[0m\n", |
2861 |
| - "Cell \u001b[0;32mIn[36], line 122\u001b[0m, in \u001b[0;36mUViT.__call__\u001b[0;34m(self, x, temb, textcontext)\u001b[0m\n\u001b[1;32m 120\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_layers \u001b[38;5;241m/\u001b[39m\u001b[38;5;241m/\u001b[39m \u001b[38;5;241m2\u001b[39m):\n\u001b[1;32m 121\u001b[0m x \u001b[38;5;241m=\u001b[39m jnp\u001b[38;5;241m.\u001b[39mconcatenate([x, skips\u001b[38;5;241m.\u001b[39mpop()], axis\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[0;32m--> 122\u001b[0m x \u001b[38;5;241m=\u001b[39m nn\u001b[38;5;241m.\u001b[39mDenseGeneral(features\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39memb_features, kernel_init\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mkernel_init\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m, \n\u001b[1;32m 123\u001b[0m dtype\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdtype, precision\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprecision)(x)\n\u001b[1;32m 124\u001b[0m x \u001b[38;5;241m=\u001b[39m TransformerBlock(heads\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_heads, dim_head\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39memb_features \u001b[38;5;241m/\u001b[39m\u001b[38;5;241m/\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_heads, \n\u001b[1;32m 125\u001b[0m dtype\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdtype, precision\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprecision, use_projection\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39muse_projection, \n\u001b[1;32m 126\u001b[0m use_flash_attention\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39muse_flash_attention, use_self_and_cross\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39muse_self_and_cross, force_fp32_for_softmax\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mforce_fp32_for_softmax, \n\u001b[1;32m 127\u001b[0m only_pure_attention\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[1;32m 128\u001b[0m kernel_init\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mkernel_init(\u001b[38;5;241m1.0\u001b[39m))(x)\n\u001b[1;32m 130\u001b[0m \u001b[38;5;66;03m# print(f'Shape of x after transformer blocks: {x.shape}')\u001b[39;00m\n", |
2862 |
| - "\u001b[0;31mTypeError\u001b[0m: kernel_init() missing 1 required positional argument: 'scale'" |
2863 |
| - ] |
2864 |
| - } |
2865 |
| - ], |
| 2851 | + "outputs": [], |
2866 | 2852 | "source": [
|
2867 | 2853 | "x = jnp.ones((8, 128, 128, 3), dtype=jnp.bfloat16)\n",
|
2868 | 2854 | "temb = jnp.ones((8,), dtype=jnp.bfloat16)\n",
|
|
2874 | 2860 | " dropout_rate=0.1, \n",
|
2875 | 2861 | " add_residualblock_output=True,\n",
|
2876 | 2862 | " norm_groups=0,\n",
|
| 2863 | + " kernel_init=partial(kernel_init, scale=1.0),\n", |
2877 | 2864 | " dtype=jnp.bfloat16)\n",
|
2878 | 2865 | "params = vit.init(jax.random.PRNGKey(0), x, temb, textcontext)\n",
|
2879 | 2866 | "\n",
|
|
0 commit comments