RateLimitMiddleware.cs 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288
  1. // Middlewares/RateLimitMiddleware.cs
  2. using AspNetCoreRateLimit;
  3. using Microsoft.AspNetCore.Http;
  4. using Microsoft.Extensions.Caching.Memory;
  5. using Microsoft.Extensions.Options;
  6. using System;
  7. using System.Linq;
  8. using System.Threading.Tasks;
  9. namespace OASystem.API.Middlewares
  10. {
  11. public class RateLimitMiddleware
  12. {
  13. private readonly RequestDelegate _next;
  14. private readonly IMemoryCache _cache;
  15. private readonly RateLimitConfig _config;
  16. private readonly ILogger<RateLimitMiddleware> _logger;
  17. public RateLimitMiddleware(
  18. RequestDelegate next,
  19. IMemoryCache cache,
  20. IOptions<RateLimitConfig> config,
  21. ILogger<RateLimitMiddleware> logger)
  22. {
  23. _next = next;
  24. _cache = cache;
  25. _config = config.Value;
  26. _logger = logger;
  27. }
  28. public async Task InvokeAsync(HttpContext context)
  29. {
  30. if (!_config.Enabled)
  31. {
  32. await _next(context);
  33. return;
  34. }
  35. var path = context.Request.Path.ToString().ToLower();
  36. var method = context.Request.Method.ToUpper();
  37. // 跳过不需要限流的路径
  38. if (ShouldSkipRateLimit(path))
  39. {
  40. await _next(context);
  41. return;
  42. }
  43. // 查找匹配的配置
  44. EndpointRateLimit endpointConfig = null;
  45. foreach (var configItem in _config.Endpoints)
  46. {
  47. if (IsPathMatch(path, configItem.Path) &&
  48. IsMethodMatch(method, configItem.Method))
  49. {
  50. endpointConfig = configItem;
  51. break;
  52. }
  53. }
  54. if (endpointConfig == null)
  55. {
  56. // 使用默认配置
  57. if (!CheckLimit(context, "default", _config.DefaultLimit, _config.DefaultPeriod, RateLimitPolicy.IP))
  58. {
  59. await ReturnRateLimitedResponse(context);
  60. return;
  61. }
  62. }
  63. else
  64. {
  65. var endpointKey = $"{endpointConfig.Method}:{endpointConfig.Path}";
  66. if (!CheckLimit(context, endpointKey, endpointConfig.Limit, endpointConfig.Period, endpointConfig.Policy))
  67. {
  68. await ReturnRateLimitedResponse(context, endpointConfig);
  69. return;
  70. }
  71. }
  72. await _next(context);
  73. }
  74. private bool CheckLimit(HttpContext context, string endpointKey, int limit, int period, RateLimitPolicy policy)
  75. {
  76. var identifier = GetIdentifier(context, policy);
  77. var cacheKey = $"ratelimit:{endpointKey}:{identifier}";
  78. var now = DateTime.UtcNow;
  79. var windowStart = now.AddSeconds(-period);
  80. if (!_cache.TryGetValue<RateLimitCounter>(cacheKey, out var counter))
  81. {
  82. counter = new RateLimitCounter
  83. {
  84. Count = 1,
  85. FirstRequestTime = now,
  86. LastRequestTime = now
  87. };
  88. _cache.Set(cacheKey, counter, TimeSpan.FromSeconds(period));
  89. return true;
  90. }
  91. // 如果第一个请求已经超过时间窗口,重置计数器
  92. if (counter.FirstRequestTime < windowStart)
  93. {
  94. counter.Count = 1;
  95. counter.FirstRequestTime = now;
  96. }
  97. else if (counter.Count >= limit)
  98. {
  99. var remainingSeconds = (int)(counter.FirstRequestTime.AddSeconds(period) - now).TotalSeconds;
  100. _logger.LogWarning($"限流触发 - Endpoint: {endpointKey}, " +
  101. $"Identifier: {identifier}, " +
  102. $"Count: {counter.Count}, " +
  103. $"Limit: {limit}, " +
  104. $"Remaining: {remainingSeconds}s");
  105. return false;
  106. }
  107. else
  108. {
  109. counter.Count++;
  110. }
  111. counter.LastRequestTime = now;
  112. _cache.Set(cacheKey, counter, TimeSpan.FromSeconds(period));
  113. return true;
  114. }
  115. private string GetIdentifier(HttpContext context, RateLimitPolicy policy)
  116. {
  117. switch (policy)
  118. {
  119. case RateLimitPolicy.User:
  120. if (context.User?.Identity?.IsAuthenticated == true)
  121. {
  122. var userId = context.User.FindFirst("sub")?.Value
  123. ?? context.User.FindFirst(System.Security.Claims.ClaimTypes.NameIdentifier)?.Value
  124. ?? context.User.Identity.Name;
  125. return $"user_{userId}";
  126. }
  127. // 如果未登录,回退到IP
  128. return $"ip_{context.Connection.RemoteIpAddress}";
  129. case RateLimitPolicy.Global:
  130. return "global";
  131. case RateLimitPolicy.Client:
  132. var clientId = context.Request.Headers["X-Client-Id"].FirstOrDefault()
  133. ?? context.Request.Headers["X-API-Key"].FirstOrDefault();
  134. return $"client_{clientId ?? "unknown"}";
  135. case RateLimitPolicy.IP:
  136. default:
  137. var ipAddress = context.Connection.RemoteIpAddress?.ToString() ?? "unknown";
  138. // 处理IPv6映射的IPv4地址
  139. if (ipAddress.Contains("::ffff:"))
  140. {
  141. ipAddress = ipAddress.Replace("::ffff:", "");
  142. }
  143. return $"ip_{ipAddress}";
  144. }
  145. }
  146. private bool IsPathMatch(string requestPath, string configPath)
  147. {
  148. if (configPath == "*") return true;
  149. configPath = configPath.ToLower();
  150. // 精确匹配
  151. if (requestPath.Equals(configPath, StringComparison.OrdinalIgnoreCase))
  152. return true;
  153. // 前缀匹配(以/结尾表示前缀匹配)
  154. if (configPath.EndsWith("/") && requestPath.StartsWith(configPath))
  155. return true;
  156. // 通配符匹配(简单的*通配符)
  157. if (configPath.Contains("*"))
  158. {
  159. var pattern = "^" + System.Text.RegularExpressions.Regex.Escape(configPath)
  160. .Replace("\\*", ".*") + "$";
  161. return System.Text.RegularExpressions.Regex.IsMatch(requestPath, pattern);
  162. }
  163. return false;
  164. }
  165. private bool IsMethodMatch(string requestMethod, string configMethod)
  166. {
  167. if (configMethod == "*") return true;
  168. return requestMethod.Equals(configMethod, StringComparison.OrdinalIgnoreCase);
  169. }
  170. private bool ShouldSkipRateLimit(string path)
  171. {
  172. // 跳过健康检查、swagger等
  173. var skipPaths = new[]
  174. {
  175. "/health",
  176. //"/swagger",
  177. "/favicon.ico",
  178. "/robots.txt",
  179. "/.well-known"
  180. };
  181. return skipPaths.Any(p => path.StartsWith(p, StringComparison.OrdinalIgnoreCase));
  182. }
  183. private async Task ReturnRateLimitedResponse(HttpContext context, EndpointRateLimit config = null)
  184. {
  185. context.Response.StatusCode = 429; // Too Many Requests
  186. context.Response.ContentType = "application/json";
  187. var message = config != null
  188. ? $"接口访问过于频繁,请{config.Period}秒后再试"
  189. : "请求过于频繁,请稍后再试";
  190. var response = new
  191. {
  192. Code = 429,
  193. Msg = message
  194. };
  195. await context.Response.WriteAsJsonAsync(response);
  196. }
  197. }
  198. /// <summary>
  199. /// 限流计数器
  200. /// </summary>
  201. public class RateLimitCounter
  202. {
  203. public int Count { get; set; }
  204. public DateTime FirstRequestTime { get; set; }
  205. public DateTime LastRequestTime { get; set; }
  206. }
  207. #region 配置模型
  208. /// <summary>
  209. /// 限流策略枚举
  210. /// </summary>
  211. public enum RateLimitPolicy
  212. {
  213. /// <summary>
  214. /// 按IP地址限流
  215. /// </summary>
  216. IP = 0,
  217. /// <summary>
  218. /// 按用户限流(需要用户登录)
  219. /// </summary>
  220. User = 1,
  221. /// <summary>
  222. /// 全局限流(所有用户共享限制)
  223. /// </summary>
  224. Global = 2,
  225. /// <summary>
  226. /// 按客户端ID限流
  227. /// </summary>
  228. Client = 3
  229. }
  230. public class RateLimitConfig
  231. {
  232. public bool Enabled { get; set; } = true;
  233. public int DefaultLimit { get; set; } = 10;
  234. public int DefaultPeriod { get; set; } = 1;
  235. public List<EndpointRateLimit> Endpoints { get; set; } = new List<EndpointRateLimit>();
  236. }
  237. public class EndpointRateLimit
  238. {
  239. public string Path { get; set; }
  240. public string Method { get; set; } = "*";
  241. public int Limit { get; set; }
  242. public int Period { get; set; } // 秒
  243. public RateLimitPolicy Policy { get; set; } = RateLimitPolicy.IP;
  244. }
  245. #endregion
  246. }