|
4 | 4 |
|
5 | 5 | import ../tensor |
6 | 6 | import ./helpers/triangular |
7 | | -import std / [sequtils, bitops] |
| 7 | +import std / [sequtils, bitops, strformat] |
8 | 8 |
|
9 | 9 | proc hilbert*(n: int, T: typedesc[SomeFloat]): Tensor[T] = |
10 | 10 | ## Generates an Hilbert matrix of shape [N, N] |
@@ -129,6 +129,7 @@ proc diagonal*[T](a: Tensor[T], k = 0, anti = false): Tensor[T] {.noInit.} = |
129 | 129 | ## - anti: If true, get the k-th "anti-diagonal" instead of the k-th regular diagonal. |
130 | 130 | ## Result: |
131 | 131 | ## - A copy of the diagonal elements as a rank-1 tensor |
| 132 | + bind `&` |
132 | 133 | assert a.rank == 2, "diagonal() only works on matrices" |
133 | 134 | assert k < a.shape[0], &"Diagonal index ({k=}) exceeds the output matrix height ({a.shape[0]})" |
134 | 135 | assert k < a.shape[1], &"Diagonal index ({k=}) exceeds the output matrix width ({a.shape[1]})" |
@@ -167,6 +168,7 @@ proc set_diagonal*[T](a: var Tensor[T], d: Tensor[T], k = 0, anti = false) = |
167 | 168 | ## - k: The index k of the diagonal that will be changed. The default is 0 (i.e. the main diagonal). |
168 | 169 | ## Use k>0 for diagonals above the main diagonal, and k<0 for diagonals below the main diagonal. |
169 | 170 | ## - anti: If true, set the k-th "anti-diagonal" instead of the k-th regular diagonal. |
| 171 | + bind `&` |
170 | 172 | assert a.rank == 2, "set_diagonal() only works on matrices" |
171 | 173 | assert d.rank == 1, "The diagonal passed to set_diagonal() must be a rank-1 tensor" |
172 | 174 | assert k < a.shape[0], &"Diagonal index ({k=}) exceeds input matrix height ({a.shape[0]})" |
@@ -259,6 +261,7 @@ proc tri*[T](shape: Metadata, k: static int = 0, upper: static bool = false): Te |
259 | 261 | ## diagonal. The default is false. |
260 | 262 | ## Result: |
261 | 263 | ## - The constructed, rank-2 triangular tensor. |
| 264 | + bind `&` |
262 | 265 | assert shape.len == 2, &"tri() requires a rank-2 shape as it's input but a shape of rank {shape.len} was passed" |
263 | 266 | assert k < shape[0], &"tri() received a diagonal index ({k=}) which exceeds the output matrix height ({shape[0]})" |
264 | 267 | assert k < shape[1], &"tri() received a diagonal index ({k=}) which exceeds the output matrix width ({shape[1]})" |
|
0 commit comments