1616
1717import numpy as np
1818import torch
19- import torchcomp
2019from torch import nn
2120
22- from ..typing import Precomputed
21+ from ..typing import Callable , Precomputed
2322from ..utils .private import filter_values , to , to_2d
2423from .base import BaseFunctionalModule
2524
@@ -161,6 +160,8 @@ def _precompute(
161160 device : torch .device | None ,
162161 dtype : torch .dtype | None ,
163162 ) -> Precomputed :
163+ import torchcomp
164+
164165 DynamicRangeCompression ._check (
165166 ratio , attack_time , release_time , sample_rate , makeup_gain , abs_max
166167 )
@@ -179,20 +180,21 @@ def _precompute(
179180 makeup_gain = to (torch .tensor (makeup_gain , device = device ), dtype = dtype )
180181 makeup_gain = 10 ** (makeup_gain / 20 )
181182 params = torch .stack ([threshold , ratio , attack_time , release_time , makeup_gain ])
182- return (abs_max ,), None , (params ,)
183+ return (abs_max , torchcomp . compexp_gain ), None , (params ,)
183184
184185 @staticmethod
185186 def _forward (
186187 x : torch .Tensor ,
187188 abs_max : float ,
189+ compexp_gain : Callable ,
188190 params : torch .Tensor ,
189191 ) -> torch .Tensor :
190192 eps = 1e-10
191193
192194 y = to_2d (x )
193195 y_abs = y .abs () / abs_max + eps
194196
195- g = torchcomp . compexp_gain (
197+ g = compexp_gain (
196198 y_abs ,
197199 params [0 ],
198200 params [1 ],
0 commit comments