|
@@ -0,0 +1,288 @@
|
|
|
|
|
+// Middlewares/RateLimitMiddleware.cs
|
|
|
|
|
+using AspNetCoreRateLimit;
|
|
|
|
|
+using Microsoft.AspNetCore.Http;
|
|
|
|
|
+using Microsoft.Extensions.Caching.Memory;
|
|
|
|
|
+using Microsoft.Extensions.Options;
|
|
|
|
|
+using System;
|
|
|
|
|
+using System.Linq;
|
|
|
|
|
+using System.Threading.Tasks;
|
|
|
|
|
+
|
|
|
|
|
+namespace OASystem.API.Middlewares
|
|
|
|
|
+{
|
|
|
|
|
+ public class RateLimitMiddleware
|
|
|
|
|
+ {
|
|
|
|
|
+ private readonly RequestDelegate _next;
|
|
|
|
|
+ private readonly IMemoryCache _cache;
|
|
|
|
|
+ private readonly RateLimitConfig _config;
|
|
|
|
|
+ private readonly ILogger<RateLimitMiddleware> _logger;
|
|
|
|
|
+
|
|
|
|
|
+ public RateLimitMiddleware(
|
|
|
|
|
+ RequestDelegate next,
|
|
|
|
|
+ IMemoryCache cache,
|
|
|
|
|
+ IOptions<RateLimitConfig> config,
|
|
|
|
|
+ ILogger<RateLimitMiddleware> logger)
|
|
|
|
|
+ {
|
|
|
|
|
+ _next = next;
|
|
|
|
|
+ _cache = cache;
|
|
|
|
|
+ _config = config.Value;
|
|
|
|
|
+ _logger = logger;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ public async Task InvokeAsync(HttpContext context)
|
|
|
|
|
+ {
|
|
|
|
|
+ if (!_config.Enabled)
|
|
|
|
|
+ {
|
|
|
|
|
+ await _next(context);
|
|
|
|
|
+ return;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ var path = context.Request.Path.ToString().ToLower();
|
|
|
|
|
+ var method = context.Request.Method.ToUpper();
|
|
|
|
|
+
|
|
|
|
|
+ // 跳过不需要限流的路径
|
|
|
|
|
+ if (ShouldSkipRateLimit(path))
|
|
|
|
|
+ {
|
|
|
|
|
+ await _next(context);
|
|
|
|
|
+ return;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // 查找匹配的配置
|
|
|
|
|
+ EndpointRateLimit endpointConfig = null;
|
|
|
|
|
+
|
|
|
|
|
+ foreach (var configItem in _config.Endpoints)
|
|
|
|
|
+ {
|
|
|
|
|
+ if (IsPathMatch(path, configItem.Path) &&
|
|
|
|
|
+ IsMethodMatch(method, configItem.Method))
|
|
|
|
|
+ {
|
|
|
|
|
+ endpointConfig = configItem;
|
|
|
|
|
+ break;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ if (endpointConfig == null)
|
|
|
|
|
+ {
|
|
|
|
|
+ // 使用默认配置
|
|
|
|
|
+ if (!CheckLimit(context, "default", _config.DefaultLimit, _config.DefaultPeriod, RateLimitPolicy.IP))
|
|
|
|
|
+ {
|
|
|
|
|
+ await ReturnRateLimitedResponse(context);
|
|
|
|
|
+ return;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ else
|
|
|
|
|
+ {
|
|
|
|
|
+ var endpointKey = $"{endpointConfig.Method}:{endpointConfig.Path}";
|
|
|
|
|
+ if (!CheckLimit(context, endpointKey, endpointConfig.Limit, endpointConfig.Period, endpointConfig.Policy))
|
|
|
|
|
+ {
|
|
|
|
|
+ await ReturnRateLimitedResponse(context, endpointConfig);
|
|
|
|
|
+ return;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ await _next(context);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ private bool CheckLimit(HttpContext context, string endpointKey, int limit, int period, RateLimitPolicy policy)
|
|
|
|
|
+ {
|
|
|
|
|
+ var identifier = GetIdentifier(context, policy);
|
|
|
|
|
+ var cacheKey = $"ratelimit:{endpointKey}:{identifier}";
|
|
|
|
|
+ var now = DateTime.UtcNow;
|
|
|
|
|
+ var windowStart = now.AddSeconds(-period);
|
|
|
|
|
+
|
|
|
|
|
+ if (!_cache.TryGetValue<RateLimitCounter>(cacheKey, out var counter))
|
|
|
|
|
+ {
|
|
|
|
|
+ counter = new RateLimitCounter
|
|
|
|
|
+ {
|
|
|
|
|
+ Count = 1,
|
|
|
|
|
+ FirstRequestTime = now,
|
|
|
|
|
+ LastRequestTime = now
|
|
|
|
|
+ };
|
|
|
|
|
+
|
|
|
|
|
+ _cache.Set(cacheKey, counter, TimeSpan.FromSeconds(period));
|
|
|
|
|
+ return true;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // 如果第一个请求已经超过时间窗口,重置计数器
|
|
|
|
|
+ if (counter.FirstRequestTime < windowStart)
|
|
|
|
|
+ {
|
|
|
|
|
+ counter.Count = 1;
|
|
|
|
|
+ counter.FirstRequestTime = now;
|
|
|
|
|
+ }
|
|
|
|
|
+ else if (counter.Count >= limit)
|
|
|
|
|
+ {
|
|
|
|
|
+ var remainingSeconds = (int)(counter.FirstRequestTime.AddSeconds(period) - now).TotalSeconds;
|
|
|
|
|
+ _logger.LogWarning($"限流触发 - Endpoint: {endpointKey}, " +
|
|
|
|
|
+ $"Identifier: {identifier}, " +
|
|
|
|
|
+ $"Count: {counter.Count}, " +
|
|
|
|
|
+ $"Limit: {limit}, " +
|
|
|
|
|
+ $"Remaining: {remainingSeconds}s");
|
|
|
|
|
+ return false;
|
|
|
|
|
+ }
|
|
|
|
|
+ else
|
|
|
|
|
+ {
|
|
|
|
|
+ counter.Count++;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ counter.LastRequestTime = now;
|
|
|
|
|
+ _cache.Set(cacheKey, counter, TimeSpan.FromSeconds(period));
|
|
|
|
|
+
|
|
|
|
|
+ return true;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ private string GetIdentifier(HttpContext context, RateLimitPolicy policy)
|
|
|
|
|
+ {
|
|
|
|
|
+ switch (policy)
|
|
|
|
|
+ {
|
|
|
|
|
+ case RateLimitPolicy.User:
|
|
|
|
|
+ if (context.User?.Identity?.IsAuthenticated == true)
|
|
|
|
|
+ {
|
|
|
|
|
+ var userId = context.User.FindFirst("sub")?.Value
|
|
|
|
|
+ ?? context.User.FindFirst(System.Security.Claims.ClaimTypes.NameIdentifier)?.Value
|
|
|
|
|
+ ?? context.User.Identity.Name;
|
|
|
|
|
+ return $"user_{userId}";
|
|
|
|
|
+ }
|
|
|
|
|
+ // 如果未登录,回退到IP
|
|
|
|
|
+ return $"ip_{context.Connection.RemoteIpAddress}";
|
|
|
|
|
+
|
|
|
|
|
+ case RateLimitPolicy.Global:
|
|
|
|
|
+ return "global";
|
|
|
|
|
+
|
|
|
|
|
+ case RateLimitPolicy.Client:
|
|
|
|
|
+ var clientId = context.Request.Headers["X-Client-Id"].FirstOrDefault()
|
|
|
|
|
+ ?? context.Request.Headers["X-API-Key"].FirstOrDefault();
|
|
|
|
|
+ return $"client_{clientId ?? "unknown"}";
|
|
|
|
|
+
|
|
|
|
|
+ case RateLimitPolicy.IP:
|
|
|
|
|
+ default:
|
|
|
|
|
+ var ipAddress = context.Connection.RemoteIpAddress?.ToString() ?? "unknown";
|
|
|
|
|
+ // 处理IPv6映射的IPv4地址
|
|
|
|
|
+ if (ipAddress.Contains("::ffff:"))
|
|
|
|
|
+ {
|
|
|
|
|
+ ipAddress = ipAddress.Replace("::ffff:", "");
|
|
|
|
|
+ }
|
|
|
|
|
+ return $"ip_{ipAddress}";
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ private bool IsPathMatch(string requestPath, string configPath)
|
|
|
|
|
+ {
|
|
|
|
|
+ if (configPath == "*") return true;
|
|
|
|
|
+
|
|
|
|
|
+ configPath = configPath.ToLower();
|
|
|
|
|
+
|
|
|
|
|
+ // 精确匹配
|
|
|
|
|
+ if (requestPath.Equals(configPath, StringComparison.OrdinalIgnoreCase))
|
|
|
|
|
+ return true;
|
|
|
|
|
+
|
|
|
|
|
+ // 前缀匹配(以/结尾表示前缀匹配)
|
|
|
|
|
+ if (configPath.EndsWith("/") && requestPath.StartsWith(configPath))
|
|
|
|
|
+ return true;
|
|
|
|
|
+
|
|
|
|
|
+ // 通配符匹配(简单的*通配符)
|
|
|
|
|
+ if (configPath.Contains("*"))
|
|
|
|
|
+ {
|
|
|
|
|
+ var pattern = "^" + System.Text.RegularExpressions.Regex.Escape(configPath)
|
|
|
|
|
+ .Replace("\\*", ".*") + "$";
|
|
|
|
|
+ return System.Text.RegularExpressions.Regex.IsMatch(requestPath, pattern);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ return false;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ private bool IsMethodMatch(string requestMethod, string configMethod)
|
|
|
|
|
+ {
|
|
|
|
|
+ if (configMethod == "*") return true;
|
|
|
|
|
+
|
|
|
|
|
+ return requestMethod.Equals(configMethod, StringComparison.OrdinalIgnoreCase);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ private bool ShouldSkipRateLimit(string path)
|
|
|
|
|
+ {
|
|
|
|
|
+ // 跳过健康检查、swagger等
|
|
|
|
|
+ var skipPaths = new[]
|
|
|
|
|
+ {
|
|
|
|
|
+ "/health",
|
|
|
|
|
+ //"/swagger",
|
|
|
|
|
+ "/favicon.ico",
|
|
|
|
|
+ "/robots.txt",
|
|
|
|
|
+ "/.well-known"
|
|
|
|
|
+ };
|
|
|
|
|
+
|
|
|
|
|
+ return skipPaths.Any(p => path.StartsWith(p, StringComparison.OrdinalIgnoreCase));
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ private async Task ReturnRateLimitedResponse(HttpContext context, EndpointRateLimit config = null)
|
|
|
|
|
+ {
|
|
|
|
|
+ context.Response.StatusCode = 429; // Too Many Requests
|
|
|
|
|
+ context.Response.ContentType = "application/json";
|
|
|
|
|
+
|
|
|
|
|
+ var message = config != null
|
|
|
|
|
+ ? $"接口访问过于频繁,请{config.Period}秒后再试"
|
|
|
|
|
+ : "请求过于频繁,请稍后再试";
|
|
|
|
|
+
|
|
|
|
|
+ var response = new
|
|
|
|
|
+ {
|
|
|
|
|
+ Code = 429,
|
|
|
|
|
+ Msg = message
|
|
|
|
|
+ };
|
|
|
|
|
+
|
|
|
|
|
+ await context.Response.WriteAsJsonAsync(response);
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ /// <summary>
|
|
|
|
|
+ /// 限流计数器
|
|
|
|
|
+ /// </summary>
|
|
|
|
|
+ public class RateLimitCounter
|
|
|
|
|
+ {
|
|
|
|
|
+ public int Count { get; set; }
|
|
|
|
|
+ public DateTime FirstRequestTime { get; set; }
|
|
|
|
|
+ public DateTime LastRequestTime { get; set; }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ #region 配置模型
|
|
|
|
|
+
|
|
|
|
|
+ /// <summary>
|
|
|
|
|
+ /// 限流策略枚举
|
|
|
|
|
+ /// </summary>
|
|
|
|
|
+ public enum RateLimitPolicy
|
|
|
|
|
+ {
|
|
|
|
|
+ /// <summary>
|
|
|
|
|
+ /// 按IP地址限流
|
|
|
|
|
+ /// </summary>
|
|
|
|
|
+ IP = 0,
|
|
|
|
|
+
|
|
|
|
|
+ /// <summary>
|
|
|
|
|
+ /// 按用户限流(需要用户登录)
|
|
|
|
|
+ /// </summary>
|
|
|
|
|
+ User = 1,
|
|
|
|
|
+
|
|
|
|
|
+ /// <summary>
|
|
|
|
|
+ /// 全局限流(所有用户共享限制)
|
|
|
|
|
+ /// </summary>
|
|
|
|
|
+ Global = 2,
|
|
|
|
|
+
|
|
|
|
|
+ /// <summary>
|
|
|
|
|
+ /// 按客户端ID限流
|
|
|
|
|
+ /// </summary>
|
|
|
|
|
+ Client = 3
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ public class RateLimitConfig
|
|
|
|
|
+ {
|
|
|
|
|
+ public bool Enabled { get; set; } = true;
|
|
|
|
|
+ public int DefaultLimit { get; set; } = 10;
|
|
|
|
|
+ public int DefaultPeriod { get; set; } = 1;
|
|
|
|
|
+
|
|
|
|
|
+ public List<EndpointRateLimit> Endpoints { get; set; } = new List<EndpointRateLimit>();
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ public class EndpointRateLimit
|
|
|
|
|
+ {
|
|
|
|
|
+ public string Path { get; set; }
|
|
|
|
|
+ public string Method { get; set; } = "*";
|
|
|
|
|
+ public int Limit { get; set; }
|
|
|
|
|
+ public int Period { get; set; } // 秒
|
|
|
|
|
+ public RateLimitPolicy Policy { get; set; } = RateLimitPolicy.IP;
|
|
|
|
|
+ }
|
|
|
|
|
+ #endregion
|
|
|
|
|
+}
|