【线程同步工具】CyclicBarrier源码分析

  • Post author:
  • Post category:其他




在指定状态点同步任务

Java 并发 API 提供了可以使多个线程在一个指定点同步的工具类 CyclicBarrier,该类前文介绍的 CountDownLatch 有些类似,但是它的一些特殊性使得其更为强大。

CyclicBarrier 类的构造器需要有一个整型参数,这个参数表示在指定点进行同步 的线程个数。当需要同步的线程运行到指定点时,可以调用 CyclicBarrier 对象的 await()方法,然后等待其他线程达到指定点。这个方法使调用线程休眠,等待其他线程的到达。当最后一个需要同步的线程到达并调用 CyclicBarrier 对象的 await()方法 时( 意味着所有的线程均已经到达需要同步的指定点。),所有在此等待的线程都会被唤醒并继续执行。

20181218144511688.gif

与CountDownLatch相比CyclicBarrier 可以有不止一个栅栏,它的栅栏(Barrier)可以重复使用(Cyclic)。

image.png

另外CyclicBarrier 类有一个有趣的优点,它提供的一个重载的构造方法可以额外接收一个 Runnable 类型的初始化参数。这样的话,CyclicBarrier 对象会在所有同步线程到达指定点后,将 Runnable 对象当作一个线程对象来执行。这个特点使得该类适合采用分而治之的编程技术来处理并发任务。



源码分析



主要属性

private static class Generation {
    boolean broken = false;
}

/** The lock for guarding barrier entry */
private final ReentrantLock lock = new ReentrantLock();
/** Condition to wait on until tripped */
private final Condition trip = lock.newCondition();
/** The number of parties */
private final int parties;
/* The command to run when tripped */
private final Runnable barrierCommand;
/** The current generation */
private Generation generation = new Generation();

private int count;

CyclicBarrier的核心属性共有6个,我们将它分为三组。

/** The number of parties */
private final int parties;
/**
* Number of parties still waiting. Counts down from parties to 0
* on each generation.  It is reset to parties on each new
* generation or when broken.
*/
private int count;

注意,这两个属性都是用来表示参与Barrier的线程数量,parties是Barrier要等待的线程数,也就是需要在该Barrier上进行同步的线程总数,是在构造函数中初始化并且不可更改的;count属性和CountDownLatch中的count一样,代表还需要等待的线程数,初始值为parties,每当一个线程到来就减一,如果该值为0,则说明所有的线程都到齐了,大家可以一起通过Barrier了。

第二组:

/** The lock for guarding barrier entry */
private final ReentrantLock lock = new ReentrantLock();
/** Condition to wait on until tripped */
private final Condition trip = lock.newCondition();
/** The current generation */
private Generation generation = new Generation();

这一组代表了CyclicBarrier的基础实现,使用了独占锁ReentrantLock和条件队列Condition。lock和trip是用来管理线程同步的,lock保证同一时刻只有一个线程可以访问障碍器,而trip则是一个等待队列,保存等待的线程。generation,用于标记当前障碍器的Barrier。

generation 这样解释你可能无法理解,再来看个形象点的解释。

由于CyclicBarrier是可重复使用的,所以我们把每一个新的barrier称为 generation。这个怎么理解呢,打个比方:一个过山车有10个座位,景区常常需要等够10个人了,才会去开动过山车。于是我们常常在栏杆(barrier)外面等,等凑够了10个人,工作人员就把栏杆打开,让10个人通过;然后再将栏杆归位,后面新来的人还是要在栏杆外等待。这里,前面已经通过的人就是一个generation,后面再继续等待的一波人就是另外一个 generation,栏杆每打开关闭一次,就产生新的 generation。

第三组:

/* The command to run when tripped */
private final Runnable barrierCommand;

这个属性是一个可选的Runnable,,表示当所有线程都到达障碍器时要执行的任务。如果不需要执行任务,则可以将该属性设置为null。这是一个很有趣的功能,在下面的使用案例中我会使用这个特点采用分而治之的编程技术来处理并发任务。

再用一张图来描绘下 CyclicBarrier 的一些属性和业务逻辑:

image.png




构造函数

CyclicBarrier有两个构造函数:

