首页 > 解决方案 > 网络在训练期间不会改变权重,pytorch

问题描述

我正在实施 DDPG 并在训练我的两个网络时陷入困境。

总而言之,我有 4 个网络称为:actor、actor_target、critic 和critic_target。我正在训练循环中训练演员和评论家,并对其他两个网络进行软更新:

def update_weights(self, source, tau):
        for target, source in zip(self.parameters(), source.parameters()):
            target.data.copy_(tau * source.data + (1 - tau) * target.data)

我的训练循环如下所示:

tensor_next_states = torch.tensor(next_states).view(-1, 1)
prediction_target = self.actor_target(tensor_next_states).data.numpy()
target_critic_output = self.critic_target(
    construct_tensor(next_states, prediction_target))
y = torch.tensor(rewards).view(-1,1) + \
    self.gamma * target_critic_output
output_critic = self.critic(
    torch.tensor(construct_tensor(states, actions), dtype=torch.float))

# compute loss and update critic
self.critic.zero_grad()
loss_critic = self.criterion_critic(y, output_critic)
loss_critic.backward()
self.critic_optim.step()

# compute loss and update actor
tensor_states = torch.tensor(states).view(-1, 1)
ouput_actor = self.actor(tensor_states).data.numpy()
self.actor.zero_grad()
loss_actor = (-1.) * \
             self.critic(construct_tensor(states, ouput_actor)).mean()
loss_actor.backward()
self.actor_optim.step()

# update target
self.actor_target.update_weights(self.actor, self.tau)
self.critic_target.update_weights(self.critic, self.tau)

SGD用作优化器和self.criterion_critic = F.mse_loss.

construct_tensor(a,b)构造一个类似 的张量[a[0], b[0], a[1], b[1], ...]

我注意到,训练前后测试集上的 RMSE 是相同的。所以我调试了很多,注意到update_weights训练网络和目标网络的权重是相同的——所以我得出结论,训练对训练网络的权重没有任何影响。我已经检查了计算的损失不是零但仍然是浮点数,检查了替换zero_grad()调用并将计算的损失移动到self,这没有任何影响。

是否有人已经遇到过这种行为和/或有任何提示或知道如何解决这个问题?

更新:完整代码:

import datetime
import random
from collections import namedtuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


def combine_tensors(s, a):
    """
    Combines the two given tensors
    :param s: tensor1
    :param a: tensor2
    :return: combined tensor
    """
    target = []
    if not len(a[0].shape) == 0:
        for i in range(len(s)):
            target.append(torch.cat((s[i], a[i])).data.numpy())
    else:
        for i in range(len(s)):
            target.append(torch.cat((s[i], a[i].float().view(-1))) \
                          .data.numpy())
    return torch.tensor(target, device=device)


class actor(nn.Module):
    """
    Actor - gets a state (2-dim) and returns probabilities about which
    action to take (4 actions -> 4 outputs)
    """

    def __init__(self):
        super(actor, self).__init__()

        # define net structure
        self.input_layer = nn.Linear(2, 4)
        self.hidden_layer_1 = nn.Linear(4, 8)
        self.hidden_layer_2 = nn.Linear(8, 16)
        self.hidden_layer_3 = nn.Linear(16, 32)
        self.output_layer = nn.Linear(32, 4)

        # initialize them
        nn.init.xavier_uniform_(self.input_layer.weight)
        nn.init.xavier_uniform_(self.hidden_layer_1.weight)
        nn.init.xavier_uniform_(self.hidden_layer_2.weight)
        nn.init.xavier_uniform_(self.hidden_layer_3.weight)
        nn.init.xavier_uniform_(self.output_layer.weight)

        nn.init.constant_(self.input_layer.bias, 0.1)
        nn.init.constant_(self.hidden_layer_1.bias, 0.1)
        nn.init.constant_(self.hidden_layer_2.bias, 0.1)
        nn.init.constant_(self.hidden_layer_3.bias, 0.1)
        nn.init.constant_(self.output_layer.bias, 0.1)

    def forward(self, state):
        state = F.relu(self.input_layer(state))
        state = F.relu(self.hidden_layer_1(state))
        state = F.relu(self.hidden_layer_2(state))
        state = F.relu(self.hidden_layer_3(state))
        state = F.softmax(self.output_layer(state), dim=0)
        return state


