Java GUI实战:Swing 实现可视化马尔可夫决策(二)

  • Post author:
  • Post category:java




Java GUI实战:Swing 实现可视化马尔可夫决策(二)



核心算法

基于马尔可夫决策过程的强化学习算法

Agent处在每一节点,在向一个方向运动时会有一个状态期望值,即它到后继节点的概率乘它到该后继节点获得的奖励值与后继节点目前状态值之和,

用公式就是



Σ

P

(

r

[

i

]

+

γ

v

[

i

]

)

\Sigma P*(r[i]+\gamma v[i])






Σ


P













(


r


[


i


]




+








γ


v


[


i


]


)





,其中



γ

\gamma






γ





是折现系数,大于0小于1(若不乘



γ

\gamma






γ





则系统不收敛),具体内容见下图:

在这里插入图片描述

对于一个节点(视为一个状态),它有四个后继。每迭代一次,它的状态值更新为四个期望值中最大的那个



马尔可夫决策代码

public static void nextValues(double[] values,double[] rewards,double gamma){
     double[] newVal = new double[25];
     for(int i=0;i<5;i++){
         for(int j=0;j<5;j++){
             int k = i*5+j;
             if(rewards[k]==1||rewards[k]==-1||rewards[k]==2)
                 newVal[k]=0;
             else{
                 double up=0;
                 double down=0;
                 double left = 0;
                 double right = 0;
                 //up
                 if(checkBorder(i-1,j-1)&&rewards[k-6]!=2) {
                     up += 0.2 * (gamma*values[k - 6]+rewards[k-6]);
                     left+=0.2*(gamma*values[k-6]+rewards[k-6]);
                 }
                 else {
                     up += 0.2 * (gamma*values[k]+rewards[k]);
                     left+=0.2*(gamma*values[k]+rewards[k]);
                 }
                 if(checkBorder(i-1,j)&&rewards[k-5]!=2)
                     up+=0.6*(gamma*values[k-5]+rewards[k-5]);
                 else
                     up+=0.6*(gamma*values[k]+rewards[k]);
                 if(checkBorder(i-1,j+1)&&rewards[k-4]!=2) {
                     up += 0.2 * (gamma*values[k - 4]+rewards[k-4]);
                     right+=0.2*(gamma*values[k-4]+rewards[k-4]);
                 }
                 else {
                     up += 0.2 * (gamma*values[k]+rewards[k]);
                     right+=0.2*(gamma*values[k]+rewards[k]);
                 }
                 //down
                 if(checkBorder(i+1,j-1)&&rewards[k+4]!=2) {
                     down += 0.2 * (gamma*values[k + 4]+rewards[k+4]);
                     left+=0.2*(gamma*values[k+4]+rewards[k+4]);
                 }
                 else {
                     down += 0.2 * (gamma*values[k]+rewards[k]);
                     left+=0.2*(gamma*values[k]+rewards[k]);
                 }
                 if(checkBorder(i+1,j)&&rewards[k+5]!=2)
                     down+=0.6*(gamma*values[k+5]+rewards[k+5]);
                 else
                     down+=0.6*(gamma*values[k]+rewards[k]);
                 if(checkBorder(i+1,j+1)&&rewards[k+6]!=2) {
                     down += 0.2 * (gamma*values[k + 6]+rewards[k+6]);
                     right+=0.2*(gamma*values[k+6]+rewards[k+6]);
                 }
                 else {
                     down += 0.2 * (gamma*values[k]+rewards[k]);
                     right+=0.2*(gamma*values[k]+rewards[k]);
                 }
                 //left
                 if(checkBorder(i,j-1)&&rewards[k-1]!=2)
                     left+=0.6*(gamma*values[k-1]+rewards[k-1]);
                 else
                     left+=0.6*(gamma*values[k]+rewards[k]);
                 //right
                 if(checkBorder(i,j+1)&&rewards[k+1]!=2)
                     right+=0.6*(gamma*values[k+1]+rewards[k+1]);
                 else
                     right+=0.6*(gamma*values[k]+rewards[k]);
                 newVal[k]=maxInFour(up,down,left,right);
             }
         }
     }
    System.arraycopy(newVal, 0, values, 0, 25);
}



 public static boolean checkBorder(int i,int j){
     return 0<=i&&i<=4&&0<=j&&j<=4;
 }

 public static double maxInFour(double d1,double d2,double d3,double d4){
     double result = d1;
     if(result<d2)
         result=d2;
     if(result<d3)
         result = d3;
     if(result<d4)
         result = d4;
     return  result;
 }



绑定NEXT按钮

next.addActionListener(new ActionListener() {
    @Override
    public void actionPerformed(ActionEvent e) {
        //CalcuNext cal = new CalcuNext();

        nextValues(values,rewards,gammaa[0]);
        for(int i=0;i<25;i++)
            if(rewards[i]==1||rewards[i]==-1){
                grids.get(i).setText(rewards[i]+"");
            }else {
                grids.get(i).setText(values[i] + "");
            }
        count[0]++;
        showN.setText(count[0]+"");
    }
});



自定义地图

edit.addActionListener(new ActionListener() {
    @Override
    public void actionPerformed(ActionEvent e) {
        int x = Integer.parseInt(editX.getText());
        int y = Integer.parseInt(editY.getText());
        int attr = Integer.parseInt(editAttr.getText());//0 for empty,1 for enter,2 for exit, 3 for gold, 4 for hole, 5 for stone
        switch (attr){
            case 0:
                values[x*5+y]=0;
                rewards[x*5+y]=0;
                grids.get(x*5+y).setBackground(Color.white);
                break;
            case 1:
                values[x*5+y]=0;
                rewards[x*5+y]=0;
                grids.get(x*5+y).setBackground(Color.BLUE);
                startIndex[0]=x*5+y;                      //there should be only one start!
                break;
            case 2:
                values[x*5+y]=0;
                rewards[x*5+y]=1;
                grids.get(x*5+y).setBackground(Color.RED);
                break;
            case 3:
                values[x*5+y]=0;
                rewards[x*5+y]=0.25;
                grids.get(x*5+y).setBackground(Color.YELLOW);
                break;
            case 4:
                values[x*5+y]=0;
                rewards[x*5+y]=-1;
                grids.get(x*5+y).setBackground(Color.GRAY);
                break;
            case 5:
                values[x*5+y]=2;
                rewards[x*5+y]=2;
                grids.get(x*5+y).setBackground(Color.BLACK);
            default:
                break;
        }
        for(int i=0;i<25;i++)
            if(rewards[i]==1||rewards[i]==-1)
                grids.get(i).setText(rewards[i]+"");
            else {
                grids.get(i).setText(values[i] + "");
            }
    }
});



REPLAY按钮

restart.addActionListener(new ActionListener() {
    @Override
    public void actionPerformed(ActionEvent e) {
        for(int i=0;i<25;i++){
            values[i]=0;
            rewards[i]=0;
            grids.get(i).setBackground(Color.white);
            grids.get(i).setText("0.0");
        }
        count[0]=0;
        stepCount[0]=0;
        showN.setText(count[0]+"");
        stepCountShow.setText(stepCount[0]+"");
        edit.setEnabled(true);
        next.setEnabled(false);
    }
});



版权声明:本文为swy_swy_swy原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。