aboutsummaryrefslogtreecommitdiff
path: root/gantools/latent_space.py
diff options
context:
space:
mode:
Diffstat (limited to 'gantools/latent_space.py')
-rw-r--r--gantools/latent_space.py50
1 files changed, 39 insertions, 11 deletions
diff --git a/gantools/latent_space.py b/gantools/latent_space.py
index b89756b..e8e997a 100644
--- a/gantools/latent_space.py
+++ b/gantools/latent_space.py
@@ -1,24 +1,34 @@
import numpy as np
-from scipy import signal
+from scipy import signal, interpolate
def one_hot(index, dim):
- y = np.zeros((1,dim))
+ y = np.zeros((1, dim))
if index < dim:
- y[0,index] = 1.0
+ y[0, index] = 1.0
return y
-def interpolate(begin, end, step_count):
+# linear interpolation
+def linear_interp(begin, end, step_count):
initial = np.tile(begin, (step_count, 1))
delta = np.tile((end - begin)/step_count, (step_count, 1))
g = np.tile(np.arange(step_count), (begin.size, 1)).transpose()
return (delta * g) + initial
+def cubic_spline_interp(points, step_count):
+ def cubic_spline_interp1d(y):
+ x = np.linspace(0., 1., len(y))
+ tck = interpolate.splrep(x, y, s=0)
+ xnew = np.linspace(0., 1., step_count)
+ return interpolate.splev(xnew, tck, der=0)
+ p = np.asarray(points)
+ return np.apply_along_axis(cubic_spline_interp1d, 0, points)
+
# TODO: the math in this function is embarrasingly bad. fix at some point.
def sequence_keyframes(keyframes, num_frames, batch_size=1):
- div = int(num_frames//len(keyframes))
- rem = int(num_frames - (div*len(keyframes)))
+ div = int(num_frames // len(keyframes))
+ rem = int(num_frames - (div * len(keyframes)))
frame_counts = np.full((len(keyframes),), div) + \
- np.append(np.ones((rem,), dtype=int), np.zeros((len(keyframes)-rem,), dtype=int))
+ np.append(np.ones((rem,), dtype=int), np.zeros((len(keyframes) - rem,), dtype=int))
batch_div = int(num_frames//batch_size)
batch_rem = 1 if int(num_frames%batch_size) > 0 else 0
batch_count = batch_div + batch_rem
@@ -26,15 +36,33 @@ def sequence_keyframes(keyframes, num_frames, batch_size=1):
keyframes.append(keyframes[0])# seq returns to start
readahead = iter(keyframes)
next(readahead)
+ '''
z_seq, label_seq, truncation_seq = [], [], []
for (begin, end, frame_count) in zip(keyframes, readahead, frame_counts):
- z_begin = np.asarray(begin['vector'])*begin['truncation']
- z_end = np.asarray(end['vector'])*end['truncation']
- z_seq.extend(interpolate(z_begin, z_end, frame_count))
+ z_begin = np.asarray(begin['vector']) * begin['truncation']
+ z_end = np.asarray(end['vector']) * end['truncation']
+ z_seq.extend(linear_interp(z_begin, z_end, frame_count))
label_begin = np.asarray(begin['label'])
label_end = np.asarray(end['label'])
- label_seq.extend(interpolate(label_begin, label_end, frame_count))
+ label_seq.extend(linear_interp(label_begin, label_end, frame_count))
+ truncation_seq.extend(np.linspace(begin['truncation'], end['truncation'], frame_count))
+ '''
+ # cubic interp
+ truncation_keys = np.asarray([keyframe['truncation'] for keyframe in keyframes])
+ z_keys = np.asarray([np.asarray(keyframe['vector']) * keyframe['truncation'] for keyframe in keyframes])
+ label_keys = np.asarray([keyframe['label'] for keyframe in keyframes])
+ z_seq = cubic_spline_interp(z_keys, num_frames)
+ label_seq = cubic_spline_interp(label_keys, num_frames)
+ truncation_seq = []
+ for (begin, end, frame_count) in zip(keyframes, readahead, frame_counts):
truncation_seq.extend(np.linspace(begin['truncation'], end['truncation'], frame_count))
+ """
+ truncation_seq = np.reshape(\
+ [np.linspace(trunc_begin, trunc_end, frame_count)\
+ for (trunc_begin, trunc_end, frame_count)\
+ in zip(truncation_keys[:-1], truncation_keys[1:], frame_counts)],\
+ (-1,))
+ """
# you can only change trunc once per batch
truncation_seq_resampled = signal.resample(truncation_seq, batch_count)
return np.asarray(z_seq), np.asarray(label_seq), truncation_seq_resampled