From 620535b58307e43a80e7f123e590e4fc31b1a755 Mon Sep 17 00:00:00 2001 From: Vee9ahd1 Date: Sat, 8 Jun 2019 17:51:12 -0400 Subject: added image save callback to sample/render loop --- gantools/biggan.py | 24 +++++++++++++++++------- gantools/cli.py | 21 ++++++++++++--------- gantools/image_utils.py | 19 +++++++++++++++++++ 3 files changed, 48 insertions(+), 16 deletions(-) diff --git a/gantools/biggan.py b/gantools/biggan.py index 0ed6bf7..571387b 100644 --- a/gantools/biggan.py +++ b/gantools/biggan.py @@ -25,7 +25,8 @@ class BigGAN(object): self.sess = tf.Session() self.sess.run(initializer) - def sample(self, vectors, labels, truncation=0.5, batch_size=1): + # NOTE: use save callback to save images once per batch. return type changes to None. + def sample(self, vectors, labels, truncation=0.5, batch_size=1, save_callback=None): num = vectors.shape[0] # deal with scalar input case @@ -38,11 +39,20 @@ class BigGAN(object): for batch_start, trunc in zip(range(0, num, batch_size), truncation): s = slice(batch_start, min(num, batch_start + batch_size)) feed_dict = {self.input_z: vectors[s], self.input_y: labels[s], self.input_trunc: trunc} - ims.append(self.sess.run(self.output, feed_dict=feed_dict)) - ims = np.concatenate(ims, axis=0) - assert ims.shape[0] == num - ims = np.clip(((ims + 1) / 2.0) * 256, 0, 255) - ims = np.uint8(ims) - return ims + ims_batch = self.sess.run(self.output, feed_dict=feed_dict) + ims_batch = np.clip(((ims_batch + 1) / 2.0) * 256, 0, 255) + ims_batch = np.uint8(ims_batch) + if save_callback is None: + ims.append(ims_batch) + else: + save_callback(ims_batch) + if save_callback is None: + ims = np.concatenate(ims, axis=0) + assert ims.shape[0] == num + return ims + else: + return None + + # TODO: make a version of sample() that includes a callback function to save ims somewhere instead of keeping # them in memory. diff --git a/gantools/cli.py b/gantools/cli.py index 3e42f2e..a4159f5 100644 --- a/gantools/cli.py +++ b/gantools/cli.py @@ -20,7 +20,7 @@ def handle_args(argv=None): what it does.).', default=1) parser.add_argument('-o', '--output-dir', help='Directory path for output images.') parser.add_argument('--prefix', help='File prefix for output images.') - parser.add_argument('--interp', help='Set interpolation method.', choices=['linear', 'cubic']) + parser.add_argument('--interp', choices=['linear', 'cubic'], default='cubic', help='Set interpolation method.') group_loop = parser.add_mutually_exclusive_group(required=False) group_loop.add_argument('--loop', dest='loop', action='store_true', default=True, help='Loop the animation.') group_loop.add_argument('--no-loop', dest='loop', action='store_false', help='Don\'t loop the animation.') @@ -34,25 +34,28 @@ def handle_args(argv=None): # create entrypoints for cli tools def main(): - handle_args() + args = handle_args() # get animation keyframes from ganbreeder print('Downloading keyframe info from ganbreeder...') keyframes = ganbreeder.get_info_batch(args.username, args.password, args.keys) # interpolate path through input space print('Interpolating path through input space...') - z_seq, label_seq, truncation_seq = latent_space.sequence_keyframes(keyframes, args.nframes, args.nbatch) + z_seq, label_seq, truncation_seq = latent_space.sequence_keyframes( + keyframes, + args.nframes, + batch_size=args.nbatch, + interp_method=args.interp, + loop=args.loop) # sample the GAN print('Loading bigGAN...') gan = biggan.BigGAN() - print('Sampling from bigGAN...') - ims = gan.sample(z_seq, label_seq, truncation_seq, args.nbatch) - # save images to file path = '' if args.output_dir == None else str(args.output_dir) prefix = '' if args.prefix == None else str(args.prefix) - print('Saving image files: '+path + prefix) - image_utils.save_images(ims, output_dir=output_dir, prefix=prefix) - + saver = image_utils.ImageSaver(output_dir=path, prefix=prefix) + print('Saving image files to: '+path + prefix) + print('Sampling from bigGAN...') + gan.sample(z_seq, label_seq, truncation=truncation_seq, batch_size=args.nbatch, save_callback=saver.save) print('Done.') diff --git a/gantools/image_utils.py b/gantools/image_utils.py index 72130b0..a875a91 100644 --- a/gantools/image_utils.py +++ b/gantools/image_utils.py @@ -9,3 +9,22 @@ def save_images(ims, output_dir='', prefix='', format='JPEG'): for i, im in enumerate(ims): full_path = os.path.join(output_dir, prefix + str(i).zfill(4) + '.' + format.lower()) save_image(im, full_path, format) + +class ImageSaver(object): + def __init__(self, output_dir='', prefix='', image_format='JPEG'): + self.output_dir = str(output_dir) + self.prefix = str(prefix) + self.image_format = str(image_format) + self.index = int(0) + self.quality = 90 + + def save(self, ims): + for i, im in enumerate(ims): + full_path = os.path.join( + self.output_dir, + self.prefix + str(self.index).zfill(4) + '.' + self.image_format.lower() + ) + image = PIL.Image.fromarray(im) + image.save(full_path, format=self.image_format, quality=self.quality) + self.index += 1 + -- cgit v1.2.1