源码分析:ReentrantLock、Semaphore以及CountDownLatch源码以及对应的设计模式

  • Post author:
  • Post category:其他

最近懵懵懂懂的看完了AQS的源码(源码分析:AQS源码),还是有很多不懂的地方,感觉还是要多来几遍的,为了更深入的理解AQS框架,看一下使用AQS的ReentrantLock、Semaphore以及CountDownLatch,直接上代码吧,解释都在注释里

/**
 * 这里是重入锁,我们需要关注一下重入锁是怎么实现的,两个条件:
 * 1. 在线程获取锁的时候,如果已经获取锁的线程是当前线程的话则直接再次获取成功
 * 2. 由于锁会被获取n次,那么只有锁在被释放同样的n次之后,该锁才算是完全释放成功
 **/
public class ReentrantLock implements Lock, java.io.Serializable {

    /**
     * 默认是非公平的锁,通过以下代码可以看出来
     *    public ReentrantLock() {
     *   sync = new NonfairSync();
     * }
     * 如果要实现公平锁则通过指定参数的方式进行
     *   public ReentrantLock(boolean fair) {
     *   sync = fair ? new FairSync() : new NonfairSync();
     * }
    private final Sync sync;

    /**
     * 实现了AQS同步框架的同步器,ReentrantLock根据自己的需求实现了同步器
     */
    abstract static class Sync extends AbstractQueuedSynchronizer {
        private static final long serialVersionUID = -5179523762034025860L;

        /**
         * Performs {@link Lock#lock}. The main reason for subclassing
         * is to allow fast path for nonfair version.
         */
        abstract void lock();

        /**
         * 为了支持重入性,增加了额外的处理逻辑,如果该锁已经被线程所占有了,会继续检查占有线程是否为当前线程,
         * 如果是的话,同步状态加1返回true,表示可以再次获取成功。每次重新获取都会对同步状态进行加一的操作
         */
        final boolean nonfairTryAcquire(int acquires) {
            final Thread current = Thread.currentThread();
            /**
             * 获取Sync的state状态,在AQS中
             * 如果该锁未被任何线程占有,该锁能被当前线程获取
             **/
            int c = getState();
            if (c == 0) {
                if (compareAndSetState(0, acquires)) {
                    setExclusiveOwnerThread(current);
                    return true;
                }
            }
            /**
             * 若被占有,检查占有线程是否是当前线程
             **/
            else if (current == getExclusiveOwnerThread()) {
                /**
                 * 再次获取,计数加一
                 **/
                int nextc = c + acquires;
                if (nextc < 0) // overflow
                    throw new Error("Maximum lock count exceeded");
                setState(nextc);
                return true;
            }
            return false;
        }

        /**
         * 重入锁的释放必须得等到同步状态为0时锁才算成功释放,否则锁仍未释放。
         * 如果锁被获取n次,释放了n-1次,该锁未完全释放返回false,只有被释放n次才算成功释放,返回true。
         **/
        protected final boolean tryRelease(int releases) {
            /**
             * 同步状态减1
             **/
            int c = getState() - releases;
            if (Thread.currentThread() != getExclusiveOwnerThread())
                throw new IllegalMonitorStateException();
            boolean free = false;
            /**
             * 只有当同步状态为0时,锁成功被释放,返回true
             **/
            if (c == 0) {
                free = true;
                setExclusiveOwnerThread(null);
            }
            /**
             * 锁未被完全释放,返回false
             **/
            setState(c);
            return free;
        }

        /**
         * 当前线程持有锁
         **/
        protected final boolean isHeldExclusively() {
            return getExclusiveOwnerThread() == Thread.currentThread();
        }

        final ConditionObject newCondition() {
            return new ConditionObject();
        }

        /**
         * 获取当前锁持有者
         **/
        final Thread getOwner() {
            return getState() == 0 ? null : getExclusiveOwnerThread();
        }

        /**
         * 获取保持锁定的线程数
         **/
        final int getHoldCount() {
            return isHeldExclusively() ? getState() : 0;
        }

        /**
         * 只要不等于1就说明是上锁了的
         **/
        final boolean isLocked() {
            return getState() != 0;
        }
    }

    /**
     * 非公平的同步器
     */
    static final class NonfairSync extends Sync {
        private static final long serialVersionUID = 7316153563782823691L;

