- 
                Notifications
    You must be signed in to change notification settings 
- Fork 5.9k
【Hackathon 5th No.37】为 Paddle 新增 householder_product API -part #58214
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 13 commits
5f627a5
              1dcde7a
              ce7f286
              708fd5d
              75c9e06
              3f30938
              1c49aa4
              8899ce7
              f477ba5
              3bedcfa
              d43aaba
              173a4cd
              ac67f4a
              1645665
              8bfcd24
              b5559d3
              b92a244
              67c350d
              8f64185
              77aafeb
              d8c3ef1
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -3724,3 +3724,133 @@ def cdist( | |
| return paddle.linalg.norm( | ||
| x[..., None, :] - y[..., None, :, :], p=p, axis=-1 | ||
| ) | ||
|  | ||
|  | ||
| def householder_product(A, tau, name=None): | ||
|          | ||
| r""" | ||
|  | ||
| Computes the first n columns of a product of Householder matrices. | ||
|  | ||
| This function can get the vector :math:`\omega_{i}` from matrix `A`(m x n), the :math:`i-1` elements are zeros, and the i-th is `1`, the rest of the elements are from i-th column of `A`. | ||
| And with the vector `tau` can calculate the first n columns of a product of Householder matrices. | ||
|  | ||
| :math:`H_i = I_m - \tau_i \omega_i \omega_i^H` | ||
|  | ||
| Args: | ||
| A (Tensor): A tensor with shape (*, m, n) where * is zero or more batch dimensions. | ||
| tau (Tensor): A tensor with shape (*, k) where * is zero or more batch dimensions. | ||
| name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. | ||
|  | ||
| Returns: | ||
| Tensor, the dtype is same as input tensor, the Q in QR decomposition. | ||
|  | ||
| :math:`out = Q = H_1H_2H_3...H_k` | ||
|  | ||
| Examples: | ||
| .. code-block:: python | ||
|  | ||
| >>> import paddle | ||
| >>> A = paddle.to_tensor([[-1.1280, 0.9012, -0.0190], | ||
| ... [ 0.3699, 2.2133, -1.4792], | ||
| ... [ 0.0308, 0.3361, -3.1761], | ||
| ... [-0.0726, 0.8245, -0.3812]]) | ||
| >>> tau = paddle.to_tensor([1.7497, 1.1156, 1.7462]) | ||
| >>> Q = paddle.linalg.householder_product(A, tau) | ||
| >>> Q | ||
|         
                  cocoshe marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
| Tensor(shape=[4, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True, | ||
| [[-0.74969995, -0.02181768, 0.31115776], | ||
| [-0.64721400, -0.12367040, -0.21738708], | ||
| [-0.05389076, -0.37562513, -0.84836429], | ||
| [ 0.12702821, -0.91822827, 0.36892807]]) | ||
| """ | ||
|  | ||
| check_dtype( | ||
| A.dtype, | ||
| 'x', | ||
| [ | ||
| 'float32', | ||
| 'float64', | ||
| 'complex64', | ||
| 'complex128', | ||
| ], | ||
| 'householder_product', | ||
| ) | ||
| check_dtype( | ||
| tau.dtype, | ||
| 'tau', | ||
| [ | ||
| 'float32', | ||
| 'float64', | ||
| 'complex64', | ||
| 'complex128', | ||
| ], | ||
| 'householder_product', | ||
| ) | ||
| assert ( | ||
| A.dtype == tau.dtype | ||
| ), "The input A must have the same dtype with input tau.\n" | ||
| assert ( | ||
| len(A.shape) >= 2 | ||
| and len(tau.shape) >= 1 | ||
| and len(A.shape) == len(tau.shape) + 1 | ||
| ), ( | ||
| "The input A must have more than 2 dimensions, and input tau must have more than 1 dimension," | ||
| "and the dimension of A is 1 larger than the dimension of tau\n" | ||
| ) | ||
| assert ( | ||
| A.shape[-2] >= A.shape[-1] | ||
|          | ||
| ), "The rows of input A must be greater than or equal to the columns of input A.\n" | ||
| assert ( | ||
| A.shape[-1] >= tau.shape[-1] | ||
| ), "The last dim of A must be greater than tau.\n" | ||
| for idx, _ in enumerate(A.shape[:-2]): | ||
| assert ( | ||
| A.shape[idx] == tau.shape[idx] | ||
| ), "The input A must have the same batch dimensions with input tau.\n" | ||
|  | ||
| def _householder_product(A, tau): | ||
| m, n = A.shape[-2:] | ||
| k = tau.shape[-1] | ||
| Q = paddle.eye(m).astype(A.dtype) | ||
| for i in range(min(k, n)): | ||
| w = A[i:, i] | ||
| if in_dynamic_mode(): | ||
| w[0] = 1 | ||
| else: | ||
| w = paddle.static.setitem(w, 0, 1) | ||
| w = w.reshape([-1, 1]) | ||
| if in_dynamic_mode(): | ||
| if A.dtype in [paddle.complex128, paddle.complex64]: | ||
| Q[:, i:] = Q[:, i:] - ( | ||
| Q[:, i:] @ w @ paddle.conj(w).T * tau[i] | ||
| ) | ||
| else: | ||
| Q[:, i:] = Q[:, i:] - (Q[:, i:] @ w @ w.T * tau[i]) | ||
| else: | ||
| Q = paddle.static.setitem( | ||
| Q, | ||
| (slice(None), slice(i, None)), | ||
| Q[:, i:] - (Q[:, i:] @ w @ w.T * tau[i]) | ||
| if A.dtype in [paddle.complex128, paddle.complex64] | ||
| else Q[:, i:] - (Q[:, i:] @ w @ w.T * tau[i]), | ||
| ) | ||
| return Q[:, :n] | ||
|  | ||
| if len(A.shape) == 2: | ||
| return _householder_product(A, tau) | ||
| m, n = A.shape[-2:] | ||
| org_A_shape = A.shape | ||
| org_tau_shape = tau.shape | ||
| A = A.reshape((-1, org_A_shape[-2], org_A_shape[-1])) | ||
| tau = tau.reshape((-1, org_tau_shape[-1])) | ||
| n_batch = A.shape[0] | ||
| out = paddle.zeros([n_batch, m, n], dtype=A.dtype) | ||
| for i in range(n_batch): | ||
| if in_dynamic_mode(): | ||
| out[i] = _householder_product(A[i], tau[i]) | ||
| else: | ||
| out = paddle.static.setitem( | ||
| out, i, _householder_product(A[i], tau[i]) | ||
| ) | ||
| out = out.reshape(org_A_shape) | ||
| return out | ||
Uh oh!
There was an error while loading. Please reload this page.