Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 59 additions & 30 deletions aten/src/ATen/native/mps/operations/UnaryOps.mm
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,16 @@
#include <ATen/native/mps/MPSGraphVenturaOps.h>

namespace at::native {

enum class MPSCumulativeOpType : uint8_t {
MPS_CUMSUM = 0,
MPS_CUMPROD = 1,
};
namespace mps {

typedef MPSGraphTensor* (^UnaryOpBlock)(MPSGraph*, MPSGraphTensor*);
using is_noop_p = std::function<bool(const Tensor&)>;


bool is_empty_tensor(const Tensor& self) {
return self.numel() == 0;
}
Expand Down Expand Up @@ -415,49 +419,74 @@ Tensor logit_mps(const Tensor& self, c10::optional<double> eps) {
}
}



TORCH_IMPL_FUNC(cumsum_out_mps)
(const Tensor& self,
int64_t dim,
c10::optional<ScalarType> dtype,
const Tensor& result) {

void cumulative_op_impl(
const Tensor& self,
int64_t dim,
c10::optional<ScalarType> dtype,
const Tensor& result,
MPSCumulativeOpType cumulativeOpType,
const std::string& op_name) {
bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
auto nDims = self.dim();
auto wrapped_dim = maybe_wrap_dim(dim, nDims);
TORCH_CHECK(wrapped_dim >=0 && wrapped_dim < std::max(1LL, self.ndimension()), "Expected wrapped dim to be between 0 and ", self.ndimension(), " but got ", wrapped_dim , "(original dim is ", dim, ")");
TORCH_CHECK(wrapped_dim >= 0 && wrapped_dim < std::max(1LL, self.ndimension()),
"Expected wrapped dim to be between 0 and ",
self.ndimension(),
" but got ",
wrapped_dim,
"(original dim is ",
dim,
")");

if (!is_macos_13_or_newer()) {
TORCH_WARN_ONCE("torch.cumsum supported by MPS on MacOS 13+, please upgrade");
TORCH_WARN_ONCE(op_name, " supported by MPS on MacOS 13+, please upgrade");
auto cpu_result = self.to(at::Device(kCPU)).cumsum(dim, dtype);
at::_copy_from_and_resize(cpu_result, result);
return;
}
auto input = dtype.has_value() ? self.to(dtype.value()) : self;

// issue #103810551: cumsum is horribly broken for int8, int16 and as chances for overflow is pretty high, cast to int32
// fixed in macOS 13.3
bool castInputData = (isIntegralType(input.scalar_type()) &&
input.scalar_type() != ScalarType::Int &&
// issue #103810551: int8, int16 have a high chance of overflow, cast to int32
bool castInputData = (isIntegralType(input.scalar_type(), false) &&
input.scalar_type() != ScalarType::Int &&
input.scalar_type() != ScalarType::Long);

TORCH_CHECK(macOS13_3_plus || input.scalar_type() != ScalarType::Long,
"MPS does not support cumsum op with int64 input. Support has been added in macOS 13.3");
"MPS does not support ", op_name, " op with int64 input. Support has been added in macOS 13.3");

mps::unary_op(input,
result,
op_name + std::to_string(dim),
^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
if (castInputData) {
inputTensor = mps::castMPSTensor(mpsGraph, inputTensor, ScalarType::Int);
}
MPSGraphTensor* cumulativeOpTensor = nil;
switch (cumulativeOpType) {
case MPSCumulativeOpType::MPS_CUMSUM:
cumulativeOpTensor = [mpsGraph cumulativeSumWithTensor:inputTensor axis:dim name:nil];
break;
case MPSCumulativeOpType::MPS_CUMPROD:
cumulativeOpTensor = [mpsGraph cumulativeProductWithTensor:inputTensor axis:dim name:nil];
break;
default:
TORCH_CHECK(false, "Undefined cumulative op type");
}
if ((mps::getMPSDataType(result.scalar_type()) != [cumulativeOpTensor dataType]) || castInputData) {
return mps::castMPSTensor(mpsGraph, cumulativeOpTensor, result.scalar_type());
}
return cumulativeOpTensor;
});
}

mps::unary_op(input, result, "cumsum_out_mp" + std::to_string(dim),
^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
TORCH_IMPL_FUNC(cumsum_out_mps)
(const Tensor& self, int64_t dim, c10::optional<ScalarType> dtype, const Tensor& result) {
cumulative_op_impl(self, dim, dtype, result, MPSCumulativeOpType::MPS_CUMSUM, "cumsum_out");
}

if (castInputData) {
inputTensor = mps::castMPSTensor(mpsGraph, inputTensor, ScalarType::Int);
}
auto rc = [mpsGraph cumulativeSumWithTensor: inputTensor
axis: dim
name: nil];
if ((mps::getMPSDataType(result.scalar_type()) != [rc dataType]) || castInputData) {
return mps::castMPSTensor(mpsGraph, rc, result.scalar_type());
}
return rc;
});
TORCH_IMPL_FUNC(cumprod_out_mps)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DenisVieriu97 , can you just go ahead and add , cummin and cummax as well ?

