| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530 |
- using IdentityModel.OidcClient;
- using Microsoft.Extensions.Caching.Memory;
- using Microsoft.Extensions.Options;
- using System.Text.Json;
- using System.Text.Json.Serialization;
- namespace OASystem.API.Middlewares
- {
- public class RateLimitMiddleware
- {
- private readonly RequestDelegate _next;
- private readonly IMemoryCache _cache;
- private readonly RateLimitConfig _config;
- private readonly ILogger<RateLimitMiddleware> _logger;
- private readonly JsonSerializerOptions _jsonOptions;
- public RateLimitMiddleware(
- RequestDelegate next,
- IMemoryCache cache,
- IOptions<RateLimitConfig> config,
- ILogger<RateLimitMiddleware> logger)
- {
- _next = next;
- _cache = cache;
- _config = config?.Value ?? new RateLimitConfig(); // 获取配置值
- _logger = logger;
- // 配置JSON序列化选项
- _jsonOptions = new JsonSerializerOptions
- {
- Encoder = System.Text.Encodings.Web.JavaScriptEncoder.UnsafeRelaxedJsonEscaping, // 允许中文
- PropertyNamingPolicy = JsonNamingPolicy.CamelCase,
- WriteIndented = false
- };
- // 日志输出配置信息,用于调试
- LogConfiguration();
- }
- private void LogConfiguration()
- {
- _logger.LogInformation("=== 限流中间件初始化 ===");
- _logger.LogInformation($"启用状态: {_config.Enabled}");
- _logger.LogInformation($"默认限制: {_config.DefaultLimit}次/{_config.DefaultPeriod}秒");
- if (_config.Endpoints != null && _config.Endpoints.Any())
- {
- _logger.LogInformation($"配置端点数量: {_config.Endpoints.Count}");
- foreach (var endpoint in _config.Endpoints.Take(5)) // 只显示前5个
- {
- _logger.LogInformation($" {endpoint.Method} {endpoint.Path} -> {endpoint.Limit}次/{endpoint.Period}秒");
- }
- if (_config.Endpoints.Count > 5)
- {
- _logger.LogInformation($" ... 还有{_config.Endpoints.Count - 5}个配置");
- }
- }
- else
- {
- _logger.LogWarning("未配置具体端点,将使用默认规则");
- }
- _logger.LogInformation("========================");
- }
- public async Task InvokeAsync(HttpContext context)
- {
- // 记录请求开始信息(调试用)
- var requestInfo = new RequestInfo
- {
- Path = context.Request.Path,
- Method = context.Request.Method,
- ClientIp = GetClientIp(context),
- Timestamp = DateTime.UtcNow
- };
- _logger.LogDebug($"收到请求: {requestInfo.Method} {requestInfo.Path} from {requestInfo.ClientIp}");
- if (!_config.Enabled)
- {
- _logger.LogDebug("限流功能已禁用,跳过检查");
- await _next(context);
- return;
- }
- var path = context.Request.Path.ToString();
- var method = context.Request.Method.ToUpper();
- // 跳过不需要限流的路径
- if (ShouldSkipRateLimit(path))
- {
- _logger.LogDebug($"跳过限流检查: {path}");
- await _next(context);
- return;
- }
- // 查找匹配的配置
- var endpointConfig = FindMatchingConfig(path, method);
- if (endpointConfig == null)
- {
- // 使用默认配置
- _logger.LogDebug($"使用默认限流配置: {method} {path}");
- if (!CheckLimit(context, "default", _config.DefaultLimit, _config.DefaultPeriod, RateLimitPolicy.IP, requestInfo))
- {
- await ReturnRateLimitedResponse(context, requestInfo, "默认规则");
- return;
- }
- }
- else
- {
- var endpointKey = $"{endpointConfig.Method}:{endpointConfig.Path}";
- _logger.LogDebug($"匹配限流配置: {endpointKey} -> {endpointConfig.Limit}次/{endpointConfig.Period}秒");
- if (!CheckLimit(context, endpointKey, endpointConfig.Limit, endpointConfig.Period, endpointConfig.Policy, requestInfo))
- {
- await ReturnRateLimitedResponse(context, requestInfo, $"{method} {path}");
- return;
- }
- }
- _logger.LogDebug($"请求通过限流检查: {method} {path}");
- await _next(context);
- }
- private bool CheckLimit(HttpContext context, string endpointKey, int limit, int period, RateLimitPolicy policy, RequestInfo requestInfo)
- {
- var identifier = GetIdentifier(context, policy);
- var cacheKey = $"ratelimit:{endpointKey}:{identifier}";
- var now = DateTime.UtcNow;
- var windowStart = now.AddSeconds(-period);
- if (!_cache.TryGetValue<RateLimitCounter>(cacheKey, out var counter))
- {
- // 首次请求
- counter = new RateLimitCounter
- {
- Count = 1,
- FirstRequestTime = now,
- LastRequestTime = now
- };
- _cache.Set(cacheKey, counter, TimeSpan.FromSeconds(period + 1));
- _logger.LogDebug($"新建限流计数器: {cacheKey}, 计数: 1/{limit}");
- return true;
- }
- // 如果第一个请求已经超过时间窗口,重置计数器
- if (counter.FirstRequestTime < windowStart)
- {
- var oldCount = counter.Count;
- counter.Count = 1;
- counter.FirstRequestTime = now;
- _logger.LogDebug($"重置限流计数器: {cacheKey}, 旧计数: {oldCount}, 新计数: 1/{limit}");
- }
- else if (counter.Count >= limit)
- {
- // 触发限流!
- var remainingSeconds = (int)(counter.FirstRequestTime.AddSeconds(period) - now).TotalSeconds;
- // 详细记录限流信息
- _logger.LogWarning($"🚫 限流触发: {requestInfo.Method} {requestInfo.Path}");
- _logger.LogWarning($" 客户端IP: {requestInfo.ClientIp}");
- _logger.LogWarning($" 标识符: {identifier}");
- _logger.LogWarning($" 规则: {endpointKey}");
- _logger.LogWarning($" 当前计数: {counter.Count}/{limit}");
- _logger.LogWarning($" 窗口开始: {counter.FirstRequestTime:HH:mm:ss}");
- _logger.LogWarning($" 剩余时间: {remainingSeconds}秒");
- _logger.LogWarning($" 用户代理: {context.Request.Headers["User-Agent"]}");
- return false;
- }
- else
- {
- // 正常计数增加
- var oldCount = counter.Count;
- counter.Count++;
- _logger.LogDebug($"增加限流计数: {cacheKey}, {oldCount} -> {counter.Count}/{limit}");
- }
- counter.LastRequestTime = now;
- _cache.Set(cacheKey, counter, TimeSpan.FromSeconds(period + 1));
- return true;
- }
- private string GetIdentifier(HttpContext context, RateLimitPolicy policy)
- {
- switch (policy)
- {
- case RateLimitPolicy.User:
- if (context.User?.Identity?.IsAuthenticated == true)
- {
- var userId = context.User.FindFirst("sub")?.Value
- ?? context.User.FindFirst(System.Security.Claims.ClaimTypes.NameIdentifier)?.Value
- ?? context.User.Identity.Name;
- return $"user_{userId}";
- }
- // 如果未登录,回退到IP
- return $"ip_{context.Connection.RemoteIpAddress}";
- case RateLimitPolicy.Global:
- return "global";
- case RateLimitPolicy.Client:
- var clientId = context.Request.Headers["X-Client-Id"].FirstOrDefault()
- ?? context.Request.Headers["X-API-Key"].FirstOrDefault();
- return $"client_{clientId ?? "unknown"}";
- case RateLimitPolicy.IP:
- default:
- var ipAddress = context.Connection.RemoteIpAddress?.ToString() ?? "unknown";
- // 处理IPv6映射的IPv4地址
- if (ipAddress.Contains("::ffff:"))
- {
- ipAddress = ipAddress.Replace("::ffff:", "");
- }
- return $"ip_{ipAddress}";
- }
- }
- private async Task ReturnRateLimitedResponse(HttpContext context, RequestInfo requestInfo, string ruleInfo)
- {
- context.Response.StatusCode = 429;
- context.Response.ContentType = "application/json; charset=utf-8";
- // 添加响应头信息
- context.Response.Headers.Add("X-RateLimit-Limit", "5");
- context.Response.Headers.Add("X-RateLimit-Remaining", "0");
- context.Response.Headers.Add("X-RateLimit-Reset", DateTime.UtcNow.AddSeconds(1).ToString("R"));
- context.Response.Headers.Add("X-RateLimit-Rule", ruleInfo);
- var response = new JsonView()
- {
- Code = 429,
- Msg = $"请求过于频繁({ruleInfo}),请稍后再试",
- Count = 0,
- Data = new
- {
- Path = requestInfo.Path,
- Method = requestInfo.Method,
- ClientIp = requestInfo.ClientIp,
- Rule = ruleInfo
- }
- };
- var json = System.Text.Json.JsonSerializer.Serialize(response, _jsonOptions);
- await context.Response.WriteAsync(json);
- // 额外记录一次限流响应日志
- _logger.LogInformation($"📤 限流响应已发送: {requestInfo.Method} {requestInfo.Path} -> 429");
- }
- private EndpointRateLimit FindMatchingConfig(string path, string method)
- {
- if (_config.Endpoints == null || !_config.Endpoints.Any())
- return null;
- foreach (var configItem in _config.Endpoints)
- {
- if (IsPathMatch(path, configItem.Path) &&
- IsMethodMatch(method, configItem.Method))
- {
- return configItem;
- }
- }
- return null;
- }
- private bool IsPathMatch(string requestPath, string configPath)
- {
- if (configPath == "*") return true;
- var normalizedPath = requestPath.ToLower();
- var normalizedConfigPath = configPath.ToLower();
- // 精确匹配
- if (normalizedPath.Equals(normalizedConfigPath))
- return true;
- // 前缀匹配(以/结尾表示前缀匹配)
- if (normalizedConfigPath.EndsWith("/") && normalizedPath.StartsWith(normalizedConfigPath))
- return true;
- // 通配符匹配(简单的*通配符)
- if (normalizedConfigPath.Contains("*"))
- {
- var pattern = "^" + System.Text.RegularExpressions.Regex.Escape(normalizedConfigPath)
- .Replace("\\*", ".*") + "$";
- return System.Text.RegularExpressions.Regex.IsMatch(normalizedPath, pattern);
- }
- return false;
- }
- private bool IsMethodMatch(string requestMethod, string configMethod)
- {
- if (configMethod == "*") return true;
- return requestMethod.Equals(configMethod, StringComparison.OrdinalIgnoreCase);
- }
- private bool ShouldSkipRateLimit(string path)
- {
- // 跳过健康检查、swagger等
- var skipPaths = new[]
- {
- "/health",
- "/swagger",
- "/favicon.ico",
- "/robots.txt",
- "/.well-known"
- };
- return skipPaths.Any(p => path.StartsWith(p, StringComparison.OrdinalIgnoreCase));
- }
- private string GetClientIp(HttpContext context)
- {
- // 方法1:获取真实IP地址(处理反向代理)
- string ipAddress = string.Empty;
- // 优先从 X-Forwarded-For 获取(如果有反向代理如Nginx)
- var forwardedFor = context.Request.Headers["X-Forwarded-For"].FirstOrDefault();
- if (!string.IsNullOrEmpty(forwardedFor))
- {
- // X-Forwarded-For 可能是多个IP(client, proxy1, proxy2)
- ipAddress = forwardedFor.Split(',')[0].Trim();
- _logger.LogDebug($"从 X-Forwarded-For 获取IP: {ipAddress}");
- // 验证IP格式
- if (IsValidIpAddress(ipAddress))
- {
- return ipAddress;
- }
- }
- // 从 X-Real-IP 获取
- var realIp = context.Request.Headers["X-Real-IP"].FirstOrDefault();
- if (!string.IsNullOrEmpty(realIp))
- {
- ipAddress = realIp.Trim();
- _logger.LogDebug($"从 X-Real-IP 获取IP: {ipAddress}");
- if (IsValidIpAddress(ipAddress))
- {
- return ipAddress;
- }
- }
- // 从 CF-Connecting-IP 获取(Cloudflare)
- var cfConnectingIp = context.Request.Headers["CF-Connecting-IP"].FirstOrDefault();
- if (!string.IsNullOrEmpty(cfConnectingIp))
- {
- ipAddress = cfConnectingIp.Trim();
- _logger.LogDebug($"从 CF-Connecting-IP 获取IP: {ipAddress}");
- if (IsValidIpAddress(ipAddress))
- {
- return ipAddress;
- }
- }
- // 从 X-Original-For 获取
- var originalFor = context.Request.Headers["X-Original-For"].FirstOrDefault();
- if (!string.IsNullOrEmpty(originalFor))
- {
- ipAddress = originalFor.Trim();
- _logger.LogDebug($"从 X-Original-For 获取IP: {ipAddress}");
- if (IsValidIpAddress(ipAddress))
- {
- return ipAddress;
- }
- }
- // 最后使用 RemoteIpAddress
- ipAddress = context.Connection.RemoteIpAddress?.ToString() ?? "unknown";
- // 处理IPv6映射的IPv4地址
- if (ipAddress.Contains("::ffff:"))
- {
- ipAddress = ipAddress.Replace("::ffff:", "");
- }
- // 处理IPv6本地地址
- if (ipAddress == "::1" || ipAddress == "0:0:0:0:0:0:0:1")
- {
- ipAddress = "127.0.0.1";
- }
- _logger.LogDebug($"从 RemoteIpAddress 获取IP: {ipAddress}");
- return ipAddress;
- }
- private bool IsValidIpAddress(string ip)
- {
- if (string.IsNullOrWhiteSpace(ip))
- return false;
- if (ip.Equals("unknown", StringComparison.OrdinalIgnoreCase))
- return false;
- // 简单验证IP格式
- if (System.Net.IPAddress.TryParse(ip, out var ipAddress))
- {
- // 排除私有IP地址(如果需要)
- return !IsPrivateIp(ipAddress);
- //return true;
- }
- return false;
- }
- private bool IsPrivateIp(System.Net.IPAddress ipAddress)
- {
- // 检查是否为内网IP
- if (ipAddress.AddressFamily == System.Net.Sockets.AddressFamily.InterNetwork) // IPv4
- {
- var bytes = ipAddress.GetAddressBytes();
- // 10.0.0.0/8
- if (bytes[0] == 10)
- return true;
- // 172.16.0.0/12
- if (bytes[0] == 172 && bytes[1] >= 16 && bytes[1] <= 31)
- return true;
- // 192.168.0.0/16
- if (bytes[0] == 192 && bytes[1] == 168)
- return true;
- // 127.0.0.0/8
- if (bytes[0] == 127)
- return true;
- }
- else if (ipAddress.AddressFamily == System.Net.Sockets.AddressFamily.InterNetworkV6) // IPv6
- {
- // IPv6 本地地址
- if (ipAddress.IsIPv6LinkLocal || ipAddress.IsIPv6SiteLocal ||
- ipAddress.IsIPv6Multicast || ipAddress.Equals(System.Net.IPAddress.IPv6Loopback))
- return true;
- }
- return false;
- }
- // 添加 RequestInfo 辅助类
- private class RequestInfo
- {
- public string Path { get; set; }
- public string Method { get; set; }
- public string ClientIp { get; set; }
- public DateTime Timestamp { get; set; }
- }
- public class RateLimitCounter
- {
- public int Count { get; set; }
- public DateTime FirstRequestTime { get; set; }
- public DateTime LastRequestTime { get; set; }
- }
-
- #region 配置模型
- /// <summary>
- /// 限流策略枚举
- /// </summary>
- public enum RateLimitPolicy
- {
- /// <summary>
- /// 按IP地址限流
- /// </summary>
- IP = 0,
- /// <summary>
- /// 按用户限流(需要用户登录)
- /// </summary>
- User = 1,
- /// <summary>
- /// 全局限流(所有用户共享限制)
- /// </summary>
- Global = 2,
- /// <summary>
- /// 按客户端ID限流
- /// </summary>
- Client = 3
- }
- public class RateLimitConfig
- {
- public bool Enabled { get; set; } = true;
- public int DefaultLimit { get; set; } = 5; // 默认5次/秒
- public int DefaultPeriod { get; set; } = 1; // 默认时间窗口1秒
- [JsonPropertyName("Endpoints")]
- public List<EndpointRateLimit> Endpoints { get; set; } = new List<EndpointRateLimit>();
- }
- public class EndpointRateLimit
- {
- [JsonPropertyName("Path")]
- public string Path { get; set; } = "*";
- [JsonPropertyName("Method")]
- public string Method { get; set; } = "*";
- [JsonPropertyName("Limit")]
- public int Limit { get; set; } = 5;
- [JsonPropertyName("Period")]
- public int Period { get; set; } = 1;
- [JsonPropertyName("Policy")]
- public RateLimitPolicy Policy { get; set; } = RateLimitPolicy.IP;
- [System.Text.Json.Serialization.JsonIgnore]
- public string EndpointKey => $"{Method.ToUpper()}:{Path.ToLower()}";
- }
- #endregion
- }
- }
|