diff options
Diffstat (limited to 'gantools')
-rw-r--r-- | gantools/cli.py | 46 | ||||
-rw-r--r-- | gantools/image_utils.py | 10 |
2 files changed, 55 insertions, 1 deletions
diff --git a/gantools/cli.py b/gantools/cli.py index df648c5..efabcae 100644 --- a/gantools/cli.py +++ b/gantools/cli.py @@ -1,3 +1,47 @@ +import sys, argparse +import ganbreeder +import biggan +import latent_space +import image_utils + # create entrypoints for cli tools def main(): - print("Hello World") + ## handle args + parser = argparse.ArgumentParser(description='GAN tools') + # load from ganbreeder + parser.add_argument('-u', '--username', help='Ganbreeder account email address/username.') + parser.add_argument('-p', '--password', help='Ganbreeder account password.') + parser.add_argument('-k', '--keys', nargs='+', help='Ganbreeder keys.') + parser.add_argument('-n', '--nframes', metavar='N', type=int, help='Total number of frames in the final animation.', default=10) + parser.add_argument('-b', '--nbatch', metavar='N', type=int, help='Number of frames in each \'batch\' \ + (note: the truncation value can only change once per batch. Don\'t fuck with this unless you know \ + what it does.).', default=1) + parser.add_argument('-p', '--pathprefix', help='Directory path and file prefix for output images.') + parser.add_argument('-r', '--radius', metavar='N', type=int, help='effect radius', default=10) + parser.add_argument('-c', '--count', metavar='N', type=int, help='effect count parameter', default=10) + parser.add_argument('-d', '--duration', metavar='N', type=float, help='frame duration', default=100) + args = parser.parse_args() + # validate args + if args.keys and not (args.username and args.password): + parser.error('The --keys argument requires a --username and --password to login to ganbreeder') + sys.exit(1) + + # get animation keyframes from ganbreeder + print('Downloading keyframe info from ganbreeder...') + keyframes = ganbreeder.get_info_batch(args.username, args.password, args.keys) + + # interpolate path through input space + print('Interpolating path through input space...') + z_seq, label_seq, truncation_seq = latent_space.sequence_keyframes(keyframes, args.nframes, batch_size) + + # sample the GAN + print('Loading bigGAN...') + gan = BigGAN() + print('Sampling from bigGAN...') + ims = gan.sample(z_seq, label_seq, truncation_seq, batch_size) + + # save images to file + print('Saving image files: '+args.pathprefix) + save_images(ims, args.pathprefix) + + print('Done.') diff --git a/gantools/image_utils.py b/gantools/image_utils.py new file mode 100644 index 0000000..ca67134 --- /dev/null +++ b/gantools/image_utils.py @@ -0,0 +1,10 @@ +import PIL.image + +def save_image(arr, fp, format='JPEG'): + image = PIL.Image.fromarray(arr) + image.save(fp, format=format, quality=90) + +def save_images(ims, path_prefix='', format='JPEG'): + for i, im in enumerate(ims): + path = str(path_prefix)+str(i).zfill(4)+'.'+str(format).lower() + save_image(im, path) |