Merge pull request #79 from ochafik/mps-support

Support & instructions for MPS (MacOS w/ GPU support on M1/M2) and CPU
pull/86/head
Xingang Pan 2 years ago committed by GitHub
commit c3fca90444
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -35,7 +35,19 @@
## Requirements ## Requirements
Please follow the requirements of [NVlabs/stylegan3](https://github.com/NVlabs/stylegan3#requirements). If you have CUDA graphic card, please follow the requirements of [NVlabs/stylegan3](https://github.com/NVlabs/stylegan3#requirements).
Otherwise (for GPU acceleration on MacOS with Silicon Mac M1/M2, or just CPU) try the following:
```sh
cat environment.yml | \
grep -v -E 'nvidia|cuda' > environment-no-nvidia.yml && \
conda env create -f environment-no-nvidia.yml
conda activate stylegan3
# On MacOS
export PYTORCH_ENABLE_MPS_FALLBACK=1
```
## Download pre-trained StyleGAN2 weights ## Download pre-trained StyleGAN2 weights

@ -5,11 +5,12 @@ channels:
dependencies: dependencies:
- python >= 3.8 - python >= 3.8
- pip - pip
- numpy>=1.20 - numpy>=1.25
- click>=8.0 - click>=8.0
- pillow=8.3.1 - pillow=9.4.0
- scipy=1.7.1 - scipy=1.11.0
- pytorch=1.9.1 - pytorch>=2.0.1
- torchvision>=0.15.2
- cudatoolkit=11.1 - cudatoolkit=11.1
- requests=2.26.0 - requests=2.26.0
- tqdm=4.62.2 - tqdm=4.62.2
@ -17,8 +18,10 @@ dependencies:
- matplotlib=3.4.2 - matplotlib=3.4.2
- imageio=2.9.0 - imageio=2.9.0
- pip: - pip:
- imgui==1.3.0 - imgui==2.0.0
- glfw==2.2.0 - glfw==2.6.1
- gradio==3.35.2
- pyopengl==3.1.5 - pyopengl==3.1.5
- imageio-ffmpeg==0.4.3 - imageio-ffmpeg==0.4.3
- pyspng # pyspng is currently broken on MacOS (see https://github.com/nurpax/pyspng/pull/6 for instance)
- pyspng-seunglab

@ -103,9 +103,10 @@ def generate_images(
""" """
print('Loading networks from "%s"...' % network_pkl) print('Loading networks from "%s"...' % network_pkl)
device = torch.device('cuda') device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
dtype = torch.float32 if device.type == 'mps' else torch.float64
with dnnlib.util.open_url(network_pkl) as f: with dnnlib.util.open_url(network_pkl) as f:
G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore G = legacy.load_network_pkl(f)['G_ema'].to(device, dtype=dtype) # type: ignore
# import pickle # import pickle
# G = legacy.load_network_pkl(f) # G = legacy.load_network_pkl(f)
# output = open('checkpoints/stylegan2-car-config-f-pt.pkl', 'wb') # output = open('checkpoints/stylegan2-car-config-f-pt.pkl', 'wb')
@ -126,7 +127,7 @@ def generate_images(
# Generate images. # Generate images.
for seed_idx, seed in enumerate(seeds): for seed_idx, seed in enumerate(seeds):
print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds))) print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds)))
z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device) z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device, dtype=dtype)
# Construct an inverse rotation/translation matrix and pass to the generator. The # Construct an inverse rotation/translation matrix and pass to the generator. The
# generator expects this matrix as an inverse to avoid potentially failing numerical # generator expects this matrix as an inverse to avoid potentially failing numerical

@ -63,9 +63,10 @@ def generate_images(
else: else:
import torch import torch
device = torch.device('cuda') device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
dtype = torch.float32 if device.type == 'mps' else torch.float64
with dnnlib.util.open_url(network_pkl) as f: with dnnlib.util.open_url(network_pkl) as f:
G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore G = legacy.load_network_pkl(f)['G_ema'].to(device, dtype=dtype) # type: ignore
os.makedirs(outdir, exist_ok=True) os.makedirs(outdir, exist_ok=True)
@ -92,7 +93,7 @@ def generate_images(
else: ## stylegan v2/v3 else: ## stylegan v2/v3
label = torch.zeros([1, G.c_dim], device=device) label = torch.zeros([1, G.c_dim], device=device)
z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device) z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device, dtype=dtype)
if target_z.size==0: if target_z.size==0:
target_z= z.cpu() target_z= z.cpu()
else: else:

@ -116,9 +116,10 @@ def main(
): ):
device = torch.device('cuda') device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
dtype = torch.float32 if device.type == 'mps' else torch.float64
with dnnlib.util.open_url(network_pkl) as f: with dnnlib.util.open_url(network_pkl) as f:
G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore G = legacy.load_network_pkl(f)['G_ema'].to(device, dtype=dtype) # type: ignore
outdir = os.path.join(outdir) outdir = os.path.join(outdir)
if not os.path.exists(outdir): if not os.path.exists(outdir):
@ -132,8 +133,8 @@ def main(
print('Require two seeds, randomly generate two now.') print('Require two seeds, randomly generate two now.')
seeds = [seeds[0],random.randint(0,10000)] seeds = [seeds[0],random.randint(0,10000)]
z1 = torch.from_numpy(np.random.RandomState(seeds[0]).randn(1, G.z_dim)).to(device) z1 = torch.from_numpy(np.random.RandomState(seeds[0]).randn(1, G.z_dim)).to(device, dtype=dtype)
z2 = torch.from_numpy(np.random.RandomState(seeds[1]).randn(1, G.z_dim)).to(device) z2 = torch.from_numpy(np.random.RandomState(seeds[1]).randn(1, G.z_dim)).to(device, dtype=dtype)
img1 = generate_image_from_z(G, z1, noise_mode, truncation_psi, device) img1 = generate_image_from_z(G, z1, noise_mode, truncation_psi, device)
img2 = generate_image_from_z(G, z2, noise_mode, truncation_psi, device) img2 = generate_image_from_z(G, z2, noise_mode, truncation_psi, device)
img1.save(f'{outdir}/seed{seeds[0]:04d}.png') img1.save(f'{outdir}/seed{seeds[0]:04d}.png')

@ -49,16 +49,17 @@ def generate_style_mix(
): ):
print('Loading networks from "%s"...' % network_pkl) print('Loading networks from "%s"...' % network_pkl)
device = torch.device('cuda') device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
dtype = torch.float32 if device.type == 'mps' else torch.float64
with dnnlib.util.open_url(network_pkl) as f: with dnnlib.util.open_url(network_pkl) as f:
G = legacy.load_network_pkl(f)['G_ema'].to(device) G = legacy.load_network_pkl(f)['G_ema'].to(device, dtype=dtype)
os.makedirs(outdir, exist_ok=True) os.makedirs(outdir, exist_ok=True)
print('Generating W vectors...') print('Generating W vectors...')
all_seeds = list(set(row_seeds + col_seeds)) all_seeds = list(set(row_seeds + col_seeds))
all_z = np.stack([np.random.RandomState(seed).randn(G.z_dim) for seed in all_seeds]) all_z = np.stack([np.random.RandomState(seed).randn(G.z_dim) for seed in all_seeds])
all_w = G.mapping(torch.from_numpy(all_z).to(device), None) all_w = G.mapping(torch.from_numpy(all_z).to(device, dtype=dtype), None)
w_avg = G.mapping.w_avg w_avg = G.mapping.w_avg
all_w = w_avg + (all_w - w_avg) * truncation_psi all_w = w_avg + (all_w - w_avg) * truncation_psi
w_dict = {seed: w for seed, w in zip(all_seeds, list(all_w))} w_dict = {seed: w for seed, w in zip(all_seeds, list(all_w))}

