Skip to content

Commit 01f0de5

Browse files
authored
Merge pull request #90 from IWillPull/master
added optional "--model_dir" and "--checkpoint_dir" parameters
2 parents 7621b10 + ecf55f3 commit 01f0de5

File tree

1 file changed

+61
-44
lines changed

1 file changed

+61
-44
lines changed

gpt_2_simple/gpt_2.py

Lines changed: 61 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from gpt_2_simple.src.accumulate import AccumulatingOptimizer
2626

2727

28-
def download_file_with_progress(url_base, sub_dir, file_name):
28+
def download_file_with_progress(url_base, sub_dir, model_name, file_name):
2929
"""General utility for incrementally downloading files from the internet
3030
with progress bar
3131
from url_base / sub_dir / filename
@@ -46,8 +46,7 @@ def download_file_with_progress(url_base, sub_dir, file_name):
4646

4747
# set to download 1MB at a time. This could be much larger with no issue
4848
DOWNLOAD_CHUNK_SIZE = 1024 * 1024
49-
50-
r = requests.get(url_base + "/" + sub_dir + "/" + file_name, stream=True)
49+
r = requests.get(url_base + "/models/" + model_name + "/" + file_name, stream=True)
5150
with open(os.path.join(sub_dir, file_name), 'wb') as f:
5251
file_size = int(r.headers["content-length"])
5352
with tqdm(ncols=100, desc="Fetching " + file_name,
@@ -57,12 +56,15 @@ def download_file_with_progress(url_base, sub_dir, file_name):
5756
pbar.update(DOWNLOAD_CHUNK_SIZE)
5857

5958

60-
def download_gpt2(model_name='117M'):
59+
def download_gpt2(model_dir='models', model_name='117M'):
6160
"""Downloads the GPT-2 model into the current directory
6261
from Google Cloud Storage.
6362
6463
Parameters
6564
----------
65+
model_dir : str
66+
parent directory of model to download
67+
6668
model_name : str
6769
name of the GPT-2 model to download.
6870
As of 22 May 2019 one of "117M" or "345M" but may later include other
@@ -71,16 +73,19 @@ def download_gpt2(model_name='117M'):
7173
Adapted from https://github.com/openai/gpt-2/blob/master/download_model.py
7274
"""
7375

74-
# create the models/<model_name> subdirectory if not present
75-
sub_dir = os.path.join('models', model_name)
76+
# create the <model_dir>/<model_name> subdirectory if not present
77+
sub_dir = os.path.join(model_dir, model_name)
7678
if not os.path.exists(sub_dir):
7779
os.makedirs(sub_dir)
7880
sub_dir = sub_dir.replace('\\', '/') # needed for Windows
7981

8082
for file_name in ['checkpoint', 'encoder.json', 'hparams.json',
8183
'model.ckpt.data-00000-of-00001', 'model.ckpt.index',
8284
'model.ckpt.meta', 'vocab.bpe']:
83-
download_file_with_progress(url_base="https://storage.googleapis.com/gpt-2", sub_dir=sub_dir, file_name=file_name)
85+
download_file_with_progress(url_base="https://storage.googleapis.com/gpt-2",
86+
sub_dir=sub_dir,
87+
model_name=model_name,
88+
file_name=file_name)
8489

8590

8691
def start_tf_sess(threads=-1, server=None):
@@ -104,12 +109,14 @@ def finetune(sess,
104109
dataset,
105110
steps=-1,
106111
model_name='117M',
112+
model_dir='models',
107113
combine=50000,
108114
batch_size=1,
109115
learning_rate=0.0001,
110116
accumulate_gradients=5,
111117
restore_from='latest',
112118
run_name='run1',
119+
checkpoint_dir='checkpoint',
113120
sample_every=100,
114121
sample_length=1023,
115122
sample_num=1,
@@ -124,11 +131,9 @@ def finetune(sess,
124131
Adapted from https://github.com/nshepperd/gpt-2/blob/finetuning/train.py.
125132
See that file for parameter definitions.
126133
"""
127-
128-
CHECKPOINT_DIR = 'checkpoint'
129134
SAMPLE_DIR = 'samples'
130135

131-
checkpoint_path = os.path.join(CHECKPOINT_DIR, run_name)
136+
checkpoint_path = os.path.join(checkpoint_dir, run_name)
132137

