From f6cbf568ff9815c180095733560a567dcb70e859 Mon Sep 17 00:00:00 2001 From: Vee9ahd1 Date: Wed, 21 Aug 2019 22:21:10 -0400 Subject: added a check to make sure the correct amount of keys are passed to cubic interpolation (providing a more meaningful error message than TypeError('m > k must hold')) --- gantools/cli.py | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) (limited to 'gantools/cli.py') diff --git a/gantools/cli.py b/gantools/cli.py index a4159f5..bb7946d 100644 --- a/gantools/cli.py +++ b/gantools/cli.py @@ -1,4 +1,4 @@ -import sys, argparse +import sys, os, argparse from gantools import ganbreeder from gantools import biggan from gantools import latent_space @@ -18,7 +18,7 @@ def handle_args(argv=None): parser.add_argument('-b', '--nbatch', metavar='N', type=int, help='Number of frames in each \'batch\' \ (note: the truncation value can only change once per batch. Don\'t fuck with this unless you know \ what it does.).', default=1) - parser.add_argument('-o', '--output-dir', help='Directory path for output images.') + parser.add_argument('-o', '--output-dir', help='Directory path for output images.', default=os.getcwd()) parser.add_argument('--prefix', help='File prefix for output images.') parser.add_argument('--interp', choices=['linear', 'cubic'], default='cubic', help='Set interpolation method.') group_loop = parser.add_mutually_exclusive_group(required=False) @@ -41,12 +41,18 @@ def main(): # interpolate path through input space print('Interpolating path through input space...') - z_seq, label_seq, truncation_seq = latent_space.sequence_keyframes( - keyframes, - args.nframes, - batch_size=args.nbatch, - interp_method=args.interp, - loop=args.loop) + try: + z_seq, label_seq, truncation_seq = latent_space.sequence_keyframes( + keyframes, + args.nframes, + batch_size=args.nbatch, + interp_method=args.interp, + loop=args.loop) + except ValueError as e: + print(e) + print('ERROR: Interpolation failed. Make sure you are using at least 3 keys (4 if --no-loop is enabled)') + print('If you would like to use fewer keys, try using the --interp linear argument') + return 1 # sample the GAN print('Loading bigGAN...') @@ -55,7 +61,11 @@ def main(): path = '' if args.output_dir == None else str(args.output_dir) prefix = '' if args.prefix == None else str(args.prefix) saver = image_utils.ImageSaver(output_dir=path, prefix=prefix) - print('Saving image files to: '+path + prefix) + print('Image files will be saved to: '+path + prefix) print('Sampling from bigGAN...') gan.sample(z_seq, label_seq, truncation=truncation_seq, batch_size=args.nbatch, save_callback=saver.save) print('Done.') + return 0 + +if __name__ == '__main__': + sys.exit(main()) -- cgit v1.2.1