Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/pir/dialect/op_generator/ops_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@
'c_reducescatter',
'c_softmax_with_cross_entropy',
'decayed_adagrad',
'distributed_lookup_table',
'dpsgd',
'embedding_grad_sparse',
'ftrl',
Expand Down
9 changes: 9 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,15 @@
data_type : fpn_rois
optional : rois_num, multi_level_rois_num

- op : distributed_lookup_table
args : (Tensor[] ids, Tensor w, int table_id = 0, bool is_distributed = false, str lookup_table_version = "lookup_table", int64_t padding_idx = -1, int dtype = 5, bool is_test = false)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
args : (Tensor[] ids, Tensor w, int table_id = 0, bool is_distributed = false, str lookup_table_version = "lookup_table", int64_t padding_idx = -1, int dtype = 5, bool is_test = false)
args : (Tensor[] ids, Tensor w, int table_id = 0, bool is_distributed = false, str lookup_table_version = "lookup_table", int64_t padding_idx = -1, DataType dtype = DataType::FLOAT32, bool is_test = false)

output : Tensor[](outputs){ids.size()}
infer_meta :
func : DistributeLookupTableInferMeta
kernel :
func : distributed_lookup_table
data_type : ids
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
data_type : ids
data_type : dtype

这里需要和DistributedLookupTableOp::GetExpectedKernelType 保持一致

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

谢谢,已修改~


- op : divide
args : (Tensor x, Tensor y)
output : Tensor(out)
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/api/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3477,6 +3477,12 @@
multi_level_rois_num: MultiLevelRoIsNum
restore_index: RestoreIndex

- op: distributed_lookup_table
inputs:
{ids: Ids, w: W}
outputs:
outputs: Outputs

- op: dpsgd
inputs:
{param: Param,grad: Grad,learning_rate: LearningRate}
Expand Down
43 changes: 43 additions & 0 deletions paddle/phi/infermeta/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1097,6 +1097,49 @@ void DistInferMeta(const MetaTensor& x,
out->set_dtype(x.dtype());
}

void DistributeLookupTableInferMeta(
const std::vector<const phi::MetaTensor*>& ids,
const MetaTensor& w,
int table_id,
bool is_distributed,
const std::string& lookup_table_version,
int64_t padding_idx,
int dtype,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
int dtype,
DataType dtype,

bool is_test,
std::vector<MetaTensor*> outputs) {
auto table_dims = w.dims();

PADDLE_ENFORCE_EQ(w.dims().size(),
2,
errors::InvalidArgument(
"Only 2 dimensions of the 'Embedding' is supported."));

for (auto& id : ids) {
PADDLE_ENFORCE_EQ(id->dims().size(),
2,
errors::InvalidArgument(
"The dimension of the 'Ids' tensor must be 2."));
}

// for fluid.embedding
for (size_t i = 0; i < ids.size(); ++i) {
MetaTensor* output = outputs[i];
auto id_dims = ids[i]->dims();
if (lookup_table_version == "lookup_table") {
output->set_dims(common::make_ddim({id_dims[0], table_dims[1]}));
output->share_lod(*ids[i]);
output->set_dtype(w.dtype());
} else if (lookup_table_version == "lookup_table_v2") {
output->set_dims(
common::make_ddim({static_cast<int64_t>(id_dims[0]),
static_cast<int64_t>(id_dims[1]),
static_cast<int64_t>(table_dims[1])}));
output->share_lod(*ids[i]);
output->set_dtype(w.dtype());
}
}
}

void DistributeFpnProposalsInferMeta(
const MetaTensor& fpn_rois,
const MetaTensor& rois_num,
Expand Down
11 changes: 11 additions & 0 deletions paddle/phi/infermeta/binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,17 @@ void DistInferMeta(const MetaTensor& x,
float p,
MetaTensor* out);

void DistributeLookupTableInferMeta(
const std::vector<const phi::MetaTensor*>& ids,
const MetaTensor& w,
int table_id,
bool is_distributed,
const std::string& lookup_table_version,
int64_t padding_idx,
int dtype,
bool is_test,
std::vector<MetaTensor*> outputs);

void DistributeFpnProposalsInferMeta(
const MetaTensor& fpn_rois,
const MetaTensor& rois_num,
Expand Down
2 changes: 2 additions & 0 deletions test/ir/pir/translator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ string(REPLACE ".py" "" TEST_INTERP_CASES "${TEST_INTERP_CASES}")

set(DISTRIBUTED_OP_TRANSLATOR_TEST test_c_reduce_min_translator)
list(APPEND DISTRIBUTED_OP_TRANSLATOR_TEST test_c_allreduce_prod_translator)
list(APPEND DISTRIBUTED_OP_TRANSLATOR_TEST
test_distributed_lookup_table_translate)

if(NOT WITH_DISTRIBUTE)
list(REMOVE_ITEM TEST_INTERP_CASES ${DISTRIBUTED_OP_TRANSLATOR_TEST})
Expand Down
52 changes: 52 additions & 0 deletions test/ir/pir/translator/test_distributed_lookup_table_translate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2023 -> 2024

#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import test_op_translator

import paddle
from paddle.base.layer_helper import LayerHelper


class TestDistributedLookupTableOpTranslator(
test_op_translator.TestOpTranslator
):
def append_op(self):
self.op_type = "distributed_lookup_table"
ids = paddle.ones(shape=(1, 1), dtype='float32')
w = paddle.ones(shape=(1, 1), dtype='float32')
out = paddle.ones(shape=(1, 1), dtype='float32')
attrs = {
'table_id': 0,
'is_distributed': False,
'lookup_table_version': 'lookup_table',
'padding_idx': -1,
'dtype': 5,
'is_test': False,
}
helper = LayerHelper(self.op_type)
helper.append_op(
type=self.op_type,
inputs={"Ids": [ids], "W": w},
outputs={"Outputs": [out]},
attrs=attrs,
)

def test_translator(self):
self.check()


if __name__ == "__main__":
unittest.main()