python - “元组”对象没有属性“gpu_fraction”
问题描述
我现在正在使用 colab 来重现显着性检测!我是学生,所以请理解我的知识还不够,,,我找到了使用的代码tensorflow
,所以我正在尝试使用该代码来重现该项目。但是,作者说代码是在tensorflow 1.00上编写的,但我不知道tensorflow
我是否只是import tensorflow as tf
来自colab的版本。我收到错误
“元组”对象没有属性“gpu_fraction”
和
模块“tensorflow”没有属性“GPUOptions”
这是我的源代码,请看看我的问题是什么
import tensorflow as tf
import numpy as np
import os
from scipy import misc
import argparse
import sys
g_mean = np.array(([126.88,120.24,112.19])).reshape([1,1,3])
output_folder = "./test_output"
def rgba2rgb(img):
return img[:,:,:3]*np.expand_dims(img[:,:,3],2)
def main(args):
if not os.path.exists(output_folder):
os.mkdir(output_folder)
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction = args.gpu_fraction)
with tf.Session(config=tf.ConfigProto(gpu_options = gpu_options)) as sess:
saver = tf.train.import_meta_graph('./meta_graph/my-model.meta')
saver.restore(sess,tf.train.latest_checkpoint('./salience_model'))
image_batch = tf.get_collection('image_batch')[0]
pred_mattes = tf.get_collection('mask')[0]
if args.rgb_folder:
rgb_pths = os.listdir(args.rgb_folder)
for rgb_pth in rgb_pths:
rgb = misc.imread(os.path.join(args.rgb_folder,rgb_pth))
if rgb.shape[2]==4:
rgb = rgba2rgb(rgb)
origin_shape = rgb.shape
rgb = np.expand_dims(misc.imresize(rgb.astype(np.uint8),[320,320,3],interp="nearest").astype(np.float32)-g_mean,0)
feed_dict = {image_batch:rgb}
pred_alpha = sess.run(pred_mattes,feed_dict = feed_dict)
final_alpha = misc.imresize(np.squeeze(pred_alpha),origin_shape)
misc.imsave(os.path.join(output_folder,rgb_pth),final_alpha)
else:
rgb = misc.imread(args.rgb)
if rgb.shape[2]==4:
rgb = rgba2rgb(rgb)
origin_shape = rgb.shape[:2]
rgb = np.expand_dims(misc.imresize(rgb.astype(np.uint8),[320,320,3],interp="nearest").astype(np.float32)-g_mean,0)
feed_dict = {image_batch:rgb}
pred_alpha = sess.run(pred_mattes,feed_dict = feed_dict)
final_alpha = misc.imresize(np.squeeze(pred_alpha),origin_shape)
misc.imsave(os.path.join(output_folder,'alpha.png'),final_alpha)
def parse_arguments(argv):
parser = argparse.ArgumentParser()
parser.add_argument('--rgb', type=str,
help='input rgb',default = None)
parser.add_argument('--rgb_folder', type=str,
help='input rgb',default = None)
parser.add_argument('--gpu_fraction', type=float,
help='how much gpu is needed, usually 4G is enough',default = 1.0)
return parser.parse_args(argv)
if __name__ == '__main__':
main(parse_arguments(sys.argv[1:]))
解决方案
您可以尝试摆脱显式 GPU 配置,即将您的
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_fraction)
with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
进入
with tf.Session() as sess:
推荐阅读
- vb.net - Microsoft.DirectX.AudioVideoPlayback,如何设置音轨语言?
- bash - 在bash中选择字符串内的子字符串
- flutter - 如何处理差异响应 API?
- javascript - 在 Vanilla JavaScript 中表格的每一行中使用 AJAX `onKeyUp` 提交表单
- javascript - JavaScript 警报未显示。addEventListener HTML 表单
- c# - 从 TransferUtility 返回的流不能用于填充字节数组
- python - AttributeError:“str”对象没有属性“find_element”
- android - 而谷歌的短信一次性代码自动填充服务,API_NOT_AVAILABLE 状态是什么意思?
- java - Apache POI:一个单元格中不允许有多个单元格注释
- avfoundation - iOS - 如何检查相机胶卷中的类似视频?