13
13
14
14
try :
15
15
# optional functionality
16
- import pymc3 as pm
17
- import theano .tensor as tt
16
+ import pymc as pm
17
+ import pytensor .tensor as pt
18
18
import arviz as az
19
19
from fastprogress import fastprogress
20
20
except (ImportError , AttributeError ) as e :
@@ -77,8 +77,8 @@ def sample_sinusoid(time, flux, const, slope, f_n, a_n, ph_n, c_err, sl_err, f_n
77
77
"""
78
78
# setup
79
79
time_t = time .reshape (- 1 , 1 ) # transposed time
80
- t_mean = tt .as_tensor_variable (np .mean (time ))
81
- t_mean_s = tt .as_tensor_variable (np .array ([np .mean (time [s [0 ]:s [1 ]]) for s in i_chunks ]))
80
+ t_mean = pt .as_tensor_variable (np .mean (time ))
81
+ t_mean_s = pt .as_tensor_variable (np .array ([np .mean (time [s [0 ]:s [1 ]]) for s in i_chunks ]))
82
82
lin_shape = (len (const ),)
83
83
sin_shape = (len (f_n ),)
84
84
# progress bar
@@ -97,7 +97,7 @@ def sample_sinusoid(time, flux, const, slope, f_n, a_n, ph_n, c_err, sl_err, f_n
97
97
slope_pm = pm .Normal ('slope' , mu = slope , sigma = sl_err , shape = lin_shape , testval = slope )
98
98
# piece-wise linear curve
99
99
linear_curves = [const_pm [k ] + slope_pm [k ] * (time [s [0 ]:s [1 ]] - t_mean_s [k ]) for k , s in enumerate (i_chunks )]
100
- model_linear = tt .concatenate (linear_curves )
100
+ model_linear = pt .concatenate (linear_curves )
101
101
# sinusoid parameter models
102
102
f_n_pm = pm .TruncatedNormal ('f_n' , mu = f_n , sigma = f_n_err , lower = 0 , shape = sin_shape , testval = f_n )
103
103
a_n_pm = pm .TruncatedNormal ('a_n' , mu = a_n , sigma = a_n_err , lower = 0 , shape = sin_shape , testval = a_n )
@@ -111,8 +111,7 @@ def sample_sinusoid(time, flux, const, slope, f_n, a_n, ph_n, c_err, sl_err, f_n
111
111
112
112
# do the sampling
113
113
with lc_model :
114
- inf_data = pm .sample (draws = 1000 , tune = 1000 , init = 'adapt_diag' , cores = 1 , progressbar = (logger is not None ),
115
- return_inferencedata = True )
114
+ inf_data = pm .sample (draws = 1000 , tune = 1000 , init = 'adapt_diag' , cores = 1 , progressbar = (logger is not None ))
116
115
117
116
if logger is not None :
118
117
az .summary (inf_data , round_to = 2 , circ_var_names = ['ph_n' ])
@@ -201,8 +200,8 @@ def sample_sinusoid_h(time, flux, p_orb, const, slope, f_n, a_n, ph_n, p_err, c_
201
200
"""
202
201
# setup
203
202
time_t = time .reshape (- 1 , 1 ) # transposed time
204
- t_mean = tt .as_tensor_variable (np .mean (time ))
205
- t_mean_s = tt .as_tensor_variable (np .array ([np .mean (time [s [0 ]:s [1 ]]) for s in i_chunks ]))
203
+ t_mean = pt .as_tensor_variable (np .mean (time ))
204
+ t_mean_s = pt .as_tensor_variable (np .array ([np .mean (time [s [0 ]:s [1 ]]) for s in i_chunks ]))
206
205
harmonics , harmonic_n = star_shine .core .frequency_sets .find_harmonics_from_pattern (f_n , 1 / p_orb , f_tol = 1e-9 )
207
206
non_harm = np .delete (np .arange (len (f_n )), harmonics )
208
207
lin_shape = (len (const ),)
@@ -224,7 +223,7 @@ def sample_sinusoid_h(time, flux, p_orb, const, slope, f_n, a_n, ph_n, p_err, c_
224
223
slope_pm = pm .Normal ('slope' , mu = slope , sigma = sl_err , shape = lin_shape , testval = slope )
225
224
# piece-wise linear curve
226
225
linear_curves = [const_pm [k ] + slope_pm [k ] * (time [s [0 ]:s [1 ]] - t_mean_s [k ]) for k , s in enumerate (i_chunks )]
227
- model_linear = tt .concatenate (linear_curves )
226
+ model_linear = pt .concatenate (linear_curves )
228
227
# sinusoid parameter models
229
228
f_n_pm = pm .TruncatedNormal ('f_n' , mu = f_n [non_harm ], sigma = f_n_err [non_harm ], lower = 0 , shape = sin_shape ,
230
229
testval = f_n [non_harm ])
@@ -250,8 +249,7 @@ def sample_sinusoid_h(time, flux, p_orb, const, slope, f_n, a_n, ph_n, p_err, c_
250
249
251
250
# do the sampling
252
251
with lc_model :
253
- inf_data = pm .sample (draws = 1000 , tune = 1000 , init = 'adapt_diag' , cores = 1 , progressbar = (logger is not None ),
254
- return_inferencedata = True )
252
+ inf_data = pm .sample (draws = 1000 , tune = 1000 , init = 'adapt_diag' , cores = 1 , progressbar = (logger is not None ))
255
253
256
254
if logger is not None :
257
255
az .summary (inf_data , round_to = 2 , circ_var_names = ['ph_n' ])
0 commit comments