diff options
-rw-r--r-- | gantools/latent_space.py | 50 |
1 files changed, 39 insertions, 11 deletions
diff --git a/gantools/latent_space.py b/gantools/latent_space.py index b89756b..e8e997a 100644 --- a/gantools/latent_space.py +++ b/gantools/latent_space.py @@ -1,24 +1,34 @@ import numpy as np -from scipy import signal +from scipy import signal, interpolate def one_hot(index, dim): - y = np.zeros((1,dim)) + y = np.zeros((1, dim)) if index < dim: - y[0,index] = 1.0 + y[0, index] = 1.0 return y -def interpolate(begin, end, step_count): +# linear interpolation +def linear_interp(begin, end, step_count): initial = np.tile(begin, (step_count, 1)) delta = np.tile((end - begin)/step_count, (step_count, 1)) g = np.tile(np.arange(step_count), (begin.size, 1)).transpose() return (delta * g) + initial +def cubic_spline_interp(points, step_count): + def cubic_spline_interp1d(y): + x = np.linspace(0., 1., len(y)) + tck = interpolate.splrep(x, y, s=0) + xnew = np.linspace(0., 1., step_count) + return interpolate.splev(xnew, tck, der=0) + p = np.asarray(points) + return np.apply_along_axis(cubic_spline_interp1d, 0, points) + # TODO: the math in this function is embarrasingly bad. fix at some point. def sequence_keyframes(keyframes, num_frames, batch_size=1): - div = int(num_frames//len(keyframes)) - rem = int(num_frames - (div*len(keyframes))) + div = int(num_frames // len(keyframes)) + rem = int(num_frames - (div * len(keyframes))) frame_counts = np.full((len(keyframes),), div) + \ - np.append(np.ones((rem,), dtype=int), np.zeros((len(keyframes)-rem,), dtype=int)) + np.append(np.ones((rem,), dtype=int), np.zeros((len(keyframes) - rem,), dtype=int)) batch_div = int(num_frames//batch_size) batch_rem = 1 if int(num_frames%batch_size) > 0 else 0 batch_count = batch_div + batch_rem @@ -26,15 +36,33 @@ def sequence_keyframes(keyframes, num_frames, batch_size=1): keyframes.append(keyframes[0])# seq returns to start readahead = iter(keyframes) next(readahead) + ''' z_seq, label_seq, truncation_seq = [], [], [] for (begin, end, frame_count) in zip(keyframes, readahead, frame_counts): - z_begin = np.asarray(begin['vector'])*begin['truncation'] - z_end = np.asarray(end['vector'])*end['truncation'] - z_seq.extend(interpolate(z_begin, z_end, frame_count)) + z_begin = np.asarray(begin['vector']) * begin['truncation'] + z_end = np.asarray(end['vector']) * end['truncation'] + z_seq.extend(linear_interp(z_begin, z_end, frame_count)) label_begin = np.asarray(begin['label']) label_end = np.asarray(end['label']) - label_seq.extend(interpolate(label_begin, label_end, frame_count)) + label_seq.extend(linear_interp(label_begin, label_end, frame_count)) + truncation_seq.extend(np.linspace(begin['truncation'], end['truncation'], frame_count)) + ''' + # cubic interp + truncation_keys = np.asarray([keyframe['truncation'] for keyframe in keyframes]) + z_keys = np.asarray([np.asarray(keyframe['vector']) * keyframe['truncation'] for keyframe in keyframes]) + label_keys = np.asarray([keyframe['label'] for keyframe in keyframes]) + z_seq = cubic_spline_interp(z_keys, num_frames) + label_seq = cubic_spline_interp(label_keys, num_frames) + truncation_seq = [] + for (begin, end, frame_count) in zip(keyframes, readahead, frame_counts): truncation_seq.extend(np.linspace(begin['truncation'], end['truncation'], frame_count)) + """ + truncation_seq = np.reshape(\ + [np.linspace(trunc_begin, trunc_end, frame_count)\ + for (trunc_begin, trunc_end, frame_count)\ + in zip(truncation_keys[:-1], truncation_keys[1:], frame_counts)],\ + (-1,)) + """ # you can only change trunc once per batch truncation_seq_resampled = signal.resample(truncation_seq, batch_count) return np.asarray(z_seq), np.asarray(label_seq), truncation_seq_resampled |