在mybatis执行SQL语句之前进行拦截处理

  • Post author:
  • Post category:其他


比较适用于在分页时候进行拦截。对分页的SQL语句通过封装处理,处理成不同的分页sql。

实用性比较强。


  1. import

    java.sql.Connection;

  2. import

    java.sql.PreparedStatement;

  3. import

    java.sql.ResultSet;

  4. import

    java.sql.SQLException;

  5. import

    java.util.List;

  6. import

    java.util.Properties;

  7. import

    org.apache.ibatis.executor.parameter.ParameterHandler;

  8. import

    org.apache.ibatis.executor.statement.RoutingStatementHandler;

  9. import

    org.apache.ibatis.executor.statement.StatementHandler;

  10. import

    org.apache.ibatis.mapping.BoundSql;

  11. import

    org.apache.ibatis.mapping.MappedStatement;

  12. import

    org.apache.ibatis.mapping.ParameterMapping;

  13. import

    org.apache.ibatis.plugin.Interceptor;

  14. import

    org.apache.ibatis.plugin.Intercepts;

  15. import

    org.apache.ibatis.plugin.Invocation;

  16. import

    org.apache.ibatis.plugin.Plugin;

  17. import

    org.apache.ibatis.plugin.Signature;

  18. import

    org.apache.ibatis.scripting.defaults.DefaultParameterHandler;

  19. import

    com.yidao.utils.Page;

  20. import

    com.yidao.utils.ReflectHelper;

  21. /**

  22. *

  23. * 分页拦截器,用于拦截需要进行分页查询的操作,然后对其进行分页处理。

  24. * 利用拦截器实现Mybatis分页的原理:

  25. * 要利用JDBC对数据库进行操作就必须要有一个对应的Statement对象,Mybatis在执行Sql语句前就会产生一个包含Sql语句的Statement对象,而且对应的Sql语句

  26. * 是在Statement之前产生的,所以我们就可以在它生成Statement之前对用来生成Statement的Sql语句下手。在Mybatis中Statement语句是通过RoutingStatementHandler对象的

  27. * prepare方法生成的。所以利用拦截器实现Mybatis分页的一个思路就是拦截StatementHandler接口的prepare方法,然后在拦截器方法中把Sql语句改成对应的分页查询Sql语句,之后再调用

  28. * StatementHandler对象的prepare方法,即调用invocation.proceed()。

  29. * 对于分页而言,在拦截器里面我们还需要做的一个操作就是统计满足当前条件的记录一共有多少,这是通过获取到了原始的Sql语句后,把它改为对应的统计语句再利用Mybatis封装好的参数和设

  30. * 置参数的功能把Sql语句中的参数进行替换,之后再执行查询记录数的Sql语句进行总记录数的统计。

  31. *

  32. */

  33. @Intercepts

    ({


    @Signature

    (type=StatementHandler.

    class

    ,method=

    “prepare”

    ,args={Connection.

    class

    })})

  34. public


    class

    PageInterceptor

    implements

    Interceptor {

  35. private

    String dialect =

    “”

    ;

    //数据库方言

  36. private

    String pageSqlId =

    “”

    ;

    //mapper.xml中需要拦截的ID(正则匹配)

  37. public

    Object intercept(Invocation invocation)

    throws

    Throwable {

  38. //对于StatementHandler其实只有两个实现类,一个是RoutingStatementHandler,另一个是抽象类BaseStatementHandler,

  39. //BaseStatementHandler有三个子类,分别是SimpleStatementHandler,PreparedStatementHandler和CallableStatementHandler,

  40. //SimpleStatementHandler是用于处理Statement的,PreparedStatementHandler是处理PreparedStatement的,而CallableStatementHandler是

  41. //处理CallableStatement的。Mybatis在进行Sql语句处理的时候都是建立的RoutingStatementHandler,而在RoutingStatementHandler里面拥有一个

  42. //StatementHandler类型的delegate属性,RoutingStatementHandler会依据Statement的不同建立对应的BaseStatementHandler,即SimpleStatementHandler、

  43. //PreparedStatementHandler或CallableStatementHandler,在RoutingStatementHandler里面所有StatementHandler接口方法的实现都是调用的delegate对应的方法。

  44. //我们在PageInterceptor类上已经用@Signature标记了该Interceptor只拦截StatementHandler接口的prepare方法,又因为Mybatis只有在建立RoutingStatementHandler的时候

  45. //是通过Interceptor的plugin方法进行包裹的,所以我们这里拦截到的目标对象肯定是RoutingStatementHandler对象。

  46. if

    (invocation.getTarget()

    instanceof

    RoutingStatementHandler){
  47. RoutingStatementHandler statementHandler = (RoutingStatementHandler)invocation.getTarget();
  48. StatementHandler delegate = (StatementHandler) ReflectHelper.getFieldValue(statementHandler,

    “delegate”

    );
  49. BoundSql boundSql = delegate.getBoundSql();
  50. Object obj = boundSql.getParameterObject();

  51. if

    (obj

    instanceof

    Page<?>) {
  52. Page<?> page = (Page<?>) obj;

  53. //通过反射获取delegate父类BaseStatementHandler的mappedStatement属性
  54. MappedStatement mappedStatement = (MappedStatement)ReflectHelper.getFieldValue(delegate,

    “mappedStatement”

    );

  55. //拦截到的prepare方法参数是一个Connection对象
  56. Connection connection = (Connection)invocation.getArgs()[

    0

    ];

  57. //获取当前要执行的Sql语句,也就是我们直接在Mapper映射语句中写的Sql语句
  58. String sql = boundSql.getSql();

  59. //给当前的page参数对象设置总记录数

  60. this

    .setTotalRecord(page,
  61. mappedStatement, connection);

  62. //获取分页Sql语句
  63. String pageSql =

    this

    .getPageSql(page, sql);

  64. //利用反射设置当前BoundSql对应的sql属性为我们建立好的分页Sql语句
  65. ReflectHelper.setFieldValue(boundSql,

    “sql”

    , pageSql);
  66. }
  67. }

  68. return

    invocation.proceed();
  69. }

  70. /**

  71. * 给当前的参数对象page设置总记录数

  72. *

  73. * @param page Mapper映射语句对应的参数对象

  74. * @param mappedStatement Mapper映射语句

  75. * @param connection 当前的数据库连接

  76. */

  77. private


    void

    setTotalRecord(Page<?> page,
  78. MappedStatement mappedStatement, Connection connection) {

  79. //获取对应的BoundSql,这个BoundSql其实跟我们利用StatementHandler获取到的BoundSql是同一个对象。

  80. //delegate里面的boundSql也是通过mappedStatement.getBoundSql(paramObj)方法获取到的。
  81. BoundSql boundSql = mappedStatement.getBoundSql(page);

  82. //获取到我们自己写在Mapper映射语句中对应的Sql语句
  83. String sql = boundSql.getSql();

  84. //通过查询Sql语句获取到对应的计算总记录数的sql语句
  85. String countSql =

    this

    .getCountSql(sql);

  86. //通过BoundSql获取对应的参数映射
  87. List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();

  88. //利用Configuration、查询记录数的Sql语句countSql、参数映射关系parameterMappings和参数对象page建立查询记录数对应的BoundSql对象。
  89. BoundSql countBoundSql =

    new

    BoundSql(mappedStatement.getConfiguration(), countSql, parameterMappings, page);

  90. //通过mappedStatement、参数对象page和BoundSql对象countBoundSql建立一个用于设定参数的ParameterHandler对象
  91. ParameterHandler parameterHandler =

    new

    DefaultParameterHandler(mappedStatement, page, countBoundSql);

  92. //通过connection建立一个countSql对应的PreparedStatement对象。
  93. PreparedStatement pstmt =

    null

    ;
  94. ResultSet rs =

    null

    ;

  95. try

    {
  96. pstmt = connection.prepareStatement(countSql);

  97. //通过parameterHandler给PreparedStatement对象设置参数
  98. parameterHandler.setParameters(pstmt);

  99. //之后就是执行获取总记录数的Sql语句和获取结果了。
  100. rs = pstmt.executeQuery();

  101. if

    (rs.next()) {

  102. int

    totalRecord = rs.getInt(

    1

    );

  103. //给当前的参数page对象设置总记录数
  104. page.setTotalRecord(totalRecord);
  105. }
  106. }

    catch

    (SQLException e) {
  107. e.printStackTrace();
  108. }

    finally

    {

  109. try

    {

  110. if

    (rs !=

    null

    )
  111. rs.close();

  112. if

    (pstmt !=

    null

    )
  113. pstmt.close();
  114. }

    catch

    (SQLException e) {
  115. e.printStackTrace();
  116. }
  117. }
  118. }

  119. /**

  120. * 根据原Sql语句获取对应的查询总记录数的Sql语句

  121. * @param sql

  122. * @return

  123. */

  124. private

    String getCountSql(String sql) {

  125. int

    index = sql.indexOf(

    “from”

    );

  126. return


    “select count(*) ”

    + sql.substring(index);
  127. }

  128. /**

  129. * 根据page对象获取对应的分页查询Sql语句,这里只做了两种数据库类型,Mysql和Oracle

  130. * 其它的数据库都 没有进行分页

  131. *

  132. * @param page 分页对象

  133. * @param sql 原sql语句

  134. * @return

  135. */

  136. private

    String getPageSql(Page<?> page, String sql) {
  137. StringBuffer sqlBuffer =

    new

    StringBuffer(sql);

  138. if

    (

    “mysql”

    .equalsIgnoreCase(dialect)) {

  139. return

    getMysqlPageSql(page, sqlBuffer);
  140. }

    else


    if

    (

    “oracle”

    .equalsIgnoreCase(dialect)) {

  141. return

    getOraclePageSql(page, sqlBuffer);
  142. }

  143. return

    sqlBuffer.toString();
  144. }

  145. /**

  146. * 获取Mysql数据库的分页查询语句

  147. * @param page 分页对象

  148. * @param sqlBuffer 包含原sql语句的StringBuffer对象

  149. * @return Mysql数据库分页语句

  150. */

  151. private

    String getMysqlPageSql(Page<?> page, StringBuffer sqlBuffer) {

  152. //计算第一条记录的位置,Mysql中记录的位置是从0开始的。

  153. //     System.out.println(“page:”+page.getPage()+”——-“+page.getRows());

  154. int

    offset = (page.getPage() –

    1

    ) * page.getRows();
  155. sqlBuffer.append(

    ” limit ”

    ).append(offset).append(

    “,”

    ).append(page.getRows());

  156. return

    sqlBuffer.toString();
  157. }

  158. /**

  159. * 获取Oracle数据库的分页查询语句

  160. * @param page 分页对象

  161. * @param sqlBuffer 包含原sql语句的StringBuffer对象

  162. * @return Oracle数据库的分页查询语句

  163. */

  164. private

    String getOraclePageSql(Page<?> page, StringBuffer sqlBuffer) {

  165. //计算第一条记录的位置,Oracle分页是通过rownum进行的,而rownum是从1开始的

  166. int

    offset = (page.getPage() –

    1

    ) * page.getRows() +

    1

    ;
  167. sqlBuffer.insert(

    0

    ,

    “select u.*, rownum r from (”

    ).append(

    “) u where rownum < ”

    ).append(offset + page.getRows());
  168. sqlBuffer.insert(

    0

    ,

    “select * from (”

    ).append(

    “) where r >= ”

    ).append(offset);

  169. //上面的Sql语句拼接之后大概是这个样子:

  170. //select * from (select u.*, rownum r from (select * from t_user) u where rownum < 31) where r >= 16

  171. return

    sqlBuffer.toString();
  172. }

  173. /**

  174. * 拦截器对应的封装原始对象的方法

  175. */

  176. public

    Object plugin(Object arg0) {

  177. // TODO Auto-generated method stub

  178. if

    (arg0

    instanceof

    StatementHandler) {

  179. return

    Plugin.wrap(arg0,

    this

    );
  180. }

    else

    {

  181. return

    arg0;
  182. }
  183. }

  184. /**

  185. * 设置注册拦截器时设定的属性

  186. */

  187. public


    void

    setProperties(Properties p) {
  188. }

  189. public

    String getDialect() {

  190. return

    dialect;
  191. }

  192. public


    void

    setDialect(String dialect) {

  193. this

    .dialect = dialect;
  194. }

  195. public

    String getPageSqlId() {

  196. return

    pageSqlId;
  197. }

  198. public


    void

    setPageSqlId(String pageSqlId) {

  199. this

    .pageSqlId = pageSqlId;
  200. }
  201. }

