Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 36 additions & 36 deletions benchmark/data/all_benchmark_data.csv
Original file line number Diff line number Diff line change
Expand Up @@ -445,39 +445,39 @@ kl_div,torch,full,speed,ms,V,vocab size,16384,11.124671936035156,11.122162818908
kl_div,torch,full,speed,ms,V,vocab size,32768,23.052032470703125,23.050334930419922,23.052589416503906,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-04 12:59:48,0.2.1
kl_div,torch,full,speed,ms,V,vocab size,65536,46.063167572021484,46.05990219116211,46.06643295288086,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-04 12:59:48,0.2.1
kl_div,torch,full,speed,ms,V,vocab size,131072,92.06393432617188,92.06393432617188,92.06393432617188,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-04 12:59:48,0.2.1
jsd,liger,full,memory,MB,V,vocab size,4096,768.0029296875,768.0029296875,768.0029296875,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:17,0.3.0
jsd,liger,full,memory,MB,V,vocab size,8192,1536.0029296875,1536.0029296875,1536.0029296875,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:17,0.3.0
jsd,liger,full,memory,MB,V,vocab size,16384,3072.0048828125,3072.0048828125,3072.0048828125,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:17,0.3.0
jsd,liger,full,memory,MB,V,vocab size,32768,6144.0087890625,6144.0087890625,6144.0087890625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:17,0.3.0
jsd,liger,full,memory,MB,V,vocab size,65536,12288.0166015625,12288.0166015625,12288.0166015625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:17,0.3.0
jsd,liger,full,memory,MB,V,vocab size,131072,24576.015625,24576.015625,24576.015625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:17,0.3.0
jsd,torch,full,memory,MB,V,vocab size,4096,1664.0009765625,1664.0009765625,1664.0009765625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:19,0.3.0
jsd,torch,full,memory,MB,V,vocab size,8192,3328.0009765625,3328.0009765625,3328.0009765625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:19,0.3.0
jsd,torch,full,memory,MB,V,vocab size,16384,6656.0009765625,6656.0009765625,6656.0009765625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:19,0.3.0
jsd,torch,full,memory,MB,V,vocab size,32768,13312.0009765625,13312.0009765625,13312.0009765625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:19,0.3.0
jsd,torch,full,memory,MB,V,vocab size,65536,26624.0,26624.0,26624.0,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:19,0.3.0
jsd,torch,full,memory,MB,V,vocab size,131072,53248.0,53248.0,53248.0,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:19,0.3.0
jsd,liger,forward,speed,ms,V,vocab size,4096,0.4657920002937317,0.4644480049610138,0.4670400023460388,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:22,0.3.0
jsd,liger,forward,speed,ms,V,vocab size,8192,0.9084159731864929,0.9064639806747437,0.9099519848823547,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:22,0.3.0
jsd,liger,forward,speed,ms,V,vocab size,16384,9.939423561096191,9.933785438537598,9.945216178894043,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:22,0.3.0
jsd,liger,forward,speed,ms,V,vocab size,32768,20.06915283203125,20.05768394470215,20.087200164794922,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:22,0.3.0
jsd,liger,forward,speed,ms,V,vocab size,65536,38.88547134399414,38.880577087402344,38.89036560058594,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:22,0.3.0
jsd,liger,forward,speed,ms,V,vocab size,131072,77.7418212890625,77.7418212890625,77.7418212890625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:22,0.3.0
jsd,torch,forward,speed,ms,V,vocab size,4096,2.1717119216918945,2.1697471141815186,2.173452854156494,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:23,0.3.0
jsd,torch,forward,speed,ms,V,vocab size,8192,4.2592315673828125,4.255411148071289,4.2608771324157715,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:23,0.3.0
jsd,torch,forward,speed,ms,V,vocab size,16384,8.363903999328613,8.359071731567383,8.36620807647705,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:23,0.3.0
jsd,torch,forward,speed,ms,V,vocab size,32768,16.591264724731445,16.588390350341797,16.595033645629883,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:23,0.3.0
jsd,torch,forward,speed,ms,V,vocab size,65536,33.06208038330078,33.06206130981445,33.06536102294922,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:23,0.3.0
jsd,torch,forward,speed,ms,V,vocab size,131072,66.0923843383789,66.0923843383789,66.0923843383789,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:23,0.3.0
jsd,liger,full,speed,ms,V,vocab size,4096,1.5683839321136475,1.4662528038024902,1.7244799137115479,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:25,0.3.0
jsd,liger,full,speed,ms,V,vocab size,8192,2.0588159561157227,2.055116891860962,2.093465566635132,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:25,0.3.0
jsd,liger,full,speed,ms,V,vocab size,16384,11.944751739501953,11.936684608459473,11.961983680725098,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:25,0.3.0
jsd,liger,full,speed,ms,V,vocab size,32768,24.27791976928711,24.254375457763672,24.299558639526367,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:25,0.3.0
jsd,liger,full,speed,ms,V,vocab size,65536,47.206687927246094,47.17191696166992,47.241458892822266,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:25,0.3.0
jsd,liger,full,speed,ms,V,vocab size,131072,94.15420532226562,94.15420532226562,94.15420532226562,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:25,0.3.0
jsd,torch,full,speed,ms,V,vocab size,4096,4.875328063964844,4.873446464538574,4.878073692321777,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:28,0.3.0
jsd,torch,full,speed,ms,V,vocab size,8192,9.582816123962402,9.57910442352295,9.58505630493164,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:28,0.3.0
jsd,torch,full,speed,ms,V,vocab size,16384,18.931264877319336,18.92802619934082,18.934911727905273,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:28,0.3.0
jsd,torch,full,speed,ms,V,vocab size,32768,38.07579040527344,38.07549285888672,38.076087951660156,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:28,0.3.0
jsd,torch,full,speed,ms,V,vocab size,65536,75.97628784179688,75.97628784179688,75.97628784179688,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:28,0.3.0
jsd,torch,full,speed,ms,V,vocab size,131072,151.8501739501953,151.8501739501953,151.8501739501953,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:28,0.3.0
jsd,liger,full,memory,MB,V,vocab size,4096,768.0029296875,768.0029296875,768.0029296875,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:31,0.3.1
jsd,liger,full,memory,MB,V,vocab size,8192,1536.0029296875,1536.0029296875,1536.0029296875,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:31,0.3.1
jsd,liger,full,memory,MB,V,vocab size,16384,3072.0048828125,3072.0048828125,3072.0048828125,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:31,0.3.1
jsd,liger,full,memory,MB,V,vocab size,32768,6144.0087890625,6144.0087890625,6144.0087890625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:31,0.3.1
jsd,liger,full,memory,MB,V,vocab size,65536,12288.0166015625,12288.0166015625,12288.0166015625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:31,0.3.1
jsd,liger,full,memory,MB,V,vocab size,131072,24576.015625,24576.015625,24576.015625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:31,0.3.1
jsd,torch,full,memory,MB,V,vocab size,4096,1664.0009765625,1664.0009765625,1664.0009765625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:33,0.3.1
jsd,torch,full,memory,MB,V,vocab size,8192,3328.0009765625,3328.0009765625,3328.0009765625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:33,0.3.1
jsd,torch,full,memory,MB,V,vocab size,16384,6656.0009765625,6656.0009765625,6656.0009765625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:33,0.3.1
jsd,torch,full,memory,MB,V,vocab size,32768,13312.0009765625,13312.0009765625,13312.0009765625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:33,0.3.1
jsd,torch,full,memory,MB,V,vocab size,65536,26624.0,26624.0,26624.0,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:33,0.3.1
jsd,torch,full,memory,MB,V,vocab size,131072,53248.0,53248.0,53248.0,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:33,0.3.1
jsd,liger,forward,speed,ms,V,vocab size,4096,0.4651840031147003,0.4636736214160919,0.4659839868545532,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:37,0.3.1
jsd,liger,forward,speed,ms,V,vocab size,8192,0.927888035774231,0.926751971244812,0.92952960729599,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:37,0.3.1
jsd,liger,forward,speed,ms,V,vocab size,16384,10.96003246307373,10.942886352539062,10.970770835876465,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:37,0.3.1
jsd,liger,forward,speed,ms,V,vocab size,32768,22.405792236328125,22.390380859375,22.41998863220215,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:37,0.3.1
jsd,liger,forward,speed,ms,V,vocab size,65536,43.49095916748047,43.47438049316406,43.50754165649414,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:37,0.3.1
jsd,liger,forward,speed,ms,V,vocab size,131072,87.0363540649414,87.0363540649414,87.0363540649414,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:37,0.3.1
jsd,torch,forward,speed,ms,V,vocab size,4096,2.4744958877563477,2.4725184440612793,2.4764864444732666,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:38,0.3.1
jsd,torch,forward,speed,ms,V,vocab size,8192,4.8528642654418945,4.851238250732422,4.854745864868164,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:38,0.3.1
jsd,torch,forward,speed,ms,V,vocab size,16384,9.532496452331543,9.528634071350098,9.535890579223633,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:38,0.3.1
jsd,torch,forward,speed,ms,V,vocab size,32768,18.91379165649414,18.911853790283203,18.919116973876953,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:38,0.3.1
jsd,torch,forward,speed,ms,V,vocab size,65536,37.70152282714844,37.70074462890625,37.70229721069336,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:38,0.3.1
jsd,torch,forward,speed,ms,V,vocab size,131072,75.37680053710938,75.37680053710938,75.37680053710938,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:38,0.3.1
jsd,liger,full,speed,ms,V,vocab size,4096,1.2074079513549805,1.1739968061447144,1.2760319709777832,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:40,0.3.1
jsd,liger,full,speed,ms,V,vocab size,8192,2.091792106628418,2.0771327018737793,2.106553554534912,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:40,0.3.1
jsd,liger,full,speed,ms,V,vocab size,16384,12.928031921386719,12.8988676071167,12.936230659484863,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:40,0.3.1
jsd,liger,full,speed,ms,V,vocab size,32768,26.55548858642578,26.550823211669922,26.570655822753906,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:40,0.3.1
jsd,liger,full,speed,ms,V,vocab size,65536,51.6833610534668,51.6833610534668,51.6833610534668,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:40,0.3.1
jsd,liger,full,speed,ms,V,vocab size,131072,103.12793731689453,103.12793731689453,103.12793731689453,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:40,0.3.1
jsd,torch,full,speed,ms,V,vocab size,4096,5.397359848022461,5.392876625061035,5.39998722076416,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:43,0.3.1
jsd,torch,full,speed,ms,V,vocab size,8192,10.60153579711914,10.597900390625,10.60470962524414,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:43,0.3.1
jsd,torch,full,speed,ms,V,vocab size,16384,20.9442081451416,20.94247055053711,20.9469051361084,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:43,0.3.1
jsd,torch,full,speed,ms,V,vocab size,32768,42.113216400146484,42.113216400146484,42.113216400146484,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:43,0.3.1
jsd,torch,full,speed,ms,V,vocab size,65536,83.9959716796875,83.9959716796875,83.9959716796875,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:43,0.3.1
jsd,torch,full,speed,ms,V,vocab size,131072,167.94175720214844,167.94175720214844,167.94175720214844,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-02 16:21:43,0.3.1
24 changes: 16 additions & 8 deletions benchmark/scripts/benchmark_jsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,25 @@
from liger_kernel.transformers.jsd import LigerJSD


