Parcourir la source

优化请求日志打印逻辑;优化会话更新逻辑;优化请求头校验逻辑;

Woody il y a 1 semaine
Parent
commit
8751edb21b

+ 5 - 0
framework-base/src/main/java/com/chelvc/framework/base/context/LoggingContextHolder.java

@@ -56,6 +56,7 @@ public final class LoggingContextHolder {
         Platform platform = ObjectUtils.ifNull(session, Session::getPlatform);
         Terminal terminal = ObjectUtils.ifNull(session, Session::getTerminal);
         String version = ObjectUtils.ifNull(session, Session::getVersion);
+        Long timestamp = ObjectUtils.ifNull(session, Session::getTimestamp);
         StringBuilder buffer = new StringBuilder("[");
         if (StringUtils.notEmpty(host)) {
             buffer.append(host);
@@ -85,6 +86,10 @@ public final class LoggingContextHolder {
             buffer.append(rid);
         }
         buffer.append("] [");
+        if (timestamp != null) {
+            buffer.append(timestamp);
+        }
+        buffer.append("] [");
         if (StringUtils.notEmpty(endpoint)) {
             buffer.append(endpoint);
         }

+ 21 - 0
framework-base/src/main/java/com/chelvc/framework/base/context/Session.java

@@ -1,6 +1,7 @@
 package com.chelvc.framework.base.context;
 
 import java.io.Serializable;
+import java.util.Collections;
 import java.util.Map;
 import java.util.Set;
 
@@ -133,4 +134,24 @@ public class Session implements Serializable {
     public void setGroup(@NonNull String scene, @NonNull Caps group) {
         this.groups.put(scene, group);
     }
+
+    /**
+     * 更新会话信息
+     *
+     * @param id          主体标识
+     * @param using       使用类别
+     * @param scope       应用范围
+     * @param mobile      手机号码
+     * @param registering 注册时间戳
+     * @param authorities 权限标识集合
+     */
+    void update(@NonNull Long id, @NonNull Using using, @NonNull String scope, String mobile,
+                Long registering, @NonNull Set<String> authorities) {
+        this.id = id;
+        this.using = using;
+        this.scope = scope;
+        this.mobile = mobile;
+        this.registering = registering;
+        this.authorities = Collections.unmodifiableSet(authorities);
+    }
 }

+ 2 - 28
framework-base/src/main/java/com/chelvc/framework/base/context/SessionContextHolder.java

@@ -3,7 +3,6 @@ package com.chelvc.framework.base.context;
 import java.io.IOException;
 import java.io.OutputStream;
 import java.util.ArrayDeque;
-import java.util.Collections;
 import java.util.Deque;
 import java.util.Objects;
 import java.util.Set;
@@ -160,34 +159,9 @@ public class SessionContextHolder implements ServletRequestListener {
      */
     public static Session updateSession(@NonNull Long id, @NonNull Using using, @NonNull String scope, String mobile,
                                         Long registering, @NonNull Set<String> authorities) {
-        return updateSession(id, using, scope, mobile, getVersion(), getPlatform(), registering, authorities);
-    }
-
-    /**
-     * 更新会话信息
-     *
-     * @param id          主体标识
-     * @param using       使用类别
-     * @param scope       应用范围
-     * @param mobile      手机号码
-     * @param version     客户端版本
-     * @param platform    客户端平台
-     * @param registering 注册时间戳
-     * @param authorities 权限标识集合
-     * @return 会话信息
-     */
-    public static Session updateSession(@NonNull Long id, @NonNull Using using, @NonNull String scope, String mobile,
-                                        String version, Platform platform, Long registering,
-                                        @NonNull Set<String> authorities) {
-        Deque<Session> deque = SESSION_CONTEXT.get();
-        Session session = deque.poll();
+        Session session = getSession(false);
         if (session != null) {
-            session = Session.builder().id(id).rid(session.getRid()).using(using).host(session.getHost())
-                    .scope(scope).mobile(mobile).device(session.getDevice()).channel(session.getChannel())
-                    .version(version).platform(platform).terminal(session.getTerminal())
-                    .timestamp(session.getTimestamp()).registering(registering)
-                    .authorities(Collections.unmodifiableSet(authorities)).build();
-            deque.push(session);
+            session.update(id, using, scope, mobile, registering, authorities);
         }
         return session;
     }

+ 3 - 8
framework-base/src/main/java/com/chelvc/framework/base/util/HttpUtils.java

@@ -208,13 +208,8 @@ public final class HttpUtils {
             }
         }
 
-        // 通过默认方式获取客户端地址
-        if (StringUtils.isEmpty(host)) {
-            host = request.getRemoteAddr();
-        }
-
-        // 如果地址存在多个则获取第一个
-        if (host.indexOf(',') > 0) {
+        // 通过默认方式获取客户端地址,如果存在多个地址则优先获取公网地址,否则获取第一个地址
+        if (StringUtils.isEmpty(host) && (host = request.getRemoteAddr()) != null && host.indexOf(',') > 0) {
             String[] ips = host.split(",");
             for (String ip : ips) {
                 if (!HostUtils.isIntranet(ip)) {
@@ -223,7 +218,7 @@ public final class HttpUtils {
             }
             host = ips[0];
         }
-        return HostUtils.DEFAULT_LOCAL_ADDRESS_IPV6.equals(host) ? HostUtils.LOCAL_ADDRESS : host;
+        return HostUtils.isLocal(host) ? HostUtils.LOCAL_ADDRESS : host;
     }
 
     /**

+ 12 - 2
framework-common/src/main/java/com/chelvc/framework/common/util/HostUtils.java

@@ -80,7 +80,17 @@ public final class HostUtils {
     }
 
     /**
-     * 判断IP是否属于内网
+     * 判断是否是本地IP
+     *
+     * @param host IP地址
+     * @return true/false
+     */
+    public static boolean isLocal(String host) {
+        return DEFAULT_LOCAL_ADDRESS.equals(host) || DEFAULT_LOCAL_ADDRESS_IPV6.equals(host);
+    }
+
+    /**
+     * 判断是否是内网IP
      *
      * @param host IP地址
      * @return true/false
@@ -88,7 +98,7 @@ public final class HostUtils {
     public static boolean isIntranet(String host) {
         if (StringUtils.isEmpty(host)) {
             return false;
-        } else if (DEFAULT_LOCAL_ADDRESS.contentEquals(host) || DEFAULT_LOCAL_ADDRESS_IPV6.contentEquals(host)) {
+        } else if (isLocal(host)) {
             return true;
         }
         for (Ip ip : INTRANET_ADDRESSES) {

+ 1 - 1
framework-security/src/main/java/com/chelvc/framework/security/context/SecurityContextHolder.java

@@ -109,7 +109,7 @@ public final class SecurityContextHolder {
     }
 
     /**
-     * 是否是客户端登陆
+     * 是否是客户端模式
      *
      * @param jwt Jwt对象
      * @return true/false

+ 1 - 4
framework-security/src/main/java/com/chelvc/framework/security/interceptor/SecurityValidateInterceptor.java

@@ -137,11 +137,8 @@ public class SecurityValidateInterceptor implements HandlerInterceptor, WebMvcCo
         if (security == null || security.header()) {
             // 请求头校验
             Long timestamp = session.getTimestamp();
-            String version = SessionContextHolder.getHeader(SessionContextHolder.HEADER_VERSION);
-            String platform = SessionContextHolder.getHeader(SessionContextHolder.HEADER_PLATFORM);
             if (timestamp == null || session.getPlatform() == null || session.getTerminal() == null
-                    || StringUtils.isEmpty(session.getVersion()) || !Objects.equals(session.getVersion(), version)
-                    || !Objects.equals(session.getPlatform().name(), platform)) {
+                    || StringUtils.isEmpty(session.getVersion())) {
                 String message = ApplicationContextHolder.getMessage("Header.Invalid");
                 if (this.isObservable()) {
                     LoggingContextHolder.warn(log, request, message);

+ 30 - 22
framework-security/src/main/java/com/chelvc/framework/security/session/DefaultSessionValidator.java

@@ -6,8 +6,11 @@ import com.chelvc.framework.base.context.ApplicationContextHolder;
 import com.chelvc.framework.base.context.SessionContextHolder;
 import com.chelvc.framework.base.context.Using;
 import com.chelvc.framework.common.model.Platform;
+import com.chelvc.framework.common.model.Version;
+import com.chelvc.framework.common.util.StringUtils;
 import com.chelvc.framework.security.context.SecurityContextHolder;
 import lombok.extern.slf4j.Slf4j;
+import org.springframework.http.HttpStatus;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult;
@@ -49,41 +52,46 @@ public class DefaultSessionValidator implements SessionValidator {
     }
 
     /**
-     * 获取版本号
-     *
-     * @param jwt Jwt对象
-     * @return 版本号
-     */
-    protected String getVersion(Jwt jwt) {
-        return SecurityContextHolder.getVersion(jwt);
-    }
-
-    /**
-     * 获取平台信息
+     * 获取用户授权信息
      *
-     * @param jwt Jwt对象
-     * @return 平台信息
+     * @param jwt JWT对象
+     * @return 授权信息集合
      */
-    protected Platform getPlatform(Jwt jwt) {
-        return SecurityContextHolder.getPlatform(jwt);
+    protected Set<String> getAuthorities(Jwt jwt) {
+        return SecurityContextHolder.getAuthorities(jwt);
     }
 
     /**
-     * 获取用户授权信息
+     * 初始化会话主体信息
      *
      * @param jwt JWT对象
-     * @return 授权信息集合
      */
-    protected Set<String> getAuthorities(Jwt jwt) {
-        return SecurityContextHolder.getAuthorities(jwt);
+    protected void initializeSessionPrincipal(Jwt jwt) {
+        Long id = this.getId(jwt);
+        String scope = this.getScope(jwt);
+        Set<String> authorities = this.getAuthorities(jwt);
+        SessionContextHolder.updateSession(id, Using.NORMAL, scope, null, null, authorities);
     }
 
     @Override
     public OAuth2TokenValidatorResult validate(Jwt jwt) {
-        if (!SecurityContextHolder.isClient(jwt)) {
-            SessionContextHolder.updateSession(this.getId(jwt), Using.NORMAL, this.getScope(jwt), null,
-                    this.getVersion(jwt), this.getPlatform(jwt), null, this.getAuthorities(jwt));
+        // 客户端模式
+        if (SecurityContextHolder.isClient(jwt)) {
+            return OAuth2TokenValidatorResult.success();
         }
+
+        // 校验版本号及平台信息
+        String version = SecurityContextHolder.getVersion(jwt);
+        Platform platform = SecurityContextHolder.getPlatform(jwt);
+        if ((StringUtils.notEmpty(version) && !Version.isAfter(SessionContextHolder.getVersion(), version, true))
+                || (platform != null && platform != SessionContextHolder.getPlatform())) {
+            throw new OAuth2AuthenticationException(new OAuth2Error(
+                    HttpStatus.UNAUTHORIZED.name(), ApplicationContextHolder.getMessage("Unauthorized"), null
+            ));
+        }
+
+        // 初始化主体信息
+        this.initializeSessionPrincipal(jwt);
         return OAuth2TokenValidatorResult.success();
     }
 }

+ 2 - 11
framework-security/src/main/java/com/chelvc/framework/security/session/RedisSessionValidator.java

@@ -7,7 +7,6 @@ import java.util.Set;
 import com.chelvc.framework.base.context.ApplicationContextHolder;
 import com.chelvc.framework.base.context.SessionContextHolder;
 import com.chelvc.framework.base.context.Using;
-import com.chelvc.framework.common.model.Platform;
 import com.chelvc.framework.common.util.StringUtils;
 import com.chelvc.framework.redis.context.RedisContextHolder;
 import com.chelvc.framework.redis.context.RedisHashHolder;
@@ -19,7 +18,6 @@ import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
 import org.springframework.data.redis.core.RedisTemplate;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2Error;
-import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult;
 import org.springframework.security.oauth2.jwt.Jwt;
 import org.springframework.stereotype.Component;
 
@@ -41,11 +39,7 @@ public class RedisSessionValidator extends DefaultSessionValidator {
     }
 
     @Override
-    public OAuth2TokenValidatorResult validate(Jwt jwt) {
-        if (SecurityContextHolder.isClient(jwt)) {
-            return OAuth2TokenValidatorResult.success();
-        }
-
+    protected void initializeSessionPrincipal(Jwt jwt) {
         // 校验令牌有效性
         Long id = this.getId(jwt);
         RedisTemplate<String, Object> template = RedisContextHolder.getDefaultTemplate();
@@ -72,14 +66,11 @@ public class RedisSessionValidator extends DefaultSessionValidator {
         }
 
         // 更新会话信息
-        String version = this.getVersion(jwt);
-        Platform platform = this.getPlatform(jwt);
         String mobile = (String) context.get(SecurityContextHolder.MOBILE);
         Long creating = (Long) context.get(SecurityContextHolder.CREATING);
         Long registering = (Long) context.get(SecurityContextHolder.REGISTERING);
         Using using = Using.from(RedisUserDailyHashHolder.using(template, id), creating, this.usingRefreshInterval);
         Set<String> authorities = SecurityContextHolder.getAuthorities(jwt);
-        SessionContextHolder.updateSession(id, using, scope, mobile, version, platform, registering, authorities);
-        return OAuth2TokenValidatorResult.success();
+        SessionContextHolder.updateSession(id, using, scope, mobile, registering, authorities);
     }
 }