xml配置:


  1. <!– MyBatis 接口编程配置  –>

  2. <


    bean


    class

    =

    “org.mybatis.spring.mapper.MapperScannerConfigurer”


    >

  3. <!– basePackage指定要扫描的包,在此包之下的映射器都会被搜索到,可指定多个包,包与包之间用逗号或分号分隔–>

  4. <


    property


    name

    =

    “basePackage”


    value

    =

    “com.yidao.mybatis.dao”


    />

  5. <


    property


    name

    =

    “sqlSessionFactoryBeanName”


    value

    =

    “sqlSessionFactory”


    />

  6. </


    bean


    >

  7. <!– MyBatis 分页拦截器–>

  8. <


    bean


    id

    =

    “paginationInterceptor”


    class

    =

    “com.mybatis.interceptor.PageInterceptor”


    >

  9. <


    property


    name

    =

    “dialect”


    value

    =

    “mysql”


    />

  10. <!– 拦截Mapper.xml文件中,id包含query字符的语句 –>

  11. <


    property


    name

    =

    “pageSqlId”


    value

    =

    “.*query$”


    />

  12. </


    bean


    >

Page类


  1. package

    com.yidao.utils;

  2. /**自己看看,需要什么字段加什么字段吧*/

  3. public


    class

    Page {

  4. private

    Integer rows;

  5. private

    Integer page =

    1

    ;

  6. private

    Integer totalRecord;

  7. public

    Integer getRows() {

  8. return

    rows;
  9. }

  10. public


    void

    setRows(Integer rows) {

  11. this

    .rows = rows;
  12. }

  13. public

    Integer getPage() {

  14. return

    page;
  15. }

  16. public


    void

    setPage(Integer page) {

  17. this

    .page = page;
  18. }

  19. public

    Integer getTotalRecord() {

  20. return

    totalRecord;
  21. }

  22. public


    void

    setTotalRecord(Integer totalRecord) {

  23. this

    .totalRecord = totalRecord;
  24. }
  25. }

