Skip to content

Commit f36cc0e

Browse files
committed
complete 2D CSR sparse transpose op
1 parent c6c33ff commit f36cc0e

File tree

2 files changed

+124
-16
lines changed

2 files changed

+124
-16
lines changed

paddle/phi/core/sparse_csr_tensor.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ class SparseCsrTensor : public TensorBase,
190190
[0, 0, 4, 0],
191191
[0, 5, 0, 6]]
192192
dims_ = (4, 4)
193-
non_zero_elements_ = [1, 2, 3, 4, 5 ,6]
193+
non_zero_elements_ = [1, 2, 3, 4, 5, 6]
194194
non_zero_crows_ = [0, 1, 3, 4, 6]
195195
non_zero_cols_ = [1, 0, 3, 2, 1, 3]
196196
*/
@@ -209,7 +209,7 @@ class SparseCsrTensor : public TensorBase,
209209
[0, 0, 4, 0],
210210
[0, 5, 0, 0]]]
211211
dims_ = (2, 4, 4)
212-
non_zero_elements_ = [1, 2, 3, 4, 5 ,6, 1, 2, 3, 4, 5]
212+
non_zero_elements_ = [1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5]
213213
non_zero_crows_ = [0, 1, 3, 4, 6, 0, 1, 2, 4, 5]
214214
non_zero_cols_ = [1, 0, 3, 2, 1, 3, 1, 0, 3, 2, 1]
215215
*/

paddle/phi/kernels/sparse/impl/unary_kernel_impl.h

Lines changed: 122 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
// limitations under the License.
1414

1515
#pragma once
16-
16+
#include <unordered_set>
1717
#include "paddle/phi/core/ddim.h"
1818
#include "paddle/phi/core/meta_tensor.h"
1919
#include "paddle/phi/core/sparse_coo_tensor.h"
@@ -219,46 +219,154 @@ void TransposeCooKernel(const Context& dev_ctx,
219219
DenseTensor* out_indices = out->mutable_indices();
220220
DenseTensor* out_values = out->mutable_non_zero_elements();
221221

222-
int64_t* x_indices_data = x_indices.data<int64_t>();
222+
const int64_t* x_indices_data = x_indices.data<int64_t>();
223223
int64_t* out_indices_data = out_indices->data<int64_t>();
224224
int64_t x_nnz = x.nnz();
225-
std::vector<int> shape;
226225
for (int64_t i = 0; i < dims.size(); ++i) {
227226
for (int64_t j = 0; j < x_nnz; ++j) {
228227
out_indices_data[j + i * x_nnz] = x_indices_data[j + dims[i] * x_nnz];
229228
}
230-
shape.push_back()
231229
}
232230

233-
DDim out_ddim(x.dims());
234-
out_ddim.transpose(dims);
235-
231+
DDim out_dims(x.dims());
232+
out_dims.transpose(dims);
236233
phi::Copy(dev_ctx, x_values, dev_ctx.GetPlace(), false, out_values);
237-
out->Resize(out_ddim, x.sparse_dim(), x_nnz);
234+
out->Resize(out_dims, x.sparse_dim(), x_nnz);
238235
}
239236

