首页 > 解决方案 > 在自定义环境中运行的 Tensorflow C51 示例代码会出现形状错误

问题描述

我正在尝试从本教程中获取 Tensorflow 示例代码:https : //www.tensorflow.org/agents/tutorials/9_c51_tutorial 在使用 pychrono 引擎的自定义环境中运行。目前,一个简单的测试代码运行随机操作,但是当我尝试运行完整代码时,我得到了错误:

InvalidArgumentError: Must have updates.shape = indices.shape + params.shape[1:] 
or updates.shape = [], got updates.shape [1], indices.shape [1], params.shape [100000,1] [Op:ResourceScatterUpdate]

自定义环境的代码是:

import pychrono as chrono
from pychrono import irrlicht as chronoirr
import numpy as np

#from __future__ import absolute_import
#from __future__ import division
#from __future__ import print_function

#import abc
#import tensorflow as tf
#import numpy as np

from tf_agents.environments import py_environment
#from tf_agents.environments import tf_environment
#from tf_agents.environments import tf_py_environment
#from tf_agents.environments import utils
from tf_agents.specs import array_spec
#from tf_agents.environments import wrappers
#from tf_agents.environments import suite_gym
from tf_agents.trajectories import time_step as ts

from typing import Text #Any, Optional, 
chrono.SetChronoDataPath('/Users/maxwelllittle/opt/anaconda3/pkgs/pychrono-6.0.0-py37_0/share/chrono/data/')



