Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ paddle/fluid/API_DEV.spec
paddle/fluid/API_PR.spec
paddle/fluid/op_use_default_grad_maker_DEV.spec
paddle/fluid/op_use_default_grad_maker_PR.spec
paddle/pten/api/*/api*

*.DS_Store
*.vs
Expand Down
2 changes: 1 addition & 1 deletion paddle/pten/api/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
add_subdirectory(lib)

cc_library(pten_api SRCS all.cc DEPS linalg_api math_api creation_api manipulation_api utils_api)
cc_library(pten_api SRCS all.cc DEPS pten_function_api utils_api)
5 changes: 1 addition & 4 deletions paddle/pten/api/all.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,7 @@ limitations under the License. */
#endif

// new pten apis
#include "paddle/pten/api/include/creation.h"
#include "paddle/pten/api/include/linalg.h"
#include "paddle/pten/api/include/manipulation.h"
#include "paddle/pten/api/include/math.h"
#include "paddle/pten/api/include/api.h"
#include "paddle/pten/api/include/tensor.h"
#include "paddle/pten/api/include/utils.h"

Expand Down
49 changes: 0 additions & 49 deletions paddle/pten/api/include/creation.h

This file was deleted.

30 changes: 0 additions & 30 deletions paddle/pten/api/include/linalg.h

This file was deleted.

28 changes: 0 additions & 28 deletions paddle/pten/api/include/manipulation.h

This file was deleted.

48 changes: 0 additions & 48 deletions paddle/pten/api/include/math.h

This file was deleted.

25 changes: 21 additions & 4 deletions paddle/pten/api/lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,25 @@ cc_library(kernel_dispatch SRCS kernel_dispatch.cc DEPS pten_tensor device_conte

cc_library(op_meta_info SRCS op_meta_info.cc DEPS pten_tensor)

cc_library(math_api SRCS math.cc DEPS pten_tensor pten kernel_dispatch)
cc_library(linalg_api SRCS linalg.cc DEPS pten_tensor pten kernel_dispatch)
cc_library(creation_api SRCS creation.cc DEPS pten_tensor pten kernel_dispatch)
cc_library(manipulation_api SRCS manipulation.cc DEPS pten_tensor pten kernel_dispatch)
set(api_gen_file ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/api_gen.py)
set(api_yaml_file ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/api.yaml)

set(api_header_file ${CMAKE_SOURCE_DIR}/paddle/pten/api/include/api.h)
set(api_source_file ${CMAKE_SOURCE_DIR}/paddle/pten/api/lib/api.cc)
set(api_header_file_tmp ${api_header_file}.tmp)
set(api_source_file_tmp ${api_source_file}.tmp)

add_custom_command(
OUTPUT ${api_header_file} ${api_source_file}
COMMAND python ${api_gen_file}
--api_yaml_path ${api_yaml_file}
--api_header_path ${api_header_file_tmp}
--api_source_path ${api_source_file_tmp}
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${api_header_file_tmp} ${api_header_file}
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${api_source_file_tmp} ${api_source_file}
COMMENT "copy_if_different ${api_header_file} ${api_source_file}"
DEPENDS ${api_yaml_file}
VERBATIM)

cc_library(utils_api SRCS utils.cc DEPS pten_tensor pten kernel_dispatch)
cc_library(pten_function_api SRCS ${api_source_file} DEPS pten_tensor pten kernel_dispatch)
135 changes: 0 additions & 135 deletions paddle/pten/api/lib/creation.cc

This file was deleted.

41 changes: 41 additions & 0 deletions paddle/pten/api/lib/kernel_dispatch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,5 +57,46 @@ paddle::platform::DeviceContext* GetDeviceContextByBackend(
return pool.Get(pten::TransToFluidPlace(backend));
}

DataType ParseDataType(DataType dtype) { return dtype; }
DataType ParseDataType(const Tensor& tensor) { return tensor.type(); }
DataType ParseDataType(const std::vector<Tensor>& tensors) {
if (tensors.empty()) {
return DataType::UNDEFINED;
}
DataType dtype = tensors[0].type();
auto n = tensors.size();
for (size_t i = 1; i < n; ++i) {
if (tensors[i].type() != dtype) {
PADDLE_THROW(platform::errors::InvalidArgument(
"The data_type of input tensor in list isn't consistent, "
"the first tensor is %s, but %dth tensor is %s.",
dtype,
i,
tensors[i].type()));
}
}
return dtype;
}

DataType ParseDataTypeWithInputOrder(DataType dtype, const Tensor& tensor) {
return dtype != DataType::UNDEFINED ? dtype : ParseDataType(tensor);
}

Backend ParseBackend(Backend backend) { return backend; }
Backend ParseBackend(const Tensor& tensor) {
return pten::TransToPtenBackend(tensor.inner_place());
}

Backend ParseBackendWithInputOrder(Backend backend, const Tensor& tensor) {
return backend != Backend::UNDEFINED ? backend : ParseBackend(tensor);
}

DataLayout ParseLayout(DataLayout layout) { return layout; }
DataLayout ParseLayout(const Tensor& tensor) { return tensor.layout(); }

DataLayout ParseLayoutWithInputOrder(DataLayout layout, const Tensor& tensor) {
return layout != DataLayout::UNDEFINED ? layout : ParseLayout(tensor);
}

} // namespace experimental
} // namespace paddle
Loading