Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions diffsptk/modules/drc.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,9 @@

import numpy as np
import torch
import torchcomp
from torch import nn

from ..typing import Precomputed
from ..typing import Callable, Precomputed
from ..utils.private import filter_values, to, to_2d
from .base import BaseFunctionalModule

Expand Down Expand Up @@ -161,6 +160,8 @@ def _precompute(
device: torch.device | None,
dtype: torch.dtype | None,
) -> Precomputed:
import torchcomp

DynamicRangeCompression._check(
ratio, attack_time, release_time, sample_rate, makeup_gain, abs_max
)
Expand All @@ -179,20 +180,21 @@ def _precompute(
makeup_gain = to(torch.tensor(makeup_gain, device=device), dtype=dtype)
makeup_gain = 10 ** (makeup_gain / 20)
params = torch.stack([threshold, ratio, attack_time, release_time, makeup_gain])
return (abs_max,), None, (params,)
return (abs_max, torchcomp.compexp_gain), None, (params,)

@staticmethod
def _forward(
x: torch.Tensor,
abs_max: float,
compexp_gain: Callable,
params: torch.Tensor,
) -> torch.Tensor:
eps = 1e-10

y = to_2d(x)
y_abs = y.abs() / abs_max + eps

g = torchcomp.compexp_gain(
g = compexp_gain(
y_abs,
params[0],
params[1],
Expand Down
11 changes: 6 additions & 5 deletions diffsptk/modules/griffin.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class GriffinLim(BaseFunctionalModule):
symmetric : bool
If True, the window is symmetric, otherwise periodic.

n_iter : int >= 1
n_iter : int >= 0
The number of iterations for phase reconstruction.

alpha : float >= 0
Expand Down Expand Up @@ -159,8 +159,8 @@ def _check(
beta: float,
gamma: float,
) -> None:
if n_iter <= 0:
raise ValueError("n_iter must be positive.")
if n_iter < 0:
raise ValueError("n_iter must be non-negative.")
if alpha < 0:
raise ValueError("alpha must be non-negative.")
if beta < 0:
Expand Down Expand Up @@ -259,7 +259,8 @@ def _forward(
if logger is not None:
logger.info(f"alpha: {alpha}, beta: {beta}, gamma: {gamma}")

s = torch.sqrt(y)
eps = 1e-16
s = torch.sqrt(y + eps)
angle = torch.exp(1j * phase_generator(s))

t_prev = d_prev = 0 # This suppresses F821 and F841.
Expand All @@ -275,7 +276,7 @@ def _forward(
c = t + alpha * diff
d = t + beta * diff

angle = c / (c.abs() + 1e-16)
angle = c / (c.abs() + eps)
t_prev = t
d_prev = d

Expand Down
13 changes: 9 additions & 4 deletions diffsptk/modules/poledf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@
# ------------------------------------------------------------------------ #

import torch
from torchlpc import sample_wise_lpc

from ..typing import Precomputed
from ..typing import Callable, Precomputed
from ..utils.private import check_size, filter_values
from .base import BaseFunctionalModule
from .linear_intpl import LinearInterpolation
Expand Down Expand Up @@ -105,11 +104,17 @@ def _precompute(
filter_order: int, frame_period: int, ignore_gain: bool = False
) -> Precomputed:
AllPoleDigitalFilter._check(filter_order, frame_period)
return (frame_period, ignore_gain)
from torchlpc import sample_wise_lpc

return (frame_period, ignore_gain, sample_wise_lpc)

@staticmethod
def _forward(
x: torch.Tensor, a: torch.Tensor, frame_period: int, ignore_gain: bool
x: torch.Tensor,
a: torch.Tensor,
frame_period: int,
ignore_gain: bool,
sample_wise_lpc: Callable,
) -> torch.Tensor:
check_size(x.size(-1), a.size(-2) * frame_period, "sequence length")

Expand Down
2 changes: 1 addition & 1 deletion diffsptk/utils/public.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,5 +190,5 @@ def write(
>>> diffsptk.write("out.wav", x, sr)

"""
x = x.cpu().numpy() if torch.is_tensor(x) else x
x = x.detach().cpu().numpy() if torch.is_tensor(x) else x
sf.write(filename, x.T if channel_first else x, sr, **kwargs)