class critic(nn.Module):
    """
    Critic - gets a state (2-dim) and an action and returns value
    """

    def __init__(self):
        super(critic, self).__init__()
        # define net structure
        self.input_layer = nn.Linear(3, 8)
        self.hidden_layer_1 = nn.Linear(8, 16)
        self.hidden_layer_2 = nn.Linear(16, 32)
        self.hidden_layer_3 = nn.Linear(32, 16)
        self.output_layer = nn.Linear(16, 1)

        # initialize them
        nn.init.xavier_uniform_(self.input_layer.weight)
        nn.init.xavier_uniform_(self.hidden_layer_1.weight)
        nn.init.xavier_uniform_(self.hidden_layer_2.weight)
        nn.init.xavier_uniform_(self.hidden_layer_3.weight)
        nn.init.xavier_uniform_(self.output_layer.weight)

        nn.init.constant_(self.input_layer.bias, 0.1)
        nn.init.constant_(self.hidden_layer_1.bias, 0.1)
        nn.init.constant_(self.hidden_layer_2.bias, 0.1)
        nn.init.constant_(self.hidden_layer_3.bias, 0.1)
        nn.init.constant_(self.output_layer.bias, 0.1)

    def forward(self, state_, action_):
        state_ = combine_tensors(state_, action_)
        state_ = F.relu(self.input_layer(state_))
        state_ = F.relu(self.hidden_layer_1(state_))
        state_ = F.relu(self.hidden_layer_2(state_))
        state_ = F.relu(self.hidden_layer_3(state_))
        state_ = self.output_layer(state_)
        return state_


Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))


class ReplayMemory(object):
    """
    Memory
    """

    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, *args):
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)


def compute_action(actor_trainined, state, eps=0.1):
    """
    Computes an action given the actual policy, the state and an eps.
    Eps is resposible for the amount of exploring
    :param actor_trainined: actual policy
    :param state:
    :param eps: float in [0,1]
    :return:
    """
    denoise = random.random()

    if denoise > eps:
        action_probs = actor_trainined(state.float())
        return torch.argmax(action_probs).view(1).int()
    else:
        return torch.randint(0, 4, (1,)).view(1).int()


def compute_next_state(_action, _state):
    """
    Computes the next state given an action and a state
    :param _action:
    :param _state:
    :return:
    """
    state_ = _state.clone()
    if _action.item() == 0:
        state_[1] += 1
    elif _action.item() == 1:
        state_[1] -= 1
    elif _action.item() == 2:
        state_[0] -= 1
    elif _action.item() == 3:
        state_[0] += 1

    return state_


def update_weights(target, source, tau):
    """
    Soft-Update of weights
    :param target:
    :param source:
    :param tau:
    :return:
    """
    for target, source in zip(target.parameters(), source.parameters()):
        target.data.copy_(tau * source.data + (1 - tau) * target.data)


def update(transition__, replay_memory, batch_size_, gamma_):
    """
    Performs one update step
    :param transition__:
    :param replay_memory:
    :param batch_size_:
    :param gamma_:
    :return:
    """
    replay_memory.push(*transition__)

    if replay_memory.__len__() < batch_size_:
        return

    transitions = replay_memory_.sample(batch_size_)
    batch = Transition(*zip(*transitions))

    states = torch.stack(batch.state)
    actions = torch.stack(batch.action)
    rewards = torch.stack(batch.reward)
    next_states = torch.stack(batch.next_state)

    action_target = torch.argmax(actor_target(next_states.float()), 1).int()
    y = (
            rewards.float().view(-1, 1) +
            gamma_ * critic_target(next_states.float(), action_target.float())
            .float()
    )

    critic_trained.zero_grad()
    crit_ = critic_trained(states.float(), actions.float())
    # nn stuff does not work here! -> doing mse myself..
    # loss_critic = (torch.sum((y.float() - crit_.float()) ** 2.)
    #                / y.data.nelement())
    loss_critic = F.l1_loss(y.float(), crit_.float())
    loss_critic.backward()
    optimizer_critic.step()

    actor_trained.zero_grad()
    loss_actor = ((-1.) * critic_trained(states.float(),
                                         torch.argmax(
                                             actor_trained(states.float()), 1
                                         ).int().float())).mean()
    loss_actor.backward()
    optimizer_actor.step()


