首页 > 解决方案 > Tensorflow while_loop 帧连接

问题描述

我在 while_loop 中连接帧时遇到问题。首先,一个可以工作但在一个简单的 for 循环中的版本。这个版本很慢,这就是为什么我想运行这个代码tf.while_loop

import gym
import tensorflow as tf 
import matplotlib.pyplot as plt
import time

env = gym.make("Pong-v0")

def preprocess(frame):
    with tf.variable_scope('frame_process'):
        output_frame = tf.image.rgb_to_grayscale(frame)
        output_frame = tf.image.crop_to_bounding_box(output_frame, 35, 0, 160, 160)
        output_frame = tf.image.resize_images(output_frame,[80,80],method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
        return tf.squeeze(output_frame)

start_time = time.time()
frame = env.reset()
state = preprocess(frame)
stack = tf.stack(4 * [state], axis=2)
step_op = lambda action : env.step(action)[:3]

with tf.Session() as session:

    session.run(tf.global_variables_initializer())

    for i in range(100):
        action = tf.py_func(lambda : env.action_space.sample(),[],[tf.int64])
        frame_, reward, done = tf.py_func(step_op,[action], [tf.uint8, tf.double, tf.bool])
        state_ = preprocess(frame_)
        stack = tf.concat([stack[:,:,1:], tf.expand_dims(session.run(state_),2)], axis=2)

    print("Done in --- %s seconds ---" % (time.time() - start_time))
    stack2 = session.run(stack)

    for x in range(4):
        plt.imshow(stack2[:,:,x], cmap='gray')
        plt.show()

注意:在这个版本中,我需要session.run(state_)在我想运行的时候 运行,tf.expand_dims因为如果我没有运行,我将收到一个Illegal Instruction!并且图像将被损坏。我不知道为什么...

这是我的第二个版本 while_loop :

import tensorflow as tf
import gym
import matplotlib.pyplot as plt
import time

env = gym.make("Pong-v0")

def preprocess(frame):
    with tf.variable_scope('frame_process'):
        output_frame = tf.image.rgb_to_grayscale(frame)
        output_frame = tf.image.crop_to_bounding_box(output_frame, 35, 0, 160, 160)
        output_frame = tf.image.resize_images(output_frame,[80,80],method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
        return tf.squeeze(output_frame)

def body(index,stack):
    action = tf.py_func(lambda : env.action_space.sample(),[],[tf.int64])
    frame_, reward, done = tf.py_func(step_op,[action], [tf.uint8, tf.double, tf.bool])
    state_ = preprocess(frame_)
    state_.set_shape((80,80))
    next_stack = tf.concat([stack[:,:,1:], tf.expand_dims(state_,2)], axis=2)
    return tf.add(index, 1), next_stack


start_time = time.time()

frame = env.reset()
state = preprocess(frame)
stack = tf.stack(4 * [state], axis=2)

i = tf.constant(0)
STEPS = tf.constant(100)

while_condition = lambda i, stack: tf.less(i, STEPS)
step_op = lambda action : env.step(action)[:3]
loop_result = tf.while_loop(while_condition, body, (i, stack))

with tf.Session() as session:

    session.run(tf.global_variables_initializer())
    idx, s = session.run(loop_result)

    print("Done in --- %s seconds ---" % (time.time() - start_time))


    for x in range(4):
        plt.imshow(s[:,:,x], cmap='gray')
        plt.show()

当我运行此代码时,我得到一个Illegal Instruction!并且图像不正确。我认为这是因为我无法评估frame_之前的扩展它。在我的第一个示例中frame_执行之前是否可以评估我的?tf.expand_dimstf.while_loop

标签: pythontensorflowmachine-learningopenai-gym

解决方案


推荐阅读