import numpy as np from scipy import signal def one_hot(index, dim): y = np.zeros((1,dim)) if index < dim: y[0,index] = 1.0 return y def interpolate(begin, end, step_count): initial = np.tile(begin, (step_count, 1)) delta = np.tile((end - begin)/step_count, (step_count, 1)) g = np.tile(np.arange(step_count), (begin.size, 1)).transpose() return (delta * g) + initial # TODO: the math in this function is embarrasingly bad. fix at some point. def sequence_keyframes(keyframes, num_frames, batch_size=1): div = int(num_frames//len(keyframes)) rem = int(num_frames - (div*len(keyframes))) frame_counts = np.full((len(keyframes),), div) + \ np.append(np.ones((rem,), dtype=int), np.zeros((len(keyframes)-rem,), dtype=int)) batch_div = int(num_frames//batch_size) batch_rem = 1 if int(num_frames%batch_size) > 0 else 0 batch_count = batch_div + batch_rem keyframes.append(keyframes[0])# seq returns to start readahead = iter(keyframes) next(readahead) z_seq, label_seq, truncation_seq = [], [], [] for (begin, end, frame_count) in zip(keyframes, readahead, frame_counts): z_begin = np.asarray(begin['vector'])*begin['truncation'] z_end = np.asarray(end['vector'])*end['truncation'] z_seq.extend(interpolate(z_begin, z_end, frame_count)) label_begin = np.asarray(begin['label']) label_end = np.asarray(end['label']) label_seq.extend(interpolate(label_begin, label_end, frame_count)) truncation_seq.extend(np.linspace(begin['truncation'], end['truncation'], frame_count)) # you can only change trunc once per batch truncation_seq_resampled = signal.resample(truncation_seq, batch_count) return np.asarray(z_seq), np.asarray(label_seq), truncation_seq_resampled