Sfoglia il codice sorgente

Merge branch 'develop' of http://132.232.92.186:3000/XinXiBu/OA2023 into develop

yuanrf 6 ore fa
parent
commit
45adcab47a

+ 19 - 0
OASystem/OASystem.Api/Controllers/AuthController.cs

@@ -1,5 +1,7 @@
 
 using Microsoft.AspNetCore.SignalR;
+using Microsoft.EntityFrameworkCore.Metadata.Internal;
+using NPOI.SS.Formula.Functions;
 using OASystem.API.OAMethodLib;
 using OASystem.API.OAMethodLib.Hub.HubClients;
 using OASystem.API.OAMethodLib.Hub.Hubs;
@@ -472,6 +474,23 @@ namespace OASystem.API.Controllers
             return Ok(JsonView(view));
         }
 
+        /// <summary>
+        /// 限流测试
+        /// </summary>
+        /// <returns></returns>
+        [HttpGet("rate-test")]
+        [ProducesResponseType(typeof(JsonView), StatusCodes.Status200OK)]
+        public IActionResult RateTest()
+        {
+            return Ok(JsonView(true, $"限流测试 - IP:{HttpContext.Connection.RemoteIpAddress?.ToString()}", null, GetRequestCount()));
+        }
+
+        private static int _requestCount = 0;
+        private int GetRequestCount()
+        {
+            return ++_requestCount;
+        }
+
         ///// <summary>
         ///// 员工信息 迁移
         ///// Old OA To New OA

+ 337 - 95
OASystem/OASystem.Api/Middlewares/RateLimitMiddleware.cs

@@ -1,11 +1,8 @@
-// Middlewares/RateLimitMiddleware.cs
-using AspNetCoreRateLimit;
-using Microsoft.AspNetCore.Http;
+using IdentityModel.OidcClient;
 using Microsoft.Extensions.Caching.Memory;
 using Microsoft.Extensions.Options;
