Skip to content

Commit 3ac421b

Browse files
committed
If have HF model/tokenizer, use that instead of faketokenizer (tiktoken) since see too large differences and failures even with 250 token buffer, still of by another 350.
1 parent 9e12513 commit 3ac421b

File tree

6 files changed

+127
-66
lines changed

6 files changed

+127
-66
lines changed

create_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1571,7 +1571,7 @@ def test_check_stats_data():
15711571

15721572
llama_type = False
15731573
tokenizer_base_model = base_model = 'h2oai/h2ogpt-oasst1-512-20b'
1574-
model_loader, tokenizer_loader = get_loaders(llama_type=llama_type, model_name=base_model, reward_type=False)
1574+
model_loader, tokenizer_loader = get_loaders(model_name=base_model, reward_type=False, llama_type=llama_type)
15751575
local_files_only = False
15761576
resume_download = True
15771577
use_auth_token = False

export_hf_checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def do_export():
2929
llama_type = "llama" in BASE_MODEL
3030
as_pytorch = False # False -> HF
3131

32-
model_loader, tokenizer_loader = get_loaders(llama_type=llama_type, model_name=BASE_MODEL, reward_type=False)
32+
model_loader, tokenizer_loader = get_loaders(model_name=BASE_MODEL, reward_type=False, llama_type=llama_type)
3333

