@@ -383,104 +383,6 @@ def weight_loader(self, param: torch.nn.Parameter,
383
383
tp_rank = tp_rank )
384
384
return
385
385
386
- def _load_input_scales (self , param : torch .nn .Parameter ,
387
- loaded_weight : torch .Tensor , expert_id : int ):
388
- param_data = param .data
389
-
390
- # Input scales can be loaded directly and should be equal.
391
- param_data [expert_id ] = loaded_weight
392
-
393
- def weight_loader (self , param : torch .nn .Parameter ,
394
- loaded_weight : torch .Tensor , weight_name : str ,
395
- shard_id : str , expert_id : int ) -> None :
396
-
397
- if shard_id not in ("w1" , "w2" , "w3" ):
398
- raise ValueError (f"shard_id must be ['w1','w2','w3'] but "
399
- f"got { shard_id } ." )
400
-
401
- WEIGHT_SCALE_SUPPORTED = [
402
- e .value for e in FusedMoeWeightScaleSupported
403
- ]
404
- # Fetch the dim to shard the parameter/loaded weight
405
- # based on the shard id. This will be whatever
406
- # dimension intermediate_size is used.
407
- SHARD_ID_TO_SHARDED_DIM = {"w1" : 0 , "w2" : 1 , "w3" : 0 }
408
-
409
- expert_data = param .data [expert_id ]
410
- tp_rank = get_tensor_model_parallel_rank ()
411
-
412
- # is_transposed: whether or not the parameter is transposed on disk
413
- # If transposed, the loaded weight will be transposed and the dim
414
- # to shard the loaded weight will be flipped.
415
- is_transposed = getattr (param , "is_transposed" , False )
416
- shard_dim = SHARD_ID_TO_SHARDED_DIM [shard_id ]
417
- if is_transposed :
418
- loaded_weight = loaded_weight .t ().contiguous ()
419
- shard_dim = ~ shard_dim
420
-
421
- # Case weight_scales
422
- if "weight_scale" in weight_name :
423
- # load the weight scaling based on the quantization scheme
424
- # supported weight scales can be found in
425
- # FusedMoeWeightScaleSupported
426
- # TODO @dsikka: once hardened, refactor to use vLLM Parameters
427
- # specific to each case
428
- quant_method = getattr (param , "quant_method" , None )
429
- if quant_method == FusedMoeWeightScaleSupported .CHANNEL .value :
430
- self ._load_per_channel_weight_scale (
431
- shard_id = shard_id ,
432
- shard_dim = shard_dim ,
433
- loaded_weight = loaded_weight ,
434
- expert_data = expert_data ,
435
- tp_rank = tp_rank )
436
- elif quant_method == FusedMoeWeightScaleSupported .GROUP .value :
437
- self ._load_model_weight_or_group_weight_scale (
438
- shard_id = shard_id ,
439
- shard_dim = shard_dim ,
440
- loaded_weight = loaded_weight ,
441
- expert_data = expert_data ,
442
- tp_rank = tp_rank )
443
- elif quant_method == FusedMoeWeightScaleSupported .TENSOR .value :
444
- self ._load_per_tensor_weight_scale (shard_id = shard_id ,
445
- param = param ,
446
- loaded_weight = loaded_weight ,
447
- expert_id = expert_id )
448
- else :
449
- raise ValueError (
450
- f"quant method must be one of { WEIGHT_SCALE_SUPPORTED } " )
451
- return
452
-
453
- if "weight_shape" in weight_name :
454
- self ._load_single_value (param = param ,
455
- loaded_weight = loaded_weight ,
456
- expert_id = expert_id )
457
- return
458
-
459
- # Case input scale
460
- if "input_scale" in weight_name :
461
- # Note: input_scale loading is only supported for fp8
462
- if param .data [expert_id ] != 1 and (param .data [expert_id ] -
463
- loaded_weight ).abs () > 1e-5 :
464
- raise ValueError (
465
- "input_scales of w1 and w3 of a layer "
466
- f"must be equal. But got { param .data [expert_id ]} "
467
- f"vs. { loaded_weight } " )
468
-
469
- self ._load_single_value (param = param ,
470
- loaded_weight = loaded_weight ,
471
- expert_id = expert_id )
472
- return
473
-
474
- # Case model weights
475
- if "weight" in weight_name :
476
- self ._load_model_weight_or_group_weight_scale (
477
- shard_id = shard_id ,
478
- shard_dim = shard_dim ,
479
- loaded_weight = loaded_weight ,
480
- expert_data = expert_data ,
481
- tp_rank = tp_rank )
482
- return
483
-
484
386
@staticmethod
485
387
def select_experts (hidden_states : torch .Tensor ,
486
388
router_logits : torch .Tensor ,
@@ -574,4 +476,4 @@ def _load_fp8_scale(self, param: torch.nn.Parameter,
574
476
param_data [expert_id ][idx ] = loaded_weight
575
477
# If we are in the row parallel case (down_proj)
576
478
else :
577
- param_data [expert_id ] = loaded_weight
479
+ param_data [expert_id ] = loaded_weight
0 commit comments