-
Couldn't load subscription status.
- Fork 68
Feature Request: Add convolution layers #120
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Feature Request: Add convolution layers #120
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! Conv layers seem like they could be a pretty useful addition.
I left some comments on specific things I noticed. High level summary:
- A lot of the changes seem like reformatting, could you make sure you are using Penzai's formatting conventions?
- I think it might be simpler to combine
ConvandConvTransposeinto one class. - I notice you introduced a number of helper functions. Could you add more documentation about what they do and what they return, and also maybe mark them private using a leading underscore if they are not meant to be used by users? (Some of this might also be fixable by just inlining the implementation into the merged
Conv/ConvTransposeclass.
penzai/core/named_axes.py
Outdated
|
|
||
| def _nmap_with_doc( | ||
| fun: Callable[..., Any], fun_name: str, fun_doc: str | None = None | ||
| fun: Callable[..., Any], fun_name: str, fun_doc: str | None = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm noticing a large amount of reformatting changes that made their way into this PR, which makes it hard to tell what the actual changes are.
Penzai currently uses pyink to format code (see our pyink config and CI checks). Can you reformat your code using Penzai's formatting configuration?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I was not using the right config I will fix it.
penzai/nn/linear_and_affine.py
Outdated
| xavier_uniform_initializer = functools.partial( | ||
| variance_scaling_initializer, | ||
| scale=1.0, | ||
| mode="fan_avg", | ||
| distribution="uniform", | ||
| xavier_uniform_initializer = cast( | ||
| LinearOperatorWeightInitializer, | ||
| functools.partial( | ||
| variance_scaling_initializer, | ||
| scale=1.0, | ||
| mode="fan_avg", | ||
| distribution="uniform", | ||
| ), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My type checker was complaining but with the right config, the cast do not seems necessary indeed.
penzai/nn/linear_and_affine.py
Outdated
| ) | ||
|
|
||
|
|
||
| def maybe_rename_output_axes( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does this do, and what does it return? It should have a docstring that explains what it does, especially if this is supposed to be called by users, and it should have a return type annotation. If it is private it should start with an underscore.
Is this refactoring out the logic for axis name conflicts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes the underscore is missing. You are right I refactored out the logic for axis name collision because the same logic is also needed for convolution layers.
penzai/nn/linear_and_affine.py
Outdated
| def prepare_for_conv( | ||
| inputs: NamedArray, | ||
| kernel: NamedArray, | ||
| spatial_axis_names: Sequence[str], | ||
| in_axis_names: Sequence[str], | ||
| out_axis_names: Sequence[str], | ||
| ): | ||
| """Preprocess lhs and rhs for jax convolution operator""" | ||
|
|
||
| lhs = inputs | ||
| rhs = kernel | ||
|
|
||
| in_axis_name = "in_axis-" + "-".join(in_axis_names) | ||
| out_axis_name = "out_axis-" + "-".join(out_axis_names) | ||
|
|
||
| # merge in axes into one in channel axis for the inputs and the kernel | ||
| lhs = lhs.untag(*in_axis_names).flatten().tag(in_axis_name) | ||
| rhs = rhs.untag(*in_axis_names).flatten().tag(in_axis_name) | ||
|
|
||
| # merge out axes into one out channels axis for jax convolution | ||
| rhs = rhs.untag(*out_axis_names).flatten().tag(out_axis_name) | ||
|
|
||
| # untag spatial axes | ||
| lhs = lhs.untag(*spatial_axis_names, in_axis_name) | ||
| rhs = rhs.untag(*spatial_axis_names, in_axis_name, out_axis_name) | ||
| return lhs, rhs | ||
|
|
||
|
|
||
| def get_named_axis_back_after_conv( | ||
| result: NamedArray, | ||
| spatial_axis_names: Sequence[str], | ||
| out_axis_names: Sequence[str], | ||
| out_axis_shape: Sequence[int], | ||
| ): | ||
| """Postprocess result from jax convolution operator""" | ||
| # Get named axes back | ||
| return ( | ||
| result.tag_prefix(*spatial_axis_names) | ||
| .reshape(out_axis_shape) | ||
| .tag(*out_axis_names) | ||
| ) | ||
|
|
||
|
|
||
| def maybe_broadcast(value: int | Sequence[int], count: int): | ||
| return [value] * count if isinstance(value, int) else value | ||
|
|
||
|
|
||
| def get_dimension_numbers(ndim): | ||
| return jax.lax.ConvDimensionNumbers( | ||
| lhs_spec=(0, ndim + 1) | ||
| + tuple(range(1, ndim + 1)), # BHSpatial -> BCSpatial | ||
| rhs_spec=(ndim + 1, ndim) + tuple(range(ndim)), # SpatialIO -> OISpatial | ||
| out_spec=(0, ndim + 1) | ||
| + tuple(range(1, ndim + 1)), # BSpatialC -> BCSpatial | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Docstrings for these could be a lot more detailed, and it would be good to add return type annotations. I am not sure from reading this what these functions are supposed to do? Also some of them are so short they could possibly just be inlined into the implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will add docstring and underscore when needed as these functions are mostly utils that are shared between Conv and ConvTranspose. The short ones are shortcut for code that is otherwise quite obscure / not very readable (e.g. get dimension numbers). I supposed that putting this code inside function makes the Conv & ConvTranspose logic less cluttered with low level implementation details.
|
|
||
|
|
||
| @struct.pytree_dataclass | ||
| class AbstractGeneralConv(layer_base.Layer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm noticing a lot of duplication between AbstractGeneralConv, Conv and ConvTranspose. What do you think about combining all of these into a single class and having transpose: bool be an attribute?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That makes sense, I will think about it, in other frameworks Conv and ConvTranspose are usually different modules. For the user it would mean writing Conv(..., transpose=True) instead of ConvTranspose(...) which is a bit heavier and somewhat makes upsampling conv transpose layers less identifiable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I can see arguments either way. Other options:
- keep them as separate classes but just have a single helper function that does the whole implementation (maybe
_conv_call), with an argumenttranspose - implement all of the logic in AbstractGeneralConv and have an abstract class method
def _is_transposed_conv(cls) -> boolthat tells it which one to use, then just implement that method on the two subclasses
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the ideas, I implemented option 2 in latest commit.
tests/nn/linear_and_affine_test.py
Outdated
| ), | ||
| ) | ||
|
|
||
| def test_conv(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding tests!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great, thanks much! Left a few more comments.
penzai/nn/linear_and_affine.py
Outdated
| in_axis_name = "in_axis-" + "-".join(in_axis_names) | ||
| out_axis_name = "out_axis-" + "-".join(out_axis_names) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor, but I noticed these are temporary axis names that are only used inside this function. There's actually an existing helper for this, TmpPosAxisMarker, which can be used to give a temporary name to a positional axis in a way that is guaranteed not to conflict with anything else. Maybe you could use this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did not know it exist, I use it now.
penzai/nn/linear_and_affine.py
Outdated
| A tuple of two named arrays, the first one is the input with the in axes | ||
| merged into a single in channel axis, and the second one is the kernel with | ||
| the in axes merged into a single in channel axis and the out axes merged | ||
| into a single out channel axis. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I understand right, the two named arrays have a very specific order of their positional axes, where all spatial axes are converted to positional (in the order given by spatial_axis_names) and then there are extra positional axes for the new input and output axes. Could you add a brief note in the docstring describing this layout?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes you understand correctly, I added precision in the doc-string.
| a single axis before applying the convolution. | ||
| Args: | ||
| result: The result of the jax convolution operator. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similarly, I think your implementation here assumes that the result has a very specific ordering of its positional axes (spatial axes in the order given by spatial_axis_names and then the combined out axis). Mind documenting it here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same, it is updated
| ) | ||
|
|
||
|
|
||
| def _maybe_broadcast(value: int | Sequence[int], count: int) -> Sequence[int]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it important that the sequence have length count? If so it might make sense to check.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added a check and changed a bit the ConvTransposed implementation to make it pass the check.
|
|
||
|
|
||
| def _get_dimension_numbers(ndim) -> jax.lax.ConvDimensionNumbers: | ||
| """Returns the dimension numbers for a convolution operator. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the idea that this dimension numbers object is specialized to the version where the spatial axes come first, followed by the in/out axes? Maybe you could add a note that this dimension numbers matches the shapes returned by _prepare_for_conv, if that's the goal?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I added precision in the doc-string
penzai/nn/linear_and_affine.py
Outdated
| return ConvInPlace( | ||
| sublayers=[ | ||
| core_layer, | ||
| RenameAxes(old=tuple(primed_names), new=tuple(original_names)), | ||
| ], | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess this doesn't quite work for ConvTranspose since it won't produce ConvTransposeInPlace?
One option would be to turn from_config into a private helper on AbstractGeneralConv (since you override it anyway), and pass in the in-place class as a keyword argument.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice catch, it should be fixed now
| ) | ||
| return core_layer | ||
|
|
||
| def _is_transposed(self) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mind adding @abc.abstractmethod here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Thank you for your help :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, this looks great! I'd be happy to merge this if you think it's in good shape (looks like the PR is still a draft PR though).
| } | ||
| ), | ||
| ) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One thought: do you think it would be valuable to add a test that actually makes sure the convolution operator matches the JAX version? e.g. pass in an array of random values, do the convolution normally with jax on an unwrapped version of it, and check that the results are the same?
This would make sure that nothing went wrong with the implementation somehow. Although I'm not sure this adds that much, since probably most errors would lead to a wrong shape. (Maybe it would be able to catch if you accidentally used a regular conv instead of a transposed conv or something like this?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added tests
|
Looks like there's a few small type errors right now, actually: /home/runner/work/penzai/penzai/penzai/nn/linear_and_affine.py:1140:1: error: in _from_config: bad return type [bad-return-type]
Expected: Union[Conv, ConvInPlace, ConvTranspose, ConvTransposeInPlace]
Actually returned: AbstractGeneralConv
return core_layer~~~~~~~~~~~~~~~~~~~~
return core_layer
/home/runner/work/penzai/penzai/penzai/nn/linear_and_affine.py:1194:1: error: in <dictcomp>: No attribute 'convolution_spatial_axis_names' on AbstractGeneralConv [attribute-error]
if name not in self.convolution_spatial_axis_names~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
if name not in self.convolution_spatial_axis_names
Called from (traceback):
line 11[91](https://github.com/google-deepmind/penzai/actions/runs/15741268214/job/44458653876?pr=120#step:9:92), in parallel_axes
/home/runner/work/penzai/penzai/penzai/nn/linear_and_affine.py:1207:1: error: in <dictcomp>: No attribute 'spatial_axes_names' on AbstractGeneralConv [attribute-error]
if name in self.spatial_axes_names~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
if name in self.spatial_axes_names
Called from (traceback):
line 1204, in convolution_spatial_axes
/home/runner/work/penzai/penzai/penzai/nn/linear_and_affine.py:1326:1: error: in from_config: bad option 'Union[ConvTranspose, ConvTransposeInPlace]' in return type [bad-return-type]
Expected: Union[Conv, ConvInPlace]
Actually returned: Union[Conv, ConvInPlace, ConvTranspose, ConvTransposeInPlace]
return super()._from_config(~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
return super()._from_config(
/home/runner/work/penzai/penzai/penzai/nn/linear_and_affine.py:1461:1: error: in from_config: bad option 'Union[Conv, ConvInPlace]' in return type [bad-return-type]
Expected: Union[ConvTranspose, ConvTransposeInPlace]
Actually returned: Union[Conv, ConvInPlace, ConvTranspose, ConvTransposeInPlace]
return super()._from_config(~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
return super()._from_config( |
It should be fixed with the last commit. I added casts in Conv and ConvTranspose implementations. |
|
I noticed that I did not implement bias for Conv and ConvTranspose. Do you think I should add a "AffineConv" & "AffineConvTranspose" like you did with the Linear layer ? Maybe it's too verbose and a use_bias parameter would do the job but it would not be consistent with the existing Linear & Affine layers. Or we could make Affine a wrapper but it would require significant changes and will probably break things. |
|
Hmm good question. At least in terms of the concrete layers that get built, I think it makes sense to have the Conv/ConvTranspose and AddBias layers separate, rather than a single layer that does both. It could be nice to have a wrapper constructor like pz.nn.Sequential([
pz.nn.Conv.from_config(
name="test",
init_base_rng=rng,
input_axes={"foo": 3, "baz": 7},
output_axes={"foo_out": 5, "baz_out": 11},
convolution_spatial_axes={"height": 3, "width": 3},
parallel_axes=None,
parallel_broadcast_axes=None,
rename_outputs_if_necessary=True,
),
pz.nn.AddBias.from_config(
name="test_bias",
init_base_rng=rng,
biased_axes={"foo_out": 5, "baz_out": 11},
)
])(Also, seems like we are still getting type errors in CI, maybe due to outdated type annotations? You can run the typechecking locally with |
|
Thank you for the efforts to add convolution layers, which help me a lot. However, when I use your convolution layer, I met some bugs. Specially, when I set And here is the error: Hope you can check such bugs before merging to the |
|
And the Here is the error: Update: I find when I change some lines in class Then the errors disappear. Hope it helps. |
|
For the issue with the stride it seems the culprit is the shape check that fails out_struct = self._output_structure()
shapecheck.check_structure(
result, out_struct, known_vars=dimvars, error_prefix=error_prefix
)I guess it is possible to make it work but it would involve some effort to get it right considering padding, stride, kernel size, input dilation and kernel dilation. This check exists for the Linear layer, I am curious about the motivation behind it ? Is this necessary or should such shape check be only in the tests. |
|
I guess the motivation here is to check the dimensions and make sure they are not inherently changed, e.g., the axis names are not exactly the same, resulting in extra dimensions. But for convolution layers, extra efforts may be required. Currently, in my code, I commented the lines for the check to make it work. |
|
Thanks @guxm2021 for pointing out these issues! The The shape checking methods are mostly for convenience and to rule out bugs in the implementation. In this case it doesn't look like the shape annotations are correctly describing the changes to the spatial axes. It might be simplest to just remove the |
Hello, thank you for the amazing work. I really like the design choices of the library and I would like to contribute to the project.
I noticed that convolution layers are missing, and I think it would be a nice addition.
So I made a draft implementation for a Conv & ConvTranspose layers following what you did with the Linear layer. I would be happy to discuss the implementation.