r - 如何将张量转换为 R 数组(在损失函数中,因此无需急切执行)?
问题描述
我有TensorFlow
2.4 版并使用 R 包tensorflow
(2.2.0) 和keras
(2.3.0.0.9000)。我想在损失函数中将张量转换为 R 数组(不要问为什么)。这是一个这样的转换(在损失函数之外)起作用的例子:
library(tensorflow)
library(keras)
x.R <- matrix(1:12, ncol = 3) # dummy R object
x.tensor <- keras_array(x.R) # converting the R object to a tensor
as.array(x.tensor) # converting it back to an R array. This works because...
stopifnot(tf$executing_eagerly()) # ... eager execution is enabled
在模型训练期间,FALSE
虽然会急切执行,因此as.array()
调用失败。为了看到这一点,让我们首先定义一个虚拟神经网络模型和训练数据。
d <- 2 # input and output dimension
in.lay <- layer_input(shape = d)
hid.lay <- layer_dense(in.lay, units = 300, activation = "relu")
out.lay <- layer_dense(hid.lay, units = d, activation = "sigmoid")
model <- keras_model(in.lay, out.lay)
n <- 1200 # number of training samples
data <- matrix(runif(n * d), ncol = d) # training data
现在让我们定义损失函数并用它编译模型。
myloss <- function(x, y) { # x and y are tensors here
stopifnot(!tf$executing_eagerly()) # confirms that eager execution is disabled
x. <- as.array(x) # ... fails with "RuntimeError: Evaluation error: invalid first argument, must be vector (list or atomic)." How can we convert x to an R array?
loss_mean_squared_error(x, y) # just a dummy return value (the MSE)
}
compile(model, optimizer = "adam", loss = myloss)
让我们尝试拟合这个模型(看看它无法通过 将张量转换x
为 R 数组as.array()
)。
prior <- matrix(rexp(n * d), ncol = d) # input sample to train the NN on
n.epoch <- 5 # number of epochs to train
batch.size <- 400 # batch size
fit(model, x = prior, y = data, batch_size = batch.size, epochs = n.epoch) # fails with error message given above
R 包 tensorflow 提供tfe_enable_eager_execution()
了在会话中启用即时执行的功能。但是如果我用TensorFlow
2.4 调用它,那么我得到:
tfe_enable_eager_execution() # "Error in py_get_attr_impl(x, name, silent) : AttributeError: module 'tensorflow' has no attribute 'contrib'"
理想情况下,我不想过多地处理急切执行(不确定副作用),只需将张量转换为数组即可。我的猜测是,除了急切执行之外别无他法,因为只有指针被解析并且 R 包tensorflow
在张量中找到数据并能够将其转换为数组。
这里提到了启用/禁用急切执行的其他想法,但这都是在 Python 中,在 R 中似乎不可用。而这篇这篇文章似乎提出了同样的问题,但在不同的背景下。
解决方案
推荐阅读
- excel - 如何分解一个非常大的excel文件
- javascript - 尝试将数据快照放入我想要的类时出现问题
- django - 如何在 Django 中将 UTC 时间转换为用户的本地时间
- python - 我需要使用变量调用 python 函数
- r - 如何按组执行成对统计检验?
- node.js - Heroku - “没有网络进程正在运行”消息,但服务器已经启动
- debian-buster - 如何在 debian 10 上安装 AMD radeon PRO WX 4100 的驱动程序?
- python - 多个 Python 绘图未在其他子图中显示绘图
- android - Android动态功能卡在安装视图虽然它说它已安装
- c# - 将按钮绑定到子组件 blazor