This repository was archived by the owner on May 11, 2023. It is now read-only.

Description
Hi,
Thanks for the great work!
I believe I found an inconsistency in the returned shape of the log_det function between different operators. See the example below where the ConstantDiagonalLinOp returns a one-dimensional array with a single element Float[Array, "1"], while the DenseLinOp produces a float Float[Array, ""]. The type defs in the function signature suggests that Float[Array, "1"] is the correct one.
import jax
import jaxlinop
diag = jaxlinop.ConstantDiagonalLinearOperator(value=jax.numpy.array([1.0]), size=3)
log_det_diag = diag.log_det()
print(log_det_diag) # >>> [0.]
dense = jaxlinop.DenseLinearOperator(matrix=jnp.eye(3))
log_det_dense = dense.log_det()
print(log_det_dense) # >>> 0.0
I'd be happy to submit a fix if you believe this is a bug.
Side question: have you considered running runtime type checkers (e.g., https://github.com/beartype/beartype) on this code in the CI?
Thanks,
Vincent