文章详情

短信预约-IT技能 免费直播动态提醒

请输入下面的图形验证码

提交验证

短信预约提醒成功

训练提速17%,第四范式开源强化学习研究框架,支持单、多智能体训练

2024-11-30 14:47

关注

目前,OpenRL 已经在 GitHub 开源:

项目地址:https://github.com/OpenRL-Lab/openrl

OpenRL 初体验

OpenRL 目前可以通过 pip 进行安装:

pip install openrl

也可以通过 conda 安装:

conda install -c openrl openrl

OpenRL 为强化学习入门用户提供了简单易用的接口, 下面是一个使用 PPO 算法训练 CartPole 环境的例子:

# train_ppo.py
from openrl.envs.common import make
from openrl.modules.common import PPONet as Net
from openrl.runners.common import PPOAgent as Agent
env = make ("CartPole-v1", env_num=9) # 创建环境,并设置环境并行数为 9
net = Net (env) # 创建神经网络
agent = Agent (net) # 初始化智能体
agent.train (total_time_steps=20000) # 开始训练,并设置环境运行总步数为 20000

使用 OpenRL 训练智能体只需要简单的四步:创建环境 => 初始化模型 => 初始化智能体 => 开始训练

在普通笔记本电脑上执行以上代码,只需要几秒钟,便可以完成该智能体的训练:


此外,对于多智能体、自然语言等任务的训练,OpenRL 也提供了同样简单易用的接口。例如,对于多智能体任务中的 MPE 环境,OpenRL 也只需要调用几行代码便可以完成训练:

# train_ppo.py
from openrl.envs.common import make
from openrl.modules.common import PPONet as Net
from openrl.runners.common import PPOAgent as Agent
def train ():
    # 创建 MPE 环境,使用异步环境,即每个智能体独立运行
    env = make (
        "simple_spread",
        env_num=100,
        asynchrnotallow=True,
    )
    # 创建 神经网络,使用 GPU 进行训练
    net = Net (env, device="cuda")
    agent = Agent (net) # 初始化训练器
    # 开始训练
    agent.train (total_time_steps=5000000)
    # 保存训练完成的智能体
    agent.save ("./ppo_agent/")
if __name__ == "__main__":
    train ()

下图展示了通过 OpenRL 训练前后智能体的表现:

加载配置文件

此外,OpenRL 还同时支持从命令行和配置文件对训练参数进行修改。比如,用户可以通过执行 python train_ppo.py --lr 5e-4 来快速修改训练时候的学习率。

当配置参数非常多的时候,OpenRL 还支持用户编写自己的配置文件来修改训练参数。例如,用户可以自行创建以下配置文件 (mpe_ppo.yaml),并修改其中的参数:

# mpe_ppo.yaml
seed: 0 # 设置 seed,保证每次实验结果一致
lr: 7e-4 # 设置学习率
episode_length: 25 # 设置每个 episode 的长度
use_recurrent_policy: true # 设置是否使用 RNN
use_joint_action_loss: true # 设置是否使用 JRPO 算法
use_valuenorm: true # 设置是否使用 value normalization

最后,用户只需要在执行程序的时候指定该配置文件即可:

python train_ppo.py --config mpe_ppo.yaml

训练与测试可视化

此外,通过 OpenRL,用户还可以方便地使用 wandb 来可视化训练过程:

OpenRL 还提供了各种环境可视化的接口,方便用户对并行环境进行可视化。用户可以在创建并行环境的时候设置环境的渲染模式为 "group_human",便可以同时对多个并行环境进行可视化:

env = make ("simple_spread", env_num=9, render_mode="group_human")

此外,用户还可以通过引入 GIFWrapper 来把环境运行过程保存为 gif 动画:

from openrl.envs.wrappers import GIFWrapper
env = GIFWrapper (env, "test_simple_spread.gif")

智能体的保存和加载

OpenRL 提供 agent.save () 和 agent.load () 接口来保存和加载训练好的智能体,并通过 agent.act () 接口来获取测试时的智能体动作:

