首页 > 解决方案 > python错误“无法根据规则'same_kind'将数组数据从dtype('O')转换为dtype('float32')”

问题描述

import numpy as np 
import gym
from gym.spaces import Box

import ray
from ray import tune
from ray.rllib.utils import try_import_tf

import ray.rllib.agents.ddpg as ddpg
from ray.tune.logger import pretty_print

tf = try_import_tf()

# gym environment adapter
class SimpleSupplyChain(gym.Env):
    def __init__(self, config):
        self.reset()
        self.action_space = Box(low=0.0, high=1000.0, shape=(self.supply_chain.retail_store_num + 1, ), dtype=np.float32)
        self.observation_space = Box(-1000000.0, 10000000, shape=(len(self.supply_chain.initial_state().to_array()), ), dtype=np.float32)

    def reset(self):
        self.supply_chain = SupplyChainEnvironment()
        self.state = self.supply_chain.initial_state()
        return self.state.to_array()

    def step(self, action):
        action_obj = Action(self.supply_chain.retail_store_num)
        action_obj.production_level = action[0]
        action_obj.shippings_to_retail_stores = action[1:]
        self.state, reward, done = self.supply_chain.step(self.state, action_obj)
        return self.state.to_array(), reward, done, {}
    
ray.shutdown()
ray.init()

def train_ddpg():
    config = ddpg.DEFAULT_CONFIG.copy()
    config["log_level"] = "WARN"
    config["actor_hiddens"] = [512, 512] 
    config["critic_hiddens"] = [512, 512]
    config["gamma"] = 0.95
    config["timesteps_per_iteration"] = 1000
    config["target_network_update_freq"] = 5
    config["buffer_size"] = 10000
    # config['actor_lr']=1e-6
    # config['critic_lr']=1e-6
    print(config)

    import json

    try:
        import cPickle as pickle
    except ImportError:  # Python 3.x
        import pickle

    with open('config.p', 'wb') as fp:
        pickle.dump(config, fp, protocol=pickle.HIGHEST_PROTOCOL)
    
    trainer = ddpg.DDPGTrainer(config=config, env=SimpleSupplyChain)
    for i in range(5):
        result = trainer.train()
        print(pretty_print(result))
        checkpoint = trainer.save('./')
        print("Checkpoint saved at", checkpoint)

train_ddpg()

当我运行上面的代码时,我收到了这个错误。

Cannot cast array data from dtype('O') to dtype('float32') according to the rule 'same_kind'

我正在使用 DDPG 查找解决方案,但数据类型似乎有问题。数据来自 pandas,我将数据类型更改为 float64。我检查并尝试了所有可能的方法,但错误仍然出现。有人可以帮我吗?

标签: pythonpandasnumpytypes

解决方案


推荐阅读