@@ -429,16 +429,24 @@ def testLowRankCholesky(self):
429
429
matrix = self ._random_batch_psd (dim )
430
430
true_diag = tf .linalg .diag_part (matrix )
431
431
432
- pchol , r = linalg .low_rank_cholesky (matrix , max_rank = 1 )
432
+ pchol , r , residual_diag = linalg .low_rank_cholesky (matrix , max_rank = 1 )
433
433
self .assertEqual (1 , self .evaluate (r ))
434
+ self .assertEqual ((2 , 11 ), residual_diag .shape )
434
435
mat = tf .matmul (pchol , pchol , transpose_b = True )
435
436
diag_diff_prev = self .evaluate (tf .abs (tf .linalg .diag_part (mat ) - true_diag ))
436
437
diff_norm_prev = self .evaluate (
437
438
tf .linalg .norm (mat - matrix , ord = 'fro' , axis = [- 1 , - 2 ]))
439
+ old_residual_trace = None
438
440
for rank in range (2 , dim + 1 ):
439
441
# Specifying trace_rtol forces the full max_rank decomposition.
440
- pchol , r = linalg .low_rank_cholesky (matrix , max_rank = rank , trace_rtol = - 1 )
442
+ pchol , r , residual_diag = linalg .low_rank_cholesky (
443
+ matrix , max_rank = rank , trace_rtol = - 1 )
441
444
self .assertEqual (rank , self .evaluate (r ))
445
+ residual_trace = tf .math .reduce_sum (residual_diag , axis = - 1 )
446
+ if old_residual_trace is not None :
447
+ self .assertTrue (self .evaluate (tf .reduce_all (
448
+ residual_trace < old_residual_trace )))
449
+ old_residual_trace = residual_trace
442
450
# Compared to pivot_cholesky, low_rank_cholesky will sometimes have
443
451
# approximate zeros like 7e-17 or -2.6e-7 where it "should" have a
444
452
# real zero.
@@ -471,7 +479,7 @@ def testGradient(self):
471
479
dim = 11
472
480
473
481
def fn (matrix ):
474
- chol , _ = linalg .low_rank_cholesky (matrix , max_rank = dim // 3 )
482
+ chol , _ , _ = linalg .low_rank_cholesky (matrix , max_rank = dim // 3 )
475
483
return chol
476
484
def grad (matrix ):
477
485
_ , dmatrix = gradient .value_and_gradient (fn , matrix )
@@ -494,7 +502,7 @@ def testGradientTapeCFv2(self):
494
502
def grad (matrix ):
495
503
with tf .GradientTape () as tape :
496
504
tape .watch (matrix )
497
- pchol , _ = linalg .low_rank_cholesky (matrix , max_rank = dim // 3 )
505
+ pchol , _ , _ = linalg .low_rank_cholesky (matrix , max_rank = dim // 3 )
498
506
dmatrix = tape .gradient (
499
507
pchol , matrix , output_gradients = tf .ones_like (pchol ) * .01 )
500
508
return dmatrix
@@ -561,7 +569,7 @@ def testOracleExamples(self, mat, oracle_pchol):
561
569
562
570
mat = np .matmul (mat , mat .T )
563
571
for rank in range (1 , max_rank ):
564
- lr_chol , r = fns [rank ](mat )
572
+ lr_chol , r , _ = fns [rank ](mat )
565
573
self .assertEqual (self .evaluate (r ), rank )
566
574
self .assertAllClose (
567
575
oracle_pchol [..., :rank ], lr_chol [..., :rank ], atol = 1e-4 )
0 commit comments