diff options
-rw-r--r-- | .gitignore | 3 | ||||
-rw-r--r-- | gantools/biggan.py | 48 | ||||
-rw-r--r-- | gantools/cli.py | 3 | ||||
-rw-r--r-- | gantools/ganbreeder.py | 45 | ||||
-rw-r--r-- | gantools/latent_space.py | 40 | ||||
-rw-r--r-- | test/biggan_test.py | 26 | ||||
-rw-r--r-- | test/ganbreeder_test.py | 16 | ||||
-rw-r--r-- | test/latent_space_test.py | 50 |
8 files changed, 231 insertions, 0 deletions
@@ -1,3 +1,6 @@ +# project-specific +secrets.py +.DS_Store # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/gantools/biggan.py b/gantools/biggan.py new file mode 100644 index 0000000..8a0a71d --- /dev/null +++ b/gantools/biggan.py @@ -0,0 +1,48 @@ +# methods for setting up and interacting with biggan +import tensorflow as tf +import tensorflow_hub as hub +import numpy as np +from itertools import cycle + +MODULE_PATH = 'https://tfhub.dev/deepmind/biggan-512/2' + +class BigGAN: + def __init__(self, module_path=MODULE_PATH): + tf.reset_default_graph() + print('Loading BigGAN module from:', module_path) + module = hub.Module(module_path) + self.inputs = {k: tf.placeholder(v.dtype, v.get_shape().as_list(), k) + for k, v in module.get_input_info_dict().items()} + self.input_z = self.inputs['z'] + self.dim_z = self.input_z.shape.as_list()[1] + self.input_y = self.inputs['y'] + self.vocab_size = self.input_y.shape.as_list()[1] # dimension of y (aka label count) + self.input_trunc = self.inputs['truncation'] + self.output = module(self.inputs) + + # initialize/instantiate tf variables + initializer = tf.global_variables_initializer() + self.sess = tf.Session() + self.sess.run(initializer) + + def sample(self, vectors, labels, truncation=0.5, batch_size=1): + num = vectors.shape[0] + + # deal with scalar input case + truncation = np.asarray(truncation) + if truncation.ndim == 0:# truncation is a scalar + #TODO: there has to be a better way to do this... + truncation = cycle([truncation]) + + ims = [] + for batch_start, trunc in zip(range(0, num, batch_size), truncation): + s = slice(batch_start, min(num, batch_start + batch_size)) + feed_dict = {self.input_z: vectors[s], self.input_y: labels[s], self.input_trunc: trunc} + ims.append(self.sess.run(self.output, feed_dict=feed_dict)) + ims = np.concatenate(ims, axis=0) + assert ims.shape[0] == num + ims = np.clip(((ims + 1) / 2.0) * 256, 0, 255) + ims = np.uint8(ims) + return ims + # TODO: make a version of sample() that includes a callback function to save ims somewhere instead of keeping + # them in memory. diff --git a/gantools/cli.py b/gantools/cli.py new file mode 100644 index 0000000..df648c5 --- /dev/null +++ b/gantools/cli.py @@ -0,0 +1,3 @@ +# create entrypoints for cli tools +def main(): + print("Hello World") diff --git a/gantools/ganbreeder.py b/gantools/ganbreeder.py new file mode 100644 index 0000000..2e7f946 --- /dev/null +++ b/gantools/ganbreeder.py @@ -0,0 +1,45 @@ +# client functions for interacting with the ganbreeder api +import requests +import json + +def login(username, password): + def get_sid(): + url = 'https://ganbreeder.app/login' + r = requests.get(url) + r.raise_for_status() + for c in r.cookies: + if c.name == 'connect.sid': # find the right cookie + print('Session ID: '+str(c.value)) + return c.value + + def login_auth(sid, username, password): + url = 'https://ganbreeder.app/login' + headers = { + 'Content-Type': 'application/json', + } + cookies = { + 'connect.sid': sid + } + payload = { + 'email': username, + 'password': password + } + r = requests.post(url, headers=headers, cookies=cookies, data=json.dumps(payload)) + if not r.ok: + print('Authentication failed') + r.raise_for_status() + print('Authenticated') + + sid = get_sid() + login_auth(sid, username, password) + return sid + +def get_info(sid, key): + if sid == '': + raise Exception('Cannot get info; session ID not defined. Be sure to login() first.') + cookies = { + 'connect.sid': sid + } + r = requests.get('http://ganbreeder.app/info?k='+str(key), cookies=cookies) + r.raise_for_status() + return(r.json()) diff --git a/gantools/latent_space.py b/gantools/latent_space.py new file mode 100644 index 0000000..b89756b --- /dev/null +++ b/gantools/latent_space.py @@ -0,0 +1,40 @@ +import numpy as np +from scipy import signal + +def one_hot(index, dim): + y = np.zeros((1,dim)) + if index < dim: + y[0,index] = 1.0 + return y + +def interpolate(begin, end, step_count): + initial = np.tile(begin, (step_count, 1)) + delta = np.tile((end - begin)/step_count, (step_count, 1)) + g = np.tile(np.arange(step_count), (begin.size, 1)).transpose() + return (delta * g) + initial + +# TODO: the math in this function is embarrasingly bad. fix at some point. +def sequence_keyframes(keyframes, num_frames, batch_size=1): + div = int(num_frames//len(keyframes)) + rem = int(num_frames - (div*len(keyframes))) + frame_counts = np.full((len(keyframes),), div) + \ + np.append(np.ones((rem,), dtype=int), np.zeros((len(keyframes)-rem,), dtype=int)) + batch_div = int(num_frames//batch_size) + batch_rem = 1 if int(num_frames%batch_size) > 0 else 0 + batch_count = batch_div + batch_rem + + keyframes.append(keyframes[0])# seq returns to start + readahead = iter(keyframes) + next(readahead) + z_seq, label_seq, truncation_seq = [], [], [] + for (begin, end, frame_count) in zip(keyframes, readahead, frame_counts): + z_begin = np.asarray(begin['vector'])*begin['truncation'] + z_end = np.asarray(end['vector'])*end['truncation'] + z_seq.extend(interpolate(z_begin, z_end, frame_count)) + label_begin = np.asarray(begin['label']) + label_end = np.asarray(end['label']) + label_seq.extend(interpolate(label_begin, label_end, frame_count)) + truncation_seq.extend(np.linspace(begin['truncation'], end['truncation'], frame_count)) + # you can only change trunc once per batch + truncation_seq_resampled = signal.resample(truncation_seq, batch_count) + return np.asarray(z_seq), np.asarray(label_seq), truncation_seq_resampled diff --git a/test/biggan_test.py b/test/biggan_test.py new file mode 100644 index 0000000..db3e59f --- /dev/null +++ b/test/biggan_test.py @@ -0,0 +1,26 @@ +import unittest +from scipy.stats import truncnorm +import numpy as np +from gantools import biggan + +def create_random_input(dim_z, vocab_size, batch_size=1, truncation = 0.5, rand_seed = 123): + def one_hot(index, dim): + y = np.zeros((1,dim)) + if index < dim: + y[0,index] = 1.0 + return y + random_state = np.random.RandomState(rand_seed) + vectors = truncation * truncnorm.rvs(-2, 2, size=(batch_size, dim_z), random_state=random_state) + #TODO random labels + labels = one_hot(0, vocab_size)#np.random.random_sample((vocab_size,)) + return vectors, labels, truncation + +class TestBigGAN(unittest.TestCase): + def test_biggan_sample(self): + gan = biggan.BigGAN() + vectors, labels, truncation = create_random_input(gan.dim_z, gan.vocab_size) + ims = gan.sample(vectors, labels, truncation) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/ganbreeder_test.py b/test/ganbreeder_test.py new file mode 100644 index 0000000..816e077 --- /dev/null +++ b/test/ganbreeder_test.py @@ -0,0 +1,16 @@ +import unittest +from gantools import ganbreeder +from test import secrets + +class TestGanbreeder(unittest.TestCase): + def test_get_info(self): + username = secrets.username + password = secrets.password + sid = ganbreeder.login(username, password) + self.assertNotEqual(sid, '', 'login() failed to produce an sid. check internet connection.') + key = 'd62c507ab4bea4ed7b70c64a' #some arbitrary ganbreeder key + ganbreeder.get_info(sid, key) + +if __name__ == '__main__': + unittest.main() + diff --git a/test/latent_space_test.py b/test/latent_space_test.py new file mode 100644 index 0000000..0655923 --- /dev/null +++ b/test/latent_space_test.py @@ -0,0 +1,50 @@ +import unittest +import numpy as np +from scipy.stats import truncnorm +from gantools import latent_space +from gantools import biggan +import PIL.Image + +def create_random_keyframe(n_vector, n_label): + truncation = (0.9 - 0.1)*np.random.random() + 0.1 + random_state = np.random.RandomState() + vectors = truncnorm.rvs(-2, 2, size=(n_vector,), random_state=random_state) + keyframe = { + 'vector': vectors.tolist(), + 'label': latent_space.one_hot(np.random.randint(0, n_label), n_label), + 'truncation': truncation, + } + return keyframe + +#### TMP +def save_image(arr, fp): + image = PIL.Image.fromarray(arr) + image.save(fp, format='JPEG', quality=90) + +def save_ims(ims): + i = 0 + for im in ims: + path = './GAN_'+str(i).zfill(3)+'.jpeg' + save_image(im, path) + i += 1 +#### + +class TestLatentSpace(unittest.TestCase): + def test_sequence_keyframes(self): + num_frames = 20 + batch_size = 3 + keyframe_count = 3 + dim_z = 128 + dim_label = 1000 + keyframes = [create_random_keyframe(dim_z,dim_label) for i in range(keyframe_count)] + z_seq, label_seq, truncation_seq = latent_space.sequence_keyframes(keyframes, num_frames, batch_size) + self.assertIs(len(z_seq), num_frames) + self.assertIs(len(label_seq), num_frames) + gan = biggan.BigGAN() + ims = gan.sample(z_seq, label_seq, truncation_seq, batch_size) + save_ims(ims) + + + +if __name__ == '__main__': + unittest.main() |