diff options
author | Vee9ahd1 <[email protected]> | 2019-06-03 16:15:53 -0400 |
---|---|---|
committer | Vee9ahd1 <[email protected]> | 2019-06-03 16:15:53 -0400 |
commit | 64307dc3dbe633d4ff8bfb80a7f5895f00019356 (patch) | |
tree | b96292b6b5b9addb71dc41f19068ce5b2f26ab25 /gantools | |
parent | e2dfe1e1dce59723b424ed334234475c0e9b6227 (diff) |
implemented selectable interpolation and loop settings
Diffstat (limited to 'gantools')
-rw-r--r-- | gantools/cli.py | 29 | ||||
-rw-r--r-- | gantools/latent_space.py | 67 |
2 files changed, 47 insertions, 49 deletions
diff --git a/gantools/cli.py b/gantools/cli.py index 0ad3a8d..3e42f2e 100644 --- a/gantools/cli.py +++ b/gantools/cli.py @@ -4,26 +4,37 @@ from gantools import biggan from gantools import latent_space from gantools import image_utils -# create entrypoints for cli tools -def main(): - ## handle args - parser = argparse.ArgumentParser(description='GAN tools') +def handle_args(argv=None): + parser = argparse.ArgumentParser( + description='GAN tools', + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) # load from ganbreeder - parser.add_argument('-u', '--username', help='Ganbreeder account email address/username.') - parser.add_argument('-p', '--password', help='Ganbreeder account password.') - parser.add_argument('-k', '--keys', nargs='+', help='Ganbreeder keys.') + ganbreeder_group = parser.add_argument_group(title='GANbreeder login') + ganbreeder_group.add_argument('-u', '--username', help='Ganbreeder account email address/username.') + ganbreeder_group.add_argument('-p', '--password', help='Ganbreeder account password.') + ganbreeder_group.add_argument('-k', '--keys', nargs='+', help='Ganbreeder keys.') parser.add_argument('-n', '--nframes', metavar='N', type=int, help='Total number of frames in the final animation.', default=10) parser.add_argument('-b', '--nbatch', metavar='N', type=int, help='Number of frames in each \'batch\' \ (note: the truncation value can only change once per batch. Don\'t fuck with this unless you know \ what it does.).', default=1) parser.add_argument('-o', '--output-dir', help='Directory path for output images.') parser.add_argument('--prefix', help='File prefix for output images.') - args = parser.parse_args() + parser.add_argument('--interp', help='Set interpolation method.', choices=['linear', 'cubic']) + group_loop = parser.add_mutually_exclusive_group(required=False) + group_loop.add_argument('--loop', dest='loop', action='store_true', default=True, help='Loop the animation.') + group_loop.add_argument('--no-loop', dest='loop', action='store_false', help='Don\'t loop the animation.') + args = parser.parse_args(argv) # validate args - if args.keys and not (args.username and args.password): + if not (lambda l: (not any(l)) or all(l))(\ + [e is not None and e is not [] for e in [args.username, args.password, args.keys]]): parser.error('The --keys argument requires a --username and --password to login to ganbreeder') sys.exit(1) + return args +# create entrypoints for cli tools +def main(): + handle_args() # get animation keyframes from ganbreeder print('Downloading keyframe info from ganbreeder...') keyframes = ganbreeder.get_info_batch(args.username, args.password, args.keys) diff --git a/gantools/latent_space.py b/gantools/latent_space.py index e8e997a..44c64a3 100644 --- a/gantools/latent_space.py +++ b/gantools/latent_space.py @@ -7,12 +7,13 @@ def one_hot(index, dim): y[0, index] = 1.0 return y -# linear interpolation -def linear_interp(begin, end, step_count): - initial = np.tile(begin, (step_count, 1)) - delta = np.tile((end - begin)/step_count, (step_count, 1)) - g = np.tile(np.arange(step_count), (begin.size, 1)).transpose() - return (delta * g) + initial +# interpolation methods +def linear_interp(points, step_count): + def linear_interp1d(y): + x = np.linspace(0., 1., len(y)) + xnew = np.linspace(0., 1., step_count) + return interpolate.interp1d(x, y)(xnew) + return np.apply_along_axis(linear_interp1d, 0, points) def cubic_spline_interp(points, step_count): def cubic_spline_interp1d(y): @@ -20,49 +21,35 @@ def cubic_spline_interp(points, step_count): tck = interpolate.splrep(x, y, s=0) xnew = np.linspace(0., 1., step_count) return interpolate.splev(xnew, tck, der=0) - p = np.asarray(points) return np.apply_along_axis(cubic_spline_interp1d, 0, points) # TODO: the math in this function is embarrasingly bad. fix at some point. -def sequence_keyframes(keyframes, num_frames, batch_size=1): +def sequence_keyframes(keyframes, num_frames, batch_size=1, interp_method='linear', loop=False): + interp_fn = { + 'linear': linear_interp, + 'cubic': cubic_spline_interp, + }[interp_method] div = int(num_frames // len(keyframes)) rem = int(num_frames - (div * len(keyframes))) - frame_counts = np.full((len(keyframes),), div) + \ - np.append(np.ones((rem,), dtype=int), np.zeros((len(keyframes) - rem,), dtype=int)) - batch_div = int(num_frames//batch_size) - batch_rem = 1 if int(num_frames%batch_size) > 0 else 0 + frame_counts = np.full((len(keyframes), ), div) + \ + np.append(np.ones((rem,), dtype=int), np.zeros((len(keyframes) - rem, ), dtype=int)) + batch_div = int(num_frames // batch_size) + batch_rem = 1 if int(num_frames % batch_size) > 0 else 0 batch_count = batch_div + batch_rem - keyframes.append(keyframes[0])# seq returns to start - readahead = iter(keyframes) - next(readahead) - ''' - z_seq, label_seq, truncation_seq = [], [], [] - for (begin, end, frame_count) in zip(keyframes, readahead, frame_counts): - z_begin = np.asarray(begin['vector']) * begin['truncation'] - z_end = np.asarray(end['vector']) * end['truncation'] - z_seq.extend(linear_interp(z_begin, z_end, frame_count)) - label_begin = np.asarray(begin['label']) - label_end = np.asarray(end['label']) - label_seq.extend(linear_interp(label_begin, label_end, frame_count)) - truncation_seq.extend(np.linspace(begin['truncation'], end['truncation'], frame_count)) - ''' - # cubic interp + if loop is True: + keyframes.append(keyframes[0])# seq returns to start + truncation_keys = np.asarray([keyframe['truncation'] for keyframe in keyframes]) z_keys = np.asarray([np.asarray(keyframe['vector']) * keyframe['truncation'] for keyframe in keyframes]) label_keys = np.asarray([keyframe['label'] for keyframe in keyframes]) - z_seq = cubic_spline_interp(z_keys, num_frames) - label_seq = cubic_spline_interp(label_keys, num_frames) - truncation_seq = [] - for (begin, end, frame_count) in zip(keyframes, readahead, frame_counts): - truncation_seq.extend(np.linspace(begin['truncation'], end['truncation'], frame_count)) - """ - truncation_seq = np.reshape(\ - [np.linspace(trunc_begin, trunc_end, frame_count)\ - for (trunc_begin, trunc_end, frame_count)\ - in zip(truncation_keys[:-1], truncation_keys[1:], frame_counts)],\ - (-1,)) - """ + + z_seq = interp_fn(z_keys, num_frames) + label_seq = interp_fn(label_keys, num_frames) + truncation_seq = interp_fn(truncation_keys, num_frames) + # you can only change trunc once per batch - truncation_seq_resampled = signal.resample(truncation_seq, batch_count) + truncation_seq_resampled = np.full((1),truncation_seq[0])\ + if batch_count is 1\ + else signal.resample(truncation_seq, batch_count) return np.asarray(z_seq), np.asarray(label_seq), truncation_seq_resampled |