@@ -13,101 +13,19 @@ See the License for the specific language governing permissions and
13
13
limitations under the License.
14
14
==============================================================================*/
15
15
16
- #include < stdexcept>
17
- #include < utility>
18
-
19
- #include " absl/container/flat_hash_map.h"
20
- #include " absl/status/statusor.h"
21
- #include " absl/strings/str_format.h"
22
16
#include " nanobind/nanobind.h"
23
- #include " nanobind/stl/pair.h" // IWYU pragma: keep
24
- #include " jaxlib/gpu/gpu_kernel_helpers.h"
25
- #include " jaxlib/gpu/solver_handle_pool.h"
26
- #include " jaxlib/gpu/solver_kernels.h"
27
17
#include " jaxlib/gpu/solver_kernels_ffi.h"
28
18
#include " jaxlib/gpu/vendor.h"
29
19
#include " jaxlib/kernel_nanobind_helpers.h"
30
- #include " xla/tsl/python/lib/core/numpy.h"
31
20
32
21
namespace jax {
33
22
namespace JAX_GPU_NAMESPACE {
34
23
namespace {
35
24
36
25
namespace nb = nanobind;
37
26
38
- // Converts a NumPy dtype to a Type.
39
- SolverType DtypeToSolverType (const dtype& np_type) {
40
- static auto * types =
41
- new absl::flat_hash_map<std::pair<char , int >, SolverType>({
42
- {{' f' , 4 }, SolverType::F32},
43
- {{' f' , 8 }, SolverType::F64},
44
- {{' c' , 8 }, SolverType::C64},
45
- {{' c' , 16 }, SolverType::C128},
46
- });
47
- auto it = types->find ({np_type.kind (), np_type.itemsize ()});
48
- if (it == types->end ()) {
49
- nb::str repr = nb::repr (np_type);
50
- throw std::invalid_argument (
51
- absl::StrFormat (" Unsupported dtype %s" , repr.c_str ()));
52
- }
53
- return it->second ;
54
- }
55
-
56
- #ifdef JAX_GPU_CUDA
57
-
58
- // csrlsvqr: Linear system solve via Sparse QR
59
-
60
- // Returns a descriptor for a csrlsvqr operation.
61
- nb::bytes BuildCsrlsvqrDescriptor (const dtype& dtype, int n, int nnzA,
62
- int reorder, double tol) {
63
- SolverType type = DtypeToSolverType (dtype);
64
- return PackDescriptor (CsrlsvqrDescriptor{type, n, nnzA, reorder, tol});
65
- }
66
-
67
- #endif // JAX_GPU_CUDA
68
-
69
- // Returns the workspace size and a descriptor for a geqrf operation.
70
- std::pair<int , nb::bytes> BuildSytrdDescriptor (const dtype& dtype, bool lower,
71
- int b, int n) {
72
- SolverType type = DtypeToSolverType (dtype);
73
- auto h = SolverHandlePool::Borrow (/* stream=*/ nullptr );
74
- JAX_THROW_IF_ERROR (h.status ());
75
- auto & handle = *h;
76
- int lwork;
77
- gpusolverFillMode_t uplo =
78
- lower ? GPUSOLVER_FILL_MODE_LOWER : GPUSOLVER_FILL_MODE_UPPER;
79
- switch (type) {
80
- case SolverType::F32:
81
- JAX_THROW_IF_ERROR (JAX_AS_STATUS (gpusolverDnSsytrd_bufferSize (
82
- handle.get (), uplo, n, /* A=*/ nullptr , /* lda=*/ n, /* D=*/ nullptr ,
83
- /* E=*/ nullptr , /* tau=*/ nullptr , &lwork)));
84
- break ;
85
- case SolverType::F64:
86
- JAX_THROW_IF_ERROR (JAX_AS_STATUS (gpusolverDnDsytrd_bufferSize (
87
- handle.get (), uplo, n, /* A=*/ nullptr , /* lda=*/ n, /* D=*/ nullptr ,
88
- /* E=*/ nullptr , /* tau=*/ nullptr , &lwork)));
89
- break ;
90
- case SolverType::C64:
91
- JAX_THROW_IF_ERROR (JAX_AS_STATUS (gpusolverDnChetrd_bufferSize (
92
- handle.get (), uplo, n, /* A=*/ nullptr , /* lda=*/ n, /* D=*/ nullptr ,
93
- /* E=*/ nullptr , /* tau=*/ nullptr , &lwork)));
94
- break ;
95
- case SolverType::C128:
96
- JAX_THROW_IF_ERROR (JAX_AS_STATUS (gpusolverDnZhetrd_bufferSize (
97
- handle.get (), uplo, n, /* A=*/ nullptr , /* lda=*/ n, /* D=*/ nullptr ,
98
- /* E=*/ nullptr , /* tau=*/ nullptr , &lwork)));
99
- break ;
100
- }
101
- return {lwork, PackDescriptor (SytrdDescriptor{type, uplo, b, n, n, lwork})};
102
- }
103
-
104
27
nb::dict Registrations () {
105
28
nb::dict dict;
106
- dict[JAX_GPU_PREFIX " solver_sytrd" ] = EncapsulateFunction (Sytrd);
107
-
108
- #ifdef JAX_GPU_CUDA
109
- dict[" cusolver_csrlsvqr" ] = EncapsulateFunction (Csrlsvqr);
110
- #endif // JAX_GPU_CUDA
111
29
112
30
dict[JAX_GPU_PREFIX " solver_getrf_ffi" ] = EncapsulateFfiHandler (GetrfFfi);
113
31
dict[JAX_GPU_PREFIX " solver_geqrf_ffi" ] = EncapsulateFfiHandler (GeqrfFfi);
@@ -127,12 +45,7 @@ nb::dict Registrations() {
127
45
}
128
46
129
47
NB_MODULE (_solver, m) {
130
- tsl::ImportNumpy ();
131
48
m.def (" registrations" , &Registrations);
132
- m.def (" build_sytrd_descriptor" , &BuildSytrdDescriptor);
133
- #ifdef JAX_GPU_CUDA
134
- m.def (" build_csrlsvqr_descriptor" , &BuildCsrlsvqrDescriptor);
135
- #endif // JAX_GPU_CUDA
136
49
}
137
50
138
51
} // namespace
0 commit comments