aboutsummaryrefslogtreecommitdiff
path: root/gantools/cli.py
diff options
context:
space:
mode:
Diffstat (limited to 'gantools/cli.py')
-rw-r--r--gantools/cli.py21
1 files changed, 12 insertions, 9 deletions
diff --git a/gantools/cli.py b/gantools/cli.py
index 3e42f2e..a4159f5 100644
--- a/gantools/cli.py
+++ b/gantools/cli.py
@@ -20,7 +20,7 @@ def handle_args(argv=None):
what it does.).', default=1)
parser.add_argument('-o', '--output-dir', help='Directory path for output images.')
parser.add_argument('--prefix', help='File prefix for output images.')
- parser.add_argument('--interp', help='Set interpolation method.', choices=['linear', 'cubic'])
+ parser.add_argument('--interp', choices=['linear', 'cubic'], default='cubic', help='Set interpolation method.')
group_loop = parser.add_mutually_exclusive_group(required=False)
group_loop.add_argument('--loop', dest='loop', action='store_true', default=True, help='Loop the animation.')
group_loop.add_argument('--no-loop', dest='loop', action='store_false', help='Don\'t loop the animation.')
@@ -34,25 +34,28 @@ def handle_args(argv=None):
# create entrypoints for cli tools
def main():
- handle_args()
+ args = handle_args()
# get animation keyframes from ganbreeder
print('Downloading keyframe info from ganbreeder...')
keyframes = ganbreeder.get_info_batch(args.username, args.password, args.keys)
# interpolate path through input space
print('Interpolating path through input space...')
- z_seq, label_seq, truncation_seq = latent_space.sequence_keyframes(keyframes, args.nframes, args.nbatch)
+ z_seq, label_seq, truncation_seq = latent_space.sequence_keyframes(
+ keyframes,
+ args.nframes,
+ batch_size=args.nbatch,
+ interp_method=args.interp,
+ loop=args.loop)
# sample the GAN
print('Loading bigGAN...')
gan = biggan.BigGAN()
- print('Sampling from bigGAN...')
- ims = gan.sample(z_seq, label_seq, truncation_seq, args.nbatch)
- # save images to file
path = '' if args.output_dir == None else str(args.output_dir)
prefix = '' if args.prefix == None else str(args.prefix)
- print('Saving image files: '+path + prefix)
- image_utils.save_images(ims, output_dir=output_dir, prefix=prefix)
-
+ saver = image_utils.ImageSaver(output_dir=path, prefix=prefix)
+ print('Saving image files to: '+path + prefix)
+ print('Sampling from bigGAN...')
+ gan.sample(z_seq, label_seq, truncation=truncation_seq, batch_size=args.nbatch, save_callback=saver.save)
print('Done.')