Skip to content

Commit 9418d1b

Browse files
committed
Parameter for loading transformer & tokenizer with local_files_only
Add local_files_only as a parameter to the FoundationCache, in case that winds up being useful NONE download_method doesn't download anything, including HF, rather than adding a new mode
1 parent 076f7e3 commit 9418d1b

File tree

4 files changed

+37
-32
lines changed

4 files changed

+37
-32
lines changed

stanza/models/common/bert_embedding.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def update_max_length(model_name, tokenizer):
3232
'NYTK/electra-small-discriminator-hungarian'):
3333
tokenizer.model_max_length = 512
3434

35-
def load_tokenizer(model_name, tokenizer_kwargs=None):
35+
def load_tokenizer(model_name, tokenizer_kwargs=None, local_files_only=False):
3636
if model_name:
3737
# note that use_fast is the default
3838
try:
@@ -44,20 +44,21 @@ def load_tokenizer(model_name, tokenizer_kwargs=None):
4444
bert_args["add_prefix_space"] = True
4545
if tokenizer_kwargs:
4646
bert_args.update(tokenizer_kwargs)
47+
bert_args['local_files_only'] = local_files_only
4748
bert_tokenizer = AutoTokenizer.from_pretrained(model_name, **bert_args)
4849
update_max_length(model_name, bert_tokenizer)
4950
return bert_tokenizer
5051
return None
5152

52-
def load_bert(model_name):
53+
def load_bert(model_name, tokenizer_kwargs=None, local_files_only=False):
5354
if model_name:
5455
# such as: "vinai/phobert-base"
5556
try:
5657
from transformers import AutoModel
5758
except ImportError:
5859
raise ImportError("Please install transformers library for BERT support! Try `pip install transformers`.")
59-
bert_model = AutoModel.from_pretrained(model_name)
60-
bert_tokenizer = load_tokenizer(model_name)
60+
bert_model = AutoModel.from_pretrained(model_name, local_files_only=local_files_only)
61+
bert_tokenizer = load_tokenizer(model_name, tokenizer_kwargs=tokenizer_kwargs, local_files_only=local_files_only)
6162
return bert_model, bert_tokenizer
6263
return None, None
6364

stanza/models/common/foundation_cache.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
BertRecord = namedtuple('BertRecord', ['model', 'tokenizer', 'peft_ids'])
1717

1818
class FoundationCache:
19-
def __init__(self, other=None):
19+
def __init__(self, other=None, local_files_only=False):
2020
if other is None:
2121
self.bert = {}
2222
self.charlms = {}
@@ -29,12 +29,13 @@ def __init__(self, other=None):
2929
self.charlms = other.charlms
3030
self.pretrains = other.pretrains
3131
self.lock = other.lock
32+
self.local_files_only=local_files_only
3233

33-
def load_bert(self, transformer_name):
34-
m, t, _ = self.load_bert_with_peft(transformer_name, None)
34+
def load_bert(self, transformer_name, local_files_only=None):
35+
m, t, _ = self.load_bert_with_peft(transformer_name, None, local_files_only=local_files_only)
3536
return m, t
3637

37-
def load_bert_with_peft(self, transformer_name, peft_name):
38+
def load_bert_with_peft(self, transformer_name, peft_name, local_files_only=None):
3839
"""
3940
Load a transformer only once
4041
@@ -44,7 +45,9 @@ def load_bert_with_peft(self, transformer_name, peft_name):
4445
return None, None, None
4546
with self.lock:
4647
if transformer_name not in self.bert:
47-
model, tokenizer = bert_embedding.load_bert(transformer_name)
48+
if local_files_only is None:
49+
local_files_only = self.local_files_only
50+
model, tokenizer = bert_embedding.load_bert(transformer_name, local_files_only=local_files_only)
4851
self.bert[transformer_name] = BertRecord(model, tokenizer, {})
4952
else:
5053
logger.debug("Reusing bert %s", transformer_name)
@@ -98,26 +101,26 @@ class NoTransformerFoundationCache(FoundationCache):
98101
since it will then have the finetuned weights for other models
99102
which don't want them
100103
"""
101-
def load_bert(self, transformer_name):
102-
return load_bert(transformer_name)
104+
def load_bert(self, transformer_name, local_files_only=None):
105+
return load_bert(transformer_name, local_files_only=self.local_files_only if local_files_only is None else local_files_only)
103106

104-
def load_bert_with_peft(self, transformer_name, peft_name):
105-
return load_bert_with_peft(transformer_name, peft_name)
107+
def load_bert_with_peft(self, transformer_name, peft_name, local_files_only=None):
108+
return load_bert_with_peft(transformer_name, peft_name, local_files_only=self.local_files_only if local_files_only is None else local_files_only)
106109