class PyEnv(py_environment.PyEnvironment):

  def __init__(self):
      self.render = 0
      
      self._action_spec = array_spec.BoundedArraySpec(
        shape=(), dtype=np.int32, minimum=0, maximum=1, name='action')
      self._observation_spec = array_spec.BoundedArraySpec(
        shape=(1,1,4), dtype=np.float64, minimum=-1000, maximum=1000, name='observation')
      self._time_step_spec = ts.TimeStep(array_spec.ArraySpec(shape=(1,), dtype = np.int32, name='step_type'), 
                                         array_spec.ArraySpec(shape=(1,), dtype = np.float32, name='reward'), 
                                         array_spec.ArraySpec(shape=(1,), dtype = np.float32, name='discount'), 
                                         self._observation_spec)
      self._episode_ended = False
      
      self.timestep = 0.01
      self.omega = 0
    # ---------------------------------------------------------------------
    #
    #  Create the simulation system and add items
    #
      
      self.rev_pend_sys = chrono.ChSystemNSC()

      chrono.ChCollisionModel.SetDefaultSuggestedEnvelope(0.001)
      chrono.ChCollisionModel.SetDefaultSuggestedMargin(0.001)

    #rev_pend_sys.SetSolverType(chrono.ChSolver.Type_BARZILAIBORWEIN) # precise, more slow
      self.rev_pend_sys.SetSolverMaxIterations(70)



    # Create a contact material (surface property)to share between all objects.
      self.rod_material = chrono.ChMaterialSurfaceNSC()
      self.rod_material.SetFriction(0.5)
      self.rod_material.SetDampingF(0.2)
      self.rod_material.SetCompliance (0.0000001)
      self.rod_material.SetComplianceT(0.0000001)



    # Create the set of rods in a vertical stack, along Y axis


      self.size_rod_y = 2.0
      self.radius_rod = 0.05
      self.density_rod = 50;    # kg/m^3

      self.mass_rod = self.density_rod * self.size_rod_y *chrono.CH_C_PI* (self.radius_rod**2);  
      self.inertia_rod_y = (self.radius_rod**2) * self.mass_rod/2;
      self.inertia_rod_x = (self.mass_rod/12)*((self.size_rod_y**2)+3*(self.radius_rod**2))
      
      self.size_table_x = 0.3;
      self.size_table_y = 0.3;
      self.size_table_z = 0.3;

      if self.render:
             
             self.myapplication = chronoirr.ChIrrApp(self.rev_pend_sys)
             self.myapplication.AddShadowAll();
             self.myapplication.SetTimestep(0.01)
             self. myapplication.SetTryRealtime(True)
             
             self.myapplication.AddTypicalSky()
             self.myapplication.AddTypicalLogo(chrono.GetChronoDataFile('logo_pychrono_alpha.png'))
             self.myapplication.AddTypicalCamera(chronoirr.vector3df(0.5,0.5,1.0))
             self.myapplication.AddLightWithShadow(chronoirr.vector3df(2,4,2),    # point
                                            chronoirr.vector3df(0,0,0),    # aimpoint
                                            9,                 # radius (power)
                                            1,9,               # near, far
                                            30)                # angle of FOV


  def _reset(self):
      
       #print("reset")
      self.isdone = False
      self.rev_pend_sys.Clear()
            # create it
      self.body_rod = chrono.ChBody()
    # set initial position
      self.body_rod.SetPos(chrono.ChVectorD(0, self.size_rod_y/2, 0 ))
    # set mass properties
      self.body_rod.SetMass(self.mass_rod)

      self.body_rod.SetInertiaXX(chrono.ChVectorD(self.inertia_rod_x,self.inertia_rod_y,self.inertia_rod_x))




    # Visualization shape, for rendering animation

      self.cyl_base1= chrono.ChVectorD(0, -self.size_rod_y/2, 0 )
      self.cyl_base2= chrono.ChVectorD(0, self.size_rod_y/2, 0 )

      self.body_rod_shape = chrono.ChCylinderShape()
      self.body_rod_shape.GetCylinderGeometry().p1= self.cyl_base1
      self.body_rod_shape.GetCylinderGeometry().p2= self.cyl_base2
      self.body_rod_shape.GetCylinderGeometry().rad= self.radius_rod

      self.body_rod.AddAsset(self.body_rod_shape)
      self.rev_pend_sys.Add(self.body_rod)


      self.body_floor = chrono.ChBody()
      self.body_floor.SetBodyFixed(True)
      self.body_floor.SetPos(chrono.ChVectorD(0, -5, 0 ))



      if self.render:
          
             self.body_floor_shape = chrono.ChBoxShape()
             self.body_floor_shape.GetBoxGeometry().Size = chrono.ChVectorD(3, 1, 3)
             self.body_floor.GetAssets().push_back(self.body_floor_shape)
             self.body_floor_texture = chrono.ChTexture()
             self.body_floor_texture.SetTextureFilename(chrono.GetChronoDataFile('textures/concrete.jpg'))
             self.body_floor.GetAssets().push_back(self.body_floor_texture)

      self.rev_pend_sys.Add(self.body_floor)



      self.body_table = chrono.ChBody()
      self.body_table.SetPos(chrono.ChVectorD(0, -self.size_table_y/2, 0 ))


      if self.render:
             self.body_table_shape = chrono.ChBoxShape()
             self.body_table_shape.GetBoxGeometry().Size = chrono.ChVectorD(self.size_table_x/2, self.size_table_y/2, self.size_table_z/2)
             self.body_table_shape.SetColor(chrono.ChColor(0.4,0.4,0.5))
             self.body_table.GetAssets().push_back(self.body_table_shape)
       
             self.body_table_texture = chrono.ChTexture()
             self.body_table_texture.SetTextureFilename(chrono.GetChronoDataFile('textures/concrete.jpg'))
             self.body_table.GetAssets().push_back(self.body_table_texture)
      self.body_table.SetMass(0.1)
      self.rev_pend_sys.Add(self.body_table)



      self.link_slider = chrono.ChLinkLockPrismatic()
      z2x = chrono.ChQuaternionD()
      z2x.Q_from_AngAxis(-chrono.CH_C_PI / 2 , chrono.ChVectorD(0, 1, 0))

      self.link_slider.Initialize(self.body_table, self.body_floor, chrono.ChCoordsysD(chrono.ChVectorD(0, 0, 0), z2x))
      self.rev_pend_sys.Add(self.link_slider)


      self.act_initpos = chrono.ChVectorD(0,0,0)
      self.actuator = chrono.ChLinkMotorLinearForce()
      self.actuator.Initialize(self.body_table, self.body_floor, chrono.ChFrameD(self.act_initpos))
      self.rev_pend_sys.Add(self.actuator)

      self.rod_pin = chrono.ChMarker()
      self.body_rod.AddMarker(self.rod_pin)
      self.rod_pin.Impose_Abs_Coord(chrono.ChCoordsysD(chrono.ChVectorD(0,0,0)))

      self.table_pin = chrono.ChMarker()
      self.body_table.AddMarker(self.table_pin)
      self.table_pin.Impose_Abs_Coord(chrono.ChCoordsysD(chrono.ChVectorD(0,0,0)))

      self.pin_joint = chrono.ChLinkLockRevolute()
      self.pin_joint.Initialize(self.rod_pin, self.table_pin)
      self.rev_pend_sys.Add(self.pin_joint)
      
      if self.render:

           # ---------------------------------------------------------------------
           #
           #  Create an Irrlicht application to visualize the system
           #
           # ==IMPORTANT!== Use this function for adding a ChIrrNodeAsset to all items
           # in the system. These ChIrrNodeAsset assets are 'proxies' to the Irrlicht meshes.
           # If you need a finer control on which item really needs a visualization proxy
           # Irrlicht, just use application.AssetBind(myitem); on a per-item basis.
       
             self.myapplication.AssetBindAll();
       
                       # ==IMPORTANT!== Use this function for 'converting' into Irrlicht meshes the assets
                       # that you added to the bodies into 3D shapes, they can be visualized by Irrlicht!
       
             self.myapplication.AssetUpdateAll();

      self.omega = self.pin_joint.GetRelWvel().Length()  
      
      self.isdone= False
      self.steps= 0
      self.timeStepVal = 0
      
      self._state = [self.link_slider.GetDist(), self.link_slider.GetDist_dt(), self.pin_joint.GetRelAngle(), self.omega]
      #return np.array(self.state)
      return ts.restart(np.array([self._state], dtype=np.float64))
  

  def _step(self, action):
       
       forceToApply = 0.0
       self.timeStepVal += 1
       if action == 1:
           forceToApply = chrono.ChFunction_Const(1)
       else:
           forceToApply = chrono.ChFunction_Const(-1)
           
       self.actuator.SetForceFunction(forceToApply)
       self.omega = self.pin_joint.GetRelWvel().Length()  
       
       if self.render:
              self.myapplication.GetDevice().run()
              self.myapplication.BeginScene()
              self.myapplication.DrawAll()
              self.myapplication.DoStep()
       else:
              self.rev_pend_sys.DoStepDynamics(self.timestep)
                  
       if self.render:
              self.myapplication.EndScene()
       
       self._state = [self.link_slider.GetDist(), self.link_slider.GetDist_dt(), self.pin_joint.GetRelAngle(), self.omega]
       
       
       
       if abs(self.link_slider.GetDist()) > 2 or abs(self.pin_joint.GetRelAngle()) >  0.2: #Went out of bounds
            #return np.array(self.state), np.array([0]), np.array([1])
            return ts.termination(np.array([self._state], dtype=np.float64), reward=0.0)
       elif self.steps > 100: #Finished
            #return np.array(self.state), np.array([1]), np.array([1])
            return ts.termination(np.array([self._state], dtype=np.float64), reward=1.0)
       else: #Not done yet
            return ts.transition(np.array([self._state], dtype=np.float64), reward=0.0, discount=1.0) #Discount = 1.0 means ignore this reward.
            #return np.array(self.state), np.array([0]), np.array([1])
        
            
  def action_spec(self):
    return self._action_spec

  def observation_spec(self):
    return self._observation_spec

  def time_step_spec(self):
      return self._time_step_spec
  
    
   
    #These two arent in py_environment.PyEnvironment, the rest are
  def close(self):
        if self.render:
            self.myapplication.GetDevice().closeDevice()
            print('Destructor called, Device deleted.')
        else:
            print('Destructor called, No device to delete.')
            
  def render(self, mode: Text = 'rgb_array'):
        try: 
            self.myapplication.SetVideoframeSave(True)
            #self.myapplication.SetVideoframeSaveInterval(interval)
              
        except:
            print('No ChIrrApp found. Cannot save video frames.')
            
    

