首页 > 解决方案 > Snake 的遗传算法不收敛

问题描述

我正在尝试用遗传算法训练 AI 玩蛇。我正在使用 Python 库 NEAT 进行培训。问题是训练不收敛,人工智能不学习。这是训练代码:

class SnakeEnv():

def __init__(self, screen):
    self.action_space = np.array([0, 1, 2, 3])
    self.state = None
    pygame.init()
    self.screen = screen
    self.snakes = [] 
    self.total_reward = 0

def reset(self):
    self.__init__()

    
def get_state(self):
    return np.reshape(self.snake.board, (400, 1)).T / 5

def render(self, snake):
    self.screen.fill((0, 0, 0))
    snake.food.render()
    snake.render()
    pygame.display.flip()

def step(self, snake, action):
    snake.move(action)
    self.render(snake)

def close(self):
    pygame.quit()


def eval_genomes(self, genomes, config):
    global nets_g
    nets_g = []
    nets = []
    snakes = []
    global ge_g
    ge_g = []
    ge = []
    for genome_id, genome in genomes:
        genome.fitness = 0
        net = neat.nn.FeedForwardNetwork.create(genome, config)
        nets.append(net)
        snakes.append(Snake(self.screen))
        ge.append(genome)
    
    ge_g = ge.copy()
    nets_g = nets.copy()
    run = True
    #Main loop
    while run:
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                run = False
                pygame.quit()
                quit()
                break

        for x, snake in enumerate(snakes):
            if(snake.done):
                continue
            ge[x].fitness += 0.1

            """
            Inputs to the neural net:
            Vertical distance from food to head
            Horizontal distance from food to head
            Vertical distance to nearest wall from head
            Horizontal distance to nearest wall from head
            Distance from head to body segment (default -1)
            """

            snake_x = snake.head.x
            snake_y = snake.head.y
            food_x = snake.food.x 
            food_y = snake.food.y 

            food_vert = snake_y - food_y
            food_horz = snake_x - food_x
            wall_vert = min(snake_y, 600 - snake_y)
            wall_horz = min(snake_x, 600 - snake_x)
            body_front = snake.body_front()
            output = np.argmax(nets[snakes.index(snake)].activate((food_vert, food_horz, wall_vert, wall_horz, body_front)))
            state = snake.move(output)
            if state["Food"] == True:
                ge[snakes.index(snake)].fitness += 1

            if state["Died"] == True:
                ge[snakes.index(snake)].fitness -= 1
                #nets.pop(snakes.index(snake))
                #ge.pop(snakes.index(snake))
                #snakes.pop(snakes.index(snake))
            all_done = [snake.done for snake in snakes]
            if(False not in all_done):
                run = False


def run(self, config_file):
    config = neat.config.Config(neat.DefaultGenome, neat.DefaultReproduction, neat.DefaultSpeciesSet, neat.DefaultStagnation, config_file)
    population = neat.Population(config)
    population.add_reporter(neat.StdOutReporter(True))
    stats = neat.StatisticsReporter()
    population.add_reporter(stats)
    best = population.run(self.eval_genomes, 200)
    print('\nBest genome:\n{!s}'.format(best))
    best_net = nets_g[ge_g.index(best)]
    pickle.dump(best_net, open('best.pkl', 'wb'))

(假装我的代码是缩进的,编辑器由于某种原因无法工作)这是conf.txt文件:

[NEAT]
fitness_criterion     = max
fitness_threshold     = 20
pop_size              = 50
reset_on_extinction   = False

[DefaultGenome]
# node activation options
activation_default      = relu
activation_mutate_rate  = 0.0
activation_options      = relu

# node aggregation options
aggregation_default     = sum
aggregation_mutate_rate = 0.0
aggregation_options     = sum

# node bias options
bias_init_mean          = 0.0
bias_init_stdev         = 1.0
bias_max_value          = 10.0
bias_min_value          = -10.0
bias_mutate_power       = 0.5
bias_mutate_rate        = 0.9
bias_replace_rate       = 0.1

# genome compatibility options
compatibility_disjoint_coefficient = 1.0
compatibility_weight_coefficient   = 0.5

# connection add/remove rates
conn_add_prob           = 0.7
conn_delete_prob        = 0.7

# connection enable options
enabled_default         = True
enabled_mutate_rate     = 0.01

feed_forward            = True
initial_connection      = full

# node add/remove rates
node_add_prob           = 0.7
node_delete_prob        = 0.7

# network parameters
num_hidden              = 0
num_inputs              = 5
num_outputs             = 4

# node response options
response_init_mean      = 1.0
response_init_stdev     = 0.0
response_max_value      = 30.0
response_min_value      = -30.0
response_mutate_power   = 0.0
response_mutate_rate    = 0.0
response_replace_rate   = 0.0

# connection weight options
weight_init_mean        = 0.0
weight_init_stdev       = 1.0
weight_max_value        = 30
weight_min_value        = -30
weight_mutate_power     = 0.5
weight_mutate_rate      = 0.8
weight_replace_rate     = 0.1

[DefaultSpeciesSet]
compatibility_threshold = 3.0

[DefaultStagnation]
species_fitness_func = max
max_stagnation       = 20
species_elitism      = 2

[DefaultReproduction]
elitism            = 2
survival_threshold = 0.2

如您所见,我训练了 200 代。结果很奇怪。这条蛇一直吃到一块食物,但很快就撞到了墙上。这是一种学习,但并不完全。我试图让它训练更多代,但没有区别。我认为问题可能出在我对神经网络的输入上,但我不确定。

编辑:我更改了网络架构,使其现在有 4 个输出节点并relu激活。现在的问题是代码在计算输出的步骤上冻结(output = np.argmax(nets[snakes.index(snake)].activate((food_vert, food_horz, wall_vert, wall_horz, body_front)))

标签: pythonmachine-learninggenetic-algorithm

解决方案


通过浏览您的代码,您似乎有一些错误:

for x, snake in enumerate(snakes):
    ge[x].fitness += 0.1

for循环中,您正在pop()从 thesnakesge列表中获取元素。在 Python 中,您不应该在迭代列表时更改它。稍后在循环中,您将使用snakes.index(snake)insted ofx来索引同一个列表。正因为如此,活着的回报很可能会落到错误的蛇身上。

您可以在迭代之前复制列表,但snakes.index(snake)到处重复也是一种反模式。您需要找到不同的解决方案。例如,您可以使用snake.dead标志。

输出形状

您似乎将单个神经元的输出缩放到整数范围。这使得解决 NN 的任务有点困难(但并非不可能),因为接近的数字实际上并没有映射到类似的动作。

更常见的方法是为每个输出使用单独的神经元,并选择激活度最高的动作。(或者使用 softmax 来选择具有随机概率的动作。这会增加噪音,但会使适应度景观更加平滑,因为即使是权重的微小变化也会对适应度产生一些影响。)

一般建议

你不能期望编写没有错误的代码。当您的代码是优化循环的一部分时,调试非常棘手,因为优化会改变错误的影响。

首先在更简单的设置中运行您的代码。例如,您可以忽略神经网络的输出并始终执行相同的操作(或随机操作)。想想应该发生什么。也许可以逐步手动跟踪一些蛇和它们的奖励,例如使用打印语句。

关键是:减少你同时调试的东西的数量。


推荐阅读