首页 > 解决方案 > 使用 tf.data 将字符串转换为 csv 中的浮点数组

问题描述

我有一个这样的csv:

kw_text,kw_text_weight
amazon google,0.5 0.5
google facebook microsoft,0.5 0.3 0.2
kw_text kw_text_weight
亚马逊谷歌 0.5 0.5
谷歌 脸书 微软 0.5 0.3 0.2

我想将列转换text_weighttf.data. 但我在 tensorflow 文档网站上一无所获。

标签: tensorflowtensorflow2.0

解决方案


我相信这就是你想要的:

import pandas as pd
import tensorflow as tf

d = {"kw_text": [['amazon', 'google'], ['google', 'facebook', 'microsoft']], 
     "kw_text_weight": [['0.5', '0.5'], ['0.5', '0.3', '0.2']]}

df = pd.DataFrame(d)

# Convert string to float
for i in range(len(df.index)):
    df['kw_text_weight'][i] = [float(s) for s in df['kw_text_weight'][i]]

# Build dataset
rt=tf.ragged.constant(df['kw_text_weight'].tolist())
kw_text_weight_data = tf.data.Dataset.from_tensor_slices(rt)

for feature_batch in kw_text_weight_data:
    print(feature_batch)

输出:

tf.Tensor([0.5 0.5], shape=(2,), dtype=float32)
tf.Tensor([0.5 0.3 0.2], shape=(3,), dtype=float32)

推荐阅读