Skip to content

Commit 1cbe033

Browse files
authored
Merge pull request karpathy#85 from python273/export-llama-without-llama
Export llama without llama
2 parents 834233e + 77745f2 commit 1cbe033

File tree

1 file changed

+88
-65
lines changed

1 file changed

+88
-65
lines changed

export_meta_llama_bin.py

Lines changed: 88 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,91 +1,114 @@
11
"""
22
This script exports the Llama 2 weights in llama2c.bin format.
3+
"""
4+
import sys
5+
import struct
6+
from pathlib import Path
7+
import json
38

4-
Place it into the root directory of:
5-
https://github.com/facebookresearch/llama
9+
import torch
610

7-
And then run it similar to their other examples, via torchrun sadly:
8-
torchrun --nproc_per_node 1 export_meta_llama_bin.py
9-
"""
11+
from model import precompute_freqs_cis
1012

11-
from llama import Llama
1213

13-
# -----------------------------------------------------------------------------
14-
def export(self, filepath='model.bin'):
14+
def export(p, state_dict, filepath='model.bin'):
1515
"""export the model weights in fp32 into .bin file to be read from C"""
16-
1716
f = open(filepath, 'wb')
18-
import struct
19-
import numpy as np
2017

21-
def serialize(t):
22-
d = t.detach().cpu().view(-1).numpy().astype(np.float32)
23-
b = struct.pack(f'{len(d)}f', *d)
24-
f.write(b)
18+
def serialize(key):
19+
print(f"writing {key}...")
20+
t = state_dict[key].contiguous().view(-1).type(torch.float32).numpy()
21+
f.write(memoryview(t))
22+
del state_dict[key]
2523

2624
# first write out the header
27-
hidden_dim = self.layers[0].feed_forward.w1.weight.shape[0]
28-
p = self.params
29-
n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads
30-
header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads,
31-
n_kv_heads, -p.vocab_size, p.max_seq_len)
25+
hidden_dim = state_dict['layers.0.feed_forward.w1.weight'].shape[0]
26+
p['vocab_size'] = 32000
27+
p['max_seq_len'] = 2048
28+
29+
n_kv_heads = p.get('n_kv_heads') or p['n_heads']
30+
header = struct.pack(
31+
'iiiiiii',
32+
p['dim'], hidden_dim, p['n_layers'], p['n_heads'],
33+
n_kv_heads, -p['vocab_size'], p['max_seq_len']
34+
)
3235
# NOTE ABOVE: -ve vocab_size is indicating that the classifier weights are present
3336
# in the checkpoint and should be loaded.
3437
f.write(header)
3538

3639
# next write out the embedding weights
3740
print("writing tok_embeddings...")
38-
serialize(self.tok_embeddings.weight)
39-
41+
serialize('tok_embeddings.weight')
42+
4043
# now all the layers
4144
# attention weights
42-
for i, layer in enumerate(self.layers):
43-
print(f"writing attention_norm layer {i}...")
44-
serialize(layer.attention_norm.weight)
45-
for i, layer in enumerate(self.layers):
46-
print(f"writing attention.wq layer {i}...")
47-
serialize(layer.attention.wq.weight)
48-
for i, layer in enumerate(self.layers):
49-
print(f"writing attention.wk layer {i}...")
50-
serialize(layer.attention.wk.weight)
51-
for i, layer in enumerate(self.layers):
52-
print(f"writing attention.wv layer {i}...")
53-
serialize(layer.attention.wv.weight)
54-
for i, layer in enumerate(self.layers):
55-
print(f"writing attention.wo layer {i}...")
56-
serialize(layer.attention.wo.weight)
45+
for i in range(p['n_layers']): serialize(f'layers.{i}.attention_norm.weight')
46+
for i in range(p['n_layers']): serialize(f'layers.{i}.attention.wq.weight')
47+
for i in range(p['n_layers']): serialize(f'layers.{i}.attention.wk.weight')
48+
for i in range(p['n_layers']): serialize(f'layers.{i}.attention.wv.weight')
49+
for i in range(p['n_layers']): serialize(f'layers.{i}.attention.wo.weight')
5750
# ffn weights
58-
for i, layer in enumerate(self.layers):
59-
print(f"writing ffn_norm layer {i}...")
60-
serialize(layer.ffn_norm.weight)
61-
for i, layer in enumerate(self.layers):
62-
print(f"writing feed_forward.w1 layer {i}...")
63-
serialize(layer.feed_forward.w1.weight)
64-
for i, layer in enumerate(self.layers):
65-
print(f"writing feed_forward.w2 layer {i}...")
66-
serialize(layer.feed_forward.w2.weight)
67-
for i, layer in enumerate(self.layers):
68-
print(f"writing feed_forward.w3 layer {i}...")
69-
serialize(layer.feed_forward.w3.weight)
51+
for i in range(p['n_layers']): serialize(f'layers.{i}.ffn_norm.weight')
52+
for i in range(p['n_layers']): serialize(f'layers.{i}.feed_forward.w1.weight')
53+
for i in range(p['n_layers']): serialize(f'layers.{i}.feed_forward.w2.weight')
54+
for i in range(p['n_layers']): serialize(f'layers.{i}.feed_forward.w3.weight')
55+
7056
# final rmsnorm
71-
print("writing final rmsnorm, classifier and freq_cis...")
72-
serialize(self.norm.weight)
57+
serialize('norm.weight')
7358
# freqs_cis
74-
serialize(self.freqs_cis.real[:p.max_seq_len])
75-
serialize(self.freqs_cis.imag[:p.max_seq_len])
59+
freqs_cis = precompute_freqs_cis(p['dim'] // p['n_heads'], p['max_seq_len'] * 2)
60+
state_dict['freqs_cis.real'] = freqs_cis.real[:p['max_seq_len']]
61+
state_dict['freqs_cis.imag'] = freqs_cis.imag[:p['max_seq_len']]
62+
serialize('freqs_cis.real')
63+
serialize('freqs_cis.imag')
64+
7665
# finally write the output weights
77-
serialize(self.output.weight)
66+
serialize('output.weight')
7867

79-
# write to binary file
8068
f.close()
8169
print(f"wrote {filepath}")
82-
# -----------------------------------------------------------------------------
83-
84-
# init Llama as normal
85-
generator = Llama.build(
86-
ckpt_dir="llama-2-7b",
87-
tokenizer_path="tokenizer.model",
88-
max_seq_len=4096,
89-
max_batch_size=1,
90-
)
91-
export(generator.model, "llama2_7b.bin")
70+
71+
72+
def concat_weights(models):
73+
state_dict = {}
74+
for name in list(models[0]):
75+
tensors = [model[name] for model in models]
76+
if len(tensors) == 1 or len(tensors[0].shape) == 1:
77+
state_dict[name] = tensors[0]
78+
continue
79+
is_axis_1 = (
80+
name.startswith('tok_embeddings.')
81+
or name.endswith('.attention.wo.weight')
82+
or name.endswith('.feed_forward.w2.weight')
83+
)
84+
axis = 1 if is_axis_1 else 0
85+
state_dict[name] = torch.cat(tensors, dim=axis)
86+
for model in models:
87+
del model[name]
88+
return state_dict
89+
90+
91+
def load_and_export(model_path, output_path):
92+
with open(model_path + 'params.json') as f:
93+
params = json.load(f)
94+
print(params)
95+
96+
model_paths = sorted(list(Path(model_path).glob('consolidated.*.pth')))
97+
models = []
98+
for i in model_paths:
99+
print(f'Loading {i}')
100+
models.append(torch.load(i, map_location='cpu'))
101+
102+
state_dict = concat_weights(models)
103+
del models
104+
export(params, state_dict, output_path)
105+
106+
107+
if __name__ == '__main__':
108+
if len(sys.argv) == 1:
109+
print('[Llama model folder path] [output path]')
110+
exit()
111+
112+
model_path = sys.argv[1]
113+
output_path = sys.argv[2]
114+
load_and_export(model_path, output_path)

0 commit comments

Comments
 (0)