diff --git a/viz/capture_widget.py b/viz/capture_widget.py index 48e1373..a63be95 100644 --- a/viz/capture_widget.py +++ b/viz/capture_widget.py @@ -31,7 +31,6 @@ class CaptureWidget: viz = self.viz try: _height, _width, channels = image.shape - assert channels in [1, 3] assert image.dtype == np.uint8 os.makedirs(self.path, exist_ok=True) file_id = 0 @@ -43,7 +42,7 @@ class CaptureWidget: if channels == 1: pil_image = PIL.Image.fromarray(image[:, :, 0], 'L') else: - pil_image = PIL.Image.fromarray(image, 'RGB') + pil_image = PIL.Image.fromarray(image[:, :, :3], 'RGB') pil_image.save(os.path.join(self.path, f'{file_id:05d}.png')) except: viz.result.error = renderer.CapturedException()