目录
1. 概述
ThreadLocal中文名叫做线程本地变量,用于解决变量在多线程问题下的并发安全问题,其思路是每一个Thread都保存自己的一份变量副本,该变量副本仅本Thread可以访问和修改,其它线程没有权限访问和修改。
ThreadLocal是一个泛型类,其泛型T就是当前Thread保存的变量类型。Thread类中持有一个类型为ThreadLocal.ThreadLocalMap的成员变量threadLocals,其初始化是在ThreadLocal的set()方法调用时完成的。ThreadLocal.ThreadLocalMap是一个特殊的哈希表,仅在ThreadLocal和Thread中被使用。和HashMap使用链表法解决哈希冲突不同,ThreadLocal.ThreadLocalMap使用线性探测法解决哈希冲突。ThreadLocal.ThreadLocalMap中的键值对被封装成为一个Entry对象,为了减少内存泄漏问题,该Entry对象的键是一个WeakReference。
我们可以先来看一下JDK给我们提供的ThreadLocal使用示例:
public class ThreadId {
// Atomic integer containing the next thread ID to be assigned
private static final AtomicInteger nextId = new AtomicInteger(0);
// Thread local variable containing each thread's ID
private static final ThreadLocal<Integer> threadId = ThreadLocal.withInitial(nextId::getAndIncrement);
// Returns the current thread's unique ID, assigning it if necessary
public static int get() {
return threadId.get();
}
}
其中nextId是ThreadId的一个静态成员变量,每次调用threadId.get()方法会先从当前线程的threadLocals里去寻找以threadId为key的键值对ThreadLocalMap.Entry,如果没有找到,则会调用初始化方法往threadLocals里添加以threadId为key的键值对ThreadLocalMap.Entry,该键值对的value值是调用nextId::getAndIncrement方法生成的。
2. ThreadLocal存在的坑
2.1 内存泄漏问题
ThreadLocal可能存在的内存泄漏问题。为什么ThreadLocal.ThreadLocalMap的key需要是弱引用WeakReference?
public class ThreadLocalDemo1 {
public static void main(String[] args) throws InterruptedException {
System.out.println("==========No GC==========");
Thread t1 = new Thread(() -> printThreadLocalMapEntries("我爱北京天安门", false));
t1.start();
t1.join();
System.out.println("==========GC==========");
Thread t2 = new Thread(() -> printThreadLocalMapEntries("天安门上太阳升", true));
t2.start();
t2.join();
}
private static void printThreadLocalMapEntries(String s, boolean gc) {
try {
new ThreadLocal<>().set(s);
if (gc) {
System.gc();
}
Thread t = Thread.currentThread();
Class<? extends Thread> clz = t.getClass();
Field field = clz.getDeclaredField("threadLocals");
field.setAccessible(true);
Object threadLocalMap = field.get(t);
Class<?> tlmClass = threadLocalMap.getClass();
Field tableField = tlmClass.getDeclaredField("table");
tableField.setAccessible(true);
Object[] arr = (Object[]) tableField.get(threadLocalMap);
for (Object o : arr) {
if (null == o) {
continue;
}
Class<?> entryClass = o.getClass();
Field valueField = entryClass.getDeclaredField("value");
Field referenceField = entryClass.getSuperclass().getSuperclass().getDeclaredField("referent");
valueField.setAccessible(true);
referenceField.setAccessible(true);
System.out.println("弱引用key:" + referenceField.get(o) + ", 值:" + valueField.get(o));
}
} catch (Exception e) {
e.printStackTrace();
}
}
}
上述程序的输出结果是:
==========No GC==========
弱引用key:java.lang.ThreadLocal@18dbec87, 值:我爱北京天安门
==========GC==========
弱引用key:null, 值:天安门上太阳升
在GC之前,java内存布局是这样的:
ThreadLocal.ThreadLocalMap的Entry中的key维持着对ThreadLocal的唯一弱引用,一旦发生GC,该ThreadLocal对象会被回收,java内存布局如下:
弱引用的好处就体现在这里了,可以帮助GC回收不再使用的对象,但是这也造成了内存泄漏,因为ThreadLocal.ThreadLocalMap中的Entry对象没有被回收。这其实不是设计者的锅,而是我们使用不当,当我们不需要使用某个ThreadLocal对象时,需要调用其remove()方法将该Entry对象移除出去,remove()方法也会扫描ThreadLocal.ThreadLocalMap中所有的Entry,将key为null的Entry对象都移除出去。
回到上述示例代码,如果将
new ThreadLocal<>().set(s);
修改为:
ThreadLocal<String> threadLocal = new ThreadLocal<>();
threadLocal.set(s);
输出结果如下:
==========No GC==========
弱引用key:java.lang.ThreadLocal@fe2acd2, 值:我爱北京天安门
==========GC==========
弱引用key:java.lang.ThreadLocal@61efd92c, 值:天安门上太阳升
此时GC前的java内存布局是这样的:
除了有ThreadLocal.ThreadLocalMap中的key指向的弱引用外,还有栈中的变量threadLocal指向ThreadLocal对象的引用,因此该ThreadLocal对象不会被GC。
2.2 线程安全问题
如果多个线程的ThreadLocal中保存的是同一个对象,该对象并不具有线程安全性,来看一个例子:
public class ThreadLocalDemo2 {
private static class Person {
private Integer age = 18;
public Integer getAge() {
return age;
}
public void setAge(Integer age) {
this.age = age;
}
@Override
public String toString() {
return "Person{" + "age=" + age + '}';
}
}
private static final ThreadLocal<Person> THREAD_LOCAL = new ThreadLocal<>();
public static void main(String[] args) throws InterruptedException {
Person p = new Person();
setData(p);
Thread subThread1 = new Thread(() -> {
setData(p);
Person data = getAndPrintData();
if (data != null)
data.setAge(100);
getAndPrintData(); // 再打印一次
});
subThread1.start();
subThread1.join();
Thread subThread2 = new Thread(() -> {
setData(p);
getAndPrintData();
});
subThread2.start();
subThread2.join();
// 主线程获取线程绑定内容
getAndPrintData();
System.out.println("======== Finish =========");
}
private static void setData(Person person) {
THREAD_LOCAL.set(person);
}
private static Person getAndPrintData() {
// 拿到当前线程绑定的一个变量,然后做逻辑(本处只打印)
Person person = THREAD_LOCAL.get();
System.out.println("get数据,线程名:" + Thread.currentThread().getName() + ",数据为:" + person);
return person;
}
}
上述程序的输出结果是:
get数据,线程名:Thread-0,数据为:Person{age=18}
get数据,线程名:Thread-0,数据为:Person{age=100}
get数据,线程名:Thread-1,数据为:Person{age=100}
get数据,线程名:main,数据为:Person{age=100}
======== Finish =========
可以看到,set数据的线程是main,Thread-0将age重新置为100,导致Thread-1和main中的Person也发生了改变,根本原因是Thread-0、Thread-1和main线程中的Person都是同一个对象。
这个坑其实也是我们使用不当,当我们使用ThreadLocal时,应该继承ThreadLocal类并重写其initialValue()方法,而不是调用其set方法来设置值,将上述程序改写如下:
public class ThreadLocalDemo3 {
private static class Person {
private Integer age = 18;
public Integer getAge() {
return age;
}
public void setAge(Integer age) {
this.age = age;
}
@Override
public String toString() {
return "Person{" + "age=" + age + '}';
}
}
private static final ThreadLocal<Person> THREAD_LOCAL = ThreadLocal.withInitial(() -> new Person());
public static void main(String[] args) throws InterruptedException {
Thread subThread1 = new Thread(() -> {
Person data = getAndPrintData();
if (data != null)
data.setAge(100);
getAndPrintData(); // 再打印一次
});
subThread1.start();
subThread1.join();
Thread subThread2 = new Thread(() -> {
getAndPrintData();
});
subThread2.start();
subThread2.join();
// 主线程获取线程绑定内容
getAndPrintData();
System.out.println("======== Finish =========");
}
private static void setData(Person person) {
THREAD_LOCAL.set(person);
}
private static Person getAndPrintData() {
// 拿到当前线程绑定的一个变量,然后做逻辑(本处只打印)
Person person = THREAD_LOCAL.get();
System.out.println("get数据,线程名:" + Thread.currentThread().getName() + ",数据为:" + person);
return person;
}
}
上述程序的输出如下:
get数据,线程名:Thread-0,数据为:Person{age=18}
get数据,线程名:Thread-0,数据为:Person{age=100}
get数据,线程名:Thread-1,数据为:Person{age=18}
get数据,线程名:main,数据为:Person{age=18}
======== Finish =========
可以看到,Thread-0中Person对象age的改变不影响Thread-1和main线程中的Person对象,根本原因是每次调用get方法时,我们都新建了一个Person对象,因此每个线程中的Person对象不是同一个。
3. ThreadLocal的实现细节
每一个ThreadLocal对象有一个成员变量threadLocalHashCode,根据threadLocalHashCode来确定该ThreadLocal对象在ThreadLocal.ThreadLocalMap维护的table数组的位置。threadLocalHashCode的获取很有意思,初始值是HASH_INCREMENT(0x61c88647),且每新建一个ThreadLocal对象就新增HASH_INCREMENT,这样做其实是为了降低哈希冲突。
3.1 get()方法
get()方法的逻辑其实很简单:先从当前线程的成员变量threadLocals中寻找当前对象对应的键值对Entry。如果找到了,返回对应的value值;如果没找到,则会调用initialValue()方法新建一个Entry对象将其放进当前线程的成员变量threadLocals中,其具体代码如下:
public T get() {
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null) {
ThreadLocalMap.Entry e = map.getEntry(this);
if (e != null) {
@SuppressWarnings("unchecked")
T result = (T)e.value;
return result;
}
}
return setInitialValue();
}
private T setInitialValue() {
T value = initialValue();
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null)
// 往threadLocals中放入新Entry对象
map.set(this, value);
else
// 此时threadLocals还没有被初始化,直接创建一个即可
createMap(t, value);
return value;
}
再来看一下ThreadLocal.ThreadLocalMap的set()方法逻辑:
private void set(ThreadLocal<?> key, Object value) {
Entry[] tab = table;
int len = tab.length;
// 理论上应该放置的位置是 i
int i = key.threadLocalHashCode & (len-1);
for (Entry e = tab[i]; e != null; e = tab[i = nextIndex(i, len)]) {
ThreadLocal<?> k = e.get();
if (k == key) {
e.value = value;
return;
}
if (k == null) {
// 走到这里,说明当前tab[i]已经过期了,调用replaceStaleEntry(),替换过期的Entry
replaceStaleEntry(key, value, i);
return;
}
}
// 走到这里,说明tab[i]为null
tab[i] = new Entry(key, value);
int sz = ++size;
if (!cleanSomeSlots(i, sz) && sz >= threshold)
// 如果启发式清理过期Entry时没有清理掉任何的Entry,且sz大于等于了扩容阈值threshold,需要扩容
rehash();
}
replaceStaleEntry()方法:
private void replaceStaleEntry(ThreadLocal<?> key, Object value, int staleSlot) {
Entry[] tab = table;
int len = tab.length;
Entry e;
int slotToExpunge = staleSlot;
for (int i = prevIndex(staleSlot, len); (e = tab[i]) != null; i = prevIndex(i, len))
if (e.get() == null)
slotToExpunge = i;
// 走到这里,slotToExpunge代表的是[0, staleSlot]中过期的Entry中,索引最小值,staleSlot处的Entry是过期的
for (int i = nextIndex(staleSlot, len); (e = tab[i]) != null; i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get();
if (k == key) {
// 交换索引i和索引staleSlot处的Entry对象
e.value = value;
tab[i] = tab[staleSlot];
tab[staleSlot] = e;
// 此时索引i处的Entry是过期的
if (slotToExpunge == staleSlot)
slotToExpunge = i;
// 对slotExpunge之后,发现null之前的table数组部分开启扫描式清理
// 对null之后的table数组部分开启启发式清理
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
return;
}
if (k == null && slotToExpunge == staleSlot)
// 如果当前Entry是过期的,且在staleSlot之前没有找到任何过期的Entry,将slotToExpunge置为i
slotToExpunge = i;
}
// 走到这里,说明在遇到tab[i]为null之前,没有遇到相同的key,将当前Entry放置在staleSlot索引位置
tab[staleSlot].value = null;
tab[staleSlot] = new Entry(key, value);
if (slotToExpunge != staleSlot)
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
}
扫描式清理expungeStaleEntry()方法:
private int expungeStaleEntry(int staleSlot) {
Entry[] tab = table;
int len = tab.length;
tab[staleSlot].value = null;
tab[staleSlot] = null;
size--;
Entry e;
int i;
for (i = nextIndex(staleSlot, len); (e = tab[i]) != null; i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get();
if (k == null) {
e.value = null;
tab[i] = null;
size--;
} else {
int h = k.threadLocalHashCode & (len - 1);
if (h != i) {
tab[i] = null;
while (tab[h] != null)
h = nextIndex(h, len);
tab[h] = e;
}
}
}
return i;
}
启发式清理cleanSomeSlots()方法:
private boolean cleanSomeSlots(int i, int n) {
boolean removed = false;
Entry[] tab = table;
int len = tab.length;
do {
i = nextIndex(i, len);
Entry e = tab[i];
if (e != null && e.get() == null) {
// 重置n,开始新一轮启发式扫描
n = len;
removed = true;
i = expungeStaleEntry(i);
}
} while ( (n >>>= 1) != 0);
return removed;
}
扩容操作rehash()方法:
private void rehash() {
expungeStaleEntries();
if (size >= threshold - threshold / 4)
resize();
}
private void resize() {
Entry[] oldTab = table;
int oldLen = oldTab.length;
int newLen = oldLen * 2;
Entry[] newTab = new Entry[newLen];
int count = 0;
for (int j = 0; j < oldLen; ++j) {
Entry e = oldTab[j];
if (e != null) {
ThreadLocal<?> k = e.get();
if (k == null) {
e.value = null; // Help the GC
} else {
int h = k.threadLocalHashCode & (newLen - 1);
while (newTab[h] != null)
h = nextIndex(h, newLen);
newTab[h] = e;
count++;
}
}
}
setThreshold(newLen);
size = count;
table = newTab;
}
3.2 set()方法
public void set(T value) {
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null)
map.set(this, value);
else
createMap(t, value);
}
3.3 remove()方法:
public void remove() {
ThreadLocalMap m = getMap(Thread.currentThread());
if (m != null)
m.remove(this);
}
private void remove(ThreadLocal<?> key) {
Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1);
for (Entry e = tab[i]; e != null; e = tab[i = nextIndex(i, len)]) {
if (e.get() == key) {
e.clear();
expungeStaleEntry(i);
return;
}
}
}
我们可以看到,在remove()方法中会调用expungeStaleEntry()来清除过期的Entry。
4. InheritableThreadLocal
InheritableThreadLocal是ThreadLocal的子类,用于解决ThreadLocal不能在父子线程间共享变量的问题,其实现特别简单,就不再粘贴代码了,其设计思想也和ThreadLocal相同,只是Thread额外持有了另一个ThreadLocal.ThreadLocalMap类型的成员变量inheritableThreadLocals。
5. 总结
ThreadLocal的代码没有HashMap、AbstractQueuedSynchronizer这么复杂,但是在使用上还是存在着许多坑。本文总结了ThreadLocal使用过程中存在的一些小坑,并分析了ThreadLocal关键部分的源码,与读者共勉。