diff options
Diffstat (limited to 'test')
-rw-r--r-- | test/biggan_test.py | 26 | ||||
-rw-r--r-- | test/ganbreeder_test.py | 16 | ||||
-rw-r--r-- | test/latent_space_test.py | 50 |
3 files changed, 92 insertions, 0 deletions
diff --git a/test/biggan_test.py b/test/biggan_test.py new file mode 100644 index 0000000..db3e59f --- /dev/null +++ b/test/biggan_test.py @@ -0,0 +1,26 @@ +import unittest +from scipy.stats import truncnorm +import numpy as np +from gantools import biggan + +def create_random_input(dim_z, vocab_size, batch_size=1, truncation = 0.5, rand_seed = 123): + def one_hot(index, dim): + y = np.zeros((1,dim)) + if index < dim: + y[0,index] = 1.0 + return y + random_state = np.random.RandomState(rand_seed) + vectors = truncation * truncnorm.rvs(-2, 2, size=(batch_size, dim_z), random_state=random_state) + #TODO random labels + labels = one_hot(0, vocab_size)#np.random.random_sample((vocab_size,)) + return vectors, labels, truncation + +class TestBigGAN(unittest.TestCase): + def test_biggan_sample(self): + gan = biggan.BigGAN() + vectors, labels, truncation = create_random_input(gan.dim_z, gan.vocab_size) + ims = gan.sample(vectors, labels, truncation) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/ganbreeder_test.py b/test/ganbreeder_test.py new file mode 100644 index 0000000..816e077 --- /dev/null +++ b/test/ganbreeder_test.py @@ -0,0 +1,16 @@ +import unittest +from gantools import ganbreeder +from test import secrets + +class TestGanbreeder(unittest.TestCase): + def test_get_info(self): + username = secrets.username + password = secrets.password + sid = ganbreeder.login(username, password) + self.assertNotEqual(sid, '', 'login() failed to produce an sid. check internet connection.') + key = 'd62c507ab4bea4ed7b70c64a' #some arbitrary ganbreeder key + ganbreeder.get_info(sid, key) + +if __name__ == '__main__': + unittest.main() + diff --git a/test/latent_space_test.py b/test/latent_space_test.py new file mode 100644 index 0000000..0655923 --- /dev/null +++ b/test/latent_space_test.py @@ -0,0 +1,50 @@ +import unittest +import numpy as np +from scipy.stats import truncnorm +from gantools import latent_space +from gantools import biggan +import PIL.Image + +def create_random_keyframe(n_vector, n_label): + truncation = (0.9 - 0.1)*np.random.random() + 0.1 + random_state = np.random.RandomState() + vectors = truncnorm.rvs(-2, 2, size=(n_vector,), random_state=random_state) + keyframe = { + 'vector': vectors.tolist(), + 'label': latent_space.one_hot(np.random.randint(0, n_label), n_label), + 'truncation': truncation, + } + return keyframe + +#### TMP +def save_image(arr, fp): + image = PIL.Image.fromarray(arr) + image.save(fp, format='JPEG', quality=90) + +def save_ims(ims): + i = 0 + for im in ims: + path = './GAN_'+str(i).zfill(3)+'.jpeg' + save_image(im, path) + i += 1 +#### + +class TestLatentSpace(unittest.TestCase): + def test_sequence_keyframes(self): + num_frames = 20 + batch_size = 3 + keyframe_count = 3 + dim_z = 128 + dim_label = 1000 + keyframes = [create_random_keyframe(dim_z,dim_label) for i in range(keyframe_count)] + z_seq, label_seq, truncation_seq = latent_space.sequence_keyframes(keyframes, num_frames, batch_size) + self.assertIs(len(z_seq), num_frames) + self.assertIs(len(label_seq), num_frames) + gan = biggan.BigGAN() + ims = gan.sample(z_seq, label_seq, truncation_seq, batch_size) + save_ims(ims) + + + +if __name__ == '__main__': + unittest.main() |