Skip to content
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 108 additions & 0 deletions cmake/external/flashattn.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

include(ExternalProject)

add_definitions(-DPADDLE_WITH_FLASHATTN)

set(FLASHATTN_PREFIX_DIR ${THIRD_PARTY_PATH}/flashattn)
set(FLASHATTN_SOURCE_SUBDIR csrc/flash_attn)
set(FLASHATTN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/flashattn)
set(FLASHATTN_REPOSITORY ${GIT_URL}/PaddlePaddle/flash-attention.git)
set(FLASHATTN_TAG f0edf243a813a65d05c75fcb331b2a95faf96bbc)

set(FLASHATTN_INCLUDE_DIR
"${FLASHATTN_INSTALL_DIR}/include"
CACHE PATH "flash-attn Directory" FORCE)
set(FLASHATTN_LIB_DIR
"${FLASHATTN_INSTALL_DIR}/lib"
CACHE PATH "flash-attn Library Directory" FORCE)

if(WIN32)
set(FLASHATTN_LIBRARIES
"${FLASHATTN_INSTALL_DIR}/bin/flashattn${CMAKE_SHARED_LIBRARY_SUFFIX}"
CACHE FILEPATH "flash-attn Library" FORCE)
else()
set(FLASHATTN_LIBRARIES
"${FLASHATTN_INSTALL_DIR}/lib/libflashattn${CMAKE_SHARED_LIBRARY_SUFFIX}"
CACHE FILEPATH "flash-attn Library" FORCE)
endif()

if(CMAKE_CXX_COMPILER_ID STREQUAL "Clang"
OR CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang"
OR WIN32)
set(USE_OMP OFF)
else()
set(USE_OMP ON)
endif()

if(WIN32)
set(FLASHATTN_C_FLAGS $<FILTER:${CMAKE_C_FLAGS},EXCLUDE,/Zc:inline>)
set(FLASHATTN_C_FLAGS_DEBUG
$<FILTER:${CMAKE_C_FLAGS_DEBUG},EXCLUDE,/Zc:inline>)
set(FLASHATTN_C_FLAGS_RELEASE
$<FILTER:${CMAKE_C_FLAGS_RELEASE},EXCLUDE,/Zc:inline>)
set(FLASHATTN_CXX_FLAGS $<FILTER:${CMAKE_CXX_FLAGS},EXCLUDE,/Zc:inline>)
set(FLASHATTN_CXX_FLAGS_RELEASE
$<FILTER:${CMAKE_CXX_FLAGS_RELEASE},EXCLUDE,/Zc:inline>)
set(FLASHATTN_CXX_FLAGS_DEBUG
$<FILTER:${CMAKE_CXX_FLAGS_DEBUG},EXCLUDE,/Zc:inline>)
else()
set(FLASHATTN_C_FLAGS ${CMAKE_C_FLAGS})
set(FLASHATTN_C_FLAGS_DEBUG ${CMAKE_C_FLAGS_DEBUG})
set(FLASHATTN_C_FLAGS_RELEASE ${CMAKE_C_FLAGS_RELEASE})
set(FLASHATTN_CXX_FLAGS ${CMAKE_CXX_FLAGS})
set(FLASHATTN_CXX_FLAGS_RELEASE ${CMAKE_CXX_FLAGS_RELEASE})
set(FLASHATTN_CXX_FLAGS_DEBUG ${CMAKE_CXX_FLAGS_DEBUG})
endif()

ExternalProject_Add(
extern_flashattn
${EXTERNAL_PROJECT_LOG_ARGS} ${SHALLOW_CLONE}
GIT_REPOSITORY ${FLASHATTN_REPOSITORY}
GIT_TAG ${FLASHATTN_TAG}
PREFIX ${FLASHATTN_PREFIX_DIR}
SOURCE_SUBDIR ${FLASHATTN_SOURCE_SUBDIR}
UPDATE_COMMAND ""
PATCH_COMMAND ""
#BUILD_ALWAYS 1
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_C_FLAGS=${FLASHATTN_C_FLAGS}
-DCMAKE_C_FLAGS_DEBUG=${FLASHATTN_C_FLAGS_DEBUG}
-DCMAKE_C_FLAGS_RELEASE=${FLASHATTN_C_FLAGS_RELEASE}
-DCMAKE_CXX_FLAGS=${FLASHATTN_CXX_FLAGS}
-DCMAKE_CXX_FLAGS_RELEASE=${FLASHATTN_CXX_FLAGS_RELEASE}
-DCMAKE_CXX_FLAGS_DEBUG=${FLASHATTN_CXX_FLAGS_DEBUG}
-DCMAKE_INSTALL_PREFIX=${FLASHATTN_INSTALL_DIR}
-DWITH_GPU=${WITH_GPU}
-DWITH_ROCM=${WITH_ROCM}
-DWITH_OMP=${USE_OMP}
-DBUILD_SHARED=ON
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
${EXTERNAL_OPTIONAL_ARGS}
CMAKE_CACHE_ARGS
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
-DCMAKE_INSTALL_PREFIX:PATH=${FLASHATTN_INSTALL_DIR}
BUILD_BYPRODUCTS ${FLASHATTN_LIBRARIES})

