首页 > 解决方案 > 通过多处理模块传递带有 keras 模型的类对象时出现 Pickle 错误

问题描述

我正在尝试运行一个并行处理应用程序,multiprocessing其中我通过一个包含使用 keras 的神经网络模型的类。但是,当通过模块的starmap方法传递对象时出现泡菜错误。multiprocessing下面提供了一个玩具示例,其中代理cartpole并行运行 10 集:

from multiprocessing import Pool
import itertools
import gym
from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.models import Model

class Agent:
    def __init__(self, input_dim, hidden_dims, output_dim):
        self.input_dim = input_dim
        self.hidden_dims = hidden_dims
        self.output_dim = output_dim
        inputs = Input(shape=(self.input_dim,))

        net = inputs
        # a layer instance is callable on a tensor, and returns a tensor
        for h_dim in self.hidden_dims:
            net = Dense(h_dim, activation='relu')(net)

        net = Dense(self.output_dim, activation='softmax')(net)
        # This creates a model that includes
        # the Input layer and three Dense layers
        self.model = Model(inputs=inputs, outputs=net)

    def act(self, state):
        state = np.reshape(state, [1, self.input_dim])
        action = np.argmax(self.model.predict(state))
        return action

env = gym.make("CartPole-v1")
observation_space = env.observation_space.shape[0]
action_space = env.action_space.n

def run_agent(num_episode, agent):
    state = env.reset()
    reward_episode = []
    while True:
        env.render()
        action = agent.act(state)
        state_next, reward, terminal, info = env.step(action)
        reward = reward if not terminal else -reward
        state_next = np.reshape(state_next, [1, observation_space])
        reward_episode.append(reward)
        state = state_next
        if terminal:
            break
    return sum(reward_episode)

def run_parallel(agent):
    episodes = list(range(10));
    args_to_func = []
    for i in episodes:
        args_to_func.append([i, agent])

    reward_agent = []
    with Pool(processes=4) as pool:
        reward_agent = pool.starmap(run_agent, args_to_func)
        pool.close()
        pool.join()
    print(reward_agent)

if __name__ == "__main__":
    agent = Agent(observation_space, [32], action_space)
    run_parallel(agent)

错误如下:

Traceback (most recent call last):
  File "example_ev_alg.py", line 63, in <module>
    run_parallel(agent)
  File "example_ev_alg.py", line 56, in run_parallel
    reward_agent = pool.starmap(run_agent, args_to_func)
  File "/usr/lib/python3.6/multiprocessing/pool.py", line 274, in starmap
    return self._map_async(func, iterable, starmapstar, chunksize).get()
  File "/usr/lib/python3.6/multiprocessing/pool.py", line 644, in get
    raise self._value
  File "/usr/lib/python3.6/multiprocessing/pool.py", line 424, in _handle_tasks
    put(task)
  File "/usr/lib/python3.6/multiprocessing/connection.py", line 206, in send
    self._send_bytes(_ForkingPickler.dumps(obj))
  File "/usr/lib/python3.6/multiprocessing/reduction.py", line 51, in dumps
    cls(buf, protocol).dump(obj)
TypeError: can't pickle _thread.RLock objects

标签: python-3.xkeraspicklepython-multiprocessing

解决方案


推荐阅读