From 044232a0947fcf02118a268e45a615cc1a8065a2 Mon Sep 17 00:00:00 2001 From: Vee9ahd1 Date: Sun, 22 Sep 2019 19:43:03 -0400 Subject: patched for artbreeder support and fixed some weird CUDNN_STATUS_INTERNAL_ERROR. if CPU support is broken now, this commit is probably why. --- gantools/ganbreeder.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) (limited to 'gantools/ganbreeder.py') diff --git a/gantools/ganbreeder.py b/gantools/ganbreeder.py index b55bc41..a2546a2 100644 --- a/gantools/ganbreeder.py +++ b/gantools/ganbreeder.py @@ -1,10 +1,11 @@ # client functions for interacting with the ganbreeder api import requests import json +import numpy as np def login(username, password): def get_sid(): - url = 'https://ganbreeder.app/login' + url = 'https://artbreeder.com/login' r = requests.get(url) r.raise_for_status() for c in r.cookies: @@ -13,7 +14,7 @@ def login(username, password): return c.value def login_auth(sid, username, password): - url = 'https://ganbreeder.app/login' + url = 'https://artbreeder.com/login' headers = { 'Content-Type': 'application/json', } @@ -34,15 +35,26 @@ def login(username, password): login_auth(sid, username, password) return sid +def parse_info_dict(info): + keyframe = dict() + keyframe['truncation'] = np.float(info['truncation']) + keyframe['latent'] = np.asarray(info['latent']) + classes = info['classes'] + keyframe['label'] = np.zeros(1000)# length of label ("classes") vector: 1000 + for c in info['classes']: + # artbreeder class entries look like [index, value] where index < 1000 + keyframe['label'][c[0]] = c[1] + return keyframe + def get_info(sid, key): if sid == '': raise Exception('Cannot get info; session ID not defined. Be sure to login() first.') cookies = { 'connect.sid': sid } - r = requests.get('http://ganbreeder.app/info?k='+str(key), cookies=cookies) + r = requests.get('http://artbreeder.com/info?k='+str(key), cookies=cookies) r.raise_for_status() - return(r.json()) + return parse_info_dict(r.json()) def get_info_batch(username, password, keys): l = list() -- cgit v1.2.1