-using System;
-using System.Linq;
-using System.Threading.Tasks;
+using System.Text.Json;
+using System.Text.Json.Serialization;
 
 namespace OASystem.API.Middlewares
 {
@@ -15,6 +12,7 @@ namespace OASystem.API.Middlewares
         private readonly IMemoryCache _cache;
         private readonly RateLimitConfig _config;
         private readonly ILogger<RateLimitMiddleware> _logger;
+        private readonly JsonSerializerOptions _jsonOptions;
 
         public RateLimitMiddleware(
             RequestDelegate next,
@@ -24,64 +22,108 @@ namespace OASystem.API.Middlewares
         {
             _next = next;
             _cache = cache;
-            _config = config.Value;
+            _config = config?.Value ?? new RateLimitConfig(); // 获取配置值
             _logger = logger;
+
+            // 配置JSON序列化选项
+            _jsonOptions = new JsonSerializerOptions
+            {
+                Encoder = System.Text.Encodings.Web.JavaScriptEncoder.UnsafeRelaxedJsonEscaping, // 允许中文
+                PropertyNamingPolicy = JsonNamingPolicy.CamelCase,
+                WriteIndented = false
+            };
+
+            // 日志输出配置信息,用于调试
+            LogConfiguration();
+        }
+
+        private void LogConfiguration()
+        {
+            _logger.LogInformation("=== 限流中间件初始化 ===");
+            _logger.LogInformation($"启用状态: {_config.Enabled}");
+            _logger.LogInformation($"默认限制: {_config.DefaultLimit}次/{_config.DefaultPeriod}秒");
+
+            if (_config.Endpoints != null && _config.Endpoints.Any())
+            {
+                _logger.LogInformation($"配置端点数量: {_config.Endpoints.Count}");
+                foreach (var endpoint in _config.Endpoints.Take(5)) // 只显示前5个
+                {
+                    _logger.LogInformation($"  {endpoint.Method} {endpoint.Path} -> {endpoint.Limit}次/{endpoint.Period}秒");
+                }
+                if (_config.Endpoints.Count > 5)
+                {
+                    _logger.LogInformation($"  ... 还有{_config.Endpoints.Count - 5}个配置");
+                }
+            }
+            else
+            {
+                _logger.LogWarning("未配置具体端点,将使用默认规则");
+            }
+            _logger.LogInformation("========================");
         }
 
         public async Task InvokeAsync(HttpContext context)
         {
+            // 记录请求开始信息(调试用)
+            var requestInfo = new RequestInfo
+            {
+                Path = context.Request.Path,
+                Method = context.Request.Method,
+                ClientIp = GetClientIp(context),
+                Timestamp = DateTime.UtcNow
+            };
+
+            _logger.LogDebug($"收到请求: {requestInfo.Method} {requestInfo.Path} from {requestInfo.ClientIp}");
+
             if (!_config.Enabled)
             {
+                _logger.LogDebug("限流功能已禁用,跳过检查");
                 await _next(context);
                 return;
             }
 
-            var path = context.Request.Path.ToString().ToLower();
+            var path = context.Request.Path.ToString();
             var method = context.Request.Method.ToUpper();
 
             // 跳过不需要限流的路径
             if (ShouldSkipRateLimit(path))
             {
+                _logger.LogDebug($"跳过限流检查: {path}");
                 await _next(context);
                 return;
             }
 
             // 查找匹配的配置
-            EndpointRateLimit endpointConfig = null;
-
-            foreach (var configItem in _config.Endpoints)
-            {
-                if (IsPathMatch(path, configItem.Path) &&
-                    IsMethodMatch(method, configItem.Method))
-                {
-                    endpointConfig = configItem;
-                    break;
-                }
-            }
+            var endpointConfig = FindMatchingConfig(path, method);
 
             if (endpointConfig == null)
             {
                 // 使用默认配置
-                if (!CheckLimit(context, "default", _config.DefaultLimit, _config.DefaultPeriod, RateLimitPolicy.IP))
+                _logger.LogDebug($"使用默认限流配置: {method} {path}");
+
+                if (!CheckLimit(context, "default", _config.DefaultLimit, _config.DefaultPeriod, RateLimitPolicy.IP, requestInfo))
                 {
-                    await ReturnRateLimitedResponse(context);
+                    await ReturnRateLimitedResponse(context, requestInfo, "默认规则");
                     return;
                 }
             }
             else
             {
                 var endpointKey = $"{endpointConfig.Method}:{endpointConfig.Path}";
-                if (!CheckLimit(context, endpointKey, endpointConfig.Limit, endpointConfig.Period, endpointConfig.Policy))
+                _logger.LogDebug($"匹配限流配置: {endpointKey} -> {endpointConfig.Limit}次/{endpointConfig.Period}秒");
+
+                if (!CheckLimit(context, endpointKey, endpointConfig.Limit, endpointConfig.Period, endpointConfig.Policy, requestInfo))
                 {
-                    await ReturnRateLimitedResponse(context, endpointConfig);
+                    await ReturnRateLimitedResponse(context, requestInfo, $"{method} {path}");
                     return;
                 }
             }
 
+            _logger.LogDebug($"请求通过限流检查: {method} {path}");
             await _next(context);
         }
 
-        private bool CheckLimit(HttpContext context, string endpointKey, int limit, int period, RateLimitPolicy policy)
+        private bool CheckLimit(HttpContext context, string endpointKey, int limit, int period, RateLimitPolicy policy, RequestInfo requestInfo)
         {
             var identifier = GetIdentifier(context, policy);
             var cacheKey = $"ratelimit:{endpointKey}:{identifier}";
@@ -90,6 +132,7 @@ namespace OASystem.API.Middlewares
 
             if (!_cache.TryGetValue<RateLimitCounter>(cacheKey, out var counter))
             {
+                // 首次请求
                 counter = new RateLimitCounter
                 {
                     Count = 1,
@@ -97,33 +140,49 @@ namespace OASystem.API.Middlewares
                     LastRequestTime = now
                 };
 
-                _cache.Set(cacheKey, counter, TimeSpan.FromSeconds(period));
+                _cache.Set(cacheKey, counter, TimeSpan.FromSeconds(period + 1));
+
+                _logger.LogDebug($"新建限流计数器: {cacheKey}, 计数: 1/{limit}");
                 return true;
             }
 
             // 如果第一个请求已经超过时间窗口,重置计数器
             if (counter.FirstRequestTime < windowStart)
             {
+                var oldCount = counter.Count;
                 counter.Count = 1;
                 counter.FirstRequestTime = now;
+
+                _logger.LogDebug($"重置限流计数器: {cacheKey}, 旧计数: {oldCount}, 新计数: 1/{limit}");
             }
             else if (counter.Count >= limit)
             {
+                // 触发限流!
                 var remainingSeconds = (int)(counter.FirstRequestTime.AddSeconds(period) - now).TotalSeconds;
-                _logger.LogWarning($"限流触发 - Endpoint: {endpointKey}, " +
-                                 $"Identifier: {identifier}, " +
-                                 $"Count: {counter.Count}, " +
-                                 $"Limit: {limit}, " +
-                                 $"Remaining: {remainingSeconds}s");
+
+                // 详细记录限流信息
+                _logger.LogWarning($"🚫 限流触发: {requestInfo.Method} {requestInfo.Path}");
+                _logger.LogWarning($"   客户端IP: {requestInfo.ClientIp}");
+                _logger.LogWarning($"   标识符: {identifier}");
+                _logger.LogWarning($"   规则: {endpointKey}");
+                _logger.LogWarning($"   当前计数: {counter.Count}/{limit}");
+                _logger.LogWarning($"   窗口开始: {counter.FirstRequestTime:HH:mm:ss}");
+                _logger.LogWarning($"   剩余时间: {remainingSeconds}秒");
+                _logger.LogWarning($"   用户代理: {context.Request.Headers["User-Agent"]}");
+
                 return false;
             }
             else
             {
+                // 正常计数增加
+                var oldCount = counter.Count;
                 counter.Count++;
+
+                _logger.LogDebug($"增加限流计数: {cacheKey}, {oldCount} -> {counter.Count}/{limit}");
             }
 
             counter.LastRequestTime = now;
-            _cache.Set(cacheKey, counter, TimeSpan.FromSeconds(period));
+            _cache.Set(cacheKey, counter, TimeSpan.FromSeconds(period + 1));
 
             return true;
         }
@@ -163,26 +222,75 @@ namespace OASystem.API.Middlewares
             }
         }
 
