Golang源码学习(一)waitGroup实现原理

  • Post author:
  • Post category:golang


A WaitGroup waits for a collection of goroutines to finish. The main goroutine calls Add to set the number of goroutines to wait for. Then each of the goroutines runs and calls Done when finished. At the same time, Wait can be used to block until all goroutines have finished.

A WaitGroup must not be copied after first use.


WaitGroup用于等待一组协程(goroutine)运行至结束,主协程(相对而言)调用Add方法来设置需要阻塞等待的协程(goroutine)的数目。 每一个协程(goroutine)运行各自的逻辑,当执行完成时调用Done方法通知协程执行结束。同时,使用Wait方法阻塞主协程直到所有的协程(goroutine)均执行结束。

WaitGroup被首次使用后不可复制使用



1、代码位置

sync/waitgroup.go



2、使用方式

  • Add(delta int) :添加任务数,增加counter
  • Wait():阻塞等待所有任务的完成,增加waiter
  • Done():完成任务,减少counter

应用示例:

package main

import (
	"sync"
)

func main() {
	wg := &sync.WaitGroup{}

	for i := 0; i < 10; i++ {
		wg.Add(1)
		go func() {
			defer wg.Done()

			//do anything
		}()
	}

	wg.Wait()
}



3、源码分析

type WaitGroup struct {
	noCopy noCopy //标识该对象首次使用后不能拷贝使用

	// 64-bit value: high 32 bits are counter, low 32 bits are waiter count.
	// 64-bit atomic operations require 64-bit alignment, but 32-bit
	// compilers do not ensure it. So we allocate 12 bytes and then use
	// the aligned 8 bytes in them as state, and the other 4 as storage
	// for the sema.
	state1 [3]uint32
}



noCopy

标识该对象首次使用后不能拷贝使用,sync包中的很多锁相关逻辑都用到了这个标识,如:Lock



state1 成员变量

state1是一个uint32的数组,存储了64bit的数据,高32bit作为

执行者数(

counter



,低32位作为

等待者(

waiter



的计数。

因为64bit的原子操作需要64位对齐,所以32位的编译器不适用上述规则,所以将分配12字节并对齐8个字节实现协程和等待者的计数。 最后的4个字节(64bit)用于存储

信号量(

sema




在这里插入图片描述

// state returns pointers to the state and sema fields stored within wg.state1.
func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
	if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
		return (*uint64)(unsafe.Pointer(&wg.state1)), &wg.state1[2]
	} else {
		return (*uint64)(unsafe.Pointer(&wg.state1[1])), &wg.state1[0]
	}
}



Add(delta int)

Add方法用来变更counter计数,当counter为正数,则waiter将在调用处阻塞直至counter为0,当counter为负数则触发panic。


// Add adds delta, which may be negative, to the WaitGroup counter.
// If the counter becomes zero, all goroutines blocked on Wait are released.
// If the counter goes negative, Add panics.
//
// Note that calls with a positive delta that occur when the counter is zero
// must happen before a Wait. Calls with a negative delta, or calls with a
// positive delta that start when the counter is greater than zero, may happen
// at any time.
// Typically this means the calls to Add should execute before the statement
// creating the goroutine or other event to be waited for.
// If a WaitGroup is reused to wait for several independent sets of events,
// new Add calls must happen after all previous Wait calls have returned.
// See the WaitGroup example.
func (wg *WaitGroup) Add(delta int) {
	statep, semap := wg.state() // 获取state的值(counter、waiter)和信号量的值
	if race.Enabled {
		_ = *statep // trigger nil deref early
		if delta < 0 {
			// Synchronize decrements with Wait.
			race.ReleaseMerge(unsafe.Pointer(wg))
		}
		race.Disable()
		defer race.Enable()
	}
	state := atomic.AddUint64(statep, uint64(delta)<<32) //使用原子加实现不加锁的安全修改,增加counter
	v := int32(state >> 32) //counter的int数值
	w := uint32(state)  // 获取waiter数量
	if race.Enabled && delta > 0 && v == int32(delta) {
		// The first increment must be synchronized with Wait.
		// Need to model this as a read, because there can be
		// several concurrent wg.counter transitions from 0.
		race.Read(unsafe.Pointer(semap))
	}
	if v < 0 { // 如果counter小于0,则Panic
		panic("sync: negative WaitGroup counter")
	}
	if w != 0 && delta > 0 && v == int32(delta) { //异常情况检查,存在w>0且counter和delta相同,说明存在并发调用Add和Wait的情况
		panic("sync: WaitGroup misuse: Add called concurrently with Wait")
	}
	if v > 0 || w == 0 { //正常情况,直接返回,结束Add
		return
	}
	//以下为WaitGroup误用检查,counter为0,waiter>0 ,此需唤醒所有waiter
	// This goroutine has set counter to 0 when waiters > 0.
	// Now there can't be concurrent mutations of state:
	// - Adds must not happen concurrently with Wait,
	// - Wait does not increment waiters if it sees counter == 0.
	// Still do a cheap sanity check to detect WaitGroup misuse.
	if *statep != state {
		panic("sync: WaitGroup misuse: Add called concurrently with Wait")
	}
	// Reset waiters count to 0.
	*statep = 0
	for ; w != 0; w-- {
		runtime_Semrelease(semap, false, 0)
	}
}



Done() – 协程任务完成信号

// Done decrements the WaitGroup counter by one.
func (wg *WaitGroup) Done() {
	wg.Add(-1) // 通过传入-1的delta实现counter的减1
}



Wait() – 阻塞至counter为0

// Wait blocks until the WaitGroup counter is zero.
func (wg *WaitGroup) Wait() {
	statep, semap := wg.state()
	if race.Enabled {
		_ = *statep // trigger nil deref early
		race.Disable()
	}
	for {
		state := atomic.LoadUint64(statep)
		v := int32(state >> 32) //计算出counter数量
		w := uint32(state)  //计算waiter的数量
		if v == 0 { // 如果counter数量为0,则结束阻塞,继续执行Wait后续逻辑
			// Counter is 0, no need to wait.
			if race.Enabled {
				race.Enable()
				race.Acquire(unsafe.Pointer(wg))
			}
			return
		}
		// Increment waiters count.
		if atomic.CompareAndSwapUint64(statep, state, state+1) { // counter>0则用CAS方式对waiters增加1
			if race.Enabled && w == 0 {
				// Wait must be synchronized with the first Add.
				// Need to model this is as a write to race with the read in Add.
				// As a consequence, can do the write only for the first waiter,
				// otherwise concurrent Waits will race with each other.
				race.Write(unsafe.Pointer(semap))
			}
			runtime_Semacquire(semap) // 阻塞等待唤起
			if *statep != 0 {
				panic("sync: WaitGroup is reused before previous Wait has returned")
			}
			if race.Enabled {
				race.Enable()
				race.Acquire(unsafe.Pointer(wg))
			}
			return
		}
	}
}

附: 关于CAS

CAS是主流CPU均支持的原子指令, 全称为Compare and Swap,意为比较并交换。 golang中的方法签名如下:

func CompareAndSwapUint64(addr *uint64, old, new uint64) (swapped bool)

这是一种无锁算法,乐观锁的思想。修改值之前先判断待修改的值是否与当前调用方携带的旧值相等,若相等则执行交换操作,完成对指定变量的更新操作,否则返回操作失败的结果。



版权声明:本文为wxs19970115原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。