mujoco+spinningup进行强化学习训练快速入门

  • Post author:
  • Post category:其他




1、搭建env

目标:使用强化学习做力控,使机器人(UR5)平稳地运动到指定位置,并保持该状态。



env的几个要素

*:

​ 必要方法:step(action)、reset()、render()

​ 必要元素:action_space、observation_space

​ 我们在写环境就是实现这几个基本要素的过程



(1)初始化MuJoCo相关的组件

 self.model = mp.load_model_from_path('ur5.xml') # 从路径加载模型。
 self.sim = mp.MjSim(self.model) # MjSim 表示一个正在运行的模拟,包括其状态。
 self.viewer = mp.MjViewer(self.sim) #显示器



(2) 设置动作和状态空间

self.action_high = np.array([2, 2, 2, 1, 1, 1], dtype=np.float32)
self.action_space = spaces.Box(-self.action_high, self.action_high, dtype=np.float32)
self.observation_high = np.array([np.finfo(np.float32).max] * 12)
self.observation_space = spaces.Box(-self.observation_high, self.observation_high, dtype=np.float32)

action_space是由控制器的范围直接决定的,我们希望做一个力控,也就是把动作的输入直接给控制器上。

action_high即三个-2到2,三个-1到1,对称的

观测的是六个关节,共12个量,即对每个关节的位置和速度信息进行观测



(3) step实现

对于step的实现,直接将action输出到力控的接口

self.sim.data.ctrl[i] = action[i]

reward设为关节位置偏差求和。偏差越大,reward的值越小,反之越大

r -= abs(self.sim.data.qpos[i] - self.target[i])

state设置为各关节的位置和速度组成的向量

self.state[i] = self.sim.data.qpos[i]
self.state[i + 6] = self.sim.data.qvel[i]



(4) render

 self.viewer.render()



(5) reset

关节位置速度归零

self.sim.data.qpos[i] = 0
self.sim.data.qvel[i] = 0
self.state = [0] * 12

(6)完整环境代码

# UR5_Controller.py
import mujoco_py as mp
import numpy as np
from gym import spaces


class UR5_Controller(object):
    def __init__(self, target): #传入目标位置
        self.model = mp.load_model_from_path('ur5.xml') # 从路径加载模型。
        self.sim = mp.MjSim(self.model) # MjSim 表示一个正在运行的模拟,包括其状态。
        self.viewer = mp.MjViewer(self.sim) #显示器
        self.state = [0] * 12
        self.action_high = np.array([2, 2, 2, 1, 1, 1], dtype=np.float32)
        self.action_space = spaces.Box(-self.action_high, self.action_high, dtype=np.float32)
        self.observation_high = np.array([np.finfo(np.float32).max] * 12)
        self.observation_space = spaces.Box(-self.observation_high, self.observation_high, dtype=np.float32)

        self.target = target

    def step(self, action):
        for i in range(6): 
            self.sim.data.ctrl[i] = action[i] # 直接将action输出到控制器接口
        self.sim.step() # step的调用
        r = 0
        for i in range(6):
            r -= abs(self.sim.data.qpos[i] - self.target[i]) # 6个关节的偏差进行一个累加
            # r -= abs(self.sim.data.qvel[i])
            self.state[i] = self.sim.data.qpos[i] # state更新成每次的速度和位置
            self.state[i + 6] = self.sim.data.qvel[i]
        done = False
        # if r < -50:
        #     done = True
        # print(r)
        return np.array(self.state, dtype=np.float32), r, done, {} # 返回状态(observaton观测到的一些东西),reward,是否完成过程的一个标记,返回一个info(包括一些信息,一般是自检的形式。比如希望返回是否在一个高速的状态,可以在里面写上一些具体的信息,自己定义的,一般可以在输出上使用))

    def render(self):
        self.viewer.render()

    def reset(self): # 从摄像机渲染视图,并将图像作为 numpy.ndarray 返回。
        self.sim.reset()
        for i in range(6):
            self.sim.data.qpos[i] = 0
            self.sim.data.qvel[i] = 0
        self.state = [0] * 12
        return np.array(self.state, dtype=np.float32) # 返回reset后的状态

    def close(self):
        pass



2、spinning up框架介绍

这是一个性能上比较优秀的强化学习框架,它的实验细节比baseline好很多。spinningup的训练效率也比baseline好很多。


支持mpi多线程

Ubuntu: sudo apt-get update && sudp apt-get install libopenmpi-dev
Mac OSX: brew install openmpi

注意:这里不建议使用conda,可能会出问题


安装spinningup

git clone https://github.com/openai/spinningup.git
cd spinningup
pip install -e .


安装测试

python -m spinup.run ppo --hid "[32,32]"--env LunarLander-v2 --exp_name installtest --gamma 0.999
python -m spinup.run test_policy data/installtest/installtest_s0
python -m spinup.run plot data/installtest/installtest_so



3、训练+运行


使用spinningup强化学习算法库编写的运行程序

# main.py
from spinup import ppo_pytorch as ppo
from UR5_Controller import UR5_Controller
from spinup.utils.test_policy import load_policy_and_env, run_policy
import torch

TRAIN = 0  # 0:运行已有信息      1:训练模式

target = [0, -1.57, 1.57, 0, 0, 0] # 目标位置
env = lambda : UR5_Controller(target) # 创建环境

if TRAIN:
    ac_kwargs = dict(hidden_sizes=[64,64], activation=torch.nn.ReLU) # 网络参数:隐藏层,激活函数
    logger_kwargs = dict(output_dir='log', exp_name='ur5_goToTarget')# 记录和输出信息的函数

    ppo(env, ac_kwargs=ac_kwargs, logger_kwargs=logger_kwargs,
        steps_per_epoch=5000, epochs=4000)

else:
    _, get_action = load_policy_and_env('log') # 读取log中的信息
    env_test = UR5_Controller(target)
    run_policy(env_test, get_action)



版权声明:本文为shouchen1原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。