message(STATUS "flash-attn library: ${FLASHATTN_LIBRARIES}")
get_filename_component(FLASHATTN_LIBRARY_PATH ${FLASHATTN_LIBRARIES} DIRECTORY)
include_directories(${FLASHATTN_INCLUDE_DIR})

add_library(flashattn INTERFACE)
#set_property(TARGET flashattn PROPERTY IMPORTED_LOCATION ${FLASHATTN_LIBRARIES})
add_dependencies(flashattn extern_flashattn)
3 changes: 3 additions & 0 deletions cmake/third_party.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,9 @@ if(WITH_GPU
include(external/cutlass) # download, build, install cusparselt
list(APPEND third_party_deps extern_cutlass)
set(WITH_CUTLASS ON)
include(external/flashattn)
list(APPEND third_party_deps extern_flashattn)
set(WITH_FLASHATTN ON)
endif()
endif()

Expand Down
10 changes: 10 additions & 0 deletions paddle/phi/api/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,16 @@
func : fill_diagonal_tensor_grad
inplace : (out_grad -> x_grad)

- backward_op : flash_attn_grad
forward : flash_attn (Tensor q, Tensor k, Tensor v, float dropout = 0.0, bool causal = false, bool return_softmax = false) -> Tensor(out), Tensor(softmax_lse), Tensor(softmax), Tensor(seed_offset)
args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, float dropout = 0.0, bool causal = false)
output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
infer_meta :
func : FlashAttnGradInferMeta
param : [q, k, v]
kernel :
func : flash_attn_grad

- backward_op : flip_grad
forward : flip (Tensor x, int[] axis) -> Tensor(out)
args : (Tensor out_grad, int[] axis)
Expand Down
10 changes: 10 additions & 0 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,16 @@
inplace : (x -> out)
backward : fill_diagonal_tensor_grad

- op : flash_attn
args : (Tensor q, Tensor k, Tensor v, float dropout = 0.0, bool causal = false, bool return_softmax = false)
output : Tensor(out), Tensor(softmax_lse), Tensor(softmax), Tensor(seed_offset)
infer_meta :
func : FlashAttnInferMeta
param : [q, k, v]
kernel :
func : flash_attn
backward : flash_attn_grad

- op : flip
args : (Tensor x, int[] axis)
output : Tensor (out)
Expand Down
7 changes: 7 additions & 0 deletions paddle/phi/backends/dynload/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,13 @@ if(WITH_MKLML)
DEPS phi_dynamic_loader mklml)
endif()

if(WITH_FLASHATTN)
cc_library(
phi_dynload_flashattn
SRCS flashattn.cc
DEPS phi_dynamic_loader flashattn)
endif()

