aboutsummaryrefslogtreecommitdiff
path: root/gantools
diff options
context:
space:
mode:
authorVee9ahd1 <[email protected]>2019-06-03 12:17:29 -0400
committerVee9ahd1 <[email protected]>2019-06-03 12:17:29 -0400
commit69eb6678fc1c61fdb77f84f25f29ba9f777cd3dd (patch)
treef4320244e50ef7bb6ae34e216d9266f6c67ae1dc /gantools
parentdcc12c31b79ac9a17738d75adba65556624c412e (diff)
output directory and file prefix (base name) are now handled by different arguments
Diffstat (limited to 'gantools')
-rw-r--r--gantools/cli.py10
-rw-r--r--gantools/image_utils.py7
2 files changed, 10 insertions, 7 deletions
diff --git a/gantools/cli.py b/gantools/cli.py
index 73213ed..0ad3a8d 100644
--- a/gantools/cli.py
+++ b/gantools/cli.py
@@ -16,7 +16,8 @@ def main():
parser.add_argument('-b', '--nbatch', metavar='N', type=int, help='Number of frames in each \'batch\' \
(note: the truncation value can only change once per batch. Don\'t fuck with this unless you know \
what it does.).', default=1)
- parser.add_argument('-f', '--pathprefix', help='Directory path and file prefix for output images.')
+ parser.add_argument('-o', '--output-dir', help='Directory path for output images.')
+ parser.add_argument('--prefix', help='File prefix for output images.')
args = parser.parse_args()
# validate args
if args.keys and not (args.username and args.password):
@@ -38,8 +39,9 @@ def main():
ims = gan.sample(z_seq, label_seq, truncation_seq, args.nbatch)
# save images to file
- pathprefix = '' if args.pathprefix == None else str(args.pathprefix)
- print('Saving image files: '+pathprefix)
- image_utils.save_images(ims, pathprefix)
+ path = '' if args.output_dir == None else str(args.output_dir)
+ prefix = '' if args.prefix == None else str(args.prefix)
+ print('Saving image files: '+path + prefix)
+ image_utils.save_images(ims, output_dir=output_dir, prefix=prefix)
print('Done.')
diff --git a/gantools/image_utils.py b/gantools/image_utils.py
index 016891f..72130b0 100644
--- a/gantools/image_utils.py
+++ b/gantools/image_utils.py
@@ -1,10 +1,11 @@
+import os
import PIL.Image
def save_image(arr, fp, format='JPEG'):
image = PIL.Image.fromarray(arr)
image.save(fp, format=format, quality=90)
-def save_images(ims, path_prefix='', format='JPEG'):
+def save_images(ims, output_dir='', prefix='', format='JPEG'):
for i, im in enumerate(ims):
- path = str(path_prefix)+str(i).zfill(4)+'.'+str(format).lower()
- save_image(im, path)
+ full_path = os.path.join(output_dir, prefix + str(i).zfill(4) + '.' + format.lower())
+ save_image(im, full_path, format)