Skip to content

Commit 6758c74

Browse files
committed
Helpers: to_numpy/cupy
1 parent 2fe7f8b commit 6758c74

File tree

7 files changed

+158
-15
lines changed

7 files changed

+158
-15
lines changed

MANIFEST.in

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@ recursive-include cmake *
66
recursive-include src *
77
recursive-include tests *
88

9+
# avoid accidentially copying compiled Python files
10+
global-exclude */__pycache__/*
11+
global-exclude *.pyc
12+
913
# see .gitignore
1014
prune cmake-build*
1115
prune .spack-env*

src/amrex/Array4.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
"""
2+
This file is part of pyAMReX
3+
4+
Copyright 2022 AMReX community
5+
Authors: Axel Huebl
6+
License: BSD-3-Clause-LBNL
7+
"""
8+
9+
10+
def array4_to_numpy(self, copy=False, order="F"):
11+
"""
12+
Provide a Numpy view into an Array4.
13+
14+
Note on the order of indices:
15+
By default, this is as in AMReX in Fortran contiguous order, indexing as
16+
x,y,z. This has performance implications for use in external libraries such
17+
as cupy.
18+
The order="C" option will index as z,y,x and perform better with cupy.
19+
https://github.com/AMReX-Codes/pyamrex/issues/55#issuecomment-1579610074
20+
21+
Parameters
22+
----------
23+
self : amrex.Array4_*
24+
An Array4 class in pyAMReX
25+
copy : bool, optional
26+
Copy the data if true, otherwise create a view (default).
27+
order : string, optional
28+
F order (default) or C. C is faster with external libraries.
29+
30+
Returns
31+
-------
32+
np.array
33+
A numpy n-dimensional array.
34+
"""
35+
import numpy as np
36+
37+
if order == "F":
38+
return np.array(self, copy=copy).T
39+
elif order == "C":
40+
return np.array(self, copy=copy)
41+
else:
42+
raise ValueError("The order argument must be F or C.")
43+
44+
45+
def array4_to_cupy(self, copy=False, order="F"):
46+
"""
47+
Provide a Cupy view into an Array4.
48+
49+
Note on the order of indices:
50+
By default, this is as in AMReX in Fortran contiguous order, indexing as
51+
x,y,z. This has performance implications for use in external libraries such
52+
as cupy.
53+
The order="C" option will index as z,y,x and perform better with cupy.
54+
https://github.com/AMReX-Codes/pyamrex/issues/55#issuecomment-1579610074
55+
56+
Parameters
57+
----------
58+
self : amrex.Array4_*
59+
An Array4 class in pyAMReX
60+
copy : bool, optional
61+
Copy the data if true, otherwise create a view (default).
62+
order : string, optional
63+
F order (default) or C. C is faster with external libraries.
64+
65+
Returns
66+
-------
67+
cupy.array
68+
A numpy n-dimensional array.
69+
70+
Raises
71+
------
72+
ImportError
73+
Raises an exception if cupy is not installed
74+
"""
75+
import cupy as cp
76+
77+
if order == "F":
78+
return cp.array(self, copy=copy).T
79+
elif order == "C":
80+
return cp.array(self, copy=copy)
81+
else:
82+
raise ValueError("The order argument must be F or C.")
83+
84+
85+
def register_Array4_extension(Array4_type):
86+
"""Array4 helper methods"""
87+
Array4_type.to_numpy = array4_to_numpy
88+
Array4_type.to_cupy = array4_to_cupy

src/amrex/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
# -*- coding: utf-8 -*-
22

33
# __version__ is TODO - only in spaceNd submodules
4-
__author__ = (
5-
"Axel Huebl, Ryan Sandberg, Shreyas Ananthan, Remi Lehe, " "Weiqun Zhang, et al."
6-
)
4+
__author__ = "Axel Huebl, Ryan Sandberg, Shreyas Ananthan, Remi Lehe, Andrew Myers, Reva Jambunathan, Edodardo Zoni, Weiqun Zhang"
75
__license__ = "BSD-3-Clause-LBNL"

src/amrex/space1d/__init__.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,20 @@
3232
#
3333
def d_decl(x, y, z):
3434
return (x,)
35+
36+
37+
from ..Array4 import register_Array4_extension
38+
39+
register_Array4_extension(Array4_float)
40+
register_Array4_extension(Array4_double)
41+
register_Array4_extension(Array4_longdouble)
42+
43+
register_Array4_extension(Array4_short)
44+
register_Array4_extension(Array4_int)
45+
register_Array4_extension(Array4_long)
46+
register_Array4_extension(Array4_longlong)
47+
48+
register_Array4_extension(Array4_ushort)
49+
register_Array4_extension(Array4_uint)
50+
register_Array4_extension(Array4_ulong)
51+
register_Array4_extension(Array4_ulonglong)

src/amrex/space2d/__init__.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,20 @@
3232
#
3333
def d_decl(x, y, z):
3434
return (x, y)
35+
36+
37+
from ..Array4 import register_Array4_extension
38+
39+
register_Array4_extension(Array4_float)
40+
register_Array4_extension(Array4_double)
41+
register_Array4_extension(Array4_longdouble)
42+
43+
register_Array4_extension(Array4_short)
44+
register_Array4_extension(Array4_int)
45+
register_Array4_extension(Array4_long)
46+
register_Array4_extension(Array4_longlong)
47+
48+
register_Array4_extension(Array4_ushort)
49+
register_Array4_extension(Array4_uint)
50+
register_Array4_extension(Array4_ulong)
51+
register_Array4_extension(Array4_ulonglong)

src/amrex/space3d/__init__.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,20 @@
3232
#
3333
def d_decl(x, y, z):
3434
return (x, y, z)
35+
36+
37+
from ..Array4 import register_Array4_extension
38+
39+
register_Array4_extension(Array4_float)
40+
register_Array4_extension(Array4_double)
41+
register_Array4_extension(Array4_longdouble)
42+
43+
register_Array4_extension(Array4_short)
44+
register_Array4_extension(Array4_int)
45+
register_Array4_extension(Array4_long)
46+
register_Array4_extension(Array4_longlong)
47+
48+
register_Array4_extension(Array4_ushort)
49+
register_Array4_extension(Array4_uint)
50+
register_Array4_extension(Array4_ulong)
51+
register_Array4_extension(Array4_ulonglong)

tests/test_multifab.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -45,26 +45,25 @@ def test_mfab_loop(make_mfab):
4545

4646
# numpy representation: non-copying view, including the
4747
# guard/ghost region
48-
# note: in numpy, indices are in C-order!
49-
marr_np = np.array(marr, copy=False)
48+
marr_np = marr.to_numpy()
5049

5150
# check the values at start/end are the same: first component
5251
assert marr_np[0, 0, 0, 0] == marr[bx.small_end]
53-
assert marr_np[0, -1, -1, -1] == marr[bx.big_end]
52+
assert marr_np[-1, -1, -1, 0] == marr[bx.big_end]
5453
# same check, but for all components
5554
for n in range(mfab.num_comp):
5655
small_end_comp = list(bx.small_end) + [n]
5756
big_end_comp = list(bx.big_end) + [n]
58-
assert marr_np[n, 0, 0, 0] == marr[small_end_comp]
59-
assert marr_np[n, -1, -1, -1] == marr[big_end_comp]
57+
assert marr_np[0, 0, 0, n] == marr[small_end_comp]
58+
assert marr_np[-1, -1, -1, n] == marr[big_end_comp]
6059

6160
# now we do some faster assignments, using range based access
6261
# this should fail as out-of-bounds, but does not
6362
# does Numpy not check array access for non-owned views?
6463
# marr_np[24:200, :, :, :] = 42.
6564

6665
# all components and all indices set at once to 42
67-
marr_np[:, :, :, :] = 42.0
66+
marr_np[()] = 42.0
6867

6968
# values in start & end still match?
7069
assert marr_np[0, 0, 0, 0] == marr[bx.small_end]
@@ -210,10 +209,11 @@ def test_mfab_ops_cuda_cupy(make_mfab_device):
210209
with cupy.profiler.time_range("assign 3 [()]", color_id=0):
211210
for mfi in mfab_device:
212211
bx = mfi.tilebox().grow(ngv)
213-
marr = mfab_device.array(mfi)
214-
marr_cupy = cp.array(marr, copy=False)
212+
marr_cupy = mfab_device.array(mfi).to_cupy(order="C")
215213
# print(marr_cupy.shape) # 1, 32, 32, 32
216214
# print(marr_cupy.dtype) # float64
215+
# performance:
216+
# https://github.com/AMReX-Codes/pyamrex/issues/55#issuecomment-1579610074
217217

218218
# write and read into the marr_cupy
219219
marr_cupy[()] = 3.0
@@ -244,8 +244,11 @@ def set_to_five(mm):
244244

245245
for mfi in mfab_device:
246246
bx = mfi.tilebox().grow(ngv)
247-
marr = mfab_device.array(mfi)
248-
marr_cupy = cp.array(marr, copy=False)
247+
marr_cupy = mfab_device.array(mfi).to_cupy(order="F")
248+
# print(marr_cupy.shape) # 32, 32, 32, 1
249+
# print(marr_cupy.dtype) # float64
250+
# performance:
251+
# https://github.com/AMReX-Codes/pyamrex/issues/55#issuecomment-1579610074
249252

250253
# write and read into the marr_cupy
251254
fives_cp = set_to_five(marr_cupy)
@@ -266,8 +269,7 @@ def set_to_seven(x):
266269

267270
for mfi in mfab_device:
268271
bx = mfi.tilebox().grow(ngv)
269-
marr = mfab_device.array(mfi)
270-
marr_cupy = cp.array(marr, copy=False)
272+
marr_cupy = mfab_device.array(mfi).to_cupy(order="C")
271273

272274
# write and read into the marr_cupy
273275
set_to_seven(marr_cupy)

0 commit comments

Comments
 (0)