133138
def maketree(path):
134139
try:
@@ -141,7 +146,7 @@ def maketree(path):
141146
for file in ['hparams.json', 'encoder.json', 'vocab.bpe']:
142147
if file not in files:
143148
try:
144-
shutil.copyfile(os.path.join('models', model_name, file),
149+
shutil.copyfile(os.path.join(model_dir, model_name, file),
145150
os.path.join(checkpoint_path, file))
146151
except FileNotFoundError as fnf_error:
147152
print("You need to download the GPT-2 model first via download_gpt2()")
@@ -209,10 +214,10 @@ def maketree(path):
209214
if ckpt is None:
210215
# Get fresh GPT weights if new run.
211216
ckpt = tf.train.latest_checkpoint(
212-
os.path.join('models', model_name))
217+
os.path.join(model_dir, model_name))
213218
elif restore_from == 'fresh':
214219
ckpt = tf.train.latest_checkpoint(
215-
os.path.join('models', model_name))
220+
os.path.join(model_dir, model_name))
216221
else:
217222
ckpt = tf.train.latest_checkpoint(restore_from)
218223
print('Loading checkpoint', ckpt)
@@ -324,14 +329,13 @@ def sample_batch():
324329

325330

326331
def load_gpt2(sess,
327-
run_name="run1"):
332+
run_name="run1",
333+
checkpoint_dir="checkpoint"):
328334
"""Loads the model checkpoint into a TensorFlow session
329335
for repeated predictions.
330336
"""
331337

332-
CHECKPOINT_DIR = 'checkpoint'
333-
334-
checkpoint_path = os.path.join(CHECKPOINT_DIR, run_name)
338+
checkpoint_path = os.path.join(checkpoint_dir, run_name)
335339

