Explorar el Código

api 限流 50%

Lyyyi hace 3 días
padre
commit
d2fdf9129c

+ 288 - 0
OASystem/OASystem.Api/Middlewares/RateLimitMiddleware.cs

@@ -0,0 +1,288 @@
+// Middlewares/RateLimitMiddleware.cs
+using AspNetCoreRateLimit;
+using Microsoft.AspNetCore.Http;
+using Microsoft.Extensions.Caching.Memory;
+using Microsoft.Extensions.Options;
+using System;
+using System.Linq;
+using System.Threading.Tasks;
+
+namespace OASystem.API.Middlewares
+{
+    public class RateLimitMiddleware
+    {
+        private readonly RequestDelegate _next;
+        private readonly IMemoryCache _cache;
+        private readonly RateLimitConfig _config;
+        private readonly ILogger<RateLimitMiddleware> _logger;
+
+        public RateLimitMiddleware(
+            RequestDelegate next,
+            IMemoryCache cache,
+            IOptions<RateLimitConfig> config,
+            ILogger<RateLimitMiddleware> logger)
+        {
+            _next = next;
+            _cache = cache;
+            _config = config.Value;
+            _logger = logger;
+        }
+
+        public async Task InvokeAsync(HttpContext context)
+        {
+            if (!_config.Enabled)
+            {
+                await _next(context);
+                return;
+            }
+
+            var path = context.Request.Path.ToString().ToLower();
+            var method = context.Request.Method.ToUpper();
+
+            // 跳过不需要限流的路径
+            if (ShouldSkipRateLimit(path))
+            {
+                await _next(context);
+                return;
+            }
+
+            // 查找匹配的配置
+            EndpointRateLimit endpointConfig = null;
+
+            foreach (var configItem in _config.Endpoints)
+            {
+                if (IsPathMatch(path, configItem.Path) &&
+                    IsMethodMatch(method, configItem.Method))
+                {
+                    endpointConfig = configItem;
+                    break;
+                }
+            }
+
+            if (endpointConfig == null)
+            {
+                // 使用默认配置
+                if (!CheckLimit(context, "default", _config.DefaultLimit, _config.DefaultPeriod, RateLimitPolicy.IP))
+                {
+                    await ReturnRateLimitedResponse(context);
+                    return;
+                }
+            }
+            else
+            {
+                var endpointKey = $"{endpointConfig.Method}:{endpointConfig.Path}";
+                if (!CheckLimit(context, endpointKey, endpointConfig.Limit, endpointConfig.Period, endpointConfig.Policy))
+                {
+                    await ReturnRateLimitedResponse(context, endpointConfig);
+                    return;
+                }
+            }
+
+            await _next(context);
+        }
+
+        private bool CheckLimit(HttpContext context, string endpointKey, int limit, int period, RateLimitPolicy policy)
+        {
+            var identifier = GetIdentifier(context, policy);
+            var cacheKey = $"ratelimit:{endpointKey}:{identifier}";
+            var now = DateTime.UtcNow;
+            var windowStart = now.AddSeconds(-period);
+
+            if (!_cache.TryGetValue<RateLimitCounter>(cacheKey, out var counter))
+            {
+                counter = new RateLimitCounter
+                {
+                    Count = 1,
+                    FirstRequestTime = now,
+                    LastRequestTime = now
+                };
+
+                _cache.Set(cacheKey, counter, TimeSpan.FromSeconds(period));
+                return true;
+            }
+
+            // 如果第一个请求已经超过时间窗口,重置计数器
+            if (counter.FirstRequestTime < windowStart)
+            {
+                counter.Count = 1;
+                counter.FirstRequestTime = now;
+            }
+            else if (counter.Count >= limit)
+            {
+                var remainingSeconds = (int)(counter.FirstRequestTime.AddSeconds(period) - now).TotalSeconds;
+                _logger.LogWarning($"限流触发 - Endpoint: {endpointKey}, " +
+                                 $"Identifier: {identifier}, " +
+                                 $"Count: {counter.Count}, " +
+                                 $"Limit: {limit}, " +
+                                 $"Remaining: {remainingSeconds}s");
+                return false;
+            }
+            else
+            {
+                counter.Count++;
+            }
+
+            counter.LastRequestTime = now;
+            _cache.Set(cacheKey, counter, TimeSpan.FromSeconds(period));
+
+            return true;
+        }
+
+        private string GetIdentifier(HttpContext context, RateLimitPolicy policy)
+        {
+            switch (policy)
+            {
+                case RateLimitPolicy.User:
+                    if (context.User?.Identity?.IsAuthenticated == true)
+                    {
+                        var userId = context.User.FindFirst("sub")?.Value
+                                  ?? context.User.FindFirst(System.Security.Claims.ClaimTypes.NameIdentifier)?.Value
+                                  ?? context.User.Identity.Name;
+                        return $"user_{userId}";
+                    }
+                    // 如果未登录,回退到IP
+                    return $"ip_{context.Connection.RemoteIpAddress}";
+
+                case RateLimitPolicy.Global:
+                    return "global";
+
+                case RateLimitPolicy.Client:
+                    var clientId = context.Request.Headers["X-Client-Id"].FirstOrDefault()
+                                ?? context.Request.Headers["X-API-Key"].FirstOrDefault();
+                    return $"client_{clientId ?? "unknown"}";
+
+                case RateLimitPolicy.IP:
+                default:
+                    var ipAddress = context.Connection.RemoteIpAddress?.ToString() ?? "unknown";
+                    // 处理IPv6映射的IPv4地址
+                    if (ipAddress.Contains("::ffff:"))
+                    {
+                        ipAddress = ipAddress.Replace("::ffff:", "");
+                    }
+                    return $"ip_{ipAddress}";
+            }
+        }
+
+        private bool IsPathMatch(string requestPath, string configPath)
+        {
+            if (configPath == "*") return true;
+
+            configPath = configPath.ToLower();
+
+            // 精确匹配
+            if (requestPath.Equals(configPath, StringComparison.OrdinalIgnoreCase))
+                return true;
+
+            // 前缀匹配(以/结尾表示前缀匹配)
+            if (configPath.EndsWith("/") && requestPath.StartsWith(configPath))
+                return true;
+
+            // 通配符匹配(简单的*通配符)
+            if (configPath.Contains("*"))
+            {
+                var pattern = "^" + System.Text.RegularExpressions.Regex.Escape(configPath)
+                    .Replace("\\*", ".*") + "$";
+                return System.Text.RegularExpressions.Regex.IsMatch(requestPath, pattern);
+            }
+
+            return false;
+        }
+
+        private bool IsMethodMatch(string requestMethod, string configMethod)
+        {
+            if (configMethod == "*") return true;
+
+            return requestMethod.Equals(configMethod, StringComparison.OrdinalIgnoreCase);
+        }
+
+        private bool ShouldSkipRateLimit(string path)
+        {
+            // 跳过健康检查、swagger等
+            var skipPaths = new[]
+            {
+                "/health",
+                //"/swagger",
+                "/favicon.ico",
+                "/robots.txt",
+                "/.well-known"
+            };
+
+            return skipPaths.Any(p => path.StartsWith(p, StringComparison.OrdinalIgnoreCase));
+        }
+
+        private async Task ReturnRateLimitedResponse(HttpContext context, EndpointRateLimit config = null)
+        {
+            context.Response.StatusCode = 429; // Too Many Requests
+            context.Response.ContentType = "application/json";
+
+            var message = config != null
+                ? $"接口访问过于频繁,请{config.Period}秒后再试"
+                : "请求过于频繁,请稍后再试";
+
+            var response = new
+            {
+                Code = 429,
+                Msg = message
+            };
+
+            await context.Response.WriteAsJsonAsync(response);
+        }
+    }
+
+    /// <summary>
+    /// 限流计数器
+    /// </summary>
+    public class RateLimitCounter
+    {
+        public int Count { get; set; }
+        public DateTime FirstRequestTime { get; set; }
+        public DateTime LastRequestTime { get; set; }
+    }
+
+    #region 配置模型
+
+    /// <summary>
+    /// 限流策略枚举
+    /// </summary>
+    public enum RateLimitPolicy
+    {
+        /// <summary>
+        /// 按IP地址限流
+        /// </summary>
+        IP = 0,
+
+        /// <summary>
+        /// 按用户限流(需要用户登录)
+        /// </summary>
+        User = 1,
+
+        /// <summary>
+        /// 全局限流(所有用户共享限制)
+        /// </summary>
+        Global = 2,
+
+        /// <summary>
+        /// 按客户端ID限流
+        /// </summary>
+        Client = 3
+    }
+
+    public class RateLimitConfig
+    {
+        public bool Enabled { get; set; } = true;
+        public int DefaultLimit { get; set; } = 10;
+        public int DefaultPeriod { get; set; } = 1;
+
+        public List<EndpointRateLimit> Endpoints { get; set; } = new List<EndpointRateLimit>();
+    }
+
+    public class EndpointRateLimit
+    {
+        public string Path { get; set; }
+        public string Method { get; set; } = "*";
+        public int Limit { get; set; }
+        public int Period { get; set; } // 秒
+        public RateLimitPolicy Policy { get; set; } = RateLimitPolicy.IP;
+    }
+    #endregion
+}

