ThreadLocal笔记

  • Post author:
  • Post category:其他




1. 概述

ThreadLocal中文名叫做线程本地变量,用于解决变量在多线程问题下的并发安全问题,其思路是每一个Thread都保存自己的一份变量副本,该变量副本仅本Thread可以访问和修改,其它线程没有权限访问和修改。

ThreadLocal是一个泛型类,其泛型T就是当前Thread保存的变量类型。Thread类中持有一个类型为ThreadLocal.ThreadLocalMap的成员变量threadLocals,其初始化是在ThreadLocal的set()方法调用时完成的。ThreadLocal.ThreadLocalMap是一个特殊的哈希表,仅在ThreadLocal和Thread中被使用。和HashMap使用链表法解决哈希冲突不同,ThreadLocal.ThreadLocalMap使用线性探测法解决哈希冲突。ThreadLocal.ThreadLocalMap中的键值对被封装成为一个Entry对象,为了减少内存泄漏问题,该Entry对象的键是一个WeakReference。

我们可以先来看一下JDK给我们提供的ThreadLocal使用示例:

public class ThreadId {
  
    // Atomic integer containing the next thread ID to be assigned
    private static final AtomicInteger nextId = new AtomicInteger(0);

    // Thread local variable containing each thread's ID
    private static final ThreadLocal<Integer> threadId = ThreadLocal.withInitial(nextId::getAndIncrement);

    // Returns the current thread's unique ID, assigning it if necessary
    public static int get() {
        return threadId.get();
    }
  
}

其中nextId是ThreadId的一个静态成员变量,每次调用threadId.get()方法会先从当前线程的threadLocals里去寻找以threadId为key的键值对ThreadLocalMap.Entry,如果没有找到,则会调用初始化方法往threadLocals里添加以threadId为key的键值对ThreadLocalMap.Entry,该键值对的value值是调用nextId::getAndIncrement方法生成的。



2. ThreadLocal存在的坑



2.1 内存泄漏问题

ThreadLocal可能存在的内存泄漏问题。为什么ThreadLocal.ThreadLocalMap的key需要是弱引用WeakReference?

public class ThreadLocalDemo1 {

    public static void main(String[] args) throws InterruptedException {
        System.out.println("==========No GC==========");
        Thread t1 = new Thread(() -> printThreadLocalMapEntries("我爱北京天安门", false));
        t1.start();
        t1.join();
        System.out.println("==========GC==========");
        Thread t2 = new Thread(() -> printThreadLocalMapEntries("天安门上太阳升", true));
        t2.start();
        t2.join();
    }

