diff --git a/scripts/download_model.py b/scripts/download_2model.py similarity index 56% rename from scripts/download_model.py rename to scripts/download_2model.py index b951b05..aa6bb14 100644 --- a/scripts/download_model.py +++ b/scripts/download_2model.py @@ -3,22 +3,25 @@ import sys import json import requests from tqdm import tqdm +from concurrent.futures import ThreadPoolExecutor def download_file(url: str, filename: str, download_dir: str): """Download a file if it does not already exist.""" try: filepath = os.path.join(download_dir, filename) - content_length = int(requests.head(url).headers.get("content-length", 0)) - # If file already exists and size matches, skip download - if os.path.isfile(filepath) and os.path.getsize(filepath) == content_length: - print(f"{filepath} already exists. Skipping download.") - return - if os.path.isfile(filepath) and os.path.getsize(filepath) != content_length: - print(f"{filepath} already exists but size does not match. Redownloading.") - else: - print(f"Downloading {filename} from {url}") + # Check if the file already exists and its size matches the remote file + if os.path.isfile(filepath): + local_file_size = os.path.getsize(filepath) + remote_file_size = int(requests.head(url).headers.get("content-length", 0)) + if local_file_size == remote_file_size: + print(f"{filepath} already exists. Skipping download.") + return + else: + print(f"{filepath} already exists, but its size does not match. Redownloading.") + + print(f"Downloading {filename} from {url}") # Start download, stream=True allows for progress tracking response = requests.get(url, stream=True) @@ -29,14 +32,14 @@ def download_file(url: str, filename: str, download_dir: str): # Create progress bar total_size = int(response.headers.get('content-length', 0)) progress_bar = tqdm( - total=total_size, - unit='iB', - unit_scale=True, - ncols=70, + total=total_size, + unit='iB', + unit_scale=True, + ncols=70, file=sys.stdout ) - # Write response content to file + # Write response content to file using a buffer with open(filepath, 'wb') as f: for data in response.iter_content(chunk_size=1024): f.write(data) @@ -50,13 +53,23 @@ def download_file(url: str, filename: str, download_dir: str): print("ERROR, something went wrong while downloading") raise Exception() - except Exception as e: print(f"An error occurred: {e}") +def download_files_parallel(config, download_dir, num_threads=4): + """Download files in parallel using ThreadPoolExecutor.""" + with ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [] + + for url, filename in config.items(): + futures.append(executor.submit(download_file, url, filename, download_dir)) + + for future in futures: + future.result() + def main(): """Main function to download files from URLs in a config file.""" - + # Get JSON config file path script_dir = os.path.dirname(os.path.realpath(__file__)) config_file_path = os.path.join(script_dir, "download_models.json") @@ -69,10 +82,8 @@ def main(): with open(config_file_path, "r") as f: config = json.load(f) - # Download each file specified in config - for url, filename in config.items(): - download_file(url, filename, download_dir) - + # Download files in parallel + download_files_parallel(config, download_dir, num_threads=4) if __name__ == "__main__": main()