aboutsummaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorVee9ahd1 <[email protected]>2019-05-12 01:29:05 -0400
committerVee9ahd1 <[email protected]>2019-05-12 01:29:05 -0400
commit2e27eb73344d691b657f72c8e794f81ce47036c6 (patch)
treefafe6255167c74131f03e2a80b105278c28af30d /test
parent560f86d452277084a1be04fbc4c0e8c5f1206ff5 (diff)
implemented most of the basic functionality from the prototype script and created some messy tests
Diffstat (limited to 'test')
-rw-r--r--test/biggan_test.py26
-rw-r--r--test/ganbreeder_test.py16
-rw-r--r--test/latent_space_test.py50
3 files changed, 92 insertions, 0 deletions
diff --git a/test/biggan_test.py b/test/biggan_test.py
new file mode 100644
index 0000000..db3e59f
--- /dev/null
+++ b/test/biggan_test.py
@@ -0,0 +1,26 @@
+import unittest
+from scipy.stats import truncnorm
+import numpy as np
+from gantools import biggan
+
+def create_random_input(dim_z, vocab_size, batch_size=1, truncation = 0.5, rand_seed = 123):
+ def one_hot(index, dim):
+ y = np.zeros((1,dim))
+ if index < dim:
+ y[0,index] = 1.0
+ return y
+ random_state = np.random.RandomState(rand_seed)
+ vectors = truncation * truncnorm.rvs(-2, 2, size=(batch_size, dim_z), random_state=random_state)
+ #TODO random labels
+ labels = one_hot(0, vocab_size)#np.random.random_sample((vocab_size,))
+ return vectors, labels, truncation
+
+class TestBigGAN(unittest.TestCase):
+ def test_biggan_sample(self):
+ gan = biggan.BigGAN()
+ vectors, labels, truncation = create_random_input(gan.dim_z, gan.vocab_size)
+ ims = gan.sample(vectors, labels, truncation)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/test/ganbreeder_test.py b/test/ganbreeder_test.py
new file mode 100644
index 0000000..816e077
--- /dev/null
+++ b/test/ganbreeder_test.py
@@ -0,0 +1,16 @@
+import unittest
+from gantools import ganbreeder
+from test import secrets
+
+class TestGanbreeder(unittest.TestCase):
+ def test_get_info(self):
+ username = secrets.username
+ password = secrets.password
+ sid = ganbreeder.login(username, password)
+ self.assertNotEqual(sid, '', 'login() failed to produce an sid. check internet connection.')
+ key = 'd62c507ab4bea4ed7b70c64a' #some arbitrary ganbreeder key
+ ganbreeder.get_info(sid, key)
+
+if __name__ == '__main__':
+ unittest.main()
+
diff --git a/test/latent_space_test.py b/test/latent_space_test.py
new file mode 100644
index 0000000..0655923
--- /dev/null
+++ b/test/latent_space_test.py
@@ -0,0 +1,50 @@
+import unittest
+import numpy as np
+from scipy.stats import truncnorm
+from gantools import latent_space
+from gantools import biggan
+import PIL.Image
+
+def create_random_keyframe(n_vector, n_label):
+ truncation = (0.9 - 0.1)*np.random.random() + 0.1
+ random_state = np.random.RandomState()
+ vectors = truncnorm.rvs(-2, 2, size=(n_vector,), random_state=random_state)
+ keyframe = {
+ 'vector': vectors.tolist(),
+ 'label': latent_space.one_hot(np.random.randint(0, n_label), n_label),
+ 'truncation': truncation,
+ }
+ return keyframe
+
+#### TMP
+def save_image(arr, fp):
+ image = PIL.Image.fromarray(arr)
+ image.save(fp, format='JPEG', quality=90)
+
+def save_ims(ims):
+ i = 0
+ for im in ims:
+ path = './GAN_'+str(i).zfill(3)+'.jpeg'
+ save_image(im, path)
+ i += 1
+####
+
+class TestLatentSpace(unittest.TestCase):
+ def test_sequence_keyframes(self):
+ num_frames = 20
+ batch_size = 3
+ keyframe_count = 3
+ dim_z = 128
+ dim_label = 1000
+ keyframes = [create_random_keyframe(dim_z,dim_label) for i in range(keyframe_count)]
+ z_seq, label_seq, truncation_seq = latent_space.sequence_keyframes(keyframes, num_frames, batch_size)
+ self.assertIs(len(z_seq), num_frames)
+ self.assertIs(len(label_seq), num_frames)
+ gan = biggan.BigGAN()
+ ims = gan.sample(z_seq, label_seq, truncation_seq, batch_size)
+ save_ims(ims)
+
+
+
+if __name__ == '__main__':
+ unittest.main()