From 620535b58307e43a80e7f123e590e4fc31b1a755 Mon Sep 17 00:00:00 2001 From: Vee9ahd1 Date: Sat, 8 Jun 2019 17:51:12 -0400 Subject: added image save callback to sample/render loop --- gantools/cli.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) (limited to 'gantools/cli.py') diff --git a/gantools/cli.py b/gantools/cli.py index 3e42f2e..a4159f5 100644 --- a/gantools/cli.py +++ b/gantools/cli.py @@ -20,7 +20,7 @@ def handle_args(argv=None): what it does.).', default=1) parser.add_argument('-o', '--output-dir', help='Directory path for output images.') parser.add_argument('--prefix', help='File prefix for output images.') - parser.add_argument('--interp', help='Set interpolation method.', choices=['linear', 'cubic']) + parser.add_argument('--interp', choices=['linear', 'cubic'], default='cubic', help='Set interpolation method.') group_loop = parser.add_mutually_exclusive_group(required=False) group_loop.add_argument('--loop', dest='loop', action='store_true', default=True, help='Loop the animation.') group_loop.add_argument('--no-loop', dest='loop', action='store_false', help='Don\'t loop the animation.') @@ -34,25 +34,28 @@ def handle_args(argv=None): # create entrypoints for cli tools def main(): - handle_args() + args = handle_args() # get animation keyframes from ganbreeder print('Downloading keyframe info from ganbreeder...') keyframes = ganbreeder.get_info_batch(args.username, args.password, args.keys) # interpolate path through input space print('Interpolating path through input space...') - z_seq, label_seq, truncation_seq = latent_space.sequence_keyframes(keyframes, args.nframes, args.nbatch) + z_seq, label_seq, truncation_seq = latent_space.sequence_keyframes( + keyframes, + args.nframes, + batch_size=args.nbatch, + interp_method=args.interp, + loop=args.loop) # sample the GAN print('Loading bigGAN...') gan = biggan.BigGAN() - print('Sampling from bigGAN...') - ims = gan.sample(z_seq, label_seq, truncation_seq, args.nbatch) - # save images to file path = '' if args.output_dir == None else str(args.output_dir) prefix = '' if args.prefix == None else str(args.prefix) - print('Saving image files: '+path + prefix) - image_utils.save_images(ims, output_dir=output_dir, prefix=prefix) - + saver = image_utils.ImageSaver(output_dir=path, prefix=prefix) + print('Saving image files to: '+path + prefix) + print('Sampling from bigGAN...') + gan.sample(z_seq, label_seq, truncation=truncation_seq, batch_size=args.nbatch, save_callback=saver.save) print('Done.') -- cgit v1.2.1