python - Argparse:期望一个参数更改文件名
问题描述
嗨,我正在运行一些代码,而 Argparse 抛出了一个我无法追踪的错误。我运行此代码:
#!/bin/bash
# This script takes in a network model name (e.g. CWGAN1) and a log directory. It automatically
# chooses the latest model in the directory and renders text2shape outputs for the test set.
# Input text descriptions are in the sentences_fake_match.txt files
# Renderings are in the
# NOTE: If the directory contains model checkpoints 500, 1000, 2500, this script will choose model
# 500. If the directory contains 1000, 2500, 5000, this script will choose model 5000.
set -e
export TOOLKIT_PATH=/home/griffin/STK/sstk
# Make sure there are two arguments
if [ "$#" -ne 2 ]; then
echo "Illegal number of parameters"
exit 1
fi
# #
# # Setup
# #
echo "--- SETUP ---"
# Model type (e.g. CWGAN1)
echo "Model:" $1
# Log path of model
echo "Directory:" $2
log_path=$2
# #
# # Run test script
# #
# Remove .index portion of string
ckpt_path=$(find $log_path -maxdepth 1 -name 'model.ckpt-*.index' | sort | tail -n1)
ckpt_path=${ckpt_path/.index/}
test_command="python3 main.py --model $1 --test --val_split test --save_outputs --n_minibatch_test 200 --log_path $log_path --ckpt_path $ckpt_path --debug --noise_size 8 --uniform_max 0.5 --dataset shapenet"
echo "Executing:" $test_command
eval $test_command
echo ""
# #
# # Create NRRD files
# #
echo "--- NRRD GENERATION ---"
render_dir=$(basename $ckpt_path)
render_dir=${render_dir/model./}
render_dir=$log_path/$render_dir
gen_nrrd_command="python3 -m tools.scripts.generate_nrrd $render_dir"
echo "Executing:" $gen_nrrd_command
eval $gen_nrrd_command
echo ""
# #
# # Batch render using javascript script
# #
此代码运行 python.main:
import argparse
import numpy as np
import os
import pprint
import yaml
# HACK: Get logger to print to stdout
import sys
sys.ps1 = '>>> ' # Make it "interactive"
import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
from multiprocessing import Queue
from lib.config import cfg_from_file, cfg_from_list, cfg
from lib.data_process import make_data_processes, kill_processes
from lib.solver import Solver
from lib.solver_encoder import TextEncoderSolver, TextEncoderCosDistSolver, LBASolver
from lib.solver_gan import End2EndGANDebugSolver
from lib.solver_classifier import ClassifierSolver
from lib.cwgan import CWGAN
from lib.lba import LBA
from lib.classifier import Classifier
import lib.utils as utils
import models
del sys.ps1 # HACK: Get logger to print to stdout
def parse_args():
"""Parse the arguments.
"""
parser = argparse.ArgumentParser(
description='Main text2voxel train/test file.')
parser.add_argument('--cfg',
dest='cfg_files',
action='append',
help='optional config file',
default=None,
type=str)
parser.add_argument('--dont_save_voxels', dest='dont_save_voxels', action='store_true')
parser.add_argument('--lba_only', dest='lba_only', action='store_true')
parser.add_argument('--metric_learning_only', dest='metric_learning_only', action='store_true')
parser.add_argument('--non_inverted_loss', dest='non_inverted_loss', action='store_true')
parser.add_argument('--synth_embedding', dest='synth_embedding', action='store_true')
parser.add_argument('--all_tuples', dest='all_tuples', action='store_true')
parser.add_argument('--reed_classifier', dest='reed_classifier', action='store_true')
parser.add_argument('--val_split',
dest='split',
help='data split for validation/testing (train, val, test)',
default=None,
type=str)
parser.add_argument('--queue_capacity',
dest='queue_capacity',
help='size of queue',
default=None,
type=int)
parser.add_argument('--n_minibatch_test',
dest='n_minibatch_test',
help='number of minibatches to use for test phase',
default=None,
type=int)
parser.add_argument('--dataset', dest='dataset',
help='dataset',
default=None,
type=str)
parser.add_argument('--improved_wgan', dest='improved_wgan', action='store_true')
parser.add_argument('--debug', dest='is_debug', action='store_true')
parser.add_argument('--rand', dest='randomize',
help='randomize (do not use a fixed seed)',
action='store_true')
parser.add_argument('--tiny_dataset', dest='tiny_dataset',
help='use a tiny dataset (~5 examples)',
action='store_true')
parser.add_argument('--model',
dest='model',
help='name of the network model',
default=None,
type=str)
parser.add_argument('--text_encoder', dest='text_encoder',
help='train/test on text encoder',
action='store_true')
parser.add_argument('--classifier', dest='classifier',
help='train/test on classifier',
action='store_true')
parser.add_argument('--end2end', dest='end2end',
help='train/test using end2end model such as End2EndLBACWGAN',
action='store_true')
parser.add_argument('--shapenet_ct_classifier', dest='shapenet_ct_classifier',
help='chair/table classifier (sets up for classification)',
action='store_true')
parser.add_argument('--noise_size',
dest='noise_size',
help='dimension of the noise',
default=None,
type=int)
parser.add_argument('--noise_dist', dest='noise_dist',
help='noise distribution (uniform, gaussian)',
default=None,
type=str)
parser.add_argument('--validation', dest='validation',
help='run validation while training',
action='store_true')
parser.add_argument('--test', dest='test',
help='test mode',
action='store_true')
parser.add_argument('--test_npy', dest='test_npy',
help='test mode using npy files',
action='store_true')
parser.add_argument('--save_outputs', dest='save_outputs',
help='save the outputs to a file',
action='store_true')
parser.add_argument('--summary_freq',
dest='summary_freq',
help='summary frequency',
default=None,
type=int)
parser.add_argument('--optimizer',
dest='optimizer',
help='name of the optimizer',
default=None,
type=str)
parser.add_argument('--critic_optimizer',
dest='critic_optimizer',
help='name of the critic optimizer',
default=None,
type=str)
parser.add_argument('--batch_size',
dest='batch_size',
help='batch size',
default=None,
type=int)
parser.add_argument('--lba_mode',
dest='lba_mode',
help='LBA mode type (TST, STS, MM)',
default=None,
type=str)
parser.add_argument('--lba_test_mode',
dest='lba_test_mode',
help='LBA test mode (shape, text) - what to input during forward pass',
default=None,
type=str)
parser.add_argument('--visit_weight',
dest='visit_weight',
help='visit weight for lba models',
default=None,
type=float)
parser.add_argument('--lba_unnormalize', dest='lba_unnormalize', action='store_true')
parser.add_argument('--num_critic_steps',
dest='num_critic_steps',
help='number of critic steps per train step',
default=None,
type=int)
parser.add_argument('--intense_training_freq',
dest='intense_training_freq',
help='frequency of intense critic training',
default=None,
type=int)
parser.add_argument('--uniform_max',
dest='uniform_max',
help='absolute max for uniform distribution',
default=None,
type=float)
parser.add_argument('--match_loss_coeff',
dest='match_loss_coeff',
help='coefficient for real match loss',
default=None,
type=float)
parser.add_argument('--fake_match_loss_coeff',
dest='fake_match_loss_coeff',
help='coefficient for fake match loss',
default=None,
type=float)
parser.add_argument('--fake_mismatch_loss_coeff',
dest='fake_mismatch_loss_coeff',
help='coefficient for fake mismatch loss',
default=None,
type=float)
parser.add_argument('--gp_weight',
dest='gp_weight',
help='coefficient for gradient penalty',
default=None,
type=float)
parser.add_argument('--text2text_weight',
dest='text2text_weight',
help='coefficient for text2text loss',
default=None,
type=float)
parser.add_argument('--shape2shape_weight',
dest='shape2shape_weight',
help='coefficient for shape2shape loss',
default=None,
type=float)
parser.add_argument('--learning_rate',
dest='learning_rate',
help='learning rate',
default=None,
type=float)
parser.add_argument('--critic_lr_multiplier',
dest='critic_lr_multiplier',
help='critic learning rate multiplier',
default=None,
type=float)
parser.add_argument('--decay_steps',
dest='decay_steps',
help='decay steps',
default=None,
type=int)
parser.add_argument('--num_epochs',
dest='num_epochs',
help='number of epochs',
default=None,
type=int)
parser.add_argument('--augment_max',
dest='augment_max',
help='maximum augmentation perturbation out of 255',
default=None,
type=int)
parser.add_argument('--set',
dest='set_cfgs',
help='set config keys',
default=None,
nargs=argparse.REMAINDER)
parser.add_argument('--ckpt_path', dest='ckpt_path',
help='Initialize network from checkpoint',
default=None)
parser.add_argument('--lba_ckpt_path', dest='lba_ckpt_path',
help='Initialize LBA component of end2endlbawgan network from checkpoint',
default=None)
parser.add_argument('--val_ckpt_path', dest='val_ckpt_path',
help='Initialize validation network from checkpoint',
default=None)
parser.add_argument('--log_path', dest='log_path', help='set log path',
default=None)
args = parser.parse_args()
return args
def modify_args(args):
"""Modify the default config based on the command line arguments.
"""
# modify default config if requested
if args.cfg_files is not None:
for cfg_file in args.cfg_files:
cfg_from_file(cfg_file)
randomize = args.randomize
if args.test: # Always randomize in test phase
randomize = True
if not randomize:
np.random.seed(cfg.CONST.RNG_SEED)
# NOTE: Unfortunately order matters here
if args.lba_only is True:
cfg_from_list(['LBA.COSINE_DIST', False])
if args.metric_learning_only is True:
cfg_from_list(['LBA.NO_LBA', True])
if args.non_inverted_loss is True:
cfg_from_list(['LBA.INVERTED_LOSS', False])
if args.dataset is not None:
cfg_from_list(['CONST.DATASET', args.dataset])
if args.lba_mode is not None:
cfg_from_list(['LBA.MODEL_TYPE', args.lba_mode])
if args.lba_test_mode is not None:
cfg_from_list(['LBA.TEST_MODE', args.lba_test_mode])
# cfg_from_list(['LBA.N_CAPTIONS_PER_MODEL', 1]) # NOTE: Important!
if args.shapenet_ct_classifier is True:
cfg_from_list(['CONST.SHAPENET_CT_CLASSIFIER', args.shapenet_ct_classifier])
if args.visit_weight is not None:
cfg_from_list(['LBA.VISIT_WEIGHT', args.visit_weight])
if args.lba_unnormalize is True:
cfg_from_list(['LBA.NORMALIZE', False])
if args.improved_wgan is True:
cfg_from_list(['CONST.IMPROVED_WGAN', args.improved_wgan])
if args.synth_embedding is True:
cfg_from_list(['CONST.SYNTH_EMBEDDING', args.synth_embedding])
if args.all_tuples is True:
cfg_from_list(['CONST.TEST_ALL_TUPLES', args.all_tuples])
if args.reed_classifier is True:
cfg_from_list(['CONST.REED_CLASSIFIER', args.reed_classifier])
if args.noise_dist is not None:
cfg_from_list(['GAN.NOISE_DIST', args.noise_dist])
if args.uniform_max is not None:
cfg_from_list(['GAN.NOISE_UNIF_ABS_MAX', args.uniform_max])
if args.num_critic_steps is not None:
cfg_from_list(['WGAN.NUM_CRITIC_STEPS', args.num_critic_steps])
if args.intense_training_freq is not None:
cfg_from_list(['WGAN.INTENSE_TRAINING_FREQ', args.intense_training_freq])
if args.match_loss_coeff is not None:
cfg_from_list(['WGAN.MATCH_LOSS_COEFF', args.match_loss_coeff])
if args.fake_match_loss_coeff is not None:
cfg_from_list(['WGAN.FAKE_MATCH_LOSS_COEFF', args.fake_match_loss_coeff])
if args.fake_mismatch_loss_coeff is not None:
cfg_from_list(['WGAN.FAKE_MISMATCH_LOSS_COEFF', args.fake_mismatch_loss_coeff])
if args.gp_weight is not None:
cfg_from_list(['WGAN.GP_COEFF', args.gp_weight])
if args.text2text_weight is not None:
cfg_from_list(['WGAN.TEXT2TEXT_WEIGHT', args.text2text_weight])
if args.shape2shape_weight is not None:
cfg_from_list(['WGAN.SHAPE2SHAPE_WEIGHT', args.shape2shape_weight])
if args.learning_rate is not None:
cfg_from_list(['TRAIN.LEARNING_RATE', args.learning_rate])
if args.critic_lr_multiplier is not None:
cfg_from_list(['GAN.D_LEARNING_RATE_MULTIPLIER', args.critic_lr_multiplier])
if args.decay_steps is not None:
cfg_from_list(['TRAIN.DECAY_STEPS', args.decay_steps])
if args.queue_capacity is not None:
cfg_from_list(['CONST.QUEUE_CAPACITY', args.queue_capacity])
if args.n_minibatch_test is not None:
cfg_from_list(['CONST.N_MINIBATCH_TEST', args.n_minibatch_test])
if args.noise_size is not None:
cfg_from_list(['GAN.NOISE_SIZE', args.noise_size])
if args.batch_size is not None:
cfg_from_list(['CONST.BATCH_SIZE', args.batch_size])
if args.summary_freq is not None:
cfg_from_list(['TRAIN.SUMMARY_FREQ', args.summary_freq])
if args.num_epochs is not None:
cfg_from_list(['TRAIN.NUM_EPOCHS', args.num_epochs])
if args.model is not None:
cfg_from_list(['NETWORK', args.model])
if args.optimizer is not None:
cfg_from_list(['TRAIN.OPTIMIZER', args.optimizer])
if args.critic_optimizer is not None:
cfg_from_list(['GAN.D_OPTIMIZER', args.critic_optimizer])
if args.ckpt_path is not None:
cfg_from_list(['DIR.CKPT_PATH', args.ckpt_path])
if args.lba_ckpt_path is not None:
cfg_from_list(['END2END.LBA_CKPT_PATH', args.lba_ckpt_path])
if args.val_ckpt_path is not None:
cfg_from_list(['DIR.VAL_CKPT_PATH', args.val_ckpt_path])
if args.log_path is not None:
cfg_from_list(['DIR.LOG_PATH', args.log_path])
if args.augment_max is not None:
cfg_from_list(['TRAIN.AUGMENT_MAX', args.augment_max])
if args.test:
cfg_from_list(['TRAIN.AUGMENT_MAX', 0])
cfg_from_list(['CONST.BATCH_SIZE', 1])
cfg_from_list(['LBA.N_CAPTIONS_PER_MODEL', 1]) # NOTE: Important!
cfg_from_list(['LBA.N_PRIMITIVE_SHAPES_PER_CATEGORY', 1]) # NOTE: Important!
if args.test_npy:
cfg_from_list(['CONST.BATCH_SIZE', 1])
# To overwrite default variables, put the set_cfgs after all argument initializations
if args.set_cfgs is not None:
cfg_from_list(args.set_cfgs)
def get_inputs_dict(args):
"""Gets the input dict for the current model and dataset.
"""
if cfg.CONST.DATASET == 'shapenet':
if (args.text_encoder is True) or (args.end2end is True) or (args.classifier is True):
inputs_dict = utils.open_pickle(cfg.DIR.TRAIN_DATA_PATH)
val_inputs_dict = utils.open_pickle(cfg.DIR.VAL_DATA_PATH)
test_inputs_dict = utils.open_pickle(cfg.DIR.TEST_DATA_PATH)
else: # Learned embeddings
inputs_dict = utils.open_pickle(cfg.DIR.SHAPENET_METRIC_EMBEDDINGS_TRAIN)
val_inputs_dict = utils.open_pickle(cfg.DIR.SHAPENET_METRIC_EMBEDDINGS_VAL)
test_inputs_dict = utils.open_pickle(cfg.DIR.SHAPENET_METRIC_EMBEDDINGS_TEST)
elif cfg.CONST.DATASET == 'primitives':
if ((cfg.CONST.SYNTH_EMBEDDING is True) or (args.text_encoder is True) or
(args.classifier is True)):
if args.classifier and not cfg.CONST.REED_CLASSIFIER: # Train on all splits for classifier
tf.logging.info('Using all (train/val/test) splits for training')
inputs_dict = utils.open_pickle(cfg.DIR.PRIMITIVES_ALL_SPLITS_DATA_PATH)
else:
tf.logging.info('Using train split only for training')
inputs_dict = utils.open_pickle(cfg.DIR.PRIMITIVES_TRAIN_DATA_PATH)
val_inputs_dict = utils.open_pickle(cfg.DIR.PRIMITIVES_VAL_DATA_PATH)
test_inputs_dict = utils.open_pickle(cfg.DIR.PRIMITIVES_TEST_DATA_PATH)
else: # Learned embeddings
inputs_dict = utils.open_pickle(cfg.DIR.PRIMITIVES_METRIC_EMBEDDINGS_TRAIN)
val_inputs_dict = utils.open_pickle(cfg.DIR.PRIMITIVES_METRIC_EMBEDDINGS_VAL)
test_inputs_dict = utils.open_pickle(cfg.DIR.PRIMITIVES_METRIC_EMBEDDINGS_TEST)
else:
raise ValueError('Please use a valid dataset (shapenet, primitives).')
if args.tiny_dataset is True:
if ((cfg.CONST.DATASET == 'primitives' and cfg.CONST.SYNTH_EMBEDDING is True)
or (args.text_encoder is True)):
raise NotImplementedError('Tiny dataset not supported for synthetic embeddings.')
ds = 5 # New dataset size
if cfg.CONST.BATCH_SIZE > ds:
raise ValueError('Please use a smaller batch size than {}.'.format(ds))
inputs_dict = utils.change_dataset_size(inputs_dict, new_dataset_size=ds)
val_inputs_dict = utils.change_dataset_size(val_inputs_dict, new_dataset_size=ds)
test_inputs_dict = utils.change_dataset_size(test_inputs_dict, new_dataset_size=ds)
# Select the validation/test split
if args.split == 'train':
split_str = 'train'
val_inputs_dict = inputs_dict
elif (args.split == 'val') or (args.split is None):
split_str = 'val'
val_inputs_dict = val_inputs_dict
elif args.split == 'test':
split_str = 'test'
val_inputs_dict = test_inputs_dict
else:
raise ValueError('Please select a valid split (train, val, test).')
print('Validation/testing on {} split.'.format(split_str))
if (cfg.CONST.DATASET == 'shapenet') and (cfg.CONST.SHAPENET_CT_CLASSIFIER is True):
category_model_list, class_labels = Classifier.set_up_classification(inputs_dict)
val_category_model_list, val_class_labels = Classifier.set_up_classification(val_inputs_dict)
assert class_labels == val_class_labels
# Update inputs dicts
inputs_dict['category_model_list'] = category_model_list
inputs_dict['class_labels'] = class_labels
val_inputs_dict['category_model_list'] = val_category_model_list
val_inputs_dict['class_labels'] = val_class_labels
return inputs_dict, val_inputs_dict
def get_solver(g, net, args, is_training):
if isinstance(net, LBA):
solver = LBASolver(net, g, is_training)
elif args.text_encoder:
solver = TextEncoderSolver(net, g, is_training)
elif isinstance(net, Classifier):
solver = ClassifierSolver(net, g, is_training)
elif isinstance(net, CWGAN):
solver = End2EndGANDebugSolver(net, g, is_training)
else:
raise ValueError('Invalid network.')
return solver
def main():
"""Main text2voxel function.
"""
args = parse_args()
print('Called with args:')
print(args)
if args.save_outputs is True and args.test is False:
raise ValueError('Can only save outputs when testing, not training.')
if args.validation:
assert not args.test
if args.test:
assert args.ckpt_path is not None
modify_args(args)
print('----------------- CONFIG -------------------')
pprint.pprint(cfg)
# Save yaml
os.makedirs(cfg.DIR.LOG_PATH, exist_ok=True)
with open(os.path.join(cfg.DIR.LOG_PATH, 'run_cfg.yaml'), 'w') as out_yaml:
yaml.dump(cfg, out_yaml, default_flow_style=False)
# set up logger
tf.logging.set_verbosity(tf.logging.INFO)
try:
with tf.Graph().as_default() as g: # create graph
# Load data
inputs_dict, val_inputs_dict = get_inputs_dict(args)
# Build network
is_training = not args.test
print('------------ BUILDING NETWORK -------------')
network_class = models.load_model(cfg.NETWORK)
net = network_class(inputs_dict, is_training)
# Prefetching data processes
#
# Create worker and data queue for data processing. For training data, use
# multiple processes to speed up the loading. For validation data, use 1
# since the queue will be popped every TRAIN.NUM_VALIDATION_ITERATIONS.
# set up data queue and start enqueue
np.random.seed(123)
data_process_class = models.get_data_process_pairs(cfg.NETWORK, is_training)
val_data_process_class = models.get_data_process_pairs(cfg.NETWORK, is_training=False)
if is_training:
global train_queue, train_processes
train_queue = Queue(cfg.CONST.QUEUE_CAPACITY)
train_processes = make_data_processes(data_process_class, train_queue, inputs_dict,
cfg.CONST.NUM_WORKERS, repeat=True)
if args.validation:
global val_queue, val_processes
val_queue = Queue(cfg.CONST.QUEUE_CAPACITY)
val_processes = make_data_processes(val_data_process_class, val_queue,
val_inputs_dict, 1, repeat=True)
else:
global test_queue, test_processes
test_inputs_dict = val_inputs_dict
test_queue = Queue(cfg.CONST.QUEUE_CAPACITY)
test_processes = make_data_processes(val_data_process_class, test_queue,
test_inputs_dict, 1, repeat=False)
# Create solver
solver = get_solver(g, net, args, is_training)
# Run solver
if is_training:
if args.validation:
if cfg.DIR.VAL_CKPT_PATH is not None:
assert train_processes[0].iters_per_epoch != 0
assert val_processes[0].iters_per_epoch != 0
solver.train(train_processes[0].iters_per_epoch, train_queue,
val_processes[0].iters_per_epoch, val_queue=val_queue,
val_inputs_dict=val_inputs_dict)
else:
if isinstance(net, LBA):
assert cfg.LBA.TEST_MODE is not None
assert cfg.LBA.TEST_MODE == 'shape'
assert train_processes[0].iters_per_epoch != 0
assert val_processes[0].iters_per_epoch != 0
solver.train(train_processes[0].iters_per_epoch, train_queue,
val_processes[0].iters_per_epoch, val_queue=val_queue,
val_inputs_dict=val_inputs_dict)
else:
assert train_processes[0].iters_per_epoch != 0
assert val_processes[0].iters_per_epoch != 0
solver.train(train_processes[0].iters_per_epoch, train_queue,
val_processes[0].iters_per_epoch, val_queue=val_queue)
else:
solver.train(train_processes[0].iters_per_epoch, train_queue)
else:
solver.test(test_processes[0], test_queue,
num_minibatches=cfg.CONST.N_MINIBATCH_TEST,
save_outputs=args.save_outputs)
finally:
# Clean up the processes and queues
if is_training:
kill_processes(train_queue, train_processes)
if args.validation:
kill_processes(val_queue, val_processes)
else:
kill_processes(test_queue, test_processes)
if __name__ == '__main__':
main()
当我运行代码时,我收到错误:
main.py: error: argument --ckpt_path: expected one argument
--ckpt_path 的输入文件名看起来像 model.ckpt-2500.index、model.ckpt-20000.index 等。我只是删除索引,以便我的 nrrd 文件可以有一个 .txt 文件名并且不理解为什么 argparse 有麻烦。
解决方案
推荐阅读
- bootstrap-4 - 自动滚动网页
- python - 将 Pandas 时间序列转换为时间增量
- node.js - Swagger-ui 不适用于 restify 但适用于 express
- ios - 没有从 iOS 中的位置管理器获取我的当前地址
- php - 使用 mysqli 类在 PHP 中返回 MySQL 查询
- javascript - Intersection Observer API:观察视口的中心
- unix - Unix 显示学生记录
- python - 如何在 Python 中迭代字符串的长度时迭代字符串
- django-mptt - 使用 TreeManager 时 get_children() 方法出现问题。错误的结果
- c++ - 将具有面的对象旋转到另一个对象时出现问题。[C++]