2525from gpt_2_simple .src .accumulate import AccumulatingOptimizer
2626
2727
28- def download_file_with_progress (url_base , sub_dir , file_name ):
28+ def download_file_with_progress (url_base , sub_dir , model_name , file_name ):
2929 """General utility for incrementally downloading files from the internet
3030 with progress bar
3131 from url_base / sub_dir / filename
@@ -46,8 +46,7 @@ def download_file_with_progress(url_base, sub_dir, file_name):
4646
4747 # set to download 1MB at a time. This could be much larger with no issue
4848 DOWNLOAD_CHUNK_SIZE = 1024 * 1024
49-
50- r = requests .get (url_base + "/" + sub_dir + "/" + file_name , stream = True )
49+ r = requests .get (url_base + "/models/" + model_name + "/" + file_name , stream = True )
5150 with open (os .path .join (sub_dir , file_name ), 'wb' ) as f :
5251 file_size = int (r .headers ["content-length" ])
5352 with tqdm (ncols = 100 , desc = "Fetching " + file_name ,
@@ -57,12 +56,15 @@ def download_file_with_progress(url_base, sub_dir, file_name):
5756 pbar .update (DOWNLOAD_CHUNK_SIZE )
5857
5958
60- def download_gpt2 (model_name = '117M' ):
59+ def download_gpt2 (model_dir = 'models' , model_name = '117M' ):
6160 """Downloads the GPT-2 model into the current directory
6261 from Google Cloud Storage.
6362
6463 Parameters
6564 ----------
65+ model_dir : str
66+ parent directory of model to download
67+
6668 model_name : str
6769 name of the GPT-2 model to download.
6870 As of 22 May 2019 one of "117M" or "345M" but may later include other
@@ -71,16 +73,19 @@ def download_gpt2(model_name='117M'):
7173 Adapted from https://github.com/openai/gpt-2/blob/master/download_model.py
7274 """
7375
74- # create the models /<model_name> subdirectory if not present
75- sub_dir = os .path .join ('models' , model_name )
76+ # create the <model_dir> /<model_name> subdirectory if not present
77+ sub_dir = os .path .join (model_dir , model_name )
7678 if not os .path .exists (sub_dir ):
7779 os .makedirs (sub_dir )
7880 sub_dir = sub_dir .replace ('\\ ' , '/' ) # needed for Windows
7981
8082 for file_name in ['checkpoint' , 'encoder.json' , 'hparams.json' ,
8183 'model.ckpt.data-00000-of-00001' , 'model.ckpt.index' ,
8284 'model.ckpt.meta' , 'vocab.bpe' ]:
83- download_file_with_progress (url_base = "https://storage.googleapis.com/gpt-2" , sub_dir = sub_dir , file_name = file_name )
85+ download_file_with_progress (url_base = "https://storage.googleapis.com/gpt-2" ,
86+ sub_dir = sub_dir ,
87+ model_name = model_name ,
88+ file_name = file_name )
8489
8590
8691def start_tf_sess (threads = - 1 , server = None ):
@@ -104,12 +109,14 @@ def finetune(sess,
104109 dataset ,
105110 steps = - 1 ,
106111 model_name = '117M' ,
112+ model_dir = 'models' ,
107113 combine = 50000 ,
108114 batch_size = 1 ,
109115 learning_rate = 0.0001 ,
110116 accumulate_gradients = 5 ,
111117 restore_from = 'latest' ,
112118 run_name = 'run1' ,
119+ checkpoint_dir = 'checkpoint' ,
113120 sample_every = 100 ,
114121 sample_length = 1023 ,
115122 sample_num = 1 ,
@@ -124,11 +131,9 @@ def finetune(sess,
124131 Adapted from https://github.com/nshepperd/gpt-2/blob/finetuning/train.py.
125132 See that file for parameter definitions.
126133 """
127-
128- CHECKPOINT_DIR = 'checkpoint'
129134 SAMPLE_DIR = 'samples'
130135
131- checkpoint_path = os .path .join (CHECKPOINT_DIR , run_name )
136+ checkpoint_path = os .path .join (checkpoint_dir , run_name )
132137
133138 def maketree (path ):
134139 try :
@@ -141,7 +146,7 @@ def maketree(path):
141146 for file in ['hparams.json' , 'encoder.json' , 'vocab.bpe' ]:
142147 if file not in files :
143148 try :
144- shutil .copyfile (os .path .join ('models' , model_name , file ),
149+ shutil .copyfile (os .path .join (model_dir , model_name , file ),
145150 os .path .join (checkpoint_path , file ))
146151 except FileNotFoundError as fnf_error :
147152 print ("You need to download the GPT-2 model first via download_gpt2()" )
@@ -209,10 +214,10 @@ def maketree(path):
209214 if ckpt is None :
210215 # Get fresh GPT weights if new run.
211216 ckpt = tf .train .latest_checkpoint (
212- os .path .join ('models' , model_name ))
217+ os .path .join (model_dir , model_name ))
213218 elif restore_from == 'fresh' :
214219 ckpt = tf .train .latest_checkpoint (
215- os .path .join ('models' , model_name ))
220+ os .path .join (model_dir , model_name ))
216221 else :
217222 ckpt = tf .train .latest_checkpoint (restore_from )
218223 print ('Loading checkpoint' , ckpt )
@@ -324,14 +329,13 @@ def sample_batch():
324329
325330
326331def load_gpt2 (sess ,
327- run_name = "run1" ):
332+ run_name = "run1" ,
333+ checkpoint_dir = "checkpoint" ):
328334 """Loads the model checkpoint into a TensorFlow session
329335 for repeated predictions.
330336 """
331337
332- CHECKPOINT_DIR = 'checkpoint'
333-
334- checkpoint_path = os .path .join (CHECKPOINT_DIR , run_name )
338+ checkpoint_path = os .path .join (checkpoint_dir , run_name )
335339
336340 hparams = model .default_hparams ()
337341 with open (os .path .join (checkpoint_path , 'hparams.json' )) as f :
@@ -350,6 +354,8 @@ def load_gpt2(sess,
350354
351355def generate (sess ,
352356 run_name = 'run1' ,
357+ checkpoint_dir = 'checkpoint' ,
358+ sample_dir = 'samples' ,
353359 return_as_list = False ,
354360 truncate = None ,
355361 destination_path = None ,
@@ -378,10 +384,7 @@ def generate(sess,
378384 if prefix == '' :
379385 prefix = None
380386
381- CHECKPOINT_DIR = 'checkpoint'
382- SAMPLE_DIR = 'samples'
383-
384- checkpoint_path = os .path .join (CHECKPOINT_DIR , run_name )
387+ checkpoint_path = os .path .join (checkpoint_dir , run_name )
385388
386389 enc = encoder .get_encoder (checkpoint_path )
387390 hparams = model .default_hparams ()
@@ -448,6 +451,7 @@ def generate(sess,
448451
449452def generate_to_file (sess ,
450453 run_name = 'run1' ,
454+ checkpoint_dir = 'checkpoint' ,
451455 truncate = None ,
452456 destination_path = 'gpt_2_gen_texts.txt' ,
453457 sample_delim = '=' * 20 + '\n ' ,
@@ -467,21 +471,22 @@ def generate_to_file(sess,
467471 Adapted from https://github.com/minimaxir/textgenrnn/blob/master/textgenrnn/textgenrnn.py
468472 """
469473
470- generate (sess ,
471- run_name ,
472- False ,
473- truncate ,
474- destination_path ,
475- sample_delim ,
476- prefix ,
477- seed ,
478- nsamples ,
479- batch_size ,
480- length ,
481- temperature ,
482- top_k ,
483- top_p ,
484- include_prefix )
474+ generate (sess = sess ,
475+ run_name = run_name ,
476+ checkpoint_dir = checkpoint_dir ,
477+ return_as_list = False ,
478+ truncate = truncate ,
479+ destination_path = destination_path ,
480+ sample_delim = sample_delim ,
481+ prefix = prefix ,
482+ seed = seed ,
483+ nsamples = nsamples ,
484+ batch_size = batch_size ,
485+ length = length ,
486+ temperature = temperature ,
487+ top_k = top_k ,
488+ top_p = top_p ,
489+ include_prefix = include_prefix )
485490
486491
487492def mount_gdrive ():
@@ -552,13 +557,13 @@ def copy_file_from_gdrive(file_path):
552557 shutil .copyfile ("/content/drive/My Drive/" + file_path , file_path )
553558
554559
555- def is_gpt2_downloaded (model_name = '117M' ):
560+ def is_gpt2_downloaded (model_dir = 'models' , model_name = '117M' ):
556561 """Checks if the original model + associated files are present in folder."""
557562
558563 for filename in ['checkpoint' , 'encoder.json' , 'hparams.json' ,
559564 'model.ckpt.data-00000-of-00001' , 'model.ckpt.index' ,
560565 'model.ckpt.meta' , 'vocab.bpe' ]:
561- if not os .path .isfile (os .path .join ("models" , model_name , filename )):
566+ if not os .path .isfile (os .path .join (model_dir , model_name , filename )):
562567 return False
563568 return True
564569
@@ -579,7 +584,7 @@ def encode_csv(csv_path, out_path='csv_encoded.txt', header=True,
579584 w .write (start_token + row [0 ] + end_token + "\n " )
580585
581586
582- def encode_dataset (file_path , out_path = 'text_encoded.npz' ,
587+ def encode_dataset (file_path , model_dir = 'models' , out_path = 'text_encoded.npz' ,
583588 model_name = "117M" ,
584589 combine = 50000 ):
585590 """Preencodes a text document into chunks and compresses it,
@@ -588,7 +593,7 @@ def encode_dataset(file_path, out_path='text_encoded.npz',
588593 Adapted from https://github.com/nshepperd/gpt-2/blob/finetuning/encode.py
589594 """
590595
591- model_path = os .path .join ('models' , model_name )
596+ model_path = os .path .join (model_dir , model_name )
592597 enc = encoder .get_encoder (model_path )
593598 print ('Reading files' )
594599 chunks = load_dataset (enc , file_path , combine )
@@ -610,9 +615,15 @@ def cmd():
610615 parser .add_argument (
611616 '--run_name' , help = "[finetune/generate] Run number to save/load the model" ,
612617 nargs = '?' , default = 'run1' )
618+ parser .add_argument (
619+ '--checkpoint_dir' , help = "[finetune] Path of the checkpoint directory" ,
620+ nargs = '?' , default = 'checkpoint' )
613621 parser .add_argument (
614622 '--model_name' , help = "[finetune] Name of the GPT-2 model to finetune" ,
615623 nargs = '?' , default = '117M' )
624+ parser .add_argument (
625+ '--model_dir' , help = "[finetune] Path of directory of the GPT-2 model to finetune" ,
626+ nargs = '?' , default = 'models' )
616627 parser .add_argument (
617628 '--dataset' , help = "[finetune] Path to the source text." ,
618629 nargs = '?' , default = None )
@@ -683,7 +694,9 @@ def cmd():
683694 assert args .dataset is not None , "You need to provide a dataset."
684695
685696 cmd_finetune (dataset = args .dataset , run_name = args .run_name ,
697+ checkpoint_dir = args .checkpoint_dir ,
686698 model_name = args .model_name ,
699+ model_dir = args .model_dir ,
687700 steps = args .steps , restore_from = args .restore_from ,
688701 sample_every = args .sample_every ,
689702 save_every = args .save_every ,
@@ -696,20 +709,23 @@ def cmd():
696709 prefix = args .prefix , truncate = args .truncate ,
697710 include_prefix = args .include_prefix ,
698711 sample_delim = args .sample_delim , run_name = args .run_name ,
712+ checkpoint_dir = args .checkpoint_dir ,
699713 top_k = args .top_k , top_p = args .top_p )
700714
701715
702- def cmd_finetune (dataset , run_name , model_name , steps ,
716+ def cmd_finetune (dataset , run_name , checkpoint_dir , model_name , model_dir , steps ,
703717 restore_from , sample_every ,
704718 save_every , print_every , overwrite ):
705719 """Wrapper script for finetuning the model via the CLI."""
706720
707- if not is_gpt2_downloaded (model_name = model_name ):
708- download_gpt2 (model_name = model_name )
721+ if not is_gpt2_downloaded (model_dir = model_dir , model_name = model_name ):
722+ download_gpt2 (model_dir = model_dir , model_name = model_name )
709723
710724 sess = start_tf_sess ()
711725 finetune (sess , dataset = dataset , run_name = run_name ,
726+ checkpoint_dir = checkpoint_dir ,
712727 model_name = model_name ,
728+ model_dir = model_dir ,
713729 steps = steps , restore_from = restore_from ,
714730 sample_every = sample_every , save_every = save_every ,
715731 print_every = print_every ,
@@ -720,14 +736,15 @@ def cmd_generate(nfiles, nsamples, folder,
720736 length , temperature , batch_size ,
721737 prefix , truncate , include_prefix ,
722738 sample_delim , run_name ,
739+ checkpoint_dir ,
723740 top_k , top_p ):
724741 """Wrapper script for generating text via the CLI.
725742 The files are generated into a folder, which can be downloaded
726743 recursively by downloading the entire folder.
727744 """
728745
729746 sess = start_tf_sess ()
730- load_gpt2 (sess , run_name = run_name )
747+ load_gpt2 (sess , run_name = run_name , checkpoint_dir = checkpoint_dir )
731748
732749 try :
733750 os .mkdir (folder )
0 commit comments