Efficient small-scale H-Net implementation. See details here.
In an external project:
src=git+https://github.com/main-horse/hnet-impl
uv add "hnet-impl @ $src"
uv add "hnet-impl[build] @ $src" --no-cache --no-build-isolationFor local development:
git clone https://github.com/main-horse/hnet-impl && cd hnet-impl
uv sync && uv sync --extra buildimport torch
from hnet_impl import HNetLM, HNetConfig, completion_sync
from hnet_impl.torchisms import rand_njt_iids
c = HNetConfig.create_reasonable_config(D=[512,1024], arch=['m4','T9'])
with torch.device('cuda'): m = HNetLM(c).bfloat16()
# inference
iids = rand_njt_iids(docs=16, slen=range(128,1024)).cuda()
logits,_ = m(iids)
# training
lbls = iids.long() # i.e. torch.randint_like(iids)
(celoss,_),extra = m(iids,lbls)To train a simple H-net for dynamically encoding letter repetition:
$ uv run torchrun --nproc-per-node 2 --standalone example.py
0%| | 0/99 [00:32<?, ?it/s]
[failed to decode UTF-8]
step=0: l_avg.item()=5.698 l_ratio.item()=1.064
...
step=990: l_avg.item()=0.462 l_ratio.item()=1.016
99%|█████████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 98/99 [00:02<00:00, 38.86it/s]
accddddddeeeeffggggggggghhhhhhhhhhiiiiijkkkkkkkkkklllllllllmmmmmmmmmmnnnoooooopppppqqrrrrrrrssssssWhere the resultant text should be rendered like:
kernels:
uv run -m hnet_impl.norm
uv run -m hnet_impl.linTORCH_LOGS=recompiles uv run -m hnet_impl.xf --s0=9289 --s1=2048 --d0=512 --d1=768 --lm=4 --lt=102stage_XL fwd equiv check:
# NOTE: download hnet_2stage_XL from somewhere first.
# cp /path/to/hnet/hnet_2stage_XL.pt /path/to/hnet/config/hnet_2stage_XL.json .
uv run -m hnet_impl.modeling_hnetThis package is reasonably useful for researchers who want to train unmodified text-only H-Nets on toy single-node tasks.