Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 47 additions & 4 deletions python/paddle/tensor/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4022,23 +4022,58 @@ def pinv(
return out_2


def solve(x: Tensor, y: Tensor, name: str | None = None) -> Tensor:
def _check_right_solve_shape(x, y):
"""check the input shape of x and y for solve when left is False"""
x_shape = x.shape[-2:]
if len(y.shape) == 1:
raise ValueError(
"Incompatible shapes of X and Y for the equation Out * X = Y, "
f"where input X's matrix shape is {x_shape} and"
f"input Y's matrix shape is {list(y.shape).append(1)}"
)
else:
y_shape = y.shape[-2:]
if x_shape[0] != y_shape[1]:
raise ValueError(
"Incompatible shapes of X and Y for the equation Out * X = Y, "
f"where input X's matrix shape is {x_shape} and"
f"input Y's matrix shape is {y_shape}"
)


def _transpose_last_2dim(x):
"""transpose the last 2 dimension of a tensor"""
x_new_dims = list(range(len(x.shape)))
x_new_dims[-1], x_new_dims[-2] = x_new_dims[-2], x_new_dims[-1]
x = transpose(x, x_new_dims)
return x


def solve(
x: Tensor, y: Tensor, left: bool = True, name: str | None = None
) -> Tensor:
r"""

Computes the solution of a square system of linear equations with a unique solution for input 'X' and 'Y'.
Let :math:`X` be a square matrix or a batch of square matrices, :math:`Y` be
a vector/matrix or a batch of vectors/matrices, the equation should be:
a vector/matrix or a batch of vectors/matrices. When `left` is True, the equation should be:

.. math::
Out = X^-1 * Y

When `left` is False, the equation should be:

.. math::
Out = Y * X^-1

Specifically, this system of linear equations has one solution if and only if input 'X' is invertible.

Args:
x (Tensor): A square matrix or a batch of square matrices. Its shape should be ``[*, M, M]``, where ``*`` is zero or
more batch dimensions. Its data type should be float32 or float64.
y (Tensor): A vector/matrix or a batch of vectors/matrices. Its shape should be ``[*, M, K]``, where ``*`` is zero or
more batch dimensions. Its data type should be float32 or float64.
left (bool, optional): Whether to solve the system :math:`X * Out = Y` or :math:`Out * X = Y`. Default: True.
name (str|None, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.

Expand All @@ -4064,8 +4099,13 @@ def solve(x: Tensor, y: Tensor, name: str | None = None) -> Tensor:
Tensor(shape=[2], dtype=float64, place=Place(cpu), stop_gradient=True,
[2., 3.])
"""
if not left:
_check_right_solve_shape(x, y)
x = _transpose_last_2dim(x)
y = _transpose_last_2dim(y)

if in_dynamic_or_pir_mode():
return _C_ops.solve(x, y)
out = _C_ops.solve(x, y)
else:
inputs = {"X": [x], "Y": [y]}
helper = LayerHelper("solve", **locals())
Expand All @@ -4076,7 +4116,10 @@ def solve(x: Tensor, y: Tensor, name: str | None = None) -> Tensor:
helper.append_op(
type="solve", inputs={"X": x, "Y": y}, outputs={"Out": out}
)
return out

if not left:
out = _transpose_last_2dim(out)
return out


def triangular_solve(
Expand Down
298 changes: 298 additions & 0 deletions test/legacy_test/test_solve_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,16 @@ def test_errors(self):
y7 = paddle.static.data(name="y7", shape=[2, 4, 3], dtype="float64")
self.assertRaises(ValueError, paddle.linalg.solve, x7, y7)

# The shape of y should not be 1 when left = False. (if y is vector it should be a row vector)
x8 = paddle.static.data(name="x8", shape=[3, 3], dtype="float64")
y8 = paddle.static.data(name="y8", shape=[3], dtype="float64")
self.assertRaises(ValueError, paddle.linalg.solve, x8, y8, False)

# The height of x should equal the width of y when left = False.
x9 = paddle.static.data(name="x9", shape=[2, 5, 5], dtype="float64")
y9 = paddle.static.data(name="y9", shape=[5, 3], dtype="float64")
self.assertRaises(ValueError, paddle.linalg.solve, x9, y9, False)


# 2D + vector case, FP64
class TestSolveOpAPI_1(unittest.TestCase):
Expand Down Expand Up @@ -570,6 +580,294 @@ def run(place):
run(place)


def np_transpose_last_2dim(x):
x_new_dims = list(range(len(x.shape)))
x_new_dims[-1], x_new_dims[-2] = x_new_dims[-2], x_new_dims[-1]
x = np.transpose(x, x_new_dims)
return x


def np_solve_right(x, y):
x = np_transpose_last_2dim(x)
y = np_transpose_last_2dim(y)
out = np.linalg.solve(x, y)
out = np_transpose_last_2dim(out)
return out


# 2D + vector right case, FP64
class TestSolveOpAPIRight_1(unittest.TestCase):
def setUp(self):
np.random.seed(2021)
self.place = []
self.dtype = "float64"
if (
os.environ.get('FLAGS_CI_both_cpu_and_gpu', 'False').lower()
in ['1', 'true', 'on']
or not core.is_compiled_with_cuda()
):
self.place.append(paddle.CPUPlace())
if core.is_compiled_with_cuda():
self.place.append(paddle.CUDAPlace(0))

def check_static_result(self, place):
with base.program_guard(base.Program(), base.Program()):
paddle_input_x = paddle.static.data(
name="input_x", shape=[3, 3], dtype=self.dtype
)
paddle_input_y = paddle.static.data(
name="input_y", shape=[1, 3], dtype=self.dtype
)
paddle_result = paddle.linalg.solve(
paddle_input_x, paddle_input_y, left=False
)

np_input_x = np.random.random([3, 3]).astype(self.dtype)
np_input_y = np.random.random([1, 3]).astype(self.dtype)

np_result = np_solve_right(np_input_x, np_input_y)

exe = base.Executor(place)
fetches = exe.run(
base.default_main_program(),
feed={"input_x": np_input_x, "input_y": np_input_y},
fetch_list=[paddle_result],
)
np.testing.assert_allclose(fetches[0], np_result, rtol=1e-05)

def test_static(self):
for place in self.place:
self.check_static_result(place=place)

def test_dygraph(self):
def run(place):
paddle.disable_static(place)
np.random.seed(2021)
input_x_np = np.random.random([3, 3]).astype(self.dtype)
input_y_np = np.random.random([1, 3]).astype(self.dtype)

tensor_input_x = paddle.to_tensor(input_x_np)
tensor_input_y = paddle.to_tensor(input_y_np)

numpy_output = np_solve_right(input_x_np, input_y_np)
paddle_output = paddle.linalg.solve(
tensor_input_x, tensor_input_y, left=False
)
np.testing.assert_allclose(
numpy_output, paddle_output.numpy(), rtol=1e-05
)
self.assertEqual(numpy_output.shape, paddle_output.numpy().shape)
paddle.enable_static()

for place in self.place:
run(place)


# 2D normal right case, FP64
class TestSolveOpAPIRight_2(unittest.TestCase):
def setUp(self):
np.random.seed(2021)
self.place = []
self.dtype = "float64"
if (
os.environ.get('FLAGS_CI_both_cpu_and_gpu', 'False').lower()
in ['1', 'true', 'on']
or not core.is_compiled_with_cuda()
):
self.place.append(paddle.CPUPlace())
if core.is_compiled_with_cuda():
self.place.append(paddle.CUDAPlace(0))

def check_static_result(self, place):
paddle.enable_static()
with base.program_guard(base.Program(), base.Program()):
paddle_input_x = paddle.static.data(
name="input_x", shape=[10, 10], dtype=self.dtype
)
paddle_input_y = paddle.static.data(
name="input_y", shape=[4, 10], dtype=self.dtype
)
paddle_result = paddle.linalg.solve(
paddle_input_x, paddle_input_y, left=False
)

np_input_x = np.random.random([10, 10]).astype(self.dtype)
np_input_y = np.random.random([4, 10]).astype(self.dtype)

np_result = np_solve_right(np_input_x, np_input_y)

exe = base.Executor(place)
fetches = exe.run(
base.default_main_program(),
feed={"input_x": np_input_x, "input_y": np_input_y},
fetch_list=[paddle_result],
)
np.testing.assert_allclose(fetches[0], np_result, rtol=1e-05)

def test_static(self):
for place in self.place:
self.check_static_result(place=place)

def test_dygraph(self):
def run(place):
paddle.disable_static(place)
np.random.seed(2021)
input_x_np = np.random.random([10, 10]).astype(self.dtype)
input_y_np = np.random.random([4, 10]).astype(self.dtype)
tensor_input_x = paddle.to_tensor(input_x_np)
tensor_input_y = paddle.to_tensor(input_y_np)

numpy_output = np_solve_right(input_x_np, input_y_np)
paddle_output = paddle.linalg.solve(
tensor_input_x, tensor_input_y, left=False
)
np.testing.assert_allclose(
numpy_output, paddle_output.numpy(), rtol=1e-05
)
self.assertEqual(numpy_output.shape, paddle_output.numpy().shape)
paddle.enable_static()

for place in self.place:
run(place)


# 2D normal right case, FP32
class TestSolveOpAPIRight_3(unittest.TestCase):
def setUp(self):
np.random.seed(2021)
self.place = []
self.dtype = "float32"
if (
os.environ.get('FLAGS_CI_both_cpu_and_gpu', 'False').lower()
in ['1', 'true', 'on']
or not core.is_compiled_with_cuda()
):
self.place.append(paddle.CPUPlace())
if core.is_compiled_with_cuda():
self.place.append(paddle.CUDAPlace(0))

def check_static_result(self, place):
paddle.enable_static()
with base.program_guard(base.Program(), base.Program()):
paddle_input_x = paddle.static.data(
name="input_x", shape=[10, 10], dtype=self.dtype
)
paddle_input_y = paddle.static.data(
name="input_y", shape=[6, 10], dtype=self.dtype
)
paddle_result = paddle.linalg.solve(
paddle_input_x, paddle_input_y, left=False
)

np_input_x = np.random.random([10, 10]).astype(self.dtype)
np_input_y = np.random.random([6, 10]).astype(self.dtype)

np_result = np_solve_right(np_input_x, np_input_y)

exe = base.Executor(place)
fetches = exe.run(
base.default_main_program(),
feed={"input_x": np_input_x, "input_y": np_input_y},
fetch_list=[paddle_result],
)
np.testing.assert_allclose(fetches[0], np_result, rtol=0.0001)

def test_static(self):
for place in self.place:
self.check_static_result(place=place)

def test_dygraph(self):
def run(place):
paddle.disable_static(place)
np.random.seed(2021)
input_x_np = np.random.random([10, 10]).astype(self.dtype)
input_y_np = np.random.random([6, 10]).astype(self.dtype)

tensor_input_x = paddle.to_tensor(input_x_np)
tensor_input_y = paddle.to_tensor(input_y_np)

numpy_output = np_solve_right(input_x_np, input_y_np)
paddle_output = paddle.linalg.solve(
tensor_input_x, tensor_input_y, left=False
)
np.testing.assert_allclose(
numpy_output, paddle_output.numpy(), rtol=0.0001
)
self.assertEqual(numpy_output.shape, paddle_output.numpy().shape)
paddle.enable_static()

for place in self.place:
run(place)


# 3D + y broadcast right case, FP64
class TestSolveOpAPIRight_4(unittest.TestCase):
def setUp(self):
np.random.seed(2021)
self.place = []
self.dtype = "float64"
if (
os.environ.get('FLAGS_CI_both_cpu_and_gpu', 'False').lower()
in ['1', 'true', 'on']
or not core.is_compiled_with_cuda()
):
self.place.append(paddle.CPUPlace())
if core.is_compiled_with_cuda():
self.place.append(paddle.CUDAPlace(0))

def check_static_result(self, place):
with base.program_guard(base.Program(), base.Program()):
paddle_input_x = paddle.static.data(
name="input_x", shape=[2, 3, 3], dtype=self.dtype
)
paddle_input_y = paddle.static.data(
name="input_y", shape=[1, 3, 3], dtype=self.dtype
)
paddle_result = paddle.linalg.solve(
paddle_input_x, paddle_input_y, left=False
)

np_input_x = np.random.random([2, 3, 3]).astype(self.dtype)
np_input_y = np.random.random([1, 3, 3]).astype(self.dtype)

np_result = np_solve_right(np_input_x, np_input_y)

exe = base.Executor(place)
fetches = exe.run(
base.default_main_program(),
feed={"input_x": np_input_x, "input_y": np_input_y},
fetch_list=[paddle_result],
)
np.testing.assert_allclose(fetches[0], np_result, rtol=1e-05)

def test_static(self):
for place in self.place:
self.check_static_result(place=place)

def test_dygraph(self):
def run(place):
paddle.disable_static(place)
np.random.seed(2021)
input_x_np = np.random.random([2, 3, 3]).astype(self.dtype)
input_y_np = np.random.random([1, 3, 3]).astype(self.dtype)

tensor_input_x = paddle.to_tensor(input_x_np)
tensor_input_y = paddle.to_tensor(input_y_np)

numpy_output = np_solve_right(input_x_np, input_y_np)
paddle_output = paddle.linalg.solve(
tensor_input_x, tensor_input_y, left=False
)
np.testing.assert_allclose(
numpy_output, paddle_output.numpy(), rtol=1e-05
)
self.assertEqual(numpy_output.shape, paddle_output.numpy().shape)
paddle.enable_static()

for place in self.place:
run(place)


class TestSolveOpSingularAPI(unittest.TestCase):
# Singular matrix is ​​not invertible
def setUp(self):
Expand Down