2222import dpctl .tensor ._tensor_impl as ti
2323import dpctl .tensor ._tensor_reductions_impl as tri
2424
25- from ._reduction import _default_reduction_dtype
26-
2725
2826def _var_impl (x , axis , correction , keepdims ):
2927 nd = x .ndim
@@ -233,22 +231,25 @@ def mean(x, axis=None, keepdims=False):
233231 host_tasks_list .append (ht_e1 )
234232 s_e .append (r_e )
235233 else :
236- tmp_dt = _default_reduction_dtype (inp_dt , q )
237234 tmp = dpt .empty (
238- res_shape , dtype = tmp_dt , usm_type = res_usm_type , sycl_queue = q
235+ arr2 . shape , dtype = res_dt , usm_type = res_usm_type , sycl_queue = q
239236 )
240- ht_e_tmp , r_e = tri . _sum_over_axis (
241- src = arr2 , trailing_dims_to_reduce = sum_nd , dst = tmp , sycl_queue = q
237+ ht_e_cpy , cpy_e = ti . _copy_usm_ndarray_into_usm_ndarray (
238+ src = arr2 , dst = tmp , sycl_queue = q
242239 )
243- host_tasks_list .append (ht_e_tmp )
240+ host_tasks_list .append (ht_e_cpy )
244241 res = dpt .empty (
245242 res_shape , dtype = res_dt , usm_type = res_usm_type , sycl_queue = q
246243 )
247- ht_e1 , c_e = ti ._copy_usm_ndarray_into_usm_ndarray (
248- src = tmp , dst = res , sycl_queue = q , depends = [r_e ]
244+ ht_e_red , r_e = tri ._sum_over_axis (
245+ src = tmp ,
246+ trailing_dims_to_reduce = sum_nd ,
247+ dst = res ,
248+ sycl_queue = q ,
249+ depends = [cpy_e ],
249250 )
250- host_tasks_list .append (ht_e1 )
251- s_e .append (c_e )
251+ host_tasks_list .append (ht_e_red )
252+ s_e .append (r_e )
252253
253254 if keepdims :
254255 res_shape = res_shape + (1 ,) * sum_nd
@@ -257,8 +258,9 @@ def mean(x, axis=None, keepdims=False):
257258
258259 res_shape = res .shape
259260 # in-place divide
261+ den_dt = dpt .finfo (res_dt ).dtype if res_dt .kind == "c" else res_dt
260262 nelems_arr = dpt .asarray (
261- nelems , dtype = res_dt , usm_type = res_usm_type , sycl_queue = q
263+ nelems , dtype = den_dt , usm_type = res_usm_type , sycl_queue = q
262264 )
263265 if nelems_arr .shape != res_shape :
264266 nelems_arr = dpt .broadcast_to (nelems_arr , res_shape )
0 commit comments