cc_library(
phi_dynload_lapack
SRCS lapack.cc
Expand Down
14 changes: 14 additions & 0 deletions paddle/phi/backends/dynload/dynamic_loader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,20 @@ void* GetWarpRNNTDsoHandle() {
#endif
}

void* GetFlashAttnDsoHandle() {
std::string flashattn_dir = "";
if (!s_py_site_pkg_path.path.empty()) {
flashattn_dir = s_py_site_pkg_path.path;
}
#if defined(__APPLE__) || defined(__OSX__)
return GetDsoHandleFromSearchPath(flashattn_dir, "libflashattn.dylib");
#elif defined(_WIN32)
return GetDsoHandleFromSearchPath(flashattn_dir, "flashattn.dll");
#else
return GetDsoHandleFromSearchPath(flashattn_dir, "libflashattn.so");
#endif
}

void* GetNCCLDsoHandle() {
#ifdef PADDLE_WITH_HIP
std::string warning_msg(
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/backends/dynload/dynamic_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ void* GetNVRTCDsoHandle();
void* GetCUDADsoHandle();
void* GetWarpCTCDsoHandle();
void* GetWarpRNNTDsoHandle();
void* GetFlashAttnDsoHandle();
void* GetNCCLDsoHandle();
void* GetHCCLDsoHandle();
void* GetTensorRtDsoHandle();
Expand Down
25 changes: 25 additions & 0 deletions paddle/phi/backends/dynload/flashattn.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/phi/backends/dynload/flashattn.h"

namespace phi {
namespace dynload {

std::once_flag flashattn_dso_flag;
void* flashattn_dso_handle = nullptr;

#define DEFINE_WRAP(__name) DynLoad__##__name __name

FLASHATTN_ROUTINE_EACH(DEFINE_WRAP);

} // namespace dynload
} // namespace phi
56 changes: 56 additions & 0 deletions paddle/phi/backends/dynload/flashattn.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <mutex> // NOLINT

#include "flashattn/include/flash_attn.h"
#include "paddle/phi/backends/dynload/dynamic_loader.h"
#include "paddle/phi/backends/dynload/port.h"

namespace phi {
namespace dynload {

extern std::once_flag flashattn_dso_flag;
extern void* flashattn_dso_handle;

#define DYNAMIC_LOAD_FLASHATTN_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \
using flashattnFunc = decltype(&::__name); \
std::call_once(flashattn_dso_flag, []() { \
flashattn_dso_handle = phi::dynload::GetFlashAttnDsoHandle(); \
}); \
static void* p_##__name = dlsym(flashattn_dso_handle, #__name); \
return reinterpret_cast<flashattnFunc>(p_##__name)(args...); \
} \
}; \
extern DynLoad__##__name __name

#define DECLARE_DYNAMIC_LOAD_FLASHATTN_WRAP(__name) \
DYNAMIC_LOAD_FLASHATTN_WRAP(__name)

#define FLASHATTN_ROUTINE_EACH(__macro) \
__macro(flash_attn_fwd); \
__macro(flash_attn_bwd); \
__macro(flash_attn_error);

FLASHATTN_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_FLASHATTN_WRAP);

#undef DYNAMIC_LOAD_FLASHATTN_WRAP

} // namespace dynload
} // namespace phi
29 changes: 29 additions & 0 deletions paddle/phi/infermeta/ternary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,35 @@ void BoxCoderInferMeta(const MetaTensor& prior_box,
output_box->set_dtype(target_box.dtype());
}

void FlashAttnInferMeta(const MetaTensor& q,
const MetaTensor& k,
const MetaTensor& v,
MetaTensor* out,
MetaTensor* softmax_lse,
MetaTensor* softmax,
MetaTensor* seed_offset) {
out->set_dims(q.dims());
out->set_dtype(q.dtype());
out->set_layout(q.layout());
}

void FlashAttnGradInferMeta(const MetaTensor& q,
const MetaTensor& k,
const MetaTensor& v,
MetaTensor* dq,
MetaTensor* dk,
MetaTensor* dv) {
if (dq) {
dq->share_meta(q);
}
if (dk && k) {
dk->share_meta(k);
}
if (dv && v) {
dv->share_meta(v);
}
}

void ArangeInferMeta(const MetaTensor& start,
const MetaTensor& end,
const MetaTensor& step,
Expand Down
15 changes: 15 additions & 0 deletions paddle/phi/infermeta/ternary.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,21 @@ void BoxCoderInferMeta(const MetaTensor& prior_box,
MetaTensor* output_box,
MetaConfig config = MetaConfig());

void FlashAttnInferMeta(const MetaTensor& q,
const MetaTensor& k,
const MetaTensor& v,
MetaTensor* out,
MetaTensor* softmax_lse,
MetaTensor* softmax,
MetaTensor* seed_offset);

void FlashAttnGradInferMeta(const MetaTensor& q,
const MetaTensor& k,
const MetaTensor& v,
MetaTensor* dq,
MetaTensor* dk,
MetaTensor* dv);

void InstanceNormInferMeta(const MetaTensor& x,
const MetaTensor& scale,
const MetaTensor& bias,
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ set(COMMON_KERNEL_DEPS
utf8proc
gather_scatter_functor)

if(WITH_FLASHATTN)
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} phi_dynload_flashattn)
endif()

set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} process_group)
if(WITH_NCCL OR WITH_RCCL)
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} process_group_nccl)
Expand Down
7 changes: 7 additions & 0 deletions paddle/phi/kernels/arange_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,11 @@ void ArangeKernel(const Context& dev_ctx,
const DenseTensor& step,
DenseTensor* out);

template <typename T, typename Context>
void ArangeNullaryKernel(const Context& dev_ctx,
const T start,
const T end,
const T step,
DenseTensor* out);

} // namespace phi
37 changes: 37 additions & 0 deletions paddle/phi/kernels/flash_attn_grad_kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"

namespace phi {

template <typename T, typename Context>
void FlashAttnGradKernel(const Context& ctx,
const DenseTensor& q,
const DenseTensor& k,
const DenseTensor& v,
const DenseTensor& out,
const DenseTensor& softmax_lse,
const DenseTensor& seed_offset,
const DenseTensor& dout,
float dropout,
bool causal,
DenseTensor* dq,
DenseTensor* dk,
DenseTensor* dv);

} // namespace phi
Loading