Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion cmake/external/flashattn.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,26 @@ else()
set(FLASHATTN_LIBRARIES
"${FLASHATTN_INSTALL_DIR}/bin/flashattn${CMAKE_SHARED_LIBRARY_SUFFIX}"
CACHE FILEPATH "flash-attn Library" FORCE)
if(WITH_FLASHATTN_V3)
set(FLASHATTN_V3_LIBRARIES
"${FLASHATTN_INSTALL_DIR}/bin/libflashattnv3${CMAKE_SHARED_LIBRARY_SUFFIX}"
CACHE FILEPATH "flash-attn Library" FORCE)
endif()
else()
set(FLASHATTN_LIBRARIES
"${FLASHATTN_INSTALL_DIR}/lib/libflashattn${CMAKE_SHARED_LIBRARY_SUFFIX}"
CACHE FILEPATH "flash-attn Library" FORCE)
if(WITH_FLASHATTN_V3)
set(FLASHATTN_V3_LIBRARIES
"${FLASHATTN_INSTALL_DIR}/lib/libflashattnv3${CMAKE_SHARED_LIBRARY_SUFFIX}"
CACHE FILEPATH "flash-attn Library" FORCE)
endif()
endif()

set(BUILD_BYPRODUCTS_LIST ${FLASHATTN_LIBRARIES})
if(WITH_FLASHATTN_V3)
add_definitions(-DPADDLE_WITH_FLASHATTN_V3)
list(APPEND BUILD_BYPRODUCTS_LIST ${FLASHATTN_V3_LIBRARIES})
endif()

if(NOT DEFINED FA_JOB_POOLS_COMPILE)
Expand Down Expand Up @@ -179,16 +195,20 @@ else()
-DCMAKE_JOB_POOL_COMPILE:STRING=compile
-DCMAKE_JOB_POOLS:STRING=compile=${FA_JOB_POOLS_COMPILE}
-DNVCC_ARCH_BIN=${FA_NVCC_ARCH_BIN}
-DWITH_FLASHATTN_V3=${WITH_FLASHATTN_V3}
${EXTERNAL_OPTIONAL_ARGS}
CMAKE_CACHE_ARGS
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
-DCMAKE_INSTALL_PREFIX:PATH=${FLASHATTN_INSTALL_DIR}
BUILD_BYPRODUCTS ${FLASHATTN_LIBRARIES})
BUILD_BYPRODUCTS ${BUILD_BYPRODUCTS_LIST})

endif()

message(STATUS "flash-attn library: ${FLASHATTN_LIBRARIES}")
if(WITH_FLASHATTN_V3)
message(STATUS "flash-attn-v3 library: ${FLASHATTN_V3_LIBRARIES}")
endif()
get_filename_component(FLASHATTN_LIBRARY_PATH ${FLASHATTN_LIBRARIES} DIRECTORY)
include_directories(${FLASHATTN_INCLUDE_DIR})

Expand Down
8 changes: 8 additions & 0 deletions cmake/inference_lib.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,14 @@ function(copy_part_of_third_party TARGET DST)
DSTS ${dst_dir} ${dst_dir}/lib)
endif()

if(WITH_FLASHATTN_V3)
set(dst_dir "${DST}/third_party/install/flashattn")
copy(
${TARGET}
SRCS ${FLASHATTN_INCLUDE_DIR} ${FLASHATTN_V3_LIBRARIES}
DSTS ${dst_dir} ${dst_dir}/lib)
endif()

