@@ -12,52 +12,20 @@ namespace py = pybind11;
12
12
template <typename T>
13
13
void elem_div (py::buffer num, py::buffer den, py::buffer dst, T zeroDivDefault) {
14
14
py::buffer_info src_num = num.request (), src_den = den.request (), dst_out = dst.request (true );
15
- const std::vector<std::string> types = {src_num.format , src_den.format , dst_out.format };
16
- for (auto &type : types) {
17
- if (type != py::format_descriptor<T>::format ()) throw py::type_error (" unexpected type" );
18
- }
19
- if (src_num.ndim != src_den.ndim ) throw py::index_error (" inputs must have same ndim" );
20
- if (src_num.ndim != dst_out.ndim ) throw py::index_error (" output must have same ndim" );
21
- for (size_t i = 0 ; i < src_num.ndim ; i++) {
22
- if (src_num.shape [i] != src_den.shape [i]) throw py::index_error (" inputs must have same shape" );
23
- if (src_num.shape [i] != dst_out.shape [i]) throw py::index_error (" output must have same shape" );
24
- }
25
-
26
15
div (static_cast <T *>(dst_out.ptr ), static_cast <T *>(src_num.ptr ), static_cast <T *>(src_den.ptr ),
27
16
dst_out.size , zeroDivDefault);
28
17
CUDA_PyErr ();
29
18
}
30
19
31
20
template <typename T> void elem_mul (py::buffer a, py::buffer b, py::buffer dst) {
32
21
py::buffer_info src_a = a.request (), src_b = b.request (), dst_out = dst.request (true );
33
- const std::vector<std::string> types = {src_a.format , src_b.format , dst_out.format };
34
- for (auto &type : types) {
35
- if (type != py::format_descriptor<T>::format ()) throw py::type_error (" unexpected type" );
36
- }
37
- if (src_a.ndim != src_b.ndim ) throw py::index_error (" inputs must have same ndim" );
38
- if (src_a.ndim != dst_out.ndim ) throw py::index_error (" output must have same ndim" );
39
- for (size_t i = 0 ; i < src_a.ndim ; i++) {
40
- if (src_a.shape [i] != src_b.shape [i]) throw py::index_error (" inputs must have same shape" );
41
- if (src_a.shape [i] != dst_out.shape [i]) throw py::index_error (" output must have same shape" );
42
- }
43
-
44
22
mul (static_cast <T *>(dst_out.ptr ), static_cast <T *>(src_a.ptr ), static_cast <T *>(src_b.ptr ),
45
23
dst_out.size );
46
24
CUDA_PyErr ();
47
25
}
48
26
49
27
template <typename T> void elem_add (py::buffer a, py::buffer b, py::buffer dst) {
50
28
py::buffer_info src_a = a.request (), src_b = b.request (), dst_out = dst.request (true );
51
- const std::vector<std::string> types = {src_a.format , src_b.format , dst_out.format };
52
- for (auto &type : types) {
53
- if (type != py::format_descriptor<T>::format ()) throw py::type_error (" unexpected type" );
54
- }
55
- if (src_a.ndim != src_b.ndim ) throw py::index_error (" inputs must have same ndim" );
56
- if (src_a.ndim != dst_out.ndim ) throw py::index_error (" output must have same ndim" );
57
- for (size_t i = 0 ; i < src_a.ndim ; i++) {
58
- if (src_a.shape [i] != src_b.shape [i]) throw py::index_error (" inputs must have same shape" );
59
- if (src_a.shape [i] != dst_out.shape [i]) throw py::index_error (" output must have same shape" );
60
- }
61
29
add (static_cast <T *>(dst_out.ptr ), static_cast <T *>(src_a.ptr ), static_cast <T *>(src_b.ptr ),
62
30
dst_out.size );
63
31
CUDA_PyErr ();
0 commit comments