问小白 wenxiaobai
资讯
历史
科技
环境与自然
成长
游戏
财经
文学与艺术
美食
健康
家居
文化
情感
汽车
三农
军事
旅行
运动
教育
生活
星座命理

深入理解PPO算法:从原理到实现

创作时间:
作者:
@小白创作中心

深入理解PPO算法:从原理到实现

引用
CSDN
1.
https://m.blog.csdn.net/qq_56683019/article/details/145860133

PPO(Proximal Policy Optimization,近端策略优化)是强化学习领域中一种广泛使用且表现优异的算法。它由OpenAI提出,旨在解决策略优化中不稳定和样本效率低的问题。本文将从原理到实现,深入探讨PPO算法的核心思想、具体实现步骤以及其在实际场景中的应用。

1. 引言

在强化学习领域,PPO(Proximal Policy Optimization,近端策略优化)是一种广泛使用且表现优异的算法。它由OpenAI提出,旨在解决策略优化中不稳定和样本效率低的问题。与传统策略梯度方法相比,PPO稳定性更强,且在诸多任务上表现优异。

2. PPO算法的背景

强化学习中的策略优化方法大体可以分为两类:基于值的算法(如DQN)和基于策略的算法(如策略梯度方法)。策略梯度方法直接优化策略函数,使智能体能够在复杂、高维的环境中获得良好的决策能力。然而,直接优化策略可能会导致策略更新过大,导致学习过程不稳定或样本效率低下。

为了解决这个问题,出现了TRPO(Trust Region Policy Optimization)算法,它通过限制策略更新的范围,避免过度更新。然而,TRPO的优化过程复杂且计算开销较大。PPO在此基础上进行改进,通过引入“剪切”(Clipping)等技术简化了优化过程,大幅度提升了算法的稳定性和样本效率。

3. PPO算法的核心思想

PPO的核心思想是限制策略更新的范围,使其不会偏离旧策略太远。PPO主要通过两种方法来实现策略的限制更新:剪切法(Clipping)和KL散度惩罚法(KL Penalty)。其中,剪切法是PPO最常用的实现方式。

具体来说,PPO的优化目标函数为:

这里的符号解释如下:

  • :策略更新比率,表示新策略和旧策略之间的差异。
  • :优势函数,用于衡量当前动作在当前状态下的好坏。
  • :控制策略更新的幅度,一般为一个小值,如0.2。

目标函数的工作原理是:限制策略更新的范围,如果策略的更新比率超过了预设的范围(即大于1+ϵ或小于1−ϵ),则该更新将被裁剪,以防止策略发生剧烈变化。

4. PPO算法的实现步骤

图1 PPO算法的基本架构图

  1. 采样数据:使用当前策略与环境交互,采集若干个轨迹,得到状态、动作、奖励和优势函数。
  2. 计算优势函数:通常使用时序差分(Temporal Difference)方法或广义优势估计(GAE)来计算优势函数。
  3. 计算更新比率:根据旧策略和当前策略,计算比率。
  4. 更新策略参数:最小化剪切目标函数中的期望值,使策略尽可能接近“最佳策略”,并确保策略更新不会超出限定范围。
  5. 重复采样和更新:不断重复采样和策略更新,直到收敛或达到设定的迭代次数。

4.1 PPO代码实现

这里是PPO的简单实现,包括策略更新和优势估计部分。

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

# Hyperparameters
learning_rate = 3e-4
gamma = 0.99          # Discount factor
lmbda = 0.95          # GAE lambda
eps_clip = 0.2        # PPO clip parameter
K_epoch = 3           # PPO update epochs
T_horizon = 20        # Rollout length

# Policy Network
class ActorCritic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(ActorCritic, self).__init__()
        self.fc1 = nn.Linear(state_dim, 256)
        self.fc_pi = nn.Linear(256, action_dim)    # Actor output
        self.fc_v = nn.Linear(256, 1)              # Critic output

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        pi = torch.softmax(self.fc_pi(x), dim=0)
        v = self.fc_v(x)
        return pi, v

    def act(self, state):
        pi, _ = self.forward(state)
        action = torch.multinomial(pi, 1).item()
        return action

