Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions benchmark/data/all_benchmark_data.csv
Original file line number Diff line number Diff line change
Expand Up @@ -445,3 +445,39 @@ kl_div,torch,full,speed,ms,V,vocab size,16384,11.124671936035156,11.122162818908
kl_div,torch,full,speed,ms,V,vocab size,32768,23.052032470703125,23.050334930419922,23.052589416503906,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-04 12:59:48,0.2.1
kl_div,torch,full,speed,ms,V,vocab size,65536,46.063167572021484,46.05990219116211,46.06643295288086,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-04 12:59:48,0.2.1
kl_div,torch,full,speed,ms,V,vocab size,131072,92.06393432617188,92.06393432617188,92.06393432617188,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-04 12:59:48,0.2.1
jsd,liger,full,memory,MB,V,vocab size,4096,768.0029296875,768.0029296875,768.0029296875,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:17,0.3.0
jsd,liger,full,memory,MB,V,vocab size,8192,1536.0029296875,1536.0029296875,1536.0029296875,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:17,0.3.0
jsd,liger,full,memory,MB,V,vocab size,16384,3072.0048828125,3072.0048828125,3072.0048828125,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:17,0.3.0
jsd,liger,full,memory,MB,V,vocab size,32768,6144.0087890625,6144.0087890625,6144.0087890625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:17,0.3.0
jsd,liger,full,memory,MB,V,vocab size,65536,12288.0166015625,12288.0166015625,12288.0166015625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:17,0.3.0
jsd,liger,full,memory,MB,V,vocab size,131072,24576.015625,24576.015625,24576.015625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:17,0.3.0
jsd,torch,full,memory,MB,V,vocab size,4096,1664.0009765625,1664.0009765625,1664.0009765625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:19,0.3.0
jsd,torch,full,memory,MB,V,vocab size,8192,3328.0009765625,3328.0009765625,3328.0009765625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:19,0.3.0
jsd,torch,full,memory,MB,V,vocab size,16384,6656.0009765625,6656.0009765625,6656.0009765625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:19,0.3.0
jsd,torch,full,memory,MB,V,vocab size,32768,13312.0009765625,13312.0009765625,13312.0009765625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:19,0.3.0
jsd,torch,full,memory,MB,V,vocab size,65536,26624.0,26624.0,26624.0,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:19,0.3.0
jsd,torch,full,memory,MB,V,vocab size,131072,53248.0,53248.0,53248.0,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:19,0.3.0
jsd,liger,forward,speed,ms,V,vocab size,4096,0.4657920002937317,0.4644480049610138,0.4670400023460388,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:22,0.3.0
jsd,liger,forward,speed,ms,V,vocab size,8192,0.9084159731864929,0.9064639806747437,0.9099519848823547,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:22,0.3.0
jsd,liger,forward,speed,ms,V,vocab size,16384,9.939423561096191,9.933785438537598,9.945216178894043,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:22,0.3.0
jsd,liger,forward,speed,ms,V,vocab size,32768,20.06915283203125,20.05768394470215,20.087200164794922,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:22,0.3.0
jsd,liger,forward,speed,ms,V,vocab size,65536,38.88547134399414,38.880577087402344,38.89036560058594,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:22,0.3.0
jsd,liger,forward,speed,ms,V,vocab size,131072,77.7418212890625,77.7418212890625,77.7418212890625,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:22,0.3.0
jsd,torch,forward,speed,ms,V,vocab size,4096,2.1717119216918945,2.1697471141815186,2.173452854156494,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:23,0.3.0
jsd,torch,forward,speed,ms,V,vocab size,8192,4.2592315673828125,4.255411148071289,4.2608771324157715,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:23,0.3.0
jsd,torch,forward,speed,ms,V,vocab size,16384,8.363903999328613,8.359071731567383,8.36620807647705,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:23,0.3.0
jsd,torch,forward,speed,ms,V,vocab size,32768,16.591264724731445,16.588390350341797,16.595033645629883,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:23,0.3.0
jsd,torch,forward,speed,ms,V,vocab size,65536,33.06208038330078,33.06206130981445,33.06536102294922,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:23,0.3.0
jsd,torch,forward,speed,ms,V,vocab size,131072,66.0923843383789,66.0923843383789,66.0923843383789,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:23,0.3.0
jsd,liger,full,speed,ms,V,vocab size,4096,1.5683839321136475,1.4662528038024902,1.7244799137115479,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:25,0.3.0
jsd,liger,full,speed,ms,V,vocab size,8192,2.0588159561157227,2.055116891860962,2.093465566635132,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:25,0.3.0
jsd,liger,full,speed,ms,V,vocab size,16384,11.944751739501953,11.936684608459473,11.961983680725098,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:25,0.3.0
jsd,liger,full,speed,ms,V,vocab size,32768,24.27791976928711,24.254375457763672,24.299558639526367,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:25,0.3.0
jsd,liger,full,speed,ms,V,vocab size,65536,47.206687927246094,47.17191696166992,47.241458892822266,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:25,0.3.0
jsd,liger,full,speed,ms,V,vocab size,131072,94.15420532226562,94.15420532226562,94.15420532226562,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:25,0.3.0
jsd,torch,full,speed,ms,V,vocab size,4096,4.875328063964844,4.873446464538574,4.878073692321777,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:28,0.3.0
jsd,torch,full,speed,ms,V,vocab size,8192,9.582816123962402,9.57910442352295,9.58505630493164,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:28,0.3.0
jsd,torch,full,speed,ms,V,vocab size,16384,18.931264877319336,18.92802619934082,18.934911727905273,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:28,0.3.0
jsd,torch,full,speed,ms,V,vocab size,32768,38.07579040527344,38.07549285888672,38.076087951660156,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:28,0.3.0
jsd,torch,full,speed,ms,V,vocab size,65536,75.97628784179688,75.97628784179688,75.97628784179688,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:28,0.3.0
jsd,torch,full,speed,ms,V,vocab size,131072,151.8501739501953,151.8501739501953,151.8501739501953,"{""B"": 4, ""T"": 2048}",NVIDIA H100 PCIe,2024-09-25 20:52:28,0.3.0
130 changes: 130 additions & 0 deletions benchmark/scripts/benchmark_jsd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import torch
import torch.nn as nn
import triton
from utils import (
QUANTILES,
SingleBenchmarkRunInput,
SingleBenchmarkRunOutput,
_test_memory,
parse_benchmark_script_args,
run_benchmarks,
)

