Skip to content

Commit 9bc0d1e

Browse files
Have low_rank_cholesky return the diagonal of the residual, as it is useful
when preconditioning. PiperOrigin-RevId: 529077632
1 parent 04c1006 commit 9bc0d1e

File tree

2 files changed

+24
-11
lines changed

2 files changed

+24
-11
lines changed

tensorflow_probability/python/math/linalg.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -428,10 +428,15 @@ def low_rank_cholesky(matrix, max_rank, trace_atol=0, trace_rtol=0, name=None):
428428
name: Optional name for the op.
429429
430430
Returns:
431-
A pair (LR, r) of a matrix LR such that the rank of LR is r <= max_rank
432-
and if r is < max_rank, trace(matrix - LR * LR^t) < trace_atol.
433-
If matrix is of shape (b1, ..., bn, m, m), then LR will be of shape
434-
(b1, ..., bn, m, r) where r <= max_rank.
431+
A triplet (LR, r, residual_diag) of
432+
LR: a matrix such that LR * LR^t is approximately the input matrix.
433+
If matrix is of shape (b1, ..., bn, m, m), then LR will be of shape
434+
(b1, ..., bn, m, r) where r <= max_rank.
435+
r: the rank of LR. If r is < max_rank, then
436+
trace(matrix - LR * LR^t) < trace_atol, and
437+
residual_diag: The diagonal entries of matrix - LR * LR^t. This is
438+
returned because together with LR, it is useful for preconditioning
439+
the input matrix.
435440
"""
436441
with tf.name_scope(name or 'low_rank_cholesky'):
437442
dtype = dtype_util.common_dtype([matrix, trace_atol, trace_rtol],
@@ -498,14 +503,14 @@ def lr_cholesky_body(i, lr, residual_diag):
498503
lr = tf.zeros(matrix.shape, dtype=matrix.dtype)[..., :max_rank]
499504

500505
mdiag = tf.linalg.diag_part(matrix)
501-
i, lr, _ = tf.while_loop(
506+
i, lr, residual_diag = tf.while_loop(
502507
cond=lr_cholesky_cond,
503508
body=lr_cholesky_body,
504509
loop_vars=(0, lr, mdiag),
505510
maximum_iterations=max_rank
506511
)
507512

508-
return lr, i
513+
return lr, i, residual_diag
509514

510515

511516
def lu_solve(lower_upper, perm, rhs,

tensorflow_probability/python/math/linalg_test.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -429,16 +429,24 @@ def testLowRankCholesky(self):
429429
matrix = self._random_batch_psd(dim)
430430
true_diag = tf.linalg.diag_part(matrix)
431431

432-
pchol, r = linalg.low_rank_cholesky(matrix, max_rank=1)
432+
pchol, r, residual_diag = linalg.low_rank_cholesky(matrix, max_rank=1)
433433
self.assertEqual(1, self.evaluate(r))
434+
self.assertEqual((2, 11), residual_diag.shape)
434435
mat = tf.matmul(pchol, pchol, transpose_b=True)
435436
diag_diff_prev = self.evaluate(tf.abs(tf.linalg.diag_part(mat) - true_diag))
436437
diff_norm_prev = self.evaluate(
437438
tf.linalg.norm(mat - matrix, ord='fro', axis=[-1, -2]))
439+
old_residual_trace = None
438440
for rank in range(2, dim + 1):
439441
# Specifying trace_rtol forces the full max_rank decomposition.
440-
pchol, r = linalg.low_rank_cholesky(matrix, max_rank=rank, trace_rtol=-1)
442+
pchol, r, residual_diag = linalg.low_rank_cholesky(
443+
matrix, max_rank=rank, trace_rtol=-1)
441444
self.assertEqual(rank, self.evaluate(r))
445+
residual_trace = tf.math.reduce_sum(residual_diag, axis=-1)
446+
if old_residual_trace is not None:
447+
self.assertTrue(self.evaluate(tf.reduce_all(
448+
residual_trace < old_residual_trace)))
449+
old_residual_trace = residual_trace
442450
# Compared to pivot_cholesky, low_rank_cholesky will sometimes have
443451
# approximate zeros like 7e-17 or -2.6e-7 where it "should" have a
444452
# real zero.
@@ -471,7 +479,7 @@ def testGradient(self):
471479
dim = 11
472480

473481
def fn(matrix):
474-
chol, _ = linalg.low_rank_cholesky(matrix, max_rank=dim // 3)
482+
chol, _, _ = linalg.low_rank_cholesky(matrix, max_rank=dim // 3)
475483
return chol
476484
def grad(matrix):
477485
_, dmatrix = gradient.value_and_gradient(fn, matrix)
@@ -494,7 +502,7 @@ def testGradientTapeCFv2(self):
494502
def grad(matrix):
495503
with tf.GradientTape() as tape:
496504
tape.watch(matrix)
497-
pchol, _ = linalg.low_rank_cholesky(matrix, max_rank=dim // 3)
505+
pchol, _, _ = linalg.low_rank_cholesky(matrix, max_rank=dim // 3)
498506
dmatrix = tape.gradient(
499507
pchol, matrix, output_gradients=tf.ones_like(pchol) * .01)
500508
return dmatrix
@@ -561,7 +569,7 @@ def testOracleExamples(self, mat, oracle_pchol):
561569

562570
mat = np.matmul(mat, mat.T)
563571
for rank in range(1, max_rank):
564-
lr_chol, r = fns[rank](mat)
572+
lr_chol, r, _ = fns[rank](mat)
565573
self.assertEqual(self.evaluate(r), rank)
566574
self.assertAllClose(
567575
oracle_pchol[..., :rank], lr_chol[..., :rank], atol=1e-4)

0 commit comments

Comments
 (0)