# PPO Algorithm
class PPO:
    def __init__(self, state_dim, action_dim):
        self.model = ActorCritic(state_dim, action_dim)
        self.optimizer = optim.Adam(self.model.parameters(), lr=learning_rate)

    def compute_advantage(self, rewards, values):
        deltas = [r + gamma * v_next - v for r, v_next, v in zip(rewards, values[1:], values[:-1])]
        advantages = []
        advantage = 0.0
        for delta in reversed(deltas):
            advantage = delta + gamma * lmbda * advantage
            advantages.insert(0, advantage)
        return advantages

    def update(self, rollout):
        states, actions, rewards, old_log_probs, values = rollout
        advantages = self.compute_advantage(rewards, values)
        for _ in range(K_epoch):
            pi, v = self.model(states)
            log_probs = torch.log(pi.gather(1, actions))
            ratios = torch.exp(log_probs - old_log_probs)
            surr1 = ratios * advantages
            surr2 = torch.clamp(ratios, 1 - eps_clip, 1 + eps_clip) * advantages
            actor_loss = -torch.min(surr1, surr2).mean()
            critic_loss = nn.functional.mse_loss(v, rewards)
            loss = actor_loss + 0.5 * critic_loss
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

4.2 代码说明

  1. 策略网络(Policy Network):ActorCritic类包含策略网络(fc_pi)和价值网络(fc_v),可以同时输出动作概率和状态值。
  2. PPO更新过程:
  • 通过compute_advantage函数计算广义优势估计(GAE)。
  • update函数使用剪切目标函数进行策略更新,其中surr1和surr2表示未剪切和剪切后的损失值,取其最小值来控制策略更新幅度。
  1. 运行与优化:在K_epoch次循环中重复更新,以使策略能够最大化累积奖励。

5. 为什么PPO效果如此出色?

  1. 更新限制:PPO通过限制策略的更新幅度,避免了过度更新带来的不稳定性问题。这种限制让PPO的训练更加平滑,学习过程更加稳定。
  2. 简单高效:相比TRPO,PPO不需要进行复杂的约束优化,而是通过简单的剪切操作实现约束,从而降低了计算复杂度和资源消耗。
  3. 广泛适用:PPO适用于离散和连续动作空间,并在不同类型的任务上取得了良好效果,如机器人控制、视频游戏等。

5.1 PPO的优势函数与GAE

PPO通常使用广义优势估计(Generalized Advantage Estimation, GAE)来计算优势函数。GAE是一种平衡偏差与方差的估计方法,通过衰减参数来控制估计的偏差和方差。GAE的优势在于可以更稳定地估计动作的优势值,使得策略更新的效果更好。

5.2 PPO的变体:PPO-Clip和PPO-KL

  1. PPO-Clip:即经典的剪切法,通过将更新比率限制在的范围内,确保策略更新不超过预设范围。
  2. PPO-KL:通过在损失函数中加入KL散度惩罚项来控制更新幅度。在这种方法中,如果新旧策略之间的KL散度过大,则增加惩罚项,使得更新更加保守。尽管PPO-KL在一些应用中表现良好,但大多数场景下PPO-Clip更常用。

6. PPO算法的应用场景

PPO算法已成功应用于多个实际场景,包括但不限于以下几个领域:

  • 游戏AI:PPO在复杂的游戏环境中表现出色,如《Dota 2》和《Atari》游戏。其稳定性和高效性使其成为游戏AI训练中的重要选择。
  • 机器人控制:在机器人操作中,PPO被广泛用于控制机器人的手臂、腿等部位。它的高样本效率使机器人能够在模拟环境中快速学习,减少了真实环境的训练成本。
  • 自动驾驶:PPO被用于训练自动驾驶中的决策模块。通过学习不同的驾驶场景,PPO可以帮助自动驾驶车辆更好地应对复杂路况。

7. 总结

PPO是一种简单且有效的策略优化算法,通过限制策略更新的范围,实现了稳定和高效的策略优化。它不仅在计算上更简单,还在多个复杂任务中取得了优异的表现。随着强化学习的不断发展,PPO已成为解决复杂决策问题的一项强大工具,未来可能会被应用到更多实际场景中。

© 2023 北京元石科技有限公司 ◎ 京公网安备 11010802042949号