3434
tokenizer = tokenizer_loader.from_pretrained(
3535
BASE_MODEL,

finetune.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def train(
185185
log("num_gpus: %d" % gpus)
186186
log("max mem: %s" % max_memory)
187187

188-
model_loader, tokenizer_loader = get_loaders(llama_type=llama_type, model_name=base_model, reward_type=False)
188+
model_loader, tokenizer_loader = get_loaders(model_name=base_model, reward_type=False, llama_type=llama_type)
189189

190190
model = model_loader.from_pretrained(
191191
base_model,

generate.py

Lines changed: 120 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import glob
55
import inspect
66
import queue
7-
import shutil
87
import sys
98
import os
109
import time
@@ -635,13 +634,26 @@ def get_config(base_model,
635634
triton_attn=False,
636635
long_sequence=True,
637636
return_model=False,
637+
raise_exception=False,
638638
):
639639
from accelerate import init_empty_weights
640640
with init_empty_weights():
641641
from transformers import AutoConfig
642-
config = AutoConfig.from_pretrained(base_model, use_auth_token=use_auth_token,
643-
trust_remote_code=trust_remote_code,
644-
offload_folder=offload_folder)
642+
try:
643+
config = AutoConfig.from_pretrained(base_model, use_auth_token=use_auth_token,
644+
trust_remote_code=trust_remote_code,
645+
offload_folder=offload_folder)
646+
except OSError as e:
647+
if raise_exception:
648+
raise
649+
if 'not a local folder and is not a valid model identifier listed on' in str(
650+
e) or '404 Client Error' in str(e):
651+
# e.g. llama, gpjt, etc.
652+
# e.g. HF TGI but not model on HF or private etc.
653+
# HF TGI server only should really require prompt_type, not HF model state
654+
return None, None
655+
else:
656+
raise
645657
if triton_attn and 'mpt-' in base_model.lower():
646658
config.attn_config['attn_impl'] = 'triton'
647659
if long_sequence:
@@ -738,20 +750,20 @@ def get_client_from_inference_server(inference_server, raise_connection_exceptio
738750
hf_client = None
739751
if headers is None:
740752
try:
741-
print("GR Client Begin: %s" % inference_server)
753+
print("GR Client Begin: %s" % inference_server, flush=True)
742754
# first do sanity check if alive, else gradio client takes too long by default
743755
requests.get(inference_server, timeout=int(os.getenv('REQUEST_TIMEOUT', '30')))
744756
gr_client = GradioClient(inference_server)
745-
print("GR Client End: %s" % inference_server)
757+
print("GR Client End: %s" % inference_server, flush=True)
746758
except (OSError, ValueError) as e:
747759
# Occurs when wrong endpoint and should have been HF client, so don't hard raise, just move to HF
748760
gr_client = None
749-
print("GR Client Failed %s: %s" % (inference_server, str(e)))
761+
print("GR Client Failed %s: %s" % (inference_server, str(e)), flush=True)
750762
except (ConnectTimeoutError, ConnectTimeout, MaxRetryError, ConnectionError, ConnectionError2,
751763
JSONDecodeError, ReadTimeout2, KeyError) as e:
752764
t, v, tb = sys.exc_info()
753765
ex = ''.join(traceback.format_exception(t, v, tb))
754-
print("GR Client Failed %s: %s" % (inference_server, str(ex)))
766+
print("GR Client Failed %s: %s" % (inference_server, str(ex)), flush=True)
755767
if raise_connection_exception:
756768
raise
757769

@@ -822,28 +834,51 @@ def get_model(
822834
"""
823835
if verbose:
824836
print("Get %s model" % base_model, flush=True)
825-
if isinstance(inference_server, str) and inference_server.startswith("http"):
837+
838+
triton_attn = False
839+
long_sequence = True
840+
config_kwargs = dict(use_auth_token=use_auth_token,
841+
trust_remote_code=trust_remote_code,
842+
offload_folder=offload_folder,
843+
triton_attn=triton_attn,
844+
long_sequence=long_sequence)
845+
config, _ = get_config(base_model, **config_kwargs, raise_exception=False)
846+
847+
if base_model in non_hf_types:
848+
assert config is None, "Expected config None for %s" % base_model
849+
850+
llama_type_from_config = 'llama' in str(config).lower()
851+
llama_type_from_name = "llama" in base_model.lower()
852+
llama_type = llama_type_from_config or llama_type_from_name
853+
if llama_type:
854+
if verbose:
855+
print("Detected as llama type from"
856+
" config (%s) or name (%s)" % (llama_type_from_config, llama_type_from_name), flush=True)
857+
858+
model_loader, tokenizer_loader = get_loaders(model_name=base_model, reward_type=reward_type, llama_type=llama_type)
859+
860+
tokenizer_kwargs = dict(local_files_only=local_files_only,
861+
resume_download=resume_download,
862+
use_auth_token=use_auth_token,
863+
trust_remote_code=trust_remote_code,
864+
offload_folder=offload_folder,
865+
padding_side='left',
866+
config=config,
867+
)
868+
if not tokenizer_base_model:
869+
tokenizer_base_model = base_model
870+
871+
if config is not None and tokenizer_loader is not None and not isinstance(tokenizer_loader, str):
872+
tokenizer = tokenizer_loader.from_pretrained(tokenizer_base_model, **tokenizer_kwargs)
873+
# sets raw (no cushion) limit
874+
set_model_max_len(config, tokenizer, verbose=False)
875+
# if using fake tokenizer, not really accurate when lots of numbers, give a bit of buffer, else get:
876+
# Generation Failed: Input validation error: `inputs` must have less than 2048 tokens. Given: 2233
877+
tokenizer.model_max_length = tokenizer.model_max_length - 50
878+
else:
826879
tokenizer = FakeTokenizer()
827-
try:
828-
from transformers import AutoConfig
829-
config, _ = get_config(base_model, use_auth_token=use_auth_token,
830-
trust_remote_code=trust_remote_code,
831-
offload_folder=offload_folder)
832-
# sets raw (no cushion) limit
833-
set_model_max_len(config, tokenizer, verbose=False)
834-
# if using fake tokenizer, not really accurate when lots of numbers, give a bit of buffer, else get:
835-
# Generation Failed: Input validation error: `inputs` must have less than 2048 tokens. Given: 2233
836-
tokenizer.model_max_length = tokenizer.model_max_length - 250
837-
except OSError as e:
838-
t, v, tb = sys.exc_info()
839-
ex = ''.join(traceback.format_exception(t, v, tb))
840-
if 'not a local folder' in str(ex) or '404 Client Error' in str(ex):
841-
# e.g. llama, gpjt, etc.
842-
pass
843-
else:
844-
if base_model not in non_hf_types:
845-
raise
846880

881+
if isinstance(inference_server, str) and inference_server.startswith("http"):
847882
inference_server, gr_client, hf_client = get_client_from_inference_server(inference_server)
848883
client = gr_client or hf_client
849884
# Don't return None, None for model, tokenizer so triggers
@@ -852,14 +887,65 @@ def get_model(
852887
assert os.getenv('OPENAI_API_KEY'), "Set environment for OPENAI_API_KEY"
853888
# Don't return None, None for model, tokenizer so triggers
854889
# include small token cushion
855-
tokenizer = FakeTokenizer(model_max_length=model_token_mapping[base_model] - 100)
890+
tokenizer = FakeTokenizer(model_max_length=model_token_mapping[base_model] - 50)
856891
return inference_server, tokenizer, inference_server
857892
assert not inference_server, "Malformed inference_server=%s" % inference_server
858893
if base_model in non_hf_types:
859894
from gpt4all_llm import get_model_tokenizer_gpt4all
860895
model, tokenizer, device = get_model_tokenizer_gpt4all(base_model)
861896
return model, tokenizer, device
862897

898+
# get local torch-HF model
899+
return get_hf_model(load_8bit=load_8bit,
900+
load_4bit=load_4bit,
901+
load_half=load_half,
902+
infer_devices=infer_devices,
903+
base_model=base_model,
904+
tokenizer_base_model=tokenizer_base_model,
905+
lora_weights=lora_weights,
906+
gpu_id=gpu_id,
907+
908+
reward_type=reward_type,
909+
local_files_only=local_files_only,
910+
resume_download=resume_download,
911+
use_auth_token=use_auth_token,
912+
trust_remote_code=trust_remote_code,
913+
offload_folder=offload_folder,
914+
compile_model=compile_model,
915+
916+
llama_type=llama_type,
917+
config_kwargs=config_kwargs,
918+
tokenizer_kwargs=tokenizer_kwargs,
919+
920+
verbose=verbose)
921+
922+
923+
def get_hf_model(load_8bit: bool = False,
924+
load_4bit: bool = False,
925+
load_half: bool = True,
926+
infer_devices: bool = True,
927+
base_model: str = '',
928+
tokenizer_base_model: str = '',
929+
lora_weights: str = "",
930+
gpu_id: int = 0,
931+
932+
reward_type: bool = None,
933+
local_files_only: bool = False,
934+
resume_download: bool = True,
935+
use_auth_token: Union[str, bool] = False,
936+
trust_remote_code: bool = True,
937+
offload_folder: str = None,
938+
compile_model: bool = True,
939+
940+
llama_type: bool = False,
941+
config_kwargs=None,
942+
tokenizer_kwargs=None,
943+
944+
verbose: bool = False,
945+
):
946+
assert config_kwargs is not None
947+
assert tokenizer_kwargs is not None
948+
863949
if lora_weights is not None and lora_weights.strip():
864950
if verbose:
865951
print("Get %s lora weights" % lora_weights, flush=True)
@@ -874,31 +960,13 @@ def get_model(
874960
"Please choose a base model with --base_model (CLI) or load one from Models Tab (gradio)"
875961
)
876962

877-
from transformers import AutoConfig
878-
config = AutoConfig.from_pretrained(base_model, use_auth_token=use_auth_token,
879-
trust_remote_code=trust_remote_code,
880-
offload_folder=offload_folder)
881-
llama_type_from_config = 'llama' in str(config).lower()
882-
llama_type_from_name = "llama" in base_model.lower()
883-
llama_type = llama_type_from_config or llama_type_from_name
884-
if llama_type:
885-
if verbose:
886-
print("Detected as llama type from"
887-
" config (%s) or name (%s)" % (llama_type_from_config, llama_type_from_name), flush=True)
963+
model_loader, tokenizer_loader = get_loaders(model_name=base_model, reward_type=reward_type, llama_type=llama_type)
888964

889-
model_loader, tokenizer_loader = get_loaders(llama_type=llama_type, model_name=base_model, reward_type=reward_type)
890-
if not tokenizer_base_model:
891-
tokenizer_base_model = base_model
965+
config, _ = get_config(base_model, return_model=False, raise_exception=True, **config_kwargs)
892966

893967
if tokenizer_loader is not None and not isinstance(tokenizer_loader, str):
894968
tokenizer = tokenizer_loader.from_pretrained(tokenizer_base_model,
895-
local_files_only=local_files_only,
896-
resume_download=resume_download,
897-
use_auth_token=use_auth_token,
898-
trust_remote_code=trust_remote_code,
899-
offload_folder=offload_folder,
900-
padding_side='left',
901-
)
969+
**tokenizer_kwargs)
902970
else:
903971
tokenizer = tokenizer_loader
904972

@@ -931,20 +999,11 @@ def get_model(
931999
model_kwargs.pop('torch_dtype', None)
9321000
pop_unused_model_kwargs(model_kwargs)
9331001

934-
triton_attn = False
935-
long_sequence = True
936-
937-
config_kwargs = dict(use_auth_token=use_auth_token,
938-
trust_remote_code=trust_remote_code,
939-
offload_folder=offload_folder,
940-
triton_attn=triton_attn,
941-
long_sequence=long_sequence)
942-
9431002
if not lora_weights:
9441003
with torch.device(device):
9451004

9461005
if infer_devices:
947-
config, model = get_config(base_model, return_model=True, **config_kwargs)
1006+
config, model = get_config(base_model, return_model=True, raise_exception=True, **config_kwargs)
9481007
model = get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward_type,
9491008
config, model,
9501009
gpu_id=gpu_id,
@@ -982,7 +1041,7 @@ def get_model(
9821041
)
9831042
else:
9841043
with torch.device(device):
985-
config, _ = get_config(base_model, **config_kwargs)
1044+
config, _ = get_config(base_model, raise_exception=True, **config_kwargs)
9861045
model = model_loader.from_pretrained(
9871046
base_model,
9881047
config=config,

loaders.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
def get_loaders(llama_type, model_name, reward_type):
1+
def get_loaders(model_name, reward_type, llama_type=None):
22
# NOTE: Some models need specific new prompt_type
33
# E.g. t5_xxl_true_nli_mixture has input format: "premise: PREMISE_TEXT hypothesis: HYPOTHESIS_TEXT".)
4+
if llama_type is None:
5+
llama_type = "llama" in model_name.lower()
46
if llama_type:
57
from transformers import LlamaForCausalLM, LlamaTokenizer
68
model_loader = LlamaForCausalLM

utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -890,7 +890,7 @@ class FakeTokenizer:
890890
"""
891891

892892
def __init__(self, model_max_length=2048, encoding_name="cl100k_base"):
893-
# dont' push limit, since if using fake tokenizer, only estimate, and seen underestimates by order 200
893+
# dont' push limit, since if using fake tokenizer, only estimate, and seen underestimates by order 250
894894
self.model_max_length = model_max_length - 250
895895
self.encoding_name = encoding_name
896896
# The first time this runs, it will require an internet connection to download. Later runs won't need an internet connection.

0 commit comments

Comments
 (0)