class TorchJSD(torch.nn.Module):
def __init__(self):
class TorchJSD(nn.Module):
def __init__(self, beta: float = 0.5, dtype: torch.dtype = torch.float):
super(TorchJSD, self).__init__()
self.kl = nn.KLDivLoss(reduction="batchmean", log_target=True)

def forward(self, log_p: torch.tensor, log_q: torch.tensor):
self.beta = beta
self.dtype = dtype

def forward(
self,
log_q: torch.tensor, # input
log_p: torch.tensor, # target
):
log_p, log_q = log_p.to(torch.float), log_q.to(torch.float)
log_p, log_q = log_p.view(-1, log_p.size(-1)), log_q.view(-1, log_q.size(-1))
m = 0.5 * (torch.exp(log_p) + torch.exp(log_q))
log_m = torch.log(m)
loss = 0.5 * (self.kl(log_m, log_p) + self.kl(log_m, log_q))
return loss
m = self.beta * torch.exp(log_p) + (1 - self.beta) * torch.exp(log_q)
loss = self.beta * self.kl(torch.log(m), log_p) + (1 - self.beta) * self.kl(
torch.log(m), log_q
)
return loss.to(self.dtype)


def bench_speed_jsd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
Expand Down
19 changes: 12 additions & 7 deletions src/liger_kernel/ops/jsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def _jsd_kernel(
loss_stride,
dX_ptr,
dX_stride,
beta,
n_rows,
n_cols,
BLOCK_SIZE: tl.constexpr,
Expand All @@ -37,20 +38,20 @@ def _jsd_kernel(

Q = tl.exp(X)
P = tl.exp(Y)
M = 0.5 * P + 0.5 * Q
M = beta * P + (1 - beta) * Q
log_M = tl.log(M)

loss = 0.5 * (P * Y + Q * X - 2 * M * log_M)
loss = beta * P * Y + (1 - beta) * Q * X - M * log_M
tl.store(loss_ptr + offsets, loss, mask=mask)

dX = 0.5 * Q * (X - log_M) / n_rows
dX = (1 - beta) * Q * (X - log_M) / n_rows
tl.store(dX_ptr + offsets, dX, mask=mask)


MAX_FUSED_SIZE = 65536


def jsd_forward(_input, target):
def jsd_forward(_input, target, beta):
BT, V = _input.shape
n_rows = BT
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
Expand All @@ -67,6 +68,7 @@ def jsd_forward(_input, target):
loss_stride=loss.stride(-2),
dX_ptr=dX,
dX_stride=dX.stride(-2),
beta=beta,
n_rows=n_rows,
n_cols=V,
BLOCK_SIZE=BLOCK_SIZE,
Expand All @@ -77,7 +79,7 @@ def jsd_forward(_input, target):


def jsd_backward(dX, grad_output):
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
# If jsd is the last layer, grad_output is 1.0. Skip the mul to save time
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
return dX
else:
Expand All @@ -90,7 +92,8 @@ class LigerJSDFunction(torch.autograd.Function):

Parameters:
_input (tensor): predict values with shape (BT, V) in logspace
target (tensor): gournd truth values with shape (BT, V) in logspace
target (tensor): ground truth values with shape (BT, V) in logspace
beta (float): coefficient beta of generalized JSD in the open interval (0, 1)

Returns:
loss (tensor): JSD
Expand All @@ -102,9 +105,10 @@ def forward(
ctx,
_input: torch.Tensor,
target: torch.Tensor,
beta: float = 0.5,
) -> torch.Tensor:

loss, dX = jsd_forward(_input, target)
loss, dX = jsd_forward(_input, target, beta)
ctx.save_for_backward(dX)
return loss

Expand All @@ -116,4 +120,5 @@ def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
return (
dX,
None,
None,
)
2 changes: 2 additions & 0 deletions src/liger_kernel/transformers/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
LigerFusedLinearCrossEntropyFunction,
)
from liger_kernel.ops.geglu import LigerGELUMulFunction
from liger_kernel.ops.jsd import LigerJSDFunction
from liger_kernel.ops.kl_div import LigerKLDivLossFunction
from liger_kernel.ops.layer_norm import LigerLayerNormFunction
from liger_kernel.ops.rms_norm import LigerRMSNormFunction
Expand All @@ -17,3 +18,4 @@
liger_rope = LigerRopeFunction.apply
liger_layer_norm = LigerLayerNormFunction.apply
liger_kl_div = LigerKLDivLossFunction.apply
liger_jsd = LigerJSDFunction.apply
10 changes: 7 additions & 3 deletions src/liger_kernel/transformers/jsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@


class LigerJSD(nn.Module):
def __init__(self):
def __init__(self, beta=0.5):
super().__init__()
assert (
beta > 0 and beta < 1
), f"beta must be greater than 0 and less than 1. Got: {beta}"
self.beta = beta

def forward(self, p, q):
return LigerJSDFunction.apply(p, q)
def forward(self, log_q, log_p):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the correct order of input and target (student and teacher) respectively. would it be too confusing?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, the name is a bit confusing, or we can add some descriptions here to clarify

return LigerJSDFunction.apply(log_q, log_p, self.beta)
Loading
Loading