上一章简单介绍了一下Spring boot和Spring Data JPA的整合和简单使用. 但是在实际开发过程中, 我们发现Spring Data JPA提供的接口太简单了,这样就导致需要编写大量的重复代码. 实际上Spring Data JPA提供了很多种扩展方式. 下面就介绍其中的一种.在上一章的代码基础上,做一些修改,我们会发现在Spring boot 中使用Spring Data JPA 更容易一些.
由于篇幅的问题,本章分两部分
1.首先,需要建立一个所有Domain的基类, 这个基类可以什么都不写,也可以写一些基础的字段比如下面的例子
@Data
@MappedSuperclass
public abstract class TableEntity implements Serializable {
@Id
@GeneratedValue(strategy = GenerationType.IDENTITY)
protected Long id;
/** 创建时间 */
@Temporal(TemporalType.TIMESTAMP)
@Column(nullable = false, updatable = false)
protected Date createTime;
/** 最后更新时间 */
@Temporal(TemporalType.TIMESTAMP)
@Column(insertable = false)
protected Date lastModifyTime;
/** 版本号,用于实现乐观锁 */
@Version
@Column(name = "version", nullable = false)
protected int version;
@PrePersist
@PreUpdate
protected void updateDate(){
if(createTime==null){
createTime = new Date();
}
lastModifyTime = new Date();
}
}
这个domain的抽象类描述了四个字段, 分别是ID, 记录创建时间, 最后更新时间, 和实现乐观锁的版本号.这都是ORM中常用的字段, 并且这些字段实现了自动更新.
2. 回忆一下上一章, 我们的Repository是继承了JpaRepository, 我们现在需要自己扩展JpaRepository. 先写一个扩展接口,让它继承JpaRepository
@NoRepositoryBean
public interface BaseJpaRepository<T extends TableEntity, ID extends Serializable> extends JpaRepository<T, ID> {
}
3.这个接口中先什么都不实现, 接着再写一个它的实现
@NoRepositoryBean
@Transactional(readOnly = true)
public class SimpleBaseJpaRepository<T extends TableEntity, ID extends Serializable>
extends SimpleJpaRepository<T, ID> implements
BaseJpaRepository<T, ID> {
}
注意: 上面这两个类头上的@NoRepositoryBean注解标明这不是一个Repository的bean, 不需要spring 来自动实现Impl子类
4.扩展工厂类
仅仅有上面这个扩展接口和扩展实现是不能完成对Spring Data JPA 扩展的. 我们注意到上一章我们只是写了一个 Repository的接口并继承了JpaRepository, 并没有写实现类. 注入这个接口就可以使用它的方法来. 实际上Spring Data JPA依据继承的JpaRepository, 用工厂自动填补了一个子类的. 那么我们虽然扩展了JpaRepository, 但是没有扩展Spring Data JPA的工厂方法,就还是完成扩展. 接下来我们开始扩展工厂类:
public class BaseJpaRepositoryFactory extends JpaRepositoryFactory {
public BaseJpaRepositoryFactory(EntityManager entityManager) {
super(entityManager);
}
@Override
protected Class<?> getRepositoryBaseClass(RepositoryMetadata metadata) {
return SimpleBaseJpaRepository.class;
}
}
注意: 这个类继承了JpaRepositoryFactory, 也就是JpaRepository的工厂类, 并且覆盖了getRepositoryBaseClass的方法,让它返回我们扩展的SimpleBaseJpaRepository.
5.创建JpaRepositoryFactory的Factory
有了扩展工厂类后,还需要有一个创建JpaRepositoryFactory的Factory, 代码示例如下:
public class BaseJpaRepositoryFactoryBean<R extends JpaRepository<T, ID>, T extends TableEntity, ID extends Serializable>
extends JpaRepositoryFactoryBean<R, T, ID> {
public BaseJpaRepositoryFactoryBean(Class<? extends R> repositoryInterface) {
super(repositoryInterface);
}
protected RepositoryFactorySupport createRepositoryFactory(
EntityManager entityManager) {
return new BaseJpaRepositoryFactory(entityManager);
}
}
6.创建RepositoryConfig
到这里,对Spring Data JPA的扩展框架已经搭好了, 在给BaseJpaRepository和SimpleBaseJpaRepository填写扩展方法之前,我们需要对上个例子中的RepositoryConfig进行一下修改. 找到RepositoryConfig这个类, 将@EnableJpaRepositories注解中的repositoryFactoryBeanClass修改为我们自己的工厂的工厂类 BaseJpaRepositoryFactoryBean.class 代码如下:
@Configuration
@EnableTransactionManagement
@EnableJpaRepositories(
basePackages = {"org.learning.repository"},
repositoryFactoryBeanClass = BaseJpaRepositoryFactoryBean.class)
public class RepositoryConfig {
private static final String HIBERNATE_DIALECT = "hibernate.dialect";
private static final String HIBERNATE_SHOW_SQL = "hibernate.show.sql";
private static final String HIBERNATE_HBM2DDL_AUTO = "hibernate.hbm2ddl.auto";
private static final String HIBERNATE_EJB_NAMING_STRATEGY = "hibernate.ejb.naming_strategy";
@Autowired
private DataSource dataSource;
@Value("${spring.jpa.show-sql}")
private String showSql;
@Value("${spring.jpa.generate-ddl}")
private String generateDdl;
@Value("${spring.jpa.hibernate.ddl-auto}")
private String hibernateDdl;
@Value("${spring.jpa.database-platform}")
private String databasePlatform;
@Bean
public LocalContainerEntityManagerFactoryBean entityManagerFactory() {
HibernateJpaVendorAdapter vendorAdapter = new HibernateJpaVendorAdapter();
LocalContainerEntityManagerFactoryBean factory = new LocalContainerEntityManagerFactoryBean();
factory.setJpaVendorAdapter(vendorAdapter);
factory.setDataSource(dataSource);
factory.setPackagesToScan("org.learning.entity");
Map<String, Object> jpaProperties = new HashMap<>();
jpaProperties.put(HIBERNATE_SHOW_SQL, showSql);
jpaProperties.put(HIBERNATE_DIALECT, databasePlatform);
jpaProperties.put(HIBERNATE_HBM2DDL_AUTO, hibernateDdl);
jpaProperties.put(HIBERNATE_EJB_NAMING_STRATEGY, "org.hibernate.cfg.ImprovedNamingStrategy");
factory.setJpaPropertyMap(jpaProperties);
factory.afterPropertiesSet();
return factory;
}
@Bean
public PlatformTransactionManager transactionManager() {
JpaTransactionManager txManager = new JpaTransactionManager();
txManager.setEntityManagerFactory(entityManagerFactory().getObject());
return txManager;
}
}
7.开始填写扩展方法, 首先需要给SimpleBaseJpaRepository创建一个构造函数,代码如下:
@NoRepositoryBean
@Transactional(readOnly = true)
public class SimpleBaseJpaRepository<T extends TableEntity, ID extends Serializable>
extends SimpleJpaRepository<T, ID> implements
BaseJpaRepository<T, ID> {
protected EntityManager entityManager;
private final JpaEntityInformation<T, ?> entityInformation;
protected Class<T> domainClass;
/**
* 父类的构造函数
*
* @param entityInformation
* @param entityManager
*/
public SimpleBaseJpaRepository(
JpaEntityInformation<T, ?> entityInformation, EntityManager entityManager) {
super(entityInformation, entityManager);
this.entityManager = entityManager;
this.entityInformation = entityInformation;
this.domainClass = entityInformation.getJavaType();
}
}
在构造函数中需要调用父类的构造函数把EntityManager和JpaEntityInformation初始化了.
8.自定义一个动态查询器
Spring Data JPA没有提供类似Hibernate中Criteria的动态查询器,但是往往项目中有很多动态查询.那就封装一个吧. 下面是代码示例:
Criterion接口:
public interface Criterion {
/**
* 操作符
*/
enum Operator {
EQ, NE, LIKE, GT, LT, GTE, LTE, AND, OR, ISNULL, ISNOTNULL, LEFTLIKE, RIGHTLIKE, BETWEEN, IN
}
Predicate toPredicate(Root<?> root, CriteriaQuery<?> query, CriteriaBuilder builder);
}
Criteria查询器, 继承Specification:
public class Criteria<T> implements Specification<T> {
private List<Criterion> criterions = new ArrayList<Criterion>();
public Predicate toPredicate(Root<T> root, CriteriaQuery<?> query, CriteriaBuilder builder) {
if (!criterions.isEmpty()) {
List<Predicate> predicates = new ArrayList<Predicate>();
for(Criterion c : criterions){
predicates.add(c.toPredicate(root, query, builder));
}
// 将所有条件用 and 联合起来
if (predicates.size() > 0) {
return builder.and(predicates.toArray(new Predicate[predicates.size()]));
}
}
return builder.conjunction();
}
/**
* 增加简单条件表达式
* @param criterion
*/
public void add(Criterion criterion){
if(criterion!=null){
criterions.add(criterion);
}
}
private boolean isField(Field[] fields, String queryKey) {
if (fields == null || fields.length == 0) {
return false;
}
for (Field field : fields) {
if (field.getName().equals(queryKey)) {
return true;
}
}
return false;
}
}
再建几个表达式类
public class SimpleExpression implements Criterion {
private String fieldName; //属性名
private Object value; //对应值
private Object[] values; //对应值
private Operator operator; //计算符
protected SimpleExpression(String fieldName, Object value, Operator operator) {
this.fieldName = fieldName;
this.value = value;
this.operator = operator;
}
protected SimpleExpression(String fieldName, Operator operator) {
this.fieldName = fieldName;
this.operator = operator;
}
protected SimpleExpression(String fieldName, Operator operator, Object... values) {
this.fieldName = fieldName;
this.values = values;
this.operator = operator;
}
public String getFieldName() {
return fieldName;
}
public Object getValue() {
return value;
}
public Operator getOperator() {
return operator;
}
@SuppressWarnings({ "rawtypes", "unchecked" })
@Override
public Predicate toPredicate(Root<?> root, CriteriaQuery<?> query,
CriteriaBuilder builder) {
Path expression ;
if(fieldName.contains(".")){
String[] names = StringUtils.split(fieldName, ".");
expression = root.get(names[0]);
for (int i = 1; i < names.length; i++) {
expression = expression.get(names[i]);
}
}else{
expression = root.get(fieldName);
}
switch (operator) {
case EQ:
return builder.equal(expression, value);
case NE:
return builder.notEqual(expression, value);
case LIKE:
return builder.like((Expression<String>) expression, "%" + value + "%");
case LEFTLIKE:
return builder.like((Expression<String>) expression, "%" + value);
case RIGHTLIKE:
return builder.like((Expression<String>) expression, value + "%");
case LT:
return builder.lessThan(expression, (Comparable) value);
case GT:
return builder.greaterThan(expression, (Comparable) value);
case LTE:
return builder.lessThanOrEqualTo(expression, (Comparable) value);
case GTE:
return builder.greaterThanOrEqualTo(expression, (Comparable) value);
case ISNULL:
return builder.isNull(expression);
case ISNOTNULL:
return builder.isNotNull(expression);
case IN:
return ((CriteriaBuilderImpl)builder).in(expression, values);
default:
return null;
}
}
}
public class LogicalExpression implements Criterion {
private Criterion[] criterion; // 逻辑表达式中包含的表达式
private Operator operator; //计算符
LogicalExpression(Criterion[] criterions, Operator operator) {
this.criterion = criterions;
this.operator = operator;
}
@Override
public Predicate toPredicate(Root<?> root, CriteriaQuery<?> query,
CriteriaBuilder builder) {
List<Predicate> predicates = new ArrayList<Predicate>();
for (Criterion aCriterion : this.criterion) {
predicates.add(aCriterion.toPredicate(root, query, builder));
}
switch (operator) {
case OR:
return builder.or(predicates.toArray(new Predicate[predicates.size()]));
default:
return builder.and(predicates.toArray(new Predicate[predicates.size()]));
}
}
}
public class BetweenExpression implements Criterion {
private final String fieldName;
private Object lo;
private Object hi;
BetweenExpression(String fieldName, Object lo, Object hi) {
this.fieldName = fieldName;
this.lo = lo;
this.hi = hi;
}
@Override
public Predicate toPredicate(Root<?> root, CriteriaQuery<?> query, CriteriaBuilder builder) {
Path expression ;
if(fieldName.contains(".")){
String[] names = StringUtils.split(fieldName, ".");
expression = root.get(names[0]);
for (int i = 1; i < names.length; i++) {
expression = expression.get(names[i]);
}
}else{
expression = root.get(fieldName);
}
if (lo instanceof Date && hi instanceof Date) {
return builder.between(expression, (Date)lo, (Date)hi);
} else if (lo instanceof String && hi instanceof String) {
return builder.between(expression, (String)lo, (String)hi);
} else if (lo instanceof Integer && hi instanceof Integer) {
return builder.between(expression, (Integer)lo, (Integer)hi);
} else if (lo instanceof Double && hi instanceof Double) {
return builder.between(expression, (Double)lo, (Double)hi);
} else if (lo instanceof BigDecimal && hi instanceof BigDecimal) {
return builder.between(expression, (BigDecimal)lo, (BigDecimal)hi);
} else {
return null;
}
}
}
public class ColumnExpression implements Criterion {
private final String fieldNameA;
private final String fieldNameB;
private Operator operator; //计算符
ColumnExpression(String fieldNameA, String fieldNameB, Operator operator) {
this.fieldNameA = fieldNameA;
this.fieldNameB = fieldNameB;
this.operator = operator;
}
@Override
public Predicate toPredicate(Root<?> root, CriteriaQuery<?> query, CriteriaBuilder builder) {
Path expressionA ;
if(fieldNameA.contains(".")){
String[] names = StringUtils.split(fieldNameA, ".");
expressionA = root.get(names[0]);
for (int i = 1; i < names.length; i++) {
expressionA = expressionA.get(names[i]);
}
}else{
expressionA = root.get(fieldNameA);
}
Path expressionB ;
if(fieldNameB.contains(".")){
String[] names = StringUtils.split(fieldNameB, ".");
expressionB = root.get(names[0]);
for (int i = 1; i < names.length; i++) {
expressionB = expressionB.get(names[i]);
}
}else{
expressionB = root.get(fieldNameB);
}
switch (operator) {
case EQ:
return builder.equal(expressionA, expressionB);
case NE:
return builder.notEqual(expressionA, expressionB);
default:
return null;
}
}
}
最后是条件比较器Restrictions:
public class Restrictions {
/**
* 等于
* @param fieldName
* @param value
* @return
*/
public static SimpleExpression eq(String fieldName, Object value) {
if(value == null) {
return null;
}
return new SimpleExpression (fieldName, value, Criterion.Operator.EQ);
}
/**
* 不等于
* @param fieldName
* @param value
* @return
*/
public static SimpleExpression ne(String fieldName, Object value) {
if(value == null) {
return null;
}
return new SimpleExpression (fieldName, value, Criterion.Operator.NE);
}
/**
* 模糊匹配
* @param fieldName
* @param value
* @return
*/
public static SimpleExpression like(String fieldName, String value) {
if(value == null) {
return null;
}
return new SimpleExpression (fieldName, value, Criterion.Operator.LIKE);
}
/**
* 模糊匹配
* @param fieldName
* @param value
* @return
*/
public static SimpleExpression leftLike(String fieldName, String value) {
if(value == null) {
return null;
}
return new SimpleExpression (fieldName, value, Criterion.Operator.LEFTLIKE);
}
/**
* 模糊匹配
* @param fieldName
* @param value
* @return
*/
public static SimpleExpression rightLike(String fieldName, String value) {
if(value == null) {
return null;
}
return new SimpleExpression (fieldName, value, Criterion.Operator.RIGHTLIKE);
}
/**
* 大于
* @param fieldName
* @param value
* @return
*/
public static SimpleExpression gt(String fieldName, Object value) {
if(value == null) {
return null;
}
return new SimpleExpression (fieldName, value, Criterion.Operator.GT);
}
/**
* 小于
* @param fieldName
* @param value
* @return
*/
public static SimpleExpression lt(String fieldName, Object value) {
if(value == null) {
return null;
}
return new SimpleExpression (fieldName, value, Criterion.Operator.LT);
}
/**
* 大于等于
* @param fieldName
* @param value
* @return
*/
public static SimpleExpression gte(String fieldName, Object value) {
if(value == null) {
return null;
}
return new SimpleExpression (fieldName, value, Criterion.Operator.GTE);
}
/**
* 小于等于
* @param fieldName
* @param value
* @return
*/
public static SimpleExpression lte(String fieldName, Object value) {
if(value == null) {
return null;
}
return new SimpleExpression (fieldName, value, Criterion.Operator.LTE);
}
/**
* 并且
* @param criterions
* @return
*/
public static LogicalExpression and(Criterion... criterions){
return new LogicalExpression(criterions, Criterion.Operator.AND);
}
/**
* 或者
* @param criterions
* @return
*/
public static LogicalExpression or(Criterion... criterions){
return new LogicalExpression(criterions, Criterion.Operator.OR);
}
/**
* 包含于
* @param fieldName
* @param value
* @return
*/
@SuppressWarnings("rawtypes")
public static SimpleExpression in(String fieldName, Collection value) {
return new SimpleExpression(fieldName, Criterion.Operator.IN, value.toArray());
}
/**
* 不包含于
* @param fieldName
* @param value
* @return
*/
@SuppressWarnings("rawtypes")
public static LogicalExpression notIn(String fieldName, Collection value) {
if((value==null||value.isEmpty())){
return null;
}
SimpleExpression[] ses = new SimpleExpression[value.size()];
int i=0;
for(Object obj : value){
ses[i]=new SimpleExpression(fieldName,obj, Criterion.Operator.NE);
i++;
}
return new LogicalExpression(ses, Criterion.Operator.AND);
}
/**
* 列为空
* @param fieldName
* @return
*/
public static LogicalExpression isNull(String fieldName) {
SimpleExpression[] ses = new SimpleExpression[1];
ses[0] = new SimpleExpression(fieldName, Criterion.Operator.ISNULL);
return new LogicalExpression(ses, Criterion.Operator.ISNULL);
}
/**
* 列不为空
* @param fieldName
* @return
*/
public static LogicalExpression isNotNull(String fieldName) {
SimpleExpression[] ses = new SimpleExpression[1];
ses[0] = new SimpleExpression(fieldName, Criterion.Operator.ISNOTNULL);
return new LogicalExpression(ses, Criterion.Operator.ISNOTNULL);
}
/**
* 时间范围
* @param fieldName
* @return
*/
public static LogicalExpression between(String fieldName, Object startDate, Object endDate) {
BetweenExpression[] bes = new BetweenExpression[1];
bes[0] = new BetweenExpression(fieldName, startDate, endDate);
return new LogicalExpression(bes, Criterion.Operator.BETWEEN);
}
/**
* 等于
* @param fieldNameA
* @param fieldNameB
* @return
*/
public static ColumnExpression eqCol(String fieldNameA, String fieldNameB) {
if(StringUtil.isBlank(fieldNameA) || StringUtil.isBlank(fieldNameB)) {
return null;
}
return new ColumnExpression(fieldNameA, fieldNameB, Criterion.Operator.EQ);
}
/**
* bu等于
* @param fieldNameA
* @param fieldNameB
* @return
*/
public static ColumnExpression neCol(String fieldNameA, String fieldNameB) {
if(StringUtil.isBlank(fieldNameA) || StringUtil.isBlank(fieldNameB)) {
return null;
}
return new ColumnExpression(fieldNameA, fieldNameB, Criterion.Operator.NE);
}
/**
* 时间范围
* @param fieldName
* @return
*/
public static LogicalExpression eqDate(String fieldName, Object date) {
BetweenExpression[] bes = new BetweenExpression[1];
Date startDate;
Date endDate;
try {
if(date instanceof String){
startDate = stringToDateTime(date.toString() + " 00:00:00");
endDate = stringToDateTime(date.toString() + " 23:59:59");
}else if(date instanceof Date){
Calendar calendar = Calendar.getInstance();
calendar.setTime((Date)date);
calendar.set(Calendar.HOUR_OF_DAY, 0);
calendar.set(Calendar.MINUTE, 0);
calendar.set(Calendar.SECOND,0);
startDate = calendar.getTime();
calendar.set(Calendar.HOUR_OF_DAY,23);
calendar.set(Calendar.MINUTE,59);
calendar.set(Calendar.SECOND,59);
endDate = calendar.getTime();
}else{
return null;
}
}catch (Exception ignored){
return null;
}
bes[0] = new BetweenExpression(fieldName, startDate, endDate);
return new LogicalExpression(bes, Criterion.Operator.BETWEEN);
}
private final static String DATETIME_PATTERN = "yyyy-MM-dd HH:mm:ss";
private static Date stringToDateTime(String dateString) {
if (dateString == null) {
return null;
}
try {
DateFormat df = new SimpleDateFormat(DATETIME_PATTERN);
df.setLenient(false);
return df.parse(dateString);
} catch (ParseException e) {
return null;
}
}
}
本章结束
本章介绍了jpa封装的部分代码, 下一章接着展示后面的代码