1616
1717#ifdef PADDLE_WITH_XPU
1818
19+ #include < unordered_map>
1920#include < vector>
2021#include " paddle/phi/backends/xpu/enforce_xpu.h"
2122#include " paddle/phi/backends/xpu/xpu_header.h"
@@ -41,29 +42,60 @@ enum XPUFCCalcType {
4142 FC_FLOAT16,
4243};
4344
44- template <typename T>
45- XPUFCCalcType FCCalcType () {
46- const char * xpu_paddle_fc_float16 = std::getenv (" XPU_PADDLE_FC_FLOAT16" );
47- if (xpu_paddle_fc_float16 != nullptr &&
48- (std::is_same<phi::dtype::float16, T>::value ||
49- std::is_same<XPUTypeFP16, T>::value || std::is_same<float , T>::value)) {
50- return XPUFCCalcType::FC_FLOAT16;
51- } else if (std::is_same<phi::dtype::float16, T>::value ||
52- std::is_same<XPUTypeFP16, T>::value) {
53- return XPUFCCalcType::FC_INT16;
54- } else if (std::getenv (" XPU_PADDLE_FC_INT32" ) != nullptr ) {
55- return XPUFCCalcType::FC_INT32;
56- } else if (std::getenv (" XPU_PADDLE_FC_LOCAL_INT16" ) != nullptr ) {
57- return XPUFCCalcType::FC_FLOAT;
58- } else if (std::getenv (" XPU_PADDLE_FC_INT32_WITH_LL" ) != nullptr ) {
59- return XPUFCCalcType::FC_INT32_WITH_LL;
60- } else if ((std::is_same<phi::dtype::bfloat16, T>::value ||
61- std::is_same<XPUTypeBF16, T>::value) ||
62- (std::is_same<float , T>::value &&
63- std::getenv (" XPU_PADDLE_FC_TF32" ) != nullptr )) {
64- return XPUFCCalcType::FC_TF32;
45+ using XPUFCCalcTypeMap = std::vector<std::pair<const char *, XPUFCCalcType>>;
46+
47+ inline XPUFCCalcType GetFCCalcTypeFromEnv (const XPUFCCalcTypeMap& env_map,
48+ XPUFCCalcType default_calc_type) {
49+ for (auto [env_name, calc_type] : env_map) {
50+ if (std::getenv (env_name) != nullptr ) {
51+ return calc_type;
52+ }
6553 }
66- return XPUFCCalcType::FC_INT16;
54+ return default_calc_type;
55+ }
56+
57+ template <typename T>
58+ inline XPUFCCalcType FCCalcType () {
59+ // FLOAT32
60+ XPUFCCalcTypeMap calc_type_map = {
61+ {" XPU_PADDLE_FC_FLOAT" , XPUFCCalcType::FC_FLOAT},
62+ {" XPU_PADDLE_FC_LOCAL_INT16" , XPUFCCalcType::FC_FLOAT},
63+ {" XPU_PADDLE_FC_TF32" , XPUFCCalcType::FC_TF32},
64+ {" XPU_PADDLE_FC_INT16" , XPUFCCalcType::FC_INT16},
65+ {" XPU_PADDLE_FC_INT32" , XPUFCCalcType::FC_INT32},
66+ {" XPU_PADDLE_FC_INT32_WITH_LL" , XPUFCCalcType::FC_INT32_WITH_LL},
67+ };
68+ #ifdef PADDLE_WITH_XPU_XRE5
69+ auto default_calc_type = XPUFCCalcType::FC_TF32;
70+ #else
71+ auto default_calc_type = XPUFCCalcType::FC_INT16;
72+ #endif
73+ return GetFCCalcTypeFromEnv (calc_type_map, default_calc_type);
74+ }
75+
76+ template <>
77+ inline XPUFCCalcType FCCalcType<XPUTypeFP16>() {
78+ XPUFCCalcTypeMap calc_type_map = {
79+ {" XPU_PADDLE_FC_FLOAT16" , XPUFCCalcType::FC_FLOAT16},
80+ {" XPU_PADDLE_FC_INT16" , XPUFCCalcType::FC_INT16},
81+ {" XPU_PADDLE_FC_FLOAT" , XPUFCCalcType::FC_FLOAT},
82+ {" XPU_PADDLE_FC_LOCAL_INT16" , XPUFCCalcType::FC_FLOAT}};
83+ #ifdef PADDLE_WITH_XPU_XRE5
84+ auto default_calc_type = XPUFCCalcType::FC_FLOAT16;
85+ #else
86+ auto default_calc_type = XPUFCCalcType::FC_INT16;
87+ #endif
88+ return GetFCCalcTypeFromEnv (calc_type_map, default_calc_type);
89+ }
90+
91+ template <>
92+ inline XPUFCCalcType FCCalcType<XPUTypeBF16>() {
93+ XPUFCCalcTypeMap calc_type_map = {
94+ // TF32 is the default, do not need to be listed here.
95+ {" XPU_PADDLE_FC_FLOAT" , XPUFCCalcType::FC_FLOAT},
96+ {" XPU_PADDLE_FC_LOCAL_INT16" , XPUFCCalcType::FC_FLOAT}};
97+ auto default_calc_type = XPUFCCalcType::FC_TF32;
98+ return GetFCCalcTypeFromEnv (calc_type_map, default_calc_type);
6799}
68100
69101struct XpuFcInfo {
0 commit comments