diff --git a/download_from_google_storage.py b/download_from_google_storage.py index 5f7b71b14..50b34b8f5 100755 --- a/download_from_google_storage.py +++ b/download_from_google_storage.py @@ -254,9 +254,24 @@ def _downloader_worker_thread(thread_num, q, force, base_url, continue extract_dir = output_filename[:-len('.tar.gz')] if os.path.exists(output_filename) and not force: - if not extract or os.path.exists(extract_dir): - if get_sha1(output_filename) == input_sha1_sum: - continue + skip = get_sha1(output_filename) == input_sha1_sum + if extract: + # Additional condition for extract: + # 1) extract_dir must exist + # 2) .tmp flag file mustn't exist + if not os.path.exists(extract_dir): + out_q.put('%d> Extract dir %s does not exist, re-downloading...' % + (thread_num, extract_dir)) + skip = False + # .tmp file is created just before extraction and removed just after + # extraction. If such file exists, it means the process was terminated + # mid-extraction and therefore needs to be extracted again. + elif os.path.exists(extract_dir + '.tmp'): + out_q.put('%d> Detected tmp flag file for %s, ' + 're-downloading...' % (thread_num, output_filename)) + skip = False + if skip: + continue # Check if file exists. file_url = '%s/%s' % (base_url, input_sha1_sum) (code, _, err) = gsutil.check_call('ls', file_url) @@ -336,7 +351,9 @@ def _downloader_worker_thread(thread_num, q, force, base_url, out_q.put('%d> Extracting %d entries from %s to %s' % (thread_num, len(tar.getmembers()),output_filename, extract_dir)) - tar.extractall(path=dirname) + with open(extract_dir + '.tmp', 'a'): + tar.extractall(path=dirname) + os.remove(extract_dir + '.tmp') # Set executable bit. if sys.platform == 'cygwin': # Under cygwin, mark all files as executable. The executable flag in diff --git a/tests/download_from_google_storage_unittest.py b/tests/download_from_google_storage_unittest.py index 978b545b2..78e09f2d1 100755 --- a/tests/download_from_google_storage_unittest.py +++ b/tests/download_from_google_storage_unittest.py @@ -295,6 +295,8 @@ class DownloadTests(unittest.TestCase): shutil.rmtree(output_dirname) sha1_hash = download_from_google_storage.get_sha1(output_filename) input_filename = '%s/%s' % (self.base_url, sha1_hash) + + # Initial download self.queue.put((sha1_hash, output_filename)) self.queue.put((None, None)) stdout_queue = queue.Queue() @@ -322,6 +324,64 @@ class DownloadTests(unittest.TestCase): self.assertTrue(os.path.exists(output_dirname)) self.assertTrue(os.path.exists(extracted_filename)) + # Test noop download + self.queue.put((sha1_hash, output_filename)) + self.queue.put((None, None)) + stdout_queue = queue.Queue() + download_from_google_storage._downloader_worker_thread(0, + self.queue, + False, + self.base_url, + self.gsutil, + stdout_queue, + self.ret_codes, + True, + True, + delete=False) + + self.assertEqual(list(stdout_queue.queue), []) + self.assertEqual(self.gsutil.history, expected_calls) + self.assertEqual(list(self.ret_codes.queue), []) + self.assertTrue(os.path.exists(output_dirname)) + self.assertTrue(os.path.exists(extracted_filename)) + + # With dirty flag file, previous extraction wasn't complete + with open(os.path.join(self.base_path, 'subfolder.tmp'), 'a'): + pass + + self.queue.put((sha1_hash, output_filename)) + self.queue.put((None, None)) + stdout_queue = queue.Queue() + download_from_google_storage._downloader_worker_thread(0, + self.queue, + False, + self.base_url, + self.gsutil, + stdout_queue, + self.ret_codes, + True, + True, + delete=False) + expected_calls += [('check_call', ('ls', input_filename)), + ('check_call', ('cp', input_filename, output_filename))] + if sys.platform != 'win32': + expected_calls.append( + ('check_call', ('stat', 'gs://sometesturl/%s' % sha1_hash))) + expected_output = [ + '0> Detected tmp flag file for %s, re-downloading...' % + (output_filename), + '0> Downloading %s@%s...' % (output_filename, sha1_hash), + '0> Removed %s...' % (output_dirname), + '0> Extracting 3 entries from %s to %s' % + (output_filename, output_dirname), + ] + expected_ret_codes = [] + self.assertEqual(list(stdout_queue.queue), expected_output) + self.assertEqual(self.gsutil.history, expected_calls) + self.assertEqual(list(self.ret_codes.queue), expected_ret_codes) + self.assertTrue(os.path.exists(output_dirname)) + self.assertTrue(os.path.exists(extracted_filename)) + def test_download_worker_skips_not_found_file(self): sha1_hash = '7871c8e24da15bad8b0be2c36edc9dc77e37727f' input_filename = '%s/%s' % (self.base_url, sha1_hash)