Support & instructions for MPS (Silicon Mac M1/M2) and CPU

pull/79/head
ochafik 2 years ago
parent d7f7319972
commit f3777f6e5b

@ -35,7 +35,19 @@
## Requirements
Please follow the requirements of [https://github.com/NVlabs/stylegan3](https://github.com/NVlabs/stylegan3).
If you have CUDA graphic card, please follow the requirements of [https://github.com/NVlabs/stylegan3](https://github.com/NVlabs/stylegan3).
Otherwise (for GPU acceleration on MacOS with Silicon Mac M1/M2, or just CPU) try the following:
```sh
cat environment.yml | \
grep -v -E 'nvidia|cuda' > environment-no-nvidia.yml && \
conda env create -f environment-no-nvidia.yml
conda activate stylegan3
# On MacOS
export PYTORCH_ENABLE_MPS_FALLBACK=1
```
## Download pre-trained StyleGAN2 weights

@ -5,11 +5,12 @@ channels:
dependencies:
- python >= 3.8
- pip
- numpy>=1.20
- numpy>=1.25
- click>=8.0
- pillow=8.3.1
- scipy=1.7.1
- pytorch=1.9.1
- pillow=9.4.0
- scipy=1.11.0
- pytorch>=2.0.1
- torchvision>=0.15.2
- cudatoolkit=11.1
- requests=2.26.0
- tqdm=4.62.2
@ -17,8 +18,10 @@ dependencies:
- matplotlib=3.4.2
- imageio=2.9.0
- pip:
- imgui==1.3.0
- glfw==2.2.0
- imgui==2.0.0
- glfw==2.6.1
- gradio==3.35.2
- pyopengl==3.1.5
- imageio-ffmpeg==0.4.3
- pyspng
# pyspng is currently broken on MacOS (see https://github.com/nurpax/pyspng/pull/6 for instance)
- pyspng-seunglab

@ -103,9 +103,10 @@ def generate_images(
"""
print('Loading networks from "%s"...' % network_pkl)
device = torch.device('cuda')
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
dtype = torch.float32 if device.type == 'mps' else torch.float64
with dnnlib.util.open_url(network_pkl) as f:
G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
G = legacy.load_network_pkl(f)['G_ema'].to(device, dtype=dtype) # type: ignore
# import pickle
# G = legacy.load_network_pkl(f)
# output = open('checkpoints/stylegan2-car-config-f-pt.pkl', 'wb')
@ -126,7 +127,7 @@ def generate_images(
# Generate images.
for seed_idx, seed in enumerate(seeds):
print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds)))
z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device)
z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device, dtype=dtype)
# Construct an inverse rotation/translation matrix and pass to the generator. The
# generator expects this matrix as an inverse to avoid potentially failing numerical

@ -63,9 +63,10 @@ def generate_images(
else:
import torch
device = torch.device('cuda')
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
dtype = torch.float32 if device.type == 'mps' else torch.float64
with dnnlib.util.open_url(network_pkl) as f:
G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
G = legacy.load_network_pkl(f)['G_ema'].to(device, dtype=dtype) # type: ignore
os.makedirs(outdir, exist_ok=True)
@ -92,7 +93,7 @@ def generate_images(
else: ## stylegan v2/v3
label = torch.zeros([1, G.c_dim], device=device)
z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device)
z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device, dtype=dtype)
if target_z.size==0:
target_z= z.cpu()
else:

@ -116,9 +116,10 @@ def main(
):
device = torch.device('cuda')
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
dtype = torch.float32 if device.type == 'mps' else torch.float64
with dnnlib.util.open_url(network_pkl) as f:
G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
G = legacy.load_network_pkl(f)['G_ema'].to(device, dtype=dtype) # type: ignore
outdir = os.path.join(outdir)
if not os.path.exists(outdir):
@ -132,8 +133,8 @@ def main(
print('Require two seeds, randomly generate two now.')
seeds = [seeds[0],random.randint(0,10000)]
z1 = torch.from_numpy(np.random.RandomState(seeds[0]).randn(1, G.z_dim)).to(device)
z2 = torch.from_numpy(np.random.RandomState(seeds[1]).randn(1, G.z_dim)).to(device)
z1 = torch.from_numpy(np.random.RandomState(seeds[0]).randn(1, G.z_dim)).to(device, dtype=dtype)
z2 = torch.from_numpy(np.random.RandomState(seeds[1]).randn(1, G.z_dim)).to(device, dtype=dtype)
img1 = generate_image_from_z(G, z1, noise_mode, truncation_psi, device)
img2 = generate_image_from_z(G, z2, noise_mode, truncation_psi, device)
img1.save(f'{outdir}/seed{seeds[0]:04d}.png')

@ -49,16 +49,17 @@ def generate_style_mix(
):
print('Loading networks from "%s"...' % network_pkl)
device = torch.device('cuda')
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
dtype = torch.float32 if device.type == 'mps' else torch.float64
with dnnlib.util.open_url(network_pkl) as f:
G = legacy.load_network_pkl(f)['G_ema'].to(device)
G = legacy.load_network_pkl(f)['G_ema'].to(device, dtype=dtype)
os.makedirs(outdir, exist_ok=True)
print('Generating W vectors...')
all_seeds = list(set(row_seeds + col_seeds))
all_z = np.stack([np.random.RandomState(seed).randn(G.z_dim) for seed in all_seeds])
all_w = G.mapping(torch.from_numpy(all_z).to(device), None)
all_w = G.mapping(torch.from_numpy(all_z).to(device, dtype=dtype), None)
w_avg = G.mapping.w_avg
all_w = w_avg + (all_w - w_avg) * truncation_psi
w_dict = {seed: w for seed, w in zip(all_seeds, list(all_w))}

@ -65,9 +65,10 @@ def style_mixing_video(network_pkl: str,
print('col_seeds: ', dst_seeds)
num_frames = int(np.rint(duration_sec * mp4_fps))
print('Loading networks from "%s"...' % network_pkl)
device = torch.device('cuda')
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
dtype = torch.float32 if device.type == 'mps' else torch.float64
with dnnlib.util.open_url(network_pkl) as f:
Gs = legacy.load_network_pkl(f)['G_ema'].to(device)
Gs = legacy.load_network_pkl(f)['G_ema'].to(device, dtype=dtype)
print(Gs.num_ws, Gs.w_dim, Gs.img_resolution)
max_style = int(2 * np.log2(Gs.img_resolution)) - 3
@ -80,14 +81,14 @@ def style_mixing_video(network_pkl: str,
src_z = scipy.ndimage.gaussian_filter(src_z, [smoothing_sec * mp4_fps] + [0] * (2- 1), mode="wrap")
src_z /= np.sqrt(np.mean(np.square(src_z)))
# Map into the detangled latent space W and do truncation trick
src_w = Gs.mapping(torch.from_numpy(src_z).to(device), None)
src_w = Gs.mapping(torch.from_numpy(src_z).to(device, dtype=dtype), None)
w_avg = Gs.mapping.w_avg
src_w = w_avg + (src_w - w_avg) * truncation_psi
# Top row latents (fixed reference)
print('Generating Destination W vectors...')
dst_z = np.stack([np.random.RandomState(seed).randn(Gs.z_dim) for seed in dst_seeds])
dst_w = Gs.mapping(torch.from_numpy(dst_z).to(device), None)
dst_w = Gs.mapping(torch.from_numpy(dst_z).to(device, dtype=dtype), None)
dst_w = w_avg + (dst_w - w_avg) * truncation_psi
# Get the width and height of each image:
H = Gs.img_resolution # 1024
@ -120,7 +121,7 @@ def style_mixing_video(network_pkl: str,
for col, dst_image in enumerate(list(dst_images)):
# Select the pertinent latent w column:
w_col = np.stack([dst_w[col].cpu()]) # [18, 512] -> [1, 18, 512]
w_col = torch.from_numpy(w_col).to(device)
w_col = torch.from_numpy(w_col).to(device, dtype=dtype)
# Replace the values defined by col_styles:
w_col[:, col_styles] = src_w[frame_idx, col_styles]#.cpu()
# Generate these synthesized images:

@ -69,7 +69,8 @@ def add_watermark_np(input_image_array, watermark_text="AI Generated"):
class Renderer:
def __init__(self, disable_timing=False):
self._device = torch.device('cuda')
self._device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
self._dtype = torch.float32 if self._device.type == 'mps' else torch.float64
self._pkl_data = dict() # {pkl: dict | CapturedException, ...}
self._networks = dict() # {cache_key: torch.nn.Module, ...}
self._pinned_bufs = dict() # {(shape, dtype): torch.Tensor, ...}
@ -241,7 +242,7 @@ class Renderer:
if self.w_load is None:
# Generate random latents.
z = torch.from_numpy(np.random.RandomState(w0_seed).randn(1, 512)).to(self._device).float()
z = torch.from_numpy(np.random.RandomState(w0_seed).randn(1, 512)).to(self._device, dtype=self._dtype)
# Run mapping network.
label = torch.zeros([1, G.c_dim], device=self._device)

Loading…
Cancel
Save