@@ -682,7 +682,7 @@ def autoround_quantize(
682682 enable_full_range : bool = False , ##for symmetric, TODO support later
683683 bs : int = 8 ,
684684 amp : bool = True ,
685- device = "cuda:0" ,
685+ device = None ,
686686 lr_scheduler = None ,
687687 dataloader = None , ## to support later
688688 dataset_name : str = "NeelNanda/pile-10k" ,
@@ -703,7 +703,6 @@ def autoround_quantize(
703703 dynamic_max_gap : int = - 1 ,
704704 data_type : str = "int" , ##only support data_type
705705 scale_dtype = "fp16" ,
706- export_args : dict = {"format" : None , "inplace" : True },
707706 ** kwargs ,
708707):
709708 """Run autoround weight-only quantization.
@@ -726,8 +725,8 @@ def autoround_quantize(
726725 }
727726 enable_full_range (bool): Whether to enable full range quantization (default is False).
728727 bs (int): Batch size for training (default is 8).
729- amp (bool): Whether to use automatic mixed precision (default is True).
730- device: The device to be used for tuning (default is "cuda:0") .
728+ amp (bool): Whether to use automatic mixed precision (default is True). Automatically detect and set.
729+ device: The device to be used for tuning (default is None). Automatically detect and set .
731730 lr_scheduler: The learning rate scheduler to be used.
732731 dataloader: The dataloader for input data (to be supported in future).
733732 dataset_name (str): The default dataset name (default is "NeelNanda/pile-10k").
@@ -747,8 +746,6 @@ def autoround_quantize(
747746 not_use_best_mse (bool): Whether to use mean squared error (default is False).
748747 dynamic_max_gap (int): The dynamic maximum gap (default is -1).
749748 data_type (str): The data type to be used (default is "int").
750- export_args (dict): The arguments for exporting compressed model, default is {"format": None, "inplace": True}.
751- Supported format: "itrex", "auto_gptq".
752749 **kwargs: Additional keyword arguments.
753750
754751 Returns:
@@ -790,11 +787,4 @@ def autoround_quantize(
790787 ** kwargs ,
791788 )
792789 qdq_model , weight_config = rounder .quantize ()
793- if export_args ["format" ] is not None :
794- output_dir = export_args .get ("output_dir" , None )
795- format = export_args ["format" ]
796- inplace = export_args .get ("inplace" , True )
797- use_triton = export_args .get ("use_triton" , False )
798- model = rounder .save_quantized (output_dir = output_dir , format = format , inplace = inplace , use_triton = use_triton )
799- return model , weight_config
800790 return qdq_model , weight_config
0 commit comments