diff options
Diffstat (limited to 'test/biggan_test.py')
-rw-r--r-- | test/biggan_test.py | 26 |
1 files changed, 26 insertions, 0 deletions
diff --git a/test/biggan_test.py b/test/biggan_test.py new file mode 100644 index 0000000..db3e59f --- /dev/null +++ b/test/biggan_test.py @@ -0,0 +1,26 @@ +import unittest +from scipy.stats import truncnorm +import numpy as np +from gantools import biggan + +def create_random_input(dim_z, vocab_size, batch_size=1, truncation = 0.5, rand_seed = 123): + def one_hot(index, dim): + y = np.zeros((1,dim)) + if index < dim: + y[0,index] = 1.0 + return y + random_state = np.random.RandomState(rand_seed) + vectors = truncation * truncnorm.rvs(-2, 2, size=(batch_size, dim_z), random_state=random_state) + #TODO random labels + labels = one_hot(0, vocab_size)#np.random.random_sample((vocab_size,)) + return vectors, labels, truncation + +class TestBigGAN(unittest.TestCase): + def test_biggan_sample(self): + gan = biggan.BigGAN() + vectors, labels, truncation = create_random_input(gan.dim_z, gan.vocab_size) + ims = gan.sample(vectors, labels, truncation) + + +if __name__ == '__main__': + unittest.main() |