+        private async Task ReturnRateLimitedResponse(HttpContext context, RequestInfo requestInfo, string ruleInfo)
+        {
+            context.Response.StatusCode = 429;
+            context.Response.ContentType = "application/json; charset=utf-8";
+
+            // 添加响应头信息
+            context.Response.Headers.Add("X-RateLimit-Limit", "5");
+            context.Response.Headers.Add("X-RateLimit-Remaining", "0");
+            context.Response.Headers.Add("X-RateLimit-Reset", DateTime.UtcNow.AddSeconds(1).ToString("R"));
+            context.Response.Headers.Add("X-RateLimit-Rule", ruleInfo);
+
+            var response = new JsonView()
+            {
+                Code = 429,
+                Msg = $"请求过于频繁({ruleInfo}),请稍后再试",
+                Count = 0,
+                Data = new
+                {
+                    Path = requestInfo.Path,
+                    Method = requestInfo.Method,
+                    ClientIp = requestInfo.ClientIp,
+                    Rule = ruleInfo
+                }
+            };
+
+            var json = System.Text.Json.JsonSerializer.Serialize(response, _jsonOptions);
+            await context.Response.WriteAsync(json);
+
+            // 额外记录一次限流响应日志
+            _logger.LogInformation($"📤 限流响应已发送: {requestInfo.Method} {requestInfo.Path} -> 429");
+        }
+
+        private EndpointRateLimit FindMatchingConfig(string path, string method)
+        {
+            if (_config.Endpoints == null || !_config.Endpoints.Any())
+                return null;
+
+            foreach (var configItem in _config.Endpoints)
+            {
+                if (IsPathMatch(path, configItem.Path) &&
+                    IsMethodMatch(method, configItem.Method))
+                {
+                    return configItem;
+                }
+            }
+
+            return null;
+        }
         private bool IsPathMatch(string requestPath, string configPath)
         {
             if (configPath == "*") return true;
 
-            configPath = configPath.ToLower();
+            var normalizedPath = requestPath.ToLower();
+            var normalizedConfigPath = configPath.ToLower();
 
             // 精确匹配
-            if (requestPath.Equals(configPath, StringComparison.OrdinalIgnoreCase))
+            if (normalizedPath.Equals(normalizedConfigPath))
                 return true;
 
             // 前缀匹配(以/结尾表示前缀匹配)
-            if (configPath.EndsWith("/") && requestPath.StartsWith(configPath))
+            if (normalizedConfigPath.EndsWith("/") && normalizedPath.StartsWith(normalizedConfigPath))
                 return true;
 
             // 通配符匹配(简单的*通配符)
-            if (configPath.Contains("*"))
+            if (normalizedConfigPath.Contains("*"))
             {
-                var pattern = "^" + System.Text.RegularExpressions.Regex.Escape(configPath)
+                var pattern = "^" + System.Text.RegularExpressions.Regex.Escape(normalizedConfigPath)
                     .Replace("\\*", ".*") + "$";
-                return System.Text.RegularExpressions.Regex.IsMatch(requestPath, pattern);
+                return System.Text.RegularExpressions.Regex.IsMatch(normalizedPath, pattern);
             }
 
             return false;
@@ -201,7 +309,7 @@ namespace OASystem.API.Middlewares
             var skipPaths = new[]
             {
                 "/health",
-                //"/swagger",
+                "/swagger",
                 "/favicon.ico",
                 "/robots.txt",
                 "/.well-known"
@@ -210,79 +318,213 @@ namespace OASystem.API.Middlewares
             return skipPaths.Any(p => path.StartsWith(p, StringComparison.OrdinalIgnoreCase));
         }
 
-        private async Task ReturnRateLimitedResponse(HttpContext context, EndpointRateLimit config = null)
+        private string GetClientIp(HttpContext context)
         {
-            context.Response.StatusCode = 429; // Too Many Requests
-            context.Response.ContentType = "application/json";
+            // 方法1:获取真实IP地址(处理反向代理)
+            string ipAddress = string.Empty;
 
-            var message = config != null
-                ? $"接口访问过于频繁,请{config.Period}秒后再试"
-                : "请求过于频繁,请稍后再试";
+            // 优先从 X-Forwarded-For 获取(如果有反向代理如Nginx)
+            var forwardedFor = context.Request.Headers["X-Forwarded-For"].FirstOrDefault();
+            if (!string.IsNullOrEmpty(forwardedFor))
+            {
+                // X-Forwarded-For 可能是多个IP(client, proxy1, proxy2)
+                ipAddress = forwardedFor.Split(',')[0].Trim();
+
+                _logger.LogDebug($"从 X-Forwarded-For 获取IP: {ipAddress}");
 
-            var response = new
+                // 验证IP格式
+                if (IsValidIpAddress(ipAddress))
+                {
+                    return ipAddress;
+                }
+            }
+
+            // 从 X-Real-IP 获取
+            var realIp = context.Request.Headers["X-Real-IP"].FirstOrDefault();
+            if (!string.IsNullOrEmpty(realIp))
             {
-                Code = 429,
-                Msg = message
-            };
+                ipAddress = realIp.Trim();
+                _logger.LogDebug($"从 X-Real-IP 获取IP: {ipAddress}");
+
+                if (IsValidIpAddress(ipAddress))
+                {
+                    return ipAddress;
+                }
+            }
 
-            await context.Response.WriteAsJsonAsync(response);
+            // 从 CF-Connecting-IP 获取(Cloudflare)
+            var cfConnectingIp = context.Request.Headers["CF-Connecting-IP"].FirstOrDefault();
+            if (!string.IsNullOrEmpty(cfConnectingIp))
+            {
+                ipAddress = cfConnectingIp.Trim();
+                _logger.LogDebug($"从 CF-Connecting-IP 获取IP: {ipAddress}");
+
+                if (IsValidIpAddress(ipAddress))
+                {
+                    return ipAddress;
+                }
+            }
+
+            // 从 X-Original-For 获取
+            var originalFor = context.Request.Headers["X-Original-For"].FirstOrDefault();
+            if (!string.IsNullOrEmpty(originalFor))
+            {
+                ipAddress = originalFor.Trim();
+                _logger.LogDebug($"从 X-Original-For 获取IP: {ipAddress}");
+
+                if (IsValidIpAddress(ipAddress))
+                {
+                    return ipAddress;
+                }
+            }
+
+            // 最后使用 RemoteIpAddress
+            ipAddress = context.Connection.RemoteIpAddress?.ToString() ?? "unknown";
+
+            // 处理IPv6映射的IPv4地址
+            if (ipAddress.Contains("::ffff:"))
+            {
+                ipAddress = ipAddress.Replace("::ffff:", "");
+            }
+
+            // 处理IPv6本地地址
+            if (ipAddress == "::1" || ipAddress == "0:0:0:0:0:0:0:1")
+            {
+                ipAddress = "127.0.0.1";
+            }
+
+            _logger.LogDebug($"从 RemoteIpAddress 获取IP: {ipAddress}");
+
+            return ipAddress;
         }
