Skip to content

Commit 18edb87

Browse files
authored
Merge pull request #150 from sp-nitech/f0eval2
Add f0eval
2 parents 5b27aa1 + 57bd7eb commit 18edb87

File tree

15 files changed

+268
-15
lines changed

15 files changed

+268
-15
lines changed

diffsptk/functional.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -671,6 +671,35 @@ def excite(
671671
)
672672

673673

674+
def f0eval(
675+
x: Tensor, y: Tensor, reduction: str = "mean", out_format: str = "f0-rmse-cent"
676+
) -> Tensor:
677+
"""Calculate F0 metric.
678+
679+
Parameters
680+
----------
681+
x : Tensor [shape=(..., N)]
682+
The input F0 in Hz.
683+
684+
y : Tensor [shape=(..., N)]
685+
The target F0 in Hz.
686+
687+
reduction : ['none', 'mean', 'sum']
688+
The reduction type.
689+
690+
out_format : ['f0-rmse-hz', 'f0-rmse-cent', 'f0-rmse-semitone', 'vuv-error-rate', \
691+
'vuv-error-percent', 'vuv-macro-f1-score']
692+
The output format.
693+
694+
Returns
695+
-------
696+
out : Tensor [shape=(...,) or scalar]
697+
The F0 metric.
698+
699+
"""
700+
return nn.F0Evaluation._func(x, y, reduction=reduction, out_format=out_format)
701+
702+
674703
def fbank(
675704
x: Tensor,
676705
n_channel: int,

diffsptk/modules/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from .dtw import DynamicTimeWarping as DTW
4747
from .entropy import Entropy
4848
from .excite import ExcitationGeneration
49+
from .f0eval import F0Evaluation
4950
from .fbank import MelFilterBankAnalysis
5051
from .fbank import MelFilterBankAnalysis as FBANK
5152
from .fftcep import CepstralAnalysis

diffsptk/modules/f0eval.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# ------------------------------------------------------------------------ #
2+
# Copyright 2022 SPTK Working Group #
3+
# #
4+
# Licensed under the Apache License, Version 2.0 (the "License"); #
5+
# you may not use this file except in compliance with the License. #
6+
# You may obtain a copy of the License at #
7+
# #
8+
# http://www.apache.org/licenses/LICENSE-2.0 #
9+
# #
10+
# Unless required by applicable law or agreed to in writing, software #
11+
# distributed under the License is distributed on an "AS IS" BASIS, #
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #
13+
# See the License for the specific language governing permissions and #
14+
# limitations under the License. #
15+
# ------------------------------------------------------------------------ #
16+
17+
import torch
18+
19+
from ..typing import Precomputed
20+
from ..utils.private import UNVOICED_SYMBOL, filter_values
21+
from .base import BaseFunctionalModule
22+
from .rmse import RootMeanSquareError
23+
24+
25+
class F0Evaluation(BaseFunctionalModule):
26+
"""See `this page <https://sp-nitech.github.io/sptk/latest/main/f0eval.html>`_
27+
for details. Note that the gradients cannot be calculated if the output format
28+
is related to voiced/unvoiced decision.
29+
30+
Parameters
31+
----------
32+
reduction : ['none', 'mean', 'sum']
33+
The reduction type.
34+
35+
out_format : ['f0-rmse-hz', 'f0-rmse-cent', 'f0-rmse-semitone', 'vuv-error-rate', \
36+
'vuv-error-percent', 'vuv-macro-f1-score']
37+
The output format.
38+
39+
"""
40+
41+
def __init__(
42+
self, reduction: str = "mean", out_format: str = "f0-rmse-cent"
43+
) -> None:
44+
super().__init__()
45+
46+
self.values = self._precompute(**filter_values(locals()))
47+
48+
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
49+
"""Calculate F0 metric.
50+
51+
Parameters
52+
----------
53+
x : Tensor [shape=(..., N)]
54+
The input F0 in Hz.
55+
56+
y : Tensor [shape=(..., N)]
57+
The target F0 in Hz.
58+
59+
Returns
60+
-------
61+
out : Tensor [shape=(...,) or scalar]
62+
The F0 metric.
63+
64+
"""
65+
return self._forward(x, y, *self.values)
66+
67+
@staticmethod
68+
def _func(x: torch.Tensor, y: torch.Tensor, *args, **kwargs) -> torch.Tensor:
69+
values = F0Evaluation._precompute(*args, **kwargs)
70+
return F0Evaluation._forward(x, y, *values)
71+
72+
@staticmethod
73+
def _takes_input_size() -> bool:
74+
return False
75+
76+
@staticmethod
77+
def _check() -> None:
78+
pass
79+
80+
@staticmethod
81+
def _precompute(reduction: str, out_format: str) -> Precomputed:
82+
F0Evaluation._check()
83+
return (reduction, out_format)
84+
85+
@staticmethod
86+
def _forward(
87+
x: torch.Tensor, y: torch.Tensor, reduction: str, out_format: str
88+
) -> torch.Tensor:
89+
if out_format.startswith("f0-rmse"):
90+
voiced = (x != UNVOICED_SYMBOL) & (y != UNVOICED_SYMBOL)
91+
if out_format == "f0-rmse-hz":
92+
convert = lambda x: x
93+
elif out_format == "f0-rmse-cent":
94+
convert = lambda x: 1200 * torch.log2(x)
95+
elif out_format == "f0-rmse-semitone":
96+
convert = lambda x: 12 * torch.log2(x)
97+
else:
98+
raise ValueError(f"out_format {out_format} is not supported.")
99+
out = RootMeanSquareError._func(
100+
convert(x[voiced]), convert(y[voiced]), "none"
101+
)
102+
else:
103+
TP = torch.sum((x != UNVOICED_SYMBOL) & (y != UNVOICED_SYMBOL), dim=-1)
104+
FP = torch.sum((x == UNVOICED_SYMBOL) & (y != UNVOICED_SYMBOL), dim=-1)
105+
FN = torch.sum((x != UNVOICED_SYMBOL) & (y == UNVOICED_SYMBOL), dim=-1)
106+
TN = torch.sum((x == UNVOICED_SYMBOL) & (y == UNVOICED_SYMBOL), dim=-1)
107+
FPFN = FP + FN
108+
if out_format == "vuv-error-rate":
109+
out = FPFN / x.shape[-1]
110+
elif out_format == "vuv-error-percent":
111+
out = 100 * FPFN / x.shape[-1]
112+
elif out_format == "vuv-macro-f1-score":
113+
f1_score_pos = torch.nan_to_num((2 * TP) / (2 * TP + FPFN))
114+
f1_score_neg = torch.nan_to_num((2 * TN) / (2 * TN + FPFN))
115+
out = (f1_score_pos + f1_score_neg) / 2
116+
else:
117+
raise ValueError(f"out_format {out_format} is not supported.")
118+
119+
if reduction == "none":
120+
pass
121+
elif reduction == "sum":
122+
out = out.sum()
123+
elif reduction == "mean":
124+
out = out.mean()
125+
else:
126+
raise ValueError(f"reduction {reduction} is not supported.")
127+
128+
return out

docs/source/modules/f0eval.rst

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
.. _f0eval:
2+
3+
f0eval
4+
======
5+
6+
.. autoclass:: diffsptk.F0Evaluation
7+
:members:
8+
9+
.. autofunction:: diffsptk.functional.f0eval
10+
11+
.. seealso::
12+
13+
:ref:`rmse`

docs/source/modules/rmse.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,4 @@ rmse
1212

1313
.. seealso::
1414

15-
:ref:`snr`
15+
:ref:`snr` :ref:`f0eval`

tests/test_entropy.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
# ------------------------------------------------------------------------ #
1616

1717
import pytest
18-
import torch
1918

2019
import diffsptk
2120
import tests.utils as U
@@ -42,4 +41,4 @@ def test_compatibility(device, dtype, module, out_format, L=5, B=2):
4241
dx=L,
4342
)
4443