from liger_kernel.transformers.jsd import LigerJSD


class TorchJSD(torch.nn.Module):
def __init__(self):
super(TorchJSD, self).__init__()
self.kl = nn.KLDivLoss(reduction="batchmean", log_target=True)

def forward(self, log_p: torch.tensor, log_q: torch.tensor):
log_p, log_q = log_p.view(-1, log_p.size(-1)), log_q.view(-1, log_q.size(-1))
m = 0.5 * (torch.exp(log_p) + torch.exp(log_q))
log_m = torch.log(m)
loss = 0.5 * (self.kl(log_m, log_p) + self.kl(log_m, log_q))
return loss


def bench_speed_jsd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
V = input.x
B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"]
torch_jsd = TorchJSD()
liger_jsd = LigerJSD()

_input = torch.randn(B * T, V, requires_grad=True, device="cuda").log_softmax(
dim=-1
)
target = torch.randn(B * T, V, device="cuda").log_softmax(dim=-1)

def fwd():
if input.kernel_provider == "liger":
return liger_jsd(_input, target)
else:
return torch_jsd(_input, target)

if input.kernel_operation_mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, quantiles=QUANTILES, rep=100)
elif input.kernel_operation_mode == "backward":
y = fwd()

ms_50, ms_20, ms_80 = triton.testing.do_bench(
lambda: y.backward(retain_graph=True),
quantiles=QUANTILES,
grad_to_none=[_input],
rep=100,
)
elif input.kernel_operation_mode == "full":

def full():
y = fwd()
y.backward(retain_graph=True)

ms_50, ms_20, ms_80 = triton.testing.do_bench(
full, quantiles=QUANTILES, rep=100
)
return SingleBenchmarkRunOutput(
y_20=ms_20,
y_50=ms_50,
y_80=ms_80,
)


def bench_memory_jsd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
torch_jsd = TorchJSD()
liger_jsd = LigerJSD()

V = input.x
B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"]

_input = torch.randn(B * T, V, requires_grad=True, device="cuda").log_softmax(
dim=-1
)
target = torch.randn(B * T, V, device="cuda").log_softmax(dim=-1)

def fwd():
if input.kernel_provider == "liger":
return liger_jsd(_input, target)
else:
return torch_jsd(_input, target)

def full():
y = fwd()
y.backward(retain_graph=True)

mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)

return SingleBenchmarkRunOutput(
y_20=mem_20,
y_50=mem_50,
y_80=mem_80,
)