-    }
 
-    /// <summary>
-    /// 限流计数器
-    /// </summary>
-    public class RateLimitCounter
-    {
-        public int Count { get; set; }
-        public DateTime FirstRequestTime { get; set; }
-        public DateTime LastRequestTime { get; set; }
-    }
+        private bool IsValidIpAddress(string ip)
+        {
+            if (string.IsNullOrWhiteSpace(ip))
+                return false;
 
-    #region 配置模型
+            if (ip.Equals("unknown", StringComparison.OrdinalIgnoreCase))
+                return false;
 
-    /// <summary>
-    /// 限流策略枚举
-    /// </summary>
-    public enum RateLimitPolicy
-    {
-        /// <summary>
-        /// 按IP地址限流
-        /// </summary>
-        IP = 0,
+            // 简单验证IP格式
+            if (System.Net.IPAddress.TryParse(ip, out var ipAddress))
+            {
+                // 排除私有IP地址(如果需要)
+                return !IsPrivateIp(ipAddress);
+                //return true;
+            }
 
-        /// <summary>
-        /// 按用户限流(需要用户登录)
-        /// </summary>
-        User = 1,
+            return false;
+        }
 
-        /// <summary>
-        /// 全局限流(所有用户共享限制)
-        /// </summary>
-        Global = 2,
+        private bool IsPrivateIp(System.Net.IPAddress ipAddress)
+        {
+            // 检查是否为内网IP
+            if (ipAddress.AddressFamily == System.Net.Sockets.AddressFamily.InterNetwork) // IPv4
+            {
+                var bytes = ipAddress.GetAddressBytes();
+
+                // 10.0.0.0/8
+                if (bytes[0] == 10)
+                    return true;
+
+                // 172.16.0.0/12
+                if (bytes[0] == 172 && bytes[1] >= 16 && bytes[1] <= 31)
+                    return true;
+
+                // 192.168.0.0/16
+                if (bytes[0] == 192 && bytes[1] == 168)
+                    return true;
+
+                // 127.0.0.0/8
+                if (bytes[0] == 127)
+                    return true;
+            }
+            else if (ipAddress.AddressFamily == System.Net.Sockets.AddressFamily.InterNetworkV6) // IPv6
+            {
+                // IPv6 本地地址
+                if (ipAddress.IsIPv6LinkLocal || ipAddress.IsIPv6SiteLocal ||
+                    ipAddress.IsIPv6Multicast || ipAddress.Equals(System.Net.IPAddress.IPv6Loopback))
+                    return true;
+            }
+
+            return false;
+        }
+
+        // 添加 RequestInfo 辅助类
+        private class RequestInfo
+        {
+            public string Path { get; set; }
+            public string Method { get; set; }
+            public string ClientIp { get; set; }
+            public DateTime Timestamp { get; set; }
+        }
+
+        public class RateLimitCounter
+        {
+            public int Count { get; set; }
+            public DateTime FirstRequestTime { get; set; }
+            public DateTime LastRequestTime { get; set; }
+        }
+        
+        #region 配置模型
 
         /// <summary>
-        /// 按客户端ID限流
+        /// 限流策略枚举
         /// </summary>
-        Client = 3
-    }
+        public enum RateLimitPolicy
+        {
+            /// <summary>
+            /// 按IP地址限流
+            /// </summary>
+            IP = 0,
+
+            /// <summary>
+            /// 按用户限流(需要用户登录)
+            /// </summary>
+            User = 1,
+
+            /// <summary>
+            /// 全局限流(所有用户共享限制)
+            /// </summary>
+            Global = 2,
+
+            /// <summary>
+            /// 按客户端ID限流
+            /// </summary>
+            Client = 3
+        }
 