45-
U.check_differentiability(device, dtype, [entropy, torch.abs], [B, L])
44+
U.check_differentiability(device, dtype, entropy, [B, L], nonnegative_input=True)

tests/test_f0eval.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# ------------------------------------------------------------------------ #
2+
# Copyright 2022 SPTK Working Group #
3+
# #
4+
# Licensed under the Apache License, Version 2.0 (the "License"); #
5+
# you may not use this file except in compliance with the License. #
6+
# You may obtain a copy of the License at #
7+
# #
8+
# http://www.apache.org/licenses/LICENSE-2.0 #
9+
# #
10+
# Unless required by applicable law or agreed to in writing, software #
11+
# distributed under the License is distributed on an "AS IS" BASIS, #
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #
13+
# See the License for the specific language governing permissions and #
14+
# limitations under the License. #
15+
# ------------------------------------------------------------------------ #
16+
17+
import pytest
18+
import torch
19+
20+
import diffsptk
21+
import tests.utils as U
22+
23+
24+
@pytest.mark.parametrize("module", [False, True])
25+
@pytest.mark.parametrize(
26+
"out_format",
27+
[
28+
"f0-rmse-hz",
29+
"f0-rmse-cent",
30+
"f0-rmse-semitone",
31+
"vuv-error-rate",
32+
"vuv-error-percent",
33+
],
34+
)
35+
@pytest.mark.parametrize("reduction", ["none", "mean", "sum"])
36+
def test_compatibility(device, dtype, module, reduction, out_format, B=2, L=10):
37+
f0eval = U.choice(
38+
module,
39+
diffsptk.F0Evaluation,
40+
diffsptk.functional.f0eval,
41+
{"reduction": reduction, "out_format": out_format},
42+
)
43+
44+
tmp1 = "f0eval.tmp1"
45+
tmp2 = "f0eval.tmp2"
46+
if out_format == "f0-rmse-hz":
47+
cmd = f"rmse -magic 0 {tmp1} {tmp2}"
48+
else:
49+
o = 1 if out_format.startswith("f0-rmse") else 2
50+
mul = 0.01 if out_format in ("f0-rmse-semitone", "vuv-error-rate") else 1
51+
cmd = f"f0eval -q 1 -o {o} {tmp1} {tmp2} | sopr -m {mul}"
52+
53+
U.check_compatibility(
54+
device,
55+
dtype,
56+
f0eval,
57+
[
58+
f"echo 0 0 200 210 0 200 0 | x2x +ad > {tmp1}",
59+
f"echo 0 0 190 180 180 0 0 | x2x +ad > {tmp2}",
60+
],
61+
[f"cat {tmp1}", f"cat {tmp2}"],
62+
cmd,
63+
[f"rm {tmp1} {tmp2}"],
64+
)
65+
66+
if out_format.startswith("f0-rmse"):
67+
U.check_differentiability(
68+
device, dtype, f0eval, [(B, L), (B, L)], nonnegative_input=True
69+
)
70+
71+
72+
def test_f1_score():
73+
f0eval = diffsptk.F0Evaluation(out_format="vuv-macro-f1-score")
74+
x = torch.tensor([0, 1, 1, 0, 0, 1, 0, 1, 0])
75+
y = torch.tensor([0, 1, 0, 0, 1, 0, 0, 1, 1])
76+
f1_score = f0eval(x, y)
77+
assert U.allclose(f1_score, torch.tensor(0.55))

