diff --git a/download_from_google_storage.py b/download_from_google_storage.py index a45f387ce..d3f743691 100755 --- a/download_from_google_storage.py +++ b/download_from_google_storage.py @@ -11,8 +11,10 @@ import optparse import os import Queue import re +import shutil import stat import sys +import tarfile import threading import time @@ -49,7 +51,6 @@ def GetNormalizedPlatform(): return 'win32' return sys.platform - # Common utilities class Gsutil(object): """Call gsutil with some predefined settings. This is a convenience object, @@ -186,8 +187,19 @@ def enumerate_work_queue(input_filename, work_queue, directory, return work_queue_size +def _validate_tar_file(tar, prefix): + def _validate(tarinfo): + """Returns false if the tarinfo is something we explicitly forbid.""" + if tarinfo.issym() or tarinfo.islnk(): + return False + if '..' in tarinfo.name or not tarinfo.name.startswith(prefix): + return False + return True + return all(map(_validate, tar.getmembers())) + def _downloader_worker_thread(thread_num, q, force, base_url, - gsutil, out_q, ret_codes, verbose): + gsutil, out_q, ret_codes, verbose, extract, + delete=True): while True: input_sha1_sum, output_filename = q.get() if input_sha1_sum is None: @@ -218,7 +230,8 @@ def _downloader_worker_thread(thread_num, q, force, base_url, # Fetch the file. out_q.put('%d> Downloading %s...' % (thread_num, output_filename)) try: - os.remove(output_filename) # Delete the file if it exists already. + if delete: + os.remove(output_filename) # Delete the file if it exists already. except OSError: if os.path.exists(output_filename): out_q.put('%d> Warning: deleting %s failed.' % ( @@ -228,6 +241,34 @@ def _downloader_worker_thread(thread_num, q, force, base_url, out_q.put('%d> %s' % (thread_num, err)) ret_codes.put((code, err)) + if extract: + if (not tarfile.is_tarfile(output_filename) + or not output_filename.endswith('.tar.gz')): + out_q.put('%d> Error: %s is not a tar.gz archive.' % ( + thread_num, output_filename)) + ret_codes.put((1, '%s is not a tar.gz archive.' % (output_filename))) + continue + with tarfile.open(output_filename, 'r:gz') as tar: + dirname = os.path.dirname(os.path.abspath(output_filename)) + extract_dir = output_filename[0:len(output_filename)-7] + if not _validate_tar_file(tar, os.path.basename(extract_dir)): + out_q.put('%d> Error: %s contains files outside %s.' % ( + thread_num, output_filename, extract_dir)) + ret_codes.put((1, '%s contains invalid entries.' % (output_filename))) + continue + if os.path.exists(extract_dir): + try: + shutil.rmtree(extract_dir) + out_q.put('%d> Removed %s...' % (thread_num, extract_dir)) + except OSError: + out_q.put('%d> Warning: Can\'t delete: %s' % ( + thread_num, extract_dir)) + ret_codes.put((1, 'Can\'t delete %s.' % (extract_dir))) + continue + out_q.put('%d> Extracting %d entries from %s to %s' % + (thread_num, len(tar.getmembers()),output_filename, + extract_dir)) + tar.extractall(path=dirname) # Set executable bit. if sys.platform == 'cygwin': # Under cygwin, mark all files as executable. The executable flag in @@ -258,7 +299,7 @@ def printer_worker(output_queue): def download_from_google_storage( input_filename, base_url, gsutil, num_threads, directory, recursive, - force, output, ignore_errors, sha1_file, verbose, auto_platform): + force, output, ignore_errors, sha1_file, verbose, auto_platform, extract): # Start up all the worker threads. all_threads = [] download_start = time.time() @@ -270,7 +311,7 @@ def download_from_google_storage( t = threading.Thread( target=_downloader_worker_thread, args=[thread_num, work_queue, force, base_url, - gsutil, stdout_queue, ret_codes, verbose]) + gsutil, stdout_queue, ret_codes, verbose, extract]) t.daemon = True t.start() all_threads.append(t) @@ -358,6 +399,13 @@ def main(args): '(linux|mac|win). If so, the script will only ' 'process files that are in the paths that ' 'that matches the current platform.') + parser.add_option('-u', '--extract', + action='store_true', + help='Extract a downloaded tar.gz file. ' + 'Leaves the tar.gz file around for sha1 verification' + 'If a directory with the same name as the tar.gz ' + 'file already exists, is deleted (to get a ' + 'clean state in case of update.)') parser.add_option('-v', '--verbose', action='store_true', help='Output extra diagnostic and progress information.') @@ -451,7 +499,8 @@ def main(args): return download_from_google_storage( input_filename, base_url, gsutil, options.num_threads, options.directory, options.recursive, options.force, options.output, options.ignore_errors, - options.sha1_file, options.verbose, options.auto_platform) + options.sha1_file, options.verbose, options.auto_platform, + options.extract) if __name__ == '__main__': diff --git a/tests/download_from_google_storage_unittests.py b/tests/download_from_google_storage_unittests.py index a8af63b0c..0420ad942 100755 --- a/tests/download_from_google_storage_unittests.py +++ b/tests/download_from_google_storage_unittests.py @@ -11,6 +11,7 @@ import os import Queue import shutil import sys +import tarfile import tempfile import threading import unittest @@ -59,6 +60,21 @@ class GsutilMock(object): return (0, '', '') +class ChangedWorkingDirectory(object): + def __init__(self, working_directory): + self._old_cwd = '' + self._working_directory = working_directory + + def __enter__(self): + self._old_cwd = os.getcwd() + print "Enter directory = ", self._working_directory + os.chdir(self._working_directory) + + def __exit__(self, *_): + print "Enter directory = ", self._old_cwd + os.chdir(self._old_cwd) + + class GstoolsUnitTests(unittest.TestCase): def setUp(self): self.temp_dir = tempfile.mkdtemp(prefix='gstools_test') @@ -68,6 +84,49 @@ class GstoolsUnitTests(unittest.TestCase): def cleanUp(self): shutil.rmtree(self.temp_dir) + def test_validate_tar_file(self): + lorem_ipsum = os.path.join(self.base_path, 'lorem_ipsum.txt') + with ChangedWorkingDirectory(self.base_path): + # Sanity ok check. + tar_dir = 'ok_dir' + os.makedirs(os.path.join(self.base_path, tar_dir)) + tar = 'good.tar.gz' + lorem_ipsum_copy = os.path.join(tar_dir, 'lorem_ipsum.txt') + shutil.copyfile(lorem_ipsum, lorem_ipsum_copy) + with tarfile.open(tar, 'w:gz') as tar: + tar.add(lorem_ipsum_copy) + self.assertTrue( + download_from_google_storage._validate_tar_file(tar, tar_dir)) + + # Test no links. + tar_dir_link = 'for_tar_link' + os.makedirs(tar_dir_link) + link = os.path.join(tar_dir_link, 'link') + os.symlink(lorem_ipsum, link) + tar_with_links = 'with_links.tar.gz' + with tarfile.open(tar_with_links, 'w:gz') as tar: + tar.add(link) + self.assertFalse( + download_from_google_storage._validate_tar_file(tar, tar_dir_link)) + + # Test not outside. + tar_dir_outside = 'outside_tar' + os.makedirs(tar_dir_outside) + tar_with_outside = 'with_outside.tar.gz' + with tarfile.open(tar_with_outside, 'w:gz') as tar: + tar.add(lorem_ipsum) + self.assertFalse( + download_from_google_storage._validate_tar_file(tar, + tar_dir_outside)) + # Test no .. + tar_with_dotdot = 'with_dotdot.tar.gz' + dotdot_file = os.path.join(tar_dir, '..', tar_dir, 'lorem_ipsum.txt') + with tarfile.open(tar_with_dotdot, 'w:gz') as tar: + tar.add(dotdot_file) + self.assertFalse( + download_from_google_storage._validate_tar_file(tar, + tar_dir)) + def test_gsutil(self): gsutil = download_from_google_storage.Gsutil(GSUTIL_DEFAULT_PATH, None) self.assertEqual(gsutil.path, GSUTIL_DEFAULT_PATH) @@ -164,7 +223,7 @@ class DownloadTests(unittest.TestCase): stdout_queue = Queue.Queue() download_from_google_storage._downloader_worker_thread( 0, self.queue, False, self.base_url, self.gsutil, - stdout_queue, self.ret_codes, True) + stdout_queue, self.ret_codes, True, False) expected_calls = [ ('check_call', ('ls', input_filename)), @@ -190,13 +249,53 @@ class DownloadTests(unittest.TestCase): stdout_queue = Queue.Queue() download_from_google_storage._downloader_worker_thread( 0, self.queue, False, self.base_url, self.gsutil, - stdout_queue, self.ret_codes, True) + stdout_queue, self.ret_codes, True, False) expected_output = [ '0> File %s exists and SHA1 matches. Skipping.' % output_filename ] self.assertEqual(list(stdout_queue.queue), expected_output) self.assertEqual(self.gsutil.history, []) + def test_download_extract_archive(self): + # By design we make this not match + sha1_hash = '61223e1ad3d86901a57629fee38313db5ec106ff' + input_filename = '%s/%s' % (self.base_url, sha1_hash) + # Generate a gzipped tarfile + output_filename = os.path.join(self.base_path, 'subfolder.tar.gz') + output_dirname = os.path.join(self.base_path, 'subfolder') + extracted_filename = os.path.join(output_dirname, 'subfolder_text.txt') + with tarfile.open(output_filename, 'w:gz') as tar: + tar.add(output_dirname, arcname='subfolder') + shutil.rmtree(output_dirname) + print(output_dirname) + self.queue.put((sha1_hash, output_filename)) + self.queue.put((None, None)) + stdout_queue = Queue.Queue() + download_from_google_storage._downloader_worker_thread( + 0, self.queue, False, self.base_url, self.gsutil, + stdout_queue, self.ret_codes, True, True, delete=False) + expected_calls = [ + ('check_call', + ('ls', input_filename)), + ('check_call', + ('cp', input_filename, output_filename))] + if sys.platform != 'win32': + expected_calls.append( + ('check_call', + ('stat', + 'gs://sometesturl/61223e1ad3d86901a57629fee38313db5ec106ff'))) + expected_output = [ + '0> Downloading %s...' % output_filename] + expected_output.extend([ + '0> Extracting 3 entries from %s to %s' % (output_filename, + output_dirname)]) + expected_ret_codes = [] + self.assertEqual(list(stdout_queue.queue), expected_output) + self.assertEqual(self.gsutil.history, expected_calls) + self.assertEqual(list(self.ret_codes.queue), expected_ret_codes) + self.assertTrue(os.path.exists(output_dirname)) + self.assertTrue(os.path.exists(extracted_filename)) + def test_download_worker_skips_not_found_file(self): sha1_hash = '7871c8e24da15bad8b0be2c36edc9dc77e37727f' input_filename = '%s/%s' % (self.base_url, sha1_hash) @@ -207,7 +306,7 @@ class DownloadTests(unittest.TestCase): self.gsutil.add_expected(1, '', '') # Return error when 'ls' is called. download_from_google_storage._downloader_worker_thread( 0, self.queue, False, self.base_url, self.gsutil, - stdout_queue, self.ret_codes, True) + stdout_queue, self.ret_codes, True, False) expected_output = [ '0> Failed to fetch file %s for %s, skipping. [Err: ]' % ( input_filename, output_filename), @@ -242,7 +341,8 @@ class DownloadTests(unittest.TestCase): ignore_errors=False, sha1_file=False, verbose=True, - auto_platform=False) + auto_platform=False, + extract=False) expected_calls = [ ('check_call', ('ls', input_filename)), @@ -273,7 +373,8 @@ class DownloadTests(unittest.TestCase): ignore_errors=False, sha1_file=False, verbose=True, - auto_platform=False) + auto_platform=False, + extract=False) expected_calls = [ ('check_call', ('ls', input_filename)), diff --git a/tests/upload_to_google_storage_unittests.py b/tests/upload_to_google_storage_unittests.py index 3bac038d7..9a13e6a64 100755 --- a/tests/upload_to_google_storage_unittests.py +++ b/tests/upload_to_google_storage_unittests.py @@ -11,6 +11,7 @@ import Queue import shutil import StringIO import sys +import tarfile import tempfile import threading import unittest @@ -19,6 +20,7 @@ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import upload_to_google_storage from download_from_google_storage_unittests import GsutilMock +from download_from_google_storage_unittests import ChangedWorkingDirectory # ../third_party/gsutil/gsutil GSUTIL_DEFAULT_PATH = os.path.join( @@ -63,6 +65,30 @@ class UploadTests(unittest.TestCase): os.remove(output_filename) self.assertEqual(code, 0) + def test_create_archive(self): + work_dir = os.path.join(self.base_path, 'download_test_data') + with ChangedWorkingDirectory(work_dir): + dirname = 'subfolder' + dirs = [dirname] + tar_gz_file = '%s.tar.gz' % dirname + self.assertTrue(upload_to_google_storage.validate_archive_dirs(dirs)) + upload_to_google_storage.create_archives(dirs) + self.assertTrue(os.path.exists(tar_gz_file)) + with tarfile.open(tar_gz_file, 'r:gz') as tar: + content = map(lambda x: x.name, tar.getmembers()) + self.assertTrue(dirname in content) + self.assertTrue(os.path.join(dirname, 'subfolder_text.txt') in content) + self.assertTrue( + os.path.join(dirname, 'subfolder_text.txt.sha1') in content) + + def test_validate_archive_dirs_fails(self): + work_dir = os.path.join(self.base_path, 'download_test_data') + with ChangedWorkingDirectory(work_dir): + symlink = 'link' + os.symlink(os.path.join(self.base_path, 'subfolder'), symlink) + self.assertFalse(upload_to_google_storage.validate_archive_dirs([symlink])) + self.assertFalse(upload_to_google_storage.validate_archive_dirs(['foobar'])) + def test_upload_single_file_remote_exists(self): filenames = [self.lorem_ipsum] output_filename = '%s.sha1' % self.lorem_ipsum diff --git a/upload_to_google_storage.py b/upload_to_google_storage.py index 4cf9d1a6e..26bc6b148 100755 --- a/upload_to_google_storage.py +++ b/upload_to_google_storage.py @@ -12,6 +12,7 @@ import Queue import re import stat import sys +import tarfile import threading import time @@ -207,11 +208,38 @@ def upload_to_google_storage( return max_ret_code +def create_archives(dirs): + archive_names = [] + for name in dirs: + tarname = '%s.tar.gz' % name + with tarfile.open(tarname, 'w:gz') as tar: + tar.add(name) + archive_names.append(tarname) + return archive_names + + +def validate_archive_dirs(dirs): + # We don't allow .. in paths in our archives. + if any(map(lambda x: '..' in x, dirs)): + return False + # We only allow dirs. + if any(map(lambda x: not os.path.isdir(x), dirs)): + return False + # We don't allow sym links in our archives. + if any(map(os.path.islink, dirs)): + return False + # We required that the subdirectories we are archiving are all just below + # cwd. + return not any(map(lambda x: x not in next(os.walk('.'))[1], dirs)) + + def main(): parser = optparse.OptionParser(USAGE_STRING) parser.add_option('-b', '--bucket', help='Google Storage bucket to upload to.') parser.add_option('-e', '--boto', help='Specify a custom boto file.') + parser.add_option('-z', '--archive', action='store_true', + help='Archive directory as a tar.gz file') parser.add_option('-f', '--force', action='store_true', help='Force upload even if remote file exists.') parser.add_option('-g', '--gsutil_path', default=GSUTIL_DEFAULT_PATH, @@ -235,6 +263,15 @@ def main(): # Enumerate our inputs. input_filenames = get_targets(args, parser, options.use_null_terminator) + if options.archive: + if not validate_archive_dirs(input_filenames): + parser.error('Only directories just below cwd are valid entries when ' + 'using the --archive argument. Entries can not contain .. ' + ' and entries can not be symlinks. Entries was %s' % + input_filenames) + return 1 + input_filenames = create_archives(input_filenames) + # Make sure we can find a working instance of gsutil. if os.path.exists(GSUTIL_DEFAULT_PATH): gsutil = Gsutil(GSUTIL_DEFAULT_PATH, boto_path=options.boto)