def get_eps(epoch):
    """
    Computes the eps for action choosing dependant on the epoch
    :param epoch: number of epoch
    :return:
    """
    if epoch <= 10:
        eps_ = 1.
    elif epoch <= 20:
        eps_ = 0.8
    elif epoch <= 40:
        eps_ = 0.6
    elif epoch <= 60:
        eps_ = 0.4
    elif epoch <= 80:
        eps_ = 0.2
    else:
        eps_ = 0.1
    return eps_


def compute_reward_2(state_, next_state_, terminal_state_):
    """
    Better (?) reward function that "compute_reward"
    If next_state == terminal_state -> reward = 100
    If next_state illegal           -> reward = -100
    if next_state is further away from terminal_state than state_ -> -2
    else 1
    :param state_:
    :param next_state_:
    :param terminal_state_:
    :return:
    """
    if torch.eq(next_state_, terminal_state_).all():
        reward_ = 100
    elif torch.eq(next_state_.abs(), 15).any():
        reward_ = -100
    else:
        if (state_.abs() > next_state_.abs()).any():
            reward_ = 1.
        else:
            reward_ = -2
    return torch.tensor(reward_, device=device, dtype=torch.float)


def compute_reward(next_state_, terminal_state_):
    """
    Computes some reward
    :param next_state_:
    :param terminal_state_:
    :return:
    """
    if torch.eq(next_state_, terminal_state_).all():
        return torch.tensor(100., device=device, dtype=torch.float)
    elif next_state_[0] == 15 or next_state_[1] == 15:
        return torch.tensor(-100., device=device, dtype=torch.float)
    else:
        return (-1.) * next_state_.abs().sum().float()


def fill_memory_2():
    """
    Fills the memory with random transitions which got a "good" action chosen
    """
    terminal_state_ = torch.tensor([0, 0], device=device, dtype=torch.int)
    while replay_memory_.__len__() < batch_size:
        state_ = torch.randint(-4, 4, (2,)).to(device).int()
        if state_[0].item() == 0 and state_[1].item == 0:
            continue

        # try to find a "good" action
        if state_[0].item() == 0:
            if state_[1].item() > 0:
                action_ = torch.tensor(1, device=device, dtype=torch.int)
            else:
                action_ = torch.tensor(0, device=device, dtype=torch.int)
        elif state_[1].item() == 0:
            if state_[0].item() > 0:
                action_ = torch.tensor(2, device=device, dtype=torch.int)
            else:
                action_ = torch.tensor(3, device=device, dtype=torch.int)
        else:
            random_bit = random.random()
            if random_bit > 0.5:
                if state_[1].item() > 0:
                    action_ = torch.tensor(1, device=device, dtype=torch.int)
                else:
                    action_ = torch.tensor(0, device=device, dtype=torch.int)
            else:
                if state_[0].item() > 0:
                    action_ = torch.tensor(2, device=device, dtype=torch.int)
                else:
                    action_ = torch.tensor(3, device=device, dtype=torch.int)

        action_ = action_.view(1).int()
        next_state_ = compute_next_state(action_, state_)
        reward_ = compute_reward_2(state_, next_state_, terminal_state_)

        transition__ = Transition(state=state_, action=action_,
                                  reward=reward_, next_state=next_state_)
        replay_memory_.push(*transition__)


