Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
713bdbe
add random routing op
sljlp Mar 23, 2022
e5e0d3a
# This is a combination of 10 commits.
liyagit21 Sep 24, 2021
cda9c51
add assign pos op
sljlp Mar 20, 2022
7a782d3
fix upper num name
sljlp Mar 20, 2022
db110c9
add api _assign pos
sljlp Mar 20, 2022
f310a54
add ut for assign pos op
sljlp Mar 21, 2022
f3ff5f6
update date
sljlp Mar 21, 2022
8968bf1
add op about moe gate
sljlp Mar 21, 2022
40a8a50
fix for win
sljlp Mar 22, 2022
37950fa
fix bugs in test_limit_by_capacity_op
sljlp Mar 23, 2022
83a1361
update ut
sljlp Mar 23, 2022
7f704c7
update for test (timeout)
sljlp Mar 23, 2022
7db0bb6
fix ut
sljlp Mar 23, 2022
441ea29
update
sljlp Mar 23, 2022
b636217
update(fix) ut for win
sljlp Mar 23, 2022
79be77c
moe apis in incubate
sljlp Mar 23, 2022
44f1064
# This is a combination of 10 commits.
liyagit21 Sep 24, 2021
6f46b07
add assign pos op
sljlp Mar 20, 2022
a300850
fix upper num name
sljlp Mar 20, 2022
f8b1852
add api _assign pos
sljlp Mar 20, 2022
c0f2701
add ut for assign pos op
sljlp Mar 21, 2022
2e464c2
update date
sljlp Mar 21, 2022
bb266f8
fix for win
sljlp Mar 22, 2022
53e494d
update for test (timeout)
sljlp Mar 23, 2022
3dd05d1
fix ut
sljlp Mar 23, 2022
d4c0cf5
update
sljlp Mar 23, 2022
1d19ed0
fix ut for number count
sljlp Mar 23, 2022
d517d5e
Merge branch 'assign_pos_op' into limit_by_capacity
sljlp Mar 23, 2022
04f39ef
Merge branch 'limit_by_capacity' into moe_apis
sljlp Mar 23, 2022
f00621b
Merge branch 'random_routing' into moe_apis
sljlp Mar 23, 2022
00054fc
add apis and utils
sljlp Mar 23, 2022
a7d9506
add gate apis
sljlp Mar 23, 2022
cfc9fc9
add moe and grad clip apis
sljlp Mar 23, 2022
913ec54
update moe apis
sljlp Mar 23, 2022
e95b2f7
add ops for moe gate
sljlp Mar 23, 2022
6df2be4
fix
sljlp Mar 23, 2022
d079246
Merge branch 'limit_by_capacity' into moe_apis
sljlp Mar 24, 2022
2945176
update for base moe layer api
sljlp Mar 24, 2022
0e97c7c
add random routing op
sljlp Mar 23, 2022
d94fd57
fix for dygraph
sljlp Mar 25, 2022
39926a0
update with ranodm routing
sljlp Mar 25, 2022
e56d26c
Merge branch 'random_routing' into limit_by_capacity
sljlp Mar 25, 2022
0e65fe8
update
sljlp Mar 25, 2022
3bce091
fix ut for limit by capacity
sljlp Mar 25, 2022
3ac3ff2
Merge branch 'limit_by_capacity' into moe_apis
sljlp Mar 25, 2022
7d49289
update
sljlp Mar 25, 2022
02f777d
Merge branch 'limit_by_capacity' into moe_apis
sljlp Mar 25, 2022
c495374
update limit by capacity for easily to switch to single thread mode
sljlp Mar 29, 2022
c0924dc
Merge branch 'moe_apis' into moe_api2
sljlp Mar 29, 2022
18510d9
update api docs
sljlp Mar 29, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 5 additions & 8 deletions paddle/fluid/operators/limit_by_capacity_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,17 @@
namespace paddle {
namespace operators {

#define CEIL(_x_, _y_) (((_x_)-1) / (_y_) + 1)

using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;

template <typename T>
__global__ void limit_by_capacity_impl(const T* expc, T* cap, T* out,
const int n_expert, const int n_worker) {
int eid = blockIdx.y;
int wid = blockIdx.x * blockDim.x + threadIdx.x;
if (wid < n_worker) {
int eid, wid;
CUDA_KERNEL_LOOP(i, (n_expert * n_worker)) {
wid = i / n_expert;
eid = i % n_expert;
auto proposal = expc[wid * n_expert + eid];
// int cap_left = atomicSub(cap + eid, proposal);
auto cap_left = paddle::platform::CudaAtomicAdd(cap + eid, proposal * (-1));
if (cap_left >= proposal) {
out[wid * n_expert + eid] = proposal;
Expand All @@ -54,12 +52,11 @@ class LimitByCapacityOpCUDAKernel : public framework::OpKernel<T> {
auto out = context.Output<Tensor>("Out");

auto n_expert = expert_count->numel() / n_worker;
// std::cout << "n_expert" << n_expert << std::endl;
const auto place = context.GetPlace();
const auto& dev_ctx =
context.template device_context<platform::CUDADeviceContext>();

dim3 grid_dim(CEIL(n_worker, 1024), n_expert);
dim3 grid_dim(256);
dim3 block_dim(1024);
auto out_data = out->mutable_data<T>(place);
const T* ec_data = expert_count->data<T>();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
22 changes: 10 additions & 12 deletions python/paddle/incubate/distributed/models/moe/moe_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@
from paddle.distributed.fleet.meta_parallel.pp_utils.utils import _hp_recompute
from paddle import fluid

__all__ = ["MoeLayer"]


def _local_scatter(inp, pos):
if pos.shape != [0]:
Expand Down Expand Up @@ -71,7 +69,7 @@ def _all_gather(tensor, group=None, use_calc_stream=True):
'ring_id', ring_id, 'nranks', nranks)


class MOEScatter(PyLayer):
class MoEScatter(PyLayer):
r"""
Scatter input samples from [batch x sequences] to contiguous alone experts.
If `world_size` is greater than 1, the samples will first be locally
Expand Down Expand Up @@ -117,10 +115,10 @@ def backward(ctx, grad):
return grad_in, None, None, None


class MOEGather(PyLayer):
class MoEGather(PyLayer):
r"""
Gather output samples from contiguous alone experts back to [batch x
sequences]. Works symmetrically with MOEScatter.
sequences]. Works symmetrically with MoEScatter.
"""

@staticmethod
Expand Down Expand Up @@ -225,8 +223,8 @@ def prepare_forward(gate, num_expert, world_size, moe_group):
fwd_batch_size, )


class MoeLayer(nn.Layer):
"""Moe Layer
class MoELayer(nn.Layer):
"""MoE Layer
Args:
d_model: (int) model dimention
experts: (nn.LayerList) expert networks list
Expand All @@ -243,7 +241,7 @@ class MoeLayer(nn.Layer):
Examples:
.. code-block:: python
from paddle.nn import layer, LayerList
from paddle.distributed.moe import Moelayer
from paddle.distributed.moe import MoElayer
from paddle.distributed.collective import Group
from paddle.distributed import fleet

Expand Down Expand Up @@ -279,7 +277,7 @@ def forward(self, x):
exp_layer = ExpertLayer(d_model, dim_feedforward // top_k, windex=expi, num_expert=num_experts)
experts_list.append(exp_layer)

moeLayer = MoeLayer(d_model = d_model,
moeLayer = MoELayer(d_model = d_model,
experts=experts_list,
gate=gate_config,
moe_group=moe_group,
Expand All @@ -295,7 +293,7 @@ def __init__(self,
moe_group=None,
mp_group=None,
**kwargs):
super(MoeLayer, self).__init__()
super(MoELayer, self).__init__()

recompute_interval = kwargs.get("recompute_interval", 0)

Expand Down Expand Up @@ -385,7 +383,7 @@ def forward(self, inp):
temp_pos = pos
assert topk == self.top_k

x = MOEScatter.apply(inp, temp_pos, local_expert_count,
x = MoEScatter.apply(inp, temp_pos, local_expert_count,
global_expert_count, fwd_batch_size,
self.world_size, self.group)

Expand Down Expand Up @@ -416,7 +414,7 @@ def experts_fwd(x, fwd_expert_count, experts):
if len(gate.shape) == 2:
out_batch_size *= gate.shape[1]

x = MOEGather.apply(x, pos, local_expert_count, global_expert_count,
x = MoEGather.apply(x, pos, local_expert_count, global_expert_count,
out_batch_size, self.world_size, self.group)

x = x.reshape([-1, self.top_k, d_model])
Expand Down