Skip to content

Conversation

@AntoinePlumerault
Copy link

Hello, thank you for the amazing work. I really like the design choices of the library and I would like to contribute to the project.

I noticed that convolution layers are missing, and I think it would be a nice addition.

So I made a draft implementation for a Conv & ConvTranspose layers following what you did with the Linear layer. I would be happy to discuss the implementation.

Copy link
Collaborator

@danieldjohnson danieldjohnson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! Conv layers seem like they could be a pretty useful addition.

I left some comments on specific things I noticed. High level summary:

  • A lot of the changes seem like reformatting, could you make sure you are using Penzai's formatting conventions?
  • I think it might be simpler to combine Conv and ConvTranspose into one class.
  • I notice you introduced a number of helper functions. Could you add more documentation about what they do and what they return, and also maybe mark them private using a leading underscore if they are not meant to be used by users? (Some of this might also be fixable by just inlining the implementation into the merged Conv/ConvTranspose class.


def _nmap_with_doc(
fun: Callable[..., Any], fun_name: str, fun_doc: str | None = None
fun: Callable[..., Any], fun_name: str, fun_doc: str | None = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm noticing a large amount of reformatting changes that made their way into this PR, which makes it hard to tell what the actual changes are.

Penzai currently uses pyink to format code (see our pyink config and CI checks). Can you reformat your code using Penzai's formatting configuration?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I was not using the right config I will fix it.

Comment on lines 164 to 170
xavier_uniform_initializer = functools.partial(
variance_scaling_initializer,
scale=1.0,
mode="fan_avg",
distribution="uniform",
xavier_uniform_initializer = cast(
LinearOperatorWeightInitializer,
functools.partial(
variance_scaling_initializer,
scale=1.0,
mode="fan_avg",
distribution="uniform",
),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My type checker was complaining but with the right config, the cast do not seems necessary indeed.

)


def maybe_rename_output_axes(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this do, and what does it return? It should have a docstring that explains what it does, especially if this is supposed to be called by users, and it should have a return type annotation. If it is private it should start with an underscore.

Is this refactoring out the logic for axis name conflicts?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes the underscore is missing. You are right I refactored out the logic for axis name collision because the same logic is also needed for convolution layers.

Comment on lines 828 to 882
def prepare_for_conv(
inputs: NamedArray,
kernel: NamedArray,
spatial_axis_names: Sequence[str],
in_axis_names: Sequence[str],
out_axis_names: Sequence[str],
):
"""Preprocess lhs and rhs for jax convolution operator"""

lhs = inputs
rhs = kernel

in_axis_name = "in_axis-" + "-".join(in_axis_names)
out_axis_name = "out_axis-" + "-".join(out_axis_names)

# merge in axes into one in channel axis for the inputs and the kernel
lhs = lhs.untag(*in_axis_names).flatten().tag(in_axis_name)
rhs = rhs.untag(*in_axis_names).flatten().tag(in_axis_name)

# merge out axes into one out channels axis for jax convolution
rhs = rhs.untag(*out_axis_names).flatten().tag(out_axis_name)

# untag spatial axes
lhs = lhs.untag(*spatial_axis_names, in_axis_name)
rhs = rhs.untag(*spatial_axis_names, in_axis_name, out_axis_name)
return lhs, rhs


def get_named_axis_back_after_conv(
result: NamedArray,
spatial_axis_names: Sequence[str],
out_axis_names: Sequence[str],
out_axis_shape: Sequence[int],
):
"""Postprocess result from jax convolution operator"""
# Get named axes back
return (
result.tag_prefix(*spatial_axis_names)
.reshape(out_axis_shape)
.tag(*out_axis_names)
)


def maybe_broadcast(value: int | Sequence[int], count: int):
return [value] * count if isinstance(value, int) else value


def get_dimension_numbers(ndim):
return jax.lax.ConvDimensionNumbers(
lhs_spec=(0, ndim + 1)
+ tuple(range(1, ndim + 1)), # BHSpatial -> BCSpatial
rhs_spec=(ndim + 1, ndim) + tuple(range(ndim)), # SpatialIO -> OISpatial
out_spec=(0, ndim + 1)
+ tuple(range(1, ndim + 1)), # BSpatialC -> BCSpatial
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docstrings for these could be a lot more detailed, and it would be good to add return type annotations. I am not sure from reading this what these functions are supposed to do? Also some of them are so short they could possibly just be inlined into the implementation.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will add docstring and underscore when needed as these functions are mostly utils that are shared between Conv and ConvTranspose. The short ones are shortcut for code that is otherwise quite obscure / not very readable (e.g. get dimension numbers). I supposed that putting this code inside function makes the Conv & ConvTranspose logic less cluttered with low level implementation details.



@struct.pytree_dataclass
class AbstractGeneralConv(layer_base.Layer):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm noticing a lot of duplication between AbstractGeneralConv, Conv and ConvTranspose. What do you think about combining all of these into a single class and having transpose: bool be an attribute?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense, I will think about it, in other frameworks Conv and ConvTranspose are usually different modules. For the user it would mean writing Conv(..., transpose=True) instead of ConvTranspose(...) which is a bit heavier and somewhat makes upsampling conv transpose layers less identifiable.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I can see arguments either way. Other options:

  • keep them as separate classes but just have a single helper function that does the whole implementation (maybe _conv_call), with an argument transpose
  • implement all of the logic in AbstractGeneralConv and have an abstract class method def _is_transposed_conv(cls) -> bool that tells it which one to use, then just implement that method on the two subclasses

Copy link
Author

@AntoinePlumerault AntoinePlumerault Jun 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the ideas, I implemented option 2 in latest commit.

),
)

def test_conv(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding tests!

Copy link
Collaborator

@danieldjohnson danieldjohnson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, thanks much! Left a few more comments.

Comment on lines 874 to 875
in_axis_name = "in_axis-" + "-".join(in_axis_names)
out_axis_name = "out_axis-" + "-".join(out_axis_names)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor, but I noticed these are temporary axis names that are only used inside this function. There's actually an existing helper for this, TmpPosAxisMarker, which can be used to give a temporary name to a positional axis in a way that is guaranteed not to conflict with anything else. Maybe you could use this?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not know it exist, I use it now.

Comment on lines 865 to 868
A tuple of two named arrays, the first one is the input with the in axes
merged into a single in channel axis, and the second one is the kernel with
the in axes merged into a single in channel axis and the out axes merged
into a single out channel axis.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand right, the two named arrays have a very specific order of their positional axes, where all spatial axes are converted to positional (in the order given by spatial_axis_names) and then there are extra positional axes for the new input and output axes. Could you add a brief note in the docstring describing this layout?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes you understand correctly, I added precision in the doc-string.

a single axis before applying the convolution.
Args:
result: The result of the jax convolution operator.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, I think your implementation here assumes that the result has a very specific ordering of its positional axes (spatial axes in the order given by spatial_axis_names and then the combined out axis). Mind documenting it here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same, it is updated

)


def _maybe_broadcast(value: int | Sequence[int], count: int) -> Sequence[int]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it important that the sequence have length count? If so it might make sense to check.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a check and changed a bit the ConvTransposed implementation to make it pass the check.



def _get_dimension_numbers(ndim) -> jax.lax.ConvDimensionNumbers:
"""Returns the dimension numbers for a convolution operator.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the idea that this dimension numbers object is specialized to the version where the spatial axes come first, followed by the in/out axes? Maybe you could add a note that this dimension numbers matches the shapes returned by _prepare_for_conv, if that's the goal?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I added precision in the doc-string

Comment on lines 1116 to 1121
return ConvInPlace(
sublayers=[
core_layer,
RenameAxes(old=tuple(primed_names), new=tuple(original_names)),
],
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this doesn't quite work for ConvTranspose since it won't produce ConvTransposeInPlace?

One option would be to turn from_config into a private helper on AbstractGeneralConv (since you override it anyway), and pass in the in-place class as a keyword argument.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch, it should be fixed now

)
return core_layer

def _is_transposed(self) -> bool:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mind adding @abc.abstractmethod here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Thank you for your help :)

Copy link
Collaborator

@danieldjohnson danieldjohnson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this looks great! I'd be happy to merge this if you think it's in good shape (looks like the PR is still a draft PR though).

}
),
)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thought: do you think it would be valuable to add a test that actually makes sure the convolution operator matches the JAX version? e.g. pass in an array of random values, do the convolution normally with jax on an unwrapped version of it, and check that the results are the same?

This would make sure that nothing went wrong with the implementation somehow. Although I'm not sure this adds that much, since probably most errors would lead to a wrong shape. (Maybe it would be able to catch if you accidentally used a regular conv instead of a transposed conv or something like this?)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added tests

@danieldjohnson
Copy link
Collaborator

Looks like there's a few small type errors right now, actually:

/home/runner/work/penzai/penzai/penzai/nn/linear_and_affine.py:1140:1: error: in _from_config: bad return type [bad-return-type]
           Expected: Union[Conv, ConvInPlace, ConvTranspose, ConvTransposeInPlace]
  Actually returned: AbstractGeneralConv

    return core_layer~~~~~~~~~~~~~~~~~~~~
    return core_layer


/home/runner/work/penzai/penzai/penzai/nn/linear_and_affine.py:1194:1: error: in <dictcomp>: No attribute 'convolution_spatial_axis_names' on AbstractGeneralConv [attribute-error]

        if name not in self.convolution_spatial_axis_names~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        if name not in self.convolution_spatial_axis_names


Called from (traceback):
  line 11[91](https://github.com/google-deepmind/penzai/actions/runs/15741268214/job/44458653876?pr=120#step:9:92), in parallel_axes
/home/runner/work/penzai/penzai/penzai/nn/linear_and_affine.py:1207:1: error: in <dictcomp>: No attribute 'spatial_axes_names' on AbstractGeneralConv [attribute-error]

        if name in self.spatial_axes_names~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        if name in self.spatial_axes_names


Called from (traceback):
  line 1204, in convolution_spatial_axes
/home/runner/work/penzai/penzai/penzai/nn/linear_and_affine.py:1326:1: error: in from_config: bad option 'Union[ConvTranspose, ConvTransposeInPlace]' in return type [bad-return-type]
           Expected: Union[Conv, ConvInPlace]
  Actually returned: Union[Conv, ConvInPlace, ConvTranspose, ConvTransposeInPlace]

    return super()._from_config(~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    return super()._from_config(


/home/runner/work/penzai/penzai/penzai/nn/linear_and_affine.py:1461:1: error: in from_config: bad option 'Union[Conv, ConvInPlace]' in return type [bad-return-type]
           Expected: Union[ConvTranspose, ConvTransposeInPlace]
  Actually returned: Union[Conv, ConvInPlace, ConvTranspose, ConvTransposeInPlace]

    return super()._from_config(~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    return super()._from_config(

@AntoinePlumerault
Copy link
Author

Looks like there's a few small type errors right now, actually:

/home/runner/work/penzai/penzai/penzai/nn/linear_and_affine.py:1140:1: error: in _from_config: bad return type [bad-return-type]
           Expected: Union[Conv, ConvInPlace, ConvTranspose, ConvTransposeInPlace]
  Actually returned: AbstractGeneralConv

    return core_layer~~~~~~~~~~~~~~~~~~~~
    return core_layer


/home/runner/work/penzai/penzai/penzai/nn/linear_and_affine.py:1194:1: error: in <dictcomp>: No attribute 'convolution_spatial_axis_names' on AbstractGeneralConv [attribute-error]

        if name not in self.convolution_spatial_axis_names~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        if name not in self.convolution_spatial_axis_names


Called from (traceback):
  line 11[91](https://github.com/google-deepmind/penzai/actions/runs/15741268214/job/44458653876?pr=120#step:9:92), in parallel_axes
/home/runner/work/penzai/penzai/penzai/nn/linear_and_affine.py:1207:1: error: in <dictcomp>: No attribute 'spatial_axes_names' on AbstractGeneralConv [attribute-error]

        if name in self.spatial_axes_names~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        if name in self.spatial_axes_names


Called from (traceback):
  line 1204, in convolution_spatial_axes
/home/runner/work/penzai/penzai/penzai/nn/linear_and_affine.py:1326:1: error: in from_config: bad option 'Union[ConvTranspose, ConvTransposeInPlace]' in return type [bad-return-type]
           Expected: Union[Conv, ConvInPlace]
  Actually returned: Union[Conv, ConvInPlace, ConvTranspose, ConvTransposeInPlace]

    return super()._from_config(~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    return super()._from_config(


/home/runner/work/penzai/penzai/penzai/nn/linear_and_affine.py:1461:1: error: in from_config: bad option 'Union[Conv, ConvInPlace]' in return type [bad-return-type]
           Expected: Union[ConvTranspose, ConvTransposeInPlace]
  Actually returned: Union[Conv, ConvInPlace, ConvTranspose, ConvTransposeInPlace]

    return super()._from_config(~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    return super()._from_config(

It should be fixed with the last commit. I added casts in Conv and ConvTranspose implementations.

@AntoinePlumerault
Copy link
Author

I noticed that I did not implement bias for Conv and ConvTranspose. Do you think I should add a "AffineConv" & "AffineConvTranspose" like you did with the Linear layer ? Maybe it's too verbose and a use_bias parameter would do the job but it would not be consistent with the existing Linear & Affine layers. Or we could make Affine a wrapper but it would require significant changes and will probably break things.

@danieldjohnson
Copy link
Collaborator

Hmm good question.

At least in terms of the concrete layers that get built, I think it makes sense to have the Conv/ConvTranspose and AddBias layers separate, rather than a single layer that does both.

It could be nice to have a wrapper constructor like Affine.from_config that builds both the conv layer and the bias, but I wonder if it would be simplest to just have users add biases manually? So if you want a conv layer with a bias you could just do something like

pz.nn.Sequential([
    pz.nn.Conv.from_config(
        name="test",
        init_base_rng=rng,
        input_axes={"foo": 3, "baz": 7},
        output_axes={"foo_out": 5, "baz_out": 11},
        convolution_spatial_axes={"height": 3, "width": 3},
        parallel_axes=None,
        parallel_broadcast_axes=None,
        rename_outputs_if_necessary=True,
    ),
    pz.nn.AddBias.from_config(
        name="test_bias",
        init_base_rng=rng,
        biased_axes={"foo_out": 5, "baz_out": 11},
    )
])

(Also, seems like we are still getting type errors in CI, maybe due to outdated type annotations? You can run the typechecking locally with uv run pytype --jobs auto penzai)

@guxm2021
Copy link
Contributor

Thank you for the efforts to add convolution layers, which help me a lot. However, when I use your convolution layer, I met some bugs. Specially, when I set strides larger than 1, the input and output shape may be different in the convolution_spatial_axes. Then there will be errors that cannot pass the shape_check. After checking your code, I think current shape check only consider the dimensions in convolution_spatial_axes are the same in both input and output. Here are my testing code:

inputs = jax.random.normal(
        key=jax.random.PRNGKey(42), shape=(1, 10, 15, 3)
    )
pz_inputs = pz.nx.wrap(inputs.reshape(1, 10, 15, 3)).tag(
        "batch", "height", "width", "channel",
    )
simple_layer = Conv.from_config(
    name="test",
    init_base_rng=jax.random.key(1),
    input_axes={"channel": 3},
    output_axes={"embedding": 1152},
    strides = (5, 5),
    convolution_spatial_axes={"height": 5, "width": 5},
    padding = "VALID",
    parallel_axes=None,
    parallel_broadcast_axes=None,
    rename_outputs_if_necessary=True,
)
simple_layer(pz_inputs)

And here is the error:

StructureMismatchError: (test.ke) Mismatch while checking structures:
At root: Named shape mismatch between value {'batch': 1, 'height': 2, 'width': 3, 'embedding': 1152} and pattern {'embedding': 1152, 'batch': var('B')['batch']:=1, 'height': var('B')['height']:=10, 'width': var('B')['width']:=15}:
  Axis 'height': Size 2 does not match previous size 10 for var('B')['height']:=10 from the known variable assignments (argument `known_vars` to check_structure)
  Axis 'width': Size 3 does not match previous size 15 for var('B')['width']:=15 from the known variable assignments (argument `known_vars` to check_structure)
  After inlining var('B') = {'batch': 1, 'height': 10, 'width': 15} from the known variable assignments (argument `known_vars` to check_structure)

Hope you can check such bugs before merging to the main branch.

@guxm2021
Copy link
Contributor

guxm2021 commented Jun 24, 2025

And the convolution layer seems not to be compatible with jit_wrapper.Jitted. Here are my code:

inputs = jax.random.normal(
        key=jax.random.PRNGKey(42), shape=(1, 10, 15, 3)
    )
pz_inputs = pz.nx.wrap(inputs.reshape(1, 10, 15, 3)).tag(
        "batch", "height", "width", "channel",
    )
simple_layer = Conv.from_config(
    name="test",
    init_base_rng=jax.random.key(1),
    input_axes={"channel": 3},
    output_axes={"embedding": 1152},
    strides = (5, 5),
    convolution_spatial_axes={"height": 5, "width": 5},
    padding = "VALID",
    parallel_axes=None,
    parallel_broadcast_axes=None,
    rename_outputs_if_necessary=True,
)
simple_layer = jit_wrapper.Jitted(simple_layer)
simple_layer(pz_inputs)

Here is the error:

TypeError: Error interpreting argument to <function variable_jit.<locals>.inner_fun at 0x115b9b954680> as an abstract array. The problematic value is of type <class 'str'> and was passed to the function at path layer.padding.
This typically means that a jit-wrapped function was called with a non-array argument, and this argument was not marked as static using the static_argnums or static_argnames parameters of jax.jit.

Update: I find when I change some lines in class Conv and AbstractGeneralConv like this:

strides: Sequence[int] = dataclasses.field(metadata={"pytree_node": False})
padding: str | Sequence[tuple[int, int]] = dataclasses.field(
      metadata={"pytree_node": False}
  )
kernel_dilation: Sequence[int] = dataclasses.field(metadata={"pytree_node": False})
  inputs_dilation: Sequence[int] = dataclasses.field(
      metadata={"pytree_node": False}
  )

Then the errors disappear. Hope it helps.

@AntoinePlumerault
Copy link
Author

For the issue with the stride it seems the culprit is the shape check that fails

out_struct = self._output_structure()
shapecheck.check_structure(
    result, out_struct, known_vars=dimvars, error_prefix=error_prefix
)

I guess it is possible to make it work but it would involve some effort to get it right considering padding, stride, kernel size, input dilation and kernel dilation. This check exists for the Linear layer, I am curious about the motivation behind it ? Is this necessary or should such shape check be only in the tests.

@guxm2021
Copy link
Contributor

I guess the motivation here is to check the dimensions and make sure they are not inherently changed, e.g., the axis names are not exactly the same, resulting in extra dimensions. But for convolution layers, extra efforts may be required. Currently, in my code, I commented the lines for the check to make it work.

@danieldjohnson
Copy link
Collaborator

Thanks @guxm2021 for pointing out these issues!

The jit_wrapper.Jitted issues should definitely be fixed before merging this. In general any attribute that affects the shape of the result should be marked with dataclasses.field(metadata={"pytree_node": False}).

The shape checking methods are mostly for convenience and to rule out bugs in the implementation. In this case it doesn't look like the shape annotations are correctly describing the changes to the spatial axes. It might be simplest to just remove the _input_structure and _output_structure methods and also remove the shape checking logic.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants