From d3b14239c0438034ebc554a092a915929a4cf183 Mon Sep 17 00:00:00 2001 From: Chetan G Date: Thu, 26 Oct 2023 18:50:04 +0530 Subject: [PATCH] Update and rename download_model.py to download_2model.py Added the ability to download files in parallel using a ThreadPoolExecutor, which can speed up the download process. Improved handling of existing files - if a local file exists, it now checks if the file size matches the remote file before deciding to skip or redownload. Changed the error handling to provide more informative error messages. Added a default of 4 threads for parallel downloads, but you can adjust this value as needed. --- .../{download_model.py => download_2model.py} | 51 +++++++++++-------- 1 file changed, 31 insertions(+), 20 deletions(-) rename scripts/{download_model.py => download_2model.py} (56%) diff --git a/scripts/download_model.py b/scripts/download_2model.py similarity index 56% rename from scripts/download_model.py rename to scripts/download_2model.py index b951b05..aa6bb14 100644 --- a/scripts/download_model.py +++ b/scripts/download_2model.py @@ -3,22 +3,25 @@ import sys import json import requests from tqdm import tqdm +from concurrent.futures import ThreadPoolExecutor def download_file(url: str, filename: str, download_dir: str): """Download a file if it does not already exist.""" try: filepath = os.path.join(download_dir, filename) - content_length = int(requests.head(url).headers.get("content-length", 0)) - # If file already exists and size matches, skip download - if os.path.isfile(filepath) and os.path.getsize(filepath) == content_length: - print(f"{filepath} already exists. Skipping download.") - return - if os.path.isfile(filepath) and os.path.getsize(filepath) != content_length: - print(f"{filepath} already exists but size does not match. Redownloading.") - else: - print(f"Downloading {filename} from {url}") + # Check if the file already exists and its size matches the remote file + if os.path.isfile(filepath): + local_file_size = os.path.getsize(filepath) + remote_file_size = int(requests.head(url).headers.get("content-length", 0)) + if local_file_size == remote_file_size: + print(f"{filepath} already exists. Skipping download.") + return + else: + print(f"{filepath} already exists, but its size does not match. Redownloading.") + + print(f"Downloading {filename} from {url}") # Start download, stream=True allows for progress tracking response = requests.get(url, stream=True) @@ -29,14 +32,14 @@ def download_file(url: str, filename: str, download_dir: str): # Create progress bar total_size = int(response.headers.get('content-length', 0)) progress_bar = tqdm( - total=total_size, - unit='iB', - unit_scale=True, - ncols=70, + total=total_size, + unit='iB', + unit_scale=True, + ncols=70, file=sys.stdout ) - # Write response content to file + # Write response content to file using a buffer with open(filepath, 'wb') as f: for data in response.iter_content(chunk_size=1024): f.write(data) @@ -50,13 +53,23 @@ def download_file(url: str, filename: str, download_dir: str): print("ERROR, something went wrong while downloading") raise Exception() - except Exception as e: print(f"An error occurred: {e}") +def download_files_parallel(config, download_dir, num_threads=4): + """Download files in parallel using ThreadPoolExecutor.""" + with ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [] + + for url, filename in config.items(): + futures.append(executor.submit(download_file, url, filename, download_dir)) + + for future in futures: + future.result() + def main(): """Main function to download files from URLs in a config file.""" - + # Get JSON config file path script_dir = os.path.dirname(os.path.realpath(__file__)) config_file_path = os.path.join(script_dir, "download_models.json") @@ -69,10 +82,8 @@ def main(): with open(config_file_path, "r") as f: config = json.load(f) - # Download each file specified in config - for url, filename in config.items(): - download_file(url, filename, download_dir) - + # Download files in parallel + download_files_parallel(config, download_dir, num_threads=4) if __name__ == "__main__": main()