    private static void printThreadLocalMapEntries(String s, boolean gc) {
        try {
            new ThreadLocal<>().set(s);
            if (gc) {
                System.gc();
            }
            Thread t = Thread.currentThread();
            Class<? extends Thread> clz = t.getClass();
            Field field = clz.getDeclaredField("threadLocals");
            field.setAccessible(true);
            Object threadLocalMap = field.get(t);
            Class<?> tlmClass = threadLocalMap.getClass();
            Field tableField = tlmClass.getDeclaredField("table");
            tableField.setAccessible(true);
            Object[] arr = (Object[]) tableField.get(threadLocalMap);
            for (Object o : arr) {
                if (null == o) {
                    continue;
                }
                Class<?> entryClass = o.getClass();
                Field valueField = entryClass.getDeclaredField("value");
                Field referenceField = entryClass.getSuperclass().getSuperclass().getDeclaredField("referent");
                valueField.setAccessible(true);
                referenceField.setAccessible(true);
                System.out.println("弱引用key:" + referenceField.get(o) + ", 值:" + valueField.get(o));
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

}

上述程序的输出结果是:

==========No GC==========
弱引用key:java.lang.ThreadLocal@18dbec87, 值:我爱北京天安门
==========GC==========
弱引用key:null, 值:天安门上太阳升

在GC之前,java内存布局是这样的:

在这里插入图片描述

ThreadLocal.ThreadLocalMap的Entry中的key维持着对ThreadLocal的唯一弱引用,一旦发生GC,该ThreadLocal对象会被回收,java内存布局如下:

在这里插入图片描述

弱引用的好处就体现在这里了,可以帮助GC回收不再使用的对象,但是这也造成了内存泄漏,因为ThreadLocal.ThreadLocalMap中的Entry对象没有被回收。这其实不是设计者的锅,而是我们使用不当,当我们不需要使用某个ThreadLocal对象时,需要调用其remove()方法将该Entry对象移除出去,remove()方法也会扫描ThreadLocal.ThreadLocalMap中所有的Entry,将key为null的Entry对象都移除出去。

回到上述示例代码,如果将

new ThreadLocal<>().set(s);

修改为:

ThreadLocal<String> threadLocal = new ThreadLocal<>();
threadLocal.set(s);

输出结果如下:

==========No GC==========
弱引用key:java.lang.ThreadLocal@fe2acd2, 值:我爱北京天安门
==========GC==========
弱引用key:java.lang.ThreadLocal@61efd92c, 值:天安门上太阳升

此时GC前的java内存布局是这样的:

在这里插入图片描述

除了有ThreadLocal.ThreadLocalMap中的key指向的弱引用外,还有栈中的变量threadLocal指向ThreadLocal对象的引用,因此该ThreadLocal对象不会被GC。



2.2 线程安全问题

如果多个线程的ThreadLocal中保存的是同一个对象,该对象并不具有线程安全性,来看一个例子:

public class ThreadLocalDemo2 {

    private static class Person {

        private Integer age = 18;

        public Integer getAge() {
            return age;
        }

        public void setAge(Integer age) {
            this.age = age;
        }

        @Override
        public String toString() {
            return "Person{" + "age=" + age + '}';
        }

    }

    private static final ThreadLocal<Person> THREAD_LOCAL = new ThreadLocal<>();

    public static void main(String[] args) throws InterruptedException {
        Person p = new Person();
        setData(p);

        Thread subThread1 = new Thread(() -> {
            setData(p);
            Person data = getAndPrintData();
            if (data != null)
                data.setAge(100);
            getAndPrintData(); // 再打印一次
        });
        subThread1.start();
        subThread1.join();

        Thread subThread2 = new Thread(() -> {
            setData(p);
            getAndPrintData();
        });
        subThread2.start();
        subThread2.join();

        // 主线程获取线程绑定内容
        getAndPrintData();
        System.out.println("======== Finish =========");
    }


    private static void setData(Person person) {
        THREAD_LOCAL.set(person);
    }

    private static Person getAndPrintData() {
        // 拿到当前线程绑定的一个变量,然后做逻辑(本处只打印)
        Person person = THREAD_LOCAL.get();
        System.out.println("get数据,线程名:" + Thread.currentThread().getName() + ",数据为:" + person);
        return person;
    }

}

上述程序的输出结果是:

get数据,线程名:Thread-0,数据为:Person{age=18}
get数据,线程名:Thread-0,数据为:Person{age=100}
get数据,线程名:Thread-1,数据为:Person{age=100}
get数据,线程名:main,数据为:Person{age=100}
======== Finish =========

可以看到,set数据的线程是main,Thread-0将age重新置为100,导致Thread-1和main中的Person也发生了改变,根本原因是Thread-0、Thread-1和main线程中的Person都是同一个对象。

这个坑其实也是我们使用不当,当我们使用ThreadLocal时,应该继承ThreadLocal类并重写其initialValue()方法,而不是调用其set方法来设置值,将上述程序改写如下:

public class ThreadLocalDemo3 {

    private static class Person {

        private Integer age = 18;

        public Integer getAge() {
            return age;
        }

        public void setAge(Integer age) {
            this.age = age;
        }

        @Override
        public String toString() {
            return "Person{" + "age=" + age + '}';
        }

    }

    private static final ThreadLocal<Person> THREAD_LOCAL = ThreadLocal.withInitial(() -> new Person());

    public static void main(String[] args) throws InterruptedException {
        Thread subThread1 = new Thread(() -> {
            Person data = getAndPrintData();
            if (data != null)
                data.setAge(100);
            getAndPrintData(); // 再打印一次
        });
        subThread1.start();
        subThread1.join();

        Thread subThread2 = new Thread(() -> {
            getAndPrintData();
        });
        subThread2.start();
        subThread2.join();

        // 主线程获取线程绑定内容
        getAndPrintData();
        System.out.println("======== Finish =========");
    }


    private static void setData(Person person) {
        THREAD_LOCAL.set(person);
    }

    private static Person getAndPrintData() {
        // 拿到当前线程绑定的一个变量,然后做逻辑(本处只打印)
        Person person = THREAD_LOCAL.get();
        System.out.println("get数据,线程名:" + Thread.currentThread().getName() + ",数据为:" + person);
        return person;
    }
  
}

上述程序的输出如下:

get数据,线程名:Thread-0,数据为:Person{age=18}
get数据,线程名:Thread-0,数据为:Person{age=100}
get数据,线程名:Thread-1,数据为:Person{age=18}
get数据,线程名:main,数据为:Person{age=18}
======== Finish =========

可以看到,Thread-0中Person对象age的改变不影响Thread-1和main线程中的Person对象,根本原因是每次调用get方法时,我们都新建了一个Person对象,因此每个线程中的Person对象不是同一个。



3. ThreadLocal的实现细节

每一个ThreadLocal对象有一个成员变量threadLocalHashCode,根据threadLocalHashCode来确定该ThreadLocal对象在ThreadLocal.ThreadLocalMap维护的table数组的位置。threadLocalHashCode的获取很有意思,初始值是HASH_INCREMENT(0x61c88647),且每新建一个ThreadLocal对象就新增HASH_INCREMENT,这样做其实是为了降低哈希冲突。



3.1 get()方法

get()方法的逻辑其实很简单:先从当前线程的成员变量threadLocals中寻找当前对象对应的键值对Entry。如果找到了,返回对应的value值;如果没找到,则会调用initialValue()方法新建一个Entry对象将其放进当前线程的成员变量threadLocals中,其具体代码如下:

public T get() {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null) {
        ThreadLocalMap.Entry e = map.getEntry(this);
        if (e != null) {
            @SuppressWarnings("unchecked")
            T result = (T)e.value;
            return result;
        }
    }
    return setInitialValue();
}

private T setInitialValue() {
    T value = initialValue();
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null)
        // 往threadLocals中放入新Entry对象
        map.set(this, value);
    else
        // 此时threadLocals还没有被初始化,直接创建一个即可
        createMap(t, value);
    return value;
}

再来看一下ThreadLocal.ThreadLocalMap的set()方法逻辑:

private void set(ThreadLocal<?> key, Object value) {
    Entry[] tab = table;
    int len = tab.length;
    // 理论上应该放置的位置是 i
    int i = key.threadLocalHashCode & (len-1);
    for (Entry e = tab[i]; e != null; e = tab[i = nextIndex(i, len)]) {
        ThreadLocal<?> k = e.get();
        if (k == key) {
            e.value = value;
            return;
        }
        if (k == null) {
            // 走到这里,说明当前tab[i]已经过期了,调用replaceStaleEntry(),替换过期的Entry
            replaceStaleEntry(key, value, i);
            return;
        }
    }
    // 走到这里,说明tab[i]为null
    tab[i] = new Entry(key, value);
    int sz = ++size;
    if (!cleanSomeSlots(i, sz) && sz >= threshold)
        // 如果启发式清理过期Entry时没有清理掉任何的Entry,且sz大于等于了扩容阈值threshold,需要扩容
        rehash();
}

replaceStaleEntry()方法:

private void replaceStaleEntry(ThreadLocal<?> key, Object value, int staleSlot) {
    Entry[] tab = table;
    int len = tab.length;
    Entry e;
    int slotToExpunge = staleSlot;
    for (int i = prevIndex(staleSlot, len); (e = tab[i]) != null; i = prevIndex(i, len))
        if (e.get() == null)
            slotToExpunge = i;
    // 走到这里,slotToExpunge代表的是[0, staleSlot]中过期的Entry中,索引最小值,staleSlot处的Entry是过期的
    for (int i = nextIndex(staleSlot, len); (e = tab[i]) != null; i = nextIndex(i, len)) {
        ThreadLocal<?> k = e.get();
        if (k == key) {
            // 交换索引i和索引staleSlot处的Entry对象
            e.value = value;
            tab[i] = tab[staleSlot];
            tab[staleSlot] = e;
						// 此时索引i处的Entry是过期的
            if (slotToExpunge == staleSlot)
                slotToExpunge = i;
            // 对slotExpunge之后,发现null之前的table数组部分开启扫描式清理
            // 对null之后的table数组部分开启启发式清理
            cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
            return;
        }
        if (k == null && slotToExpunge == staleSlot)
            // 如果当前Entry是过期的,且在staleSlot之前没有找到任何过期的Entry,将slotToExpunge置为i
            slotToExpunge = i;
    }
    // 走到这里,说明在遇到tab[i]为null之前,没有遇到相同的key,将当前Entry放置在staleSlot索引位置
    tab[staleSlot].value = null;
    tab[staleSlot] = new Entry(key, value);
    if (slotToExpunge != staleSlot)
        cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
}

扫描式清理expungeStaleEntry()方法:

private int expungeStaleEntry(int staleSlot) {
    Entry[] tab = table;
    int len = tab.length;
    tab[staleSlot].value = null;
    tab[staleSlot] = null;
    size--;
    Entry e;
    int i;
    for (i = nextIndex(staleSlot, len); (e = tab[i]) != null; i = nextIndex(i, len)) {
        ThreadLocal<?> k = e.get();
        if (k == null) {
            e.value = null;
            tab[i] = null;
            size--;
        } else {
            int h = k.threadLocalHashCode & (len - 1);
            if (h != i) {
                tab[i] = null;
                while (tab[h] != null)
                    h = nextIndex(h, len);
                tab[h] = e;
            }
        }
    }
    return i;
}

启发式清理cleanSomeSlots()方法:

private boolean cleanSomeSlots(int i, int n) {
    boolean removed = false;
    Entry[] tab = table;
    int len = tab.length;
    do {
        i = nextIndex(i, len);
        Entry e = tab[i];
        if (e != null && e.get() == null) {
            // 重置n,开始新一轮启发式扫描
            n = len;
            removed = true;
            i = expungeStaleEntry(i);
        }
    } while ( (n >>>= 1) != 0);
    return removed;
}

扩容操作rehash()方法:

private void rehash() {
    expungeStaleEntries();
    if (size >= threshold - threshold / 4)
        resize();
}

private void resize() {
    Entry[] oldTab = table;
    int oldLen = oldTab.length;
    int newLen = oldLen * 2;
    Entry[] newTab = new Entry[newLen];
    int count = 0;
    for (int j = 0; j < oldLen; ++j) {
        Entry e = oldTab[j];
        if (e != null) {
            ThreadLocal<?> k = e.get();
            if (k == null) {
                e.value = null; // Help the GC
            } else {
                int h = k.threadLocalHashCode & (newLen - 1);
                while (newTab[h] != null)
                    h = nextIndex(h, newLen);
                newTab[h] = e;
                count++;
            }
        }
    }
    setThreshold(newLen);
    size = count;
    table = newTab;
}



3.2 set()方法

public void set(T value) {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value);
    else
        createMap(t, value);
}