/**
* 创建一个新的CyclicBarrier,当给定数量的线程正在等待它时会被触发,屏障触发时不执行任何预定义的操作。
*
* @param parties 等待触发屏障的线程数量
* @throws IllegalArgumentException 如果parties小于1
*/
public CyclicBarrier(int parties) {
    this(parties, null); // 调用另一个构造函数,屏障触发时执行的操作为null
}

/**
* 创建一个新的CyclicBarrier,当给定数量的线程正在等待它时会被触发,
* 并且当屏障被触发时会由最后一个进入屏障的线程执行给定的屏障操作。
*
* @param parties 等待触发屏障的线程数量
* @param barrierAction 屏障触发时执行的操作,如果没有操作则为null
* @throws IllegalArgumentException 如果parties小于1
*/
public CyclicBarrier(int parties, Runnable barrierAction) {
    if (parties <= 0) throw new IllegalArgumentException();
    this.parties = parties; // 等待触发屏障的线程数量
    this.count = parties; // 等待触发屏障的线程数量
    this.barrierCommand = barrierAction; // 屏障触发时执行的操作,如果没有操作则为null
}



辅助方法




nextGeneration()

private void nextGeneration() {
    // 唤醒当前这一代中所有等待在条件队列里的线程
    trip.signalAll();
    // 恢复count值,开启新的一代
    count = parties;
    generation = new Generation();
}

该方法用于开启新的“一代”,通常是被最后一个调用await方法的线程调用。在该方法中,我们的主要工作就是唤醒当前这一代中所有等待在条件队列里的线程,将count的值恢复为parties,以及开启新的一代。




breakBarrier()

breakBarrier即打破现有的栅栏,让所有线程通过:

private void breakBarrier() {
    // 标记broken状态
    generation.broken = true;
    // 恢复count值
    count = parties;
    // 唤醒当前这一代中所有等待在条件队列里的线程(因为栅栏已经打破了)
    trip.signalAll();
}

这个breakBarrier怎么理解呢,继续拿上面过上车的例子打比方,有时候某个时间段,景区的人比较少,等待过山车的人数凑不够10个人,眼看后面迟迟没有人再来,这个时候有的工作人员也会打开栅栏,让正在等待的人进来坐过山车。这里工作人员的行为就是breakBarrier,由于并不是在凑够10个人的情况下就开启了栅栏,我们就把这一代的broken状态标记为true。




reset()

public void reset() {
    final ReentrantLock lock = this.lock;
    lock.lock();
    try {
        breakBarrier();   // break the current generation
        nextGeneration(); // start a new generation
    } finally {
        lock.unlock();
    }
}

CyclicBarrier 类提供的 reset()方法可以实现重置。当调用这个方法后,所有因调用 await()方法而休眠等待的线程,将收到 BrokenBarrierException 异常。在本 案例的异常处理中,会将异常信息栈打印输出。然而,在更复杂的应用中,可以在此处进 行其他处理,比如重启执行或者恢复中断点的操作。

值得注意的是,该方法执行前需要先获得锁。



await 方法

看完前面的辅助方法之后,接下来我们就来看 CyclicBarrier 最核心的 await方法,可以说整个CyclicBarrier最关键的只有它了。它也是一个

集“countDown”和“阻塞等待”于一体的方法。


await方法有两种版本,一种带超时机制,一种不带,然而从源码上看,它们最终调用的都是带超时机制的dowait方法。

public int await() throws InterruptedException, BrokenBarrierException {
    try {
        return dowait(false, 0L);
    } catch (TimeoutException toe) {
        throw new Error(toe); // cannot happen
    }
}
public int await(long timeout, TimeUnit unit) throws InterruptedException, BrokenBarrierException, TimeoutException {
    return dowait(true, unit.toNanos(timeout));
}

CyclicBarrier 对象可以处于一个特殊的状态,被称为“损坏状态”。当多个线程 因调用 await() 方法而等待时,若其中一个被中断了,则此线程会收到 InterruptedException 异常,而其他线程将收到 BrokenBarrierException 异常,并且 CyclicBarrier 对象会进入损坏状态。 CyclicBarrier 类提供了 isBroken()方法。如果 CyclicBarrier 对象处于损坏 状态,则调用该方法将返回 true;否则,将返回 false。



