@@ -189,45 +189,29 @@ struct BroadcastDataLoader<Index, VecSize, false, kElementwise> {
189189  }
190190};
191191
192- //  Common broadcast data loader.
193- template  <int  Index, int  VecSize, bool  IsBoundary>
194- struct  BroadcastDataLoader <Index, VecSize, IsBoundary, kBroadcast > {
195-   template  <typename  Array1, typename  Array2, typename  Array3, typename  ArgsT>
196-   static  __device__ __forceinline__ void  Apply (const  Array1 &ins,
197-                                                ArgsT *args,
198-                                                const  Array2 &configs,
199-                                                const  Array3 &use_broadcast,
200-                                                const  int  block_offset,
201-                                                const  int  num,
202-                                                const  uint32_t  numel) {
192+ template  <int  Index, int  VecSize>
193+ struct  BroadcastDataInit  {
194+   template  <typename  ArgsT>
195+   static  __device__ __forceinline__ void  Apply (ArgsT *args) {
203196    using  Type = std::tuple_element_t <Index, ArgsT>;
204-     uint32_t  index_bc[VecSize];
205197#pragma  unroll
206198    for  (int  k = 0 ; k < VecSize; ++k) {
207-       index_bc[k] = 0 ;
208199      std::get<Index>(args[k]) = static_cast <Type>(1 );
209200    }
201+   }
202+ };
210203
211-     uint32_t  thread_offset = block_offset + threadIdx.x  * VecSize;
212- #pragma  unroll
213-     for  (int  k = 0 ; k < VecSize; ++k) {
214-       uint32_t  idx = thread_offset + k;
215-       if  (IsBoundary && idx == numel) {
216-         break ;
217-       }
218- #pragma  unroll
219-       for  (int  i = 0 ; i < phi::DDim::kMaxRank ; ++i) {
220-         if  (i == configs[0 ].rank ) break ;
221-         auto  fast_divmoder = configs[0 ].divmoders [i].Divmod (idx);
222-         idx = fast_divmoder.val [0 ];
223-         index_bc[k] += fast_divmoder.val [1 ] * configs[Index].strides [i];
224-       }
225-     }
226- 
204+ template  <int  Index, int  VecSize>
205+ struct  BroadcastDataSetter  {
206+   template  <typename  Array, typename  ArgsT>
207+   static  __device__ __forceinline__ void  Apply (const  Array &ins,
208+                                                ArgsT *args,
209+                                                uint32_t  index_bc[][VecSize]) {
210+     using  Type = std::tuple_element_t <Index, ArgsT>;
227211#pragma  unroll
228212    for  (int  k = 0 ; k < VecSize; ++k) {
229213      std::get<Index>(args[k]) =
230-           reinterpret_cast <const  _ptr_ Type *>(ins[Index])[index_bc[k]];
214+           reinterpret_cast <const  _ptr_ Type *>(ins[Index])[index_bc[Index][ k]];
231215    }
232216  }
233217};
@@ -285,8 +269,30 @@ __device__ void VectorizedBroadcastKernelImpl(
285269  __simd__ ArgsT args[VecSize];
286270  __simd__ ConditionalT<OutT, NumOuts> result[VecSize];
287271
288-   BcUnroller<BroadcastDataLoader, IsBoundary, LoadType, VecSize, Arity>::step (
289-       ins, args, configs, use_broadcast, block_offset, num, numel);
272+   if  (LoadType == kBroadcast ) {
273+     uint32_t  index_bc[Arity][VecSize] = {0 };
274+     Unroller<BroadcastDataInit, VecSize, Arity>::step (args);
275+     uint32_t  thread_offset = block_offset + threadIdx.x  * VecSize;
276+ #pragma  unroll
277+     for  (int  k = 0 ; k < VecSize; ++k) {
278+       uint32_t  idx = thread_offset + k;
279+       if  (IsBoundary && idx == numel) break ;
280+ #pragma  unroll
281+       for  (int  i = 0 ; i < phi::DDim::kMaxRank ; ++i) {
282+         if  (i == configs[0 ].rank ) break ;
283+         auto  fast_divmoder = configs[0 ].divmoders [i].Divmod (idx);
284+         idx = fast_divmoder.val [0 ];
285+ #pragma  unroll
286+         for  (int  j = 0 ; j < Arity; ++j) {
287+           index_bc[j][k] += fast_divmoder.val [1 ] * configs[j].strides [i];
288+         }
289+       }
290+     }
291+     Unroller<BroadcastDataSetter, VecSize, Arity>::step (ins, args, index_bc);
292+   } else  {
293+     BcUnroller<BroadcastDataLoader, IsBoundary, LoadType, VecSize, Arity>::step (
294+         ins, args, configs, use_broadcast, block_offset, num, numel);
295+   }
290296
291297  SameDimsElementwisePrimitiveCaller<ConditionalT<OutT, NumOuts>,
292298                                     VecSize,
@@ -783,11 +789,7 @@ struct LaunchBroadcastKernelWithInt64IndexHelper<OutT,
783789};
784790#endif 
785791
786- template  <ElementwiseType ET,
787-           typename  OutT,
788-           typename  Functor,
789-           int  kArity ,
790-           int  NumOuts = 1 >
792+ template  <typename  OutT, typename  Functor, int  kArity , int  NumOuts = 1 >
791793void  BroadcastKernelForDifferentVecSize (
792794    const  KPDevice &ctx,
793795    const  std::vector<const  DenseTensor *> &ins,
@@ -922,16 +924,12 @@ void BroadcastKernelForDifferentVecSize(
922924  }
923925}
924926
925- template  <ElementwiseType ET,
926-           typename  InT,
927-           typename  OutT,
928-           typename  Functor,
929-           int  NumOuts = 1 >
927+ template  <typename  OutT, typename  Functor, int  NumOuts = 1 >
930928void  BroadcastKernel (const  KPDevice &ctx,
931929                     const  std::vector<const  DenseTensor *> &ins,
932930                     std::vector<DenseTensor *> *outs,
933-                      int  axis ,
934-                      Functor func ) {
931+                      Functor func ,
932+                      int  axis = - 1 ) {
935933  //  When there are multiple inputs, the outputs's rank should be equal the
936934  //  maximum rank of all inputs.
937935  using  Traits = phi::funcs::FunctionTraits<Functor>;
@@ -968,23 +966,22 @@ void BroadcastKernel(const KPDevice &ctx,
968966    max_rank = std::max (max_rank, (*outs)[0 ]->dims ().size ());
969967  }
970968  axis = axis == -1  ? max_rank - min_rank : axis;
971-   BroadcastKernelForDifferentVecSize<ET,  OutT, Functor, kArity , NumOuts>(
969+   BroadcastKernelForDifferentVecSize<OutT, Functor, kArity , NumOuts>(
972970      ctx, ins, outs, axis, func);
973971}
974972
975973template  <typename  Functor, typename  T, typename  OutType = T>
976974void  ElementwiseCompute (const  GPUContext &dev_ctx,
977975                        const  DenseTensor &x,
978976                        const  DenseTensor &y,
979-                         int  axis,
980977                        Functor func,
981-                         DenseTensor *z) {
978+                         DenseTensor *z,
979+                         int  axis = -1 ) {
982980  std::vector<const  DenseTensor *> ins = {&x, &y};
983981  std::vector<DenseTensor *> outs = {z};
984982  dev_ctx.template  Alloc <OutType>(z);
985983
986-   BroadcastKernel<ElementwiseType::kBinary , T, OutType, Functor, 1 >(
987-       dev_ctx, ins, &outs, axis, func);
984+   BroadcastKernel<OutType, Functor, 1 >(dev_ctx, ins, &outs, func, axis);
988985}
989986
990987template  <typename  DeviceContext,
@@ -999,7 +996,7 @@ void DefaultElementwiseOperator(const DeviceContext &dev_ctx,
999996  auto  x_dims = x.dims ();
1000997  auto  y_dims = y.dims ();
1001998  dev_ctx.template  Alloc <T>(z);
1002-   funcs::ElementwiseCompute<Functor, T>(dev_ctx, x, y, axis,  Functor (), z);
999+   funcs::ElementwiseCompute<Functor, T>(dev_ctx, x, y, Functor (), z, axis );
10031000}
10041001
10051002#else 
@@ -1017,10 +1014,10 @@ void DefaultElementwiseOperator(const DeviceContext &dev_ctx,
10171014  auto  y_dims = y.dims ();
10181015  dev_ctx.template  Alloc <T>(z);
10191016  if  (x_dims.size () >= y_dims.size ()) {
1020-     funcs::ElementwiseCompute<Functor, T>(dev_ctx, x, y, axis,  Functor (), z);
1017+     funcs::ElementwiseCompute<Functor, T>(dev_ctx, x, y, Functor (), z, axis );
10211018  } else  {
10221019    funcs::ElementwiseCompute<InverseFunctor, T>(
1023-         dev_ctx, x, y, axis,  InverseFunctor (), z);
1020+         dev_ctx, x, y, InverseFunctor (), z, axis );
10241021  }
10251022}
10261023#endif 
0 commit comments