Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -697,7 +697,7 @@ def GenerateNodeCreationCodes(
else:
# Tuple api_result
if IsPlainTensorType(rtype):
output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&api_result[{pos}]);"
output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&std::get<{pos}>(api_result));"
else:
assert IsVectorTensorType(rtype)
output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&api_result[{pos}]);\n"
Expand Down Expand Up @@ -734,8 +734,11 @@ def GenerateNodeCreationCodes(
else:
set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({name}, true);"
else:
if IsVectorTensorType(atype):
tw_name = f"api_result[{pos}]"
if num_fwd_outputs > 1:
# Aligned with forward output position
assert name in forward_outputs_position_map.keys()
fwd_output_pos = forward_outputs_position_map[name][1]
tw_name = f"std::get<{fwd_output_pos}>(api_result)"
else:
tw_name = f"api_result"

Expand Down Expand Up @@ -777,7 +780,7 @@ def GenerateNodeCreationCodes(
if num_outputs == 1:
set_retain_grad = f" egr::EagerUtils::CheckAndRetainGrad(api_result);"
else:
set_retain_grad = f" egr::EagerUtils::CheckAndRetainGrad(api_result[{pos}]);"
set_retain_grad = f"egr::EagerUtils::CheckAndRetainGrad(std::get<{pos}>(api_result));"
set_retain_grad_list.append(set_retain_grad)
set_out_rank_str = "\n".join(set_out_rank_list)
set_history_str = "\n".join(set_history_list)
Expand Down Expand Up @@ -900,7 +903,7 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name,
returns_list[0] = f"api_result"
else:
# Tuple api_result
returns_list[pos] = f"api_result[{pos}]"
returns_list[pos] = f"std::get<{pos}>(api_result)"

if IsPlainTensorType(rtype):
returns_type_list[pos] = "paddle::experimental::Tensor"
Expand Down Expand Up @@ -1038,7 +1041,8 @@ def GenerateNodeCCFile(filepath, node_definition_str):
#include "paddle/fluid/eager/api/generated/eager_generated/backwards/nodes.h"
#include "paddle/fluid/eager/to_static/run_program_op_node.h"

#include "paddle/phi/api/include/sparse_api.h"
//#include "paddle/phi/api/include/sparse_api.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

多了行注释

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

#include "paddle/phi/api/backward/sparse_bw_api.h"
"""
file_contents += node_definition_str
with open(filepath, 'a') as f:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ def GeneratePythonCFunction(fwd_api_name, forward_inputs_position_map,
if is_forward_only:
fwd_function_name = "paddle::experimental::" + namespace_str + fwd_api_name
else:
fwd_function_name = namespace_str + GetForwardFunctionName(fwd_api_name)
fwd_function_name = "::" + namespace_str + GetForwardFunctionName(
fwd_api_name)

python_c_function_str = PYTHON_C_FUNCTION_TEMPLATE.format(
fwd_api_name, fwd_api_name, get_eager_tensor_str, parse_attributes_str,
Expand Down
121 changes: 87 additions & 34 deletions paddle/phi/api/lib/sparse_api_custom_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,29 +25,46 @@ namespace paddle {
namespace experimental {
namespace sparse {

Tensor to_sparse_coo_impl(const Tensor& x,
Backend backend,
const int64_t sparse_dim) {
Tensor to_sparse_coo_impl(const Tensor& x, const int64_t sparse_dim) {
if (x.layout() == phi::DataLayout::SPARSE_COO) {
return x;
}

Backend kernel_backend = Backend::UNDEFINED;
DataLayout kernel_layout = DataLayout::UNDEFINED;
DataType kernel_data_type = DataType::UNDEFINED;
if (kernel_backend == Backend::UNDEFINED ||
kernel_layout == DataLayout::UNDEFINED ||
kernel_data_type == DataType::UNDEFINED) {
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
if (kernel_backend == Backend::UNDEFINED) {
kernel_backend = kernel_key.backend();
}
if (kernel_layout == DataLayout::UNDEFINED) {
kernel_layout = kernel_key.layout();
}
if (kernel_data_type == DataType::UNDEFINED) {
kernel_data_type = kernel_key.dtype();
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果自定义实现的话不用写的这么复杂,直接kernel_key拿出来用就行

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

// 1. Get kernel signature and kernel
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
kernel_key_set.backend_set = kernel_key_set.backend_set | BackendSet(backend);
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
std::string kernel_name = "dense_to_sparse_coo";
if (x.layout() == phi::DataLayout::SPARSE_CSR) {
kernel_name = "sparse_csr_to_coo";
}

VLOG(6) << "kernel_name: " << kernel_name;

auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
kernel_name, kernel_key);
kernel_name, {kernel_backend, kernel_layout, kernel_data_type});

VLOG(6) << "to API kernel key: " << kernel_key;
VLOG(6) << "add API kernel key: [" << kernel_backend << ", " << kernel_layout
<< ", " << kernel_data_type << "]";
VLOG(6) << "to API kernel: " << kernel;

// 2. Get Device Context
auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend());
auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
auto kernel_context = phi::KernelContext(dev_ctx);

// 3. Auto data transform
Expand All @@ -61,19 +78,21 @@ Tensor to_sparse_coo_impl(const Tensor& x,
}

// 4. InferMeta
VLOG(6) << "infer meta.";
auto indices_meta =
phi::DenseTensorMeta(phi::DataType::INT64, {-1}, phi::DataLayout::NCHW);
auto elements_meta = phi::DenseTensorMeta(x.dtype(), {-1}, x.layout());
phi::DenseTensorMeta(phi::DataType::INT64, {1}, phi::DataLayout::NCHW);
auto elements_meta = phi::DenseTensorMeta(x.dtype(), {1}, x.layout());

// 5. Prepare outputs
// create empty SparseCooTensor
VLOG(6) << "create empty SparseCooTensor.";
phi::DenseTensor non_zero_indices(
phi::make_intrusive<paddle::experimental::SharedStorage>(
phi::TransToPhiPlace(backend)),
phi::TransToPhiPlace(kernel_backend)),
std::move(indices_meta));
phi::DenseTensor non_zero_elements(
phi::make_intrusive<paddle::experimental::SharedStorage>(
phi::TransToPhiPlace(backend)),
phi::TransToPhiPlace(kernel_backend)),
std::move(elements_meta));
auto coo = std::make_shared<phi::SparseCooTensor>(
non_zero_indices, non_zero_elements, x.dims());
Expand All @@ -83,32 +102,50 @@ Tensor to_sparse_coo_impl(const Tensor& x,
out.set_impl(coo);

// 6. Call kernel
VLOG(6) << "call kernel ";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这种内容比较少的VLOG感觉可以去掉或者把级别调大一些

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已把debug信息去掉。


kernel(&kernel_context);

return out;
}

Tensor to_sparse_csr_impl(const Tensor& x, Backend backend) {
Tensor to_sparse_csr_impl(const Tensor& x) {
if (x.layout() == phi::DataLayout::SPARSE_CSR) {
return x;
}
Backend kernel_backend = Backend::UNDEFINED;
DataLayout kernel_layout = DataLayout::UNDEFINED;
DataType kernel_data_type = DataType::UNDEFINED;
if (kernel_backend == Backend::UNDEFINED ||
kernel_layout == DataLayout::UNDEFINED ||
kernel_data_type == DataType::UNDEFINED) {
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
if (kernel_backend == Backend::UNDEFINED) {
kernel_backend = kernel_key.backend();
}
if (kernel_layout == DataLayout::UNDEFINED) {
kernel_layout = kernel_key.layout();
}
if (kernel_data_type == DataType::UNDEFINED) {
kernel_data_type = kernel_key.dtype();
}
}
// 1. Get kernel signature and kernel
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
kernel_key_set.backend_set = kernel_key_set.backend_set | BackendSet(backend);
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
std::string kernel_name = "dense_to_sparse_csr";
if (x.layout() == phi::DataLayout::SPARSE_COO) {
kernel_name = "sparse_coo_to_csr";
}

auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
kernel_name, kernel_key);
kernel_name, {kernel_backend, kernel_layout, kernel_data_type});

VLOG(6) << "to API kernel key: " << kernel_key;
VLOG(6) << "add API kernel key: [" << kernel_backend << ", " << kernel_layout
<< ", " << kernel_data_type << "]";
VLOG(6) << "to API kernel: " << kernel;

// 2. Get Device Context
auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend());
auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
auto kernel_context = phi::KernelContext(dev_ctx);

// 3. Auto data transform
Expand All @@ -122,24 +159,24 @@ Tensor to_sparse_csr_impl(const Tensor& x, Backend backend) {

// 4. InferMeta
auto crows_meta =
phi::DenseTensorMeta(phi::DataType::INT64, {-1}, phi::DataLayout::NCHW);
phi::DenseTensorMeta(phi::DataType::INT64, {1}, phi::DataLayout::NCHW);
auto cols_meta =
phi::DenseTensorMeta(phi::DataType::INT64, {-1}, phi::DataLayout::NCHW);
auto elements_meta = phi::DenseTensorMeta(x.dtype(), {-1}, x.layout());
phi::DenseTensorMeta(phi::DataType::INT64, {1}, phi::DataLayout::NCHW);
auto elements_meta = phi::DenseTensorMeta(x.dtype(), {1}, x.layout());

// 5. Prepare outputs
// create empty SparseCooTensor
phi::DenseTensor non_zero_crows(
phi::make_intrusive<paddle::experimental::SharedStorage>(
phi::TransToPhiPlace(backend)),
phi::TransToPhiPlace(kernel_backend)),
std::move(crows_meta));
phi::DenseTensor non_zero_cols(
phi::make_intrusive<paddle::experimental::SharedStorage>(
phi::TransToPhiPlace(backend)),
phi::TransToPhiPlace(kernel_backend)),
std::move(cols_meta));
phi::DenseTensor non_zero_elements(
phi::make_intrusive<paddle::experimental::SharedStorage>(
phi::TransToPhiPlace(backend)),
phi::TransToPhiPlace(kernel_backend)),
std::move(elements_meta));
auto csr = std::make_shared<phi::SparseCsrTensor>(
non_zero_crows, non_zero_cols, non_zero_elements, x.dims());
Expand All @@ -154,28 +191,44 @@ Tensor to_sparse_csr_impl(const Tensor& x, Backend backend) {
return out;
}

Tensor to_dense_impl(const Tensor& x, Backend backend) {
Tensor to_dense_impl(const Tensor& x) {
if (x.layout() != phi::DataLayout::SPARSE_CSR &&
x.layout() != phi::DataLayout::SPARSE_COO) {
return x;
}
Backend kernel_backend = Backend::UNDEFINED;
DataLayout kernel_layout = DataLayout::UNDEFINED;
DataType kernel_data_type = DataType::UNDEFINED;
if (kernel_backend == Backend::UNDEFINED ||
kernel_layout == DataLayout::UNDEFINED ||
kernel_data_type == DataType::UNDEFINED) {
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
if (kernel_backend == Backend::UNDEFINED) {
kernel_backend = kernel_key.backend();
}
if (kernel_layout == DataLayout::UNDEFINED) {
kernel_layout = kernel_key.layout();
}
if (kernel_data_type == DataType::UNDEFINED) {
kernel_data_type = kernel_key.dtype();
}
}
// 1. Get kernel signature and kernel
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
kernel_key_set.backend_set = kernel_key_set.backend_set | BackendSet(backend);
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
std::string kernel_name = "sparse_coo_to_dense";
if (x.layout() == phi::DataLayout::SPARSE_CSR) {
kernel_name = "sparse_csr_to_dense";
}

auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
kernel_name, kernel_key);
kernel_name, {kernel_backend, kernel_layout, kernel_data_type});

VLOG(6) << "to API kernel key: " << kernel_key;
VLOG(6) << "add API kernel key: [" << kernel_backend << ", " << kernel_layout
<< ", " << kernel_data_type << "]";
VLOG(6) << "to API kernel: " << kernel;

// 2. Get Device Context
auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend());
auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
auto kernel_context = phi::KernelContext(dev_ctx);

// 3. Auto data transform
Expand All @@ -194,7 +247,7 @@ Tensor to_dense_impl(const Tensor& x, Backend backend) {
// create empty SparseCooTensor
auto dense_out = std::make_shared<phi::DenseTensor>(
phi::make_intrusive<paddle::experimental::SharedStorage>(
phi::TransToPhiPlace(backend)),
phi::TransToPhiPlace(kernel_backend)),
std::move(dense_meta));

kernel_context.EmplaceBackOutput(dense_out.get());
Expand Down
8 changes: 3 additions & 5 deletions paddle/phi/api/lib/sparse_api_custom_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,11 @@ namespace paddle {
namespace experimental {
namespace sparse {

Tensor to_dense_impl(const Tensor& x, Backend backend);
Tensor to_dense_impl(const Tensor& x);

Tensor to_sparse_coo_impl(const Tensor& x,
Backend backend,
const int64_t sparse_dim);
Tensor to_sparse_coo_impl(const Tensor& x, const int64_t sparse_dim);

Tensor to_sparse_csr_impl(const Tensor& x, Backend backend);
Tensor to_sparse_csr_impl(const Tensor& x);

} // namespace sparse
} // namespace experimental
Expand Down
8 changes: 8 additions & 0 deletions paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License. */
#include <thrust/execution_policy.h>
#include <thrust/remove.h>

#include "glog/logging.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/kernel_registry.h"
Expand Down Expand Up @@ -93,6 +94,7 @@ void DenseToSparseCooKernel(const Context& dev_ctx,
const DenseTensor& x,
const int64_t sparse_dim,
SparseCooTensor* out) {
VLOG(6) << "enter DenseToSparseCooKernel.";
const T* x_data = x.data<T>();
const auto& x_dims = x.dims();
auto dims_2d = flatten_to_2d(x_dims, sparse_dim);
Expand Down Expand Up @@ -123,6 +125,7 @@ void DenseToSparseCooKernel(const Context& dev_ctx,
phi::DenseTensorMeta(DataType::INT32, {rows}, phi::DataLayout::NCHW);
DenseTensor temp_indexs = phi::Empty(dev_ctx, std::move(temp_indexs_meta));
int* temp_indexs_ptr = temp_indexs.mutable_data<int>(place);
VLOG(6) << "get the number of non-zero elements.";
GetNonZeroNums<<<config.block_per_grid.x,
config.thread_per_block.x,
0,
Expand Down Expand Up @@ -171,6 +174,8 @@ void DenseToSparseCooKernel(const Context& dev_ctx,

dev_ctx.Wait(); // wait the copy

VLOG(6) << "alloc SparseCooTensor";

const auto values_dims =
phi::funcs::sparse::InferDenseDims(x_dims, sparse_dim, non_zero_num);
DenseTensorMeta indices_meta(DataType::INT64,
Expand All @@ -189,6 +194,7 @@ void DenseToSparseCooKernel(const Context& dev_ctx,
T* sparse_data = values.mutable_data<T>(place);

// 3. calc indices by indexs and get values by indexs
VLOG(6) << "calc indices..";
config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, non_zero_num, 1);
GetNonZeroElementsAndIndices<<<config.block_per_grid.x,
config.thread_per_block.x,
Expand All @@ -201,7 +207,9 @@ void DenseToSparseCooKernel(const Context& dev_ctx,
temp_indexs_ptr,
indices_data,
sparse_data);
VLOG(6) << "set member";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉这像是调试的log,内容不算规范

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

out->SetMember(indices, values, x_dims, true);
VLOG(6) << "leave DenseToSparseCoo";
}

__global__ void GetBatchSizes(const int64_t* crows,
Expand Down
Loading