dowait 方法

/**
* Barrier 主要代码,包含各种执行策略
*/
private int dowait(boolean timed, long nanos) throws InterruptedException, BrokenBarrierException, TimeoutException {
    final ReentrantLock lock = this.lock;
    // 所有执行await方法的线程必须是已经持有了锁,所以这里必须先获取锁
    lock.lock();
    try {
        final Generation g = generation;

        // 前面说过,调用breakBarrier会将当前“代”的broken属性设为true
        // 如果一个正在await的线程发现barrier已经被break了,则将直接抛出BrokenBarrierException异常
        if (g.broken)
            throw new BrokenBarrierException();

        // 如果当前线程被中断了,则先将栅栏打破,再抛出InterruptedException
        // 这么做的原因是,所以等待在barrier的线程都是相互等待的,如果其中一个被中断了,那其他的就不用等了。
        if (Thread.interrupted()) {
            breakBarrier();
            throw new InterruptedException();
        }

        // 当前线程已经来到了栅栏前,先将等待的线程数减一
        int index = --count;
        
        // 如果等待的线程数为0了,说明所有的parties都到齐了
        // 则可以唤醒所有等待的线程,让大家一起通过栅栏,并重置栅栏
        if (index == 0) {  // tripped
            boolean ranAction = false;
            try {
                final Runnable command = barrierCommand;
                if (command != null)
                    // 如果创建CyclicBarrier时传入了barrierCommand
                    // 说明通过栅栏前有一些额外的工作要做
                    command.run(); 
                ranAction = true;
                // 唤醒所有线程,开启新一代
                nextGeneration();
                return 0;
            } finally {
                if (!ranAction)
                    breakBarrier();
            }
        }

        // 如果count数不为0,就将当前线程挂起,直到所有的线程到齐,或者超时,或者中断发生
        for (;;) {
            try {
                // 如果没有设定超时机制,则直接调用condition的await方法
                if (!timed)
                    trip.await();  // 当前线程在这里被挂起
                else if (nanos > 0L)
                    // 如果设了超时,则等待指定的时间
                    nanos = trip.awaitNanos(nanos); // 当前线程在这里被挂起,超时时间到了就会自动唤醒
            } catch (InterruptedException ie) {
                // 执行到这里说明线程被中断了
                // 如果线程被中断时还处于当前这一“代”,并且当前这一代还没有被broken,则先打破栅栏
                if (g == generation && ! g.broken) {
                    breakBarrier();
                    throw ie;
                } else {
                    // 注意来到这里有两种情况
                    // 一种是g!=generation,说明新的一代已经产生了,所以我们没有必要处理这个中断,只要再自我中断一下就好,交给后续的人处理
                    // 一种是g.broken = true, 说明中断前栅栏已经被打破了,既然中断发生时栅栏已经被打破了,也没有必要再处理这个中断了
                    Thread.currentThread().interrupt();
                }
            }

            // 注意,执行到这里是对应于线程从await状态被唤醒了
            
            // 这里先检测broken状态,能使broken状态变为true的,只有breakBarrier()方法,到这里对应的场景是
            // 1. 其他执行await方法的线程在挂起前就被中断了
            // 2. 其他执行await方法的线程在还处于等待中时被中断了
            // 2. 最后一个到达的线程在执行barrierCommand的时候发生了错误
            // 4. reset()方法被调用
            if (g.broken)
                throw new BrokenBarrierException();

            // 如果线程被唤醒时,新一代已经被开启了,说明一切正常,直接返回
            if (g != generation)
                return index;

            // 如果是因为超时时间到了被唤醒,则打破栅栏,返回TimeoutException
            if (timed && nanos <= 0L) {
                breakBarrier();
                throw new TimeoutException();
            }
        }
    } finally {
        lock.unlock();
    }
}

dowait()方法里的整个逻辑分成两部分:

(1)最后一个线程走上面的逻辑,当count减为0的时候,打破栅栏,它调用nextGeneration()方法通知条件队列中的等待线程转移到AQS的队列中等待被唤醒,并进入下一代。

