首页 > 解决方案 > Cartpole ValueError 上的 CEM:输入必须是 1 维或 2 维

问题描述

希望大家都好。我正在使用交叉熵方法制作推车杆,但是当我遇到这个错误时我很困惑。

    def sampleAgents(self):
        self.paramSize = 4
        self.nPop = 100
        self.mu = np.zeros(self.paramSize)
        self.cov = np.ones(self.paramSize)
        
        #Sample parameters from Gaussian dist. (with diagonal cov. matrix) using "self.mu" and "self.cov"
        samp = np.random.multivariate_normal(self.mu, np.diag(self.cov), self.nPop)
        
        #Assign samples to "self.paramSet". (self.paramSet.shape = (self.nPop, self.paramSize))
        self.paramSet = samp

当我运行这个时,我得到了一个错误ValueError: Input must be 1- or 2-d. 但是当我试图最后添加print(self.paramSet)时(看起来像这样)

    def sampleAgents(self):
        self.paramSize = 4
        self.nPop = 100
        self.mu = np.zeros(self.paramSize)
        self.cov = np.ones(self.paramSize)
        
        #Sample parameters from Gaussian dist. (with diagonal cov. matrix) using "self.mu" and "self.cov"
        samp = np.random.multivariate_normal(self.mu, np.diag(self.cov), self.nPop)
        
        #Assign samples to "self.paramSet". (self.paramSet.shape = (self.nPop, self.paramSize))
        self.paramSet = samp
        print(self.paramSet) 

并在该 print(self.paramSet) 中放置断点并进行调试,它工作正常。甚至他们也显示了我想要的大小、暗淡和价值。

有人可以帮我在哪里修复代码吗?提前致谢!

标签: pythonnumpyreinforcement-learningcross-entropy

解决方案


推荐阅读