aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorVee9ahd1 <[email protected]>2019-06-08 17:51:12 -0400
committerVee9ahd1 <[email protected]>2019-06-08 17:51:12 -0400
commit620535b58307e43a80e7f123e590e4fc31b1a755 (patch)
tree047188a39c394b4c3d7dfec67b637730b8714f84
parent16e8cb2ff9cc94265d3ad960d51ac8fa318b1b74 (diff)
added image save callback to sample/render loop
-rw-r--r--gantools/biggan.py24
-rw-r--r--gantools/cli.py21
-rw-r--r--gantools/image_utils.py19
3 files changed, 48 insertions, 16 deletions
diff --git a/gantools/biggan.py b/gantools/biggan.py
index 0ed6bf7..571387b 100644
--- a/gantools/biggan.py
+++ b/gantools/biggan.py
@@ -25,7 +25,8 @@ class BigGAN(object):
self.sess = tf.Session()
self.sess.run(initializer)
- def sample(self, vectors, labels, truncation=0.5, batch_size=1):
+ # NOTE: use save callback to save images once per batch. return type changes to None.
+ def sample(self, vectors, labels, truncation=0.5, batch_size=1, save_callback=None):
num = vectors.shape[0]
# deal with scalar input case
@@ -38,11 +39,20 @@ class BigGAN(object):
for batch_start, trunc in zip(range(0, num, batch_size), truncation):
s = slice(batch_start, min(num, batch_start + batch_size))
feed_dict = {self.input_z: vectors[s], self.input_y: labels[s], self.input_trunc: trunc}
- ims.append(self.sess.run(self.output, feed_dict=feed_dict))
- ims = np.concatenate(ims, axis=0)
- assert ims.shape[0] == num
- ims = np.clip(((ims + 1) / 2.0) * 256, 0, 255)
- ims = np.uint8(ims)
- return ims
+ ims_batch = self.sess.run(self.output, feed_dict=feed_dict)
+ ims_batch = np.clip(((ims_batch + 1) / 2.0) * 256, 0, 255)
+ ims_batch = np.uint8(ims_batch)
+ if save_callback is None:
+ ims.append(ims_batch)
+ else:
+ save_callback(ims_batch)
+ if save_callback is None:
+ ims = np.concatenate(ims, axis=0)
+ assert ims.shape[0] == num
+ return ims
+ else:
+ return None
+
+
# TODO: make a version of sample() that includes a callback function to save ims somewhere instead of keeping
# them in memory.
diff --git a/gantools/cli.py b/gantools/cli.py
index 3e42f2e..a4159f5 100644
--- a/gantools/cli.py
+++ b/gantools/cli.py
@@ -20,7 +20,7 @@ def handle_args(argv=None):
what it does.).', default=1)
parser.add_argument('-o', '--output-dir', help='Directory path for output images.')
parser.add_argument('--prefix', help='File prefix for output images.')
- parser.add_argument('--interp', help='Set interpolation method.', choices=['linear', 'cubic'])
+ parser.add_argument('--interp', choices=['linear', 'cubic'], default='cubic', help='Set interpolation method.')
group_loop = parser.add_mutually_exclusive_group(required=False)
group_loop.add_argument('--loop', dest='loop', action='store_true', default=True, help='Loop the animation.')
group_loop.add_argument('--no-loop', dest='loop', action='store_false', help='Don\'t loop the animation.')
@@ -34,25 +34,28 @@ def handle_args(argv=None):
# create entrypoints for cli tools
def main():
- handle_args()
+ args = handle_args()
# get animation keyframes from ganbreeder
print('Downloading keyframe info from ganbreeder...')
keyframes = ganbreeder.get_info_batch(args.username, args.password, args.keys)
# interpolate path through input space
print('Interpolating path through input space...')
- z_seq, label_seq, truncation_seq = latent_space.sequence_keyframes(keyframes, args.nframes, args.nbatch)
+ z_seq, label_seq, truncation_seq = latent_space.sequence_keyframes(
+ keyframes,
+ args.nframes,
+ batch_size=args.nbatch,
+ interp_method=args.interp,
+ loop=args.loop)
# sample the GAN
print('Loading bigGAN...')
gan = biggan.BigGAN()
- print('Sampling from bigGAN...')
- ims = gan.sample(z_seq, label_seq, truncation_seq, args.nbatch)
- # save images to file
path = '' if args.output_dir == None else str(args.output_dir)
prefix = '' if args.prefix == None else str(args.prefix)
- print('Saving image files: '+path + prefix)
- image_utils.save_images(ims, output_dir=output_dir, prefix=prefix)
-
+ saver = image_utils.ImageSaver(output_dir=path, prefix=prefix)
+ print('Saving image files to: '+path + prefix)
+ print('Sampling from bigGAN...')
+ gan.sample(z_seq, label_seq, truncation=truncation_seq, batch_size=args.nbatch, save_callback=saver.save)
print('Done.')
diff --git a/gantools/image_utils.py b/gantools/image_utils.py
index 72130b0..a875a91 100644
--- a/gantools/image_utils.py
+++ b/gantools/image_utils.py
@@ -9,3 +9,22 @@ def save_images(ims, output_dir='', prefix='', format='JPEG'):
for i, im in enumerate(ims):
full_path = os.path.join(output_dir, prefix + str(i).zfill(4) + '.' + format.lower())
save_image(im, full_path, format)
+
+class ImageSaver(object):
+ def __init__(self, output_dir='', prefix='', image_format='JPEG'):
+ self.output_dir = str(output_dir)
+ self.prefix = str(prefix)
+ self.image_format = str(image_format)
+ self.index = int(0)
+ self.quality = 90
+
+ def save(self, ims):
+ for i, im in enumerate(ims):
+ full_path = os.path.join(
+ self.output_dir,
+ self.prefix + str(self.index).zfill(4) + '.' + self.image_format.lower()
+ )
+ image = PIL.Image.fromarray(im)
+ image.save(full_path, format=self.image_format, quality=self.quality)
+ self.index += 1
+