失败的代码是:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import base64
import imageio
import IPython
import matplotlib
import matplotlib.pyplot as plt
import PIL.Image
import pyvirtualdisplay

import PyEnv

import tensorflow as tf

from tf_agents.agents.categorical_dqn import categorical_dqn_agent
from tf_agents.drivers import dynamic_step_driver
from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment
from tf_agents.eval import metric_utils
from tf_agents.metrics import tf_metrics
from tf_agents.networks import categorical_q_network
from tf_agents.policies import random_tf_policy
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.trajectories import trajectory
from tf_agents.utils import common

tf.compat.v1.enable_v2_behavior()


# Set up a virtual display for rendering OpenAI gym environments.
#display = pyvirtualdisplay.Display(visible=0, size=(1400, 900)).start()

#env_name = "CartPole-v1" # @param {type:"string"}
num_iterations = 15000 # @param {type:"integer"}

initial_collect_steps = 1000  # @param {type:"integer"} 
collect_steps_per_iteration = 1  # @param {type:"integer"}
replay_buffer_capacity = 100000  # @param {type:"integer"}

fc_layer_params = (100,)

batch_size = 64  # @param {type:"integer"}
learning_rate = 1e-3  # @param {type:"number"}
gamma = 0.99
log_interval = 200  # @param {type:"integer"}

num_atoms = 51  # @param {type:"integer"}
min_q_value = -20  # @param {type:"integer"}
max_q_value = 20  # @param {type:"integer"}
n_step_update = 2  # @param {type:"integer"}

num_eval_episodes = 10  # @param {type:"integer"}
eval_interval = 1000  # @param {type:"integer"}


train_py_env = PyEnv.PyEnv()
eval_py_env = PyEnv.PyEnv()

train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)



categorical_q_net = categorical_q_network.CategoricalQNetwork(
    train_env.observation_spec(),
    train_env.action_spec(),
    num_atoms=num_atoms,
    fc_layer_params=fc_layer_params)


optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)

train_step_counter = tf.compat.v2.Variable(0)

agent = categorical_dqn_agent.CategoricalDqnAgent(
    train_env.time_step_spec(),
    train_env.action_spec(),
    categorical_q_network=categorical_q_net,
    optimizer=optimizer,
    min_q_value=min_q_value,
    max_q_value=max_q_value,
    n_step_update=n_step_update,
    td_errors_loss_fn=common.element_wise_squared_loss,
    gamma=gamma,
    train_step_counter=train_step_counter)
agent.initialize()


def compute_avg_return(environment, policy, num_episodes=10):

  total_return = 0.0
  for _ in range(num_episodes):

    time_step = environment.reset()
    episode_return = 0.0

    while not time_step.is_last():
      action_step = policy.action(time_step)
      time_step = environment.step(action_step.action)
      episode_return += time_step.reward
    total_return += episode_return

  avg_return = total_return / num_episodes
  return avg_return.numpy()[0]


random_policy = random_tf_policy.RandomTFPolicy(train_env.time_step_spec(),
                                                train_env.action_spec())

compute_avg_return(eval_env, random_policy, num_eval_episodes)

# Please also see the metrics module for standard implementations of different
# metrics.


replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    data_spec=agent.collect_data_spec,
    batch_size=train_env.batch_size,
    max_length=replay_buffer_capacity)

def collect_step(environment, policy):
  time_step = environment.current_time_step()
  action_step = policy.action(time_step)
  next_time_step = environment.step(action_step.action)
  traj = trajectory.from_transition(time_step, action_step, next_time_step)

  # Add trajectory to the replay buffer
  replay_buffer.add_batch(traj)

for _ in range(initial_collect_steps):
  collect_step(train_env, random_policy)

错误跟踪:

runfile('/Users/maxwelllittle/opt/anaconda3/pkgs/pychrono-6.0.0-py37_0/lib/python3.7/site-packages/pychrono/Auxon1/C51LongTest.py', wdir='/Users/maxwelllittle/opt/anaconda3/pkgs/pychrono-6.0.0-py37_0/lib/python3.7/site-packages/pychrono/Auxon1')
Reloaded modules: PyEnv
Traceback (most recent call last):

  File "/Users/maxwelllittle/opt/anaconda3/pkgs/pychrono-6.0.0-py37_0/lib/python3.7/site-packages/pychrono/Auxon1/C51LongTest.py", line 143, in <module>
    collect_step(train_env, random_policy)

  File "/Users/maxwelllittle/opt/anaconda3/pkgs/pychrono-6.0.0-py37_0/lib/python3.7/site-packages/pychrono/Auxon1/C51LongTest.py", line 140, in collect_step
    replay_buffer.add_batch(traj)

  File "/Users/maxwelllittle/opt/anaconda3/envs/snubh/lib/python3.7/site-packages/tf_agents/replay_buffers/replay_buffer.py", line 83, in add_batch
    return self._add_batch(items)

  File "/Users/maxwelllittle/opt/anaconda3/envs/snubh/lib/python3.7/site-packages/tf_agents/replay_buffers/tf_uniform_replay_buffer.py", line 205, in _add_batch
    write_data_op = self._data_table.write(write_rows, items)

  File "/Users/maxwelllittle/opt/anaconda3/envs/snubh/lib/python3.7/site-packages/tf_agents/replay_buffers/table.py", line 131, in write
    for (slot, value) in zip(flattened_slots, flattened_values)

  File "/Users/maxwelllittle/opt/anaconda3/envs/snubh/lib/python3.7/site-packages/tf_agents/replay_buffers/table.py", line 131, in <listcomp>
    for (slot, value) in zip(flattened_slots, flattened_values)

  File "/Users/maxwelllittle/opt/anaconda3/envs/snubh/lib/python3.7/site-packages/tensorflow/python/ops/state_ops.py", line 306, in scatter_update
    name=name))

  File "/Users/maxwelllittle/opt/anaconda3/envs/snubh/lib/python3.7/site-packages/tensorflow/python/ops/gen_resource_variable_ops.py", line 1120, in resource_scatter_update
    _ops.raise_from_not_ok_status(e, name)

  File "/Users/maxwelllittle/opt/anaconda3/envs/snubh/lib/python3.7/site-packages/tensorflow/python/framework/ops.py", line 6862, in raise_from_not_ok_status
    six.raise_from(core._status_to_exception(e.code, message), None)

  File "<string>", line 3, in raise_from

InvalidArgumentError: Must have updates.shape = indices.shape + params.shape[1:] or updates.shape = [], got updates.shape [1], indices.shape [1], params.shape [100000,1] [Op:ResourceScatterUpdate]

标签: pythontensorflowmachine-learning

解决方案


推荐阅读