Skip to content

Commit ed69456

Browse files
authored
update solve (#67071)
1 parent 3ad5f26 commit ed69456

File tree

2 files changed

+345
-4
lines changed

2 files changed

+345
-4
lines changed

python/paddle/tensor/linalg.py

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4019,23 +4019,58 @@ def pinv(
40194019
return out_2
40204020

40214021

4022-
def solve(x: Tensor, y: Tensor, name: str | None = None) -> Tensor:
4022+
def _check_right_solve_shape(x, y):
4023+
"""check the input shape of x and y for solve when left is False"""
4024+
x_shape = x.shape[-2:]
4025+
if len(y.shape) == 1:
4026+
raise ValueError(
4027+
"Incompatible shapes of X and Y for the equation Out * X = Y, "
4028+
f"where input X's matrix shape is {x_shape} and"
4029+
f"input Y's matrix shape is {list(y.shape).append(1)}"
4030+
)
4031+
else:
4032+
y_shape = y.shape[-2:]
4033+
if x_shape[0] != y_shape[1]:
4034+
raise ValueError(
4035+
"Incompatible shapes of X and Y for the equation Out * X = Y, "
4036+
f"where input X's matrix shape is {x_shape} and"
4037+
f"input Y's matrix shape is {y_shape}"
4038+
)
4039+
4040+
4041+
def _transpose_last_2dim(x):
4042+
"""transpose the last 2 dimension of a tensor"""
4043+
x_new_dims = list(range(len(x.shape)))
4044+
x_new_dims[-1], x_new_dims[-2] = x_new_dims[-2], x_new_dims[-1]
4045+
x = transpose(x, x_new_dims)
4046+
return x
4047+
4048+
4049+
def solve(
4050+
x: Tensor, y: Tensor, left: bool = True, name: str | None = None
4051+
) -> Tensor:
40234052
r"""
40244053
40254054
Computes the solution of a square system of linear equations with a unique solution for input 'X' and 'Y'.
40264055
Let :math:`X` be a square matrix or a batch of square matrices, :math:`Y` be
4027-
a vector/matrix or a batch of vectors/matrices, the equation should be:
4056+
a vector/matrix or a batch of vectors/matrices. When `left` is True, the equation should be:
40284057
40294058
.. math::
40304059
Out = X^-1 * Y
40314060
4061+
When `left` is False, the equation should be:
4062+
4063+
.. math::
4064+
Out = Y * X^-1
4065+
40324066
Specifically, this system of linear equations has one solution if and only if input 'X' is invertible.
40334067
40344068
Args:
40354069
x (Tensor): A square matrix or a batch of square matrices. Its shape should be ``[*, M, M]``, where ``*`` is zero or
40364070
more batch dimensions. Its data type should be float32 or float64.
40374071
y (Tensor): A vector/matrix or a batch of vectors/matrices. Its shape should be ``[*, M, K]``, where ``*`` is zero or
40384072
more batch dimensions. Its data type should be float32 or float64.
4073+
left (bool, optional): Whether to solve the system :math:`X * Out = Y` or :math:`Out * X = Y`. Default: True.
40394074
name (str|None, optional): Name for the operation (optional, default is None).
40404075
For more information, please refer to :ref:`api_guide_Name`.
40414076
@@ -4061,8 +4096,13 @@ def solve(x: Tensor, y: Tensor, name: str | None = None) -> Tensor:
40614096
Tensor(shape=[2], dtype=float64, place=Place(cpu), stop_gradient=True,
40624097
[2., 3.])
40634098
"""
4099+
if not left:
4100+
_check_right_solve_shape(x, y)
4101+
x = _transpose_last_2dim(x)
4102+
y = _transpose_last_2dim(y)
4103+
40644104
if in_dynamic_or_pir_mode():
4065-
return _C_ops.solve(x, y)
4105+
out = _C_ops.solve(x, y)
40664106
else:
40674107
inputs = {"X": [x], "Y": [y]}
40684108
helper = LayerHelper("solve", **locals())
@@ -4073,7 +4113,10 @@ def solve(x: Tensor, y: Tensor, name: str | None = None) -> Tensor:
40734113
helper.append_op(
40744114
type="solve", inputs={"X": x, "Y": y}, outputs={"Out": out}
40754115
)
4076-
return out
4116+
4117+
if not left:
4118+
out = _transpose_last_2dim(out)
4119+
return out
40774120

40784121

40794122
def triangular_solve(

test/legacy_test/test_solve_op.py

Lines changed: 298 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,16 @@ def test_errors(self):
304304
y7 = paddle.static.data(name="y7", shape=[2, 4, 3], dtype="float64")
305305
self.assertRaises(ValueError, paddle.linalg.solve, x7, y7)
306306

307+
# The shape of y should not be 1 when left = False. (if y is vector it should be a row vector)
308+
x8 = paddle.static.data(name="x8", shape=[3, 3], dtype="float64")
309+
y8 = paddle.static.data(name="y8", shape=[3], dtype="float64")
310+
self.assertRaises(ValueError, paddle.linalg.solve, x8, y8, False)
311+
312+
# The height of x should equal the width of y when left = False.
313+
x9 = paddle.static.data(name="x9", shape=[2, 5, 5], dtype="float64")
314+
y9 = paddle.static.data(name="y9", shape=[5, 3], dtype="float64")
315+
self.assertRaises(ValueError, paddle.linalg.solve, x9, y9, False)
316+
307317

308318
# 2D + vector case, FP64
309319
class TestSolveOpAPI_1(unittest.TestCase):
@@ -570,6 +580,294 @@ def run(place):
570580
run(place)
571581

572582

583+
def np_transpose_last_2dim(x):
584+
x_new_dims = list(range(len(x.shape)))
585+
x_new_dims[-1], x_new_dims[-2] = x_new_dims[-2], x_new_dims[-1]
586+
x = np.transpose(x, x_new_dims)
587+
return x
588+
589+
590+
def np_solve_right(x, y):
591+
x = np_transpose_last_2dim(x)
592+
y = np_transpose_last_2dim(y)
593+
out = np.linalg.solve(x, y)
594+
out = np_transpose_last_2dim(out)
595+
return out
596+
597+
598+
# 2D + vector right case, FP64
599+
class TestSolveOpAPIRight_1(unittest.TestCase):
600+
def setUp(self):
601+
np.random.seed(2021)
602+
self.place = []
603+
self.dtype = "float64"
604+
if (
605+
os.environ.get('FLAGS_CI_both_cpu_and_gpu', 'False').lower()
606+
in ['1', 'true', 'on']
607+
or not core.is_compiled_with_cuda()
608+
):
609+
self.place.append(paddle.CPUPlace())
610+
if core.is_compiled_with_cuda():
611+
self.place.append(paddle.CUDAPlace(0))
612+
613+
def check_static_result(self, place):
614+
with base.program_guard(base.Program(), base.Program()):
615+
paddle_input_x = paddle.static.data(
616+
name="input_x", shape=[3, 3], dtype=self.dtype
617+
)
618+
paddle_input_y = paddle.static.data(
619+
name="input_y", shape=[1, 3], dtype=self.dtype
620+
)
621+
paddle_result = paddle.linalg.solve(
622+
paddle_input_x, paddle_input_y, left=False
623+
)
624+
625+
np_input_x = np.random.random([3, 3]).astype(self.dtype)
626+
np_input_y = np.random.random([1, 3]).astype(self.dtype)
627+
628+
np_result = np_solve_right(np_input_x, np_input_y)
629+
630+
exe = base.Executor(place)
631+
fetches = exe.run(
632+
base.default_main_program(),
633+
feed={"input_x": np_input_x, "input_y": np_input_y},
634+
fetch_list=[paddle_result],
635+
)
636+
np.testing.assert_allclose(fetches[0], np_result, rtol=1e-05)
637+
638+
def test_static(self):
639+
for place in self.place:
640+
self.check_static_result(place=place)
641+
642+
def test_dygraph(self):
643+
def run(place):
644+
paddle.disable_static(place)
645+
np.random.seed(2021)
646+
input_x_np = np.random.random([3, 3]).astype(self.dtype)
647+
input_y_np = np.random.random([1, 3]).astype(self.dtype)
648+
649+
tensor_input_x = paddle.to_tensor(input_x_np)
650+
tensor_input_y = paddle.to_tensor(input_y_np)
651+
652+
numpy_output = np_solve_right(input_x_np, input_y_np)
653+
paddle_output = paddle.linalg.solve(
654+
tensor_input_x, tensor_input_y, left=False
655+
)
656+
np.testing.assert_allclose(
657+
numpy_output, paddle_output.numpy(), rtol=1e-05
658+
)
659+
self.assertEqual(numpy_output.shape, paddle_output.numpy().shape)
660+
paddle.enable_static()
661+
662+
for place in self.place:
663+
run(place)
664+
665+
666+
# 2D normal right case, FP64
667+
class TestSolveOpAPIRight_2(unittest.TestCase):
668+
def setUp(self):
669+
np.random.seed(2021)
670+
self.place = []
671+
self.dtype = "float64"
672+
if (
673+
os.environ.get('FLAGS_CI_both_cpu_and_gpu', 'False').lower()
674+
in ['1', 'true', 'on']
675+
or not core.is_compiled_with_cuda()
676+
):
677+
self.place.append(paddle.CPUPlace())
678+
if core.is_compiled_with_cuda():
679+
self.place.append(paddle.CUDAPlace(0))
680+
681+
def check_static_result(self, place):
682+
paddle.enable_static()
683+
with base.program_guard(base.Program(), base.Program()):
684+
paddle_input_x = paddle.static.data(
685+
name="input_x", shape=[10, 10], dtype=self.dtype
686+
)
687+
paddle_input_y = paddle.static.data(
688+
name="input_y", shape=[4, 10], dtype=self.dtype
689+
)
690+
paddle_result = paddle.linalg.solve(
691+
paddle_input_x, paddle_input_y, left=False
692+
)
693+
694+
np_input_x = np.random.random([10, 10]).astype(self.dtype)
695+
np_input_y = np.random.random([4, 10]).astype(self.dtype)
696+
697+
np_result = np_solve_right(np_input_x, np_input_y)
698+
699+
exe = base.Executor(place)
700+
fetches = exe.run(
701+
base.default_main_program(),
702+
feed={"input_x": np_input_x, "input_y": np_input_y},
703+
fetch_list=[paddle_result],
704+
)
705+
np.testing.assert_allclose(fetches[0], np_result, rtol=1e-05)
706+
707+
def test_static(self):
708+
for place in self.place:
709+
self.check_static_result(place=place)
710+
711+
def test_dygraph(self):
712+
def run(place):
713+
paddle.disable_static(place)
714+
np.random.seed(2021)
715+
input_x_np = np.random.random([10, 10]).astype(self.dtype)
716+
input_y_np = np.random.random([4, 10]).astype(self.dtype)
717+
tensor_input_x = paddle.to_tensor(input_x_np)
718+
tensor_input_y = paddle.to_tensor(input_y_np)
719+
720+
numpy_output = np_solve_right(input_x_np, input_y_np)
721+
paddle_output = paddle.linalg.solve(
722+
tensor_input_x, tensor_input_y, left=False
723+
)
724+
np.testing.assert_allclose(
725+
numpy_output, paddle_output.numpy(), rtol=1e-05
726+
)
727+
self.assertEqual(numpy_output.shape, paddle_output.numpy().shape)
728+
paddle.enable_static()
729+
730+
for place in self.place:
731+
run(place)
732+
733+
734+
# 2D normal right case, FP32
735+
class TestSolveOpAPIRight_3(unittest.TestCase):
736+
def setUp(self):
737+
np.random.seed(2021)
738+
self.place = []
739+
self.dtype = "float32"
740+
if (
741+
os.environ.get('FLAGS_CI_both_cpu_and_gpu', 'False').lower()
742+
in ['1', 'true', 'on']
743+
or not core.is_compiled_with_cuda()
744+
):
745+
self.place.append(paddle.CPUPlace())
746+
if core.is_compiled_with_cuda():
747+
self.place.append(paddle.CUDAPlace(0))
748+
749+
def check_static_result(self, place):
750+
paddle.enable_static()
751+
with base.program_guard(base.Program(), base.Program()):
752+
paddle_input_x = paddle.static.data(
753+
name="input_x", shape=[10, 10], dtype=self.dtype
754+
)
755+
paddle_input_y = paddle.static.data(
756+
name="input_y", shape=[6, 10], dtype=self.dtype
757+
)
758+
paddle_result = paddle.linalg.solve(
759+
paddle_input_x, paddle_input_y, left=False
760+
)
761+
762+
np_input_x = np.random.random([10, 10]).astype(self.dtype)
763+
np_input_y = np.random.random([6, 10]).astype(self.dtype)
764+
765+
np_result = np_solve_right(np_input_x, np_input_y)
766+
767+
exe = base.Executor(place)
768+
fetches = exe.run(
769+
base.default_main_program(),
770+
feed={"input_x": np_input_x, "input_y": np_input_y},
771+
fetch_list=[paddle_result],
772+
)
773+
np.testing.assert_allclose(fetches[0], np_result, rtol=0.0001)
774+
775+
def test_static(self):
776+
for place in self.place:
777+
self.check_static_result(place=place)
778+
779+
def test_dygraph(self):
780+
def run(place):
781+
paddle.disable_static(place)
782+
np.random.seed(2021)
783+
input_x_np = np.random.random([10, 10]).astype(self.dtype)
784+
input_y_np = np.random.random([6, 10]).astype(self.dtype)
785+
786+
tensor_input_x = paddle.to_tensor(input_x_np)
787+
tensor_input_y = paddle.to_tensor(input_y_np)
788+
789+
numpy_output = np_solve_right(input_x_np, input_y_np)
790+
paddle_output = paddle.linalg.solve(
791+
tensor_input_x, tensor_input_y, left=False
792+
)
793+
np.testing.assert_allclose(
794+
numpy_output, paddle_output.numpy(), rtol=0.0001
795+
)
796+
self.assertEqual(numpy_output.shape, paddle_output.numpy().shape)
797+
paddle.enable_static()
798+
799+
for place in self.place:
800+
run(place)
801+
802+
803+
# 3D + y broadcast right case, FP64
804+
class TestSolveOpAPIRight_4(unittest.TestCase):
805+
def setUp(self):
806+
np.random.seed(2021)
807+
self.place = []
808+
self.dtype = "float64"
809+
if (
810+
os.environ.get('FLAGS_CI_both_cpu_and_gpu', 'False').lower()
811+
in ['1', 'true', 'on']
812+
or not core.is_compiled_with_cuda()
813+
):
814+
self.place.append(paddle.CPUPlace())
815+
if core.is_compiled_with_cuda():
816+
self.place.append(paddle.CUDAPlace(0))
817+
818+
def check_static_result(self, place):
819+
with base.program_guard(base.Program(), base.Program()):
820+
paddle_input_x = paddle.static.data(
821+
name="input_x", shape=[2, 3, 3], dtype=self.dtype
822+
)
823+
paddle_input_y = paddle.static.data(
824+
name="input_y", shape=[1, 3, 3], dtype=self.dtype
825+
)
826+
paddle_result = paddle.linalg.solve(
827+
paddle_input_x, paddle_input_y, left=False
828+
)
829+
830+
np_input_x = np.random.random([2, 3, 3]).astype(self.dtype)
831+
np_input_y = np.random.random([1, 3, 3]).astype(self.dtype)
832+
833+
np_result = np_solve_right(np_input_x, np_input_y)
834+
835+
exe = base.Executor(place)
836+
fetches = exe.run(
837+
base.default_main_program(),
838+
feed={"input_x": np_input_x, "input_y": np_input_y},
839+
fetch_list=[paddle_result],
840+
)
841+
np.testing.assert_allclose(fetches[0], np_result, rtol=1e-05)
842+
843+
def test_static(self):
844+
for place in self.place:
845+
self.check_static_result(place=place)
846+
847+
def test_dygraph(self):
848+
def run(place):
849+
paddle.disable_static(place)
850+
np.random.seed(2021)
851+
input_x_np = np.random.random([2, 3, 3]).astype(self.dtype)
852+
input_y_np = np.random.random([1, 3, 3]).astype(self.dtype)
853+
854+
tensor_input_x = paddle.to_tensor(input_x_np)
855+
tensor_input_y = paddle.to_tensor(input_y_np)
856+
857+
numpy_output = np_solve_right(input_x_np, input_y_np)
858+
paddle_output = paddle.linalg.solve(
859+
tensor_input_x, tensor_input_y, left=False
860+
)
861+
np.testing.assert_allclose(
862+
numpy_output, paddle_output.numpy(), rtol=1e-05
863+
)
864+
self.assertEqual(numpy_output.shape, paddle_output.numpy().shape)
865+
paddle.enable_static()
866+
867+
for place in self.place:
868+
run(place)
869+
870+
573871
class TestSolveOpSingularAPI(unittest.TestCase):
574872
# Singular matrix is ​​not invertible
575873
def setUp(self):

0 commit comments

Comments
 (0)