python - 使用 tensorflow 数据集对象训练两个输入模型
问题描述
我正在尝试建立一个分子翻译模型,该模型将尝试根据分子的化学结构图像来预测分子的名称。因此该模型由两部分组成,CNN 部分接受输入图像,RNN 部分在训练时接受序列。我正在尝试构建一个用于训练的 Keras 数据集,因为该数据集非常大,无法存储在内存中。所以数据集有三个张量,第一个是输入图像,第二个是输入序列,然后是预测序列中下一个字符的目标序列。但是由于某种原因,在尝试使用此数据集进行训练时,会抛出一个错误,告诉您只给出了一个输入。有人可以帮我解决我的代码有什么问题。
import numpy as np
import cv2
import pandas as pd
import tensorflow as tf
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
data = pd.read_csv('../input/bms-molecular-translation/train_labels.csv')[:2000]
数据包含两列,一列包含图像 id,其中包含路径,另一列是地面实况预测
标记文本
image_ids = data['image_id']
inchi = data['InChI'].apply(lambda x: '\t' + x[9:] + '\n')
max_len = inchi.apply(len).max()
image_id_tensor = []
sequence_tensor = []
characters = set()
for seq in inchi:
seq_list = list(seq)
for char in seq_list:
characters.add(char)
char_len = len(characters)
characters = ''.join(characters)
tokenizer = Tokenizer(filters=None, lower=False, char_level=True)
加载图像的功能。该函数接收 seq 和 label 参数,以便我可以使用数据集的 map 方法
def load_image(image_id, seq,label):
print(image_id)
image_id = image_id.decode('utf-8')
image_path = f'../input/bms-molecular-translation/train/{image_id[0]}/{image_id[1]}/{image_id[2]}/{image_id}.png'
print(image_path)
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
image = cv2.resize(image, (224, 224))
image = image/255.0
return image, seq, label
序列预处理
for each, image_id in zip(inchi, image_ids):
for i in range(1, len(each)):
n_seq = each[:i + 1]
sequence_tensor.append(n_seq)
image_id_tensor.append(image_id)
tokenizer.fit_on_texts(characters)
sequence_ = [tokenizer.texts_to_sequences([x])[0] for x in sequence_tensor]
sequence_ = pad_sequences(sequence_, maxlen=max_len, padding='pre')
sequence_ = np.array(sequence_)
y_sequence = sequence_[:, -1]
y_sequence = tf.one_hot(y_sequence, char_len)
sequence_ = sequence_[:, :-1]
数据集声明
dataset = tf.data.Dataset.from_tensor_slices(((image_id_tensor, sequence_),y_sequence))
dataset = dataset.map(lambda item1, item2,item3: tf.numpy_function(load_image, [item1,items2,item3],[tf.float32,tf.int32,tf.int32]),num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.batch(BATCH_SIZE)
dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
模型
# CNN Input Layers
cnn_input = tf.keras.layers.Input(shape=(224, 224, 1))
x = tf.keras.layers.Conv2D(64, kernel_size=3, padding='same',activation='relu')(cnn_input)
x = tf.keras.layers.Conv2D(64, kernel_size=3, padding='same', activation='relu')(x)
x = tf.keras.layers.MaxPool2D((2,2))(x)
x = tf.keras.layers.Dropout(0.2)(x)
x = tf.keras.layers.Conv2D(128, kernel_size=3, padding='same',activation='relu')(x)
x = tf.keras.layers.Conv2D(128, kernel_size=3, padding='same', activation='relu')(x)
x = tf.keras.layers.MaxPool2D((2,2))(x)
x = tf.keras.layers.Dropout(0.2)(x)
x = tf.keras.layers.Conv2D(256, kernel_size=3, padding='same',activation='relu')(x)
x = tf.keras.layers.Conv2D(256, kernel_size=3, padding='same', activation='relu')(x)
x = tf.keras.layers.MaxPool2D((3,3))(x)
x = tf.keras.layers.Dropout(0.2)(x)
x = tf.keras.layers.Flatten()(x)
cnn_output = tf.keras.layers.Dense(64,activation='relu')(x)
# RNN Input Layers
input_rnn = tf.keras.layers.Input(shape=(max_len-1))
x = tf.keras.layers.Embedding(char_len,64)(input_rnn)
x = tf.keras.layers.LSTM(64, return_sequences=True)(x,initial_state=[cnn_output,cnn_output])
x = tf.keras.layers.LSTM(64, return_sequences=False)(x)
rnn_output =tf.keras.layers.Dense(char_len ,activation='softmax')(x)
model = tf.keras.models.Model(inputs=[cnn_input, input_rnn], outputs=[rnn_output])
model.compile(loss='categorical_crossentropy', optimizer='adam',metrics=['accuracy'])
model.fit(dataset, epochs=50)
当我尝试运行模型时,它给出了一个错误,表明只有一个输入,并且它需要两个输入。有人可以帮我写代码。
错误
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-12-8aed413901b4> in <module>
----> 1 model.fit(dataset, epochs=100)
/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
1098 _r=1):
1099 callbacks.on_train_batch_begin(step)
-> 1100 tmp_logs = self.train_function(iterator)
1101 if data_handler.should_sync:
1102 context.async_wait()
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
826 tracing_count = self.experimental_get_tracing_count()
827 with trace.Trace(self._name) as tm:
--> 828 result = self._call(*args, **kwds)
829 compiler = "xla" if self._experimental_compile else "nonXla"
830 new_tracing_count = self.experimental_get_tracing_count()
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
869 # This is the first call of __call__, so we have to initialize.
870 initializers = []
--> 871 self._initialize(args, kwds, add_initializers_to=initializers)
872 finally:
873 # At this point we know that the initialization is complete (or less
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to)
724 self._concrete_stateful_fn = (
725 self._stateful_fn._get_concrete_function_internal_garbage_collected( # pylint: disable=protected-access
--> 726 *args, **kwds))
727
728 def invalid_creator_scope(*unused_args, **unused_kwds):
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
2967 args, kwargs = None, None
2968 with self._lock:
-> 2969 graph_function, _ = self._maybe_define_function(args, kwargs)
2970 return graph_function
2971
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
3359
3360 self._function_cache.missed.add(call_context_key)
-> 3361 graph_function = self._create_graph_function(args, kwargs)
3362 self._function_cache.primary[cache_key] = graph_function
3363
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
3204 arg_names=arg_names,
3205 override_flat_arg_shapes=override_flat_arg_shapes,
-> 3206 capture_by_value=self._capture_by_value),
3207 self._function_attributes,
3208 function_spec=self.function_spec,
/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
988 _, original_func = tf_decorator.unwrap(python_func)
989
--> 990 func_outputs = python_func(*func_args, **func_kwargs)
991
992 # invariant: `func_outputs` contains only Tensors, CompositeTensors,
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds)
632 xla_context.Exit()
633 else:
--> 634 out = weak_wrapped_fn().__wrapped__(*args, **kwds)
635 return out
636
/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs)
975 except Exception as e: # pylint:disable=broad-except
976 if hasattr(e, "ag_error_metadata"):
--> 977 raise e.ag_error_metadata.to_exception(e)
978 else:
979 raise
ValueError: in user code:
/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py:805 train_function *
return step_function(self, iterator)
/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py:795 step_function **
outputs = model.distribute_strategy.run(run_step, args=(data,))
/opt/conda/lib/python3.7/site-packages/tensorflow/python/distribute/distribute_lib.py:1259 run
return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
/opt/conda/lib/python3.7/site-packages/tensorflow/python/distribute/distribute_lib.py:2730 call_for_each_replica
return self._call_for_each_replica(fn, args, kwargs)
/opt/conda/lib/python3.7/site-packages/tensorflow/python/distribute/distribute_lib.py:3417 _call_for_each_replica
return fn(*args, **kwargs)
/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py:788 run_step **
outputs = model.train_step(data)
/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py:754 train_step
y_pred = self(x, training=True)
/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py:998 __call__
input_spec.assert_input_compatibility(self.input_spec, inputs, self.name)
/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/input_spec.py:207 assert_input_compatibility
' input tensors. Inputs received: ' + str(inputs))
ValueError: Layer model expects 2 input(s), but it received 1 input tensors. Inputs received: [<tf.Tensor 'IteratorGetNext:0' shape=<unknown> dtype=float32>]
解决方案
推荐阅读
- c++ - 为什么 C++ 位域要求我指定类型?
- javascript - 搜索符号时 List.js 不起作用
- ruby-on-rails - 使用 Rails Active Storage 创建对象时在 AWS S3 上智能地命名对象
- c++ - Boost::GIL:使用 alpha 通道读取 *.png 图像缺少抗锯齿
- php - Onesignal Rest API - Laravel
- python - Python evdev 库是否具有特定于事件的抓取或直通?
- python - 如何在熊猫中使用 groupby 或 pivot_table
- java - 一种使用 Java 中的堆栈返回新的反向单链表的方法,保留相同的元素但以相反的顺序打印出来
- javascript - 是否有可能在不使用加载调用的情况下将表单绑定到模式引导窗口?
- wagtail - 我可以在 Wagtail 中使用递归 ParentalKey 吗?