# test_ppo.py
from openrl.envs.common import make
from openrl.modules.common import PPONet as Net
from openrl.runners.common import PPOAgent as Agent
from openrl.envs.wrappers import GIFWrapper # 用于生成 gif
def test ():
    # 创建 MPE 环境
    env = make ( "simple_spread", env_num=4)
    # 使用 GIFWrapper,用于生成 gif
    env = GIFWrapper (env, "test_simple_spread.gif")
    agent = Agent (Net (env)) # 创建 智能体
    # 保存智能体
    agent.save ("./ppo_agent/")    
    # 加载智能体
    agent.load ('./ppo_agent/')
    # 开始测试
    obs, _ = env.reset ()
    while True:
        # 智能体根据 observation 预测下一个动作
        action, _ = agent.act (obs)
        obs, r, done, info = env.step (action)
        if done.any ():
            break
    env.close ()
if __name__ == "__main__":
    test ()

执行该测试代码,便可以在同级目录下找到保存好的环境运行动画文件 (test_simple_spread.gif):

训练自然语言对话任务

最近的研究表明,强化学习也可以用于训练语言模型, 并且能显著提升模型的性能。目前,OpenRL 已经支持自然语言对话任务的强化学习训练。OpenRL 通过模块化设计,支持用户加载自己的数据集 ,自定义训练模型,自定义奖励模型,自定义 wandb 信息输出以及一键开启混合精度训练等。

对于对话任务训练,OpenRL 提供了同样简单易用的训练接口:

# train_ppo.py
from openrl.envs.common import make
from openrl.modules.common import PPONet as Net
from openrl.runners.common import PPOAgent as Agent
from openrl.configs.config import create_config_parser
def train ():
    # 添加读取配置文件的代码
    cfg_parser = create_config_parser ()
    cfg = cfg_parser.parse_args ()
    # 创建 NLP 环境
    env = make ("daily_dialog",env_num=2,asynchrnotallow=True,cfg=cfg,)
    net = Net (env, cfg=cfg, device="cuda")
    agent = Agent (net)
    agent.train (total_time_steps=5000000)
if __name__ == "__main__":
    train ()

可以看出,OpenRL 训练对话任务和其他强化学习任务一样,都是通过创建交互环境的方式进行训练。

加载自定义数据集

训练对话任务,需要对话数据集。这里我们可以使用 Hugging Face 上的公开数据集(用户可以替换成自己的数据集)。加载数据集,只需要在配置文件中传入数据集的名称或者路径即可:

# nlp_ppo.yaml
data_path: daily_dialog # 数据集路径
env: # 环境所用到的参数
    args: {'tokenizer_path': 'gpt2'} # 读取 tokenizer 的路径
seed: 0 # 设置 seed,保证每次实验结果一致
lr: 1e-6 # 设置 policy 模型的学习率
critic_lr: 1e-6 # 设置 critic 模型的学习率
episode_length: 20 # 设置每个 episode 的长度
use_recurrent_policy: true

上述配置文件中的 data_path 可以设置为 Hugging Face 数据集名称或者本地数据集路径。此外,环境参数中的 tokenizer_path 用于指定加载文字编码器的 Hugging Face 名称或者本地路径。

自定义训练模型

在 OpenRL 中,我们可以使用 Hugging Face 上的模型来进行训练。为了加载 Hugging Face 上的模型,我们首先需要在配置文件 nlp_ppo.yaml 中添加以下内容:

# nlp_ppo.yaml
# 预训练模型路径
model_path: rajkumarrrk/gpt2-fine-tuned-on-daily-dialog 
use_share_model: true # 策略网络和价值网络是否共享模型
ppo_epoch: 5 # ppo 训练迭代次数

data_path: daily_dialog # 数据集名称或者路径
env: # 环境所用到的参数
    args: {'tokenizer_path': 'gpt2'} # 读取 tokenizer 的路径
lr: 1e-6 # 设置 policy 模型的学习率
critic_lr: 1e-6 # 设置 critic 模型的学习率
episode_length: 128 # 设置每个 episode 的长度
num_mini_batch: 20

然后在 train_ppo.py 中添加以下代码:

# train_ppo.py
from openrl.envs.common import make
from openrl.modules.common import PPONet as Net
from openrl.runners.common import PPOAgent as Agent
from openrl.configs.config import create_config_parser
from openrl.modules.networks.policy_value_network_gpt import (
    PolicyValueNetworkGPT as PolicyValueNetwork,
)
def train ():
    # 添加读取配置文件的代码
    cfg_parser = create_config_parser ()
    cfg = cfg_parser.parse_args ()
    # 创建 NLP 环境
    env = make ("daily_dialog",env_num=2,asynchrnotallow=True,cfg=cfg,)
    # 创建自定义神经网络
    model_dict = {"model": PolicyValueNetwork}
    net = Net (env, cfg=cfg, model_dict=model_dict)
    # 创建训练智能体
    agent = Agent (net)
    agent.train (total_time_steps=5000000)
if __name__ == "__main__":
    train ()

通过以上简单几行的修改,用户便可以使用 Hugging Face 上的预训练模型进行训练。如果用户希望分别自定义策略网络和价值网络,可以写好 CustomPolicyNetwork 以及 CustomValueNetwork 后通过以下方式从外部传入训练网络:

model_dict = {
    "policy": CustomPolicyNetwork,
    "critic": CustomValueNetwork,
}
net = Net (env, model_dict=model_dict)

自定义奖励模型

通常,自然语言任务的数据集中并不包含奖励信息。因此,如果需要使用强化学习来训练自然语言任务,就需要使用额外的奖励模型来生成奖励。在该对话任务中,我们可以使用一个复合的奖励模型,它包含以下三个部分:

●意图奖励:即当智能体生成的语句和期望的意图接近时,智能体便可以获得更高的奖励。

●METEOR 指标奖励:METEOR 是一个用于评估文本生成质量的指标,它可以用来衡量生成的语句和期望的语句的相似程度。我们把这个指标作为奖励反馈给智能体,以达到优化生成的语句的效果。

●KL 散度奖励:该奖励用来限制智能体生成的文本偏离预训练模型的程度,防止出现 reward hacking 的问题。

我们最终的奖励为以上三个奖励的加权和,其中 KL 散度奖励的系数是随着 KL 散度的大小动态变化的。想在 OpenRL 中使用该奖励模型,用户无需修改训练代码,只需要在 nlp_ppo.yaml 文件中添加 reward_class 参数即可:

# nlp_ppo.yaml
reward_class:
    id: NLPReward # 奖励模型名称
    args: {
        # 用于意图判断的模型的名称或路径
        "intent_model": rajkumarrrk/roberta-daily-dialog-intent-classifier,
        # 用于计算 KL 散度的预训练模型的名称或路径
        "ref_model": roberta-base, # 用于意图判断的 tokenizer 的名称或路径
    }

OpenRL 支持用户使用自定义的奖励模型。首先,用户需要编写自定义奖励模型 (需要继承 BaseReward 类)。接着,用户需要注册自定义的奖励模型,即在 train_ppo.py 添加以下代码:

# train_ppo.py
from openrl.rewards.nlp_reward import CustomReward
from openrl.rewards import RewardFactory
RewardFactory.register ("CustomReward", CustomReward)

最后,用户只需要在配置文件中填写自定义的奖励模型即可:

reward_class:
    id: "CustomReward" # 自定义奖励模型名称
    args: {} # 用户自定义奖励函数可能用到的参数

自定义训练过程信息输出

OpenRL 还支持用户自定义 wandb 和 tensorboard 的输出内容。例如,在该任务的训练过程中,我们还需要输出各种类型奖励的信息和 KL 散度系数的信息, 用户可以在 nlp_ppo.yaml 文件中加入 vec_info_class 参数来实现:

# nlp_ppo.yaml
vec_info_class:
    id: "NLPVecInfo" # 调用 NLPVecInfo 类以打印 NLP 任务中奖励函数的信息
# 设置 wandb 信息
wandb_entity: openrl # 这里用于指定 wandb 团队名称,请把 openrl 替换为你自己的团队名称
experiment_name: train_nlp # 这里用于指定实验名称
run_dir: ./run_results/ # 这里用于指定实验数据保存的路径
log_interval: 1 # 这里用于指定每隔多少个 episode 上传一次 wandb 数据
# 自行填写其他参数...

修改完配置文件后,在 train_ppo.py 文件中启用 wandb:

# train_ppo.py
agent.train (total_time_steps=100000, use_wandb=True)

然后执行 python train_ppo.py –config nlp_ppo.yaml,稍后,便可以在 wandb 中看到如下的输出:

