首页 > 解决方案 > 使用 tensorflow 数据集对象训练两个输入模型

问题描述

我正在尝试建立一个分子翻译模型,该模型将尝试根据分子的化学结构图像来预测分子的名称。因此该模型由两部分组成,CNN 部分接受输入图像,RNN 部分在训练时接受序列。我正在尝试构建一个用于训练的 Keras 数据集,因为该数据集非常大,无法存储在内存中。所以数据集有三个张量,第一个是输入图像,第二个是输入序列,然后是预测序列中下一个字符的目标序列。但是由于某种原因,在尝试使用此数据集进行训练时,会抛出一个错误,告诉您只给出了一个输入。有人可以帮我解决我的代码有什么问题。

import numpy as np
import cv2
import pandas as pd
import tensorflow as tf
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
data = pd.read_csv('../input/bms-molecular-translation/train_labels.csv')[:2000]

数据包含两列,一列包含图像 id,其中包含路径,另一列是地面实况预测

标记文本

image_ids = data['image_id']
inchi = data['InChI'].apply(lambda x: '\t' + x[9:] + '\n')
max_len = inchi.apply(len).max()
image_id_tensor = []
sequence_tensor = []
characters = set()
for seq in inchi:
    seq_list = list(seq)
    for char in seq_list:
        characters.add(char)
char_len = len(characters)
characters = ''.join(characters)

tokenizer = Tokenizer(filters=None, lower=False, char_level=True)

加载图像的功能。该函数接收 seq 和 label 参数,以便我可以使用数据集的 map 方法

def load_image(image_id, seq,label):
    print(image_id)
    image_id = image_id.decode('utf-8')
    image_path = f'../input/bms-molecular-translation/train/{image_id[0]}/{image_id[1]}/{image_id[2]}/{image_id}.png'
    print(image_path)
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    image = cv2.resize(image, (224, 224))
    image = image/255.0
    return image, seq, label

序列预处理

for each, image_id in zip(inchi, image_ids):
    for i in range(1, len(each)):
        n_seq = each[:i + 1]
        sequence_tensor.append(n_seq)
        image_id_tensor.append(image_id)
tokenizer.fit_on_texts(characters)
sequence_ = [tokenizer.texts_to_sequences([x])[0] for x in sequence_tensor]
sequence_ = pad_sequences(sequence_, maxlen=max_len, padding='pre')
sequence_ = np.array(sequence_)
y_sequence = sequence_[:, -1]
y_sequence = tf.one_hot(y_sequence, char_len)
sequence_ = sequence_[:, :-1]

数据集声明

dataset = tf.data.Dataset.from_tensor_slices(((image_id_tensor, sequence_),y_sequence))
dataset = dataset.map(lambda item1, item2,item3: tf.numpy_function(load_image, [item1,items2,item3],[tf.float32,tf.int32,tf.int32]),num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.batch(BATCH_SIZE)
dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)

模型


# CNN Input Layers

cnn_input = tf.keras.layers.Input(shape=(224, 224, 1))
x = tf.keras.layers.Conv2D(64, kernel_size=3, padding='same',activation='relu')(cnn_input)
x = tf.keras.layers.Conv2D(64, kernel_size=3, padding='same', activation='relu')(x)
x = tf.keras.layers.MaxPool2D((2,2))(x)
x = tf.keras.layers.Dropout(0.2)(x)
x = tf.keras.layers.Conv2D(128, kernel_size=3, padding='same',activation='relu')(x)
x = tf.keras.layers.Conv2D(128, kernel_size=3, padding='same', activation='relu')(x)
x = tf.keras.layers.MaxPool2D((2,2))(x)
x = tf.keras.layers.Dropout(0.2)(x)
x = tf.keras.layers.Conv2D(256, kernel_size=3, padding='same',activation='relu')(x)
x = tf.keras.layers.Conv2D(256, kernel_size=3, padding='same', activation='relu')(x)
x = tf.keras.layers.MaxPool2D((3,3))(x)
x = tf.keras.layers.Dropout(0.2)(x)
x = tf.keras.layers.Flatten()(x)
cnn_output = tf.keras.layers.Dense(64,activation='relu')(x)
# RNN Input Layers
input_rnn = tf.keras.layers.Input(shape=(max_len-1))
x = tf.keras.layers.Embedding(char_len,64)(input_rnn)
x = tf.keras.layers.LSTM(64, return_sequences=True)(x,initial_state=[cnn_output,cnn_output])
x = tf.keras.layers.LSTM(64, return_sequences=False)(x)
rnn_output =tf.keras.layers.Dense(char_len ,activation='softmax')(x)