ReflectHelper类


  1. package

    com.yidao.utils;

  2. import

    java.lang.reflect.Field;

  3. import

    org.apache.commons.lang3.reflect.FieldUtils;

  4. public


    class

    ReflectHelper {

  5. public


    static

    Object getFieldValue(Object obj , String fieldName ){

  6. if

    (obj ==

    null

    ){

  7. return


    null

    ;
  8. }
  9. Field targetField = getTargetField(obj.getClass(), fieldName);

  10. try

    {

  11. return

    FieldUtils.readField(targetField, obj,

    true

    ) ;
  12. }

    catch

    (IllegalAccessException e) {
  13. e.printStackTrace();
  14. }

  15. return


    null

    ;
  16. }

  17. public


    static

    Field getTargetField(Class<?> targetClass, String fieldName) {
  18. Field field =

    null

    ;

  19. try

    {

  20. if

    (targetClass ==

    null

    ) {

  21. return

    field;
  22. }

  23. if

    (Object.

    class

    .equals(targetClass)) {

  24. return

    field;
  25. }
  26. field = FieldUtils.getDeclaredField(targetClass, fieldName,

    true

    );

  27. if

    (field ==

    null

    ) {
  28. field = getTargetField(targetClass.getSuperclass(), fieldName);
  29. }
  30. }

    catch

    (Exception e) {
  31. }

  32. return

    field;
  33. }

  34. public


    static


    void

    setFieldValue(Object obj , String fieldName , Object value ){

  35. if

    (

    null

    == obj){


    return

    ;}
  36. Field targetField = getTargetField(obj.getClass(), fieldName);

  37. try

    {
  38. FieldUtils.writeField(targetField, obj, value) ;
  39. }

    catch

    (IllegalAccessException e) {
  40. e.printStackTrace();
  41. }
  42. }
  43. }