RateLimitMiddleware.cs 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530
  1. using IdentityModel.OidcClient;
  2. using Microsoft.Extensions.Caching.Memory;
  3. using Microsoft.Extensions.Options;
  4. using System.Text.Json;
  5. using System.Text.Json.Serialization;
  6. namespace OASystem.API.Middlewares
  7. {
  8. public class RateLimitMiddleware
  9. {
  10. private readonly RequestDelegate _next;
  11. private readonly IMemoryCache _cache;
  12. private readonly RateLimitConfig _config;
  13. private readonly ILogger<RateLimitMiddleware> _logger;
  14. private readonly JsonSerializerOptions _jsonOptions;
  15. public RateLimitMiddleware(
  16. RequestDelegate next,
  17. IMemoryCache cache,
  18. IOptions<RateLimitConfig> config,
  19. ILogger<RateLimitMiddleware> logger)
  20. {
  21. _next = next;
  22. _cache = cache;
  23. _config = config?.Value ?? new RateLimitConfig(); // 获取配置值
  24. _logger = logger;
  25. // 配置JSON序列化选项
  26. _jsonOptions = new JsonSerializerOptions
  27. {
  28. Encoder = System.Text.Encodings.Web.JavaScriptEncoder.UnsafeRelaxedJsonEscaping, // 允许中文
  29. PropertyNamingPolicy = JsonNamingPolicy.CamelCase,
  30. WriteIndented = false
  31. };
  32. // 日志输出配置信息,用于调试
  33. LogConfiguration();
  34. }
  35. private void LogConfiguration()
  36. {
  37. _logger.LogInformation("=== 限流中间件初始化 ===");
  38. _logger.LogInformation($"启用状态: {_config.Enabled}");
  39. _logger.LogInformation($"默认限制: {_config.DefaultLimit}次/{_config.DefaultPeriod}秒");
  40. if (_config.Endpoints != null && _config.Endpoints.Any())
  41. {
  42. _logger.LogInformation($"配置端点数量: {_config.Endpoints.Count}");
  43. foreach (var endpoint in _config.Endpoints.Take(5)) // 只显示前5个
  44. {
  45. _logger.LogInformation($" {endpoint.Method} {endpoint.Path} -> {endpoint.Limit}次/{endpoint.Period}秒");
  46. }
  47. if (_config.Endpoints.Count > 5)
  48. {
  49. _logger.LogInformation($" ... 还有{_config.Endpoints.Count - 5}个配置");
  50. }
  51. }
  52. else
  53. {
  54. _logger.LogWarning("未配置具体端点,将使用默认规则");
  55. }
  56. _logger.LogInformation("========================");
  57. }
  58. public async Task InvokeAsync(HttpContext context)
  59. {
  60. // 记录请求开始信息(调试用)
  61. var requestInfo = new RequestInfo
  62. {
  63. Path = context.Request.Path,
  64. Method = context.Request.Method,
  65. ClientIp = GetClientIp(context),
  66. Timestamp = DateTime.UtcNow
  67. };
  68. _logger.LogDebug($"收到请求: {requestInfo.Method} {requestInfo.Path} from {requestInfo.ClientIp}");
  69. if (!_config.Enabled)
  70. {
  71. _logger.LogDebug("限流功能已禁用,跳过检查");
  72. await _next(context);
  73. return;
  74. }
  75. var path = context.Request.Path.ToString();
  76. var method = context.Request.Method.ToUpper();
  77. // 跳过不需要限流的路径
  78. if (ShouldSkipRateLimit(path))
  79. {
  80. _logger.LogDebug($"跳过限流检查: {path}");
  81. await _next(context);
  82. return;
  83. }
  84. // 查找匹配的配置
  85. var endpointConfig = FindMatchingConfig(path, method);
  86. if (endpointConfig == null)
  87. {
  88. // 使用默认配置
  89. _logger.LogDebug($"使用默认限流配置: {method} {path}");
  90. if (!CheckLimit(context, "default", _config.DefaultLimit, _config.DefaultPeriod, RateLimitPolicy.IP, requestInfo))
  91. {
  92. await ReturnRateLimitedResponse(context, requestInfo, "默认规则");
  93. return;
  94. }
  95. }
  96. else
  97. {
  98. var endpointKey = $"{endpointConfig.Method}:{endpointConfig.Path}";
  99. _logger.LogDebug($"匹配限流配置: {endpointKey} -> {endpointConfig.Limit}次/{endpointConfig.Period}秒");
  100. if (!CheckLimit(context, endpointKey, endpointConfig.Limit, endpointConfig.Period, endpointConfig.Policy, requestInfo))
  101. {
  102. await ReturnRateLimitedResponse(context, requestInfo, $"{method} {path}");
  103. return;
  104. }
  105. }
  106. _logger.LogDebug($"请求通过限流检查: {method} {path}");
  107. await _next(context);
  108. }
  109. private bool CheckLimit(HttpContext context, string endpointKey, int limit, int period, RateLimitPolicy policy, RequestInfo requestInfo)
  110. {
  111. var identifier = GetIdentifier(context, policy);
  112. var cacheKey = $"ratelimit:{endpointKey}:{identifier}";
  113. var now = DateTime.UtcNow;
  114. var windowStart = now.AddSeconds(-period);
  115. if (!_cache.TryGetValue<RateLimitCounter>(cacheKey, out var counter))
  116. {
  117. // 首次请求
  118. counter = new RateLimitCounter
  119. {
  120. Count = 1,
  121. FirstRequestTime = now,
  122. LastRequestTime = now
  123. };
  124. _cache.Set(cacheKey, counter, TimeSpan.FromSeconds(period + 1));
  125. _logger.LogDebug($"新建限流计数器: {cacheKey}, 计数: 1/{limit}");
  126. return true;
  127. }
  128. // 如果第一个请求已经超过时间窗口,重置计数器
  129. if (counter.FirstRequestTime < windowStart)
  130. {
  131. var oldCount = counter.Count;
  132. counter.Count = 1;
  133. counter.FirstRequestTime = now;
  134. _logger.LogDebug($"重置限流计数器: {cacheKey}, 旧计数: {oldCount}, 新计数: 1/{limit}");
  135. }
  136. else if (counter.Count >= limit)
  137. {
  138. // 触发限流!
  139. var remainingSeconds = (int)(counter.FirstRequestTime.AddSeconds(period) - now).TotalSeconds;
  140. // 详细记录限流信息
  141. _logger.LogWarning($"🚫 限流触发: {requestInfo.Method} {requestInfo.Path}");
  142. _logger.LogWarning($" 客户端IP: {requestInfo.ClientIp}");
  143. _logger.LogWarning($" 标识符: {identifier}");
  144. _logger.LogWarning($" 规则: {endpointKey}");
  145. _logger.LogWarning($" 当前计数: {counter.Count}/{limit}");
  146. _logger.LogWarning($" 窗口开始: {counter.FirstRequestTime:HH:mm:ss}");
  147. _logger.LogWarning($" 剩余时间: {remainingSeconds}秒");
  148. _logger.LogWarning($" 用户代理: {context.Request.Headers["User-Agent"]}");
  149. return false;
  150. }
  151. else
  152. {
  153. // 正常计数增加
  154. var oldCount = counter.Count;
  155. counter.Count++;
  156. _logger.LogDebug($"增加限流计数: {cacheKey}, {oldCount} -> {counter.Count}/{limit}");
  157. }
  158. counter.LastRequestTime = now;
  159. _cache.Set(cacheKey, counter, TimeSpan.FromSeconds(period + 1));
  160. return true;
  161. }
  162. private string GetIdentifier(HttpContext context, RateLimitPolicy policy)
  163. {
  164. switch (policy)
  165. {
  166. case RateLimitPolicy.User:
  167. if (context.User?.Identity?.IsAuthenticated == true)
  168. {
  169. var userId = context.User.FindFirst("sub")?.Value
  170. ?? context.User.FindFirst(System.Security.Claims.ClaimTypes.NameIdentifier)?.Value
  171. ?? context.User.Identity.Name;
  172. return $"user_{userId}";
  173. }
  174. // 如果未登录,回退到IP
  175. return $"ip_{context.Connection.RemoteIpAddress}";
  176. case RateLimitPolicy.Global:
  177. return "global";
  178. case RateLimitPolicy.Client:
  179. var clientId = context.Request.Headers["X-Client-Id"].FirstOrDefault()
  180. ?? context.Request.Headers["X-API-Key"].FirstOrDefault();
  181. return $"client_{clientId ?? "unknown"}";
  182. case RateLimitPolicy.IP:
  183. default:
  184. var ipAddress = context.Connection.RemoteIpAddress?.ToString() ?? "unknown";
  185. // 处理IPv6映射的IPv4地址
  186. if (ipAddress.Contains("::ffff:"))
  187. {
  188. ipAddress = ipAddress.Replace("::ffff:", "");
  189. }
  190. return $"ip_{ipAddress}";
  191. }
  192. }
  193. private async Task ReturnRateLimitedResponse(HttpContext context, RequestInfo requestInfo, string ruleInfo)
  194. {
  195. context.Response.StatusCode = 429;
  196. context.Response.ContentType = "application/json; charset=utf-8";
  197. // 添加响应头信息
  198. context.Response.Headers.Add("X-RateLimit-Limit", "5");
  199. context.Response.Headers.Add("X-RateLimit-Remaining", "0");
  200. context.Response.Headers.Add("X-RateLimit-Reset", DateTime.UtcNow.AddSeconds(1).ToString("R"));
  201. context.Response.Headers.Add("X-RateLimit-Rule", ruleInfo);
  202. var response = new JsonView()
  203. {
  204. Code = 429,
  205. Msg = $"请求过于频繁({ruleInfo}),请稍后再试",
  206. Count = 0,
  207. Data = new
  208. {
  209. Path = requestInfo.Path,
  210. Method = requestInfo.Method,
  211. ClientIp = requestInfo.ClientIp,
  212. Rule = ruleInfo
  213. }
  214. };
  215. var json = System.Text.Json.JsonSerializer.Serialize(response, _jsonOptions);
  216. await context.Response.WriteAsync(json);
  217. // 额外记录一次限流响应日志
  218. _logger.LogInformation($"📤 限流响应已发送: {requestInfo.Method} {requestInfo.Path} -> 429");
  219. }
  220. private EndpointRateLimit FindMatchingConfig(string path, string method)
  221. {
  222. if (_config.Endpoints == null || !_config.Endpoints.Any())
  223. return null;
  224. foreach (var configItem in _config.Endpoints)
  225. {
  226. if (IsPathMatch(path, configItem.Path) &&
  227. IsMethodMatch(method, configItem.Method))
  228. {
  229. return configItem;
  230. }
  231. }
  232. return null;
  233. }
  234. private bool IsPathMatch(string requestPath, string configPath)
  235. {
  236. if (configPath == "*") return true;
  237. var normalizedPath = requestPath.ToLower();
  238. var normalizedConfigPath = configPath.ToLower();
  239. // 精确匹配
  240. if (normalizedPath.Equals(normalizedConfigPath))
  241. return true;
  242. // 前缀匹配(以/结尾表示前缀匹配)
  243. if (normalizedConfigPath.EndsWith("/") && normalizedPath.StartsWith(normalizedConfigPath))
  244. return true;
  245. // 通配符匹配(简单的*通配符)
  246. if (normalizedConfigPath.Contains("*"))
  247. {
  248. var pattern = "^" + System.Text.RegularExpressions.Regex.Escape(normalizedConfigPath)
  249. .Replace("\\*", ".*") + "$";
  250. return System.Text.RegularExpressions.Regex.IsMatch(normalizedPath, pattern);
  251. }
  252. return false;
  253. }
  254. private bool IsMethodMatch(string requestMethod, string configMethod)
  255. {
  256. if (configMethod == "*") return true;
  257. return requestMethod.Equals(configMethod, StringComparison.OrdinalIgnoreCase);
  258. }
  259. private bool ShouldSkipRateLimit(string path)
  260. {
  261. // 跳过健康检查、swagger等
  262. var skipPaths = new[]
  263. {
  264. "/health",
  265. "/swagger",
  266. "/favicon.ico",
  267. "/robots.txt",
  268. "/.well-known"
  269. };
  270. return skipPaths.Any(p => path.StartsWith(p, StringComparison.OrdinalIgnoreCase));
  271. }
  272. private string GetClientIp(HttpContext context)
  273. {
  274. // 方法1:获取真实IP地址(处理反向代理)
  275. string ipAddress = string.Empty;
  276. // 优先从 X-Forwarded-For 获取(如果有反向代理如Nginx)
  277. var forwardedFor = context.Request.Headers["X-Forwarded-For"].FirstOrDefault();
  278. if (!string.IsNullOrEmpty(forwardedFor))
  279. {
  280. // X-Forwarded-For 可能是多个IP(client, proxy1, proxy2)
  281. ipAddress = forwardedFor.Split(',')[0].Trim();
  282. _logger.LogDebug($"从 X-Forwarded-For 获取IP: {ipAddress}");
  283. // 验证IP格式
  284. if (IsValidIpAddress(ipAddress))
  285. {
  286. return ipAddress;
  287. }
  288. }
  289. // 从 X-Real-IP 获取
  290. var realIp = context.Request.Headers["X-Real-IP"].FirstOrDefault();
  291. if (!string.IsNullOrEmpty(realIp))
  292. {
  293. ipAddress = realIp.Trim();
  294. _logger.LogDebug($"从 X-Real-IP 获取IP: {ipAddress}");
  295. if (IsValidIpAddress(ipAddress))
  296. {
  297. return ipAddress;
  298. }
  299. }
  300. // 从 CF-Connecting-IP 获取(Cloudflare)
  301. var cfConnectingIp = context.Request.Headers["CF-Connecting-IP"].FirstOrDefault();
  302. if (!string.IsNullOrEmpty(cfConnectingIp))
  303. {
  304. ipAddress = cfConnectingIp.Trim();
  305. _logger.LogDebug($"从 CF-Connecting-IP 获取IP: {ipAddress}");
  306. if (IsValidIpAddress(ipAddress))
  307. {
  308. return ipAddress;
  309. }
  310. }
  311. // 从 X-Original-For 获取
  312. var originalFor = context.Request.Headers["X-Original-For"].FirstOrDefault();
  313. if (!string.IsNullOrEmpty(originalFor))
  314. {
  315. ipAddress = originalFor.Trim();
  316. _logger.LogDebug($"从 X-Original-For 获取IP: {ipAddress}");
  317. if (IsValidIpAddress(ipAddress))
  318. {
  319. return ipAddress;
  320. }
  321. }
  322. // 最后使用 RemoteIpAddress
  323. ipAddress = context.Connection.RemoteIpAddress?.ToString() ?? "unknown";
  324. // 处理IPv6映射的IPv4地址
  325. if (ipAddress.Contains("::ffff:"))
  326. {
  327. ipAddress = ipAddress.Replace("::ffff:", "");
  328. }
  329. // 处理IPv6本地地址
  330. if (ipAddress == "::1" || ipAddress == "0:0:0:0:0:0:0:1")
  331. {
  332. ipAddress = "127.0.0.1";
  333. }
  334. _logger.LogDebug($"从 RemoteIpAddress 获取IP: {ipAddress}");
  335. return ipAddress;
  336. }
  337. private bool IsValidIpAddress(string ip)
  338. {
  339. if (string.IsNullOrWhiteSpace(ip))
  340. return false;
  341. if (ip.Equals("unknown", StringComparison.OrdinalIgnoreCase))
  342. return false;
  343. // 简单验证IP格式
  344. if (System.Net.IPAddress.TryParse(ip, out var ipAddress))
  345. {
  346. // 排除私有IP地址(如果需要)
  347. return !IsPrivateIp(ipAddress);
  348. //return true;
  349. }
  350. return false;
  351. }
  352. private bool IsPrivateIp(System.Net.IPAddress ipAddress)
  353. {
  354. // 检查是否为内网IP
  355. if (ipAddress.AddressFamily == System.Net.Sockets.AddressFamily.InterNetwork) // IPv4
  356. {
  357. var bytes = ipAddress.GetAddressBytes();
  358. // 10.0.0.0/8
  359. if (bytes[0] == 10)
  360. return true;
  361. // 172.16.0.0/12
  362. if (bytes[0] == 172 && bytes[1] >= 16 && bytes[1] <= 31)
  363. return true;
  364. // 192.168.0.0/16
  365. if (bytes[0] == 192 && bytes[1] == 168)
  366. return true;
  367. // 127.0.0.0/8
  368. if (bytes[0] == 127)
  369. return true;
  370. }
  371. else if (ipAddress.AddressFamily == System.Net.Sockets.AddressFamily.InterNetworkV6) // IPv6
  372. {
  373. // IPv6 本地地址
  374. if (ipAddress.IsIPv6LinkLocal || ipAddress.IsIPv6SiteLocal ||
  375. ipAddress.IsIPv6Multicast || ipAddress.Equals(System.Net.IPAddress.IPv6Loopback))
  376. return true;
  377. }
  378. return false;
  379. }
  380. // 添加 RequestInfo 辅助类
  381. private class RequestInfo
  382. {
  383. public string Path { get; set; }
  384. public string Method { get; set; }
  385. public string ClientIp { get; set; }
  386. public DateTime Timestamp { get; set; }
  387. }
  388. public class RateLimitCounter
  389. {
  390. public int Count { get; set; }
  391. public DateTime FirstRequestTime { get; set; }
  392. public DateTime LastRequestTime { get; set; }
  393. }
  394. #region 配置模型
  395. /// <summary>
  396. /// 限流策略枚举
  397. /// </summary>
  398. public enum RateLimitPolicy
  399. {
  400. /// <summary>
  401. /// 按IP地址限流
  402. /// </summary>
  403. IP = 0,
  404. /// <summary>
  405. /// 按用户限流(需要用户登录)
  406. /// </summary>
  407. User = 1,
  408. /// <summary>
  409. /// 全局限流(所有用户共享限制)
  410. /// </summary>
  411. Global = 2,
  412. /// <summary>
  413. /// 按客户端ID限流
  414. /// </summary>
  415. Client = 3
  416. }
  417. public class RateLimitConfig
  418. {
  419. public bool Enabled { get; set; } = true;
  420. public int DefaultLimit { get; set; } = 5; // 默认5次/秒
  421. public int DefaultPeriod { get; set; } = 1; // 默认时间窗口1秒
  422. [JsonPropertyName("Endpoints")]
  423. public List<EndpointRateLimit> Endpoints { get; set; } = new List<EndpointRateLimit>();
  424. }
  425. public class EndpointRateLimit
  426. {
  427. [JsonPropertyName("Path")]
  428. public string Path { get; set; } = "*";
  429. [JsonPropertyName("Method")]
  430. public string Method { get; set; } = "*";
  431. [JsonPropertyName("Limit")]
  432. public int Limit { get; set; } = 5;
  433. [JsonPropertyName("Period")]
  434. public int Period { get; set; } = 1;
  435. [JsonPropertyName("Policy")]
  436. public RateLimitPolicy Policy { get; set; } = RateLimitPolicy.IP;
  437. [System.Text.Json.Serialization.JsonIgnore]
  438. public string EndpointKey => $"{Method.ToUpper()}:{Path.ToLower()}";
  439. }
  440. #endregion
  441. }
  442. }