(2)非最后一个线程走下面的for循环逻辑,这些线程会阻塞在condition的await()方法处,它们会加入到条件队列中,等待被通知,当它们唤醒的时候已经更新换“代”了,这时候返回。

值得注意的是,

await方法是有返回值的,代表了线程到达的顺序

,第一个到达的线程的index为 parties – 1,最后一个到达的线程的index为 0



使用案例



案例一:赛马



Horse类

public class Horse implements Runnable {
    // 静态变量,用来记录马的数量
    private static int counter = 0;
    // 每匹马的编号
    private final int id = counter++;
    // 记录马已经走过的距离
    private int strides = 0;
    // 随机数生成器,用来生成每次随机跑几步
    private static Random rand = new Random(47);
    // CyclicBarrier 对象,用来控制所有马同时开始比赛
    private static CyclicBarrier barrier;

    // 构造函数,传入 CyclicBarrier 对象
    public Horse(CyclicBarrier b) { barrier = b; }

    // 实现 Runnable 接口中的 run() 方法
    @Override
    public void run() {
        try {
            // 只要线程没有被中断,就一直循环
            while(!Thread.interrupted()) {
                // 随机跑几步
                strides += rand.nextInt(3);
                // 等待所有的马准备好,才开始比赛
                barrier.await();
            }
        } catch(Exception e) {
            e.printStackTrace();
        }
    }

    // 返回当前马已经走过的距离的字符串表示
    public String tracks() {
        StringBuilder s = new StringBuilder();
        for(int i = 0; i < getStrides(); i++) {
            s.append("*");
        }
        s.append("Horse-" + id);
        return s.toString();
    }

    // 返回当前马已经走过的距离
    public int getStrides() { return strides; }
    // 返回当前马的编号
    public String toString() { return "Horse " + id + " "; }
}



HorseRace类

public class HorseRace implements Runnable {

    private static final int FINISH_LINE = 80;

    private List<Horse> horses; // 赛马列表
    private ExecutorService exec; // 线程池

    public HorseRace(List<Horse> horses, ExecutorService exec) {
        this.horses = horses;
        this.exec = exec;
    }

    @Override
    public void run() {
        StringBuilder s = new StringBuilder();
        // 打印赛道边界
        for (int i = 0; i < FINISH_LINE; i++) {
            s.append("=");
        }
        System.out.println(s);
        // 打印赛马轨迹
        for (Horse horse : horses) {
            System.out.println(horse.tracks());
        }
        // 判断是否结束
        for (Horse horse : horses) {
            if (horse.getStrides() >= FINISH_LINE) {
                System.out.println(horse + "won!");
                exec.shutdownNow(); // 中断线程池中的所有线程
                return;
            }
        }
        // 休息指定时间再到下一轮
        try {
            TimeUnit.MILLISECONDS.sleep(200);
        } catch (InterruptedException e) {
            System.out.println("barrier-action sleep interrupted");
        }
    }

}



Main类

public class Main {

    // 创建一个马的列表,使用ArrayList存储
    private static List<Horse> horses = new ArrayList<Horse>();

    // 创建一个线程池,使用newCachedThreadPool方法创建一个可缓存的线程池
    private static ExecutorService exec = Executors.newCachedThreadPool();

    public static void main(String[] args) {
        // 创建一个CyclicBarrier对象,用于等待所有马准备就绪后开始比赛
        // 这里设置7个参与者,当有7个线程到达await方法时,会执行HorseRace的run方法
        CyclicBarrier barrier = new CyclicBarrier(7, new HorseRace(horses, exec));

        // 循环创建7匹马,将它们添加到马列表中,并且提交到线程池中执行
        for(int i = 0; i < 7; i++) {
            Horse horse = new Horse(barrier);
            horses.add(horse);
            exec.execute(horse);
        }
    }
}

该赛马程序主要是通过在控制台不停的打印各赛马的当前轨迹,以此达到动态显示的效果。整场比赛有多个轮次,每一轮次各个赛马都会随机走上几步然后调用await方法进行等待,当所有赛马走完一轮的时候将会执行任务将所有赛马的当前轨迹打印到控制台上。这样每一轮下来各赛马的轨迹都在不停的增长,当其中某个赛马的轨迹最先增长到指定的值的时候将会结束整场比赛,该赛马成为整场比赛的胜利者!

