diff options
Diffstat (limited to 'gantools/biggan.py')
-rw-r--r-- | gantools/biggan.py | 24 |
1 files changed, 17 insertions, 7 deletions
diff --git a/gantools/biggan.py b/gantools/biggan.py index 0ed6bf7..571387b 100644 --- a/gantools/biggan.py +++ b/gantools/biggan.py @@ -25,7 +25,8 @@ class BigGAN(object): self.sess = tf.Session() self.sess.run(initializer) - def sample(self, vectors, labels, truncation=0.5, batch_size=1): + # NOTE: use save callback to save images once per batch. return type changes to None. + def sample(self, vectors, labels, truncation=0.5, batch_size=1, save_callback=None): num = vectors.shape[0] # deal with scalar input case @@ -38,11 +39,20 @@ class BigGAN(object): for batch_start, trunc in zip(range(0, num, batch_size), truncation): s = slice(batch_start, min(num, batch_start + batch_size)) feed_dict = {self.input_z: vectors[s], self.input_y: labels[s], self.input_trunc: trunc} - ims.append(self.sess.run(self.output, feed_dict=feed_dict)) - ims = np.concatenate(ims, axis=0) - assert ims.shape[0] == num - ims = np.clip(((ims + 1) / 2.0) * 256, 0, 255) - ims = np.uint8(ims) - return ims + ims_batch = self.sess.run(self.output, feed_dict=feed_dict) + ims_batch = np.clip(((ims_batch + 1) / 2.0) * 256, 0, 255) + ims_batch = np.uint8(ims_batch) + if save_callback is None: + ims.append(ims_batch) + else: + save_callback(ims_batch) + if save_callback is None: + ims = np.concatenate(ims, axis=0) + assert ims.shape[0] == num + return ims + else: + return None + + # TODO: make a version of sample() that includes a callback function to save ims somewhere instead of keeping # them in memory. |