从上图可以看到,wandb 输出了各种类型奖励的信息和 KL 散度系数的信息。 

如果用户还需要输出其他信息,还可以参考 NLPVecInfo 类 和 VecInfo 类来实现自己的 CustomVecInfo 类。然后,需要在 train_ppo.py 中注册自定义的 CustomVecInfo 类:

# train_ppo.py # 注册自定义输出信息类 
VecInfoFactory.register ("CustomVecInfo", CustomVecInfo)

最后,只需要在 nlp_ppo.yaml 中填写 CustomVecInfo 类即可启用:

# nlp_ppo.yaml
vec_info_class:
    id: "CustomVecInfo" # 调用自定义 CustomVecInfo 类以输出自定义信息

使用混合精度训练加速

OpenRL 还提供了一键开启混合精度训练的功能。用户只需要在配置文件中加入以下参数即可:

# nlp_ppo.yaml
use_amp: true # 开启混合精度训练

对比评测

下表格展示了使用 OpenRL 训练该对话任务的结果。结果显示使用强化学习训练后,模型各项指标皆有所提升。另外,从下表可以看出,相较于 RL4LMs , OpenRL 的训练速度更快(在同样 3090 显卡的机器上,速度提升 17% ),最终的性能指标也更好:

最后,对于训练好的智能体,用户可以方便地通过 agent.chat () 接口进行对话:

# chat.py
from openrl.runners.common import ChatAgent as Agent
def chat ():
    agent = Agent.load ("./ppo_agent", tokenizer="gpt2",)
    history = []
    print ("Welcome to OpenRL!")
    while True:
        input_text = input ("> User:")
        if input_text == "quit":
            break
        elif input_text == "reset":
            history = []
            print ("Welcome to OpenRL!")
            continue
        response = agent.chat (input_text, history)
        print (f"> OpenRL Agent: {response}")
        history.append (input_text)
        history.append (response)
if __name__ == "__main__":
    chat ()

执行 python chat.py ,便可以和训练好的智能体进行对话了:

总结

OpenRL 框架经过了 OpenRL-Lab 的多次迭代并应用于学术研究和 AI 竞赛,目前已经成为了一个较为成熟的强化学习框架。OpenRL-Lab 团队将持续维护和更新 OpenRL,欢迎大家加入我们的开源社区,一起为强化学习的发展做出贡献。更多关于 OpenRL 的信息,可以参考:

致谢

OpenRL 框架的开发吸取了其他强化学习框架的优点:

未来工作

目前,OpenRL 还处于持续开发和建设阶段,未来 OpenRL 将会开源更多功能:

OpenRL Lab 团队

OpenRL框架是由OpenRL Lab团队开发,该团队是第四范式公司旗下的强化学习研究团队。第四范式长期致力于强化学习的研发和工业应用。为了促进强化学习的产学研一体化,第四范式成立了OpenRL Lab研究团队,目标是先进技术开源和人工智能前沿探索。成立不到一年,OpenRL Lab团队已经在AAMAS发表过三篇论文,参加谷歌足球游戏 11 vs 11比赛并获得第三的成绩。团队提出的TiZero智能体,实现了首个从零开始,通过课程学习、分布式强化学习、自博弈等技术完成谷歌足球全场游戏智能体的训练:

截止 2022 年 10 月 28 日,Tizero 在及第评测平台上排名第一:

来源:机器之心内容投诉

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

软考中级精品资料免费领

  • 历年真题答案解析
  • 备考技巧名师总结
  • 高频考点精准押题
  • 2024年上半年信息系统项目管理师第二批次真题及答案解析(完整版)

    难度     813人已做
    查看
  • 【考后总结】2024年5月26日信息系统项目管理师第2批次考情分析

    难度     354人已做
    查看
  • 【考后总结】2024年5月25日信息系统项目管理师第1批次考情分析

    难度     318人已做
    查看
  • 2024年上半年软考高项第一、二批次真题考点汇总(完整版)

    难度     435人已做
    查看
  • 2024年上半年系统架构设计师考试综合知识真题

    难度     224人已做
    查看

相关文章

发现更多好内容

猜你喜欢

AI推送时光机
位置:首页-资讯-后端开发
咦!没有更多了?去看看其它编程学习网 内容吧
首页课程
资料下载
问答资讯