import numpy as np from scipy import signal, interpolate def one_hot(index, dim): y = np.zeros((1, dim)) if index < dim: y[0, index] = 1.0 return y # linear interpolation def linear_interp(begin, end, step_count): initial = np.tile(begin, (step_count, 1)) delta = np.tile((end - begin)/step_count, (step_count, 1)) g = np.tile(np.arange(step_count), (begin.size, 1)).transpose() return (delta * g) + initial def cubic_spline_interp(points, step_count): def cubic_spline_interp1d(y): x = np.linspace(0., 1., len(y)) tck = interpolate.splrep(x, y, s=0) xnew = np.linspace(0., 1., step_count) return interpolate.splev(xnew, tck, der=0) p = np.asarray(points) return np.apply_along_axis(cubic_spline_interp1d, 0, points) # TODO: the math in this function is embarrasingly bad. fix at some point. def sequence_keyframes(keyframes, num_frames, batch_size=1): div = int(num_frames // len(keyframes)) rem = int(num_frames - (div * len(keyframes))) frame_counts = np.full((len(keyframes),), div) + \ np.append(np.ones((rem,), dtype=int), np.zeros((len(keyframes) - rem,), dtype=int)) batch_div = int(num_frames//batch_size) batch_rem = 1 if int(num_frames%batch_size) > 0 else 0 batch_count = batch_div + batch_rem keyframes.append(keyframes[0])# seq returns to start readahead = iter(keyframes) next(readahead) ''' z_seq, label_seq, truncation_seq = [], [], [] for (begin, end, frame_count) in zip(keyframes, readahead, frame_counts): z_begin = np.asarray(begin['vector']) * begin['truncation'] z_end = np.asarray(end['vector']) * end['truncation'] z_seq.extend(linear_interp(z_begin, z_end, frame_count)) label_begin = np.asarray(begin['label']) label_end = np.asarray(end['label']) label_seq.extend(linear_interp(label_begin, label_end, frame_count)) truncation_seq.extend(np.linspace(begin['truncation'], end['truncation'], frame_count)) ''' # cubic interp truncation_keys = np.asarray([keyframe['truncation'] for keyframe in keyframes]) z_keys = np.asarray([np.asarray(keyframe['vector']) * keyframe['truncation'] for keyframe in keyframes]) label_keys = np.asarray([keyframe['label'] for keyframe in keyframes]) z_seq = cubic_spline_interp(z_keys, num_frames) label_seq = cubic_spline_interp(label_keys, num_frames) truncation_seq = [] for (begin, end, frame_count) in zip(keyframes, readahead, frame_counts): truncation_seq.extend(np.linspace(begin['truncation'], end['truncation'], frame_count)) """ truncation_seq = np.reshape(\ [np.linspace(trunc_begin, trunc_end, frame_count)\ for (trunc_begin, trunc_end, frame_count)\ in zip(truncation_keys[:-1], truncation_keys[1:], frame_counts)],\ (-1,)) """ # you can only change trunc once per batch truncation_seq_resampled = signal.resample(truncation_seq, batch_count) return np.asarray(z_seq), np.asarray(label_seq), truncation_seq_resampled