-    public class RateLimitConfig
-    {
-        public bool Enabled { get; set; } = true;
-        public int DefaultLimit { get; set; } = 10;
-        public int DefaultPeriod { get; set; } = 1;
+        public class RateLimitConfig
+        {
+            public bool Enabled { get; set; } = true;
+            public int DefaultLimit { get; set; } = 5;      // 默认5次/秒
+            public int DefaultPeriod { get; set; } = 1;     // 默认时间窗口1秒
 
-        public List<EndpointRateLimit> Endpoints { get; set; } = new List<EndpointRateLimit>();
-    }
+            [JsonPropertyName("Endpoints")]
+            public List<EndpointRateLimit> Endpoints { get; set; } = new List<EndpointRateLimit>();
+        }
 
-    public class EndpointRateLimit
-    {
-        public string Path { get; set; }
-        public string Method { get; set; } = "*";
-        public int Limit { get; set; }
-        public int Period { get; set; } // 秒
-        public RateLimitPolicy Policy { get; set; } = RateLimitPolicy.IP;
+        public class EndpointRateLimit
+        {
+            [JsonPropertyName("Path")]
+            public string Path { get; set; } = "*";
+
+            [JsonPropertyName("Method")]
+            public string Method { get; set; } = "*";
+
+            [JsonPropertyName("Limit")]
+            public int Limit { get; set; } = 5;
+
+            [JsonPropertyName("Period")]
+            public int Period { get; set; } = 1;
+
+            [JsonPropertyName("Policy")]
+            public RateLimitPolicy Policy { get; set; } = RateLimitPolicy.IP;
+
+            [System.Text.Json.Serialization.JsonIgnore]
+            public string EndpointKey => $"{Method.ToUpper()}:{Path.ToLower()}";
+        }
+        #endregion
     }
