1414# limitations under the License. #
1515# ------------------------------------------------------------------------ #
1616
17+ import numpy as np
1718import pytest
19+ import torch
1820
1921import diffsptk
2022import tests .utils as U
2325@pytest .mark .parametrize ("module" , [False , True ])
2426@pytest .mark .parametrize ("metric" , [0 , 1 , 2 , 3 ])
2527@pytest .mark .parametrize ("p" , [0 , 1 , 2 , 3 , 4 , 5 , 6 ])
26- def test_compatibility (device , dtype , module , metric , p , M = 0 , T1 = 10 , T2 = 10 ):
28+ def test_compatibility (device , dtype , module , metric , p , M = 2 , T1 = 10 , T2 = 10 ):
2729 dtw = U .choice (
2830 module ,
2931 diffsptk .DynamicTimeWarping ,
@@ -35,40 +37,41 @@ def test_compatibility(device, dtype, module, metric, p, M=0, T1=10, T2=10):
3537 tmp1 = "dtw.tmp1"
3638 tmp2 = "dtw.tmp2"
3739 tmp3 = "dtw.tmp3"
40+
41+ def _dtw (x , y ):
42+ distance , indices = dtw (x , y , return_indices = True )
43+ output_distance = distance .item ()
44+ target_distance = np .fromfile (tmp3 , dtype = np .float64 )
45+ assert np .allclose (output_distance , target_distance ), (
46+ f"Output: { output_distance } \n Target: { target_distance } "
47+ )
48+ return torch .cat (diffsptk .functional .dtw_merge (x , y , indices [0 ]), dim = - 1 )
49+
3850 U .check_compatibility (
3951 device ,
4052 dtype ,
41- dtw ,
53+ _dtw ,
4254 [
4355 f"nrand -s 1 -l { T1 * (M + 1 )} { sopr } > { tmp1 } " ,
4456 f"nrand -s 2 -l { T2 * (M + 1 )} { sopr } > { tmp2 } " ,
4557 ],
4658 [f"cat { tmp1 } " , f"cat { tmp2 } " ],
47- (
48- f"dtw -m { M } -d { metric } -p { p } < { tmp1 } { tmp2 } -S { tmp3 } > /dev/null "
49- f"&& cat { tmp3 } "
50- ),
59+ (f"dtw -m { M } -d { metric } -p { p } < { tmp1 } { tmp2 } -S { tmp3 } " ),
5160 [f"rm { tmp1 } { tmp2 } { tmp3 } " ],
5261 dx = M + 1 ,
62+ dy = 2 * (M + 1 ),
5363 )
5464
55- U .check_compatibility (
56- device ,
57- dtype ,
65+ U .check_differentiability (device , dtype , dtw , [(T1 , M + 1 ), (T2 , M + 1 )])
66+
67+
68+ def test_various_shape (T = 10 ):
69+ dtw = diffsptk .DynamicTimeWarping ()
70+ U .check_various_shape (
5871 dtw ,
5972 [
60- f"nrand -s 1 -l { T1 } { sopr } > { tmp1 } " ,
61- f"nrand -s 2 -l { T2 } { sopr } > { tmp2 } " ,
73+ [(T ,), (T ,)],
74+ [(T , 1 ), (T , 1 )],
75+ [(1 , T , 1 ), (1 , T , 1 )],
6276 ],
63- [f"cat { tmp1 } " , f"cat { tmp2 } " ],
64- (
65- f"dtw -m 0 -d { metric } -p { p } < { tmp1 } { tmp2 } -P { tmp3 } > /dev/null "
66- f"&& x2x +id { tmp3 } "
67- ),
68- [f"rm { tmp1 } { tmp2 } { tmp3 } " ],
69- dy = 2 ,
70- opt = {"return_indices" : True },
71- get = [1 , 0 ],
7277 )
73-
74- U .check_differentiability (device , dtype , dtw , [(T1 , M + 1 ), (T2 , M + 1 )])
0 commit comments