model = tf.keras.models.Model(inputs=[cnn_input, input_rnn], outputs=[rnn_output])

model.compile(loss='categorical_crossentropy', optimizer='adam',metrics=['accuracy'])

model.fit(dataset, epochs=50)

当我尝试运行模型时,它给出了一个错误,表明只有一个输入,并且它需要两个输入。有人可以帮我写代码。

错误

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-12-8aed413901b4> in <module>
----> 1 model.fit(dataset, epochs=100)

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
   1098                 _r=1):
   1099               callbacks.on_train_batch_begin(step)
-> 1100               tmp_logs = self.train_function(iterator)
   1101               if data_handler.should_sync:
   1102                 context.async_wait()

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
    826     tracing_count = self.experimental_get_tracing_count()
    827     with trace.Trace(self._name) as tm:
--> 828       result = self._call(*args, **kwds)
    829       compiler = "xla" if self._experimental_compile else "nonXla"
    830       new_tracing_count = self.experimental_get_tracing_count()

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
    869       # This is the first call of __call__, so we have to initialize.
    870       initializers = []
--> 871       self._initialize(args, kwds, add_initializers_to=initializers)
    872     finally:
    873       # At this point we know that the initialization is complete (or less

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to)
    724     self._concrete_stateful_fn = (
    725         self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
--> 726             *args, **kwds))
    727 
    728     def invalid_creator_scope(*unused_args, **unused_kwds):

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
   2967       args, kwargs = None, None
   2968     with self._lock:
-> 2969       graph_function, _ = self._maybe_define_function(args, kwargs)
   2970     return graph_function
   2971 

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
   3359 
   3360           self._function_cache.missed.add(call_context_key)
-> 3361           graph_function = self._create_graph_function(args, kwargs)
   3362           self._function_cache.primary[cache_key] = graph_function
   3363 

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
   3204             arg_names=arg_names,
   3205             override_flat_arg_shapes=override_flat_arg_shapes,
-> 3206             capture_by_value=self._capture_by_value),
   3207         self._function_attributes,
   3208         function_spec=self.function_spec,

/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
    988         _, original_func = tf_decorator.unwrap(python_func)
    989 
--> 990       func_outputs = python_func(*func_args, **func_kwargs)
    991 
    992       # invariant: `func_outputs` contains only Tensors, CompositeTensors,

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds)
    632             xla_context.Exit()
    633         else:
--> 634           out = weak_wrapped_fn().__wrapped__(*args, **kwds)
    635         return out
    636 

/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs)
    975           except Exception as e:  # pylint:disable=broad-except
    976             if hasattr(e, "ag_error_metadata"):
--> 977               raise e.ag_error_metadata.to_exception(e)
    978             else:
    979               raise

ValueError: in user code:

    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py:805 train_function  *
        return step_function(self, iterator)
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py:795 step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/distribute/distribute_lib.py:1259 run
        return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/distribute/distribute_lib.py:2730 call_for_each_replica
        return self._call_for_each_replica(fn, args, kwargs)
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/distribute/distribute_lib.py:3417 _call_for_each_replica
        return fn(*args, **kwargs)
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py:788 run_step  **
        outputs = model.train_step(data)
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py:754 train_step
        y_pred = self(x, training=True)
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py:998 __call__
        input_spec.assert_input_compatibility(self.input_spec, inputs, self.name)
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/input_spec.py:207 assert_input_compatibility
        ' input tensors. Inputs received: ' + str(inputs))

    ValueError: Layer model expects 2 input(s), but it received 1 input tensors. Inputs received: [<tf.Tensor 'IteratorGetNext:0' shape=<unknown> dtype=float32>]

标签: pythontensorflowmachine-learningkerasdeep-learning

解决方案


推荐阅读