(const Tensor& self, int64_t dim, c10::optional<ScalarType> dtype, const Tensor& result) {
cumulative_op_impl(self, dim, dtype, result, MPSCumulativeOpType::MPS_CUMPROD, "cumprod_out");
}

} // namespace at::native
}
1 change: 1 addition & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1850,6 +1850,7 @@
device_check: NoCheck # TensorIterator
dispatch:
CPU, CUDA: cumprod_out
MPS: cumprod_out_mps

- func: cumprod.dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor
device_check: NoCheck # TensorIterator
Expand Down
34 changes: 18 additions & 16 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -2797,22 +2797,26 @@ def helper(dtype, noncontiguous, dim):
with self.subTest(dtype=dtype, noncontiguous=noncontiguous, dim=dim):
helper(dtype, noncontiguous, dim)

def test_cumsum_all_dtypes(self):
def helper(dtype):
t = torch.tensor([1, 1, 1, 1], device="mps", dtype=dtype)
t_cpu = torch.tensor([1, 1, 1, 1], device="cpu")
def test_cumulative_ops_all_dtypes(self):
def helper(op, dtype):
t = torch.tensor([1, 2, 3, 4], device="mps", dtype=dtype)
t_cpu = torch.tensor([1, 2, 3, 4], device="cpu")

a = t.cumsum(0, dtype=dtype)
a_cpu = t_cpu.cumsum(0, dtype=dtype)
a = op(t, 0, dtype=dtype)
a_cpu = op(t_cpu, 0, dtype=dtype)

self.assertEqual(a.cpu(), a_cpu)
[helper(dtype) for dtype in [torch.int8, torch.int16, torch.int32, torch.float32]]

try:
helper(torch.int64)
except Exception as e:
e_string = str(e)
self.assertEqual(e_string, "MPS does not support cumsum op with int64 input. Support has been added in macOS 13.3")
for op_and_name in [[torch.cumsum, "cumsum_out"], [torch.cumprod, "cumprod_out"]]:
op = op_and_name[0]
name = op_and_name[1]
[helper(op, dtype) for dtype in [torch.int8, torch.int16, torch.int32, torch.float32]]

try:
helper(op, torch.int64)
except Exception as e:
e_string = str(e)
self.assertEqual(e_string, f"MPS does not support {name} op with int64 input. Support has been added in macOS 13.3")

def test_gelu_tanh(self):
def helper(shape):
Expand Down Expand Up @@ -9582,7 +9586,7 @@ class TestConsistency(TestCaseMPS):
'cov': ['f32', 'i16', 'i32', 'i64', 'u8'],
'cummax': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'],
'cummin': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'],
'cumprod': ['f32', 'i16', 'i32', 'i64', 'u8'],
'cumprod': ['b8', 'i8', 'f32', 'i16', 'i32', 'i64', 'u8'],
'cumsum': ['i8', 'b8', 'f32', 'i16', 'i32', 'i64', 'u8'],
'deg2rad': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'diag': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
Expand Down Expand Up @@ -9752,7 +9756,7 @@ class TestConsistency(TestCaseMPS):
'lu_unpack': ['f32'],
'masked.argmax': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'masked.argmin': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'masked.cumprod': ['f32', 'i16', 'i32', 'i64', 'u8'],
'masked.cumprod': ['b8', 'i8', 'f32', 'i16', 'i32', 'i64', 'u8'],
'masked.cumsum': ['f32', 'i16', 'i32', 'i64', 'u8'],
'masked.log_softmax': ['f32'],
'masked.logaddexp': ['f32'],
Expand Down Expand Up @@ -10819,7 +10823,6 @@ class TestConsistency(TestCaseMPS):
'copysign': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8],
'cummax': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8],
'cummin': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8],
'cumprod': [torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8],
'digamma': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8],
'erfc': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8],
'erfinv': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8],
Expand Down Expand Up @@ -10882,7 +10885,6 @@ class TestConsistency(TestCaseMPS):
'lu': [torch.float32],
'lu_solve': [torch.float32],
'lu_unpack': [torch.float32],
'masked.cumprod': [torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8],
'masked.median': [torch.float32],
'matrix_exp': [torch.float32],
'mode': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8],
Expand Down