@ -65,9 +65,10 @@ def style_mixing_video(network_pkl: str,
print('col_seeds: ', dst_seeds) print('col_seeds: ', dst_seeds)
num_frames = int(np.rint(duration_sec * mp4_fps)) num_frames = int(np.rint(duration_sec * mp4_fps))
print('Loading networks from "%s"...' % network_pkl) print('Loading networks from "%s"...' % network_pkl)
device = torch.device('cuda') device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
dtype = torch.float32 if device.type == 'mps' else torch.float64
with dnnlib.util.open_url(network_pkl) as f: with dnnlib.util.open_url(network_pkl) as f:
Gs = legacy.load_network_pkl(f)['G_ema'].to(device) Gs = legacy.load_network_pkl(f)['G_ema'].to(device, dtype=dtype)
print(Gs.num_ws, Gs.w_dim, Gs.img_resolution) print(Gs.num_ws, Gs.w_dim, Gs.img_resolution)
max_style = int(2 * np.log2(Gs.img_resolution)) - 3 max_style = int(2 * np.log2(Gs.img_resolution)) - 3
@ -80,14 +81,14 @@ def style_mixing_video(network_pkl: str,
src_z = scipy.ndimage.gaussian_filter(src_z, [smoothing_sec * mp4_fps] + [0] * (2- 1), mode="wrap") src_z = scipy.ndimage.gaussian_filter(src_z, [smoothing_sec * mp4_fps] + [0] * (2- 1), mode="wrap")
src_z /= np.sqrt(np.mean(np.square(src_z))) src_z /= np.sqrt(np.mean(np.square(src_z)))
# Map into the detangled latent space W and do truncation trick # Map into the detangled latent space W and do truncation trick
src_w = Gs.mapping(torch.from_numpy(src_z).to(device), None) src_w = Gs.mapping(torch.from_numpy(src_z).to(device, dtype=dtype), None)
w_avg = Gs.mapping.w_avg w_avg = Gs.mapping.w_avg
src_w = w_avg + (src_w - w_avg) * truncation_psi src_w = w_avg + (src_w - w_avg) * truncation_psi
# Top row latents (fixed reference) # Top row latents (fixed reference)
print('Generating Destination W vectors...') print('Generating Destination W vectors...')
dst_z = np.stack([np.random.RandomState(seed).randn(Gs.z_dim) for seed in dst_seeds]) dst_z = np.stack([np.random.RandomState(seed).randn(Gs.z_dim) for seed in dst_seeds])
dst_w = Gs.mapping(torch.from_numpy(dst_z).to(device), None) dst_w = Gs.mapping(torch.from_numpy(dst_z).to(device, dtype=dtype), None)
dst_w = w_avg + (dst_w - w_avg) * truncation_psi dst_w = w_avg + (dst_w - w_avg) * truncation_psi
# Get the width and height of each image: # Get the width and height of each image:
H = Gs.img_resolution # 1024 H = Gs.img_resolution # 1024
@ -120,7 +121,7 @@ def style_mixing_video(network_pkl: str,
for col, dst_image in enumerate(list(dst_images)): for col, dst_image in enumerate(list(dst_images)):
# Select the pertinent latent w column: # Select the pertinent latent w column:
w_col = np.stack([dst_w[col].cpu()]) # [18, 512] -> [1, 18, 512] w_col = np.stack([dst_w[col].cpu()]) # [18, 512] -> [1, 18, 512]
w_col = torch.from_numpy(w_col).to(device) w_col = torch.from_numpy(w_col).to(device, dtype=dtype)
# Replace the values defined by col_styles: # Replace the values defined by col_styles:
w_col[:, col_styles] = src_w[frame_idx, col_styles]#.cpu() w_col[:, col_styles] = src_w[frame_idx, col_styles]#.cpu()
# Generate these synthesized images: # Generate these synthesized images:

@ -69,7 +69,8 @@ def add_watermark_np(input_image_array, watermark_text="AI Generated"):
class Renderer: class Renderer:
def __init__(self, disable_timing=False): def __init__(self, disable_timing=False):
self._device = torch.device('cuda') self._device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
self._dtype = torch.float32 if self._device.type == 'mps' else torch.float64
self._pkl_data = dict() # {pkl: dict | CapturedException, ...} self._pkl_data = dict() # {pkl: dict | CapturedException, ...}
self._networks = dict() # {cache_key: torch.nn.Module, ...} self._networks = dict() # {cache_key: torch.nn.Module, ...}
self._pinned_bufs = dict() # {(shape, dtype): torch.Tensor, ...} self._pinned_bufs = dict() # {(shape, dtype): torch.Tensor, ...}
@ -241,7 +242,7 @@ class Renderer:
if self.w_load is None: if self.w_load is None:
# Generate random latents. # Generate random latents.
z = torch.from_numpy(np.random.RandomState(w0_seed).randn(1, 512)).to(self._device).float() z = torch.from_numpy(np.random.RandomState(w0_seed).randn(1, 512)).to(self._device, dtype=self._dtype)
# Run mapping network. # Run mapping network.
label = torch.zeros([1, G.c_dim], device=self._device) label = torch.zeros([1, G.c_dim], device=self._device)

Loading…
Cancel
Save