diff options
Diffstat (limited to 'test/latent_space_test.py')
-rw-r--r-- | test/latent_space_test.py | 50 |
1 files changed, 50 insertions, 0 deletions
diff --git a/test/latent_space_test.py b/test/latent_space_test.py new file mode 100644 index 0000000..0655923 --- /dev/null +++ b/test/latent_space_test.py @@ -0,0 +1,50 @@ +import unittest +import numpy as np +from scipy.stats import truncnorm +from gantools import latent_space +from gantools import biggan +import PIL.Image + +def create_random_keyframe(n_vector, n_label): + truncation = (0.9 - 0.1)*np.random.random() + 0.1 + random_state = np.random.RandomState() + vectors = truncnorm.rvs(-2, 2, size=(n_vector,), random_state=random_state) + keyframe = { + 'vector': vectors.tolist(), + 'label': latent_space.one_hot(np.random.randint(0, n_label), n_label), + 'truncation': truncation, + } + return keyframe + +#### TMP +def save_image(arr, fp): + image = PIL.Image.fromarray(arr) + image.save(fp, format='JPEG', quality=90) + +def save_ims(ims): + i = 0 + for im in ims: + path = './GAN_'+str(i).zfill(3)+'.jpeg' + save_image(im, path) + i += 1 +#### + +class TestLatentSpace(unittest.TestCase): + def test_sequence_keyframes(self): + num_frames = 20 + batch_size = 3 + keyframe_count = 3 + dim_z = 128 + dim_label = 1000 + keyframes = [create_random_keyframe(dim_z,dim_label) for i in range(keyframe_count)] + z_seq, label_seq, truncation_seq = latent_space.sequence_keyframes(keyframes, num_frames, batch_size) + self.assertIs(len(z_seq), num_frames) + self.assertIs(len(label_seq), num_frames) + gan = biggan.BigGAN() + ims = gan.sample(z_seq, label_seq, truncation_seq, batch_size) + save_ims(ims) + + + +if __name__ == '__main__': + unittest.main() |