        /**
         * 尝试加锁,如果状态改变成功则设置让当前线程持有锁
         * 否则的话进入队列
         */
        final void lock() {
            if (compareAndSetState(0, 1))
                setExclusiveOwnerThread(Thread.currentThread());
            else
                acquire(1);
        }

        /**
         * 此处是非公平重入锁的逻辑
         **/
        protected final boolean tryAcquire(int acquires) {
            return nonfairTryAcquire(acquires);
        }
    }

    /**
     * 公平锁每次获取到锁为同步队列中的第一个节点,保证请求资源时间上的绝对顺序,而非公平锁有可能刚释放锁的线程下次继续获取该锁,
     * 则有可能导致其他线程永远无法获取到锁,造成“饥饿”现象。
     * 
     * 公平锁为了保证时间上的绝对顺序,需要频繁的上下文切换,而非公平锁会降低一定的上下文切换,降低性能开销。
     * 因此,ReentrantLock默认选择的是非公平锁,则是为了减少一部分上下文切换,保证了系统更大的吞吐量
     */
    static final class FairSync extends Sync {
        private static final long serialVersionUID = -3000897897090466540L;

        /**
         * 直接进入队列,不像非公平锁上来先抢一抢
         **/
        final void lock() {
            acquire(1);
        }

        /**
         * hasQueuedPredecessors方法用来判断当前节点在同步队列中是否有前驱节点的判断,如果有前驱节点说明有线程比当前线程更早的请求资源,
         * 根据公平性,当前线程请求资源失败。如果当前节点没有前驱节点的话,再才有做后面的逻辑判断的必要性。
         * 公平锁每次都是从同步队列中的第一个节点获取到锁,而非公平性锁则不一定,有可能刚释放锁的线程能再次获取到锁。
         */
        protected final boolean tryAcquire(int acquires) {
            final Thread current = Thread.currentThread();
            int c = getState();
            if (c == 0) {
                if (!hasQueuedPredecessors() &&
                    compareAndSetState(0, acquires)) {
                    setExclusiveOwnerThread(current);
                    return true;
                }
            }
            else if (current == getExclusiveOwnerThread()) {
                int nextc = c + acquires;
                if (nextc < 0)
                    throw new Error("Maximum lock count exceeded");
                setState(nextc);
                return true;
            }
            return false;
        }
    }

    /**
     * ReentrantLock支持两种锁:公平锁和非公平锁。何谓公平性,是针对获取锁而言的,如果一个锁是公平的,那么锁的获取顺序就应该符合请求上的绝对时间顺序,满足FIFO
     * 默认是非公平锁
     */
    public ReentrantLock() {
        sync = new NonfairSync();
    }

    /**
     * 传递参数构造参数
     */
    public ReentrantLock(boolean fair) {
        sync = fair ? new FairSync() : new NonfairSync();
    }

    /**
     * 根据不同的同步器调用lock方法,是抢占的还是直接进入队列的
     */
    public void lock() {
        sync.lock();
    }

    /**
     *  sync.acquireInterruptibly(1);源码:
     *     public final void acquireInterruptibly(int arg)
     *        throws InterruptedException {
     *       if (Thread.interrupted())
     *          throw new InterruptedException();
     *       if (!tryAcquire(arg))
     *          doAcquireInterruptibly(arg);
     *    }
     * 这里调用了AbstractQueuedSynchronizer.acquireInterruptibly方法。
     * 如果线程已被中断则直接抛出异常,否则则尝试获取锁,失败则doAcquireInterruptibly()
     * lock():若lock被thread A取得,thread B会进入block状态,直到取得lock。
     * tryLock():如果当下不能取得lock,thread就会放弃。
     * lockInterruptibly():跟lock()情況一下,但是thread B可以通过interrupt被唤醒处理InterruptedException。
     * https://pandaforme.github.io/2016/12/09/Java-lock%E3%80%81tryLock%E5%92%8ClockInterruptibly%E7%9A%84%E5%B7%AE%E5%88%A5/
     */
    public void lockInterruptibly() throws InterruptedException {
        sync.acquireInterruptibly(1);
    }

    /**
     * 尝试加锁,得不到锁就直接return false,可以看上边的源码
     */
    public boolean tryLock() {
        return sync.nonfairTryAcquire(1);
    }

    /**
     * 在acquireInterruptibly基础上增加了超时等待功能,在超时时间内没有获得同步状态返回false
     */
    public boolean tryLock(long timeout, TimeUnit unit)
            throws InterruptedException {
        return sync.tryAcquireNanos(1, unit.toNanos(timeout));
    }

