Skip to content
This repository was archived by the owner on May 11, 2023. It is now read-only.
This repository was archived by the owner on May 11, 2023. It is now read-only.

bug: inconsistency in return value log_det #21

@vdutor

Description

@vdutor

Hi,

Thanks for the great work!

I believe I found an inconsistency in the returned shape of the log_det function between different operators. See the example below where the ConstantDiagonalLinOp returns a one-dimensional array with a single element Float[Array, "1"], while the DenseLinOp produces a float Float[Array, ""]. The type defs in the function signature suggests that Float[Array, "1"] is the correct one.

import jax
import jaxlinop

diag = jaxlinop.ConstantDiagonalLinearOperator(value=jax.numpy.array([1.0]), size=3)
log_det_diag = diag.log_det()
print(log_det_diag)  # >>> [0.]

dense = jaxlinop.DenseLinearOperator(matrix=jnp.eye(3))
log_det_dense = dense.log_det()
print(log_det_dense) # >>> 0.0

I'd be happy to submit a fix if you believe this is a bug.

Side question: have you considered running runtime type checkers (e.g., https://github.com/beartype/beartype) on this code in the CI?

Thanks,
Vincent

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workinggood first issueGood for newcomers

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions