diff options
-rw-r--r-- | gantools/biggan.py | 29 |
1 files changed, 18 insertions, 11 deletions
diff --git a/gantools/biggan.py b/gantools/biggan.py index 71228ae..fbfde98 100644 --- a/gantools/biggan.py +++ b/gantools/biggan.py @@ -1,19 +1,11 @@ # methods for setting up and interacting with biggan -import tensorflow as tf +import tensorflow.compat.v1 as tf import tensorflow_hub as hub import numpy as np from itertools import cycle -#----------------------------------------------------------------- -# fix "could not create cudnn handle" error -# see: https://github.com/tensorflow/tensorflow/issues/24496 -from tensorflow.compat.v1 import ConfigProto -from tensorflow.compat.v1 import InteractiveSession -config = ConfigProto() -config.gpu_options.allow_growth = True -#----------------------------------------------------------------- -session = InteractiveSession(config=config) +#session = InteractiveSession(config=config) MODULE_PATH = 'https://tfhub.dev/deepmind/biggan-512/2' @@ -22,6 +14,13 @@ class BigGAN(object): def __init__(self, module_path=MODULE_PATH): tf.reset_default_graph() print('Loading BigGAN module from:', module_path) + + #----------------------------------------------------------------- + # fix "RuntimeError: Exporting/importing meta graphs is not + # supported when eager execution is enabled." error when importing + # the tfhub module + tf.disable_eager_execution() + #----------------------------------------------------------------- module = hub.Module(module_path) self.inputs = {k: tf.placeholder(v.dtype, v.get_shape().as_list(), k) for k, v in module.get_input_info_dict().items()} @@ -34,7 +33,15 @@ class BigGAN(object): # initialize/instantiate tf variables initializer = tf.global_variables_initializer() - self.sess = tf.Session() + + #----------------------------------------------------------------- + # fix "could not create cudnn handle" error + # see: https://github.com/tensorflow/tensorflow/issues/24496 + config = tf.ConfigProto() + config.gpu_options.allow_growth = True + #----------------------------------------------------------------- + + self.sess = tf.Session(config=config) self.sess.run(initializer) # NOTE: use save callback to save images once per batch. return type changes to None. |