tensorflow - 使用 Tensorflow 在 For 循环中创建神经网络
问题描述
我正在尝试从包含所需网络结构的列表中使用 Tensorflow 在 for 循环中创建一个简单的 MLP,例如 structure = [100, 50, 20, 1]。在这个列表中,100 代表输入大小,1 代表输出大小。(我需要这个用于预测应用程序,但这与我的问题并不严格相关。)
我还没有看到在 Tensorflow 中创建网络的类似方法。由于部分原因让我无法理解,人们似乎建议最好分别声明每个变量,例如 layer_1 = x1 * w1 + b1 然后 layer_2 = x2 * w2 + b2。创建我在 for 循环 [for i in range(len(structure)-1):] 中使用的网络的动态方法是否错误?对我来说,网络似乎工作正常,张量板上显示的网络结构似乎是正确的。
您认为这种创建网络的方式好吗?您是否认为我在不知不觉中陷入了任何 Tensorflow / Context Manager 问题?
import tensorflow as tf
class Model(object):
def __init__(self, structure, lr=0.01):
assert structure[-1] == 1
input_size = structure[0]
act_fun = tf.nn.tanh
G = tf.Graph()
with G.as_default():
self.X = tf.placeholder(tf.float32, shape=[None, input_size])
self.Y = tf.placeholder(tf.float32, shape=[None, 1])
X_out = self.X
for i in range(len(structure)-1):
from_, to = structure[i], structure[i+1]
initializer = tf.variance_scaling_initializer()
w = tf.Variable(initializer([from_, to]), dtype=tf.float32, name=f'W{i}')
b = tf.Variable(tf.zeros(to), name=f'B{i}')
if to != 1:
X_out = act_fun(tf.matmul(X_out, w) + b)
else:
X_out = tf.matmul(X_out, w) + b
self.forecast_layer = X_out
self.loss = tf.losses.mean_squared_error(self.Y, self.forecast_layer)
self.trainer = tf.train.GradientDescentOptimizer(lr).minimize(self.loss)
self.init = tf.global_variables_initializer()
self.session = tf.Session(graph=G)
self.session.run(self.init)
self.session.graph.finalize()
def fit(self, X, Y):
self.session.run(self.trainer, feed_dict={self.X:X, self.Y:Y})
def forecast(self, X):
return self.forecast_layer.eval(feed_dict={self.X:X}, session=self.session)
def evaluate_loss(self, X, Y):
return self.loss.eval(feed_dict={self.Y:Y, self.forecast_layer:self.forecast(X)}, session=self.session)
M = Model([100, 50, 20, 1], lr=0.001)
解决方案
这似乎很好,虽然我会使用tf.layers.dense
:
X_out = self.X
layers = [50, 20, 1]
for layer_size in layers:
X_out = tf.layers.dense(X_out, layer_size, activation=act_fun if x != 1 else None)
笔记:
警告:此功能已弃用。它将在未来的版本中删除。更新说明:改用 keras.layers.dense。
推荐阅读
- php - 如何解析数据
- python - 如何反转顺序以弹出 Python 3.6.4
- sql - 您可以在 DB2 SQL 中订购 listagg 返回值吗?
- c# - 从图像适配器中的 sqlite 检索图像路径(Xamarin Android C# 上的 GridView)
- c++ - 我如何获得正确的 BMI?
- android - 如何在flutter中调用另一个类的方法
- c - 如何在不触发错误的情况下将数组的大小包含到 for 循环中
- node.js - 当响应为异步时,如何在 Node.js 中处理多个同时请求?
- java - 从 .json 获取 JSON 值
- docker - 在启动 docker 容器时,我必须在 docker 容器中执行脚本