    /**
     * 释放锁
     */
    public void unlock() {
        sync.release(1);
    }
}

Semaphore可以用于做流量控制,特别公用资源有限的应用场景,比如数据库连接。假如有一个需求,要读取几万个文件的数据,因为都是IO密集型任务,我们可以启动几十个线程并发的读取,但是如果读到内存后,还需要存储到数据库中,而数据库的连接数只有10个,这时我们必须控制只有十个线程同时获取数据库连接保存数据,否则会报错无法获取数据库连接。这个时候,我们就可以使用Semaphore来做流控,它的实现原理和上边的方式大同小异,先看一下是如何限流的:

public class SemaphoreTest {

	private static final int THREAD_COUNT = 30;

	private static ExecutorService threadPool = Executors
			.newFixedThreadPool(THREAD_COUNT);

	private static Semaphore s = new Semaphore(10);

	public static void main(String[] args) {
		for (int i = 0; i < THREAD_COUNT; i++) {
			threadPool.execute(new Runnable() {
				@Override
				public void run() {
					try {
						s.acquire();
						System.out.println("save data");
						s.release();
					} catch (InterruptedException e) {
					}
				}
			});
		}

		threadPool.shutdown();
	}
}

看看实现的源码:

 /**
     * 调用AQS方法sync.acquireSharedInterruptibly(1)
     **/
    public void acquire() throws InterruptedException {
        sync.acquireSharedInterruptibly(1);
    }

    /**
     * AQS方法
     **/
    public final void acquireSharedInterruptibly(int arg)
            throws InterruptedException {
        if (Thread.interrupted())
            throw new InterruptedException();//线程被中断抛出中断异常
        if (tryAcquireShared(arg) < 0)//尝试获取锁
            doAcquireSharedInterruptibly(arg);//调用方法doAcquireSharedInterruptibly
    }

    /**
     * 同步器自己方法
     **/
    static final class NonfairSync extends Sync {
        private static final long serialVersionUID = -2694183684443567898L;

        /**
         * 就是设置了state的值
         * Sync(int permits) {
         *   setState(permits);
         * }
         **/
        NonfairSync(int permits) {
            super(permits);
        }

        /**
         * AQS中调用tryAcquireShared
         **/
        protected int tryAcquireShared(int acquires) {
            return nonfairTryAcquireShared(acquires);
        }
    }

    /**
     * 进一步调用这个方法
     * compareAndSetState(available, remaining)将减完之后的值存入,成功的话返回剩余值
     **/
    final int nonfairTryAcquireShared(int acquires) {
        for (;;) {
            int available = getState();
            int remaining = available - acquires;
            if (remaining < 0 ||
                compareAndSetState(available, remaining))
                return remaining;
        }
    }

    /**
     * AQS方法
     * 方法为一个自旋方法会尝试一直去获取同步状态
     * 尝试着解除限制
     * https://zhuanlan.zhihu.com/p/38165635
     **/
    private void doAcquireSharedInterruptibly(int arg)
        throws InterruptedException {
        /**
         * 将当前线程包装为类型为 Node.SHARED 的节点,标示这是一个共享节点。
         **/
        final Node node = addWaiter(Node.SHARED);
        boolean failed = true;
        try {
            for (;;) {
                /**
                 * 首次循环 拿到SHARED节点的前置节点,当然就是head节点了
                 **/
                final Node p = node.predecessor();
                if (p == head) {
                    /**
                     * 再次查看state状态
                     * tryAcquireShared 返回就两个值 如果需要阻塞,r=-1,不需要阻塞,r=1
                     **/
                    int r = tryAcquireShared(arg);
                    if (r >= 0) {
                        setHeadAndPropagate(node, r);
                        p.next = null; // help GC
                        failed = false;
                        return;
                    }
                }
                if (shouldParkAfterFailedAcquire(p, node) &&
                    parkAndCheckInterrupt())
                    throw new InterruptedException();
            }
        } finally {
            if (failed)
                cancelAcquire(node);
        }
    }

    /**
     * 该方法主要靠前驱节点判断当前线程是否应该被阻塞
     **/
    private static boolean shouldParkAfterFailedAcquire(Node pred, Node node) {
        //前驱节点
        int ws = pred.waitStatus;
        //状态为signal,表示当前线程处于等待状态,直接放回true
        if (ws == Node.SIGNAL)
            return true;
        //前驱节点状态 > 0 ,则为Cancelled,表明该节点已经超时或者被中断了,需要从同步队列中取消
        if (ws > 0) {
            do {
                node.prev = pred = pred.prev;
            } while (pred.waitStatus > 0);
            pred.next = node;
        }
        //前驱节点状态为Condition、propagate
        else {
            compareAndSetWaitStatus(pred, ws, Node.SIGNAL);
        }
        return false;
    }

    /**
     * 挂起当前线程
     **/
    private final boolean parkAndCheckInterrupt() {
        LockSupport.park(this);
        return Thread.interrupted();
    }

