Skip to content

Commit

Permalink
[CI] AWS batch job tool for GluonNLP (Part I) (dmlc#1251)
Browse files Browse the repository at this point in the history
* AWS batch job tool for GluonNLP

* limit range

Co-authored-by: Xingjian Shi <[email protected]>
  • Loading branch information
szha and sxjscience authored Jul 7, 2020
1 parent e06ff01 commit 689eba9
Show file tree
Hide file tree
Showing 4 changed files with 307 additions and 0 deletions.
27 changes: 27 additions & 0 deletions tools/batch/docker/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
FROM nvidia/cuda:10.1-cudnn7-devel-ubuntu18.04

RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \
locales \
cmake \
git \
curl \
vim \
unzip \
sudo \
ca-certificates \
libjpeg-dev \
libpng-dev \
libfreetype6-dev \
libxft-dev &&\
rm -rf /var/lib/apt/lists/*

RUN curl -o ~/miniconda.sh -O https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
chmod +x ~/miniconda.sh && \
~/miniconda.sh -b -p /opt/conda && \
rm ~/miniconda.sh && \
/opt/conda/bin/conda clean -ya
ENV PATH /opt/conda/bin:$PATH
RUN git clone https://github.com/dmlc/gluon-nlp
WORKDIR gluon-nlp
ADD gluon_nlp_job.sh .
33 changes: 33 additions & 0 deletions tools/batch/docker/gluon_nlp_job.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#!/bin/bash
date
echo "Args: $@"
env
echo "jobId: $AWS_BATCH_JOB_ID"
echo "jobQueue: $AWS_BATCH_JQ_NAME"
echo "computeEnvironment: $AWS_BATCH_CE_NAME"

SOURCE_REF=$1
CONDA_ENV=$2
WORK_DIR=$3
COMMAND=$4
SAVED_OUTPUT=$5
SAVE_PATH=$6
REMOTE=$7

if [ ! -z $REMOTE ]; then
git remote set-url origin $REMOTE
fi;

git fetch origin $SOURCE_REF:working
git checkout working
pip install -v -e .[extras]

cd $WORK_DIR
/bin/bash -o pipefail -c "$COMMAND"
COMMAND_EXIT_CODE=$?
if [[ -f $SAVED_OUTPUT ]]; then
aws s3 cp $SAVED_OUTPUT s3://gluon-nlp-staging/$SAVE_PATH;
elif [[ -d $SAVED_OUTPUT ]]; then
aws s3 cp --recursive $SAVED_OUTPUT s3://gluon-nlp-staging/$SAVE_PATH;
fi;
exit $COMMAND_EXIT_CODE
154 changes: 154 additions & 0 deletions tools/batch/submit-job.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import argparse
import random
import re
import sys
import time
from datetime import datetime

import boto3
from botocore.compat import total_seconds

parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)

parser.add_argument('--profile', help='profile name of aws account.', type=str,
default=None)
parser.add_argument('--region', help='Default region when creating new connections', type=str,
default=None)
parser.add_argument('--name', help='name of the job', type=str, default='dummy')
parser.add_argument('--job-queue', help='name of the job queue to submit this job', type=str,
default='gluon-nlp-jobs')
parser.add_argument('--job-definition', help='name of the job job definition', type=str,
default='gluon-nlp-jobs:8')
parser.add_argument('--source-ref',
help='ref in GluonNLP main github. e.g. master, refs/pull/500/head',
type=str, default='master')
parser.add_argument('--work-dir',
help='working directory inside the repo. e.g. scripts/sentiment_analysis',
type=str, default='scripts/bert')
parser.add_argument('--saved-output',
help='output to be saved, relative to working directory. '
'it can be either a single file or a directory',
type=str, default='.')
parser.add_argument('--save-path',
help='s3 path where files are saved.',
type=str, default='batch/temp/{}'.format(datetime.now().isoformat()))
parser.add_argument('--conda-env',
help='conda environment preset to use.',
type=str, default='gpu/py3')
parser.add_argument('--command', help='command to run', type=str,
default='git rev-parse HEAD | tee stdout.log')
parser.add_argument('--remote',
help='git repo address. https://github.com/dmlc/gluon-nlp',
type=str, default="https://github.com/dmlc/gluon-nlp")
parser.add_argument('--wait', help='block wait until the job completes. '
'Non-zero exit code if job fails.', action='store_true')
parser.add_argument('--timeout', help='job timeout in seconds', default=None, type=int)

args = parser.parse_args()

session = boto3.Session(profile_name=args.profile, region_name=args.region)
batch, cloudwatch = [session.client(service_name=sn) for sn in ['batch', 'logs']]

def printLogs(logGroupName, logStreamName, startTime):
kwargs = {'logGroupName': logGroupName,
'logStreamName': logStreamName,
'startTime': startTime,
'startFromHead': True}

lastTimestamp = 0
while True:
logEvents = cloudwatch.get_log_events(**kwargs)

for event in logEvents['events']:
lastTimestamp = event['timestamp']
timestamp = datetime.utcfromtimestamp(lastTimestamp / 1000.0).isoformat()
print('[{}] {}'.format((timestamp + '.000')[:23] + 'Z', event['message']))

nextToken = logEvents['nextForwardToken']
if nextToken and kwargs.get('nextToken') != nextToken:
kwargs['nextToken'] = nextToken
else:
break
return lastTimestamp


def getLogStream(logGroupName, jobName, jobId):
response = cloudwatch.describe_log_streams(
logGroupName=logGroupName,
logStreamNamePrefix=jobName + '/' + jobId
)
logStreams = response['logStreams']
if not logStreams:
return ''
else:
return logStreams[0]['logStreamName']

def nowInMillis():
endTime = long(total_seconds(datetime.utcnow() - datetime(1970, 1, 1))) * 1000
return endTime


def main():
spin = ['-', '/', '|', '\\', '-', '/', '|', '\\']
logGroupName = '/aws/batch/job'

jobName = re.sub('[^A-Za-z0-9_\-]', '', args.name)[:128] # Enforce AWS Batch jobName rules
jobQueue = args.job_queue
jobDefinition = args.job_definition
command = args.command.split()
wait = args.wait

parameters={
'SOURCE_REF': args.source_ref,
'WORK_DIR': args.work_dir,
'SAVED_OUTPUT': args.saved_output,
'SAVE_PATH': args.save_path,
'CONDA_ENV': args.conda_env,
'COMMAND': args.command,
'REMOTE': args.remote
}
kwargs = dict(
jobName=jobName,
jobQueue=jobQueue,
jobDefinition=jobDefinition,
parameters=parameters,
)
if args.timeout is not None:
kwargs['timeout'] = {'attemptDurationSeconds': args.timeout}
submitJobResponse = batch.submit_job(**kwargs)

jobId = submitJobResponse['jobId']
print('Submitted job [{} - {}] to the job queue [{}]'.format(jobName, jobId, jobQueue))

spinner = 0
running = False
status_set = set()
startTime = 0

while wait:
time.sleep(random.randint(5, 10))
describeJobsResponse = batch.describe_jobs(jobs=[jobId])
status = describeJobsResponse['jobs'][0]['status']
if status == 'SUCCEEDED' or status == 'FAILED':
print('=' * 80)
print('Job [{} - {}] {}'.format(jobName, jobId, status))

sys.exit(status == 'FAILED')

elif status == 'RUNNING':
logStreamName = getLogStream(logGroupName, jobName, jobId)
if not running:
running = True
print('\rJob [{} - {}] is RUNNING.'.format(jobName, jobId))
if logStreamName:
print('Output [{}]:\n {}'.format(logStreamName, '=' * 80))
if logStreamName:
startTime = printLogs(logGroupName, logStreamName, startTime) + 1
elif status not in status_set:
status_set.add(status)
print('\rJob [%s - %s] is %-9s... %s' % (jobName, jobId, status, spin[spinner % len(spin)]),)
sys.stdout.flush()
spinner += 1

if __name__ == '__main__':
main()
93 changes: 93 additions & 0 deletions tools/batch/wait-job.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import argparse
from datetime import datetime
import sys
import time

import boto3
from botocore.compat import total_seconds

parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)

parser.add_argument('--profile', help='profile name of aws account.', type=str,
default=None)
parser.add_argument('--job-id', help='job id to check status and wait.', type=str,
default=None)

args = parser.parse_args()

session = boto3.Session(profile_name=args.profile)
batch, cloudwatch = [session.client(service_name=sn) for sn in ['batch', 'logs']]

def printLogs(logGroupName, logStreamName, startTime):
kwargs = {'logGroupName': logGroupName,
'logStreamName': logStreamName,
'startTime': startTime,
'startFromHead': True}

lastTimestamp = 0
while True:
logEvents = cloudwatch.get_log_events(**kwargs)

for event in logEvents['events']:
lastTimestamp = event['timestamp']
timestamp = datetime.utcfromtimestamp(lastTimestamp / 1000.0).isoformat()
print('[{}] {}'.format((timestamp + '.000')[:23] + 'Z', event['message']))

nextToken = logEvents['nextForwardToken']
if nextToken and kwargs.get('nextToken') != nextToken:
kwargs['nextToken'] = nextToken
else:
break
return lastTimestamp


def getLogStream(logGroupName, jobName, jobId):
response = cloudwatch.describe_log_streams(
logGroupName=logGroupName,
logStreamNamePrefix=jobName + '/' + jobId
)
logStreams = response['logStreams']
if not logStreams:
return ''
else:
return logStreams[0]['logStreamName']

def nowInMillis():
endTime = long(total_seconds(datetime.utcnow() - datetime(1970, 1, 1))) * 1000
return endTime


def main():
spin = ['-', '/', '|', '\\', '-', '/', '|', '\\']
logGroupName = '/aws/batch/job'

jobId = args.job_id

spinner = 0
running = False
startTime = 0

while True:
time.sleep(1)
describeJobsResponse = batch.describe_jobs(jobs=[jobId])
job = describeJobsResponse['jobs'][0]
status, jobName = job['status'], job['jobName']
if status == 'SUCCEEDED' or status == 'FAILED':
print('=' * 80)
print('Job [{} - {}] {}'.format(jobName, jobId, status))
break
elif status == 'RUNNING':
logStreamName = getLogStream(logGroupName, jobName, jobId)
if not running and logStreamName:
running = True
print('\rJob [{} - {}] is RUNNING.'.format(jobName, jobId))
print('Output [{}]:\n {}'.format(logStreamName, '=' * 80))
if logStreamName:
startTime = printLogs(logGroupName, logStreamName, startTime) + 1
else:
print('\rJob [%s - %s] is %-9s... %s' % (jobName, jobId, status, spin[spinner % len(spin)]),)
sys.stdout.flush()
spinner += 1

if __name__ == '__main__':
main()

0 comments on commit 689eba9

Please sign in to comment.