import unittest import numpy as np from scipy.stats import truncnorm from gantools import latent_space from gantools import biggan import PIL.Image def create_random_keyframe(n_vector, n_label): truncation = (0.9 - 0.1)*np.random.random() + 0.1 random_state = np.random.RandomState() vectors = truncnorm.rvs(-2, 2, size=(n_vector,), random_state=random_state) keyframe = { 'vector': vectors.tolist(), 'label': latent_space.one_hot(np.random.randint(0, n_label), n_label), 'truncation': truncation, } return keyframe #### TMP def save_image(arr, fp): image = PIL.Image.fromarray(arr) image.save(fp, format='JPEG', quality=90) def save_ims(ims): i = 0 for im in ims: path = './GAN_'+str(i).zfill(3)+'.jpeg' save_image(im, path) i += 1 #### class TestLatentSpace(unittest.TestCase): def test_sequence_keyframes(self): num_frames = 20 batch_size = 3 keyframe_count = 3 dim_z = 128 dim_label = 1000 keyframes = [create_random_keyframe(dim_z,dim_label) for i in range(keyframe_count)] z_seq, label_seq, truncation_seq = latent_space.sequence_keyframes(keyframes, num_frames, batch_size) self.assertIs(len(z_seq), num_frames) self.assertIs(len(label_seq), num_frames) gan = biggan.BigGAN() ims = gan.sample(z_seq, label_seq, truncation_seq, batch_size) save_ims(ims) if __name__ == '__main__': unittest.main()