import com.example.demo.entity.PageData;
import com.example.demo.service.IBuserService;
import com.example.demo.service.MenuService;
import com.example.demo.util.SpringContextUtil;
import org.apache.shiro.SecurityUtils;
import org.apache.shiro.authz.AuthorizationInfo;
import org.apache.shiro.subject.Subject;
import org.apache.shiro.web.filter.PathMatchingFilter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import java.util.List;
public class MyAccessControlFilter extends PathMatchingFilter {
private Logger logger = LoggerFactory.getLogger(this.getClass());
//相关service
@Autowired
private IBuserService userService;
//相关service
@Autowired
private MenuService menuService;
//自定义Realm类
@Autowired
private CustomRealm customRealm;
@Override
protected boolean onPreHandle(ServletRequest request, ServletResponse response, Object mappedValue) throws Exception {
//注意这个不能省略,否则注入可能为空 SpringContextUtil.工具类可以复制第四
if (userService==null){
userService= SpringContextUtil.getBean(IBuserService.class);
}
if (menuService==null){
menuService= SpringContextUtil.getBean(MenuService.class);
}
if (customRealm==null){
customRealm= SpringContextUtil.getBean(CustomRealm.class);
}
//请求的url
String requestURL = getPathWithinApplication(request);
System.out.println("请求的url :"+requestURL);
//判断是否登录
Subject subject = SecurityUtils.getSubject();
if (!subject.isAuthenticated()){
// 如果没有登录, 直接返回true 进入登录流程
request.getRequestDispatcher("/login.html").forward(request, response);
return true;
}
//获取账户
String account = (String)subject.getPrincipal();
//此处pageData只是个通用实体类
PageData accountData= userService.getAccount(account);
//用户id
Long userId = (Long)accountData.get("user_id");
// 获取所有权限
List
for (PageData pd : permissions) {
//根据菜单id查询路径
PageData menu = menuService.getMenuById(pd);
if (menu.getString("path").equals(requestURL)){
//获取访菜单id
Integer menu_id=(Integer)menu.get("id");
//通过方法调用再次调用自定义CustomRealm 的doGetAuthorizationInfo
AuthorizationInfo info = customRealm.getInfo(SecurityUtils.getSubject().getPrincipals(),menu_id) ;
hasPermission=true;
break;
}
hasPermission=true;
}
if (hasPermission){
return true;
}else {
UnauthorizedException ex = new UnauthorizedException("当前用户没有访问路径" + requestURL + "的权限");
subject.getSession().setAttribute("ex",ex);
WebUtils.issueRedirect(request, response, "/unauthorized");
return false;
}
}
}
自定义CustomRealm类
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.shiro.authc.*;
import org.apache.shiro.authz.AuthorizationInfo;
import org.apache.shiro.authz.SimpleAuthorizationInfo;
import org.apache.shiro.realm.AuthorizingRealm;
import org.apache.shiro.subject.PrincipalCollection;
import org.springframework.beans.factory.annotation.Autowired;
import com.example.demo.entity.PageData;
import com.example.demo.service.IBuserService;
import com.example.demo.service.MenuService;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
public class CustomRealm extends AuthorizingRealm {
private Logger logger = LoggerFactory.getLogger(this.getClass());
@Autowired
private IBuserService userService;
@Autowired
private MenuService menuService;
private Integer menu_id=0;
public AuthorizationInfo getInfo(PrincipalCollection principalCollection) {
return this.doGetAuthorizationInfo(principalCollection);
}
@Override
protected AuthorizationInfo doGetAuthorizationInfo(PrincipalCollection principalCollection) {
SimpleAuthorizationInfo info = new SimpleAuthorizationInfo();
String account =(String) principalCollection.getPrimaryPrincipal();
//根据account获取用户 这里我是将账号和用户化两个表
PageData accountData= userService.getAccount(account);
PageData pd = new PageData();
if(null!=accountData) {
Long uid = (Long)accountData.get("user_id");
pd.put("id", uid);
PageData userData=userService.getUserById(pd);
Long userId=(Long)userData.get("id");
//查询用户拥有的角色
List
for (PageData role : roles) {
//赋予角色
info.addRole(role.getString("name"));
//用户权限列表
if(menu_id!=0) {
Set
PageData pageData = new PageData();
pageData.put("menu_id", menu_id);
Long rid = (Long)role.get("id");
pageData.put("role_id", rid);
PageData p= menuService.getPermByMenuId(pageData);
String[] split = p.getString("permission").split(",");
for (String str : split) {
//赋予角色
permissions.add(str);
}
info.addStringPermissions(permissions);
}
}
}
return info;
}
@Override
protected AuthenticationInfo doGetAuthenticationInfo(AuthenticationToken authenticationToken) throws AuthenticationException {
System.out.println("-------身份认证方法--------");
String userName = (String) authenticationToken.getPrincipal();
String userPwd = new String((char[]) authenticationToken.getCredentials());
//根据用户名从数据库获取密码
String password = "123";
PageData accountData= (PageData)userService.getAccount(userName);
if(null!=accountData) {
PageData pd = new PageData();
Long uid = (Long)accountData.get("user_id");
pd.put("id", uid);
PageData userData=userService.getUserById(pd);
password = userData.getString("password");
}
if (userName == null) {
throw new AccountException("用户名不正确");
} else if (!userPwd.equals(password )) {
throw new AccountException("密码不正确");
}
return new SimpleAuthenticationInfo(userName, password,getName());
}
public AuthorizationInfo getInfo(PrincipalCollection principalCollection, Integer mid) {
// TODO Auto-generated method stub
menu_id=mid;
return this.doGetAuthorizationInfo(principalCollection);
}
}
3.shiro配置类import org.apache.shiro.mgt.SecurityManager;
import org.apache.shiro.spring.web.ShiroFilterFactoryBean;
import org.apache.shiro.web.mgt.DefaultWebSecurityManager;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import at.pollux.thymeleaf.shiro.dialect.ShiroDialect;
import java.util.linkedHashMap;
import java.util.Map;
import javax.servlet.Filter;
@Configuration
public class ShiroConfig {
@Bean(name = "shiroFilter")
public ShiroFilterFactoryBean shiroFilter(SecurityManager securityManager) {
ShiroFilterFactoryBean shiroFilterFactoryBean = new ShiroFilterFactoryBean();
shiroFilterFactoryBean.setSecurityManager(securityManager);
//自定义拦截器
Map
Map
filtersMap.put("myAccessControlFilter", new MyAccessControlFilter());
//登录url
shiroFilterFactoryBean.setLoginUrl("/login.html");
//成功url
shiroFilterFactoryBean.setSuccessUrl("/index.html");
//
filterChainDefinitionMap.put("/buser/login", "anon");
filterChainDefinitionMap.put("/logout", "anon");
filterChainDefinitionMap.put("/js
@Bean
public SecurityManager securityManager() {
DefaultWebSecurityManager defaultSecurityManager = new DefaultWebSecurityManager();
defaultSecurityManager.setRealm(customRealm());
return defaultSecurityManager;
}
@Bean
public CustomRealm customRealm() {
CustomRealm customRealm = new CustomRealm();
return customRealm;
}
@Bean
public ShiroDialect getShiroDialect() {
return new ShiroDialect();
}
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import javax.servlet.http.HttpServletRequest;
@SuppressWarnings("all")
public class PageData extends HashMap implements Map {
private static final long serialVersionUID = 1L;
Map map = null;
HttpServletRequest request;
public PageData(HttpServletRequest request) {
this.request = request;
Map properties = request.getParameterMap();
Map returnMap = new HashMap();
Iterator entries = properties.entrySet().iterator();
Map.Entry entry;
String name = "";
String value = "";
while (entries.hasNext()) {
entry = (Map.Entry) entries.next();
name = (String) entry.getKey();
Object valueObj = entry.getValue();
if (null == valueObj) {
value = "";
} else if (valueObj instanceof String[]) {
String[] values = (String[]) valueObj;
for (int i = 0; i < values.length; i++) {
value = values[i] + ",";
}
value = value.substring(0, value.length() - 1);
} else {
value = valueObj.toString();
}
returnMap.put(name, value);
}
map = returnMap;
}
public PageData() {
map = new HashMap();
}
@Override
public Object get(Object key) {
Object obj = null;
if (map.get(key) instanceof Object[]) {
Object[] arr = (Object[]) map.get(key);
obj = request == null ? arr : (request.getParameter((String) key) == null ? arr : arr[0]);
} else {
obj = map.get(key);
}
return obj;
}
public String getString(Object key) {
return (String) get(key);
}
public Number getNumber(Object key) {
return (Number) get(key);
}
@SuppressWarnings("unchecked")
@Override
public Object put(Object key, Object value) {
return map.put(key, value);
}
@Override
public Object remove(Object key) {
return map.remove(key);
}
public void clear() {
map.clear();
}
public boolean containsKey(Object key) {
return map.containsKey(key);
}
public boolean containsValue(Object value) {
return map.containsValue(value);
}
public Set entrySet() {
return map.entrySet();
}
public boolean isEmpty() {
return map.isEmpty();
}
public Set keySet() {
return map.keySet();
}
@SuppressWarnings("unchecked")
public void putAll(Map t) {
map.putAll(t);
}
public int size() {
return map.size();
}
public Collection values() {
return map.values();
}
}
import org.springframework.beans.BeansException;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.stereotype.Component;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import javax.servlet.http.HttpServletRequest;
@Component
public class SpringContextUtil implements ApplicationContextAware {
private static ApplicationContext applicationContext;
@Override
public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
this.applicationContext = applicationContext;
}
public static ApplicationContext getApplicationContext() {
return applicationContext;
}
public static HttpServletRequest getHttpServletRequest() {
return ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest();
}
public static String getDomain(){
HttpServletRequest request = getHttpServletRequest();
StringBuffer url = request.getRequestURL();
return url.delete(url.length() - request.getRequestURI().length(), url.length()).toString();
}
public static String getOrigin(){
HttpServletRequest request = getHttpServletRequest();
return request.getHeader("Origin");
}
public static Object getBean(String name) {
return getApplicationContext().getBean(name);
}
public static
return getApplicationContext().getBean(clazz);
}
public static
return getApplicationContext().getBean(name, clazz);
}
}
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)