SpringBoot使用MongoTemplate动态切换MongoDB

  • Post author:
  • Post category:其他


SpringBoot使用MongoTemplate动态切换Mongo

当前有一个项目包含一套代码和1个mongo的数据库,但要给2个部门用户使用,而且要做数据隔离,现有做法是部署2套代码和2个mongo数据库,2个域名,这样就能各自访问各自的系统,做到隔离。但为节省资源和减少运维工作,计划将2个系统合并为一个系统,一套代码部署,2个数据库。那么如何使用最少的成本完成系统升级呢。

计划的方案,2个数据库,一个为A,一个为B。在http添加一个header参数来区分需要访问的是A还是B。在代码中获取header参数来切换数据库。

代码是springboot架构。mongo部分代码是Repository->MongoTemplate,MongoTempla做CURD操作,所以准备使用AOP的方式,在调用MongoTemplate做CURD时切换数据库。就是实例化连接2个数据库的MongoTemplate,然后通过AOP在Repository调用时,根据不同的数据库重新赋值MongoTemplate。在网上搜索了下,基本也是这个思路。看来这个方案是可行的。但实际开发中有一个问题,AOP不是线程安全的,而MongoTemplate是单例,这样并发时MongoTemplate会被错误的赋值。所以还有一个急需解决的问题就是线程安全。首先想的办法是同步锁,在AOP时,使用synchronized将赋值和jionpoitn.proceed()同步,这样解决了线程安全的问题。但这样访问都变成的单线程模式,性能大幅下降,只好改变思路,给MongoTemplate赋值的数据库连接保存到线程变量里面,MongTemplate在进行数据库操作时使用当前线程的连接,查看了下MongoTemplate的源码,看到一个getDB(),每次find时通过getDB()获取的数据库连接,于是产生一个思路,写一个MongTemplate的子类,定义一个ThreadLoacl,AOP时将对应数据库的MongoFactory保存到ThreadLocal中,重载getDB(),ThreadLocal.get().getDB(),然后将这个子重新类赋值到Repository中替换掉已有的MongoTemplate。这样保证线程安全,而且性能太大的损失。

切换的AOP代码

import com.mongodb.MongoClient;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.autoconfigure.mongo.MongoProperties;
import org.springframework.data.mongodb.MongoDbFactory;
import org.springframework.data.mongodb.core.MongoTemplate;
import org.springframework.data.mongodb.core.SimpleMongoDbFactory;
import org.springframework.stereotype.Component;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import ucan.xdf.com.config.MongoDBConfigProperties;
import javax.servlet.http.HttpServletRequest;
import java.lang.reflect.Field;
import java.util.HashMap;
import java.util.Map;

@Component
@Aspect
public class MongoSwitch {
    private final Logger logger = LoggerFactory.getLogger(MongoSwitch.class);
    @Autowired
    private MongoDBConfigProperties mongoDBConfigProperties;
    @Autowired
    private MongoDbFactory mongoDbFactory;
    private Map<String,MongoDbFactory> templateMuliteMap=new HashMap<>();


    @Pointcut("execution(* xxx.xxx.com.dao..*.*(..))")
    public void routeMongoDB() {

    }

    @Around("routeMongoDB()")
    public Object routeMongoDB(ProceedingJoinPoint joinPoint) {
        Object result = null;
        HttpServletRequest request = ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest();
        String uri = request.getRequestURL().toString();
       //获取需要访问的项目数据库
        String subject = request.getHeader("subject");
        String name = joinPoint.getSignature().getName();
        Object o = joinPoint.getTarget();
        Field[] fields = o.getClass().getDeclaredFields();
        MultiMongoTemplate mongoTemplate = null;

        try {
            for (Field field : fields) {
                field.setAccessible(true);
                Object fieldObject = field.get(o);
                Class fieldclass = fieldObject.getClass();
                 //找到Template的变量
                if (fieldclass == MongoTemplate.class || fieldclass == MultiMongoTemplate.class) {
                    //查找项目对应的MongFactory
                    SimpleMongoDbFactory simpleMongoDbFactory=(SimpleMongoDbFactory)templateMuliteMap.get(subject);
						//实例化
                    if(simpleMongoDbFactory==null){
                        Field propertiesField = MongoDBConfigProperties.class.getDeclaredField(subject);
                        propertiesField.setAccessible(true);
                        MongoProperties properties = (MongoProperties) propertiesField.get(mongoDBConfigProperties);
                        simpleMongoDbFactory = new SimpleMongoDbFactory(new MongoClient(properties.getHost(), properties.getPort()), properties.getDatabase());
                        templateMuliteMap.put(subject,simpleMongoDbFactory);
                    }
						//如果第一次,赋值成自定义的MongoTemplate子类
                    if(fieldclass==MongoTemplate.class){
                        mongoTemplate = new MultiMongoTemplate(simpleMongoDbFactory);
                    }else if(fieldclass==MultiMongoTemplate.class){
                        mongoTemplate=(MultiMongoTemplate)fieldObject;
                    }
      //设置MongoFactory        
       mongoTemplate.setMongoDbFactory(simpleMongoDbFactory);
                     //重新赋值
                    field.set(o, mongoTemplate);
                    break;
                }
            }
            try {
                result = joinPoint.proceed();
                //清理ThreadLocal的变量
                mongoTemplate.removeMongoDbFactory();
            } catch (Throwable t) {
                logger.error("", t);
            }
        } catch (Exception e) {
            logger.error("", e);
        }

        return result;
    }
}

MongoTemplate子类

import com.mongodb.client.MongoDatabase;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.data.mongodb.MongoDbFactory;
import org.springframework.data.mongodb.core.MongoTemplate;

public class MultiMongoTemplate extends MongoTemplate {
    private Logger logger= LoggerFactory.getLogger(MultiMongoTemplate.class);
    private static ThreadLocal<MongoDbFactory> mongoDbFactoryThreadLocal;
    public MultiMongoTemplate(MongoDbFactory mongoDbFactory){
        super(mongoDbFactory);
        if(mongoDbFactoryThreadLocal==null) {
            mongoDbFactoryThreadLocal = new ThreadLocal<>();
        }
    }

    public void setMongoDbFactory(MongoDbFactory factory){
        mongoDbFactoryThreadLocal.set(factory);
    }

    public void removeMongoDbFactory(){
        mongoDbFactoryThreadLocal.remove();
    }

    @Override
    public MongoDatabase getDb() {
        return mongoDbFactoryThreadLocal.get().getDb();
    }
}



版权声明:本文为JavaAndJava原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。