if __name__ == "__main__":
args = parse_benchmark_script_args()
common_args = {
"kernel_name": "jsd",
"x_name": "V",
"x_label": "vocab size",
"x_values": [2**i for i in range(12, 18)],
"kernel_providers": ["liger", "torch"],
"extra_benchmark_configs": [{"B": 4, "T": 2048}],
"overwrite": args.overwrite,
}

run_benchmarks(
bench_test_fn=bench_memory_jsd,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**common_args,
)

run_benchmarks(
bench_test_fn=bench_speed_jsd,
kernel_operation_modes=["forward", "full"],
metric_name="speed",
metric_unit="ms",
**common_args,
)
119 changes: 119 additions & 0 deletions src/liger_kernel/ops/jsd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import torch
import triton
import triton.language as tl

from liger_kernel.ops.utils import ensure_contiguous


@triton.jit
def _jsd_kernel(
X_ptr, # input in logspace, X = log Q
X_stride,
Y_ptr, # ground truth in logspace, Y = log P
Y_stride,
loss_ptr,
loss_stride,
dX_ptr,
dX_stride,
n_rows,
n_cols,
BLOCK_SIZE: tl.constexpr,
):
# JSD(P || Q) = (KL(P || M) + KL(Q || M)) / 2, M = (1/2) * (P + Q) = (1/2) * (e ^ Y + e ^ X)
# = sum(P * log P + Q * log Q - 2 * M * log M) / 2
# = sum(e ^ Y * Y + e ^ X * X - 2 * M * log M) / 2
# grad_x_i = 0.5 * Q * (X - log_M)
pid = tl.program_id(0).to(tl.int64)
X_ptr += pid * X_stride
dX_ptr += pid * dX_stride
Y_ptr += pid * Y_stride
loss_ptr += pid * loss_stride

for i in range(0, n_cols, BLOCK_SIZE):
offsets = i + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_cols
X = tl.load(X_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32)
Y = tl.load(Y_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32)

Q = tl.exp(X)
P = tl.exp(Y)
M = 0.5 * P + 0.5 * Q
log_M = tl.log(M)

loss = 0.5 * (P * Y + Q * X - 2 * M * log_M)
tl.store(loss_ptr + offsets, loss, mask=mask)

dX = 0.5 * Q * (X - log_M) / n_rows
tl.store(dX_ptr + offsets, dX, mask=mask)


MAX_FUSED_SIZE = 65536


def jsd_forward(_input, target):
BT, V = _input.shape
n_rows = BT
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
# non reduction loss
loss = torch.zeros(_input.shape, dtype=torch.float32, device=_input.device)
dX = torch.empty_like(_input)

_jsd_kernel[(n_rows,)](
X_ptr=_input, # input in logspace, X = log Q
X_stride=_input.stride(-2),
Y_ptr=target, # ground truth in logspace, Y = log P
Y_stride=target.stride(-2),
loss_ptr=loss,
loss_stride=loss.stride(-2),
dX_ptr=dX,
dX_stride=dX.stride(-2),
n_rows=n_rows,
n_cols=V,
BLOCK_SIZE=BLOCK_SIZE,
)
# reduction == "batchmean"
loss = torch.sum(loss) / n_rows
return loss.to(_input.dtype), dX


def jsd_backward(dX, grad_output):
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
return dX
else:
return grad_output * dX


class LigerJSDFunction(torch.autograd.Function):
"""
Class implementing the forward and backward pass for the JS Divergence using Triton, as defined by the following formula:

Parameters:
_input (tensor): predict values with shape (BT, V) in logspace
target (tensor): gournd truth values with shape (BT, V) in logspace

Returns:
loss (tensor): JSD
"""

@staticmethod
@ensure_contiguous
def forward(
ctx,
_input: torch.Tensor,
target: torch.Tensor,
) -> torch.Tensor:

loss, dX = jsd_forward(_input, target)
ctx.save_for_backward(dX)
return loss

@staticmethod
@ensure_contiguous
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
(dX,) = ctx.saved_tensors
dX = jsd_backward(dX, grad_output)
return (
dX,
None,
)
11 changes: 11 additions & 0 deletions src/liger_kernel/transformers/jsd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import torch.nn as nn

from liger_kernel.ops.jsd import LigerJSDFunction


class LigerJSD(nn.Module):
def __init__(self):
super().__init__()

def forward(self, p, q):
return LigerJSDFunction.apply(p, q)
Loading
Loading