aboutsummaryrefslogtreecommitdiff
path: root/test/biggan_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/biggan_test.py')
-rw-r--r--test/biggan_test.py26
1 files changed, 26 insertions, 0 deletions
diff --git a/test/biggan_test.py b/test/biggan_test.py
new file mode 100644
index 0000000..db3e59f
--- /dev/null
+++ b/test/biggan_test.py
@@ -0,0 +1,26 @@
+import unittest
+from scipy.stats import truncnorm
+import numpy as np
+from gantools import biggan
+
+def create_random_input(dim_z, vocab_size, batch_size=1, truncation = 0.5, rand_seed = 123):
+ def one_hot(index, dim):
+ y = np.zeros((1,dim))
+ if index < dim:
+ y[0,index] = 1.0
+ return y
+ random_state = np.random.RandomState(rand_seed)
+ vectors = truncation * truncnorm.rvs(-2, 2, size=(batch_size, dim_z), random_state=random_state)
+ #TODO random labels
+ labels = one_hot(0, vocab_size)#np.random.random_sample((vocab_size,))
+ return vectors, labels, truncation
+
+class TestBigGAN(unittest.TestCase):
+ def test_biggan_sample(self):
+ gan = biggan.BigGAN()
+ vectors, labels, truncation = create_random_input(gan.dim_z, gan.vocab_size)
+ ims = gan.sample(vectors, labels, truncation)
+
+
+if __name__ == '__main__':
+ unittest.main()