336340
hparams = model.default_hparams()
337341
with open(os.path.join(checkpoint_path, 'hparams.json')) as f:
@@ -350,6 +354,8 @@ def load_gpt2(sess,
350354

351355
def generate(sess,
352356
run_name='run1',
357+
checkpoint_dir='checkpoint',
358+
sample_dir='samples',
353359
return_as_list=False,
354360
truncate=None,
355361
destination_path=None,
@@ -378,10 +384,7 @@ def generate(sess,
378384
if prefix == '':
379385
prefix = None
380386

381-
CHECKPOINT_DIR = 'checkpoint'
382-
SAMPLE_DIR = 'samples'
383-
384-
checkpoint_path = os.path.join(CHECKPOINT_DIR, run_name)
387+
checkpoint_path = os.path.join(checkpoint_dir, run_name)
385388

386389
enc = encoder.get_encoder(checkpoint_path)
387390
hparams = model.default_hparams()
@@ -448,6 +451,7 @@ def generate(sess,
448451

449452
def generate_to_file(sess,
450453
run_name='run1',
454+
checkpoint_dir='checkpoint',
451455
truncate=None,
452456
destination_path='gpt_2_gen_texts.txt',
453457
sample_delim='=' * 20 + '\n',
@@ -467,21 +471,22 @@ def generate_to_file(sess,
467471
Adapted from https://github.com/minimaxir/textgenrnn/blob/master/textgenrnn/textgenrnn.py
468472
"""
469473

470-
generate(sess,
471-
run_name,
472-
False,
473-
truncate,
474-
destination_path,
475-
sample_delim,
476-
prefix,
477-
seed,
478-
nsamples,
479-
batch_size,
480-
length,
481-
temperature,
482-
top_k,
483-
top_p,
484-
include_prefix)
474+
generate(sess=sess,
475+
run_name=run_name,
476+
checkpoint_dir=checkpoint_dir,
477+
return_as_list=False,
478+
truncate=truncate,
479+
destination_path=destination_path,
480+
sample_delim=sample_delim,
481+
prefix=prefix,
482+
seed=seed,
483+
nsamples=nsamples,
484+
batch_size=batch_size,
485+
length=length,
486+
temperature=temperature,
487+
top_k=top_k,
488+
top_p=top_p,
489+
include_prefix=include_prefix)
485490

486491

487492
def mount_gdrive():
@@ -552,13 +557,13 @@ def copy_file_from_gdrive(file_path):
552557
shutil.copyfile("/content/drive/My Drive/" + file_path, file_path)
553558

554559

555-
def is_gpt2_downloaded(model_name='117M'):
560+
def is_gpt2_downloaded(model_dir='models', model_name='117M'):
556561
"""Checks if the original model + associated files are present in folder."""
557562

558563
for filename in ['checkpoint', 'encoder.json', 'hparams.json',
559564
'model.ckpt.data-00000-of-00001', 'model.ckpt.index',
560565
'model.ckpt.meta', 'vocab.bpe']:
561-
if not os.path.isfile(os.path.join("models", model_name, filename)):
566+
if not os.path.isfile(os.path.join(model_dir, model_name, filename)):
562567
return False
563568
return True
564569

@@ -579,7 +584,7 @@ def encode_csv(csv_path, out_path='csv_encoded.txt', header=True,
579584
w.write(start_token + row[0] + end_token + "\n")
580585

581586

582-
def encode_dataset(file_path, out_path='text_encoded.npz',
587+
def encode_dataset(file_path, model_dir='models', out_path='text_encoded.npz',
583588
model_name="117M",
584589
combine=50000):
585590
"""Preencodes a text document into chunks and compresses it,
@@ -588,7 +593,7 @@ def encode_dataset(file_path, out_path='text_encoded.npz',
588593
Adapted from https://github.com/nshepperd/gpt-2/blob/finetuning/encode.py
589594
"""
590595

591-
model_path = os.path.join('models', model_name)
596+
model_path = os.path.join(model_dir, model_name)
592597
enc = encoder.get_encoder(model_path)
593598
print('Reading files')
594599
chunks = load_dataset(enc, file_path, combine)
@@ -610,9 +615,15 @@ def cmd():
610615
parser.add_argument(
611616
'--run_name', help="[finetune/generate] Run number to save/load the model",
612617
nargs='?', default='run1')
618+
parser.add_argument(
619+
'--checkpoint_dir', help="[finetune] Path of the checkpoint directory",
620+
nargs='?', default='checkpoint')
613621
parser.add_argument(
614622
'--model_name', help="[finetune] Name of the GPT-2 model to finetune",
615623
nargs='?', default='117M')
624+
parser.add_argument(
625+
'--model_dir', help="[finetune] Path of directory of the GPT-2 model to finetune",
626+
nargs='?', default='models')
616627
parser.add_argument(
617628
'--dataset', help="[finetune] Path to the source text.",
618629
nargs='?', default=None)
@@ -683,7 +694,9 @@ def cmd():
683694
assert args.dataset is not None, "You need to provide a dataset."
684695

685696
cmd_finetune(dataset=args.dataset, run_name=args.run_name,
697+
checkpoint_dir=args.checkpoint_dir,
686698
model_name=args.model_name,
699+
model_dir=args.model_dir,
687700
steps=args.steps, restore_from=args.restore_from,
688701
sample_every=args.sample_every,
689702
save_every=args.save_every,
@@ -696,20 +709,23 @@ def cmd():
696709
prefix=args.prefix, truncate=args.truncate,
697710
include_prefix=args.include_prefix,
698711
sample_delim=args.sample_delim, run_name=args.run_name,
712+
checkpoint_dir=args.checkpoint_dir,
699713
top_k=args.top_k, top_p=args.top_p)
700714

701715

702-
def cmd_finetune(dataset, run_name, model_name, steps,
716+
def cmd_finetune(dataset, run_name, checkpoint_dir, model_name, model_dir, steps,
703717
restore_from, sample_every,
704718
save_every, print_every, overwrite):
705719
"""Wrapper script for finetuning the model via the CLI."""
706720

707-
if not is_gpt2_downloaded(model_name=model_name):
708-
download_gpt2(model_name=model_name)
721+
if not is_gpt2_downloaded(model_dir=model_dir, model_name=model_name):
722+
download_gpt2(model_dir=model_dir, model_name=model_name)
709723

710724
sess = start_tf_sess()
711725
finetune(sess, dataset=dataset, run_name=run_name,
726+
checkpoint_dir=checkpoint_dir,
712727
model_name=model_name,
728+
model_dir=model_dir,
713729
steps=steps, restore_from=restore_from,
714730
sample_every=sample_every, save_every=save_every,
715731
print_every=print_every,
@@ -720,14 +736,15 @@ def cmd_generate(nfiles, nsamples, folder,
720736
length, temperature, batch_size,
721737
prefix, truncate, include_prefix,
722738
sample_delim, run_name,
739+
checkpoint_dir,
723740
top_k, top_p):
724741
"""Wrapper script for generating text via the CLI.
725742
The files are generated into a folder, which can be downloaded
726743
recursively by downloading the entire folder.
727744
"""
728745

729746
sess = start_tf_sess()
730-
load_gpt2(sess, run_name=run_name)
747+
load_gpt2(sess, run_name=run_name, checkpoint_dir=checkpoint_dir)
731748

732749
try:
733750
os.mkdir(folder)

0 commit comments

Comments
 (0)