aboutsummaryrefslogtreecommitdiff
path: root/gantools
diff options
context:
space:
mode:
authorVee9ahd1 <[email protected]>2019-06-03 16:15:53 -0400
committerVee9ahd1 <[email protected]>2019-06-03 16:15:53 -0400
commit64307dc3dbe633d4ff8bfb80a7f5895f00019356 (patch)
treeb96292b6b5b9addb71dc41f19068ce5b2f26ab25 /gantools
parente2dfe1e1dce59723b424ed334234475c0e9b6227 (diff)
implemented selectable interpolation and loop settings
Diffstat (limited to 'gantools')
-rw-r--r--gantools/cli.py29
-rw-r--r--gantools/latent_space.py67
2 files changed, 47 insertions, 49 deletions
diff --git a/gantools/cli.py b/gantools/cli.py
index 0ad3a8d..3e42f2e 100644
--- a/gantools/cli.py
+++ b/gantools/cli.py
@@ -4,26 +4,37 @@ from gantools import biggan
from gantools import latent_space
from gantools import image_utils
-# create entrypoints for cli tools
-def main():
- ## handle args
- parser = argparse.ArgumentParser(description='GAN tools')
+def handle_args(argv=None):
+ parser = argparse.ArgumentParser(
+ description='GAN tools',
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
# load from ganbreeder
- parser.add_argument('-u', '--username', help='Ganbreeder account email address/username.')
- parser.add_argument('-p', '--password', help='Ganbreeder account password.')
- parser.add_argument('-k', '--keys', nargs='+', help='Ganbreeder keys.')
+ ganbreeder_group = parser.add_argument_group(title='GANbreeder login')
+ ganbreeder_group.add_argument('-u', '--username', help='Ganbreeder account email address/username.')
+ ganbreeder_group.add_argument('-p', '--password', help='Ganbreeder account password.')
+ ganbreeder_group.add_argument('-k', '--keys', nargs='+', help='Ganbreeder keys.')
parser.add_argument('-n', '--nframes', metavar='N', type=int, help='Total number of frames in the final animation.', default=10)
parser.add_argument('-b', '--nbatch', metavar='N', type=int, help='Number of frames in each \'batch\' \
(note: the truncation value can only change once per batch. Don\'t fuck with this unless you know \
what it does.).', default=1)
parser.add_argument('-o', '--output-dir', help='Directory path for output images.')
parser.add_argument('--prefix', help='File prefix for output images.')
- args = parser.parse_args()
+ parser.add_argument('--interp', help='Set interpolation method.', choices=['linear', 'cubic'])
+ group_loop = parser.add_mutually_exclusive_group(required=False)
+ group_loop.add_argument('--loop', dest='loop', action='store_true', default=True, help='Loop the animation.')
+ group_loop.add_argument('--no-loop', dest='loop', action='store_false', help='Don\'t loop the animation.')
+ args = parser.parse_args(argv)
# validate args
- if args.keys and not (args.username and args.password):
+ if not (lambda l: (not any(l)) or all(l))(\
+ [e is not None and e is not [] for e in [args.username, args.password, args.keys]]):
parser.error('The --keys argument requires a --username and --password to login to ganbreeder')
sys.exit(1)
+ return args
+# create entrypoints for cli tools
+def main():
+ handle_args()
# get animation keyframes from ganbreeder
print('Downloading keyframe info from ganbreeder...')
keyframes = ganbreeder.get_info_batch(args.username, args.password, args.keys)
diff --git a/gantools/latent_space.py b/gantools/latent_space.py
index e8e997a..44c64a3 100644
--- a/gantools/latent_space.py
+++ b/gantools/latent_space.py
@@ -7,12 +7,13 @@ def one_hot(index, dim):
y[0, index] = 1.0
return y
-# linear interpolation
-def linear_interp(begin, end, step_count):
- initial = np.tile(begin, (step_count, 1))
- delta = np.tile((end - begin)/step_count, (step_count, 1))
- g = np.tile(np.arange(step_count), (begin.size, 1)).transpose()
- return (delta * g) + initial
+# interpolation methods
+def linear_interp(points, step_count):
+ def linear_interp1d(y):
+ x = np.linspace(0., 1., len(y))
+ xnew = np.linspace(0., 1., step_count)
+ return interpolate.interp1d(x, y)(xnew)
+ return np.apply_along_axis(linear_interp1d, 0, points)
def cubic_spline_interp(points, step_count):
def cubic_spline_interp1d(y):
@@ -20,49 +21,35 @@ def cubic_spline_interp(points, step_count):
tck = interpolate.splrep(x, y, s=0)
xnew = np.linspace(0., 1., step_count)
return interpolate.splev(xnew, tck, der=0)
- p = np.asarray(points)
return np.apply_along_axis(cubic_spline_interp1d, 0, points)
# TODO: the math in this function is embarrasingly bad. fix at some point.
-def sequence_keyframes(keyframes, num_frames, batch_size=1):
+def sequence_keyframes(keyframes, num_frames, batch_size=1, interp_method='linear', loop=False):
+ interp_fn = {
+ 'linear': linear_interp,
+ 'cubic': cubic_spline_interp,
+ }[interp_method]
div = int(num_frames // len(keyframes))
rem = int(num_frames - (div * len(keyframes)))
- frame_counts = np.full((len(keyframes),), div) + \
- np.append(np.ones((rem,), dtype=int), np.zeros((len(keyframes) - rem,), dtype=int))
- batch_div = int(num_frames//batch_size)
- batch_rem = 1 if int(num_frames%batch_size) > 0 else 0
+ frame_counts = np.full((len(keyframes), ), div) + \
+ np.append(np.ones((rem,), dtype=int), np.zeros((len(keyframes) - rem, ), dtype=int))
+ batch_div = int(num_frames // batch_size)
+ batch_rem = 1 if int(num_frames % batch_size) > 0 else 0
batch_count = batch_div + batch_rem
- keyframes.append(keyframes[0])# seq returns to start
- readahead = iter(keyframes)
- next(readahead)
- '''
- z_seq, label_seq, truncation_seq = [], [], []
- for (begin, end, frame_count) in zip(keyframes, readahead, frame_counts):
- z_begin = np.asarray(begin['vector']) * begin['truncation']
- z_end = np.asarray(end['vector']) * end['truncation']
- z_seq.extend(linear_interp(z_begin, z_end, frame_count))
- label_begin = np.asarray(begin['label'])
- label_end = np.asarray(end['label'])
- label_seq.extend(linear_interp(label_begin, label_end, frame_count))
- truncation_seq.extend(np.linspace(begin['truncation'], end['truncation'], frame_count))
- '''
- # cubic interp
+ if loop is True:
+ keyframes.append(keyframes[0])# seq returns to start
+
truncation_keys = np.asarray([keyframe['truncation'] for keyframe in keyframes])
z_keys = np.asarray([np.asarray(keyframe['vector']) * keyframe['truncation'] for keyframe in keyframes])
label_keys = np.asarray([keyframe['label'] for keyframe in keyframes])
- z_seq = cubic_spline_interp(z_keys, num_frames)
- label_seq = cubic_spline_interp(label_keys, num_frames)
- truncation_seq = []
- for (begin, end, frame_count) in zip(keyframes, readahead, frame_counts):
- truncation_seq.extend(np.linspace(begin['truncation'], end['truncation'], frame_count))
- """
- truncation_seq = np.reshape(\
- [np.linspace(trunc_begin, trunc_end, frame_count)\
- for (trunc_begin, trunc_end, frame_count)\
- in zip(truncation_keys[:-1], truncation_keys[1:], frame_counts)],\
- (-1,))
- """
+
+ z_seq = interp_fn(z_keys, num_frames)
+ label_seq = interp_fn(label_keys, num_frames)
+ truncation_seq = interp_fn(truncation_keys, num_frames)
+
# you can only change trunc once per batch
- truncation_seq_resampled = signal.resample(truncation_seq, batch_count)
+ truncation_seq_resampled = np.full((1),truncation_seq[0])\
+ if batch_count is 1\
+ else signal.resample(truncation_seq, batch_count)
return np.asarray(z_seq), np.asarray(label_seq), truncation_seq_resampled