Merge pull request #65 from PDillis/main

General fixes
pull/79/head^2
Xingang Pan 2 years ago committed by GitHub
commit ccd84ffab9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -177,15 +177,16 @@ class Visualizer(imgui_window.ImguiWindow):
if self.result.init_net: if self.result.init_net:
self.drag_widget.reset_point() self.drag_widget.reset_point()
if self.check_update_mask(**self.args):
h, w, _ = self.result.image.shape
self.drag_widget.init_mask(w, h)
# Display. # Display.
max_w = self.content_width - self.pane_w max_w = self.content_width - self.pane_w
max_h = self.content_height max_h = self.content_height
pos = np.array([self.pane_w + max_w / 2, max_h / 2]) pos = np.array([self.pane_w + max_w / 2, max_h / 2])
if 'image' in self.result: if 'image' in self.result:
# Reset mask after loading a new pickle or changing seed.
if self.check_update_mask(**self.args):
h, w, _ = self.result.image.shape
self.drag_widget.init_mask(w, h)
if self._tex_img is not self.result.image: if self._tex_img is not self.result.image:
self._tex_img = self.result.image self._tex_img = self.result.image
if self._tex_obj is None or not self._tex_obj.is_compatible(image=self._tex_img): if self._tex_obj is None or not self._tex_obj.is_compatible(image=self._tex_img):

@ -31,7 +31,7 @@ class CaptureWidget:
viz = self.viz viz = self.viz
try: try:
_height, _width, channels = image.shape _height, _width, channels = image.shape
assert channels in [1, 3] print(viz.result)
assert image.dtype == np.uint8 assert image.dtype == np.uint8
os.makedirs(self.path, exist_ok=True) os.makedirs(self.path, exist_ok=True)
file_id = 0 file_id = 0
@ -43,8 +43,9 @@ class CaptureWidget:
if channels == 1: if channels == 1:
pil_image = PIL.Image.fromarray(image[:, :, 0], 'L') pil_image = PIL.Image.fromarray(image[:, :, 0], 'L')
else: else:
pil_image = PIL.Image.fromarray(image, 'RGB') pil_image = PIL.Image.fromarray(image[:, :, :3], 'RGB')
pil_image.save(os.path.join(self.path, f'{file_id:05d}.png')) pil_image.save(os.path.join(self.path, f'{file_id:05d}.png'))
np.save(os.path.join(self.path, f'{file_id:05d}.npy'), viz.result.w)
except: except:
viz.result.error = renderer.CapturedException() viz.result.error = renderer.CapturedException()

@ -382,5 +382,6 @@ class Renderer:
img = img.cpu().numpy() img = img.cpu().numpy()
img = Image.fromarray(img) img = Image.fromarray(img)
res.image = img res.image = img
res.w = ws.detach().cpu().numpy()
#---------------------------------------------------------------------------- #----------------------------------------------------------------------------

Loading…
Cancel
Save