def fill_memory():
    """
    Fills the memory with random transitions
    """
    while replay_memory_.__len__() < batch_size:
        state_ = torch.randint(-14, 15, (2,)).to(device).int()
        if state_[0].item() == 0 and state_[1].item == 0:
            continue
        terminal_state_ = torch.tensor([0, 0], device=device, dtype=torch.int)
        action_ = torch.randint(0, 4, (1,)).view(1).int()
        next_state_ = compute_next_state(action_, state_)
        reward_ = compute_reward_2(state_, next_state_, terminal_state_)

        transition__ = Transition(state=state_, action=action_,
                                  reward=reward_, next_state=next_state_)
        replay_memory_.push(*transition__)


if __name__ == '__main__':
    # get device if possible
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # set seed
    seed_ = 0
    random.seed(seed_)  # seed of python
    if device == "cuda":
        # cuda seed
        torch.cuda.manual_seed(seed_)
    else:
        # cpu seed
        torch.manual_seed(seed_)

    # initialize the nets
    actor_trained = actor().to(device)
    actor_target = actor().to(device)
    # copy -> _trained eqaul _target
    actor_target.load_state_dict(actor_trained.state_dict())
    optimizer_actor = optim.RMSprop(actor_trained.parameters())
    # move them to the device
    critic_trained = critic().to(device)
    critic_target = critic().to(device)
    critic_target.load_state_dict((critic_trained.state_dict()))
    actor_target.load_state_dict((actor_trained.state_dict()))
    # used optimizer
    optimizer_critic = optim.RMSprop(critic_trained.parameters(),
                                     momentum=0.9, weight_decay=0.001)

    # replay memory
    capacity_replay_memory = 16384
    replay_memory_ = ReplayMemory(capacity_replay_memory)

    # hyperparams
    batch_size = 1024
    gamma = 0.7
    tau = 0.01
    num_epochs = 256

    # fill replay memory such that batching is possible
    fill_memory_2()

    # Print params
    printing_while_training = True
    printing_while_testing = False

    print('######################## Training ########################')
    starting_time = datetime.datetime.now()
    for i in range(num_epochs):
        # random state
        starting_state = torch.randint(-14, 15, (2,)).to(device).int()
        # skip if terminal state
        if starting_state[0].item() == 0 and starting_state[0].item() == 0:
            continue
        state = starting_state.clone()
        # terminal state
        terminal_state = torch.tensor([0, 0], device=device, dtype=torch.int)
        iteration = 0

        # get eps for exploring
        eps = get_eps(i)

        running_reward = 0.

        # training loos
        while True:
            # compute action and next state
            action = compute_action(actor_trained, state, eps)
            next_state = compute_next_state(action, state)

            # finished if next state is terminal state
            if torch.eq(next_state, terminal_state).all():
                reward = compute_reward_2(state, next_state, terminal_state)
                running_reward += reward.item()
                transition_ = Transition(state=state, action=action,
                                         reward=reward, next_state=next_state)
                replay_memory_.push(*transition_)
                if printing_while_training:
                    print('{}: Finished after {} iterations with reward {} '
                          'in state {} starting from {}'
                          .format(i + 1, iteration + 1, running_reward,
                                  next_state.data.numpy(),
                                  starting_state.data.numpy()))
                break
            # abort if illegal state
            elif torch.eq(next_state.abs(), 15).any() or iteration == 99:
                reward = compute_reward_2(state, next_state, terminal_state)
                running_reward += reward
                transition_ = Transition(state=state, action=action,
                                         reward=reward, next_state=next_state)
                replay_memory_.push(*transition_)
                if printing_while_training:
                    print('{}: Aborted after {} iterations with reward {} '
                          'in state {} starting from {}'
                          .format(i + 1, iteration + 1, running_reward,
                                  next_state.data.numpy(),
                                  starting_state.data.numpy()))
                break

            # compute immediate reward
            reward = compute_reward_2(state, next_state, terminal_state)
            # save it - only for logging purposes
            running_reward += reward.item()

            # construct transition
            transition_ = Transition(state=state, action=action, reward=reward,
                                     next_state=next_state)

            # update model
            update(transition_, replay_memory_, batch_size, gamma)
            # perform soft updates
            update_weights(actor_target, actor_trained, tau)
            update_weights(critic_target, critic_trained, tau)

            state = next_state
            iteration += 1
    print('Ended after: {}'.format(datetime.datetime.now() - starting_time))

    print('######################## Testing ########################')
    starting_time = datetime.datetime.now()
    test_states = [torch.tensor([i, j], device=device, dtype=torch.int)
                   for i in range(-15, 14) for j in range(-15, 14)]
    finished = 0
    aborted = 0
    aborted_reward = []
    finished_reward = []

    for starting_state in test_states:
        state = starting_state.clone()
        terminal_state = torch.tensor([0, 0], device=device, dtype=torch.int)
        iteration = 0
        reward = 0.

        while True:
            action = torch.argmax(actor_target(state.float())).view(1).int()
            next_state = compute_next_state(action, state)

            if torch.eq(next_state, terminal_state).all():
                reward += compute_reward_2(state, next_state,
                                           terminal_state)
                finished_reward.append(reward.item())
                if printing_while_testing:
                    print('{}: Finished after {} iterations with reward {} '
                          'in state {} starting from {}'
                          .format(starting_state.data.numpy(), iteration + 1,
                                  reward.item(), next_state.data.numpy(),
                                  starting_state.data.numpy()))
                finished += 1
                break
            elif torch.eq(next_state.abs(), 15).any():
                reward += compute_reward_2(state, next_state,
                                           terminal_state)
                aborted_reward.append(reward.item())
                if printing_while_testing:
                    print('{}: Aborted after {} iterations with reward {} '
                          'in state {} starting from {}'
                          .format(starting_state.data.numpy(), iteration + 1,
                                  reward.item(), next_state.data.numpy(),
                                  starting_state.data.numpy()))
                aborted += 1
                break
            elif iteration > 500:
                if printing_while_testing:
                    print('Aborting due to more than 500 iterations! '
                          'Started from {}'.format(
                        starting_state.data.numpy()))
                aborted += 1
                break
            reward += compute_reward_2(state, next_state, terminal_state)
            state = next_state
            iteration += 1

    print('Ended after: {}'.format(datetime.datetime.now() - starting_time))
    print('Finished: {}, aborted: {}'.format(finished, aborted))
    print('Reward mean finished: {}, aborted: {}'
          .format(np.mean(finished_reward), np.mean(aborted_reward)))

