Skip to content

Commit f81b6bf

Browse files
committed
Rename ParitionIndex as ReplacementPartitionIndex
1 parent 64d4aaa commit f81b6bf

File tree

3 files changed

+55
-52
lines changed

3 files changed

+55
-52
lines changed

include/ck_tile/core/tensor/store_tile.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,14 @@ store_tile_raw(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_
6464
template <typename BottomTensorView_,
6565
typename WindowLengths_,
6666
typename TileDistribution_,
67-
typename PartitionIndex_,
67+
typename ReplacementPartitionIndex_,
6868
index_t NumCoord,
6969
typename DataType_>
7070
CK_TILE_DEVICE void
7171
store_tile(tile_window_with_static_distribution<BottomTensorView_,
7272
WindowLengths_,
7373
TileDistribution_,
74-
PartitionIndex_,
74+
ReplacementPartitionIndex_,
7575
NumCoord>& tile_window,
7676
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
7777
{
@@ -81,14 +81,14 @@ store_tile(tile_window_with_static_distribution<BottomTensorView_,
8181
template <typename BottomTensorView_,
8282
typename WindowLengths_,
8383
typename TileDistribution_,
84-
typename PartitionIndex_,
84+
typename ReplacementPartitionIndex_,
8585
index_t NumCoord,
8686
typename DataType_>
8787
CK_TILE_DEVICE void
8888
store_tile_raw(tile_window_with_static_distribution<BottomTensorView_,
8989
WindowLengths_,
9090
TileDistribution_,
91-
PartitionIndex_,
91+
ReplacementPartitionIndex_,
9292
NumCoord>& tile_window,
9393
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
9494
{

include/ck_tile/core/tensor/tile_window.hpp

Lines changed: 47 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -25,22 +25,23 @@ namespace ck_tile {
2525
* @note This tile window does not support single issue you need to use tile_window_linear
2626
* structure for this purpose
2727
*
28-
* @tparam BottomTensorView_ Class describing & holding device tensor memory.
29-
* @tparam WindowLengths_ Spatial sizes of windowed view on tensor.
30-
* @tparam StaticTileDistribution_ Thread distribution (mapping) into Tile dimensions
31-
* @tparam NumCoord TBD
28+
* @tparam BottomTensorView_ Class describing & holding device tensor memory.
29+
* @tparam WindowLengths_ Spatial sizes of windowed view on tensor.
30+
* @tparam StaticTileDistribution_ Thread distribution (mapping) into Tile dimensions
31+
* @tparam ReplacementPartitionIndex Replacement values of (get_warp_id(), get_lane_id()) tuple
32+
* @tparam NumCoord TBD
3233
*/
3334
template <typename BottomTensorView_,
3435
typename WindowLengths_,
3536
typename StaticTileDistribution_,
36-
typename PartitionIndex,
37+
typename ReplacementPartitionIndex,
3738
index_t NumCoord>
3839
struct tile_window_with_static_distribution
3940
: public tile_window_with_tile_dstr_base<
4041
tile_window_with_static_distribution<BottomTensorView_,
4142
WindowLengths_,
4243
StaticTileDistribution_,
43-
PartitionIndex,
44+
ReplacementPartitionIndex,
4445
NumCoord>,
4546
BottomTensorView_,
4647
WindowLengths_,
@@ -50,7 +51,7 @@ struct tile_window_with_static_distribution
5051
tile_window_with_static_distribution<BottomTensorView_,
5152
WindowLengths_,
5253
StaticTileDistribution_,
53-
PartitionIndex,
54+
ReplacementPartitionIndex,
5455
NumCoord>,
5556
BottomTensorView_,
5657
WindowLengths_,
@@ -88,12 +89,12 @@ struct tile_window_with_static_distribution
8889
}
8990
}
9091

91-
template <typename NewPartitionIndex = PartitionIndex>
92+
template <typename NewReplacementPartitionIndex = ReplacementPartitionIndex>
9293
CK_TILE_DEVICE constexpr auto
9394
prepare_coords(const typename Base::BottomTensorView& bottom_tensor_view,
9495
const typename Base::BottomTensorIndex& window_origin,
9596
const typename Base::TileDstr& tile_distribution,
96-
NewPartitionIndex = {}) const
97+
NewReplacementPartitionIndex = {}) const
9798
{
9899
array<tuple<typename Base::WindowAdaptorCoord, typename Base::BottomTensorCoord>, NumCoord>
99100
coords;
@@ -102,15 +103,16 @@ struct tile_window_with_static_distribution
102103
tile_distribution.get_ps_ys_to_xs_adaptor(),
103104
container_concat(
104105
// Override partition_index with the corresponding non-negative elements (if
105-
// any) from NewPartitionIndex
106+
// any) from NewReplacementPartitionIndex
106107
[&] {
107108
auto partition_index = detail::get_partition_index(tile_distribution);
108109
static_for<0,
109-
ck_tile::min(partition_index.size(), NewPartitionIndex::size()),
110+
ck_tile::min(partition_index.size(),
111+
NewReplacementPartitionIndex::size()),
110112
1>{}([&](auto idx) {
111-
if constexpr(0 <= NewPartitionIndex{}[idx])
113+
if constexpr(0 <= NewReplacementPartitionIndex{}[idx])
112114
{
113-
partition_index[idx] = NewPartitionIndex{}[idx];
115+
partition_index[idx] = NewReplacementPartitionIndex{}[idx];
114116
}
115117
});
116118
return partition_index;
@@ -1018,20 +1020,20 @@ struct tile_window_with_static_distribution
10181020
template <typename TensorView_,
10191021
typename WindowLengths_,
10201022
typename StaticTileDistribution_,
1021-
typename PartitionIndex = sequence<-1, -1>,
1022-
index_t NumCoord = 1>
1023+
typename ReplacementPartitionIndex = sequence<-1, -1>,
1024+
index_t NumCoord = 1>
10231025
CK_TILE_DEVICE constexpr auto
10241026
make_tile_window(const TensorView_& tensor_view,
10251027
const WindowLengths_& window_lengths,
10261028
const multi_index<TensorView_::get_num_of_dimension()>& origin,
10271029
const StaticTileDistribution_& tile_distribution,
1028-
PartitionIndex = {},
1029-
number<NumCoord> = {})
1030+
ReplacementPartitionIndex = {},
1031+
number<NumCoord> = {})
10301032
{
10311033
return tile_window_with_static_distribution<remove_cvref_t<TensorView_>,
10321034
remove_cvref_t<WindowLengths_>,
10331035
remove_cvref_t<StaticTileDistribution_>,
1034-
PartitionIndex,
1036+
ReplacementPartitionIndex,
10351037
NumCoord>{
10361038
tensor_view, window_lengths, origin, tile_distribution};
10371039
}
@@ -1040,20 +1042,20 @@ make_tile_window(const TensorView_& tensor_view,
10401042
template <typename TensorView_,
10411043
typename WindowLengths_,
10421044
typename StaticTileDistribution_,
1043-
typename PartitionIndex = sequence<-1, -1>,
1044-
index_t NumCoord = 1>
1045+
typename ReplacementPartitionIndex = sequence<-1, -1>,
1046+
index_t NumCoord = 1>
10451047
CK_TILE_DEVICE auto
10461048
make_tile_window_raw(const TensorView_& tensor_view,
10471049
const WindowLengths_& window_lengths,
10481050
const multi_index<TensorView_::get_num_of_dimension()>& origin,
10491051
const StaticTileDistribution_& tile_distribution,
1050-
PartitionIndex = {},
1051-
number<NumCoord> = {})
1052+
ReplacementPartitionIndex = {},
1053+
number<NumCoord> = {})
10521054
{
10531055
auto w = tile_window_with_static_distribution<remove_cvref_t<TensorView_>,
10541056
remove_cvref_t<WindowLengths_>,
10551057
remove_cvref_t<StaticTileDistribution_>,
1056-
PartitionIndex,
1058+
ReplacementPartitionIndex,
10571059
NumCoord>{
10581060
tensor_view, window_lengths, origin, tile_distribution};
10591061
w.init_raw();
@@ -1063,18 +1065,18 @@ make_tile_window_raw(const TensorView_& tensor_view,
10631065
template <typename TensorView_,
10641066
typename WindowLengths_,
10651067
typename StaticTileDistribution_,
1066-
typename PartitionIndex_,
1068+
typename ReplacementPartitionIndex_,
10671069
index_t NumCoord>
10681070
CK_TILE_DEVICE void move_tile_window(
10691071
tile_window_with_static_distribution<TensorView_,
10701072
WindowLengths_,
10711073
StaticTileDistribution_,
1072-
PartitionIndex_,
1074+
ReplacementPartitionIndex_,
10731075
NumCoord>& window,
10741076
const typename tile_window_with_static_distribution<TensorView_,
10751077
WindowLengths_,
10761078
StaticTileDistribution_,
1077-
PartitionIndex_,
1079+
ReplacementPartitionIndex_,
10781080
NumCoord>::BottomTensorIndex& step)
10791081
{
10801082
window.move(step);
@@ -1083,24 +1085,24 @@ CK_TILE_DEVICE void move_tile_window(
10831085
template <typename TensorView_,
10841086
typename WindowLengths_,
10851087
typename StaticTileDistribution_,
1086-
typename PartitionIndex_,
1088+
typename ReplacementPartitionIndex_,
10871089
index_t NumCoord>
10881090
CK_TILE_DEVICE void move_tile_window(
10891091
tuple<tile_window_with_static_distribution<TensorView_,
10901092
WindowLengths_,
10911093
StaticTileDistribution_,
1092-
PartitionIndex_,
1094+
ReplacementPartitionIndex_,
10931095
NumCoord>>& window,
10941096
const typename tile_window_with_static_distribution<TensorView_,
10951097
WindowLengths_,
10961098
StaticTileDistribution_,
1097-
PartitionIndex_,
1099+
ReplacementPartitionIndex_,
10981100
NumCoord>::BottomTensorIndex& step)
10991101
{
11001102
using T = tuple<tile_window_with_static_distribution<TensorView_,
11011103
WindowLengths_,
11021104
StaticTileDistribution_,
1103-
PartitionIndex_,
1105+
ReplacementPartitionIndex_,
11041106
NumCoord>>;
11051107

11061108
static constexpr auto N = T::size();
@@ -1228,45 +1230,45 @@ make_tile_window(const tile_window_with_static_lengths<TensorView, WindowLengths
12281230
template <typename TensorView,
12291231
typename WindowLengths,
12301232
typename StaticTileDistribution,
1231-
typename PartitionIndex = sequence<-1, -1>>
1233+
typename ReplacementPartitionIndex = sequence<-1, -1>>
12321234
CK_TILE_DEVICE constexpr auto
12331235
make_tile_window(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
12341236
const multi_index<TensorView::get_num_of_dimension()>& origin,
12351237
const StaticTileDistribution& tile_distribution,
1236-
PartitionIndex = {})
1238+
ReplacementPartitionIndex = {})
12371239
{
12381240
return make_tile_window(tile_window.get_bottom_tensor_view(),
12391241
tile_window.get_window_lengths(),
12401242
origin,
12411243
tile_distribution,
1242-
PartitionIndex{});
1244+
ReplacementPartitionIndex{});
12431245
}
12441246

12451247
template <typename TensorView,
12461248
typename WindowLengths,
12471249
typename StaticTileDistribution,
1248-
typename PartitionIndex = sequence<-1, -1>>
1250+
typename ReplacementPartitionIndex = sequence<-1, -1>>
12491251
CK_TILE_DEVICE constexpr auto
12501252
make_tile_window(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
12511253
const StaticTileDistribution& tile_distribution,
1252-
PartitionIndex = {})
1254+
ReplacementPartitionIndex = {})
12531255
{
12541256
return make_tile_window(tile_window.get_bottom_tensor_view(),
12551257
tile_window.get_window_lengths(),
12561258
tile_window.get_window_origin(),
12571259
tile_distribution,
1258-
PartitionIndex{});
1260+
ReplacementPartitionIndex{});
12591261
}
12601262

12611263
template <typename TensorView,
12621264
typename WindowLengths,
12631265
typename StaticTileDistribution,
1264-
typename PartitionIndex = sequence<-1, -1>>
1266+
typename ReplacementPartitionIndex = sequence<-1, -1>>
12651267
CK_TILE_DEVICE constexpr auto
12661268
make_tile_window_raw(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
12671269
const StaticTileDistribution& tile_distribution)
12681270
{
1269-
auto w = make_tile_window(tile_window, tile_distribution, PartitionIndex{});
1271+
auto w = make_tile_window(tile_window, tile_distribution, ReplacementPartitionIndex{});
12701272
w.init_raw();
12711273
return w;
12721274
}
@@ -1295,21 +1297,22 @@ struct is_tile_window_with_static_distribution : std::false_type
12951297
/**
12961298
* @brief Specialization for `tile_window_with_static_distribution` to evaluate to `true_type`.
12971299
*
1298-
* @tparam BottomTensorView_ Bottom tensor view type of the tile window.
1299-
* @tparam WindowLengths_ Static window lengths.
1300-
* @tparam StaticTileDistribution_ Tile distribution policy.
1301-
* @tparam NumCoord Number of coordinate dimensions.
1300+
* @tparam BottomTensorView_ Class describing & holding device tensor memory.
1301+
* @tparam WindowLengths_ Spatial sizes of windowed view on tensor.
1302+
* @tparam StaticTileDistribution_ Thread distribution (mapping) into Tile dimensions
1303+
* @tparam ReplacementPartitionIndex Replacement values of (get_warp_id(), get_lane_id()) tuple
1304+
* @tparam NumCoord TBD
13021305
*/
13031306
template <typename BottomTensorView_,
13041307
typename WindowLengths_,
13051308
typename StaticTileDistribution_,
1306-
typename PartitionIndex_,
1309+
typename ReplacementPartitionIndex_,
13071310
index_t NumCoord>
13081311
struct is_tile_window_with_static_distribution<
13091312
tile_window_with_static_distribution<BottomTensorView_,
13101313
WindowLengths_,
13111314
StaticTileDistribution_,
1312-
PartitionIndex_,
1315+
ReplacementPartitionIndex_,
13131316
NumCoord>> : std::true_type
13141317
{
13151318
};

include/ck_tile/core/tensor/update_tile.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ update_tile(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>&
4040
template <typename BottomTensorView_,
4141
typename WindowLengths_,
4242
typename TileDistribution_,
43-
typename PartitionIndex_,
43+
typename ReplacementPartitionIndex_,
4444
index_t NumCoord,
4545
typename DataType_,
4646
index_t i_access = -1,
@@ -49,7 +49,7 @@ CK_TILE_DEVICE void
4949
update_tile(tile_window_with_static_distribution<BottomTensorView_,
5050
WindowLengths_,
5151
TileDistribution_,
52-
PartitionIndex_,
52+
ReplacementPartitionIndex_,
5353
NumCoord>& tile_window,
5454
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
5555
number<i_access> = {},
@@ -61,7 +61,7 @@ update_tile(tile_window_with_static_distribution<BottomTensorView_,
6161
template <typename BottomTensorView_,
6262
typename WindowLengths_,
6363
typename TileDistribution_,
64-
typename PartitionIndex_,
64+
typename ReplacementPartitionIndex_,
6565
index_t NumCoord,
6666
typename DataType_,
6767
index_t i_access = -1,
@@ -71,7 +71,7 @@ CK_TILE_DEVICE void
7171
update_tile_raw(tile_window_with_static_distribution<BottomTensorView_,
7272
WindowLengths_,
7373
TileDistribution_,
74-
PartitionIndex_,
74+
ReplacementPartitionIndex_,
7575
NumCoord>& tile_window,
7676
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
7777
number<i_access> = {},

0 commit comments

Comments
 (0)