@@ -25,22 +25,23 @@ namespace ck_tile {
25
25
* @note This tile window does not support single issue you need to use tile_window_linear
26
26
* structure for this purpose
27
27
*
28
- * @tparam BottomTensorView_ Class describing & holding device tensor memory.
29
- * @tparam WindowLengths_ Spatial sizes of windowed view on tensor.
30
- * @tparam StaticTileDistribution_ Thread distribution (mapping) into Tile dimensions
31
- * @tparam NumCoord TBD
28
+ * @tparam BottomTensorView_ Class describing & holding device tensor memory.
29
+ * @tparam WindowLengths_ Spatial sizes of windowed view on tensor.
30
+ * @tparam StaticTileDistribution_ Thread distribution (mapping) into Tile dimensions
31
+ * @tparam ReplacementPartitionIndex Replacement values of (get_warp_id(), get_lane_id()) tuple
32
+ * @tparam NumCoord TBD
32
33
*/
33
34
template <typename BottomTensorView_,
34
35
typename WindowLengths_,
35
36
typename StaticTileDistribution_,
36
- typename PartitionIndex ,
37
+ typename ReplacementPartitionIndex ,
37
38
index_t NumCoord>
38
39
struct tile_window_with_static_distribution
39
40
: public tile_window_with_tile_dstr_base<
40
41
tile_window_with_static_distribution<BottomTensorView_,
41
42
WindowLengths_,
42
43
StaticTileDistribution_,
43
- PartitionIndex ,
44
+ ReplacementPartitionIndex ,
44
45
NumCoord>,
45
46
BottomTensorView_,
46
47
WindowLengths_,
@@ -50,7 +51,7 @@ struct tile_window_with_static_distribution
50
51
tile_window_with_static_distribution<BottomTensorView_,
51
52
WindowLengths_,
52
53
StaticTileDistribution_,
53
- PartitionIndex ,
54
+ ReplacementPartitionIndex ,
54
55
NumCoord>,
55
56
BottomTensorView_,
56
57
WindowLengths_,
@@ -88,12 +89,12 @@ struct tile_window_with_static_distribution
88
89
}
89
90
}
90
91
91
- template <typename NewPartitionIndex = PartitionIndex >
92
+ template <typename NewReplacementPartitionIndex = ReplacementPartitionIndex >
92
93
CK_TILE_DEVICE constexpr auto
93
94
prepare_coords (const typename Base::BottomTensorView& bottom_tensor_view,
94
95
const typename Base::BottomTensorIndex& window_origin,
95
96
const typename Base::TileDstr& tile_distribution,
96
- NewPartitionIndex = {}) const
97
+ NewReplacementPartitionIndex = {}) const
97
98
{
98
99
array<tuple<typename Base::WindowAdaptorCoord, typename Base::BottomTensorCoord>, NumCoord>
99
100
coords;
@@ -102,15 +103,16 @@ struct tile_window_with_static_distribution
102
103
tile_distribution.get_ps_ys_to_xs_adaptor (),
103
104
container_concat (
104
105
// Override partition_index with the corresponding non-negative elements (if
105
- // any) from NewPartitionIndex
106
+ // any) from NewReplacementPartitionIndex
106
107
[&] {
107
108
auto partition_index = detail::get_partition_index (tile_distribution);
108
109
static_for<0 ,
109
- ck_tile::min (partition_index.size (), NewPartitionIndex::size ()),
110
+ ck_tile::min (partition_index.size (),
111
+ NewReplacementPartitionIndex::size ()),
110
112
1 >{}([&](auto idx) {
111
- if constexpr (0 <= NewPartitionIndex {}[idx])
113
+ if constexpr (0 <= NewReplacementPartitionIndex {}[idx])
112
114
{
113
- partition_index[idx] = NewPartitionIndex {}[idx];
115
+ partition_index[idx] = NewReplacementPartitionIndex {}[idx];
114
116
}
115
117
});
116
118
return partition_index;
@@ -1018,20 +1020,20 @@ struct tile_window_with_static_distribution
1018
1020
template <typename TensorView_,
1019
1021
typename WindowLengths_,
1020
1022
typename StaticTileDistribution_,
1021
- typename PartitionIndex = sequence<-1 , -1 >,
1022
- index_t NumCoord = 1 >
1023
+ typename ReplacementPartitionIndex = sequence<-1 , -1 >,
1024
+ index_t NumCoord = 1 >
1023
1025
CK_TILE_DEVICE constexpr auto
1024
1026
make_tile_window (const TensorView_& tensor_view,
1025
1027
const WindowLengths_& window_lengths,
1026
1028
const multi_index<TensorView_::get_num_of_dimension()>& origin,
1027
1029
const StaticTileDistribution_& tile_distribution,
1028
- PartitionIndex = {},
1029
- number<NumCoord> = {})
1030
+ ReplacementPartitionIndex = {},
1031
+ number<NumCoord> = {})
1030
1032
{
1031
1033
return tile_window_with_static_distribution<remove_cvref_t <TensorView_>,
1032
1034
remove_cvref_t <WindowLengths_>,
1033
1035
remove_cvref_t <StaticTileDistribution_>,
1034
- PartitionIndex ,
1036
+ ReplacementPartitionIndex ,
1035
1037
NumCoord>{
1036
1038
tensor_view, window_lengths, origin, tile_distribution};
1037
1039
}
@@ -1040,20 +1042,20 @@ make_tile_window(const TensorView_& tensor_view,
1040
1042
template <typename TensorView_,
1041
1043
typename WindowLengths_,
1042
1044
typename StaticTileDistribution_,
1043
- typename PartitionIndex = sequence<-1 , -1 >,
1044
- index_t NumCoord = 1 >
1045
+ typename ReplacementPartitionIndex = sequence<-1 , -1 >,
1046
+ index_t NumCoord = 1 >
1045
1047
CK_TILE_DEVICE auto
1046
1048
make_tile_window_raw (const TensorView_& tensor_view,
1047
1049
const WindowLengths_& window_lengths,
1048
1050
const multi_index<TensorView_::get_num_of_dimension()>& origin,
1049
1051
const StaticTileDistribution_& tile_distribution,
1050
- PartitionIndex = {},
1051
- number<NumCoord> = {})
1052
+ ReplacementPartitionIndex = {},
1053
+ number<NumCoord> = {})
1052
1054
{
1053
1055
auto w = tile_window_with_static_distribution<remove_cvref_t <TensorView_>,
1054
1056
remove_cvref_t <WindowLengths_>,
1055
1057
remove_cvref_t <StaticTileDistribution_>,
1056
- PartitionIndex ,
1058
+ ReplacementPartitionIndex ,
1057
1059
NumCoord>{
1058
1060
tensor_view, window_lengths, origin, tile_distribution};
1059
1061
w.init_raw ();
@@ -1063,18 +1065,18 @@ make_tile_window_raw(const TensorView_& tensor_view,
1063
1065
template <typename TensorView_,
1064
1066
typename WindowLengths_,
1065
1067
typename StaticTileDistribution_,
1066
- typename PartitionIndex_ ,
1068
+ typename ReplacementPartitionIndex_ ,
1067
1069
index_t NumCoord>
1068
1070
CK_TILE_DEVICE void move_tile_window (
1069
1071
tile_window_with_static_distribution<TensorView_,
1070
1072
WindowLengths_,
1071
1073
StaticTileDistribution_,
1072
- PartitionIndex_ ,
1074
+ ReplacementPartitionIndex_ ,
1073
1075
NumCoord>& window,
1074
1076
const typename tile_window_with_static_distribution<TensorView_,
1075
1077
WindowLengths_,
1076
1078
StaticTileDistribution_,
1077
- PartitionIndex_ ,
1079
+ ReplacementPartitionIndex_ ,
1078
1080
NumCoord>::BottomTensorIndex& step)
1079
1081
{
1080
1082
window.move (step);
@@ -1083,24 +1085,24 @@ CK_TILE_DEVICE void move_tile_window(
1083
1085
template <typename TensorView_,
1084
1086
typename WindowLengths_,
1085
1087
typename StaticTileDistribution_,
1086
- typename PartitionIndex_ ,
1088
+ typename ReplacementPartitionIndex_ ,
1087
1089
index_t NumCoord>
1088
1090
CK_TILE_DEVICE void move_tile_window (
1089
1091
tuple<tile_window_with_static_distribution<TensorView_,
1090
1092
WindowLengths_,
1091
1093
StaticTileDistribution_,
1092
- PartitionIndex_ ,
1094
+ ReplacementPartitionIndex_ ,
1093
1095
NumCoord>>& window,
1094
1096
const typename tile_window_with_static_distribution<TensorView_,
1095
1097
WindowLengths_,
1096
1098
StaticTileDistribution_,
1097
- PartitionIndex_ ,
1099
+ ReplacementPartitionIndex_ ,
1098
1100
NumCoord>::BottomTensorIndex& step)
1099
1101
{
1100
1102
using T = tuple<tile_window_with_static_distribution<TensorView_,
1101
1103
WindowLengths_,
1102
1104
StaticTileDistribution_,
1103
- PartitionIndex_ ,
1105
+ ReplacementPartitionIndex_ ,
1104
1106
NumCoord>>;
1105
1107
1106
1108
static constexpr auto N = T::size ();
@@ -1228,45 +1230,45 @@ make_tile_window(const tile_window_with_static_lengths<TensorView, WindowLengths
1228
1230
template <typename TensorView,
1229
1231
typename WindowLengths,
1230
1232
typename StaticTileDistribution,
1231
- typename PartitionIndex = sequence<-1 , -1 >>
1233
+ typename ReplacementPartitionIndex = sequence<-1 , -1 >>
1232
1234
CK_TILE_DEVICE constexpr auto
1233
1235
make_tile_window (const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
1234
1236
const multi_index<TensorView::get_num_of_dimension()>& origin,
1235
1237
const StaticTileDistribution& tile_distribution,
1236
- PartitionIndex = {})
1238
+ ReplacementPartitionIndex = {})
1237
1239
{
1238
1240
return make_tile_window (tile_window.get_bottom_tensor_view (),
1239
1241
tile_window.get_window_lengths (),
1240
1242
origin,
1241
1243
tile_distribution,
1242
- PartitionIndex {});
1244
+ ReplacementPartitionIndex {});
1243
1245
}
1244
1246
1245
1247
template <typename TensorView,
1246
1248
typename WindowLengths,
1247
1249
typename StaticTileDistribution,
1248
- typename PartitionIndex = sequence<-1 , -1 >>
1250
+ typename ReplacementPartitionIndex = sequence<-1 , -1 >>
1249
1251
CK_TILE_DEVICE constexpr auto
1250
1252
make_tile_window (const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
1251
1253
const StaticTileDistribution& tile_distribution,
1252
- PartitionIndex = {})
1254
+ ReplacementPartitionIndex = {})
1253
1255
{
1254
1256
return make_tile_window (tile_window.get_bottom_tensor_view (),
1255
1257
tile_window.get_window_lengths (),
1256
1258
tile_window.get_window_origin (),
1257
1259
tile_distribution,
1258
- PartitionIndex {});
1260
+ ReplacementPartitionIndex {});
1259
1261
}
1260
1262
1261
1263
template <typename TensorView,
1262
1264
typename WindowLengths,
1263
1265
typename StaticTileDistribution,
1264
- typename PartitionIndex = sequence<-1 , -1 >>
1266
+ typename ReplacementPartitionIndex = sequence<-1 , -1 >>
1265
1267
CK_TILE_DEVICE constexpr auto
1266
1268
make_tile_window_raw (const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
1267
1269
const StaticTileDistribution& tile_distribution)
1268
1270
{
1269
- auto w = make_tile_window (tile_window, tile_distribution, PartitionIndex {});
1271
+ auto w = make_tile_window (tile_window, tile_distribution, ReplacementPartitionIndex {});
1270
1272
w.init_raw ();
1271
1273
return w;
1272
1274
}
@@ -1295,21 +1297,22 @@ struct is_tile_window_with_static_distribution : std::false_type
1295
1297
/* *
1296
1298
* @brief Specialization for `tile_window_with_static_distribution` to evaluate to `true_type`.
1297
1299
*
1298
- * @tparam BottomTensorView_ Bottom tensor view type of the tile window.
1299
- * @tparam WindowLengths_ Static window lengths.
1300
- * @tparam StaticTileDistribution_ Tile distribution policy.
1301
- * @tparam NumCoord Number of coordinate dimensions.
1300
+ * @tparam BottomTensorView_ Class describing & holding device tensor memory.
1301
+ * @tparam WindowLengths_ Spatial sizes of windowed view on tensor.
1302
+ * @tparam StaticTileDistribution_ Thread distribution (mapping) into Tile dimensions
1303
+ * @tparam ReplacementPartitionIndex Replacement values of (get_warp_id(), get_lane_id()) tuple
1304
+ * @tparam NumCoord TBD
1302
1305
*/
1303
1306
template <typename BottomTensorView_,
1304
1307
typename WindowLengths_,
1305
1308
typename StaticTileDistribution_,
1306
- typename PartitionIndex_ ,
1309
+ typename ReplacementPartitionIndex_ ,
1307
1310
index_t NumCoord>
1308
1311
struct is_tile_window_with_static_distribution <
1309
1312
tile_window_with_static_distribution<BottomTensorView_,
1310
1313
WindowLengths_,
1311
1314
StaticTileDistribution_,
1312
- PartitionIndex_ ,
1315
+ ReplacementPartitionIndex_ ,
1313
1316
NumCoord>> : std::true_type
1314
1317
{
1315
1318
};
0 commit comments