Skip to content

Commit 33fb527

Browse files
authored
Merge pull request #143 from sp-nitech/griffin
Maintenance
2 parents 04e974b + 67fc0f6 commit 33fb527

File tree

4 files changed

+22
-14
lines changed

4 files changed

+22
-14
lines changed

diffsptk/modules/drc.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,9 @@
1616

1717
import numpy as np
1818
import torch
19-
import torchcomp
2019
from torch import nn
2120

22-
from ..typing import Precomputed
21+
from ..typing import Callable, Precomputed
2322
from ..utils.private import filter_values, to, to_2d
2423
from .base import BaseFunctionalModule
2524

@@ -161,6 +160,8 @@ def _precompute(
161160
device: torch.device | None,
162161
dtype: torch.dtype | None,
163162
) -> Precomputed:
163+
import torchcomp
164+
164165
DynamicRangeCompression._check(
165166
ratio, attack_time, release_time, sample_rate, makeup_gain, abs_max
166167
)
@@ -179,20 +180,21 @@ def _precompute(
179180
makeup_gain = to(torch.tensor(makeup_gain, device=device), dtype=dtype)
180181
makeup_gain = 10 ** (makeup_gain / 20)
181182
params = torch.stack([threshold, ratio, attack_time, release_time, makeup_gain])
182-
return (abs_max,), None, (params,)
183+
return (abs_max, torchcomp.compexp_gain), None, (params,)
183184

184185
@staticmethod
185186
def _forward(
186187
x: torch.Tensor,
187188
abs_max: float,
189+
compexp_gain: Callable,
188190
params: torch.Tensor,
189191
) -> torch.Tensor:
190192
eps = 1e-10
191193

192194
y = to_2d(x)
193195
y_abs = y.abs() / abs_max + eps
194196

195-
g = torchcomp.compexp_gain(
197+
g = compexp_gain(
196198
y_abs,
197199
params[0],
198200
params[1],

diffsptk/modules/griffin.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ class GriffinLim(BaseFunctionalModule):
5454
symmetric : bool
5555
If True, the window is symmetric, otherwise periodic.
5656
57-
n_iter : int >= 1
57+
n_iter : int >= 0
5858
The number of iterations for phase reconstruction.
5959
6060
alpha : float >= 0
@@ -159,8 +159,8 @@ def _check(
159159
beta: float,
160160
gamma: float,
161161
) -> None:
162-
if n_iter <= 0:
163-
raise ValueError("n_iter must be positive.")
162+
if n_iter < 0:
163+
raise ValueError("n_iter must be non-negative.")
164164
if alpha < 0:
165165
raise ValueError("alpha must be non-negative.")
166166
if beta < 0:
@@ -259,7 +259,8 @@ def _forward(
259259
if logger is not None:
260260
logger.info(f"alpha: {alpha}, beta: {beta}, gamma: {gamma}")
261261

262-
s = torch.sqrt(y)
262+
eps = 1e-16
263+
s = torch.sqrt(y + eps)
263264
angle = torch.exp(1j * phase_generator(s))
264265

265266
t_prev = d_prev = 0 # This suppresses F821 and F841.
@@ -275,7 +276,7 @@ def _forward(
275276
c = t + alpha * diff
276277
d = t + beta * diff
277278

278-
angle = c / (c.abs() + 1e-16)
279+
angle = c / (c.abs() + eps)
279280
t_prev = t
280281
d_prev = d
281282

diffsptk/modules/poledf.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,8 @@
1515
# ------------------------------------------------------------------------ #
1616

1717
import torch
18-
from torchlpc import sample_wise_lpc
1918

20-
from ..typing import Precomputed
19+
from ..typing import Callable, Precomputed
2120
from ..utils.private import check_size, filter_values
2221
from .base import BaseFunctionalModule
2322
from .linear_intpl import LinearInterpolation
@@ -105,11 +104,17 @@ def _precompute(
105104
filter_order: int, frame_period: int, ignore_gain: bool = False
106105
) -> Precomputed:
107106
AllPoleDigitalFilter._check(filter_order, frame_period)
108-
return (frame_period, ignore_gain)
107+
from torchlpc import sample_wise_lpc
108+
109+
return (frame_period, ignore_gain, sample_wise_lpc)
109110

110111
@staticmethod
111112
def _forward(
112-
x: torch.Tensor, a: torch.Tensor, frame_period: int, ignore_gain: bool
113+
x: torch.Tensor,
114+
a: torch.Tensor,
115+
frame_period: int,
116+
ignore_gain: bool,
117+
sample_wise_lpc: Callable,
113118
) -> torch.Tensor:
114119
check_size(x.size(-1), a.size(-2) * frame_period, "sequence length")
115120

diffsptk/utils/public.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,5 +190,5 @@ def write(
190190
>>> diffsptk.write("out.wav", x, sr)
191191
192192
"""
193-
x = x.cpu().numpy() if torch.is_tensor(x) else x
193+
x = x.detach().cpu().numpy() if torch.is_tensor(x) else x
194194
sf.write(filename, x.T if channel_first else x, sr, **kwargs)

0 commit comments

Comments
 (0)