我已经尝试过使用其他奖励功能,但它没有任何效果......

此外,我尝试使用一些不那么激进的探索,optim.SGD而不是optim.RMSprop- 两者都没有效果。

标签: pythondeep-learningpytorchreinforcement-learning

解决方案


这现在可能是您的代码工作的直接答案或秘诀,但我有一些最初的担忧可能会帮助您调试代码。

我相信,最大的问题是您对不是张量的数据类型执行了多次转换。例如,您多次调用combine_tensors()函数,它将给定的张量转换为numpy()并在返回值时创建新的张量。其他时候你调用你的网络来执行前向传递,并给它们转换为float()函数作为参数的张量。在张量上也有对int()的调用。所有这些调用都会导致张量的操作图丢失,该操作图用于计算back()上的梯度称呼。这在 PyTorch 文档中进行了描述,在此框架中编写 RL 算法之前应该理解这一点。在训练函数中的整个过程中处理张量是很重要的——从将经验批次转换为张量,到调用后向函数。

仅此一点还不能保证学习将以正确的方式进行。例如,当你使用目标网络来估计批评家的损失时,你应该分离结果以防止目标网络中的梯度计算(尽管,如果你使用优化器,并且只注册批评家参数,它更像是一个性能问题,因为step()调用不会更新目标网络的参数)。

当您的代码中解决了这两个问题时,您可能会观察到更正确的行为。我在这里的附加评论是我并不真正理解您的部分代码,我认为这不是正确的 DDPG 实现(即您在演员网络输出上使用argmax()并将其提供给批评网络,这看起来不像正确的方法)。

我建议您退后一步,对 PyTorch 框架和想法有更多的了解,并寻找一些 DDPG 的基线实现,以确保您知道如何逐步执行计算。


推荐阅读