44import glob
55import inspect
66import queue
7- import shutil
87import sys
98import os
109import time
@@ -635,13 +634,26 @@ def get_config(base_model,
635634 triton_attn = False ,
636635 long_sequence = True ,
637636 return_model = False ,
637+ raise_exception = False ,
638638 ):
639639 from accelerate import init_empty_weights
640640 with init_empty_weights ():
641641 from transformers import AutoConfig
642- config = AutoConfig .from_pretrained (base_model , use_auth_token = use_auth_token ,
643- trust_remote_code = trust_remote_code ,
644- offload_folder = offload_folder )
642+ try :
643+ config = AutoConfig .from_pretrained (base_model , use_auth_token = use_auth_token ,
644+ trust_remote_code = trust_remote_code ,
645+ offload_folder = offload_folder )
646+ except OSError as e :
647+ if raise_exception :
648+ raise
649+ if 'not a local folder and is not a valid model identifier listed on' in str (
650+ e ) or '404 Client Error' in str (e ):
651+ # e.g. llama, gpjt, etc.
652+ # e.g. HF TGI but not model on HF or private etc.
653+ # HF TGI server only should really require prompt_type, not HF model state
654+ return None , None
655+ else :
656+ raise
645657 if triton_attn and 'mpt-' in base_model .lower ():
646658 config .attn_config ['attn_impl' ] = 'triton'
647659 if long_sequence :
@@ -738,20 +750,20 @@ def get_client_from_inference_server(inference_server, raise_connection_exceptio
738750 hf_client = None
739751 if headers is None :
740752 try :
741- print ("GR Client Begin: %s" % inference_server )
753+ print ("GR Client Begin: %s" % inference_server , flush = True )
742754 # first do sanity check if alive, else gradio client takes too long by default
743755 requests .get (inference_server , timeout = int (os .getenv ('REQUEST_TIMEOUT' , '30' )))
744756 gr_client = GradioClient (inference_server )
745- print ("GR Client End: %s" % inference_server )
757+ print ("GR Client End: %s" % inference_server , flush = True )
746758 except (OSError , ValueError ) as e :
747759 # Occurs when wrong endpoint and should have been HF client, so don't hard raise, just move to HF
748760 gr_client = None
749- print ("GR Client Failed %s: %s" % (inference_server , str (e )))
761+ print ("GR Client Failed %s: %s" % (inference_server , str (e )), flush = True )
750762 except (ConnectTimeoutError , ConnectTimeout , MaxRetryError , ConnectionError , ConnectionError2 ,
751763 JSONDecodeError , ReadTimeout2 , KeyError ) as e :
752764 t , v , tb = sys .exc_info ()
753765 ex = '' .join (traceback .format_exception (t , v , tb ))
754- print ("GR Client Failed %s: %s" % (inference_server , str (ex )))
766+ print ("GR Client Failed %s: %s" % (inference_server , str (ex )), flush = True )
755767 if raise_connection_exception :
756768 raise
757769
@@ -822,28 +834,51 @@ def get_model(
822834 """
823835 if verbose :
824836 print ("Get %s model" % base_model , flush = True )
825- if isinstance (inference_server , str ) and inference_server .startswith ("http" ):
837+
838+ triton_attn = False
839+ long_sequence = True
840+ config_kwargs = dict (use_auth_token = use_auth_token ,
841+ trust_remote_code = trust_remote_code ,
842+ offload_folder = offload_folder ,
843+ triton_attn = triton_attn ,
844+ long_sequence = long_sequence )
845+ config , _ = get_config (base_model , ** config_kwargs , raise_exception = False )
846+
847+ if base_model in non_hf_types :
848+ assert config is None , "Expected config None for %s" % base_model
849+
850+ llama_type_from_config = 'llama' in str (config ).lower ()
851+ llama_type_from_name = "llama" in base_model .lower ()
852+ llama_type = llama_type_from_config or llama_type_from_name
853+ if llama_type :
854+ if verbose :
855+ print ("Detected as llama type from"
856+ " config (%s) or name (%s)" % (llama_type_from_config , llama_type_from_name ), flush = True )
857+
858+ model_loader , tokenizer_loader = get_loaders (model_name = base_model , reward_type = reward_type , llama_type = llama_type )
859+
860+ tokenizer_kwargs = dict (local_files_only = local_files_only ,
861+ resume_download = resume_download ,
862+ use_auth_token = use_auth_token ,
863+ trust_remote_code = trust_remote_code ,
864+ offload_folder = offload_folder ,
865+ padding_side = 'left' ,
866+ config = config ,
867+ )
868+ if not tokenizer_base_model :
869+ tokenizer_base_model = base_model
870+
871+ if config is not None and tokenizer_loader is not None and not isinstance (tokenizer_loader , str ):
872+ tokenizer = tokenizer_loader .from_pretrained (tokenizer_base_model , ** tokenizer_kwargs )
873+ # sets raw (no cushion) limit
874+ set_model_max_len (config , tokenizer , verbose = False )
875+ # if using fake tokenizer, not really accurate when lots of numbers, give a bit of buffer, else get:
876+ # Generation Failed: Input validation error: `inputs` must have less than 2048 tokens. Given: 2233
877+ tokenizer .model_max_length = tokenizer .model_max_length - 50
878+ else :
826879 tokenizer = FakeTokenizer ()
827- try :
828- from transformers import AutoConfig
829- config , _ = get_config (base_model , use_auth_token = use_auth_token ,
830- trust_remote_code = trust_remote_code ,
831- offload_folder = offload_folder )
832- # sets raw (no cushion) limit
833- set_model_max_len (config , tokenizer , verbose = False )
834- # if using fake tokenizer, not really accurate when lots of numbers, give a bit of buffer, else get:
835- # Generation Failed: Input validation error: `inputs` must have less than 2048 tokens. Given: 2233
836- tokenizer .model_max_length = tokenizer .model_max_length - 250
837- except OSError as e :
838- t , v , tb = sys .exc_info ()
839- ex = '' .join (traceback .format_exception (t , v , tb ))
840- if 'not a local folder' in str (ex ) or '404 Client Error' in str (ex ):
841- # e.g. llama, gpjt, etc.
842- pass
843- else :
844- if base_model not in non_hf_types :
845- raise
846880
881+ if isinstance (inference_server , str ) and inference_server .startswith ("http" ):
847882 inference_server , gr_client , hf_client = get_client_from_inference_server (inference_server )
848883 client = gr_client or hf_client
849884 # Don't return None, None for model, tokenizer so triggers
@@ -852,14 +887,65 @@ def get_model(
852887 assert os .getenv ('OPENAI_API_KEY' ), "Set environment for OPENAI_API_KEY"
853888 # Don't return None, None for model, tokenizer so triggers
854889 # include small token cushion
855- tokenizer = FakeTokenizer (model_max_length = model_token_mapping [base_model ] - 100 )
890+ tokenizer = FakeTokenizer (model_max_length = model_token_mapping [base_model ] - 50 )
856891 return inference_server , tokenizer , inference_server
857892 assert not inference_server , "Malformed inference_server=%s" % inference_server
858893 if base_model in non_hf_types :
859894 from gpt4all_llm import get_model_tokenizer_gpt4all
860895 model , tokenizer , device = get_model_tokenizer_gpt4all (base_model )
861896 return model , tokenizer , device
862897
898+ # get local torch-HF model
899+ return get_hf_model (load_8bit = load_8bit ,
900+ load_4bit = load_4bit ,
901+ load_half = load_half ,
902+ infer_devices = infer_devices ,
903+ base_model = base_model ,
904+ tokenizer_base_model = tokenizer_base_model ,
905+ lora_weights = lora_weights ,
906+ gpu_id = gpu_id ,
907+
908+ reward_type = reward_type ,
909+ local_files_only = local_files_only ,
910+ resume_download = resume_download ,
911+ use_auth_token = use_auth_token ,
912+ trust_remote_code = trust_remote_code ,
913+ offload_folder = offload_folder ,
914+ compile_model = compile_model ,
915+
916+ llama_type = llama_type ,
917+ config_kwargs = config_kwargs ,
918+ tokenizer_kwargs = tokenizer_kwargs ,
919+
920+ verbose = verbose )
921+
922+
923+ def get_hf_model (load_8bit : bool = False ,
924+ load_4bit : bool = False ,
925+ load_half : bool = True ,
926+ infer_devices : bool = True ,
927+ base_model : str = '' ,
928+ tokenizer_base_model : str = '' ,
929+ lora_weights : str = "" ,
930+ gpu_id : int = 0 ,
931+
932+ reward_type : bool = None ,
933+ local_files_only : bool = False ,
934+ resume_download : bool = True ,
935+ use_auth_token : Union [str , bool ] = False ,
936+ trust_remote_code : bool = True ,
937+ offload_folder : str = None ,
938+ compile_model : bool = True ,
939+
940+ llama_type : bool = False ,
941+ config_kwargs = None ,
942+ tokenizer_kwargs = None ,
943+
944+ verbose : bool = False ,
945+ ):
946+ assert config_kwargs is not None
947+ assert tokenizer_kwargs is not None
948+
863949 if lora_weights is not None and lora_weights .strip ():
864950 if verbose :
865951 print ("Get %s lora weights" % lora_weights , flush = True )
@@ -874,31 +960,13 @@ def get_model(
874960 "Please choose a base model with --base_model (CLI) or load one from Models Tab (gradio)"
875961 )
876962
877- from transformers import AutoConfig
878- config = AutoConfig .from_pretrained (base_model , use_auth_token = use_auth_token ,
879- trust_remote_code = trust_remote_code ,
880- offload_folder = offload_folder )
881- llama_type_from_config = 'llama' in str (config ).lower ()
882- llama_type_from_name = "llama" in base_model .lower ()
883- llama_type = llama_type_from_config or llama_type_from_name
884- if llama_type :
885- if verbose :
886- print ("Detected as llama type from"
887- " config (%s) or name (%s)" % (llama_type_from_config , llama_type_from_name ), flush = True )
963+ model_loader , tokenizer_loader = get_loaders (model_name = base_model , reward_type = reward_type , llama_type = llama_type )
888964
889- model_loader , tokenizer_loader = get_loaders (llama_type = llama_type , model_name = base_model , reward_type = reward_type )
890- if not tokenizer_base_model :
891- tokenizer_base_model = base_model
965+ config , _ = get_config (base_model , return_model = False , raise_exception = True , ** config_kwargs )
892966
893967 if tokenizer_loader is not None and not isinstance (tokenizer_loader , str ):
894968 tokenizer = tokenizer_loader .from_pretrained (tokenizer_base_model ,
895- local_files_only = local_files_only ,
896- resume_download = resume_download ,
897- use_auth_token = use_auth_token ,
898- trust_remote_code = trust_remote_code ,
899- offload_folder = offload_folder ,
900- padding_side = 'left' ,
901- )
969+ ** tokenizer_kwargs )
902970 else :
903971 tokenizer = tokenizer_loader
904972
@@ -931,20 +999,11 @@ def get_model(
931999 model_kwargs .pop ('torch_dtype' , None )
9321000 pop_unused_model_kwargs (model_kwargs )
9331001
934- triton_attn = False
935- long_sequence = True
936-
937- config_kwargs = dict (use_auth_token = use_auth_token ,
938- trust_remote_code = trust_remote_code ,
939- offload_folder = offload_folder ,
940- triton_attn = triton_attn ,
941- long_sequence = long_sequence )
942-
9431002 if not lora_weights :
9441003 with torch .device (device ):
9451004
9461005 if infer_devices :
947- config , model = get_config (base_model , return_model = True , ** config_kwargs )
1006+ config , model = get_config (base_model , return_model = True , raise_exception = True , ** config_kwargs )
9481007 model = get_non_lora_model (base_model , model_loader , load_half , model_kwargs , reward_type ,
9491008 config , model ,
9501009 gpu_id = gpu_id ,
@@ -982,7 +1041,7 @@ def get_model(
9821041 )
9831042 else :
9841043 with torch .device (device ):
985- config , _ = get_config (base_model , ** config_kwargs )
1044+ config , _ = get_config (base_model , raise_exception = True , ** config_kwargs )
9861045 model = model_loader .from_pretrained (
9871046 base_model ,
9881047 config = config ,
0 commit comments