Skip to content

Commit 3e52872

Browse files
dfmGoogle-ML-Automation
authored andcommitted
Clean up some unused GPU linear algebra kernels.
This change removes the legacy `csrlsvqr` and `sytrd` custom calls from jaxlib. These were never covered by the export compatibility policy, and their FFI counterparts have been targeted by JAX for several releases. PiperOrigin-RevId: 766298494
1 parent e9925ee commit 3e52872

File tree

7 files changed

+0
-464
lines changed

7 files changed

+0
-464
lines changed

jaxlib/cuda/BUILD

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -155,24 +155,6 @@ cc_library(
155155
],
156156
)
157157

158-
cc_library(
159-
name = "cusolver_kernels",
160-
srcs = ["//jaxlib/gpu:solver_kernels.cc"],
161-
hdrs = ["//jaxlib/gpu:solver_kernels.h"],
162-
deps = [
163-
":cuda_gpu_kernel_helpers",
164-
":cuda_solver_handle_pool",
165-
":cuda_vendor",
166-
"//jaxlib:kernel_helpers",
167-
"@com_google_absl//absl/status",
168-
"@com_google_absl//absl/status:statusor",
169-
"@local_config_cuda//cuda:cuda_headers",
170-
"@xla//xla/service:custom_call_status",
171-
"@xla//xla/tsl/cuda:cudart",
172-
"@xla//xla/tsl/cuda:cusolver",
173-
],
174-
)
175-
176158
cc_library(
177159
name = "cusolver_interface",
178160
srcs = ["//jaxlib/gpu:solver_interface.cc"],
@@ -223,21 +205,14 @@ nanobind_extension(
223205
features = ["-use_header_modules"],
224206
module_name = "_solver",
225207
deps = [
226-
":cuda_gpu_kernel_helpers",
227-
":cuda_solver_handle_pool",
228208
":cuda_vendor",
229-
":cusolver_kernels",
230209
":cusolver_kernels_ffi",
231210
"//jaxlib:kernel_nanobind_helpers",
232-
"@com_google_absl//absl/container:flat_hash_map",
233-
"@com_google_absl//absl/status:statusor",
234-
"@com_google_absl//absl/strings:str_format",
235211
"@local_config_cuda//cuda:cuda_headers",
236212
"@nanobind",
237213
"@xla//xla/tsl/cuda:cublas",
238214
"@xla//xla/tsl/cuda:cudart",
239215
"@xla//xla/tsl/cuda:cusolver",
240-
"@xla//xla/tsl/python/lib/core:numpy",
241216
],
242217
)
243218

@@ -472,7 +447,6 @@ cc_library(
472447
":cuda_prng_kernels",
473448
":cuda_vendor",
474449
":cudnn_rnn_kernels",
475-
":cusolver_kernels",
476450
":cusolver_kernels_ffi",
477451
":cusparse_kernels",
478452
":triton_kernels",

jaxlib/gpu/BUILD

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,6 @@ exports_files(srcs = [
5959
"solver_handle_pool.h",
6060
"solver_interface.cc",
6161
"solver_interface.h",
62-
"solver_kernels.cc",
63-
"solver_kernels.h",
6462
"solver_kernels_ffi.cc",
6563
"solver_kernels_ffi.h",
6664
"sparse.cc",

jaxlib/gpu/gpu_kernels.cc

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ limitations under the License.
1919
#include "jaxlib/gpu/linalg_kernels.h"
2020
#include "jaxlib/gpu/prng_kernels.h"
2121
#include "jaxlib/gpu/rnn_kernels.h"
22-
#include "jaxlib/gpu/solver_kernels.h"
2322
#include "jaxlib/gpu/solver_kernels_ffi.h"
2423
#include "jaxlib/gpu/sparse_kernels.h"
2524
#include "jaxlib/gpu/triton_kernels.h"
@@ -40,14 +39,12 @@ XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_syrk_ffi", "CUDA",
4039
SyrkFfi);
4140
XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_geqrf_ffi", "CUDA",
4241
GeqrfFfi);
43-
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_csrlsvqr", Csrlsvqr, "CUDA");
4442
XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_csrlsvqr_ffi", "CUDA",
4543
CsrlsvqrFfi);
4644
XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_orgqr_ffi", "CUDA",
4745
OrgqrFfi);
4846
XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_syevd_ffi", "CUDA",
4947
SyevdFfi);
50-
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_sytrd", Sytrd, "CUDA");
5148
XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_sytrd_ffi", "CUDA",
5249
SytrdFfi);
5350
XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_gesvd_ffi", "CUDA",

jaxlib/gpu/solver.cc

Lines changed: 0 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -13,101 +13,19 @@ See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
1515

16-
#include <stdexcept>
17-
#include <utility>
18-
19-
#include "absl/container/flat_hash_map.h"
20-
#include "absl/status/statusor.h"
21-
#include "absl/strings/str_format.h"
2216
#include "nanobind/nanobind.h"
23-
#include "nanobind/stl/pair.h" // IWYU pragma: keep
24-
#include "jaxlib/gpu/gpu_kernel_helpers.h"
25-
#include "jaxlib/gpu/solver_handle_pool.h"
26-
#include "jaxlib/gpu/solver_kernels.h"
2717
#include "jaxlib/gpu/solver_kernels_ffi.h"
2818
#include "jaxlib/gpu/vendor.h"
2919
#include "jaxlib/kernel_nanobind_helpers.h"
30-
#include "xla/tsl/python/lib/core/numpy.h"
3120

