Skip to content

Commit f0a382a

Browse files
authored
Merge pull request #149 from sp-nitech/dtw_merge
Add dtw_merge
2 parents 4220c8f + 11fb454 commit f0a382a

File tree

6 files changed

+88
-34
lines changed

6 files changed

+88
-34
lines changed

diffsptk/functional.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,38 @@ def dtw(
571571
)
572572

573573

574+
def dtw_merge(x: Tensor, y: Tensor, indices: Tensor) -> tuple[Tensor, Tensor]:
575+
"""Align two vector sequences according to the given path.
576+
577+
Parameters
578+
----------
579+
x : Tensor [shape=(T1, ...)]
580+
The query vector sequence.
581+
582+
y : Tensor [shape=(T2, ...)]
583+
The reference vector sequence.
584+
585+
indices : Tensor [shape=(T, 2)]
586+
The indices of the path.
587+
588+
Returns
589+
-------
590+
x_align : Tensor [shape=(T, ...)]
591+
The aligned query vector sequence.
592+
593+
y_align : Tensor [shape=(T, ...)]
594+
The aligned reference vector sequence.
595+
596+
"""
597+
if x.dim() != y.dim():
598+
raise ValueError("x and y must have the same number of dimensions.")
599+
if indices.dim() != 2 or indices.size(-1) != 2:
600+
raise ValueError("The shape of indices must be (T, 2).")
601+
x_align = x[indices[:, 0]]
602+
y_align = y[indices[:, 1]]
603+
return x_align, y_align
604+
605+
574606
def entropy(p: Tensor, out_format: str = "nat") -> Tensor:
575607
"""Calculate the entropy of a probability distribution.
576608

diffsptk/modules/dtw.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ def _soft_dtw_core(
3232
B, T1, T2 = D.shape
3333

3434
R = torch.full_like(D, float("inf"))
35-
R[:, 0, 0] = D[:, 0, 0]
3635
R_ = R.clone() if has_two_step_transition else None
36+
R[:, 0, 0] = D[:, 0, 0]
3737

3838
if return_indices:
3939
P = torch.full((B, T1, T2, 2), -1, device=D.device, dtype=torch.long)
@@ -300,14 +300,13 @@ def _forward(
300300
dist_func: Callable,
301301
dtw_func: Callable,
302302
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
303-
d = x.dim()
304-
if d == 1:
303+
if x.dim() == 1:
305304
x = x.view(1, -1, 1)
306305
y = y.view(1, -1, 1)
307-
elif d == 2:
306+
elif x.dim() == 2:
308307
x = x.unsqueeze(0)
309308
y = y.unsqueeze(0)
310-
else:
309+
if x.dim() != 3:
311310
raise ValueError("x and y must be 1D, 2D, or 3D tensor.")
312311
if x.dim() != y.dim():
313312
raise ValueError("x and y must have the same number of dimensions.")

docs/source/modules/dtw.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,7 @@ dtw
99
:members:
1010

1111
.. autofunction:: diffsptk.functional.dtw
12+
13+
.. seealso::
14+
15+
:ref:`dtw_merge`

docs/source/modules/dtw_merge.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
.. _dtw_merge:
2+
3+
dtw_merge
4+
=========
5+
6+
.. autofunction:: diffsptk.functional.dtw_merge
7+
8+
.. seealso::
9+
10+
:ref:`dtw`

tests/test_dtw.py

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
# limitations under the License. #
1515
# ------------------------------------------------------------------------ #
1616

17+
import numpy as np
1718
import pytest
19+
import torch
1820

1921
import diffsptk
2022
import tests.utils as U
@@ -23,7 +25,7 @@
2325
@pytest.mark.parametrize("module", [False, True])
2426
@pytest.mark.parametrize("metric", [0, 1, 2, 3])
2527
@pytest.mark.parametrize("p", [0, 1, 2, 3, 4, 5, 6])
26-
def test_compatibility(device, dtype, module, metric, p, M=0, T1=10, T2=10):
28+
def test_compatibility(device, dtype, module, metric, p, M=2, T1=10, T2=10):
2729
dtw = U.choice(
2830
module,
2931
diffsptk.DynamicTimeWarping,
@@ -35,40 +37,41 @@ def test_compatibility(device, dtype, module, metric, p, M=0, T1=10, T2=10):
3537
tmp1 = "dtw.tmp1"
3638
tmp2 = "dtw.tmp2"
3739
tmp3 = "dtw.tmp3"
40+
41+
def _dtw(x, y):
42+
distance, indices = dtw(x, y, return_indices=True)
43+
output_distance = distance.item()
44+
target_distance = np.fromfile(tmp3, dtype=np.float64)
45+
assert np.allclose(output_distance, target_distance), (
46+
f"Output: {output_distance}\nTarget: {target_distance}"
47+
)
48+
return torch.cat(diffsptk.functional.dtw_merge(x, y, indices[0]), dim=-1)
49+
3850
U.check_compatibility(
3951
device,
4052
dtype,
41-
dtw,
53+
_dtw,
4254
[
4355
f"nrand -s 1 -l {T1 * (M + 1)} {sopr} > {tmp1}",
4456
f"nrand -s 2 -l {T2 * (M + 1)} {sopr} > {tmp2}",
4557
],
4658
[f"cat {tmp1}", f"cat {tmp2}"],
47-
(
48-
f"dtw -m {M} -d {metric} -p {p} < {tmp1} {tmp2} -S {tmp3} > /dev/null "
49-
f"&& cat {tmp3}"
50-
),
59+
(f"dtw -m {M} -d {metric} -p {p} < {tmp1} {tmp2} -S {tmp3}"),
5160
[f"rm {tmp1} {tmp2} {tmp3}"],
5261
dx=M + 1,
62+
dy=2 * (M + 1),
5363
)
5464

55-
U.check_compatibility(
56-
device,
57-
dtype,
65+
U.check_differentiability(device, dtype, dtw, [(T1, M + 1), (T2, M + 1)])
66+
67+
68+
def test_various_shape(T=10):
69+
dtw = diffsptk.DynamicTimeWarping()
70+
U.check_various_shape(
5871
dtw,
5972
[
60-
f"nrand -s 1 -l {T1} {sopr} > {tmp1}",
61-
f"nrand -s 2 -l {T2} {sopr} > {tmp2}",
73+
[(T,), (T,)],
74+
[(T, 1), (T, 1)],
75+
[(1, T, 1), (1, T, 1)],
6276
],
63-
[f"cat {tmp1}", f"cat {tmp2}"],
64-
(
65-
f"dtw -m 0 -d {metric} -p {p} < {tmp1} {tmp2} -P {tmp3} > /dev/null "
66-
f"&& x2x +id {tmp3}"
67-
),
68-
[f"rm {tmp1} {tmp2} {tmp3}"],
69-
dy=2,
70-
opt={"return_indices": True},
71-
get=[1, 0],
7277
)
73-
74-
U.check_differentiability(device, dtype, dtw, [(T1, M + 1), (T2, M + 1)])

tests/utils.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,6 @@ def check_compatibility(
148148
if dy is not None:
149149
y = y.reshape(-1, dy)
150150

151-
for cmd in teardown:
152-
call(cmd, get=False)
153-
154151
module = compose(*modules)
155152
if len(key) == 0:
156153
y_hat = module(*x, **opt)
@@ -180,6 +177,9 @@ def check_compatibility(
180177
else:
181178
assert eq(y_hat, y, **kwargs), f"Output: {y_hat}\nTarget: {y}"
182179

180+
for cmd in teardown:
181+
call(cmd, get=False)
182+
183183

184184
def check_confidence(
185185
device,
@@ -256,12 +256,18 @@ def check_differentiability(
256256

257257

258258
def check_various_shape(module, shapes, *, preprocess=None):
259-
x = torch.randn(*shapes[0])
259+
if is_array(shapes[0][0]):
260+
xs = [torch.randn(*shape) for shape in shapes[0]]
261+
else:
262+
xs = [torch.randn(*shapes[0])]
260263
if preprocess is not None:
261-
x = preprocess(x)
264+
xs = [preprocess(x) for x in xs]
262265
for i, shape in enumerate(shapes):
263-
x = x.view(shape)
264-
y = module(x).view(-1)
266+
if is_array(shapes[0][0]):
267+
x = [x.view(*shape) for x, shape in zip(xs, shape)]
268+
else:
269+
x = [xs[0].view(shape)]
270+
y = module(*x).view(-1)
265271
if i == 0:
266272
target = y
267273
else:

0 commit comments

Comments
 (0)