From 2e27eb73344d691b657f72c8e794f81ce47036c6 Mon Sep 17 00:00:00 2001 From: Vee9ahd1 Date: Sun, 12 May 2019 01:29:05 -0400 Subject: implemented most of the basic functionality from the prototype script and created some messy tests --- test/biggan_test.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 test/biggan_test.py (limited to 'test/biggan_test.py') diff --git a/test/biggan_test.py b/test/biggan_test.py new file mode 100644 index 0000000..db3e59f --- /dev/null +++ b/test/biggan_test.py @@ -0,0 +1,26 @@ +import unittest +from scipy.stats import truncnorm +import numpy as np +from gantools import biggan + +def create_random_input(dim_z, vocab_size, batch_size=1, truncation = 0.5, rand_seed = 123): + def one_hot(index, dim): + y = np.zeros((1,dim)) + if index < dim: + y[0,index] = 1.0 + return y + random_state = np.random.RandomState(rand_seed) + vectors = truncation * truncnorm.rvs(-2, 2, size=(batch_size, dim_z), random_state=random_state) + #TODO random labels + labels = one_hot(0, vocab_size)#np.random.random_sample((vocab_size,)) + return vectors, labels, truncation + +class TestBigGAN(unittest.TestCase): + def test_biggan_sample(self): + gan = biggan.BigGAN() + vectors, labels, truncation = create_random_input(gan.dim_z, gan.vocab_size) + ims = gan.sample(vectors, labels, truncation) + + +if __name__ == '__main__': + unittest.main() -- cgit v1.2.1