You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
voidadam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay) {
15
-
CHECK_INPUT(p);
16
-
if (p_copy.numel() > 0) CHECK_INPUT(p_copy);
17
-
CHECK_INPUT(m);
18
-
CHECK_INPUT(v);
19
-
CHECK_INPUT(g);
20
-
int64_t num_elem = p.numel();
21
-
AT_ASSERTM(m.numel() == num_elem, "number of elements in m and p tensors should be equal");
22
-
AT_ASSERTM(v.numel() == num_elem, "number of elements in v and p tensors should be equal");
23
-
AT_ASSERTM(g.numel() == num_elem, "number of elements in g and p tensors should be equal");
24
-
AT_ASSERTM(p_copy.numel() == num_elem || p_copy.numel() == 0, "number of elements in p_copy and p tensors should be equal, or p_copy should be empty");
0 commit comments