aboutsummaryrefslogtreecommitdiff
path: root/gantools/biggan.py
diff options
context:
space:
mode:
Diffstat (limited to 'gantools/biggan.py')
-rw-r--r--gantools/biggan.py24
1 files changed, 17 insertions, 7 deletions
diff --git a/gantools/biggan.py b/gantools/biggan.py
index 0ed6bf7..571387b 100644
--- a/gantools/biggan.py
+++ b/gantools/biggan.py
@@ -25,7 +25,8 @@ class BigGAN(object):
self.sess = tf.Session()
self.sess.run(initializer)
- def sample(self, vectors, labels, truncation=0.5, batch_size=1):
+ # NOTE: use save callback to save images once per batch. return type changes to None.
+ def sample(self, vectors, labels, truncation=0.5, batch_size=1, save_callback=None):
num = vectors.shape[0]
# deal with scalar input case
@@ -38,11 +39,20 @@ class BigGAN(object):
for batch_start, trunc in zip(range(0, num, batch_size), truncation):
s = slice(batch_start, min(num, batch_start + batch_size))
feed_dict = {self.input_z: vectors[s], self.input_y: labels[s], self.input_trunc: trunc}
- ims.append(self.sess.run(self.output, feed_dict=feed_dict))
- ims = np.concatenate(ims, axis=0)
- assert ims.shape[0] == num
- ims = np.clip(((ims + 1) / 2.0) * 256, 0, 255)
- ims = np.uint8(ims)
- return ims
+ ims_batch = self.sess.run(self.output, feed_dict=feed_dict)
+ ims_batch = np.clip(((ims_batch + 1) / 2.0) * 256, 0, 255)
+ ims_batch = np.uint8(ims_batch)
+ if save_callback is None:
+ ims.append(ims_batch)
+ else:
+ save_callback(ims_batch)
+ if save_callback is None:
+ ims = np.concatenate(ims, axis=0)
+ assert ims.shape[0] == num
+ return ims
+ else:
+ return None
+
+
# TODO: make a version of sample() that includes a callback function to save ims somewhere instead of keeping
# them in memory.