+ 1 - 0
OASystem/OASystem.Api/OASystem.API.csproj

@@ -38,6 +38,7 @@
     <PackageReference Include="AspNetCore.HealthChecks.UI" Version="6.0.1" />
     <PackageReference Include="AspNetCore.HealthChecks.UI.Client" Version="6.0.2" />
     <PackageReference Include="AspNetCore.HealthChecks.UI.InMemory.Storage" Version="6.0.2" />
+    <PackageReference Include="AspNetCoreRateLimit" Version="5.0.0" />
     <PackageReference Include="Aspose.Cells" Version="23.4.0" />
     <PackageReference Include="Autofac" Version="6.4.0" />
     <PackageReference Include="Autofac.Extensions.DependencyInjection" Version="8.0.0" />

+ 3 - 0
OASystem/OASystem.Api/Program.cs

@@ -637,6 +637,9 @@ app.UseCors("Cors");  //Cors
 app.UseMiddleware<RecordAPIOperationMiddleware>();
 
 app.UseAuthentication(); // ÈÏÖ¤
+
+app.UseMiddleware<RateLimitMiddleware>();
+
 app.UseAuthorization();  // ÊÚȨ
 
 app.UseWhen(context =>

+ 56 - 2
OASystem/OASystem.Api/appsettings.json

@@ -512,8 +512,62 @@
     "MaxFileSizeMB": 50,
     "ExcelTemplatesPath": "Templates/Excel"
   },