tests/test_ifbank.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
# ------------------------------------------------------------------------ #
1616

1717
import pytest
18-
import torch
1918

2019
import diffsptk
2120
import tests.utils as U
@@ -70,7 +69,7 @@ def test_compatibility(
7069
)
7170

7271
U.check_differentiability(
73-
device, dtype, [ifbank, fbank, torch.abs], [B, L // 2 + 1]
72+
device, dtype, [ifbank, fbank], [B, L // 2 + 1], nonnegative_input=True
7473
)
7574

7675

tests/test_ignorm.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
# ------------------------------------------------------------------------ #
1616

1717
import pytest
18-
import torch
1918

2019
import diffsptk
2120
import tests.utils as U
@@ -44,4 +43,4 @@ def test_compatibility(device, dtype, module, gamma, c, M=4, B=2):
4443
dy=M + 1,
4544
)
4645

47-
U.check_differentiability(device, dtype, [ignorm, torch.abs], [B, M + 1])
46+
U.check_differentiability(device, dtype, ignorm, [B, M + 1], nonnegative_input=True)

tests/test_lsp2sp.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,4 +50,6 @@ def test_compatibility(device, dtype, module, M, out_format, L=16, B=2):
5050
dy=L // 2 + 1,
5151
)
5252

53-
U.check_differentiability(device, dtype, [lsp2sp, torch.abs], [B, M + 1])
53+
U.check_differentiability(
54+
device, dtype, [lsp2sp, lambda x: torch.sort(x)[0], torch.abs], [B, M + 1]
55+
)

0 commit comments

Comments
 (0)