diff options
Diffstat (limited to 'gantools')
-rw-r--r-- | gantools/cli.py | 21 | ||||
-rw-r--r-- | gantools/ganbreeder.py | 7 | ||||
-rw-r--r-- | gantools/image_utils.py | 2 |
3 files changed, 17 insertions, 13 deletions
diff --git a/gantools/cli.py b/gantools/cli.py index efabcae..b8e3d4b 100644 --- a/gantools/cli.py +++ b/gantools/cli.py @@ -1,8 +1,8 @@ import sys, argparse -import ganbreeder -import biggan -import latent_space -import image_utils +from gantools import ganbreeder +from gantools import biggan +from gantools import latent_space +from gantools import image_utils # create entrypoints for cli tools def main(): @@ -16,10 +16,7 @@ def main(): parser.add_argument('-b', '--nbatch', metavar='N', type=int, help='Number of frames in each \'batch\' \ (note: the truncation value can only change once per batch. Don\'t fuck with this unless you know \ what it does.).', default=1) - parser.add_argument('-p', '--pathprefix', help='Directory path and file prefix for output images.') - parser.add_argument('-r', '--radius', metavar='N', type=int, help='effect radius', default=10) - parser.add_argument('-c', '--count', metavar='N', type=int, help='effect count parameter', default=10) - parser.add_argument('-d', '--duration', metavar='N', type=float, help='frame duration', default=100) + parser.add_argument('-f', '--pathprefix', help='Directory path and file prefix for output images.') args = parser.parse_args() # validate args if args.keys and not (args.username and args.password): @@ -32,16 +29,16 @@ def main(): # interpolate path through input space print('Interpolating path through input space...') - z_seq, label_seq, truncation_seq = latent_space.sequence_keyframes(keyframes, args.nframes, batch_size) + z_seq, label_seq, truncation_seq = latent_space.sequence_keyframes(keyframes, args.nframes, args.nbatch) # sample the GAN print('Loading bigGAN...') - gan = BigGAN() + gan = biggan.BigGAN() print('Sampling from bigGAN...') - ims = gan.sample(z_seq, label_seq, truncation_seq, batch_size) + ims = gan.sample(z_seq, label_seq, truncation_seq, args.nbatch) # save images to file print('Saving image files: '+args.pathprefix) - save_images(ims, args.pathprefix) + image_utils.save_images(ims, args.pathprefix) print('Done.') diff --git a/gantools/ganbreeder.py b/gantools/ganbreeder.py index 2e7f946..22d1adf 100644 --- a/gantools/ganbreeder.py +++ b/gantools/ganbreeder.py @@ -43,3 +43,10 @@ def get_info(sid, key): r = requests.get('http://ganbreeder.app/info?k='+str(key), cookies=cookies) r.raise_for_status() return(r.json()) + +def get_info_batch(username, password, keys): + l = list() + sid = login(username, password) + for key in keys: + l.append(get_info(sid, key)) + return l diff --git a/gantools/image_utils.py b/gantools/image_utils.py index ca67134..016891f 100644 --- a/gantools/image_utils.py +++ b/gantools/image_utils.py @@ -1,4 +1,4 @@ -import PIL.image +import PIL.Image def save_image(arr, fp, format='JPEG'): image = PIL.Image.fromarray(arr) |