diff options
Diffstat (limited to 'gantools/cli.py')
-rw-r--r-- | gantools/cli.py | 46 |
1 files changed, 45 insertions, 1 deletions
diff --git a/gantools/cli.py b/gantools/cli.py index df648c5..efabcae 100644 --- a/gantools/cli.py +++ b/gantools/cli.py @@ -1,3 +1,47 @@ +import sys, argparse +import ganbreeder +import biggan +import latent_space +import image_utils + # create entrypoints for cli tools def main(): - print("Hello World") + ## handle args + parser = argparse.ArgumentParser(description='GAN tools') + # load from ganbreeder + parser.add_argument('-u', '--username', help='Ganbreeder account email address/username.') + parser.add_argument('-p', '--password', help='Ganbreeder account password.') + parser.add_argument('-k', '--keys', nargs='+', help='Ganbreeder keys.') + parser.add_argument('-n', '--nframes', metavar='N', type=int, help='Total number of frames in the final animation.', default=10) + parser.add_argument('-b', '--nbatch', metavar='N', type=int, help='Number of frames in each \'batch\' \ + (note: the truncation value can only change once per batch. Don\'t fuck with this unless you know \ + what it does.).', default=1) + parser.add_argument('-p', '--pathprefix', help='Directory path and file prefix for output images.') + parser.add_argument('-r', '--radius', metavar='N', type=int, help='effect radius', default=10) + parser.add_argument('-c', '--count', metavar='N', type=int, help='effect count parameter', default=10) + parser.add_argument('-d', '--duration', metavar='N', type=float, help='frame duration', default=100) + args = parser.parse_args() + # validate args + if args.keys and not (args.username and args.password): + parser.error('The --keys argument requires a --username and --password to login to ganbreeder') + sys.exit(1) + + # get animation keyframes from ganbreeder + print('Downloading keyframe info from ganbreeder...') + keyframes = ganbreeder.get_info_batch(args.username, args.password, args.keys) + + # interpolate path through input space + print('Interpolating path through input space...') + z_seq, label_seq, truncation_seq = latent_space.sequence_keyframes(keyframes, args.nframes, batch_size) + + # sample the GAN + print('Loading bigGAN...') + gan = BigGAN() + print('Sampling from bigGAN...') + ims = gan.sample(z_seq, label_seq, truncation_seq, batch_size) + + # save images to file + print('Saving image files: '+args.pathprefix) + save_images(ims, args.pathprefix) + + print('Done.') |