aboutsummaryrefslogtreecommitdiff
path: root/gantools/biggan.py
diff options
context:
space:
mode:
Diffstat (limited to 'gantools/biggan.py')
-rw-r--r--gantools/biggan.py48
1 files changed, 48 insertions, 0 deletions
diff --git a/gantools/biggan.py b/gantools/biggan.py
new file mode 100644
index 0000000..8a0a71d
--- /dev/null
+++ b/gantools/biggan.py
@@ -0,0 +1,48 @@
+# methods for setting up and interacting with biggan
+import tensorflow as tf
+import tensorflow_hub as hub
+import numpy as np
+from itertools import cycle
+
+MODULE_PATH = 'https://tfhub.dev/deepmind/biggan-512/2'
+
+class BigGAN:
+ def __init__(self, module_path=MODULE_PATH):
+ tf.reset_default_graph()
+ print('Loading BigGAN module from:', module_path)
+ module = hub.Module(module_path)
+ self.inputs = {k: tf.placeholder(v.dtype, v.get_shape().as_list(), k)
+ for k, v in module.get_input_info_dict().items()}
+ self.input_z = self.inputs['z']
+ self.dim_z = self.input_z.shape.as_list()[1]
+ self.input_y = self.inputs['y']
+ self.vocab_size = self.input_y.shape.as_list()[1] # dimension of y (aka label count)
+ self.input_trunc = self.inputs['truncation']
+ self.output = module(self.inputs)
+
+ # initialize/instantiate tf variables
+ initializer = tf.global_variables_initializer()
+ self.sess = tf.Session()
+ self.sess.run(initializer)
+
+ def sample(self, vectors, labels, truncation=0.5, batch_size=1):
+ num = vectors.shape[0]
+
+ # deal with scalar input case
+ truncation = np.asarray(truncation)
+ if truncation.ndim == 0:# truncation is a scalar
+ #TODO: there has to be a better way to do this...
+ truncation = cycle([truncation])
+
+ ims = []
+ for batch_start, trunc in zip(range(0, num, batch_size), truncation):
+ s = slice(batch_start, min(num, batch_start + batch_size))
+ feed_dict = {self.input_z: vectors[s], self.input_y: labels[s], self.input_trunc: trunc}
+ ims.append(self.sess.run(self.output, feed_dict=feed_dict))
+ ims = np.concatenate(ims, axis=0)
+ assert ims.shape[0] == num
+ ims = np.clip(((ims + 1) / 2.0) * 256, 0, 255)
+ ims = np.uint8(ims)
+ return ims
+ # TODO: make a version of sample() that includes a callback function to save ims somewhere instead of keeping
+ # them in memory.