aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorVee9ahd1 <[email protected]>2019-05-12 01:29:05 -0400
committerVee9ahd1 <[email protected]>2019-05-12 01:29:05 -0400
commit2e27eb73344d691b657f72c8e794f81ce47036c6 (patch)
treefafe6255167c74131f03e2a80b105278c28af30d
parent560f86d452277084a1be04fbc4c0e8c5f1206ff5 (diff)
implemented most of the basic functionality from the prototype script and created some messy tests
-rw-r--r--.gitignore3
-rw-r--r--gantools/biggan.py48
-rw-r--r--gantools/cli.py3
-rw-r--r--gantools/ganbreeder.py45
-rw-r--r--gantools/latent_space.py40
-rw-r--r--test/biggan_test.py26
-rw-r--r--test/ganbreeder_test.py16
-rw-r--r--test/latent_space_test.py50
8 files changed, 231 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore
index af6d502..3ef98f0 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,3 +1,6 @@
+# project-specific
+secrets.py
+.DS_Store
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
diff --git a/gantools/biggan.py b/gantools/biggan.py
new file mode 100644
index 0000000..8a0a71d
--- /dev/null
+++ b/gantools/biggan.py
@@ -0,0 +1,48 @@
+# methods for setting up and interacting with biggan
+import tensorflow as tf
+import tensorflow_hub as hub
+import numpy as np
+from itertools import cycle
+
+MODULE_PATH = 'https://tfhub.dev/deepmind/biggan-512/2'
+
+class BigGAN:
+ def __init__(self, module_path=MODULE_PATH):
+ tf.reset_default_graph()
+ print('Loading BigGAN module from:', module_path)
+ module = hub.Module(module_path)
+ self.inputs = {k: tf.placeholder(v.dtype, v.get_shape().as_list(), k)
+ for k, v in module.get_input_info_dict().items()}
+ self.input_z = self.inputs['z']
+ self.dim_z = self.input_z.shape.as_list()[1]
+ self.input_y = self.inputs['y']
+ self.vocab_size = self.input_y.shape.as_list()[1] # dimension of y (aka label count)
+ self.input_trunc = self.inputs['truncation']
+ self.output = module(self.inputs)
+
+ # initialize/instantiate tf variables
+ initializer = tf.global_variables_initializer()
+ self.sess = tf.Session()
+ self.sess.run(initializer)
+
+ def sample(self, vectors, labels, truncation=0.5, batch_size=1):
+ num = vectors.shape[0]
+
+ # deal with scalar input case
+ truncation = np.asarray(truncation)
+ if truncation.ndim == 0:# truncation is a scalar
+ #TODO: there has to be a better way to do this...
+ truncation = cycle([truncation])
+
+ ims = []
+ for batch_start, trunc in zip(range(0, num, batch_size), truncation):
+ s = slice(batch_start, min(num, batch_start + batch_size))
+ feed_dict = {self.input_z: vectors[s], self.input_y: labels[s], self.input_trunc: trunc}
+ ims.append(self.sess.run(self.output, feed_dict=feed_dict))
+ ims = np.concatenate(ims, axis=0)
+ assert ims.shape[0] == num
+ ims = np.clip(((ims + 1) / 2.0) * 256, 0, 255)
+ ims = np.uint8(ims)
+ return ims
+ # TODO: make a version of sample() that includes a callback function to save ims somewhere instead of keeping
+ # them in memory.
diff --git a/gantools/cli.py b/gantools/cli.py
new file mode 100644
index 0000000..df648c5
--- /dev/null
+++ b/gantools/cli.py
@@ -0,0 +1,3 @@
+# create entrypoints for cli tools
+def main():
+ print("Hello World")
diff --git a/gantools/ganbreeder.py b/gantools/ganbreeder.py
new file mode 100644
index 0000000..2e7f946
--- /dev/null
+++ b/gantools/ganbreeder.py
@@ -0,0 +1,45 @@
+# client functions for interacting with the ganbreeder api
+import requests
+import json
+
+def login(username, password):
+ def get_sid():
+ url = 'https://ganbreeder.app/login'
+ r = requests.get(url)
+ r.raise_for_status()
+ for c in r.cookies:
+ if c.name == 'connect.sid': # find the right cookie
+ print('Session ID: '+str(c.value))
+ return c.value
+
+ def login_auth(sid, username, password):
+ url = 'https://ganbreeder.app/login'
+ headers = {
+ 'Content-Type': 'application/json',
+ }
+ cookies = {
+ 'connect.sid': sid
+ }
+ payload = {
+ 'email': username,
+ 'password': password
+ }
+ r = requests.post(url, headers=headers, cookies=cookies, data=json.dumps(payload))
+ if not r.ok:
+ print('Authentication failed')
+ r.raise_for_status()
+ print('Authenticated')
+
+ sid = get_sid()
+ login_auth(sid, username, password)
+ return sid
+
+def get_info(sid, key):
+ if sid == '':
+ raise Exception('Cannot get info; session ID not defined. Be sure to login() first.')
+ cookies = {
+ 'connect.sid': sid
+ }
+ r = requests.get('http://ganbreeder.app/info?k='+str(key), cookies=cookies)
+ r.raise_for_status()
+ return(r.json())
diff --git a/gantools/latent_space.py b/gantools/latent_space.py
new file mode 100644
index 0000000..b89756b
--- /dev/null
+++ b/gantools/latent_space.py
@@ -0,0 +1,40 @@
+import numpy as np
+from scipy import signal
+
+def one_hot(index, dim):
+ y = np.zeros((1,dim))
+ if index < dim:
+ y[0,index] = 1.0
+ return y
+
+def interpolate(begin, end, step_count):
+ initial = np.tile(begin, (step_count, 1))
+ delta = np.tile((end - begin)/step_count, (step_count, 1))
+ g = np.tile(np.arange(step_count), (begin.size, 1)).transpose()
+ return (delta * g) + initial
+
+# TODO: the math in this function is embarrasingly bad. fix at some point.
+def sequence_keyframes(keyframes, num_frames, batch_size=1):
+ div = int(num_frames//len(keyframes))
+ rem = int(num_frames - (div*len(keyframes)))
+ frame_counts = np.full((len(keyframes),), div) + \
+ np.append(np.ones((rem,), dtype=int), np.zeros((len(keyframes)-rem,), dtype=int))
+ batch_div = int(num_frames//batch_size)
+ batch_rem = 1 if int(num_frames%batch_size) > 0 else 0
+ batch_count = batch_div + batch_rem
+
+ keyframes.append(keyframes[0])# seq returns to start
+ readahead = iter(keyframes)
+ next(readahead)
+ z_seq, label_seq, truncation_seq = [], [], []
+ for (begin, end, frame_count) in zip(keyframes, readahead, frame_counts):
+ z_begin = np.asarray(begin['vector'])*begin['truncation']
+ z_end = np.asarray(end['vector'])*end['truncation']
+ z_seq.extend(interpolate(z_begin, z_end, frame_count))
+ label_begin = np.asarray(begin['label'])
+ label_end = np.asarray(end['label'])
+ label_seq.extend(interpolate(label_begin, label_end, frame_count))
+ truncation_seq.extend(np.linspace(begin['truncation'], end['truncation'], frame_count))
+ # you can only change trunc once per batch
+ truncation_seq_resampled = signal.resample(truncation_seq, batch_count)
+ return np.asarray(z_seq), np.asarray(label_seq), truncation_seq_resampled
diff --git a/test/biggan_test.py b/test/biggan_test.py
new file mode 100644
index 0000000..db3e59f
--- /dev/null
+++ b/test/biggan_test.py
@@ -0,0 +1,26 @@
+import unittest
+from scipy.stats import truncnorm
+import numpy as np
+from gantools import biggan
+
+def create_random_input(dim_z, vocab_size, batch_size=1, truncation = 0.5, rand_seed = 123):
+ def one_hot(index, dim):
+ y = np.zeros((1,dim))
+ if index < dim:
+ y[0,index] = 1.0
+ return y
+ random_state = np.random.RandomState(rand_seed)
+ vectors = truncation * truncnorm.rvs(-2, 2, size=(batch_size, dim_z), random_state=random_state)
+ #TODO random labels
+ labels = one_hot(0, vocab_size)#np.random.random_sample((vocab_size,))
+ return vectors, labels, truncation
+
+class TestBigGAN(unittest.TestCase):
+ def test_biggan_sample(self):
+ gan = biggan.BigGAN()
+ vectors, labels, truncation = create_random_input(gan.dim_z, gan.vocab_size)
+ ims = gan.sample(vectors, labels, truncation)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/test/ganbreeder_test.py b/test/ganbreeder_test.py
new file mode 100644
index 0000000..816e077
--- /dev/null
+++ b/test/ganbreeder_test.py
@@ -0,0 +1,16 @@
+import unittest
+from gantools import ganbreeder
+from test import secrets
+
+class TestGanbreeder(unittest.TestCase):
+ def test_get_info(self):
+ username = secrets.username
+ password = secrets.password
+ sid = ganbreeder.login(username, password)
+ self.assertNotEqual(sid, '', 'login() failed to produce an sid. check internet connection.')
+ key = 'd62c507ab4bea4ed7b70c64a' #some arbitrary ganbreeder key
+ ganbreeder.get_info(sid, key)
+
+if __name__ == '__main__':
+ unittest.main()
+
diff --git a/test/latent_space_test.py b/test/latent_space_test.py
new file mode 100644
index 0000000..0655923
--- /dev/null
+++ b/test/latent_space_test.py
@@ -0,0 +1,50 @@
+import unittest
+import numpy as np
+from scipy.stats import truncnorm
+from gantools import latent_space
+from gantools import biggan
+import PIL.Image
+
+def create_random_keyframe(n_vector, n_label):
+ truncation = (0.9 - 0.1)*np.random.random() + 0.1
+ random_state = np.random.RandomState()
+ vectors = truncnorm.rvs(-2, 2, size=(n_vector,), random_state=random_state)
+ keyframe = {
+ 'vector': vectors.tolist(),
+ 'label': latent_space.one_hot(np.random.randint(0, n_label), n_label),
+ 'truncation': truncation,
+ }
+ return keyframe
+
+#### TMP
+def save_image(arr, fp):
+ image = PIL.Image.fromarray(arr)
+ image.save(fp, format='JPEG', quality=90)
+
+def save_ims(ims):
+ i = 0
+ for im in ims:
+ path = './GAN_'+str(i).zfill(3)+'.jpeg'
+ save_image(im, path)
+ i += 1
+####
+
+class TestLatentSpace(unittest.TestCase):
+ def test_sequence_keyframes(self):
+ num_frames = 20
+ batch_size = 3
+ keyframe_count = 3
+ dim_z = 128
+ dim_label = 1000
+ keyframes = [create_random_keyframe(dim_z,dim_label) for i in range(keyframe_count)]
+ z_seq, label_seq, truncation_seq = latent_space.sequence_keyframes(keyframes, num_frames, batch_size)
+ self.assertIs(len(z_seq), num_frames)
+ self.assertIs(len(label_seq), num_frames)
+ gan = biggan.BigGAN()
+ ims = gan.sample(z_seq, label_seq, truncation_seq, batch_size)
+ save_ims(ims)
+
+
+
+if __name__ == '__main__':
+ unittest.main()