aboutsummaryrefslogtreecommitdiff
path: root/gantools/latent_space.py
diff options
context:
space:
mode:
Diffstat (limited to 'gantools/latent_space.py')
-rw-r--r--gantools/latent_space.py40
1 files changed, 40 insertions, 0 deletions
diff --git a/gantools/latent_space.py b/gantools/latent_space.py
new file mode 100644
index 0000000..b89756b
--- /dev/null
+++ b/gantools/latent_space.py
@@ -0,0 +1,40 @@
+import numpy as np
+from scipy import signal
+
+def one_hot(index, dim):
+ y = np.zeros((1,dim))
+ if index < dim:
+ y[0,index] = 1.0
+ return y
+
+def interpolate(begin, end, step_count):
+ initial = np.tile(begin, (step_count, 1))
+ delta = np.tile((end - begin)/step_count, (step_count, 1))
+ g = np.tile(np.arange(step_count), (begin.size, 1)).transpose()
+ return (delta * g) + initial
+
+# TODO: the math in this function is embarrasingly bad. fix at some point.
+def sequence_keyframes(keyframes, num_frames, batch_size=1):
+ div = int(num_frames//len(keyframes))
+ rem = int(num_frames - (div*len(keyframes)))
+ frame_counts = np.full((len(keyframes),), div) + \
+ np.append(np.ones((rem,), dtype=int), np.zeros((len(keyframes)-rem,), dtype=int))
+ batch_div = int(num_frames//batch_size)
+ batch_rem = 1 if int(num_frames%batch_size) > 0 else 0
+ batch_count = batch_div + batch_rem
+
+ keyframes.append(keyframes[0])# seq returns to start
+ readahead = iter(keyframes)
+ next(readahead)
+ z_seq, label_seq, truncation_seq = [], [], []
+ for (begin, end, frame_count) in zip(keyframes, readahead, frame_counts):
+ z_begin = np.asarray(begin['vector'])*begin['truncation']
+ z_end = np.asarray(end['vector'])*end['truncation']
+ z_seq.extend(interpolate(z_begin, z_end, frame_count))
+ label_begin = np.asarray(begin['label'])
+ label_end = np.asarray(end['label'])
+ label_seq.extend(interpolate(label_begin, label_end, frame_count))
+ truncation_seq.extend(np.linspace(begin['truncation'], end['truncation'], frame_count))
+ # you can only change trunc once per batch
+ truncation_seq_resampled = signal.resample(truncation_seq, batch_count)
+ return np.asarray(z_seq), np.asarray(label_seq), truncation_seq_resampled