Skip to content

Commit f870ac4

Browse files
csutertensorflower-gardener
authored andcommitted
Remove dependency on deprecated scipy.misc.derivative
PiperOrigin-RevId: 757824889
1 parent 1429fdf commit f870ac4

File tree

2 files changed

+13
-6
lines changed

2 files changed

+13
-6
lines changed

tensorflow_probability/python/distributions/exp_gamma_test.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
# ============================================================================
1515
from absl.testing import parameterized
1616
import numpy as np
17-
from scipy import misc as sp_misc
1817
from scipy import special as sp_special
1918
from scipy import stats as sp_stats
2019

@@ -671,8 +670,13 @@ def gen_samples(concentration, rate):
671670
def expected_grad(s, c, r):
672671
u = sp_special.gammainc(c, s * r)
673672
delta = 1e-4
674-
return sp_misc.derivative(
675-
lambda x: sp_special.gammaincinv(x, u), c, dx=delta * c) / r
673+
674+
def _finite_diff(f, x, dx):
675+
return (f(x + dx) - f(x - dx)) / 2 / dx
676+
677+
return _finite_diff(
678+
f=lambda x: sp_special.gammaincinv(x, u), x=c, dx=delta * c,
679+
) / r
676680

677681
self.assertAllClose(
678682
concentration_grad,

tensorflow_probability/python/distributions/gamma_test.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
from absl.testing import parameterized
1919
import numpy as np
20-
from scipy import misc as sp_misc
2120
from scipy import special as sp_special
2221
from scipy import stats as sp_stats
2322

@@ -945,8 +944,12 @@ def gen_samples(concentration, rate):
945944
def expected_grad(s, c, r):
946945
u = sp_special.gammainc(c, s * r)
947946
delta = 1e-4
948-
return sp_misc.derivative(
949-
lambda x: sp_special.gammaincinv(x, u), c, dx=delta * c) / r
947+
def _finite_diff(f, x, dx):
948+
return (f(x + dx) - f(x - dx)) / 2 / dx
949+
950+
return _finite_diff(
951+
f=lambda x: sp_special.gammaincinv(x, u), x=c, dx=delta * c,
952+
) / r
950953

951954
self.assertAllClose(
952955
concentration_grad,

0 commit comments

Comments
 (0)