程序的运行结果如下:

20181218144511713.gif



案例二:分而治之



MatrixMock 类

/**
 * 该类生成一个介于1和10之间的整数随机矩阵
 */
public class MatrixMock {

	/**
	 * 包含随机数的二维数组
	 */
	private final int data[][];

	/**
	 * 类的构造函数。生成二维数组。
	 * 在生成数组时,计算出指定要查找的数字出现的次数,以便检查CyclicBarrier类的工作是否良好。
	 * @param size 数组的行数
	 * @param length 数组的列数
	 * @param number 要查找的数字
	 */
	public MatrixMock(int size, int length, int number){

		int counter=0; // 计数器初始化为0
		data=new int[size][length]; // 生成指定行数和列数的二维整数数组
		Random random=new Random(); // 创建一个Random对象
		for (int i=0; i<size; i++) {
			for (int j=0; j<length; j++){
				data[i][j]=random.nextInt(10); // 用随机数填充数组
				if (data[i][j]==number){ // 如果数组中的元素等于指定数字
					counter++; // 计数器递增
				}
			}
		}
		System.out.printf("Mock: 在生成的数据中有%d个数字%d的出现。\n", counter, number); // 输出指定数字出现的次数
	}


	/**
	 * 该方法返回二维数组的一行
	 * @param row 要返回的行号
	 * @return 所选行
	 */
	public int[] getRow(int row){
		if ((row>=0)&&(row<data.length)){ // 如果指定的行号在数组的索引范围内
			return data[row]; // 返回指定行的数组
		}
		return null; // 否则返回null
	}


}



Results类

/**
 * 该类用于存储我们在二维数组的每行中查找的数字的出现次数
 */
public class Results {

	/**
	 * 用于存储每行中数字出现次数的数组
	 */
	private final int data[];

	/**
	 * 该类的构造函数,用于初始化它的属性
	 * @param size 用于存储结果的数组的大小
	 */
	public Results(int size){
		data=new int[size];
	}

	/**
	 * 设置结果数组中一个位置的值
	 * @param position 数组中的位置
	 * @param value 要设置的值
	 */
	public void setData(int position, int value){
		data[position]=value;
	}

	/**
	 * 返回结果数组
	 * @return 结果数组
	 */
	public int[] getData(){
		return data;
	}
}



Search类

/**
 * 在二维数组的一组行中查找指定的数字
 */
public class Searcher implements Runnable {

	/**
	 * 需要查找的数字
	 */
	private final int number;

	/**
	 * 待查找的二维数组
	 */
	private final MatrixMock mock;

	/**
	 * 结果数组
	 */
	private final Results results;

	/**
	 * 需要查找的行范围
	 */
	private final int firstRow, lastRow;

	/**
	 * 同步屏障,用于控制任务的执行
	 */
	private final CyclicBarrier barrier;

	/**
	 * 构造函数,用于初始化类属性
	 * @param firstRow 查找的起始行号
	 * @param lastRow 查找的结束行号
	 * @param mock 待查找的二维数组
	 * @param results 存储结果的结果数组
	 * @param number 需要查找的数字
	 * @param barrier 同步屏障,用于控制任务的执行
	 */
	public Searcher(int firstRow, int lastRow, MatrixMock mock, Results results, int number, CyclicBarrier barrier) {
		this.firstRow = firstRow;
		this.lastRow = lastRow;
		this.mock = mock;
		this.results = results;
		this.number = number;
		this.barrier = barrier;
	}

	/**
	 * 每行查找数字并将结果存储到结果数组中
	 */
	@Override
	public void run() {
		int counter;
		System.out.printf("%s: 正在处理从第%d行到第%d行的数据\n", Thread.currentThread().getName(), firstRow, lastRow);
		for (int i = firstRow; i < lastRow; i++) {
			int row[] = mock.getRow(i); // 获取第i行
			counter = 0;
			for (int j = 0; j < row.length; j++) {
				if (row[j] == number) { // 在行中查找数字
					counter++; // 数字出现次数+1
				}
			}
			results.setData(i, counter); // 将数字出现次数存储到结果数组中
		}
		System.out.printf("%s: 处理完毕\n", Thread.currentThread().getName());
		try {
			barrier.await(); // 等待其他任务完成
		} catch (InterruptedException e) {
			e.printStackTrace();
		} catch (BrokenBarrierException e) {
			e.printStackTrace();
		}
	}
}