-    #endregion
 }

+ 11 - 0
OASystem/OASystem.Api/Program.cs

@@ -21,6 +21,7 @@ using Quartz.Spi;
 using QuzrtzJob.Factory;
 using Serilog.Events;
 using System.IO.Compression;
+using static OASystem.API.Middlewares.RateLimitMiddleware;
 
 Console.Title = $"FMGJ OASystem Server";
 var builder = WebApplication.CreateBuilder(args);
@@ -81,6 +82,16 @@ builder.Services.AddControllers()
 
 builder.Services.TryAddSingleton<IHttpContextAccessor, HttpContextAccessor>();
 
+#region 添加限流中间件服务注册
+
+// 添加内存缓存(限流需要)
+builder.Services.AddMemoryCache();
+
+// 配置限流设置
+builder.Services.Configure<RateLimitConfig>(
+    builder.Configuration.GetSection("RateLimiting"));
+#endregion
+
 #region Gzip
 
 builder.Services.AddResponseCompression(options =>

+ 4 - 4
OASystem/OASystem.Api/appsettings.json

@@ -519,7 +519,7 @@
   //限流配置
   "RateLimiting": {
     "Enabled": true,
-    "DefaultLimit": 10,
+    "DefaultLimit": 5,
     "DefaultPeriod": 1,
     "Endpoints": [
       // 所有api接口统一限制:5次/秒
@@ -528,8 +528,9 @@
         "Method": "*",
         "Limit": 5,
         "Period": 1,
-        "Policy": "Global"
-      },
+        "Policy": "IP" //// 0=IP, 1=User, 2=Global, 3=Client
+      }
+      //,
       // 特定接口单独设置更严格的限制
       //{
       //  "Path": "/api/auth/login",
@@ -568,6 +569,5 @@
       //}
     ]
   },
-
   "AllowedHosts": "*"
 }

+ 3 - 3
OASystem/OASystem.Domain/Dtos/Task/GetTaskDefaultDto.cs

@@ -178,9 +178,9 @@ namespace OASystem.Domain.Dtos.Task
 
         public int TaskId { get; set; }
 
-        /// <summary>
-        /// 审核状态 0未审核 1通过 -1未通过
-        /// </summary>
+        ///// <summary>
+        ///// 审核状态 0未审核 1通过 -1未通过
+        ///// </summary>
         //public int IsApproved { get; set; }
     }