240237
template <typename T, typename Context>
241238
void TransposeCsrKernel(const Context& dev_ctx,
242239
const SparseCsrTensor& x,
243240
const std::vector<int>& dims,
244241
SparseCsrTensor* out) {
245-
out->set_dims(x.dims());
246-
242+
int n_dim = dims.size();
243+
DDim out_dims(x.dims());
244+
out_dims.transpose(dims);
245+
out->set_dims(out_dims);
246+
out->Resize(out_dims, x.nnz());
247247
const DenseTensor& x_crows = x.crows();
248248
const DenseTensor& x_cols = x.cols();
249249
const DenseTensor& x_values = x.non_zero_elements();
250250
DenseTensor* out_crows = out->mutable_crows();
251251
DenseTensor* out_cols = out->mutable_cols();
252252
DenseTensor* out_values = out->mutable_non_zero_elements();
253253

254-
*out_crows = x_crows;
255-
*out_cols = x_cols;
254+
// return a copy of x
255+
if (dims[0] == 0 && dims[1] == 1 && (n_dim == 2 || dims[2] == 2)) {
256+
*out_crows = x_crows;
257+
*out_cols = x_cols;
258+
phi::Copy(dev_ctx, x_values, dev_ctx.GetPlace(), false, out_values);
259+
return;
260+
}
256261

257262
int* out_crows_data = out_crows->data<int>();
258263
int* out_cols_data = out_cols->data<int>();
264+
T* out_values_data = out_values->data<T>();
265+
const int* x_crows_data = x_crows.data<int>();
266+
const int* x_cols_data = x_cols.data<int>();
267+
const T* x_values_data = x_values.data<T>();
259268

260-
phi::Copy(dev_ctx, x_values, dev_ctx.GetPlace(), false, out_values);
261-
out->Resize(phi::make_ddim(shape), x_values.dims()[0]);
269+
if (n_dim == 2) { // dims == {1, 0}
270+
// compute out_crows_data by x_cols_data
271+
for (int i = 0; i < out_dims[0]; ++i) {
272+
out_crows_data[i] = 0;
273+
}
274+
out_crows_data[out_dims[0]] = x.nnz();
275+
for (int i = 0; i < x.nnz(); ++i) {
276+
int j = x_cols_data[i];
277+
out_crows_data[j + 1]++;
278+
}
279+
for (int i = 1; i < out_dims[0]; ++i) {
280+
out_crows_data[i] += out_crows_data[i - 1];
281+
}
282+
// compute out_cols_data and out_values_data by out_crows_data and x
283+
std::unordered_set<int> cols_ptr;
284+
for (int i = 0; i < x.dims()[0]; ++i) {
285+
int start = x_crows_data[i];
286+
int end = x_crows_data[i + 1];
287+
for (int j = start; j < end; ++j) {
288+
int jj = x_cols_data[j];
289+
int jjj = out_crows_data[jj];
290+
int jjj_ptr = jjj + cols_ptr.count();
291+
out_cols_data[jjj_ptr] = i;
292+
out_values_data[jjj_ptr] = x_values_data[j];
293+
cols_ptr.insert(jjj);
294+
}
295+
}
296+
} else { // n_dim == 3
297+
for (int k = 0; k < out_dims[0]; ++k) {
298+
if (dims[0] == 0) { // dims == {0, 2, 1}
299+
int out_n_rows = out_dims[1];
300+
// compute out_crows_data by x_cols_data
301+
for (int i = 0; i < out_n_rows; ++i) {
302+
out_crows_data[i] = 0;
303+
}
304+
out_crows_data[out_n_rows] = x_crows_data[x.dims()[1]];
305+
for (int i = 0; i < out_crows_data[out_n_rows]; ++i) {
306+
int j = x_cols_data[i];
307+
out_crows_data[j + 1]++;
308+
}
309+
for (int i = 1; i < out_n_rows; ++i) {
310+
out_crows_data[i] += out_crows_data[i - 1];
311+
}
312+
// compute out_cols_data and out_values_data by out_crows_data and x
313+
std::unordered_set<int> cols_ptr;
314+
for (int i = 0; i < x.dims()[1]; ++i) {
315+
int start = x_crows_data[i];
316+
int end = x_crows_data[i + 1];
317+
for (int j = start; j < end; ++j) {
318+
int jj = x_cols_data[j];
319+
int jjj = out_crows_data[jj];
320+
int jjj_ptr = jjj + cols_ptr.count();
321+
out_cols_data[jjj_ptr] = i;
322+
out_values_data[jjj_ptr] = x_values_data[j];
323+
cols_ptr.insert(jjj);
324+
}
325+
}
326+
// x offset
327+
x_crows_data += x.dims()[1] + 1;
328+
x_cols_data += x_crows_data[x.dims()[1]];
329+
x_values_data += x_crows_data[x.dims()[1]];
330+
} else if (dims[0] == 1) {
331+
int out_n_rows = out_dims[1];
332+
// compute out_crows_data by x_cols_data
333+
for (int i = 0; i < out_n_rows; ++i) {
334+
out_crows_data[i] = 0;
335+
}
336+
// out_crows_data[out_n_rows] = x_crows_data[x.dims()[1]];
337+
int x_cols_offset = 0;
338+
int out_cols_offset = 0;
339+
for (int i = 0; i < x.dims()[0]; ++i) {
340+
int x_crows_index = i * (x.dims()[1] + 1);
341+
int start = x_crows_data[x_crows_index];
342+
int end = x_crows_data[x_crows_index + 1];
343+
out_crows_data[i] = end - start;
344+
for (int j = start; j < end; ++j) {
345+
out_cols_data[j - start] = x_cols_data[x_cols_offset + j];
346+
out_values_data[j - start] = x_values_data[x_cols_offset + j];
347+
x_cols_offset += x_crows_data[x_crows_index + x.dims()[1]];
348+
out_cols_offset += out_crows_data[... + out_dims[1]];
349+
}
350+
}
351+
352+
for (int i = 0; i < out_crows_data[out_n_rows]; ++i) {
353+
int j = x_cols_data[i];
354+
out_crows_data[j + 1]++;
355+
}
356+
for (int i = 1; i < out_n_rows; ++i) {
357+
out_crows_data[i] += out_crows_data[i - 1];
358+
}
359+
360+
// x offset
361+
x_crows_data += 1;
362+
} else {
363+
}
364+
// out offset
365+
out_crows_data += out_dims[1] + 1;
366+
out_cols_data += x_crows_data[out_dims[1]];
367+
out_values_data += x_crows_data[out_dims[1]];
368+
}
369+
}
262370
}
263371

264372
} // namespace sparse

0 commit comments

Comments
 (0)