3221
namespace jax {
3322
namespace JAX_GPU_NAMESPACE {
3423
namespace {
3524

3625
namespace nb = nanobind;
3726

38-
// Converts a NumPy dtype to a Type.
39-
SolverType DtypeToSolverType(const dtype& np_type) {
40-
static auto* types =
41-
new absl::flat_hash_map<std::pair<char, int>, SolverType>({
42-
{{'f', 4}, SolverType::F32},
43-
{{'f', 8}, SolverType::F64},
44-
{{'c', 8}, SolverType::C64},
45-
{{'c', 16}, SolverType::C128},
46-
});
47-
auto it = types->find({np_type.kind(), np_type.itemsize()});
48-
if (it == types->end()) {
49-
nb::str repr = nb::repr(np_type);
50-
throw std::invalid_argument(
51-
absl::StrFormat("Unsupported dtype %s", repr.c_str()));
52-
}
53-
return it->second;
54-
}
55-
56-
#ifdef JAX_GPU_CUDA
57-
58-
// csrlsvqr: Linear system solve via Sparse QR
59-
60-
// Returns a descriptor for a csrlsvqr operation.
61-
nb::bytes BuildCsrlsvqrDescriptor(const dtype& dtype, int n, int nnzA,
62-
int reorder, double tol) {
63-
SolverType type = DtypeToSolverType(dtype);
64-
return PackDescriptor(CsrlsvqrDescriptor{type, n, nnzA, reorder, tol});
65-
}
66-
67-
#endif // JAX_GPU_CUDA
68-
69-
// Returns the workspace size and a descriptor for a geqrf operation.
70-
std::pair<int, nb::bytes> BuildSytrdDescriptor(const dtype& dtype, bool lower,
71-
int b, int n) {
72-
SolverType type = DtypeToSolverType(dtype);
73-
auto h = SolverHandlePool::Borrow(/*stream=*/nullptr);
74-
JAX_THROW_IF_ERROR(h.status());
75-
auto& handle = *h;
76-
int lwork;
77-
gpusolverFillMode_t uplo =
78-
lower ? GPUSOLVER_FILL_MODE_LOWER : GPUSOLVER_FILL_MODE_UPPER;
79-
switch (type) {
80-
case SolverType::F32:
81-
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnSsytrd_bufferSize(
82-
handle.get(), uplo, n, /*A=*/nullptr, /*lda=*/n, /*D=*/nullptr,
83-
/*E=*/nullptr, /*tau=*/nullptr, &lwork)));
84-
break;
85-
case SolverType::F64:
86-
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnDsytrd_bufferSize(
87-
handle.get(), uplo, n, /*A=*/nullptr, /*lda=*/n, /*D=*/nullptr,
88-
/*E=*/nullptr, /*tau=*/nullptr, &lwork)));
89-
break;
90-
case SolverType::C64:
91-
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnChetrd_bufferSize(
92-
handle.get(), uplo, n, /*A=*/nullptr, /*lda=*/n, /*D=*/nullptr,
93-
/*E=*/nullptr, /*tau=*/nullptr, &lwork)));
94-
break;
95-
case SolverType::C128:
96-
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnZhetrd_bufferSize(
97-
handle.get(), uplo, n, /*A=*/nullptr, /*lda=*/n, /*D=*/nullptr,
98-
/*E=*/nullptr, /*tau=*/nullptr, &lwork)));
99-
break;
100-
}
101-
return {lwork, PackDescriptor(SytrdDescriptor{type, uplo, b, n, n, lwork})};
102-
}
103-
10427
nb::dict Registrations() {
10528
nb::dict dict;
106-
dict[JAX_GPU_PREFIX "solver_sytrd"] = EncapsulateFunction(Sytrd);
107-
108-
#ifdef JAX_GPU_CUDA
109-
dict["cusolver_csrlsvqr"] = EncapsulateFunction(Csrlsvqr);
110-
#endif // JAX_GPU_CUDA
11129

11230
dict[JAX_GPU_PREFIX "solver_getrf_ffi"] = EncapsulateFfiHandler(GetrfFfi);
11331
dict[JAX_GPU_PREFIX "solver_geqrf_ffi"] = EncapsulateFfiHandler(GeqrfFfi);
@@ -127,12 +45,7 @@ nb::dict Registrations() {
12745
}
12846

12947
NB_MODULE(_solver, m) {
130-
tsl::ImportNumpy();
13148
m.def("registrations", &Registrations);
132-
m.def("build_sytrd_descriptor", &BuildSytrdDescriptor);
133-
#ifdef JAX_GPU_CUDA
134-
m.def("build_csrlsvqr_descriptor", &BuildCsrlsvqrDescriptor);
135-
#endif // JAX_GPU_CUDA
13649
}
13750

13851
} // namespace

0 commit comments

Comments
 (0)