From 64307dc3dbe633d4ff8bfb80a7f5895f00019356 Mon Sep 17 00:00:00 2001 From: Vee9ahd1 Date: Mon, 3 Jun 2019 16:15:53 -0400 Subject: implemented selectable interpolation and loop settings --- gantools/latent_space.py | 67 +++++++++++++++++++----------------------------- 1 file changed, 27 insertions(+), 40 deletions(-) (limited to 'gantools/latent_space.py') diff --git a/gantools/latent_space.py b/gantools/latent_space.py index e8e997a..44c64a3 100644 --- a/gantools/latent_space.py +++ b/gantools/latent_space.py @@ -7,12 +7,13 @@ def one_hot(index, dim): y[0, index] = 1.0 return y -# linear interpolation -def linear_interp(begin, end, step_count): - initial = np.tile(begin, (step_count, 1)) - delta = np.tile((end - begin)/step_count, (step_count, 1)) - g = np.tile(np.arange(step_count), (begin.size, 1)).transpose() - return (delta * g) + initial +# interpolation methods +def linear_interp(points, step_count): + def linear_interp1d(y): + x = np.linspace(0., 1., len(y)) + xnew = np.linspace(0., 1., step_count) + return interpolate.interp1d(x, y)(xnew) + return np.apply_along_axis(linear_interp1d, 0, points) def cubic_spline_interp(points, step_count): def cubic_spline_interp1d(y): @@ -20,49 +21,35 @@ def cubic_spline_interp(points, step_count): tck = interpolate.splrep(x, y, s=0) xnew = np.linspace(0., 1., step_count) return interpolate.splev(xnew, tck, der=0) - p = np.asarray(points) return np.apply_along_axis(cubic_spline_interp1d, 0, points) # TODO: the math in this function is embarrasingly bad. fix at some point. -def sequence_keyframes(keyframes, num_frames, batch_size=1): +def sequence_keyframes(keyframes, num_frames, batch_size=1, interp_method='linear', loop=False): + interp_fn = { + 'linear': linear_interp, + 'cubic': cubic_spline_interp, + }[interp_method] div = int(num_frames // len(keyframes)) rem = int(num_frames - (div * len(keyframes))) - frame_counts = np.full((len(keyframes),), div) + \ - np.append(np.ones((rem,), dtype=int), np.zeros((len(keyframes) - rem,), dtype=int)) - batch_div = int(num_frames//batch_size) - batch_rem = 1 if int(num_frames%batch_size) > 0 else 0 + frame_counts = np.full((len(keyframes), ), div) + \ + np.append(np.ones((rem,), dtype=int), np.zeros((len(keyframes) - rem, ), dtype=int)) + batch_div = int(num_frames // batch_size) + batch_rem = 1 if int(num_frames % batch_size) > 0 else 0 batch_count = batch_div + batch_rem - keyframes.append(keyframes[0])# seq returns to start - readahead = iter(keyframes) - next(readahead) - ''' - z_seq, label_seq, truncation_seq = [], [], [] - for (begin, end, frame_count) in zip(keyframes, readahead, frame_counts): - z_begin = np.asarray(begin['vector']) * begin['truncation'] - z_end = np.asarray(end['vector']) * end['truncation'] - z_seq.extend(linear_interp(z_begin, z_end, frame_count)) - label_begin = np.asarray(begin['label']) - label_end = np.asarray(end['label']) - label_seq.extend(linear_interp(label_begin, label_end, frame_count)) - truncation_seq.extend(np.linspace(begin['truncation'], end['truncation'], frame_count)) - ''' - # cubic interp + if loop is True: + keyframes.append(keyframes[0])# seq returns to start + truncation_keys = np.asarray([keyframe['truncation'] for keyframe in keyframes]) z_keys = np.asarray([np.asarray(keyframe['vector']) * keyframe['truncation'] for keyframe in keyframes]) label_keys = np.asarray([keyframe['label'] for keyframe in keyframes]) - z_seq = cubic_spline_interp(z_keys, num_frames) - label_seq = cubic_spline_interp(label_keys, num_frames) - truncation_seq = [] - for (begin, end, frame_count) in zip(keyframes, readahead, frame_counts): - truncation_seq.extend(np.linspace(begin['truncation'], end['truncation'], frame_count)) - """ - truncation_seq = np.reshape(\ - [np.linspace(trunc_begin, trunc_end, frame_count)\ - for (trunc_begin, trunc_end, frame_count)\ - in zip(truncation_keys[:-1], truncation_keys[1:], frame_counts)],\ - (-1,)) - """ + + z_seq = interp_fn(z_keys, num_frames) + label_seq = interp_fn(label_keys, num_frames) + truncation_seq = interp_fn(truncation_keys, num_frames) + # you can only change trunc once per batch - truncation_seq_resampled = signal.resample(truncation_seq, batch_count) + truncation_seq_resampled = np.full((1),truncation_seq[0])\ + if batch_count is 1\ + else signal.resample(truncation_seq, batch_count) return np.asarray(z_seq), np.asarray(label_seq), truncation_seq_resampled -- cgit v1.2.1