diff options
author | Vee9ahd1 <[email protected]> | 2019-05-12 01:29:05 -0400 |
---|---|---|
committer | Vee9ahd1 <[email protected]> | 2019-05-12 01:29:05 -0400 |
commit | 2e27eb73344d691b657f72c8e794f81ce47036c6 (patch) | |
tree | fafe6255167c74131f03e2a80b105278c28af30d /test/latent_space_test.py | |
parent | 560f86d452277084a1be04fbc4c0e8c5f1206ff5 (diff) |
implemented most of the basic functionality from the prototype script and created some messy tests
Diffstat (limited to 'test/latent_space_test.py')
-rw-r--r-- | test/latent_space_test.py | 50 |
1 files changed, 50 insertions, 0 deletions
diff --git a/test/latent_space_test.py b/test/latent_space_test.py new file mode 100644 index 0000000..0655923 --- /dev/null +++ b/test/latent_space_test.py @@ -0,0 +1,50 @@ +import unittest +import numpy as np +from scipy.stats import truncnorm +from gantools import latent_space +from gantools import biggan +import PIL.Image + +def create_random_keyframe(n_vector, n_label): + truncation = (0.9 - 0.1)*np.random.random() + 0.1 + random_state = np.random.RandomState() + vectors = truncnorm.rvs(-2, 2, size=(n_vector,), random_state=random_state) + keyframe = { + 'vector': vectors.tolist(), + 'label': latent_space.one_hot(np.random.randint(0, n_label), n_label), + 'truncation': truncation, + } + return keyframe + +#### TMP +def save_image(arr, fp): + image = PIL.Image.fromarray(arr) + image.save(fp, format='JPEG', quality=90) + +def save_ims(ims): + i = 0 + for im in ims: + path = './GAN_'+str(i).zfill(3)+'.jpeg' + save_image(im, path) + i += 1 +#### + +class TestLatentSpace(unittest.TestCase): + def test_sequence_keyframes(self): + num_frames = 20 + batch_size = 3 + keyframe_count = 3 + dim_z = 128 + dim_label = 1000 + keyframes = [create_random_keyframe(dim_z,dim_label) for i in range(keyframe_count)] + z_seq, label_seq, truncation_seq = latent_space.sequence_keyframes(keyframes, num_frames, batch_size) + self.assertIs(len(z_seq), num_frames) + self.assertIs(len(label_seq), num_frames) + gan = biggan.BigGAN() + ims = gan.sample(z_seq, label_seq, truncation_seq, batch_size) + save_ims(ims) + + + +if __name__ == '__main__': + unittest.main() |