From b6663ae731b3757715a7cbbe3973420f45b838c7 Mon Sep 17 00:00:00 2001 From: JustAnotherArchivist Date: Tue, 11 Jan 2022 04:22:09 +0000 Subject: [PATCH] Add concurrency --- ia-upload-stream | 102 ++++++++++++++++++++++++++++++----------------- 1 file changed, 65 insertions(+), 37 deletions(-) diff --git a/ia-upload-stream b/ia-upload-stream index 79a76a6..61d7f68 100755 --- a/ia-upload-stream +++ b/ia-upload-stream @@ -1,8 +1,10 @@ #!/usr/bin/env python3 # Only external dependency: requests import argparse +import asyncio import base64 import collections +import concurrent.futures import configparser import contextlib import functools @@ -151,7 +153,42 @@ def maybe_file_progress_bar(progress, f, *args, **kwargs): yield f -def upload(item, filename, metadata, *, iaConfigFile = None, partSize = 100*1024*1024, tries = 3, queueDerive = True, keepOldVersion = True, complete = True, uploadId = None, parts = None, progress = True): +def upload_one(url, uploadId, partNumber, data, contentMd5, size, headers, progress, tries): + for attempt in range(1, tries + 1): + if attempt > 1: + logger.info(f'Retrying part {partNumber}') + try: + with maybe_file_progress_bar(progress, data, 'read', f'uploading {partNumber}', size = size) as w: + r = requests.put(f'{url}?partNumber={partNumber}&uploadId={uploadId}', headers = {**headers, 'Content-MD5': contentMd5}, data = w) + except (ConnectionError, requests.exceptions.RequestException) as e: + err = f'error {type(e).__module__}.{type(e).__name__} {e!s}' + else: + if r.status_code == 200: + break + err = f'status {r.status_code}' + sleepTime = min(3 ** attempt, 30) + retrying = f', retrying after {sleepTime} seconds' if attempt < tries else '' + logger.error(f'Got {err} from IA S3 on uploading part {partNumber}{retrying}') + if attempt == tries: + raise UploadError(f'Got {err} from IA S3 on uploading part {partNumber}', r = r, uploadId = uploadId) # parts is added in wait_first + time.sleep(sleepTime) + data.seek(0) + return partNumber, r.headers['ETag'] + + +async def wait_first(tasks, parts): + task = tasks.popleft() + try: + partNumber, eTag = await task + except UploadError as e: + # The upload task can't add an accurate parts list, so add that here and reraise + e.parts = parts + raise + parts.append((partNumber, eTag)) + logger.info(f'Upload of part {partNumber} OK, ETag: {eTag}') + + +async def upload(item, filename, metadata, *, iaConfigFile = None, partSize = 100*1024*1024, tries = 3, concurrency = 1, queueDerive = True, keepOldVersion = True, complete = True, uploadId = None, parts = None, progress = True): f = sys.stdin.buffer # Read `ia` config @@ -177,41 +214,30 @@ def upload(item, filename, metadata, *, iaConfigFile = None, partSize = 100*1024 # Upload the data in parts if parts is None: parts = [] - for partNumber in itertools.count(start = len(parts) + 1): - data = io.BytesIO() - with maybe_file_progress_bar(progress, data, 'write', 'reading input') as w: - readinto_size_limit(f, w, partSize) - data.seek(0) - size = len(data.getbuffer()) - if not size: - # We're done! - break - logger.info(f'Uploading part {partNumber} ({size} bytes)') - logger.info('Calculating MD5') - h = hashlib.md5(data.getbuffer()) - logger.info(f'MD5: {h.hexdigest()}') - contentMd5 = base64.b64encode(h.digest()).decode('ascii') - for attempt in range(1, tries + 1): - if attempt > 1: - logger.info(f'Retrying part {partNumber}') - try: - with maybe_file_progress_bar(progress, data, 'read', 'uploading', size = size) as w: - r = requests.put(f'{url}?partNumber={partNumber}&uploadId={uploadId}', headers = {**headers, 'Content-MD5': contentMd5}, data = w) - except (ConnectionError, requests.exceptions.RequestException) as e: - err = f'error {type(e).__module__}.{type(e).__name__} {e!s}' - else: - if r.status_code == 200: - break - err = f'status {r.status_code}' - sleepTime = min(3 ** attempt, 30) - retrying = f', retrying after {sleepTime} seconds' if attempt < tries else '' - logger.error(f'Got {err} from IA S3 on uploading part {partNumber}{retrying}') - if attempt == tries: - raise UploadError(f'Got {err} from IA S3 on uploading part {partNumber}', r = r, uploadId = uploadId, parts = parts) - time.sleep(sleepTime) + tasks = collections.deque() + loop = asyncio.get_event_loop() + with concurrent.futures.ThreadPoolExecutor(max_workers = concurrency) as executor: + for partNumber in itertools.count(start = len(parts) + 1): + while len(tasks) >= concurrency: + await wait_first(tasks, parts) + data = io.BytesIO() + with maybe_file_progress_bar(progress, data, 'write', 'reading input') as w: + readinto_size_limit(f, w, partSize) data.seek(0) - logger.info(f'Upload OK, ETag: {r.headers["ETag"]}') - parts.append((partNumber, r.headers['ETag'])) + size = len(data.getbuffer()) + if not size: + # We're done! + break + logger.info(f'Uploading part {partNumber} ({size} bytes)') + logger.info('Calculating MD5') + h = hashlib.md5(data.getbuffer()) + logger.info(f'MD5: {h.hexdigest()}') + contentMd5 = base64.b64encode(h.digest()).decode('ascii') + + task = loop.run_in_executor(executor, upload_one, url, uploadId, partNumber, data, contentMd5, size, headers, progress, tries) + tasks.append(task) + while tasks: + await wait_first(tasks, parts) # If --no-complete is used, raise the special error to be caught in main for pretty printing. if not complete: @@ -293,6 +319,7 @@ def main(): parser.add_argument('--clobber', dest = 'keepOldVersion', action = 'store_false', help = 'enable clobbering existing files') parser.add_argument('--ia-config-file', dest = 'iaConfigFile', metavar = 'FILE', help = 'path to the ia CLI config file (default: search the same paths as ia)') parser.add_argument('--tries', type = int, default = 3, metavar = 'N', help = 'retry on S3 errors (default: 3)') + parser.add_argument('--concurrency', '--concurrent', type = int, default = 1, metavar = 'N', help = 'upload N parts in parallel (default: 1)') parser.add_argument('--no-complete', dest = 'complete', action = 'store_false', help = 'disable completing the upload when stdin is exhausted') parser.add_argument('--no-progress', dest = 'progress', action = 'store_false', help = 'disable progress bar') parser.add_argument('--upload-id', dest = 'uploadId', help = 'upload ID when resuming or aborting an upload') @@ -310,20 +337,21 @@ def main(): logging.basicConfig(level = logging.INFO, format = '{asctime}.{msecs:03.0f} {levelname} {name} {message}', datefmt = '%Y-%m-%d %H:%M:%S', style = '{') try: if not args.abort: - upload( + asyncio.run(upload( args.item, args.filename, args.metadata, iaConfigFile = args.iaConfigFile, partSize = args.partSize, tries = args.tries, + concurrency = args.concurrency, queueDerive = args.queueDerive, keepOldVersion = args.keepOldVersion, complete = args.complete, uploadId = args.uploadId, parts = args.parts, progress = args.progress, - ) + )) else: abort( args.item,