Skip to content

Commit 6b372ef

Browse files
authored
Fix/lora delta serving (#19)
* wip: fix output correctness * fix output correctness * Remove debug print statement in toppings_manager.py
1 parent 0e18096 commit 6b372ef

File tree

9 files changed

+95
-45
lines changed

9 files changed

+95
-45
lines changed

scratchpad/managers/toppings_manager.py

Lines changed: 61 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,6 @@ def init_toppings(self):
124124
self.origin_target_modules = set()
125125
for name, top in self.available_toppings.items():
126126
self.configs[name] = ToppingConfig(topping_type=top[0], path=top[1])
127-
128127
self.origin_target_modules = set(self.origin_target_modules) | set(
129128
self.configs[name].hf_config["target_modules"]
130129
)
@@ -134,6 +133,8 @@ def init_toppings(self):
134133
self.base_model.get_module_name(module)
135134
for module in self.origin_target_modules
136135
}
136+
# remove down_proj from target modules
137+
logger.info(f"Target modules: {self.target_modules}")
137138
else:
138139
logger.warning(
139140
"WARNING: get_module_name() is not defined, "
@@ -166,6 +167,7 @@ def init_toppings(self):
166167
self.lora_id = {}
167168
self.deltas: List[DeltaAdapter] = []
168169
self.delta_id = {}
170+
169171
for name in self.available_toppings.keys():
170172
t_type = self.available_toppings[name][0]
171173
logger.info(f"Loading {t_type} {name}")
@@ -189,10 +191,15 @@ def init_toppings(self):
189191
self.deltas[-1].initialize_weights()
190192

191193
# misc lora configs
192-
self.max_lora_dim = max(
193-
[x.hf_config["r"] for x in self.configs.values() if "r" in x.hf_config]
194-
)
195-
self.scaling = self.loras[0].scaling
194+
self.max_lora_dim = [
195+
x.hf_config["r"] for x in self.configs.values() if "r" in x.hf_config
196+
]
197+
if len(self.max_lora_dim) == 0:
198+
self.max_lora_dim = 0
199+
self.scaling = 0
200+
else:
201+
self.max_lora_dim = max(self.max_lora_dim)
202+
self.scaling = self.loras[0].scaling
196203
# FIXME remove the restrictions
197204
assert all(
198205
x.hf_config["r"] == self.max_lora_dim
@@ -215,6 +222,9 @@ def print_available_toppings(self):
215222
def set_topping_module(self, module_name, module):
216223
topping_module = get_topping_layer(module)
217224
replace_submodule(self.base_model, module_name, topping_module)
225+
logger.info(
226+
f"Replaced {module_name} with topping module {type(topping_module)}"
227+
)
218228
return topping_module
219229

220230
def prepare_topping_batch(self, forward_batch: ForwardBatch):
@@ -288,7 +298,6 @@ def prepare_topping_batch(self, forward_batch: ForwardBatch):
288298
dtype=torch.int64,
289299
device=forward_batch.input_ids.device,
290300
)
291-
print(f"weight_indices: {weight_indices}")
292301
for module_name, module in self.topping_modules:
293302
layer_id = get_layer_id(module_name)
294303
if "lm_head" in module_name:
@@ -327,6 +336,42 @@ def prepare_topping_batch(self, forward_batch: ForwardBatch):
327336
self.scales_buffer["kv_proj"][layer_id][:len_deltas],
328337
),
329338
)
339+
elif "down_proj" in module_name:
340+
weight_name = self.get_weight_name(module_name, 0)
341+
module.set_topping_info(
342+
bs,
343+
weight_indices,
344+
lora_buffer=(
345+
(
346+
self.A_buffer[weight_name][layer_id][:len_loras]
347+
if weight_name in self.A_buffer
348+
else None
349+
),
350+
(
351+
self.B_buffer[weight_name][layer_id][:len_loras]
352+
if weight_name in self.B_buffer
353+
else None
354+
),
355+
),
356+
delta_buffer=(
357+
(
358+
self.qweight_buffer[weight_name][layer_id][:len_deltas]
359+
if weight_name in self.qweight_buffer
360+
else None
361+
),
362+
(
363+
self.meta_buffer[weight_name][layer_id][:len_deltas]
364+
if weight_name in self.meta_buffer
365+
else None
366+
),
367+
(
368+
self.scales_buffer[weight_name][layer_id][:len_deltas]
369+
if weight_name in self.scales_buffer
370+
else None
371+
),
372+
),
373+
debug=False,
374+
)
330375
else:
331376
weight_name = self.get_weight_name(module_name, 0)
332377
module.set_topping_info(
@@ -375,6 +420,7 @@ def load_topping(self, uid, buffer_id):
375420
"""
376421
This function loads topping from CPU -> GPU memory
377422
"""
423+
378424
if uid not in self.available_toppings:
379425
logger.error(f"Topping {uid} not registered")
380426
return
@@ -420,6 +466,7 @@ def _load_delta(self, uid, buffer_id):
420466

421467
for i in range(num_layer):
422468
layer_weights = self.deltas[self.delta_id[uid]].layers[i].weights
469+
# load to buffer space
423470
for name, weights in layer_weights.items():
424471
if (
425472
"qkv_proj" in name
@@ -445,7 +492,7 @@ def _load_delta(self, uid, buffer_id):
445492
self.scales_buffer[kv_proj_name][i][buffer_id].copy_(
446493
weights[:, q_dim:]
447494
)
448-
else:
495+
elif "meta" in name:
449496
q_proj_name = "q_proj"
450497
kv_proj_name = "kv_proj"
451498
q_dim = self.meta_buffer[q_proj_name][i][buffer_id].shape[0]
@@ -455,23 +502,30 @@ def _load_delta(self, uid, buffer_id):
455502
self.meta_buffer[kv_proj_name][i][buffer_id].copy_(
456503
weights[q_dim:, :]
457504
)
505+
else:
506+
print("Unknown delta weight name: {name}")
458507
else:
459508
if "qweight" in name:
460509
weight_name = self.get_delta_weight_name(name)
461510
if weight_name:
462511
self.qweight_buffer[weight_name][i][buffer_id].copy_(
463512
weights
464513
)
514+
else:
515+
print("Unknown delta weight name: {name}")
516+
465517
elif "scales" in name:
466518
weight_name = self.get_delta_weight_name(name)
467519
if weight_name:
468520
self.scales_buffer[weight_name][i][buffer_id].copy_(weights)
521+
469522
elif "meta" in name:
470523
weight_name = self.get_delta_weight_name(name)
471524
if weight_name:
472525
self.meta_buffer[weight_name][i][buffer_id].copy_(weights)
473526
else:
474527
print("Unknown delta weight name: {name}")
528+
raise ValueError(f"Unknown delta weight name: {name}")
475529

476530
for name, outside_module in self.deltas[
477531
self.delta_id[uid]

scratchpad/memory/topping_pool.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -141,32 +141,32 @@ def __init__(
141141
stacked_dim = dimensions[1] * stack_factor
142142

143143
self.qweight_buffer[module] = [
144-
torch.zeros(
144+
torch.empty(
145145
self.max_toppings_per_batch,
146146
dimensions[0] // (pack_factor * sparse_factor * 2),
147147
stacked_dim * 2,
148148
dtype=delta_dtypes["qweight"],
149149
device="cuda",
150150
)
151-
for _ in range(num_layers)
151+
for i in range(num_layers)
152152
]
153153
self.meta_buffer[module] = [
154-
torch.zeros(
154+
torch.empty(
155155
self.max_toppings_per_batch,
156156
stacked_dim,
157157
dimensions[0] // (pack_factor * sparse_factor),
158158
dtype=delta_dtypes["meta"],
159159
device="cuda",
160160
)
161-
for _ in range(num_layers)
161+
for i in range(num_layers)
162162
]
163163
self.scales_buffer[module] = [
164-
torch.zeros(
164+
torch.empty(
165165
self.max_toppings_per_batch,
166166
1,
167167
stacked_dim,
168168
dtype=delta_dtypes["scales"],
169169
device="cuda",
170170
)
171-
for _ in range(num_layers)
171+
for i in range(num_layers)
172172
]

scratchpad/model_executor/forward_info.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,5 +212,4 @@ def init_new(
212212
# Init lora information
213213
if model_runner.server_args.enable_toppings:
214214
model_runner.topping_manager.prepare_topping_batch(ret)
215-
216215
return ret

scratchpad/nn/models/llama.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -302,11 +302,11 @@ def forward(
302302
def get_hidden_dim(self, module_name):
303303
if module_name in ["q_proj", "o_proj", "qkv_proj"]:
304304
return self.config.hidden_size, self.config.hidden_size
305-
elif module_name in ["kv_proj"]:
305+
elif module_name in ["kv_proj", "k_proj", "v_proj"]:
306306
return self.config.hidden_size, self.config.hidden_size // (
307307
self.config.num_attention_heads // self.config.num_key_value_heads
308308
)
309-
elif module_name == "gate_up_proj":
309+
elif module_name in ["gate_up_proj", "up_proj", "gate_proj"]:
310310
return self.config.hidden_size, self.config.intermediate_size
311311
elif module_name == "down_proj":
312312
return self.config.intermediate_size, self.config.hidden_size

scratchpad/nn/toppings/topping_layer.py

Lines changed: 12 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import torch
44
from torch import nn
5+
from torch.nn import functional as F
56
from typing import Union
67
from scratchpad.nn.layers.vocab_parallel_embedding import (
78
ParallelLMHead,
@@ -137,6 +138,7 @@ def forward(self, input_: torch.Tensor):
137138
qweight_dim = self.qweight_buffer.shape[2] // 2
138139
metas_dim = self.metas_buffer.shape[1] // 2
139140
scales_dim = self.scales_buffer.shape[2] // 2
141+
140142
for i in range(2):
141143
output = ldmm(
142144
indices=self.weight_indices,
@@ -198,18 +200,6 @@ def set_topping_info(
198200
self.meta_buffer_kv = torch.zeros(0, 0, 0)
199201
self.scales_buffer_kv = torch.zeros(0, 0, 0)
200202

201-
# q,k,v have the same input dimensions
202-
# k,v have the same output dimensions
203-
# q has a different output dimension than k,v
204-
205-
# (A_buffer_qkv: bsz, dim1, rank*2)
206-
# (B_buffer_q: bsz, rank, dim2*2)
207-
# (B_buffer_kv: bsz, rank, dim3*2)
208-
209-
# (qweight_buffer: bsz,_, _*3)
210-
# (meta_buffer: bsz,_, _*3)
211-
# (scales_buffer: bsz, _, _*3)
212-
213203
def forward(self, input_: torch.Tensor):
214204
base_output = self.base_layer(input_)[0]
215205
rank = self.A_buffer_qkv.shape[2] // 3
@@ -248,25 +238,27 @@ def forward(self, input_: torch.Tensor):
248238
],
249239
)
250240
base_output[:, i * b_dim_kv : (i + 1) * b_dim_kv] += output
251-
252241
return base_output, None
253242

254243

255244
class RowParallelLinearWithTopping(BaseLayerWithTopping):
256245
def __init__(self, base_layer: RowParallelLinear, config: Dict) -> None:
257246
super().__init__(base_layer, config)
258247

259-
def set_topping_info(self, bs, weight_indices, lora_buffer=None, delta_buffer=None):
248+
def set_topping_info(
249+
self, bs, weight_indices, lora_buffer=None, delta_buffer=None, debug=False
250+
):
260251
self.weight_indices = weight_indices
261252
self.bs = bs
262-
if lora_buffer != None:
253+
self.debug = debug
254+
if lora_buffer is not None:
263255
self.A_buffer = lora_buffer[0]
264256
self.B_buffer = lora_buffer[1]
265257
else:
266258
self.A_buffer = torch.zeros(0, 0, 0)
267259
self.B_buffer = torch.zeros(0, 0, 0)
268260

269-
if delta_buffer != None:
261+
if delta_buffer is not None:
270262
self.qweight_buffer = delta_buffer[0]
271263
self.metas_buffer = delta_buffer[1]
272264
self.scales_buffer = delta_buffer[2]
@@ -276,7 +268,7 @@ def set_topping_info(self, bs, weight_indices, lora_buffer=None, delta_buffer=No
276268
self.scales_buffer = torch.zeros(0, 0, 0)
277269

278270
def forward(self, input_: torch.Tensor):
279-
base_output = torch.matmul(input_, self.base_layer.weight.T)
271+
base_output = F.linear(input_, self.base_layer.weight, self.base_layer.bias)
280272
delta_output = ldmm(
281273
indices=self.weight_indices,
282274
x=input_,
@@ -285,16 +277,10 @@ def forward(self, input_: torch.Tensor):
285277
DeltaW=self.qweight_buffer,
286278
metas=self.metas_buffer,
287279
ss=self.scales_buffer,
280+
debug=self.debug,
288281
)
289-
print(f"weight_indices: {self.weight_indices}")
290-
print(f"A_buffer.shape: {self.A_buffer.shape}")
291-
print(f"base_output.shape: {base_output.shape}")
292-
print(f"delta_output.shape: {delta_output.shape}")
293-
print(f"base: {base_output}")
294-
print(f"max delta: {torch.max(abs(delta_output))}")
295-
assert base_output.shape == delta_output.shape
282+
# assert base_output.shape == delta_output.shape
296283
output_ = base_output + delta_output
297-
# output_ = base_output
298284
if not self.base_layer.skip_bias_add:
299285
output = (
300286
output_ + self.base_layer.bias
@@ -364,7 +350,7 @@ def _get_logits(
364350
assert len(unique_indices) == 1, f"Prefill stage only supports one index"
365351
w_idx = unique_indices[0]
366352
if w_idx == -1:
367-
w = weight.T
353+
w = weight
368354
else:
369355
w = self.delta_buffer[w_idx]
370356
output = nn.functional.linear(last_hidden, w)

scratchpad/nn/toppings/topping_module.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def initialize_weights(self):
9696
delta_config = json.load(f)
9797
self.pack_factor = 32 // delta_config["compress_config"]["bits"]
9898
self.sparse_factor = int(1 / delta_config["compress_config"]["sparsity"])
99+
99100
weight_path = os.path.join(local_path, "deltazip-compressed.safetensors")
100101
with st.safe_open(weight_path, framework="torch", device="cpu") as f:
101102
keys = f.keys()
Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
11
export PROMETHEUS_MULTIPROC_DIR=.local
22
sp serve meta-llama/Llama-3.2-1B-Instruct --host 0.0.0.0 --port 8080 \
3-
--enable-system-controller --use-heterogeneous-pool \
4-
--init-toppings lora:ketchup123/llama-3.2-1B-instruct-gsm8k:ketchup123/llama-3.2-1B-instruct-gsm8k
3+
--enable-system-controller \
4+
--use-heterogeneous-pool \
5+
--enable-toppings \
6+
--init-toppings lora:ketchup123/llama-3.2-1B-instruct-gsm8k:ketchup123/llama-3.2-1B-instruct-gsm8k,delta:deltazip/meta-llama.Llama-3.2-1B-Instruct.4b_2n4m_128bs:deltazip/meta-llama.Llama-3.2-1B-Instruct.4b_2n4m_128bs-1,delta:deltazip/meta-llama.Llama-3.2-1B-Instruct.4b_2n4m_128bs:deltazip/meta-llama.Llama-3.2-1B-Instruct.4b_2n4m_128bs-2 \
7+
--attention-backend triton \
8+
--sampling-backend pytorch \
9+
--max-toppings-per-batch 2 \
10+
--disable-cuda-graph

scripts/serve_llama_1b_with_toppings_torch.sh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
export PROMETHEUS_MULTIPROC_DIR=.local
2-
sp serve meta-llama/Llama-3.2-3B-Instruct --host 0.0.0.0 --port 8080 \
2+
sp serve meta-llama/Llama-3.2-3B --host 0.0.0.0 --port 8080 \
33
--enable-system-controller \
4+
--tokenizer-path meta-llama/Llama-3.2-3B-Instruct \
45
--use-heterogeneous-pool \
6+
--enable-toppings \
57
--init-toppings lora:eltorio/Llama-3.2-3B-appreciation:eltorio/Llama-3.2-3B-appreciation-1,lora:eltorio/Llama-3.2-3B-appreciation:eltorio/Llama-3.2-3B-appreciation-2,delta:deltazip/meta-llama.Llama-3.2-3B-Instruct.4b_2n4m_128bs:deltazip/meta-llama.Llama-3.2-3B-Instruct.4b_2n4m_128bs-1,delta:deltazip/meta-llama.Llama-3.2-3B-Instruct.4b_2n4m_128bs:deltazip/meta-llama.Llama-3.2-3B-Instruct.4b_2n4m_128bs-2 \
68
--attention-backend triton \
79
--sampling-backend pytorch \

tools/utils/test_concurrency.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ def main(args):
1515
# "eltorio/Llama-3.2-3B-appreciation-2",
1616
"deltazip/meta-llama.Llama-3.2-3B-Instruct.4b_2n4m_128bs-1",
1717
"deltazip/meta-llama.Llama-3.2-3B-Instruct.4b_2n4m_128bs-2",
18+
# "meta-llama/Llama-3.2-3B"
1819
]
1920
prompts = np.random.choice(prompts, args.num_req, replace=True)
2021
models = np.random.choice(models, args.num_req, replace=True)
@@ -29,6 +30,7 @@ def main(args):
2930
]
3031
responses = asyncio.run(make_requests(args.endpoint, reqs))
3132
for resp in responses:
33+
print(f"---")
3234
print(resp["choices"][0]["message"]["content"])
3335

3436

0 commit comments

Comments
 (0)