107-
def load_bert(model_name, foundation_cache=None):
110+
def load_bert(model_name, foundation_cache=None, local_files_only=None):
108111
"""
109112
Load a bert, possibly using a foundation cache, ignoring the cache if None
110113
"""
111114
if foundation_cache is None:
112-
return bert_embedding.load_bert(model_name)
115+
return bert_embedding.load_bert(model_name, local_files_only=local_files_only)
113116
else:
114-
return foundation_cache.load_bert(model_name)
117+
return foundation_cache.load_bert(model_name, local_files_only=local_files_only)
115118

116-
def load_bert_with_peft(model_name, peft_name, foundation_cache=None):
119+
def load_bert_with_peft(model_name, peft_name, foundation_cache=None, local_files_only=None):
117120
if foundation_cache is None:
118-
m, t = bert_embedding.load_bert(model_name)
121+
m, t = bert_embedding.load_bert(model_name, local_files_only=local_files_only)
119122
return m, t, peft_name
120-
return foundation_cache.load_bert_with_peft(model_name, peft_name)
123+
return foundation_cache.load_bert_with_peft(model_name, peft_name, local_files_only=local_files_only)
121124

122125
def load_charlm(charlm_file, foundation_cache=None, finetune=False):
123126
if not charlm_file:

stanza/models/coref/model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -560,7 +560,8 @@ def _build_model(self, foundation_cache):
560560
tokenizer_kwargs = self.config.tokenizer_kwargs.get(base_bert_name, {})
561561
if tokenizer_kwargs:
562562
logger.debug(f"Using tokenizer kwargs: {tokenizer_kwargs}")
563-
self.tokenizer = load_tokenizer(self.config.bert_model, tokenizer_kwargs)
563+
# we just downloaded the tokenizer, so for simplicity, we don't make another request to HF
564+
self.tokenizer = load_tokenizer(self.config.bert_model, tokenizer_kwargs, local_files_only=True)
564565

565566
if self.config.bert_finetune or (hasattr(self.config, 'lora') and self.config.lora):
566567
self.bert = self.bert.train()

stanza/pipeline/core.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class DownloadMethod(Enum):
4040
"""
4141
Determines a couple options on how to download resources for the pipeline.
4242
43-
NONE will not download anything, probably resulting in failure if the resources aren't already in place.
43+
NONE will not download anything, including HF transformers, probably resulting in failure if the resources aren't already in place.
4444
REUSE_RESOURCES will reuse the existing resources.json and models, but will download any missing models.
4545
DOWNLOAD_RESOURCES will download a new resources.json and will overwrite any out of date models.
4646
"""
@@ -201,16 +201,9 @@ def __init__(self,
201201
# set global logging level
202202
set_logging_level(logging_level, verbose)
203203

204-
# processors can use this to save on the effort of loading
205-
# large sub-models, such as pretrained embeddings, bert, etc
206-
if foundation_cache is None:
207-
self.foundation_cache = FoundationCache()
208-
else:
209-
self.foundation_cache = foundation_cache
210-
211-
download_method = normalize_download_method(download_method)
212-
if (download_method is DownloadMethod.DOWNLOAD_RESOURCES or
213-
(download_method is DownloadMethod.REUSE_RESOURCES and not os.path.exists(os.path.join(self.dir, "resources.json")))):
204+
self.download_method = normalize_download_method(download_method)
205+
if (self.download_method is DownloadMethod.DOWNLOAD_RESOURCES or
206+
(self.download_method is DownloadMethod.REUSE_RESOURCES and not os.path.exists(os.path.join(self.dir, "resources.json")))):
214207
logger.info("Checking for updates to resources.json in case models have been updated. Note: this behavior can be turned off with download_method=None or download_method=DownloadMethod.REUSE_RESOURCES")
215208
download_resources_json(self.dir,
216209
resources_url=resources_url,
@@ -219,6 +212,13 @@ def __init__(self,
219212
resources_filepath=resources_filepath,
220213
proxies=proxies)
221214

215+
# processors can use this to save on the effort of loading
216+
# large sub-models, such as pretrained embeddings, bert, etc
217+
if foundation_cache is None:
218+
self.foundation_cache = FoundationCache(local_files_only=(self.download_method is DownloadMethod.NONE))
219+
else:
220+
self.foundation_cache = FoundationCache(foundation_cache, local_files_only=(self.download_method is DownloadMethod.NONE))
221+
222222
# process different pipeline parameters
223223
lang, self.dir, package, processors = process_pipeline_parameters(lang, self.dir, package, processors)
224224

@@ -241,7 +241,7 @@ def __init__(self,
241241
if lang in resources:
242242
self.load_list = maintain_processor_list(resources, lang, package, processors, maybe_add_mwt=(not kwargs.get("tokenize_pretokenized")))
243243
self.load_list = add_dependencies(resources, lang, self.load_list)
244-
if download_method is not DownloadMethod.NONE:
244+
if self.download_method is not DownloadMethod.NONE:
245245
# skip processors which aren't downloaded from our collection
246246
download_list = [x for x in self.load_list if x[0] in resources.get(lang, {})]
247247
# skip variants

0 commit comments

Comments
 (0)