Kaynağa Gözat

优化防火墙黑/白名单处理逻辑

Woody 1 ay önce
ebeveyn
işleme
14aac7a95a

+ 63 - 16
framework-security/src/main/java/com/chelvc/framework/security/interceptor/SecurityFirewallInterceptor.java

@@ -2,10 +2,12 @@ package com.chelvc.framework.security.interceptor;
 
 import java.util.Collections;
 import java.util.List;
+import java.util.Objects;
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletResponse;
 
 import com.chelvc.framework.base.context.ApplicationContextHolder;
+import com.chelvc.framework.base.context.Session;
 import com.chelvc.framework.base.context.SessionContextHolder;
 import com.chelvc.framework.common.exception.FrameworkException;
 import com.chelvc.framework.common.model.Ip;
@@ -16,6 +18,7 @@ import com.chelvc.framework.redis.context.RedisContextHolder;
 import com.chelvc.framework.security.firewall.FirewallProcessor;
 import com.chelvc.framework.security.firewall.Rule;
 import com.fasterxml.jackson.core.type.TypeReference;
+import com.google.common.collect.Lists;
 import lombok.RequiredArgsConstructor;
 import lombok.extern.slf4j.Slf4j;
 import org.springframework.beans.factory.annotation.Autowired;
@@ -78,21 +81,21 @@ public class SecurityFirewallInterceptor implements HandlerInterceptor, WebMvcCo
     }
 
     /**
-     * 获取IP白名单
+     * 获取主体白名单
      *
-     * @return IP列表
+     * @return 主体列表
      */
-    private List<Ip> getWhitelists() {
-        return ApplicationContextHolder.getProperty("security.firewall.whitelists", Ip::host2ips);
+    private List<CharSequence> getWhitelists() {
+        return ApplicationContextHolder.getProperty("security.firewall.whitelists", this::property2principals);
     }
 
     /**
-     * 获取IP黑名单
+     * 获取主体黑名单
      *
-     * @return IP列表
+     * @return 主体列表
      */
-    private List<Ip> getBlacklists() {
-        return ApplicationContextHolder.getProperty("security.firewall.blacklists", Ip::host2ips);
+    private List<CharSequence> getBlacklists() {
+        return ApplicationContextHolder.getProperty("security.firewall.blacklists", this::property2principals);
     }
 
     /**
@@ -127,6 +130,46 @@ public class SecurityFirewallInterceptor implements HandlerInterceptor, WebMvcCo
                 && rule.getExpiration() >= 60 && rule.getExpiration() <= 1800));
     }
 
+    /**
+     * 判断主体是否匹配
+     *
+     * @param principal 主体信息
+     * @param host      IP数字
+     * @return true/false
+     */
+    private boolean matches(CharSequence principal, long host) {
+        if (principal instanceof Ip) {
+            return host > 0 && ((Ip) principal).matches(host);
+        }
+        Session session = SessionContextHolder.getSession(false);
+        return session != null && (Objects.equals(principal, session.getMobile())
+                || Objects.equals(principal, session.getDevice()));
+    }
+
+    /**
+     * 将属性值转换成主体信息
+     *
+     * @param property 属性值
+     * @return 主体信息列表
+     */
+    private List<CharSequence> property2principals(String property) {
+        if (StringUtils.isEmpty(property)) {
+            return Collections.emptyList();
+        }
+
+        List<CharSequence> principals = Lists.newLinkedList();
+        for (String principal : property.split(",")) {
+            if (!(principal = principal.trim()).isEmpty()) {
+                if (StringUtils.isIp(principal)) {
+                    principals.add(new Ip(principal));
+                } else {
+                    principals.add(principal);
+                }
+            }
+        }
+        return principals.isEmpty() ? Collections.emptyList() : principals;
+    }
+
     @Override
     public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler)
             throws Exception {
@@ -134,18 +177,22 @@ public class SecurityFirewallInterceptor implements HandlerInterceptor, WebMvcCo
             return true;
         }
 
-        long host = this.getHostNumber();
-        if (host > 0) {
-            // 如果是IP白名单则放行
-            for (Ip ip : this.getWhitelists()) {
-                if (ip.matches(host)) {
+        // 主体黑/白名单校验
+        List<CharSequence> whitelists = this.getWhitelists();
+        List<CharSequence> blacklists = this.getBlacklists();
+        if (ObjectUtils.notEmpty(whitelists) || ObjectUtils.notEmpty(blacklists)) {
+            long host = this.getHostNumber();
+
+            // 白名单放行
+            for (CharSequence principal : whitelists) {
+                if (this.matches(principal, host)) {
                     return true;
                 }
             }
 
-            // 如果是IP黑名单拒绝访问
-            for (Ip ip : this.getBlacklists()) {
-                if (ip.matches(host)) {
+            // 黑名单拒绝访问
+            for (CharSequence principal : blacklists) {
+                if (this.matches(principal, host)) {
                     String message = ApplicationContextHolder.getMessage("Forbidden");
                     throw new FrameworkException(HttpStatus.FORBIDDEN.name(), null, message);
                 }