python - keras在拟合模型时是否深度复制数据?
问题描述
当我运行我的模型(用于图像分割的 Unet)时,我有 ram 内存错误弹出:
2020-11-19 11:25:18.027748: W tensorflow/core/framework/cpu_allocator_impl.cc:81] Allocation of 11998593024 exceeds 10% of free system memory.
2020-11-19 11:25:32.991088: W tensorflow/core/framework/cpu_allocator_impl.cc:81] Allocation of 11998593024 exceeds 10% of free system memory.
2020-11-19 11:25:46.109554: W tensorflow/core/framework/cpu_allocator_impl.cc:81] Allocation of 11998593024 exceeds 10% of free system memory.
分配的内存图:
我想知道 tensorflow 是否在深度复制数据,如果是这样,有没有办法避免它(不使用 DataGenerator)。
主脚本:
from data_preprocessing import data_utils,DataGenerator
from model import model_utils,loss_utils
from keras.callbacks import ModelCheckpoint, LearningRateScheduler
from sklearn.model_selection import train_test_split
import tensorflow as tf
if __name__ == "__main__":
X,Y = data_utils.load_all()
print("Checkpoint 1")
strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
Xtrain,Xtest,Ytrain,Ytest = train_test_split(X,Y, test_size = 1/5, shuffle = True)
print("Checkpoint 2")
unet = model_utils.unet(input_size=(256,256,1))
print("Checkpoint 3")
checkpointer = ModelCheckpoint('image_segm.hdf5',monitor='loss',verbose=1,save_best_only=True)
historic = unet.fit(Xtrain,Ytrain,epochs=1,callbacks=[checkpointer],batch_size= 5)
print("End")
编辑:在 conda 环境中使用 tensorflow-gpu 2.20.0
解决方案
查看这篇文章,它将帮助您解决 Datagen 问题https://stanford.edu/~shervine/blog/keras-how-to-generate-data-on-the-fly
推荐阅读
- flutter - 嵌入式视频(来自互联网)播放器颤动
- c# - 访问使用 DBContext 创建的数据库
- facebook-graph-api - 使用 Facebook Graph API Explorer 安排图像时出现“不支持的发布请求”
- snowflake-cloud-data-platform - 如何删除雪花表中的重复项但只保留一条记录?除了使用 rownumber() 插入另一个表之外,还有其他方法吗?
- excel - 如何使我的 VBA 代码运行得更快而不使屏幕闪烁?
- r - 在 col C 的某些值下,如何根据 col A 是否大于 col B 的平均值来改变新列?
- reactjs - 如何在 Redux 中存储或访问有效负载数据
- javascript - 使用 math.random() 而不重复数组中的那些随机数
- python - 在列表列表中,我需要将每个单独列表中的元素限制为前 100 个
- javascript - 我如何申请 onInvalid 但适用于整个表格?