tensorflow - TensorFlow 中的文本数据增强
问题描述
我正在对 tensorflow 中的 IMDB 数据集进行情感分析,并尝试通过使用他们所说的“即插即用”到 tensorflow的textaugment 库来扩充训练数据集。所以它应该很简单,但我是 tf 的新手,所以我不知道该怎么做。根据阅读网站上的教程,这是我所拥有的和正在尝试的。
我试图做一个地图来增加训练数据,但我得到了一个错误。您可以向下滚动到最后一个代码块以查看错误。
pip install -q tensorflow-text
pip install -q tf-models-official
import os
import shutil
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_text as text
from official.nlp import optimization # to create AdamW Optimizer
import matplotlib.pyplot as plt
tf.get_logger().setLevel('ERROR')
#下载IMDB数据集并制作训练/验证/测试集
url = 'https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz'
dataset = tf.keras.utils.get_file('aclImdb_v1.tar.gz', url,
untar=True, cache_dir='.',
cache_subdir='')
dataset_dir = os.path.join(os.path.dirname(dataset), 'aclImdb')
train_dir = os.path.join(dataset_dir, 'train')
# remove unused folders to make it easier to load the data
remove_dir = os.path.join(train_dir, 'unsup')
shutil.rmtree(remove_dir)
AUTOTUNE = tf.data.AUTOTUNE
batch_size = 32
seed = 42
raw_train_ds = tf.keras.preprocessing.text_dataset_from_directory(
'aclImdb/train',
batch_size=batch_size,
validation_split=0.2,
subset='training',
seed=seed)
class_names = raw_train_ds.class_names
train_ds = raw_train_ds.cache().prefetch(buffer_size=AUTOTUNE)
val_ds = tf.keras.preprocessing.text_dataset_from_directory(
'aclImdb/train',
batch_size=batch_size,
validation_split=0.2,
subset='validation',
seed=seed)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
test_ds = tf.keras.preprocessing.text_dataset_from_directory(
'aclImdb/test',
batch_size=batch_size)
test_ds = test_ds.cache().prefetch(buffer_size=AUTOTUNE)
#setting up the textaugment
try:
import textaugment
except ModuleNotFoundError:
!pip install textaugment
import textaugment
from textaugment import EDA
import nltk
nltk.download('stopwords')
现在这是我得到错误的地方,我在 train_ds 上尝试了一个地图,并尝试在保持类相同的同时向每个元素添加随机交换:
aug_ds = train_ds.map(
lambda x, y: (t.random_swap(x), y))
错误信息:
AttributeError Traceback (most recent call last)
<ipython-input-24-b4af68cc0677> in <module>()
1 aug_ds = train_ds.map(
----> 2 lambda x, y: (t.random_swap(x), y))
10 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/impl/api.py in wrapper(*args, **kwargs)
668 except Exception as e: # pylint:disable=broad-except
669 if hasattr(e, 'ag_error_metadata'):
--> 670 raise e.ag_error_metadata.to_exception(e)
671 else:
672 raise
AttributeError: in user code:
<ipython-input-24-b4af68cc0677>:2 None *
lambda x, y: (t.random_swap(x), y))
/usr/local/lib/python3.6/dist-packages/textaugment/eda.py:187 random_swap *
self.validate(sentence=sentence, n=n)
/usr/local/lib/python3.6/dist-packages/textaugment/eda.py:74 validate *
if not isinstance(kwargs['sentence'].strip(), str) or len(kwargs['sentence'].strip()) == 0:
AttributeError: 'Tensor' object has no attribute 'strip'
解决方案
我也在尝试做同样的事情。发生错误是因为 textaugment 函数t.random_swap()
应该适用于Python 字符串对象。
在您的代码中,该函数采用 dtype=string 的张量。截至目前,张量对象没有与 Python 字符串相同的方法。因此,错误代码。
NB。tensorflow_text有一些额外的 API 可以处理这种字符串类型的张量。尽管目前它仅限于标记化、检查大小写等。一个冗长的解决方法是使用py_function
包装器,但这会降低性能。干杯,希望这会有所帮助。在我的用例中,我最终选择不使用 textaugment。
NB。tf.strings API具有更多功能,例如正则表达式替换等,但对于您的增强用例来说还不够复杂。看看其他人想出了什么,或者未来是否有 TF 或 textaugment 的更新会很有帮助。
推荐阅读
- r - 调整ggplot2中的第二个y轴
- c# - DateTime.ParseExact 工作日作为文本抛出 FormatException
- python - 多列的分组并通过考虑每个列的开始和结束为每个列分配值(熊猫)
- python - 为什么 Binary_Cross-entropy Loss 给出负值?
- terraform-provider-azure - 如何使用 terraform 将虚拟网络添加到 api 管理?
- javascript - 转换嵌套对象中的所有值
- sql - 为 azure 突触分析创建外部表的问题
- java - installCert.java 中的 SavingTrustManager 函数什么时候调用?
- javascript - Javascript RegExp 未知的重复匹配
- php - 试图从json数组将列数据插入mysql