Skip to content

Commit 0a6da93

Browse files
committed
drop duplicate error handling
1 parent 06ad50b commit 0a6da93

File tree

1 file changed

+0
-32
lines changed

1 file changed

+0
-32
lines changed

numcu/src/numcu.cu

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -12,52 +12,20 @@ namespace py = pybind11;
1212
template <typename T>
1313
void elem_div(py::buffer num, py::buffer den, py::buffer dst, T zeroDivDefault) {
1414
py::buffer_info src_num = num.request(), src_den = den.request(), dst_out = dst.request(true);
15-
const std::vector<std::string> types = {src_num.format, src_den.format, dst_out.format};
16-
for (auto &type : types) {
17-
if (type != py::format_descriptor<T>::format()) throw py::type_error("unexpected type");
18-
}
19-
if (src_num.ndim != src_den.ndim) throw py::index_error("inputs must have same ndim");
20-
if (src_num.ndim != dst_out.ndim) throw py::index_error("output must have same ndim");
21-
for (size_t i = 0; i < src_num.ndim; i++) {
22-
if (src_num.shape[i] != src_den.shape[i]) throw py::index_error("inputs must have same shape");
23-
if (src_num.shape[i] != dst_out.shape[i]) throw py::index_error("output must have same shape");
24-
}
25-
2615
div(static_cast<T *>(dst_out.ptr), static_cast<T *>(src_num.ptr), static_cast<T *>(src_den.ptr),
2716
dst_out.size, zeroDivDefault);
2817
CUDA_PyErr();
2918
}
3019

3120
template <typename T> void elem_mul(py::buffer a, py::buffer b, py::buffer dst) {
3221
py::buffer_info src_a = a.request(), src_b = b.request(), dst_out = dst.request(true);
33-
const std::vector<std::string> types = {src_a.format, src_b.format, dst_out.format};
34-
for (auto &type : types) {
35-
if (type != py::format_descriptor<T>::format()) throw py::type_error("unexpected type");
36-
}
37-
if (src_a.ndim != src_b.ndim) throw py::index_error("inputs must have same ndim");
38-
if (src_a.ndim != dst_out.ndim) throw py::index_error("output must have same ndim");
39-
for (size_t i = 0; i < src_a.ndim; i++) {
40-
if (src_a.shape[i] != src_b.shape[i]) throw py::index_error("inputs must have same shape");
41-
if (src_a.shape[i] != dst_out.shape[i]) throw py::index_error("output must have same shape");
42-
}
43-
4422
mul(static_cast<T *>(dst_out.ptr), static_cast<T *>(src_a.ptr), static_cast<T *>(src_b.ptr),
4523
dst_out.size);
4624
CUDA_PyErr();
4725
}
4826

4927
template <typename T> void elem_add(py::buffer a, py::buffer b, py::buffer dst) {
5028
py::buffer_info src_a = a.request(), src_b = b.request(), dst_out = dst.request(true);
51-
const std::vector<std::string> types = {src_a.format, src_b.format, dst_out.format};
52-
for (auto &type : types) {
53-
if (type != py::format_descriptor<T>::format()) throw py::type_error("unexpected type");
54-
}
55-
if (src_a.ndim != src_b.ndim) throw py::index_error("inputs must have same ndim");
56-
if (src_a.ndim != dst_out.ndim) throw py::index_error("output must have same ndim");
57-
for (size_t i = 0; i < src_a.ndim; i++) {
58-
if (src_a.shape[i] != src_b.shape[i]) throw py::index_error("inputs must have same shape");
59-
if (src_a.shape[i] != dst_out.shape[i]) throw py::index_error("output must have same shape");
60-
}
6129
add(static_cast<T *>(dst_out.ptr), static_cast<T *>(src_a.ptr), static_cast<T *>(src_b.ptr),
6230
dst_out.size);
6331
CUDA_PyErr();

0 commit comments

Comments
 (0)