if(NOT PROTOBUF_FOUND OR WIN32)
set(dst_dir "${DST}/third_party/install/protobuf")
copy(
Expand Down
17 changes: 16 additions & 1 deletion cmake/third_party.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,22 @@ if(WITH_GPU
AND NOT WITH_ARM
AND NOT WIN32
AND NOT APPLE)
if(${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 11.4)
if(${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 12.3)
foreach(arch ${NVCC_ARCH_BIN})
if(${arch} GREATER_EQUAL 90)
set(WITH_FLASHATTN_V3 ON)
break()
endif()
endforeach()
foreach(arch ${NVCC_ARCH_BIN})
if(${arch} GREATER_EQUAL 80)
include(external/flashattn)
list(APPEND third_party_deps extern_flashattn)
set(WITH_FLASHATTN ON)
break()
endif()
endforeach()
elseif(${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 11.4)
foreach(arch ${NVCC_ARCH_BIN})
if(${arch} GREATER_EQUAL 80)
include(external/flashattn)
Expand Down
18 changes: 18 additions & 0 deletions paddle/common/flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2023,3 +2023,21 @@ PHI_DEFINE_EXPORTED_int64(multi_block_attention_min_partition_size,
PHI_DEFINE_EXPORTED_bool(save_cf_stack_op,
false,
"Save cf stack op for higher-order derivatives.");

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
/**
* FlashAttention related FLAG
* Name: FLAGS_flash_attn_version
* Value Range: int32, default=2
* Example:
* Note: Specify the version of FlashAttention to use, options are 2 or 3.
* Version 2 requires Ampere architecture or higher,
* while version 3 requires Hopper architecture.
*/
PHI_DEFINE_EXPORTED_int32(
flash_attn_version,
2,
"Specify the version of FlashAttention to use, options are 2 or 3. "
"Version 2 requires Ampere architecture or higher, "
"while version 3 requires Hopper architecture.");
#endif
4 changes: 4 additions & 0 deletions paddle/phi/backends/dynload/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ if(WITH_FLASHATTN)
list(APPEND DYNLOAD_COMMON_SRCS flashattn.cc)
endif()

if(WITH_FLASHATTN_V3)
list(APPEND DYNLOAD_COMMON_SRCS flashattnv3.cc)
endif()

if(WITH_PSCORE)
list(APPEND DYNLOAD_COMMON_SRCS afs_api.cc)
endif()
Expand Down
14 changes: 14 additions & 0 deletions paddle/phi/backends/dynload/dynamic_loader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,20 @@ void* GetFlashAttnDsoHandle() {
#endif
}

void* GetFlashAttnV3DsoHandle() {
std::string flashattn_dir = "";
if (!s_py_site_pkg_path.path.empty()) {
flashattn_dir = s_py_site_pkg_path.path;
}
#if defined(__APPLE__) || defined(__OSX__)
return GetDsoHandleFromSearchPath(flashattn_dir, "libflashattnv3.dylib");
#elif defined(_WIN32)
return GetDsoHandleFromSearchPath(flashattn_dir, "flashattnv3.dll");
#else
return GetDsoHandleFromSearchPath(flashattn_dir, "libflashattnv3.so");
#endif
}

void* GetAfsApiDsoHandle() {
std::string afsapi_dir = "";
if (!s_py_site_pkg_path.path.empty()) {
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/backends/dynload/dynamic_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ void* GetCUDADsoHandle();
void* GetWarpCTCDsoHandle();
void* GetWarpRNNTDsoHandle();
void* GetFlashAttnDsoHandle();
void* GetFlashAttnV3DsoHandle();
void* GetNCCLDsoHandle();
void* GetTensorRtDsoHandle();
void* GetMKLMLDsoHandle();
Expand Down
28 changes: 28 additions & 0 deletions paddle/phi/backends/dynload/flashattnv3.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/backends/dynload/flashattnv3.h"

namespace phi {
namespace dynload {

std::once_flag flashattnv3_dso_flag;
void* flashattnv3_dso_handle = nullptr;

#define DEFINE_WRAP(__name) DynLoad__##__name __name

FLASHATTN_V3_ROUTINE_EACH(DEFINE_WRAP);

} // namespace dynload
} // namespace phi
57 changes: 57 additions & 0 deletions paddle/phi/backends/dynload/flashattnv3.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <mutex> // NOLINT

#include "flashattn/include/flashv3_api.h"
#include "paddle/phi/backends/dynload/dynamic_loader.h"
#include "paddle/phi/common/port.h"

namespace phi {
namespace dynload {

extern std::once_flag flashattnv3_dso_flag;
extern void* flashattnv3_dso_handle;

#define DYNAMIC_LOAD_FLASHATTN_V3_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \
using flashattnFunc = decltype(&::__name); \
std::call_once(flashattnv3_dso_flag, []() { \
flashattnv3_dso_handle = phi::dynload::GetFlashAttnV3DsoHandle(); \
}); \
static void* p_##__name = dlsym(flashattnv3_dso_handle, #__name); \
return reinterpret_cast<flashattnFunc>(p_##__name)(args...); \
} \
}; \
extern DynLoad__##__name __name

#define DECLARE_DYNAMIC_LOAD_FLASHATTN_V3_WRAP(__name) \
DYNAMIC_LOAD_FLASHATTN_V3_WRAP(__name)

#ifdef PADDLE_WITH_CUDA
#define FLASHATTN_V3_ROUTINE_EACH(__macro) \
__macro(flash_attn_v3_fwd); \
__macro(flash_attn_v3_bwd);
#endif

FLASHATTN_V3_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_FLASHATTN_V3_WRAP);

#undef DYNAMIC_LOAD_FLASHATTN_V3_WRAP

} // namespace dynload
} // namespace phi
4 changes: 3 additions & 1 deletion paddle/phi/kernels/gpu/calc_reduced_attn_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ struct CalcReducedAttnScoresParams : public FlashAttnParamsBase {
const int _head_size,
const float _scale,
const DataType q_dtype)
: FlashAttnParamsBase(_batch_size,
: FlashAttnParamsBase(/*version=*/2,
/*is_fwd=*/true,
_batch_size,
_max_seqlen_q,
_max_seqlen_k,
_num_heads,
Expand Down
Loading