-  "AllowedHosts": "*",
+  //"AllowedHosts": "*",
   "DecreasePaymentsSelectIsGm": [
     149
-  ]
+  ],
+  //限流配置
+  "RateLimiting": {
+    "Enabled": true,
+    "DefaultLimit": 10,
+    "DefaultPeriod": 1,
+    "Endpoints": [
+      // 所有api接口统一限制:5次/秒
+      {
+        "Path": "/api/*",
+        "Method": "*",
+        "Limit": 5,
+        "Period": 1,
+        "Policy": "Global"
+      },
+      // 特定接口单独设置更严格的限制
+      //{
+      //  "Path": "/api/auth/login",
+      //  "Method": "POST",
+      //  "Limit": 5,
+      //  "Period": 60,
+      //  "Policy": 0 // 0=IP, 1=User, 2=Global, 3=Client
+      //},
+      //{
+      //  "Path": "/api/auth/register",
+      //  "Method": "POST",
+      //  "Limit": 3,
+      //  "Period": 3600,
+      //  "Policy": 0
+      //},
+      //{
+      //  "Path": "/api/ocr/*",
+      //  "Method": "POST",
+      //  "Limit": 2,
+      //  "Period": 1,
+      //  "Policy": 0
+      //},
+      //{
+      //  "Path": "/api/financial/*",
+      //  "Method": "*",
+      //  "Limit": 30,
+      //  "Period": 60,
+      //  "Policy": 1 // 按用户限流
+      //},
+      //{
+      //  "Path": "/api/task/create",
+      //  "Method": "POST",
+      //  "Limit": 1,
+      //  "Period": 5,
+      //  "Policy": 0
+      //}
+    ]
+  },
+
+  "AllowedHosts": "*"
 }