|
| 1 | +import argparse |
| 2 | +from datetime import datetime |
| 3 | +import random |
| 4 | +import sys |
| 5 | +import time |
| 6 | + |
| 7 | +import boto3 |
| 8 | +from botocore.compat import total_seconds |
| 9 | + |
| 10 | +parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) |
| 11 | + |
| 12 | +parser.add_argument('--profile', help='profile name of aws account.', type=str, |
| 13 | + default=None) |
| 14 | +parser.add_argument('--region', help='Default region when creating new connections', type=str, |
| 15 | + default=None) |
| 16 | +parser.add_argument('--name', help='name of the job', type=str, default='dummy') |
| 17 | +parser.add_argument('--job-queue', help='name of the job queue to submit this job', type=str, |
| 18 | + default='gluon-nlp-jobs') |
| 19 | +parser.add_argument('--job-definition', help='name of the job job definition', type=str, |
| 20 | + default='gluon-nlp-jobs:6') |
| 21 | +parser.add_argument('--source-ref', |
| 22 | + help='ref in GluonNLP main github. e.g. master, refs/pull/500/head', |
| 23 | + type=str, default='master') |
| 24 | +parser.add_argument('--work-dir', |
| 25 | + help='working directory inside the repo. e.g. scripts/sentiment_analysis', |
| 26 | + type=str, default='scripts/bert') |
| 27 | +parser.add_argument('--saved-output', |
| 28 | + help='output to be saved, relative to working directory. ' |
| 29 | + 'it can be either a single file or a directory', |
| 30 | + type=str, default='.') |
| 31 | +parser.add_argument('--save-path', |
| 32 | + help='s3 path where files are saved.', |
| 33 | + type=str, default='batch/temp/{}'.format(datetime.now().isoformat())) |
| 34 | +parser.add_argument('--conda-env', |
| 35 | + help='conda environment preset to use.', |
| 36 | + type=str, default='gpu/py3') |
| 37 | +parser.add_argument('--command', help='command to run', type=str, |
| 38 | + default='git rev-parse HEAD | tee stdout.log') |
| 39 | +parser.add_argument('--remote', |
| 40 | + help='git repo address. https://github.com/dmlc/gluon-nlp', |
| 41 | + type=str, default="https://github.com/dmlc/gluon-nlp") |
| 42 | +parser.add_argument('--wait', help='block wait until the job completes. ' |
| 43 | + 'Non-zero exit code if job fails.', action='store_true') |
| 44 | +parser.add_argument('--timeout', help='job timeout in seconds', default=None, type=int) |
| 45 | + |
| 46 | +args = parser.parse_args() |
| 47 | + |
| 48 | +session = boto3.Session(profile_name=args.profile, region_name=args.region) |
| 49 | +batch, cloudwatch = [session.client(service_name=sn) for sn in ['batch', 'logs']] |
| 50 | + |
| 51 | +def printLogs(logGroupName, logStreamName, startTime): |
| 52 | + kwargs = {'logGroupName': logGroupName, |
| 53 | + 'logStreamName': logStreamName, |
| 54 | + 'startTime': startTime, |
| 55 | + 'startFromHead': True} |
| 56 | + |
| 57 | + lastTimestamp = 0 |
| 58 | + while True: |
| 59 | + logEvents = cloudwatch.get_log_events(**kwargs) |
| 60 | + |
| 61 | + for event in logEvents['events']: |
| 62 | + lastTimestamp = event['timestamp'] |
| 63 | + timestamp = datetime.utcfromtimestamp(lastTimestamp / 1000.0).isoformat() |
| 64 | + print('[{}] {}'.format((timestamp + '.000')[:23] + 'Z', event['message'])) |
| 65 | + |
| 66 | + nextToken = logEvents['nextForwardToken'] |
| 67 | + if nextToken and kwargs.get('nextToken') != nextToken: |
| 68 | + kwargs['nextToken'] = nextToken |
| 69 | + else: |
| 70 | + break |
| 71 | + return lastTimestamp |
| 72 | + |
| 73 | + |
| 74 | +def getLogStream(logGroupName, jobName, jobId): |
| 75 | + response = cloudwatch.describe_log_streams( |
| 76 | + logGroupName=logGroupName, |
| 77 | + logStreamNamePrefix=jobName + '/' + jobId |
| 78 | + ) |
| 79 | + logStreams = response['logStreams'] |
| 80 | + if not logStreams: |
| 81 | + return '' |
| 82 | + else: |
| 83 | + return logStreams[0]['logStreamName'] |
| 84 | + |
| 85 | +def nowInMillis(): |
| 86 | + endTime = long(total_seconds(datetime.utcnow() - datetime(1970, 1, 1))) * 1000 |
| 87 | + return endTime |
| 88 | + |
| 89 | + |
| 90 | +def main(): |
| 91 | + spin = ['-', '/', '|', '\\', '-', '/', '|', '\\'] |
| 92 | + logGroupName = '/aws/batch/job' |
| 93 | + |
| 94 | + jobName = args.name |
| 95 | + jobQueue = args.job_queue |
| 96 | + jobDefinition = args.job_definition |
| 97 | + command = args.command.split() |
| 98 | + wait = args.wait |
| 99 | + |
| 100 | + parameters={ |
| 101 | + 'SOURCE_REF': args.source_ref, |
| 102 | + 'WORK_DIR': args.work_dir, |
| 103 | + 'SAVED_OUTPUT': args.saved_output, |
| 104 | + 'SAVE_PATH': args.save_path, |
| 105 | + 'CONDA_ENV': args.conda_env, |
| 106 | + 'COMMAND': args.command, |
| 107 | + 'REMOTE': args.remote |
| 108 | + } |
| 109 | + kwargs = dict( |
| 110 | + jobName=jobName, |
| 111 | + jobQueue=jobQueue, |
| 112 | + jobDefinition=jobDefinition, |
| 113 | + parameters=parameters, |
| 114 | + ) |
| 115 | + if args.timeout is not None: |
| 116 | + kwargs['timeout'] = {'attemptDurationSeconds': args.timeout} |
| 117 | + submitJobResponse = batch.submit_job(**kwargs) |
| 118 | + |
| 119 | + jobId = submitJobResponse['jobId'] |
| 120 | + print('Submitted job [{} - {}] to the job queue [{}]'.format(jobName, jobId, jobQueue)) |
| 121 | + |
| 122 | + spinner = 0 |
| 123 | + running = False |
| 124 | + status_set = set() |
| 125 | + startTime = 0 |
| 126 | + |
| 127 | + while wait: |
| 128 | + time.sleep(random.randint(5, 10)) |
| 129 | + describeJobsResponse = batch.describe_jobs(jobs=[jobId]) |
| 130 | + status = describeJobsResponse['jobs'][0]['status'] |
| 131 | + if status == 'SUCCEEDED' or status == 'FAILED': |
| 132 | + print('=' * 80) |
| 133 | + print('Job [{} - {}] {}'.format(jobName, jobId, status)) |
| 134 | + |
| 135 | + sys.exit(status == 'FAILED') |
| 136 | + |
| 137 | + elif status == 'RUNNING': |
| 138 | + logStreamName = getLogStream(logGroupName, jobName, jobId) |
| 139 | + if not running: |
| 140 | + running = True |
| 141 | + print('\rJob [{} - {}] is RUNNING.'.format(jobName, jobId)) |
| 142 | + if logStreamName: |
| 143 | + print('Output [{}]:\n {}'.format(logStreamName, '=' * 80)) |
| 144 | + if logStreamName: |
| 145 | + startTime = printLogs(logGroupName, logStreamName, startTime) + 1 |
| 146 | + elif status not in status_set: |
| 147 | + status_set.add(status) |
| 148 | + print('\rJob [%s - %s] is %-9s... %s' % (jobName, jobId, status, spin[spinner % len(spin)]),) |
| 149 | + sys.stdout.flush() |
| 150 | + spinner += 1 |
| 151 | + |
| 152 | +if __name__ == '__main__': |
| 153 | + main() |
0 commit comments