aboutsummaryrefslogtreecommitdiff
path: root/gantools
diff options
context:
space:
mode:
Diffstat (limited to 'gantools')
-rw-r--r--gantools/biggan.py12
-rw-r--r--gantools/ganbreeder.py20
-rw-r--r--gantools/latent_space.py2
3 files changed, 29 insertions, 5 deletions
diff --git a/gantools/biggan.py b/gantools/biggan.py
index 571387b..71228ae 100644
--- a/gantools/biggan.py
+++ b/gantools/biggan.py
@@ -4,8 +4,20 @@ import tensorflow_hub as hub
import numpy as np
from itertools import cycle
+#-----------------------------------------------------------------
+# fix "could not create cudnn handle" error
+# see: https://github.com/tensorflow/tensorflow/issues/24496
+from tensorflow.compat.v1 import ConfigProto
+from tensorflow.compat.v1 import InteractiveSession
+config = ConfigProto()
+config.gpu_options.allow_growth = True
+#-----------------------------------------------------------------
+
+session = InteractiveSession(config=config)
+
MODULE_PATH = 'https://tfhub.dev/deepmind/biggan-512/2'
+
class BigGAN(object):
def __init__(self, module_path=MODULE_PATH):
tf.reset_default_graph()
diff --git a/gantools/ganbreeder.py b/gantools/ganbreeder.py
index b55bc41..a2546a2 100644
--- a/gantools/ganbreeder.py
+++ b/gantools/ganbreeder.py
@@ -1,10 +1,11 @@
# client functions for interacting with the ganbreeder api
import requests
import json
+import numpy as np
def login(username, password):
def get_sid():
- url = 'https://ganbreeder.app/login'
+ url = 'https://artbreeder.com/login'
r = requests.get(url)
r.raise_for_status()
for c in r.cookies:
@@ -13,7 +14,7 @@ def login(username, password):
return c.value
def login_auth(sid, username, password):
- url = 'https://ganbreeder.app/login'
+ url = 'https://artbreeder.com/login'
headers = {
'Content-Type': 'application/json',
}
@@ -34,15 +35,26 @@ def login(username, password):
login_auth(sid, username, password)
return sid
+def parse_info_dict(info):
+ keyframe = dict()
+ keyframe['truncation'] = np.float(info['truncation'])
+ keyframe['latent'] = np.asarray(info['latent'])
+ classes = info['classes']
+ keyframe['label'] = np.zeros(1000)# length of label ("classes") vector: 1000
+ for c in info['classes']:
+ # artbreeder class entries look like [index, value] where index < 1000
+ keyframe['label'][c[0]] = c[1]
+ return keyframe
+
def get_info(sid, key):
if sid == '':
raise Exception('Cannot get info; session ID not defined. Be sure to login() first.')
cookies = {
'connect.sid': sid
}
- r = requests.get('http://ganbreeder.app/info?k='+str(key), cookies=cookies)
+ r = requests.get('http://artbreeder.com/info?k='+str(key), cookies=cookies)
r.raise_for_status()
- return(r.json())
+ return parse_info_dict(r.json())
def get_info_batch(username, password, keys):
l = list()
diff --git a/gantools/latent_space.py b/gantools/latent_space.py
index 64e3dec..a39e804 100644
--- a/gantools/latent_space.py
+++ b/gantools/latent_space.py
@@ -44,7 +44,7 @@ def sequence_keyframes(keyframes, num_frames, batch_size=1, interp_method='linea
keyframes.append(keyframes[0])# seq returns to start
truncation_keys = np.asarray([keyframe['truncation'] for keyframe in keyframes])
- z_keys = np.asarray([np.asarray(keyframe['vector']) * keyframe['truncation'] for keyframe in keyframes])
+ z_keys = np.asarray([np.asarray(keyframe['latent']) * keyframe['truncation'] for keyframe in keyframes])
label_keys = np.asarray([keyframe['label'] for keyframe in keyframes])
z_seq = interp_fn(z_keys, num_frames)