aboutsummaryrefslogtreecommitdiff
path: root/gantools/latent_space.py
diff options
context:
space:
mode:
authorVee9ahd1 <[email protected]>2019-08-21 22:21:10 -0400
committerVee9ahd1 <[email protected]>2019-08-21 22:21:10 -0400
commitf6cbf568ff9815c180095733560a567dcb70e859 (patch)
treea6edd6071ff65537ac83124edfdf2df2dc8b83c8 /gantools/latent_space.py
parentb7b043ac983613f02166b0b42bee70daff1539ef (diff)
added a check to make sure the correct amount of keys are passed to cubic interpolation (providing a more meaningful error message than TypeError('m > k must hold'))
Diffstat (limited to 'gantools/latent_space.py')
-rw-r--r--gantools/latent_space.py3
1 files changed, 3 insertions, 0 deletions
diff --git a/gantools/latent_space.py b/gantools/latent_space.py
index 44c64a3..64e3dec 100644
--- a/gantools/latent_space.py
+++ b/gantools/latent_space.py
@@ -21,8 +21,11 @@ def cubic_spline_interp(points, step_count):
tck = interpolate.splrep(x, y, s=0)
xnew = np.linspace(0., 1., step_count)
return interpolate.splev(xnew, tck, der=0)
+ if points.shape[0] < 4:
+ raise ValueError('Too few points for cubic interpolation: need 4, got {}'.format(points.shape[0]))
return np.apply_along_axis(cubic_spline_interp1d, 0, points)
+
# TODO: the math in this function is embarrasingly bad. fix at some point.
def sequence_keyframes(keyframes, num_frames, batch_size=1, interp_method='linear', loop=False):
interp_fn = {