python - TensorFlow 2.0 教程问题
问题描述
我正在关注https://www.tensorflow.org/alpha/tutorials/sequences/text_classification_rnn上的官方教程,但遇到了问题。以下行导致错误:
train_dataset = train_dataset.padded_batch(BATCH_SIZE, train_dataset.output_shapes)
回溯(最后一次调用):文件“main.py”,第 30 行,在 train_dataset = train_dataset.padded_batch(BATCH_SIZE, train_dataset.output_shapes) AttributeError: 'ShuffleDataset' object has no attribute 'output_shapes
我错过了什么?这是我完成一半的代码:
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow_datasets as tfds
import tensorflow as tf
import matplotlib.pyplot as plt
import tensorflow.keras
def plot_graphs(history, string):
plt.plot(history.history[string])
plt.plot(history.history['val_'+string])
plt.xlabel("Epochs")
plt.ylabel(string)
plt.legend([string, 'val_'+string])
plt.show()
dataset, info = tfds.load('imdb_reviews/subwords8k', with_info=True,
as_supervised=True)
train_dataset, test_dataset = dataset['train'], dataset['test']
tokenizer = info.features['text'].encoder
print ('Vocabulary size: {}'.format(tokenizer.vocab_size))
# sample_string = 'TensorFlow is cool.'
# tokenized_string = tokenizer.encode(sample_string)
# print ('Tokenized string is {}'.format(tokenized_string))
# original_string = tokenizer.decode(tokenized_string)
# print ('The original string: {}'.format(original_string))
# assert original_string == sample_string
# for ts in tokenized_string:
# print ('{} ----> {}'.format(ts, tokenizer.decode([ts])))
BUFFER_SIZE = 10000
BATCH_SIZE = 64
train_dataset = train_dataset.shuffle(BUFFER_SIZE)
train_dataset = train_dataset.padded_batch(BATCH_SIZE, train_dataset.output_shapes)
test_dataset = test_dataset.padded_batch(BATCH_SIZE, test_dataset.output_shapes)
解决方案
请更换
train_dataset = train_dataset.padded_batch(BATCH_SIZE, train_dataset.output_shapes)
和
train_dataset = train_dataset.padded_batch(BATCH_SIZE, tf.compat.v1.data.get_output_shapes(train_dataset))
并替换
test_dataset = test_dataset.padded_batch(BATCH_SIZE, test_dataset.output_shapes)
和
test_dataset = test_dataset.padded_batch(BATCH_SIZE, tf.compat.v1.data.get_output_shapes(test_dataset))
推荐阅读
- parsing - Dart 语言是如何解析的(从左到右,接收等)?
- spring-security-oauth2 - OAuth2 身份验证 Web 服务调用后端客户端
- javascript - 使用静态类属性时 Node.js 的巨大性能问题
- ionic-framework - Ionic 4 页面导航
- google-cloud-platform - GCP和GSC的多个Google-site-verification元标记?
- c# - 如何将相关表添加到不相关表?
- google-apps-script - 注销后清除导航堆栈 - Gmail 插件
- sql - PostgreSQL: CAST() as money: 指定货币
- c++ - 检测到您的柯南配置文件设置和 CMake 之间的编译器版本不匹配
- python - 如何在 python 中打印完整的相关输出数组?另外,如果你能告诉我如何解释它们,那就太好了/