准备 自定义注解最近公司项目想要做成一个云SaaS平台,需要不同用户能看到不同数据,需要做到数据的物理隔离。目前的方案就是在每张业务表中增加一个
platform_id
字段,来区分不同的租户,这就意味着在原来系统层面需要再增删改查都需要带上platform_id
字段作为标识。如果在每个脚本上都手动加上这个字段的话那就太麻烦,太复杂了。所以就想使用mybatis 的拦截器Interceptor
来实现。
@PlatformTag
@PlatformTagIngore
为了让代码更灵活,只在mapper 类上标注有
@PlatformTag
的标记才会自动拦截并且添加条件。
/**
*
* @author tengwang8
* @version 1.0
* @date 2022/2/28 10:04
*/
@Target({ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface PlatformTag {
}
这个注解只作用于类上,但是我们又考虑到自定义的复杂sql是没办法自动添加条件的,所以就再增加一个@PlatformTagIngore
注解来忽略类中的方法进行手动添加条件。
/*
*
* @author tengwang8
* @version 1.0
* @date 2022/2/28 10:04
*/
@Target({ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface PlatformTagIgnore {
}
过滤器PlatformInterceptor
首先考虑一下我们的mysql的几种基本类型,增删改查,删除不用考虑(删除一般都是直接根据主键id删除),改和新增属于
update
和insert
,查是select
,目前就只需要考虑这三种情况了
import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
import net.sf.jsqlparser.parser.CCJSqlParserManager;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.Statement;
import net.sf.jsqlparser.statement.insert.Insert;
import net.sf.jsqlparser.statement.select.FromItem;
import net.sf.jsqlparser.statement.select.PlainSelect;
import net.sf.jsqlparser.statement.select.Select;
import net.sf.jsqlparser.statement.update.Update;
import org.apache.ibatis.cache.CacheKey;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Plugin;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import java.io.StringReader;
import java.util.ArrayList;
import java.util.List;
import java.util.Properties;
/**
* TODO
*
* @author tengwang8
* @version 1.0
* @date 2022/3/1 17:27
*/
@Slf4j
@Intercepts({
@Signature( type = Executor.class, method = "update",args = {MappedStatement.class, Object.class}),
@Signature(type = Executor.class, method = "query",args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}),
@Signature(type = Executor.class, method = "query",args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class,BoundSql.class})
})
public class AreaInterceptor implements Interceptor {
private static final String COLUMN_NAME = "platform_id";
@Override
public Object intercept(Invocation invocation) throws Throwable {
String processSql = ExecutorPluginUtils.getSqlByInvocation(invocation);
log.debug("schema替换前:{}", processSql);
// 执行自定义修改sql *** 作
// 获取sql
String sql2Reset = processSql;
//忽略sql中包含on conflict的情况
Statement statement = CCJSqlParserUtil.parse(processSql);
MappedStatement mappedStatement = (MappedStatement) invocation.getArgs()[0];
if (ExecutorPluginUtils.isAreaTag(mappedStatement)) {
try {
if (statement instanceof Update) {
Update updateStatement = (Update) statement;
Table table = updateStatement.getTables().get(0);
if (table != null) {
List<Column> columns = updateStatement.getColumns();
List<Expression> expressions = updateStatement.getExpressions();
columns.add(new Column(COLUMN_NAME));
expressions.add(CCJSqlParserUtil.parseExpression(CurrentPlatformIdCache.getCurrentPlatformId()));
updateStatement.setColumns(columns);
updateStatement.setExpressions(expressions);
sql2Reset = updateStatement.toString();
}
}
if (statement instanceof Insert) {
Insert insertStatement = (Insert) statement;
List<Column> columns = insertStatement.getColumns();
ExpressionList itemsList = (ExpressionList) insertStatement.getItemsList();
columns.add(new Column(COLUMN_NAME));
List<Expression> list = new ArrayList<>();
list.addAll(itemsList.getExpressions());
list.add(CCJSqlParserUtil.parseExpression(CurrentPlatformIdCache.getCurrentPlatformId()));
itemsList.setExpressions(list);
insertStatement.setItemsList(itemsList);
insertStatement.setColumns(columns);
sql2Reset = insertStatement.toString();
}
if (statement instanceof Select) {
Select selectStatement = (Select) statement;
PlainSelect plain = (PlainSelect) selectStatement.getSelectBody();
FromItem fromItem = plain.getFromItem();
//获取到原始sql语句
String sql = processSql;
StringBuffer whereSql = new StringBuffer();
//增加sql语句的逻辑部分处理
if (fromItem.getAlias() != null) {
whereSql.append(fromItem.getAlias().getName()).append(".platform_id = ").append(CurrentPlatformIdCache.getCurrentPlatformId());
} else {
whereSql.append("platform_id = ").append(CurrentPlatformIdCache.getCurrentPlatformId());
}
Expression where = plain.getWhere();
if (where == null) {
if (whereSql.length() > 0) {
Expression expression = CCJSqlParserUtil
.parseCondExpression(whereSql.toString());
Expression whereExpression = (Expression) expression;
plain.setWhere(whereExpression);
}
} else {
if (whereSql.length() > 0) {
//where条件之前存在,需要重新进行拼接
whereSql.append(" and ( " + where.toString() + " )");
} else {
//新增片段不存在,使用之前的sql
whereSql.append(where.toString());
}
Expression expression = CCJSqlParserUtil
.parseCondExpression(whereSql.toString());
plain.setWhere(expression);
}
sql2Reset = selectStatement.toString();
}
} catch (Exception e) {
e.printStackTrace();
}
}
log.info("schema替换后:{}", sql2Reset);
// 替换sql
ExecutorPluginUtils.resetSql2Invocation(invocation, sql2Reset);
return invocation.proceed();
}
@Override
public Object plugin(Object target) {
return Plugin.wrap(target, this);
}
@Override
public void setProperties(Properties properties) {
}
}
工具类ExecutorPluginUtils
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.mapping.SqlSource;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.reflection.DefaultReflectorFactory;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.factory.DefaultObjectFactory;
import org.apache.ibatis.reflection.wrapper.DefaultObjectWrapperFactory;
import java.lang.reflect.Method;
import java.sql.SQLException;
/**
* TODO
*
* @author tengwang8
* @version 1.0
* @date 2022/3/2 9:07
*/
public class ExecutorPluginUtils {
/**
* 获取sql语句
* @param invocation
* @return
*/
public static String getSqlByInvocation(Invocation invocation) {
final Object[] args = invocation.getArgs();
MappedStatement ms = (MappedStatement) args[0];
Object parameterObject = args[1];
BoundSql boundSql = ms.getBoundSql(parameterObject);
return boundSql.getSql();
}
/**
* 包装sql后,重置到invocation中
* @param invocation
* @param sql
* @throws SQLException
*/
public static void resetSql2Invocation(Invocation invocation, String sql) throws SQLException {
final Object[] args = invocation.getArgs();
MappedStatement statement = (MappedStatement) args[0];
Object parameterObject = args[1];
BoundSql boundSql = statement.getBoundSql(parameterObject);
MappedStatement newStatement = newMappedStatement(statement, new BoundSqlSqlSource(boundSql));
MetaObject msObject = MetaObject.forObject(newStatement, new DefaultObjectFactory(), new DefaultObjectWrapperFactory(),new DefaultReflectorFactory());
msObject.setValue("sqlSource.boundSql.sql", sql);
args[0] = newStatement;
}
private static MappedStatement newMappedStatement(MappedStatement ms, SqlSource newSqlSource) {
MappedStatement.Builder builder =
new MappedStatement.Builder(ms.getConfiguration(), ms.getId(), newSqlSource, ms.getSqlCommandType());
builder.resource(ms.getResource());
builder.fetchSize(ms.getFetchSize());
builder.statementType(ms.getStatementType());
builder.keyGenerator(ms.getKeyGenerator());
if (ms.getKeyProperties() != null && ms.getKeyProperties().length != 0) {
StringBuilder keyProperties = new StringBuilder();
for (String keyProperty : ms.getKeyProperties()) {
keyProperties.append(keyProperty).append(",");
}
keyProperties.delete(keyProperties.length() - 1, keyProperties.length());
builder.keyProperty(keyProperties.toString());
}
builder.timeout(ms.getTimeout());
builder.parameterMap(ms.getParameterMap());
builder.resultMaps(ms.getResultMaps());
builder.resultSetType(ms.getResultSetType());
builder.cache(ms.getCache());
builder.flushCacheRequired(ms.isFlushCacheRequired());
builder.useCache(ms.isUseCache());
return builder.build();
}
/**
* 是否标记为区域字段
* @return
*/
public static boolean isAreaTag( MappedStatement mappedStatement) throws ClassNotFoundException {
String id = mappedStatement.getId();
Class> classType = Class.forName(id.substring(0,mappedStatement.getId().lastIndexOf(".")));
//获取对应拦截方法名
String mName = mappedStatement.getId().substring(mappedStatement.getId().lastIndexOf(".") + 1);
boolean ignore = false;
for(Method method : classType.getDeclaredMethods()){
if(method.isAnnotationPresent(AreaTagIgnore.class) && mName.equals(method.getName()) ) {
ignore = true;
}
}
if (classType.isAnnotationPresent(AreaTag.class) && !ignore) {
return true;
}
return false;
}
/**
* 是否标记为区域字段
* @return
*/
public static boolean isAreaTagIngore( MappedStatement mappedStatement) throws ClassNotFoundException {
String id = mappedStatement.getId();
Class> classType = Class.forName(id.substring(0,mappedStatement.getId().lastIndexOf(".")));
//获取对应拦截方法名
String mName = mappedStatement.getId().substring(mappedStatement.getId().lastIndexOf(".") + 1);
boolean ignore = false;
for(Method method : classType.getDeclaredMethods()){
if(method.isAnnotationPresent(AreaTagIgnore.class) && mName.equals(method.getName()) ) {
ignore = true;
}
}
return ignore;
}
public static String getOperateType(Invocation invocation) {
final Object[] args = invocation.getArgs();
MappedStatement ms = (MappedStatement) args[0];
SqlCommandType commondType = ms.getSqlCommandType();
if (commondType.compareTo(SqlCommandType.SELECT) == 0) {
return "select";
}
if (commondType.compareTo(SqlCommandType.INSERT) == 0) {
return "insert";
}
if (commondType.compareTo(SqlCommandType.UPDATE) == 0) {
return "update";
}
if (commondType.compareTo(SqlCommandType.DELETE) == 0) {
return "delete";
}
return null;
}
// 定义一个内部辅助类,作用是包装sq
static class BoundSqlSqlSource implements SqlSource {
private BoundSql boundSql;
public BoundSqlSqlSource(BoundSql boundSql) {
this.boundSql = boundSql;
}
@Override
public BoundSql getBoundSql(Object parameterObject) {
return boundSql;
}
}
}
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)