python - 使用 tensorflow 2.0 进行模型预测会导致 python 内核死机并重新启动
问题描述
我通过简单地使用 tf.keras 而不是 keras 将我的模型从 tensorflow 1.15 移植到 2.0。即使使用 GPU,该模型在 tf 1.15 下也可以正常工作,但是在 tf 2.0 或 2.1 的情况下,当我调用 mode.predict() 时,它会导致 python 内核死亡并重新启动。根本没有错误消息。对于 tensorflow 2.1,现象是一样的。
模型创建工作正常,因此编译模型,从文件加载权重。问题是由调用 model.predict() 引起的。
有什么提示吗?
import numpy as np
from IPython.display import clear_output
import time
import tensorflow as tf
import os
os.environ["PATH"] += os.pathsep + "C:/Program Files (x86)/Graphviz2.38/bin/"
from tensorflow.keras.layers import Lambda
from tensorflow.keras.layers import Input, Dense, Conv2D, Conv1D, Flatten, BatchNormalization, Activation, LeakyReLU, add, Subtract
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Dense, Activation, Flatten, Input, Concatenate, Reshape
from tensorflow.keras.optimizers import Adam, Nadam
from tensorflow.keras.utils import plot_model
from tensorflow.keras.initializers import RandomNormal
from tensorflow.keras import regularizers
import tensorflow.keras.backend as K
def residual_layer( input_block, filters, kernel_size, use_bias=True, use_reg = True,
reg_const = 0.0001):
x = input_block
if use_reg:
x = Conv2D(
filters=filters
, kernel_size=kernel_size
, data_format="channels_first"
, padding='same'
, use_bias=use_bias
, activation='linear'
, kernel_regularizer=regularizers.l2(reg_const)
)(x)
else:
x = Conv2D(
filters=filters
, kernel_size=kernel_size
, data_format="channels_first"
, padding='same'
, use_bias=use_bias
, activation='linear'
)(x)
x = BatchNormalization()(x)
#x = BatchNormalization(axis=1)(x)
x = add([input_block, x])
x = LeakyReLU()(x)
return (x)
def conv_layer( x, filters, kernel_size, padding='same', use_bias=True, use_reg = True,
reg_const = 0.0001, use_batchnorm = True, use_LeakyReLU = True):
if use_reg:
x = Conv2D(
filters=filters
, kernel_size=kernel_size
, data_format="channels_first"
, padding=padding
, use_bias=use_bias
, activation='relu'
, kernel_regularizer=regularizers.l2(reg_const)
)(x)
else:
x = Conv2D(
filters=filters
, kernel_size=kernel_size
, data_format="channels_first"
, padding=padding
, use_bias=use_bias
, activation='relu'
)(x)
if use_batchnorm:
x = BatchNormalization(axis=1)(x)
if use_LeakyReLU:
x = LeakyReLU()(x)
return (x)
def LeakyReLU_layer(x):
x = LeakyReLU()(x)
return (x)
def value_head(x, use_bias = True, use_reg = True,
reg_const = 0.0001, use_batchnorm = True, use_LeakyReLU = True):
if use_reg:
x = Conv2D(
filters=1
, kernel_size=(1, 1)
, data_format="channels_first"
, padding='same'
, use_bias=use_bias
, activation='linear'
, kernel_regularizer=regularizers.l2(reg_const)
)(x)
else:
x = Conv2D(
filters=1
, kernel_size=(1, 1)
, data_format="channels_first"
, padding='same'
, use_bias=use_bias
, activation='linear'
)(x)
if use_batchnorm:
x = BatchNormalization(axis=1)(x)
if use_LeakyReLU:
x = LeakyReLU()(x)
x = Flatten()(x)
x = Dense(
29*29*2
, use_bias=use_bias
, activation='relu'
, kernel_regularizer=regularizers.l2(reg_const)
)(x)
if use_LeakyReLU:
x = LeakyReLU()(x)
x = Dense(
1
, use_bias=use_bias
, activation='tanh'
, kernel_regularizer=regularizers.l2(reg_const)
, name='value_head'
)(x)
return (x)
def build_model_v01_only_value(weightsFileName = None):
main_input = Input(shape=(3,19,19,), name='main_input')
x = conv_layer(main_input, 64, (5,5), padding='same')
x = conv_layer(x, 64, (3, 3), padding='same')
x = conv_layer(x, 64, (3, 3), padding='same')
x = conv_layer(x, 64, (3, 3), padding='same')
for h in range(5):
x = residual_layer(x, 64, (3,3))
vh = value_head(x)
modelv = Model(inputs=[main_input], outputs=[vh])
modelv.compile(loss={'value_head': 'mean_squared_error'},
optimizer=Adam(), metrics=['mae'])
plot_model(modelv, 'gb_modelv_01.png', show_shapes=True)
if weightsFileName != None:
modelv.load_weights(weightsFileName)
return modelv
vm = build_model_v01_only_value("memoryweights/vm_0018.hdf5")
Xin= np.zeros((noOfMirrorsForEval+1, 3, 19, 19), dtype=float)
# ... some calculation of Xin
v_predict = vm.predict(Xin).flatten()
解决方案
推荐阅读
- laravel - 如何在 Laravel 中绑定服务?
- python - django.contrib.gis.geos.error.GEOSException:无法解析版本信息字符串
- php - 当用户在 woocommerce 上单击添加到购物车时,我想添加一个弹出确认
- xml - XSL:FO 应用 html 定义规则生成 PDF 的问题
- javascript - vue路由器组件变量未定义
- go - 得到错误未定义:使用数学/兰德库时go lang中的数学
- python - NameError:名称“状态”未定义
- c# - MongoDB C# - 不调用 ISupportInitialize 方法
- python - 运行 PyInstaller (python 2.7) 创建的可执行文件时出错
- python - Flask render_template 用于为 iframe 发送参数时显示错误