using IdentityModel.OidcClient; using Microsoft.Extensions.Caching.Memory; using Microsoft.Extensions.Options; using System.Text.Json; using System.Text.Json.Serialization; namespace OASystem.API.Middlewares { public class RateLimitMiddleware { private readonly RequestDelegate _next; private readonly IMemoryCache _cache; private readonly RateLimitConfig _config; private readonly ILogger _logger; private readonly JsonSerializerOptions _jsonOptions; public RateLimitMiddleware( RequestDelegate next, IMemoryCache cache, IOptions config, ILogger logger) { _next = next; _cache = cache; _config = config?.Value ?? new RateLimitConfig(); // 获取配置值 _logger = logger; // 配置JSON序列化选项 _jsonOptions = new JsonSerializerOptions { Encoder = System.Text.Encodings.Web.JavaScriptEncoder.UnsafeRelaxedJsonEscaping, // 允许中文 PropertyNamingPolicy = JsonNamingPolicy.CamelCase, WriteIndented = false }; // 日志输出配置信息,用于调试 LogConfiguration(); } private void LogConfiguration() { _logger.LogInformation("=== 限流中间件初始化 ==="); _logger.LogInformation($"启用状态: {_config.Enabled}"); _logger.LogInformation($"默认限制: {_config.DefaultLimit}次/{_config.DefaultPeriod}秒"); if (_config.Endpoints != null && _config.Endpoints.Any()) { _logger.LogInformation($"配置端点数量: {_config.Endpoints.Count}"); foreach (var endpoint in _config.Endpoints.Take(5)) // 只显示前5个 { _logger.LogInformation($" {endpoint.Method} {endpoint.Path} -> {endpoint.Limit}次/{endpoint.Period}秒"); } if (_config.Endpoints.Count > 5) { _logger.LogInformation($" ... 还有{_config.Endpoints.Count - 5}个配置"); } } else { _logger.LogWarning("未配置具体端点,将使用默认规则"); } _logger.LogInformation("========================"); } public async Task InvokeAsync(HttpContext context) { // 记录请求开始信息(调试用) var requestInfo = new RequestInfo { Path = context.Request.Path, Method = context.Request.Method, ClientIp = GetClientIp(context), Timestamp = DateTime.UtcNow }; _logger.LogDebug($"收到请求: {requestInfo.Method} {requestInfo.Path} from {requestInfo.ClientIp}"); if (!_config.Enabled) { _logger.LogDebug("限流功能已禁用,跳过检查"); await _next(context); return; } var path = context.Request.Path.ToString(); var method = context.Request.Method.ToUpper(); // 跳过不需要限流的路径 if (ShouldSkipRateLimit(path)) { _logger.LogDebug($"跳过限流检查: {path}"); await _next(context); return; } // 查找匹配的配置 var endpointConfig = FindMatchingConfig(path, method); if (endpointConfig == null) { // 使用默认配置 _logger.LogDebug($"使用默认限流配置: {method} {path}"); if (!CheckLimit(context, "default", _config.DefaultLimit, _config.DefaultPeriod, RateLimitPolicy.IP, requestInfo)) { await ReturnRateLimitedResponse(context, requestInfo, "默认规则"); return; } } else { var endpointKey = $"{endpointConfig.Method}:{endpointConfig.Path}"; _logger.LogDebug($"匹配限流配置: {endpointKey} -> {endpointConfig.Limit}次/{endpointConfig.Period}秒"); if (!CheckLimit(context, endpointKey, endpointConfig.Limit, endpointConfig.Period, endpointConfig.Policy, requestInfo)) { await ReturnRateLimitedResponse(context, requestInfo, $"{method} {path}"); return; } } _logger.LogDebug($"请求通过限流检查: {method} {path}"); await _next(context); } private bool CheckLimit(HttpContext context, string endpointKey, int limit, int period, RateLimitPolicy policy, RequestInfo requestInfo) { var identifier = GetIdentifier(context, policy); var cacheKey = $"ratelimit:{endpointKey}:{identifier}"; var now = DateTime.UtcNow; var windowStart = now.AddSeconds(-period); if (!_cache.TryGetValue(cacheKey, out var counter)) { // 首次请求 counter = new RateLimitCounter { Count = 1, FirstRequestTime = now, LastRequestTime = now }; _cache.Set(cacheKey, counter, TimeSpan.FromSeconds(period + 1)); _logger.LogDebug($"新建限流计数器: {cacheKey}, 计数: 1/{limit}"); return true; } // 如果第一个请求已经超过时间窗口,重置计数器 if (counter.FirstRequestTime < windowStart) { var oldCount = counter.Count; counter.Count = 1; counter.FirstRequestTime = now; _logger.LogDebug($"重置限流计数器: {cacheKey}, 旧计数: {oldCount}, 新计数: 1/{limit}"); } else if (counter.Count >= limit) { // 触发限流! var remainingSeconds = (int)(counter.FirstRequestTime.AddSeconds(period) - now).TotalSeconds; // 详细记录限流信息 _logger.LogWarning($"🚫 限流触发: {requestInfo.Method} {requestInfo.Path}"); _logger.LogWarning($" 客户端IP: {requestInfo.ClientIp}"); _logger.LogWarning($" 标识符: {identifier}"); _logger.LogWarning($" 规则: {endpointKey}"); _logger.LogWarning($" 当前计数: {counter.Count}/{limit}"); _logger.LogWarning($" 窗口开始: {counter.FirstRequestTime:HH:mm:ss}"); _logger.LogWarning($" 剩余时间: {remainingSeconds}秒"); _logger.LogWarning($" 用户代理: {context.Request.Headers["User-Agent"]}"); return false; } else { // 正常计数增加 var oldCount = counter.Count; counter.Count++; _logger.LogDebug($"增加限流计数: {cacheKey}, {oldCount} -> {counter.Count}/{limit}"); } counter.LastRequestTime = now; _cache.Set(cacheKey, counter, TimeSpan.FromSeconds(period + 1)); return true; } private string GetIdentifier(HttpContext context, RateLimitPolicy policy) { switch (policy) { case RateLimitPolicy.User: if (context.User?.Identity?.IsAuthenticated == true) { var userId = context.User.FindFirst("sub")?.Value ?? context.User.FindFirst(System.Security.Claims.ClaimTypes.NameIdentifier)?.Value ?? context.User.Identity.Name; return $"user_{userId}"; } // 如果未登录,回退到IP return $"ip_{context.Connection.RemoteIpAddress}"; case RateLimitPolicy.Global: return "global"; case RateLimitPolicy.Client: var clientId = context.Request.Headers["X-Client-Id"].FirstOrDefault() ?? context.Request.Headers["X-API-Key"].FirstOrDefault(); return $"client_{clientId ?? "unknown"}"; case RateLimitPolicy.IP: default: var ipAddress = context.Connection.RemoteIpAddress?.ToString() ?? "unknown"; // 处理IPv6映射的IPv4地址 if (ipAddress.Contains("::ffff:")) { ipAddress = ipAddress.Replace("::ffff:", ""); } return $"ip_{ipAddress}"; } } private async Task ReturnRateLimitedResponse(HttpContext context, RequestInfo requestInfo, string ruleInfo) { context.Response.StatusCode = 429; context.Response.ContentType = "application/json; charset=utf-8"; // 添加响应头信息 context.Response.Headers.Add("X-RateLimit-Limit", "5"); context.Response.Headers.Add("X-RateLimit-Remaining", "0"); context.Response.Headers.Add("X-RateLimit-Reset", DateTime.UtcNow.AddSeconds(1).ToString("R")); context.Response.Headers.Add("X-RateLimit-Rule", ruleInfo); var response = new JsonView() { Code = 429, Msg = $"请求过于频繁(IP:{requestInfo.ClientIp}),请稍后再试!", Count = 0, Data = new { Path = requestInfo.Path, Method = requestInfo.Method, ClientIp = requestInfo.ClientIp, Rule = ruleInfo } }; var json = System.Text.Json.JsonSerializer.Serialize(response, _jsonOptions); await context.Response.WriteAsync(json); // 额外记录一次限流响应日志 _logger.LogInformation($"📤 限流响应已发送: {requestInfo.Method} {requestInfo.Path} -> 429"); } private EndpointRateLimit FindMatchingConfig(string path, string method) { if (_config.Endpoints == null || !_config.Endpoints.Any()) return null; foreach (var configItem in _config.Endpoints) { if (IsPathMatch(path, configItem.Path) && IsMethodMatch(method, configItem.Method)) { return configItem; } } return null; } private bool IsPathMatch(string requestPath, string configPath) { if (configPath == "*") return true; var normalizedPath = requestPath.ToLower(); var normalizedConfigPath = configPath.ToLower(); // 精确匹配 if (normalizedPath.Equals(normalizedConfigPath)) return true; // 前缀匹配(以/结尾表示前缀匹配) if (normalizedConfigPath.EndsWith("/") && normalizedPath.StartsWith(normalizedConfigPath)) return true; // 通配符匹配(简单的*通配符) if (normalizedConfigPath.Contains("*")) { var pattern = "^" + System.Text.RegularExpressions.Regex.Escape(normalizedConfigPath) .Replace("\\*", ".*") + "$"; return System.Text.RegularExpressions.Regex.IsMatch(normalizedPath, pattern); } return false; } private bool IsMethodMatch(string requestMethod, string configMethod) { if (configMethod == "*") return true; return requestMethod.Equals(configMethod, StringComparison.OrdinalIgnoreCase); } private bool ShouldSkipRateLimit(string path) { // 跳过健康检查、swagger等 var skipPaths = new[] { "/health", "/swagger", "/favicon.ico", "/robots.txt", "/.well-known" }; return skipPaths.Any(p => path.StartsWith(p, StringComparison.OrdinalIgnoreCase)); } private string GetClientIp(HttpContext context) { // 方法1:获取真实IP地址(处理反向代理) string ipAddress = string.Empty; // 优先从 X-Forwarded-For 获取(如果有反向代理如Nginx) var forwardedFor = context.Request.Headers["X-Forwarded-For"].FirstOrDefault(); if (!string.IsNullOrEmpty(forwardedFor)) { // X-Forwarded-For 可能是多个IP(client, proxy1, proxy2) ipAddress = forwardedFor.Split(',')[0].Trim(); _logger.LogDebug($"从 X-Forwarded-For 获取IP: {ipAddress}"); // 验证IP格式 if (IsValidIpAddress(ipAddress)) { return ipAddress; } } // 从 X-Real-IP 获取 var realIp = context.Request.Headers["X-Real-IP"].FirstOrDefault(); if (!string.IsNullOrEmpty(realIp)) { ipAddress = realIp.Trim(); _logger.LogDebug($"从 X-Real-IP 获取IP: {ipAddress}"); if (IsValidIpAddress(ipAddress)) { return ipAddress; } } // 从 CF-Connecting-IP 获取(Cloudflare) var cfConnectingIp = context.Request.Headers["CF-Connecting-IP"].FirstOrDefault(); if (!string.IsNullOrEmpty(cfConnectingIp)) { ipAddress = cfConnectingIp.Trim(); _logger.LogDebug($"从 CF-Connecting-IP 获取IP: {ipAddress}"); if (IsValidIpAddress(ipAddress)) { return ipAddress; } } // 从 X-Original-For 获取 var originalFor = context.Request.Headers["X-Original-For"].FirstOrDefault(); if (!string.IsNullOrEmpty(originalFor)) { ipAddress = originalFor.Trim(); _logger.LogDebug($"从 X-Original-For 获取IP: {ipAddress}"); if (IsValidIpAddress(ipAddress)) { return ipAddress; } } // 最后使用 RemoteIpAddress ipAddress = context.Connection.RemoteIpAddress?.ToString() ?? "unknown"; // 处理IPv6映射的IPv4地址 if (ipAddress.Contains("::ffff:")) { ipAddress = ipAddress.Replace("::ffff:", ""); } // 处理IPv6本地地址 if (ipAddress == "::1" || ipAddress == "0:0:0:0:0:0:0:1") { ipAddress = "127.0.0.1"; } _logger.LogDebug($"从 RemoteIpAddress 获取IP: {ipAddress}"); return ipAddress; } private bool IsValidIpAddress(string ip) { if (string.IsNullOrWhiteSpace(ip)) return false; if (ip.Equals("unknown", StringComparison.OrdinalIgnoreCase)) return false; // 简单验证IP格式 if (System.Net.IPAddress.TryParse(ip, out var ipAddress)) { // 排除私有IP地址(如果需要) return !IsPrivateIp(ipAddress); //return true; } return false; } private bool IsPrivateIp(System.Net.IPAddress ipAddress) { // 检查是否为内网IP if (ipAddress.AddressFamily == System.Net.Sockets.AddressFamily.InterNetwork) // IPv4 { var bytes = ipAddress.GetAddressBytes(); // 10.0.0.0/8 if (bytes[0] == 10) return true; // 172.16.0.0/12 if (bytes[0] == 172 && bytes[1] >= 16 && bytes[1] <= 31) return true; // 192.168.0.0/16 if (bytes[0] == 192 && bytes[1] == 168) return true; // 127.0.0.0/8 if (bytes[0] == 127) return true; } else if (ipAddress.AddressFamily == System.Net.Sockets.AddressFamily.InterNetworkV6) // IPv6 { // IPv6 本地地址 if (ipAddress.IsIPv6LinkLocal || ipAddress.IsIPv6SiteLocal || ipAddress.IsIPv6Multicast || ipAddress.Equals(System.Net.IPAddress.IPv6Loopback)) return true; } return false; } // 添加 RequestInfo 辅助类 private class RequestInfo { public string Path { get; set; } public string Method { get; set; } public string ClientIp { get; set; } public DateTime Timestamp { get; set; } } public class RateLimitCounter { public int Count { get; set; } public DateTime FirstRequestTime { get; set; } public DateTime LastRequestTime { get; set; } } #region 配置模型 /// /// 限流策略枚举 /// public enum RateLimitPolicy { /// /// 按IP地址限流 /// IP = 0, /// /// 按用户限流(需要用户登录) /// User = 1, /// /// 全局限流(所有用户共享限制) /// Global = 2, /// /// 按客户端ID限流 /// Client = 3 } public class RateLimitConfig { public bool Enabled { get; set; } = true; public int DefaultLimit { get; set; } = 5; // 默认5次/秒 public int DefaultPeriod { get; set; } = 1; // 默认时间窗口1秒 [JsonPropertyName("Endpoints")] public List Endpoints { get; set; } = new List(); } public class EndpointRateLimit { [JsonPropertyName("Path")] public string Path { get; set; } = "*"; [JsonPropertyName("Method")] public string Method { get; set; } = "*"; [JsonPropertyName("Limit")] public int Limit { get; set; } = 5; [JsonPropertyName("Period")] public int Period { get; set; } = 1; [JsonPropertyName("Policy")] public RateLimitPolicy Policy { get; set; } = RateLimitPolicy.IP; [System.Text.Json.Serialization.JsonIgnore] public string EndpointKey => $"{Method.ToUpper()}:{Path.ToLower()}"; } #endregion } }