From 044232a0947fcf02118a268e45a615cc1a8065a2 Mon Sep 17 00:00:00 2001 From: Vee9ahd1 Date: Sun, 22 Sep 2019 19:43:03 -0400 Subject: patched for artbreeder support and fixed some weird CUDNN_STATUS_INTERNAL_ERROR. if CPU support is broken now, this commit is probably why. --- gantools/biggan.py | 12 ++++++++++++ gantools/ganbreeder.py | 20 ++++++++++++++++---- gantools/latent_space.py | 2 +- 3 files changed, 29 insertions(+), 5 deletions(-) (limited to 'gantools') diff --git a/gantools/biggan.py b/gantools/biggan.py index 571387b..71228ae 100644 --- a/gantools/biggan.py +++ b/gantools/biggan.py @@ -4,8 +4,20 @@ import tensorflow_hub as hub import numpy as np from itertools import cycle +#----------------------------------------------------------------- +# fix "could not create cudnn handle" error +# see: https://github.com/tensorflow/tensorflow/issues/24496 +from tensorflow.compat.v1 import ConfigProto +from tensorflow.compat.v1 import InteractiveSession +config = ConfigProto() +config.gpu_options.allow_growth = True +#----------------------------------------------------------------- + +session = InteractiveSession(config=config) + MODULE_PATH = 'https://tfhub.dev/deepmind/biggan-512/2' + class BigGAN(object): def __init__(self, module_path=MODULE_PATH): tf.reset_default_graph() diff --git a/gantools/ganbreeder.py b/gantools/ganbreeder.py index b55bc41..a2546a2 100644 --- a/gantools/ganbreeder.py +++ b/gantools/ganbreeder.py @@ -1,10 +1,11 @@ # client functions for interacting with the ganbreeder api import requests import json +import numpy as np def login(username, password): def get_sid(): - url = 'https://ganbreeder.app/login' + url = 'https://artbreeder.com/login' r = requests.get(url) r.raise_for_status() for c in r.cookies: @@ -13,7 +14,7 @@ def login(username, password): return c.value def login_auth(sid, username, password): - url = 'https://ganbreeder.app/login' + url = 'https://artbreeder.com/login' headers = { 'Content-Type': 'application/json', } @@ -34,15 +35,26 @@ def login(username, password): login_auth(sid, username, password) return sid +def parse_info_dict(info): + keyframe = dict() + keyframe['truncation'] = np.float(info['truncation']) + keyframe['latent'] = np.asarray(info['latent']) + classes = info['classes'] + keyframe['label'] = np.zeros(1000)# length of label ("classes") vector: 1000 + for c in info['classes']: + # artbreeder class entries look like [index, value] where index < 1000 + keyframe['label'][c[0]] = c[1] + return keyframe + def get_info(sid, key): if sid == '': raise Exception('Cannot get info; session ID not defined. Be sure to login() first.') cookies = { 'connect.sid': sid } - r = requests.get('http://ganbreeder.app/info?k='+str(key), cookies=cookies) + r = requests.get('http://artbreeder.com/info?k='+str(key), cookies=cookies) r.raise_for_status() - return(r.json()) + return parse_info_dict(r.json()) def get_info_batch(username, password, keys): l = list() diff --git a/gantools/latent_space.py b/gantools/latent_space.py index 64e3dec..a39e804 100644 --- a/gantools/latent_space.py +++ b/gantools/latent_space.py @@ -44,7 +44,7 @@ def sequence_keyframes(keyframes, num_frames, batch_size=1, interp_method='linea keyframes.append(keyframes[0])# seq returns to start truncation_keys = np.asarray([keyframe['truncation'] for keyframe in keyframes]) - z_keys = np.asarray([np.asarray(keyframe['vector']) * keyframe['truncation'] for keyframe in keyframes]) + z_keys = np.asarray([np.asarray(keyframe['latent']) * keyframe['truncation'] for keyframe in keyframes]) label_keys = np.asarray([keyframe['label'] for keyframe in keyframes]) z_seq = interp_fn(z_keys, num_frames) -- cgit v1.2.1