3.3 remove()方法:

public void remove() {
    ThreadLocalMap m = getMap(Thread.currentThread());
    if (m != null)
        m.remove(this);
}

private void remove(ThreadLocal<?> key) {
    Entry[] tab = table;
    int len = tab.length;
    int i = key.threadLocalHashCode & (len-1);
    for (Entry e = tab[i]; e != null; e = tab[i = nextIndex(i, len)]) {
        if (e.get() == key) {
            e.clear();
            expungeStaleEntry(i);
            return;
        }
    }
}

我们可以看到,在remove()方法中会调用expungeStaleEntry()来清除过期的Entry。



4. InheritableThreadLocal

InheritableThreadLocal是ThreadLocal的子类,用于解决ThreadLocal不能在父子线程间共享变量的问题,其实现特别简单,就不再粘贴代码了,其设计思想也和ThreadLocal相同,只是Thread额外持有了另一个ThreadLocal.ThreadLocalMap类型的成员变量inheritableThreadLocals。



5. 总结

ThreadLocal的代码没有HashMap、AbstractQueuedSynchronizer这么复杂,但是在使用上还是存在着许多坑。本文总结了ThreadLocal使用过程中存在的一些小坑,并分析了ThreadLocal关键部分的源码,与读者共勉。



版权声明:本文为qq_41231926原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。