// Middlewares/RateLimitMiddleware.cs using AspNetCoreRateLimit; using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Caching.Memory; using Microsoft.Extensions.Options; using System; using System.Linq; using System.Threading.Tasks; namespace OASystem.API.Middlewares { public class RateLimitMiddleware { private readonly RequestDelegate _next; private readonly IMemoryCache _cache; private readonly RateLimitConfig _config; private readonly ILogger _logger; public RateLimitMiddleware( RequestDelegate next, IMemoryCache cache, IOptions config, ILogger logger) { _next = next; _cache = cache; _config = config.Value; _logger = logger; } public async Task InvokeAsync(HttpContext context) { if (!_config.Enabled) { await _next(context); return; } var path = context.Request.Path.ToString().ToLower(); var method = context.Request.Method.ToUpper(); // 跳过不需要限流的路径 if (ShouldSkipRateLimit(path)) { await _next(context); return; } // 查找匹配的配置 EndpointRateLimit endpointConfig = null; foreach (var configItem in _config.Endpoints) { if (IsPathMatch(path, configItem.Path) && IsMethodMatch(method, configItem.Method)) { endpointConfig = configItem; break; } } if (endpointConfig == null) { // 使用默认配置 if (!CheckLimit(context, "default", _config.DefaultLimit, _config.DefaultPeriod, RateLimitPolicy.IP)) { await ReturnRateLimitedResponse(context); return; } } else { var endpointKey = $"{endpointConfig.Method}:{endpointConfig.Path}"; if (!CheckLimit(context, endpointKey, endpointConfig.Limit, endpointConfig.Period, endpointConfig.Policy)) { await ReturnRateLimitedResponse(context, endpointConfig); return; } } await _next(context); } private bool CheckLimit(HttpContext context, string endpointKey, int limit, int period, RateLimitPolicy policy) { var identifier = GetIdentifier(context, policy); var cacheKey = $"ratelimit:{endpointKey}:{identifier}"; var now = DateTime.UtcNow; var windowStart = now.AddSeconds(-period); if (!_cache.TryGetValue(cacheKey, out var counter)) { counter = new RateLimitCounter { Count = 1, FirstRequestTime = now, LastRequestTime = now }; _cache.Set(cacheKey, counter, TimeSpan.FromSeconds(period)); return true; } // 如果第一个请求已经超过时间窗口,重置计数器 if (counter.FirstRequestTime < windowStart) { counter.Count = 1; counter.FirstRequestTime = now; } else if (counter.Count >= limit) { var remainingSeconds = (int)(counter.FirstRequestTime.AddSeconds(period) - now).TotalSeconds; _logger.LogWarning($"限流触发 - Endpoint: {endpointKey}, " + $"Identifier: {identifier}, " + $"Count: {counter.Count}, " + $"Limit: {limit}, " + $"Remaining: {remainingSeconds}s"); return false; } else { counter.Count++; } counter.LastRequestTime = now; _cache.Set(cacheKey, counter, TimeSpan.FromSeconds(period)); return true; } private string GetIdentifier(HttpContext context, RateLimitPolicy policy) { switch (policy) { case RateLimitPolicy.User: if (context.User?.Identity?.IsAuthenticated == true) { var userId = context.User.FindFirst("sub")?.Value ?? context.User.FindFirst(System.Security.Claims.ClaimTypes.NameIdentifier)?.Value ?? context.User.Identity.Name; return $"user_{userId}"; } // 如果未登录,回退到IP return $"ip_{context.Connection.RemoteIpAddress}"; case RateLimitPolicy.Global: return "global"; case RateLimitPolicy.Client: var clientId = context.Request.Headers["X-Client-Id"].FirstOrDefault() ?? context.Request.Headers["X-API-Key"].FirstOrDefault(); return $"client_{clientId ?? "unknown"}"; case RateLimitPolicy.IP: default: var ipAddress = context.Connection.RemoteIpAddress?.ToString() ?? "unknown"; // 处理IPv6映射的IPv4地址 if (ipAddress.Contains("::ffff:")) { ipAddress = ipAddress.Replace("::ffff:", ""); } return $"ip_{ipAddress}"; } } private bool IsPathMatch(string requestPath, string configPath) { if (configPath == "*") return true; configPath = configPath.ToLower(); // 精确匹配 if (requestPath.Equals(configPath, StringComparison.OrdinalIgnoreCase)) return true; // 前缀匹配(以/结尾表示前缀匹配) if (configPath.EndsWith("/") && requestPath.StartsWith(configPath)) return true; // 通配符匹配(简单的*通配符) if (configPath.Contains("*")) { var pattern = "^" + System.Text.RegularExpressions.Regex.Escape(configPath) .Replace("\\*", ".*") + "$"; return System.Text.RegularExpressions.Regex.IsMatch(requestPath, pattern); } return false; } private bool IsMethodMatch(string requestMethod, string configMethod) { if (configMethod == "*") return true; return requestMethod.Equals(configMethod, StringComparison.OrdinalIgnoreCase); } private bool ShouldSkipRateLimit(string path) { // 跳过健康检查、swagger等 var skipPaths = new[] { "/health", //"/swagger", "/favicon.ico", "/robots.txt", "/.well-known" }; return skipPaths.Any(p => path.StartsWith(p, StringComparison.OrdinalIgnoreCase)); } private async Task ReturnRateLimitedResponse(HttpContext context, EndpointRateLimit config = null) { context.Response.StatusCode = 429; // Too Many Requests context.Response.ContentType = "application/json"; var message = config != null ? $"接口访问过于频繁,请{config.Period}秒后再试" : "请求过于频繁,请稍后再试"; var response = new { Code = 429, Msg = message }; await context.Response.WriteAsJsonAsync(response); } } /// /// 限流计数器 /// public class RateLimitCounter { public int Count { get; set; } public DateTime FirstRequestTime { get; set; } public DateTime LastRequestTime { get; set; } } #region 配置模型 /// /// 限流策略枚举 /// public enum RateLimitPolicy { /// /// 按IP地址限流 /// IP = 0, /// /// 按用户限流(需要用户登录) /// User = 1, /// /// 全局限流(所有用户共享限制) /// Global = 2, /// /// 按客户端ID限流 /// Client = 3 } public class RateLimitConfig { public bool Enabled { get; set; } = true; public int DefaultLimit { get; set; } = 10; public int DefaultPeriod { get; set; } = 1; public List Endpoints { get; set; } = new List(); } public class EndpointRateLimit { public string Path { get; set; } public string Method { get; set; } = "*"; public int Limit { get; set; } public int Period { get; set; } // 秒 public RateLimitPolicy Policy { get; set; } = RateLimitPolicy.IP; } #endregion }