Update and rename download_model.py to download_2model.py

Added the ability to download files in parallel using a ThreadPoolExecutor, which can speed up the download process.
Improved handling of existing files - if a local file exists, it now checks if the file size matches the remote file before deciding to skip or redownload.
Changed the error handling to provide more informative error messages.
Added a default of 4 threads for parallel downloads, but you can adjust this value as needed.
pull/407/head
Chetan G 2 years ago committed by GitHub
parent 336f120ce1
commit d3b14239c0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -3,22 +3,25 @@ import sys
import json
import requests
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor
def download_file(url: str, filename: str, download_dir: str):
"""Download a file if it does not already exist."""
try:
filepath = os.path.join(download_dir, filename)
content_length = int(requests.head(url).headers.get("content-length", 0))
# If file already exists and size matches, skip download
if os.path.isfile(filepath) and os.path.getsize(filepath) == content_length:
print(f"{filepath} already exists. Skipping download.")
return
if os.path.isfile(filepath) and os.path.getsize(filepath) != content_length:
print(f"{filepath} already exists but size does not match. Redownloading.")
else:
print(f"Downloading {filename} from {url}")
# Check if the file already exists and its size matches the remote file
if os.path.isfile(filepath):
local_file_size = os.path.getsize(filepath)
remote_file_size = int(requests.head(url).headers.get("content-length", 0))
if local_file_size == remote_file_size:
print(f"{filepath} already exists. Skipping download.")
return
else:
print(f"{filepath} already exists, but its size does not match. Redownloading.")
print(f"Downloading {filename} from {url}")
# Start download, stream=True allows for progress tracking
response = requests.get(url, stream=True)
@ -29,14 +32,14 @@ def download_file(url: str, filename: str, download_dir: str):
# Create progress bar
total_size = int(response.headers.get('content-length', 0))
progress_bar = tqdm(
total=total_size,
unit='iB',
unit_scale=True,
ncols=70,
total=total_size,
unit='iB',
unit_scale=True,
ncols=70,
file=sys.stdout
)
# Write response content to file
# Write response content to file using a buffer
with open(filepath, 'wb') as f:
for data in response.iter_content(chunk_size=1024):
f.write(data)
@ -50,13 +53,23 @@ def download_file(url: str, filename: str, download_dir: str):
print("ERROR, something went wrong while downloading")
raise Exception()
except Exception as e:
print(f"An error occurred: {e}")
def download_files_parallel(config, download_dir, num_threads=4):
"""Download files in parallel using ThreadPoolExecutor."""
with ThreadPoolExecutor(max_workers=num_threads) as executor:
futures = []
for url, filename in config.items():
futures.append(executor.submit(download_file, url, filename, download_dir))
for future in futures:
future.result()
def main():
"""Main function to download files from URLs in a config file."""
# Get JSON config file path
script_dir = os.path.dirname(os.path.realpath(__file__))
config_file_path = os.path.join(script_dir, "download_models.json")
@ -69,10 +82,8 @@ def main():
with open(config_file_path, "r") as f:
config = json.load(f)
# Download each file specified in config
for url, filename in config.items():
download_file(url, filename, download_dir)
# Download files in parallel
download_files_parallel(config, download_dir, num_threads=4)
if __name__ == "__main__":
main()
Loading…
Cancel
Save