Group类

/**
 * 将每个Searcher的结果分组。将存储在Results对象中的值相加
 * 当所有Searcher完成其工作时,此类的对象会被CyclicBarrier自动执行
 */
public class Grouper implements Runnable {

	/**
	 * 存储每一行中数字出现次数的Results对象
	 */
	private final Results results;

	/**
	 * 类的构造方法。初始化它的属性
	 * @param results 存储每一行数字出现次数的Results对象
	 */
	public Grouper(Results results){
		this.results=results;
	}

	/**
	 * Grouper的主方法。将存储在Results对象中的值相加
	 */
	@Override
	public void run() {
		int finalResult=0;
		System.out.printf("Grouper: 数据个数统计中...\n");
		int data[]=results.getData();
		for (int number:data){
			finalResult+=number;
		}
		System.out.printf("Grouper: 数据出现的个数: %d.\n",finalResult);
	}

}

Main类

/**
 * 示例的主类
 */
public class Main {
	/**
	 * 示例的主方法
	 * @param args 命令行参数
	 */
	public static void main(String[] args) {
		/*
		 * 初始化二维数据数组
		 * 		10000行
		 * 		每行1000个数字
		 * 		查找数字5
		 */
		final int ROWS=10000;
		final int NUMBERS=1000;
		final int SEARCH=5;
		final int PARTICIPANTS=5;
		final int LINES_PARTICIPANT=2000;
		MatrixMock mock=new MatrixMock(ROWS, NUMBERS,SEARCH);

		// 初始化结果对象
		Results results=new Results(ROWS);

		// 创建 Grouper 对象
		Grouper grouper=new Grouper(results);

		// 创建 CyclicBarrier 对象。它有5个参与者,当他们完成时,CyclicBarrier 将执行 grouper 对象
		CyclicBarrier barrier=new CyclicBarrier(PARTICIPANTS,grouper);

		// 创建、初始化并启动5个 Searcher 对象
		Searcher searchers[]=new Searcher[PARTICIPANTS];
		for (int i=0; i<PARTICIPANTS; i++){
			searchers[i]=new Searcher(i*LINES_PARTICIPANT, (i*LINES_PARTICIPANT)+LINES_PARTICIPANT, mock, results, 5,barrier);
			Thread thread=new Thread(searchers[i]);
			thread.start();
		}
		System.out.printf("Main: 主线程完成.\n");
	}
}

这个案例所解决的问题比较简单,有一个大矩阵,其元素为 0~9 的随机整数,想要知道某个数字在其中出现的次数。为了提升查询效率,可以采用分而治之的方法,将矩阵划 分为 5 个子集,采用多线程的方式分别统计每个子集中某个数字出现的次数。这些线程执行的任务为 Searcher 对象。 在这里,可以采用一个 CyclicBarrier 对象来同步 5 个线程的完成状态 ( 使得在 5 个线程均完成统计任务后,执行 Grouper 任务。)。

对每个线程统计结果进行汇总得到最后结果。 正如之前介绍,CyclicBarrier 类定义了内部计数器,它对需要在同步状态点进行同步的线程数进行控制。每当一个线程执行到同步状态点时,它会通过调用 await()通知 CyclicBarrier 对象有一个线程已达到同步状态点,这时 CyclicBarrier 对象将内部计数器的值减 1,并将调用线程休眠,直到所有线程均达到同步状态点。 当所有线程均达到同步状态点后,CyclicBarrier 对象将唤醒所有因调用该对象 await()方法而休眠等待的线程,并且该对象还会创建一个新的线程来执行在CyclicBarrier 构造器中传入的Runnable 对象(在本问中,这指Grouper 对象)。



运行结果:

image.png



参考资料



版权声明:本文为weixin_45645846原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。