aboutsummaryrefslogtreecommitdiff
path: root/gantools
diff options
context:
space:
mode:
Diffstat (limited to 'gantools')
-rw-r--r--gantools/cli.py46
-rw-r--r--gantools/image_utils.py10
2 files changed, 55 insertions, 1 deletions
diff --git a/gantools/cli.py b/gantools/cli.py
index df648c5..efabcae 100644
--- a/gantools/cli.py
+++ b/gantools/cli.py
@@ -1,3 +1,47 @@
+import sys, argparse
+import ganbreeder
+import biggan
+import latent_space
+import image_utils
+
# create entrypoints for cli tools
def main():
- print("Hello World")
+ ## handle args
+ parser = argparse.ArgumentParser(description='GAN tools')
+ # load from ganbreeder
+ parser.add_argument('-u', '--username', help='Ganbreeder account email address/username.')
+ parser.add_argument('-p', '--password', help='Ganbreeder account password.')
+ parser.add_argument('-k', '--keys', nargs='+', help='Ganbreeder keys.')
+ parser.add_argument('-n', '--nframes', metavar='N', type=int, help='Total number of frames in the final animation.', default=10)
+ parser.add_argument('-b', '--nbatch', metavar='N', type=int, help='Number of frames in each \'batch\' \
+ (note: the truncation value can only change once per batch. Don\'t fuck with this unless you know \
+ what it does.).', default=1)
+ parser.add_argument('-p', '--pathprefix', help='Directory path and file prefix for output images.')
+ parser.add_argument('-r', '--radius', metavar='N', type=int, help='effect radius', default=10)
+ parser.add_argument('-c', '--count', metavar='N', type=int, help='effect count parameter', default=10)
+ parser.add_argument('-d', '--duration', metavar='N', type=float, help='frame duration', default=100)
+ args = parser.parse_args()
+ # validate args
+ if args.keys and not (args.username and args.password):
+ parser.error('The --keys argument requires a --username and --password to login to ganbreeder')
+ sys.exit(1)
+
+ # get animation keyframes from ganbreeder
+ print('Downloading keyframe info from ganbreeder...')
+ keyframes = ganbreeder.get_info_batch(args.username, args.password, args.keys)
+
+ # interpolate path through input space
+ print('Interpolating path through input space...')
+ z_seq, label_seq, truncation_seq = latent_space.sequence_keyframes(keyframes, args.nframes, batch_size)
+
+ # sample the GAN
+ print('Loading bigGAN...')
+ gan = BigGAN()
+ print('Sampling from bigGAN...')
+ ims = gan.sample(z_seq, label_seq, truncation_seq, batch_size)
+
+ # save images to file
+ print('Saving image files: '+args.pathprefix)
+ save_images(ims, args.pathprefix)
+
+ print('Done.')
diff --git a/gantools/image_utils.py b/gantools/image_utils.py
new file mode 100644
index 0000000..ca67134
--- /dev/null
+++ b/gantools/image_utils.py
@@ -0,0 +1,10 @@
+import PIL.image
+
+def save_image(arr, fp, format='JPEG'):
+ image = PIL.Image.fromarray(arr)
+ image.save(fp, format=format, quality=90)
+
+def save_images(ims, path_prefix='', format='JPEG'):
+ for i, im in enumerate(ims):
+ path = str(path_prefix)+str(i).zfill(4)+'.'+str(format).lower()
+ save_image(im, path)