@@ -41,11 +41,7 @@ struct GeluWithoutApproximateFunctor {
4141 inline HOSTDEVICE T operator ()(T arg_x) {
4242 // actual gelu with approximation = false
4343 MPType x = static_cast <MPType>(arg_x);
44- MPType one = static_cast <MPType>(1 );
45- MPType half = static_cast <MPType>(0.5 );
46- MPType erf_out = erf (x * static_cast <MPType>(M_SQRT1_2));
47- MPType out = x * half * (one + erf_out);
48- return static_cast <T>(out);
44+ return static_cast <T>(x * normcdf (x));
4945 }
5046};
5147
@@ -100,12 +96,10 @@ struct GeluWithoutApproximateGradFunctor {
10096 inline HOSTDEVICE T operator ()(T arg_x, T arg_dout) {
10197 MPType x = static_cast <MPType>(arg_x);
10298 MPType dout = static_cast <MPType>(arg_dout);
103- MPType one = static_cast <MPType>(1 );
104- MPType half = static_cast <MPType>(0.5 );
105- MPType kAlpha = static_cast <MPType>(M_2_SQRTPI * M_SQRT1_2);
106- auto ans = half * (one + erf (x * static_cast <MPType>(M_SQRT1_2))) +
107- half * kAlpha * x * exp (-half * x * x);
108- return static_cast <T>(ans * dout);
99+ constexpr MPType kBeta = M_2_SQRTPI * M_SQRT1_2 * static_cast <MPType>(0.5 );
100+ const MPType cdf = normcdf (x);
101+ const MPType pdf = exp (static_cast <MPType>(-0.5 ) * x * x) * kBeta ;
102+ return static_cast <T>(dout * (cdf + x * pdf));
109103 }
110104};
111105
0 commit comments