python - Tensorflow 随机分段错误
问题描述
我正在尝试从官方 tensorflow网站运行演示代码 我在此处附上完整代码(复制和整理)以方便
import tensorflow as tf
# print("1")
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import time
import os
# print("2")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
# @tf.function
def train_step(x, y):
with tf.GradientTape() as tape:
logits = model(x, training=True)
loss_value = loss_fn(y, logits)
grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))
train_acc_metric.update_state(y, logits)
return loss_value
# @tf.function
def test_step(x, y):
val_logits = model(x, training=False)
val_acc_metric.update_state(y, val_logits)
inputs = keras.Input(shape=(784,), name="digits")
x1 = layers.Dense(64, activation="relu")(inputs)
x2 = layers.Dense(64, activation="relu")(x1)
outputs = layers.Dense(10, name="predictions")(x2)
model = keras.Model(inputs=inputs, outputs=outputs)
# Instantiate an optimizer.
optimizer = keras.optimizers.SGD(learning_rate=1e-3)
# Instantiate a loss function.
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
train_acc_metric = keras.metrics.SparseCategoricalAccuracy()
val_acc_metric = keras.metrics.SparseCategoricalAccuracy()
# Prepare the training dataset.
batch_size = 64
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = np.reshape(x_train, (-1, 784))
x_test = np.reshape(x_test, (-1, 784))
# Reserve 10,000 samples for validation.
x_val = x_train[-10000:]
y_val = y_train[-10000:]
x_train = x_train[:-10000]
y_train = y_train[:-10000]
# Prepare the training dataset.
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)
# Prepare the validation dataset.
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_dataset = val_dataset.batch(batch_size)
epochs = 2
for epoch in range(epochs):
print("\nStart of epoch %d" % (epoch,))
start_time = time.time()
# Iterate over the batches of the dataset.
for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
loss_value = train_step(x_batch_train, y_batch_train)
# Log every 200 batches.
if step % 200 == 0:
print(
"Training loss (for one batch) at step %d: %.4f"
% (step, float(loss_value))
)
print("Seen so far: %d samples" % ((step + 1) * 64))
# Display metrics at the end of each epoch.
train_acc = train_acc_metric.result()
print("Training acc over epoch: %.4f" % (float(train_acc),))
# Reset training metrics at the end of each epoch
train_acc_metric.reset_states()
# Run a validation loop at the end of each epoch.
for x_batch_val, y_batch_val in val_dataset:
test_step(x_batch_val, y_batch_val)
val_acc = val_acc_metric.result()
val_acc_metric.reset_states()
print("Validation acc: %.4f" % (float(val_acc),))
print("Time taken: %.2fs" % (time.time() - start_time))
print("end")
这段代码无缘无故一开始就进入了Tensorflow 2.3.1中的Segmentation Fault
>python dummy.py
2021-03-11 17:45:52.231509: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.1
Segmentation fault (core dumped)
有趣的是,如果我在一开始就放一些随机打印语句(那些print("1")
etc 语句,代码将执行到最后并在最后遇到分段错误(未显示冗余输出)
Start of epoch 1
Training loss (for one batch) at step 0: 1.0215
Seen so far: 64 samples
Training loss (for one batch) at step 200: 0.9116
Seen so far: 12864 samples
Training loss (for one batch) at step 400: 0.4894
Seen so far: 25664 samples
Training loss (for one batch) at step 600: 0.5636
Seen so far: 38464 samples
Training acc over epoch: 0.8416
Validation acc: 0.8296
Time taken: 3.16s
end
Segmentation fault (core dumped)
另一个观察结果是,如果我取消注释我的和函数的@tf.function
顶部,代码会再次进入段错误但在打印之后
trainStep
testStep
Start of epoch 0
有人可以解释我的 Tensorflow 包出了什么问题吗?
解决方案
这是由于旧版本的 Ubuntu 造成的。我用的是14,升级到18后问题解决了
推荐阅读
- r - mutate based of the value of multiple columns
- java - 使用 userdata 在 EC2 上升级到 Java 8 OpenJDK
- java - 无法运行 Junit 测试
- hibernate - 如何更改拦截器内对象的属性值?
- vba - Error 1004 on formula
- asp.net - 未找到 Visual Studio localhost IIS 服务器
- python - 如何用选项填充 Django 自定义多项选择字段
- c - 拆分长字符串
- android - 在同一个位图上应用两个或多个 ScriptIntrinsicConvolve3x3
- mongodb - MongoDB 游标可以返回两次更改的文档吗?