@@ -20,13 +20,6 @@ import numpy as _numpy
2020
2121
2222cdef extern from * nogil:
23- # from CUDA
24- ctypedef enum DataType ' cudaDataType_t' :
25- pass
26- ctypedef enum LibPropType ' libraryPropertyType' :
27- pass
28- ctypedef struct int2 ' int2' :
29- pass
3023
3124 # cuStateVec functions
3225 int custatevecCreate(_Handle* )
@@ -127,6 +120,11 @@ cdef extern from * nogil:
127120 int custatevecSwapIndexBits(
128121 _Handle, void * , DataType, const uint32_t, const int2* , const uint32_t,
129122 const int32_t* , const int32_t* , const uint32_t)
123+ int custatevecMultiDeviceSwapIndexBits(
124+ _Handle* , const uint32_t, void ** , const DataType, const uint32_t,
125+ const uint32_t, const int2* , const uint32_t,
126+ const int32_t* , const int32_t* , const uint32_t,
127+ const _DeviceNetworkType)
130128 int custatevecTestMatrixTypeGetWorkspaceSize(
131129 _Handle, _MatrixType, const void * , DataType, _MatrixLayout,
132130 const uint32_t, const int32_t, _ComputeType, size_t* )
@@ -377,11 +375,10 @@ cpdef abs2sum_array(
377375 - a Python sequence of index bit ordering
378376
379377 bit_ordering_len (uint32_t): The length of ``bit_ordering``.
380- mask_bit_string: A host array for a bit string to specify mask. It can
381- be
378+ mask_bit_string: A host array for specifying mask values. It can be
382379
383380 - an :class:`int` as the pointer address to the array
384- - a Python sequence of index bit ordering
381+ - a Python sequence of mask values
385382
386383 mask_ordering: A host array of mask ordering. It can be
387384
@@ -1622,11 +1619,10 @@ cpdef swap_index_bits(
16221619 - a nested Python sequence of swapped index bits
16231620
16241621 n_swapped_bits (uint32_t): The number of pairs of swapped index bits.
1625- mask_bit_string: A host array for a bit string to specify mask. It can
1626- be
1622+ mask_bit_string: A host array for specifying mask values. It can be
16271623
16281624 - an :class:`int` as the pointer address to the array
1629- - a Python sequence of index bit ordering
1625+ - a Python sequence of mask values
16301626
16311627 mask_ordering: A host array of mask ordering. It can be
16321628
@@ -1690,6 +1686,125 @@ cpdef swap_index_bits(
16901686 check_status(status)
16911687
16921688
1689+ cpdef multi_device_swap_index_bits(
1690+ handles, uint32_t n_handles, sub_svs, int sv_data_type,
1691+ uint32_t n_global_index_bits, uint32_t n_local_index_bits,
1692+ swapped_bits, uint32_t n_swapped_bits,
1693+ mask_bit_string, mask_ordering, uint32_t mask_len,
1694+ int device_network_type):
1695+ """ Swap index bits and reorder statevector elements on multiple devices.
1696+
1697+ Args:
1698+ handles: A host array of the library handles. It can be
1699+
1700+ - an :class:`int` as the pointer address to the array
1701+ - a Python sequence of :class:`int`, each of which is a valid
1702+ library handle
1703+
1704+ n_handles (uint32_t): The number of handles.
1705+ sub_svs: A host array of the sub-statevector pointers. It can be
1706+
1707+ - an :class:`int` as the pointer address to the array
1708+ - a Python sequence of :class:`int`, each of which is a valid
1709+ sub-statevector pointer (on device)
1710+
1711+ sv_data_type (cuquantum.cudaDataType): The data type of the statevectors.
1712+ n_global_index_bits (uint32_t): The number of the global index bits.
1713+ n_local_index_bits (uint32_t): The number of the local index bits.
1714+ swapped_bits: A host array of pairs of swapped index bits. It can be
1715+
1716+ - an :class:`int` as the pointer address to the nested sequence
1717+ - a nested Python sequence of swapped index bits
1718+
1719+ n_swapped_bits (uint32_t): The number of pairs of swapped index bits.
1720+ mask_bit_string: A host array for specifying mask values. It can be
1721+
1722+ - an :class:`int` as the pointer address to the array
1723+ - a Python sequence of mask values
1724+
1725+ mask_ordering: A host array of mask ordering. It can be
1726+
1727+ - an :class:`int` as the pointer address to the array
1728+ - a Python sequence of index bit ordering
1729+
1730+ mask_len (uint32_t): The length of ``mask_ordering``.
1731+ device_network_type (DeviceNetworkType): The device network topology.
1732+
1733+ .. seealso:: `custatevecMultiDeviceSwapIndexBits`
1734+ """
1735+ # handles can be a pointer address, or a Python sequence
1736+ cdef vector[intptr_t] handlesData
1737+ cdef _Handle* handlesPtr
1738+ if cpython.PySequence_Check(handles):
1739+ handlesData = handles
1740+ handlesPtr = < _Handle* > handlesData.data()
1741+ else : # a pointer address
1742+ handlesPtr = < _Handle* >< intptr_t> handles
1743+
1744+ # sub_svs can be a pointer address, or a Python sequence
1745+ cdef vector[intptr_t] subSVsData
1746+ cdef void ** subSVsPtr
1747+ if cpython.PySequence_Check(sub_svs):
1748+ subSVsData = sub_svs
1749+ subSVsPtr = < void ** > subSVsData.data()
1750+ else : # a pointer address
1751+ subSVsPtr = < void ** >< intptr_t> sub_svs
1752+
1753+ # swapped_bits can be:
1754+ # - a plain pointer address
1755+ # - a nested Python sequence (ex: a list of 2-tuples)
1756+ # Note: it cannot be a mix of sequences and ints. It also cannot be a
1757+ # 1D sequence (of ints), because it's inefficient.
1758+ cdef vector[intptr_t] swappedBitsCData
1759+ cdef int2* swappedBitsPtr
1760+ if is_nested_sequence(swapped_bits):
1761+ try :
1762+ # direct conversion
1763+ data = _numpy.asarray(swapped_bits, dtype = _numpy.int32)
1764+ data = data.reshape(- 1 )
1765+ except :
1766+ # unlikely, but let's do it in the stupid way
1767+ data = _numpy.empty(2 * n_swapped_bits, dtype = _numpy.int32)
1768+ for i, (first, second) in enumerate (swapped_bits):
1769+ data[2 * i] = first
1770+ data[2 * i+ 1 ] = second
1771+ assert data.size == 2 * n_swapped_bits
1772+ swappedBitsPtr = < int2* > (< intptr_t> data.ctypes.data)
1773+ elif isinstance (swapped_bits, int ):
1774+ # a pointer address, take it as is
1775+ swappedBitsPtr = < int2* >< intptr_t> swapped_bits
1776+ else :
1777+ raise ValueError (" swapped_bits is provided in an "
1778+ " un-recognized format" )
1779+
1780+ # mask_bit_string can be a pointer address, or a Python sequence
1781+ cdef vector[int32_t] maskBitStringData
1782+ cdef int32_t* maskBitStringPtr
1783+ if cpython.PySequence_Check(mask_bit_string):
1784+ maskBitStringData = mask_bit_string
1785+ maskBitStringPtr = maskBitStringData.data()
1786+ else : # a pointer address
1787+ maskBitStringPtr = < int32_t* >< intptr_t> mask_bit_string
1788+
1789+ # mask_ordering can be a pointer address, or a Python sequence
1790+ cdef vector[int32_t] maskOrderingData
1791+ cdef int32_t* maskOrderingPtr
1792+ if cpython.PySequence_Check(mask_ordering):
1793+ maskOrderingData = mask_ordering
1794+ maskOrderingPtr = maskOrderingData.data()
1795+ else : # a pointer address
1796+ maskOrderingPtr = < int32_t* >< intptr_t> mask_ordering
1797+
1798+ with nogil:
1799+ status = custatevecMultiDeviceSwapIndexBits(
1800+ handlesPtr, n_handles, subSVsPtr, < DataType> sv_data_type,
1801+ n_global_index_bits, n_local_index_bits,
1802+ swappedBitsPtr, n_swapped_bits,
1803+ maskBitStringPtr, maskOrderingPtr, mask_len,
1804+ < _DeviceNetworkType> device_network_type)
1805+ check_status(status)
1806+
1807+
16931808cpdef size_t test_matrix_type_get_workspace_size(
16941809 intptr_t handle, int matrix_type,
16951810 intptr_t matrix, int matrix_data_type, int layout, uint32_t n_targets,
@@ -1698,9 +1813,9 @@ cpdef size_t test_matrix_type_get_workspace_size(
16981813
16991814 Args:
17001815 handle (intptr_t): The library handle.
1701- matrix_type (cuquantum. MatrixType): The matrix type of the gate matrix.
1702- matrix (intptr_t): The pointer address (as Python :class:`int`) to a matrix
1703- (on either host or device).
1816+ matrix_type (MatrixType): The matrix type of the gate matrix.
1817+ matrix (intptr_t): The pointer address (as Python :class:`int`) to a
1818+ matrix (on either host or device).
17041819 matrix_data_type (cuquantum.cudaDataType): The data type of the matrix.
17051820 layout (MatrixLayout): The memory layout the the matrix.
17061821 n_targets (uint32_t): The length of ``targets``.
@@ -1733,9 +1848,9 @@ cpdef double test_matrix_type(
17331848
17341849 Args:
17351850 handle (intptr_t): The library handle.
1736- matrix_type (cuquantum. MatrixType): The matrix type of the gate matrix.
1737- matrix (intptr_t): The pointer address (as Python :class:`int`) to a matrix
1738- (on either host or device).
1851+ matrix_type (MatrixType): The matrix type of the gate matrix.
1852+ matrix (intptr_t): The pointer address (as Python :class:`int`) to a
1853+ matrix (on either host or device).
17391854 matrix_data_type (cuquantum.cudaDataType): The data type of the matrix.
17401855 layout (MatrixLayout): The memory layout the the matrix.
17411856 n_targets (uint32_t): The length of ``targets``.
@@ -1977,6 +2092,11 @@ class SamplerOutput(IntEnum):
19772092 RANDNUM_ORDER = CUSTATEVEC_SAMPLER_OUTPUT_RANDNUM_ORDER
19782093 ASCENDING_ORDER = CUSTATEVEC_SAMPLER_OUTPUT_ASCENDING_ORDER
19792094
2095+ class DeviceNetworkType (IntEnum ):
2096+ """ See `custatevecDeviceNetworkType_t`."""
2097+ SWITCH = CUSTATEVEC_DEVICE_NETWORK_TYPE_SWITCH
2098+ FULLMESH = CUSTATEVEC_DEVICE_NETWORK_TYPE_FULLMESH
2099+
19802100
19812101del IntEnum
19822102
0 commit comments