再看一下CountDownLatch的使用,然后看下源码:

/**
 * 常用的方法有下边三个
 **/
CountDownLatch(int count) //实例化一个倒计数器,count指定计数个数
countDown() // 计数减一
await() //等待,当计数减到0时,所有线程并行执行

/**
 * 计数数量为10,这表示需要有10个线程来完成任务,等待在CountDownLatch上的线程才能继续执行。
 * latch.countDown();方法作用是通知CountDownLatch有一个线程已经准备完毕,倒计数器可以减一了。
 * latch.await()方法要求主线程等待所有10个检查任务全部准备好才一起并行执行。
 **/
public class CountDownLatchDemo implements Runnable{

    static final CountDownLatch latch = new CountDownLatch(10);
    static final CountDownLatchDemo demo = new CountDownLatchDemo();

    @Override
    public void run() {
        // 模拟检查任务
        try {
            Thread.sleep(new Random().nextInt(10) * 1000);
            System.out.println("check complete");
        } catch (InterruptedException e) {
            e.printStackTrace();
        } finally {
            //计数减一
            //放在finally避免任务执行过程出现异常,导致countDown()不能被执行
            latch.countDown();
        }
    }


    public static void main(String[] args) throws InterruptedException {
        ExecutorService exec = Executors.newFixedThreadPool(10);
        for (int i=0; i<10; i++){
            exec.submit(demo);
        }

        // 等待检查
        latch.await();

        // 发射火箭
        System.out.println("Fire!");
        // 关闭线程池
        exec.shutdown();
    }
}

源码:

    /**
     * 实现同步器AQS
     **/
    private static final class Sync extends AbstractQueuedSynchronizer {
        private static final long serialVersionUID = 4982264981922014374L;

        /**
         * 设置初始值,一般就是我们设置的countdown的个数
         **/
        Sync(int count) {
            setState(count);
        }

        /**
         * 获取个数
         **/
        int getCount() {
            return getState();
        }

        /**
         * 同样是实现了父类的方法,子类调用
         **/
        protected int tryAcquireShared(int acquires) {
            return (getState() == 0) ? 1 : -1;
        }

        /**
         * 每次减值就是原状态减去1
         **/
        protected boolean tryReleaseShared(int releases) {
            // Decrement count; signal when transition to zero
            for (;;) {
                int c = getState();
                if (c == 0)
                    return false;
                int nextc = c-1;
                if (compareAndSetState(c, nextc))
                    return nextc == 0;
            }
        }
    }

    /**
     * 调用了同步器的构造方法
     **/
    public CountDownLatch(int count) {
        if (count < 0) throw new IllegalArgumentException("count < 0");
        this.sync = new Sync(count);
    }

    /**
     * 触发当前线程进入wait状态,state减为0后唤醒,关于这个方法的具体代码可以看信号量里的
     **/
    public void await() throws InterruptedException {
        sync.acquireSharedInterruptibly(1);
    }

    /**
     * countDown释放1
     **/
    public void countDown() {
        sync.releaseShared(1);
    }

    /**
     * AQS方法,调用子类tryReleaseShared的实现
     **/
    public final boolean releaseShared(int arg) {
        if (tryReleaseShared(arg)) {
            doReleaseShared();
            return true;
        }
        return false;
    }

这三个类里边都有一种方法,他们在父类里边定义,在子类中实现,其他类在引用子类时,使用此方法,比如这个tryAcquireShared。我们要知道子类继承了父类的所有方法,同时同名情况下子类的实现会覆盖父类的实现,这是什么设计模式呢?模版方法设计模式:模板方法模式是基于”继承“的;1、将相同部分的代码放在抽象的父类中,而将不同的代码放入不同的子类中;2、通过一个父类调用其子类的操作,通过对子类的具体实现扩展不同的行为,实现了反向控制


版权声明:本文为maoyeqiu原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。