| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288 |
- // Middlewares/RateLimitMiddleware.cs
- using AspNetCoreRateLimit;
- using Microsoft.AspNetCore.Http;
- using Microsoft.Extensions.Caching.Memory;
- using Microsoft.Extensions.Options;
- using System;
- using System.Linq;
- using System.Threading.Tasks;
- namespace OASystem.API.Middlewares
- {
- public class RateLimitMiddleware
- {
- private readonly RequestDelegate _next;
- private readonly IMemoryCache _cache;
- private readonly RateLimitConfig _config;
- private readonly ILogger<RateLimitMiddleware> _logger;
- public RateLimitMiddleware(
- RequestDelegate next,
- IMemoryCache cache,
- IOptions<RateLimitConfig> config,
- ILogger<RateLimitMiddleware> logger)
- {
- _next = next;
- _cache = cache;
- _config = config.Value;
- _logger = logger;
- }
- public async Task InvokeAsync(HttpContext context)
- {
- if (!_config.Enabled)
- {
- await _next(context);
- return;
- }
- var path = context.Request.Path.ToString().ToLower();
- var method = context.Request.Method.ToUpper();
- // 跳过不需要限流的路径
- if (ShouldSkipRateLimit(path))
- {
- await _next(context);
- return;
- }
- // 查找匹配的配置
- EndpointRateLimit endpointConfig = null;
- foreach (var configItem in _config.Endpoints)
- {
- if (IsPathMatch(path, configItem.Path) &&
- IsMethodMatch(method, configItem.Method))
- {
- endpointConfig = configItem;
- break;
- }
- }
- if (endpointConfig == null)
- {
- // 使用默认配置
- if (!CheckLimit(context, "default", _config.DefaultLimit, _config.DefaultPeriod, RateLimitPolicy.IP))
- {
- await ReturnRateLimitedResponse(context);
- return;
- }
- }
- else
- {
- var endpointKey = $"{endpointConfig.Method}:{endpointConfig.Path}";
- if (!CheckLimit(context, endpointKey, endpointConfig.Limit, endpointConfig.Period, endpointConfig.Policy))
- {
- await ReturnRateLimitedResponse(context, endpointConfig);
- return;
- }
- }
- await _next(context);
- }
- private bool CheckLimit(HttpContext context, string endpointKey, int limit, int period, RateLimitPolicy policy)
- {
- var identifier = GetIdentifier(context, policy);
- var cacheKey = $"ratelimit:{endpointKey}:{identifier}";
- var now = DateTime.UtcNow;
- var windowStart = now.AddSeconds(-period);
- if (!_cache.TryGetValue<RateLimitCounter>(cacheKey, out var counter))
- {
- counter = new RateLimitCounter
- {
- Count = 1,
- FirstRequestTime = now,
- LastRequestTime = now
- };
- _cache.Set(cacheKey, counter, TimeSpan.FromSeconds(period));
- return true;
- }
- // 如果第一个请求已经超过时间窗口,重置计数器
- if (counter.FirstRequestTime < windowStart)
- {
- counter.Count = 1;
- counter.FirstRequestTime = now;
- }
- else if (counter.Count >= limit)
- {
- var remainingSeconds = (int)(counter.FirstRequestTime.AddSeconds(period) - now).TotalSeconds;
- _logger.LogWarning($"限流触发 - Endpoint: {endpointKey}, " +
- $"Identifier: {identifier}, " +
- $"Count: {counter.Count}, " +
- $"Limit: {limit}, " +
- $"Remaining: {remainingSeconds}s");
- return false;
- }
- else
- {
- counter.Count++;
- }
- counter.LastRequestTime = now;
- _cache.Set(cacheKey, counter, TimeSpan.FromSeconds(period));
- return true;
- }
- private string GetIdentifier(HttpContext context, RateLimitPolicy policy)
- {
- switch (policy)
- {
- case RateLimitPolicy.User:
- if (context.User?.Identity?.IsAuthenticated == true)
- {
- var userId = context.User.FindFirst("sub")?.Value
- ?? context.User.FindFirst(System.Security.Claims.ClaimTypes.NameIdentifier)?.Value
- ?? context.User.Identity.Name;
- return $"user_{userId}";
- }
- // 如果未登录,回退到IP
- return $"ip_{context.Connection.RemoteIpAddress}";
- case RateLimitPolicy.Global:
- return "global";
- case RateLimitPolicy.Client:
- var clientId = context.Request.Headers["X-Client-Id"].FirstOrDefault()
- ?? context.Request.Headers["X-API-Key"].FirstOrDefault();
- return $"client_{clientId ?? "unknown"}";
- case RateLimitPolicy.IP:
- default:
- var ipAddress = context.Connection.RemoteIpAddress?.ToString() ?? "unknown";
- // 处理IPv6映射的IPv4地址
- if (ipAddress.Contains("::ffff:"))
- {
- ipAddress = ipAddress.Replace("::ffff:", "");
- }
- return $"ip_{ipAddress}";
- }
- }
- private bool IsPathMatch(string requestPath, string configPath)
- {
- if (configPath == "*") return true;
- configPath = configPath.ToLower();
- // 精确匹配
- if (requestPath.Equals(configPath, StringComparison.OrdinalIgnoreCase))
- return true;
- // 前缀匹配(以/结尾表示前缀匹配)
- if (configPath.EndsWith("/") && requestPath.StartsWith(configPath))
- return true;
- // 通配符匹配(简单的*通配符)
- if (configPath.Contains("*"))
- {
- var pattern = "^" + System.Text.RegularExpressions.Regex.Escape(configPath)
- .Replace("\\*", ".*") + "$";
- return System.Text.RegularExpressions.Regex.IsMatch(requestPath, pattern);
- }
- return false;
- }
- private bool IsMethodMatch(string requestMethod, string configMethod)
- {
- if (configMethod == "*") return true;
- return requestMethod.Equals(configMethod, StringComparison.OrdinalIgnoreCase);
- }
- private bool ShouldSkipRateLimit(string path)
- {
- // 跳过健康检查、swagger等
- var skipPaths = new[]
- {
- "/health",
- //"/swagger",
- "/favicon.ico",
- "/robots.txt",
- "/.well-known"
- };
- return skipPaths.Any(p => path.StartsWith(p, StringComparison.OrdinalIgnoreCase));
- }
- private async Task ReturnRateLimitedResponse(HttpContext context, EndpointRateLimit config = null)
- {
- context.Response.StatusCode = 429; // Too Many Requests
- context.Response.ContentType = "application/json";
- var message = config != null
- ? $"接口访问过于频繁,请{config.Period}秒后再试"
- : "请求过于频繁,请稍后再试";
- var response = new
- {
- Code = 429,
- Msg = message
- };
- await context.Response.WriteAsJsonAsync(response);
- }
- }
- /// <summary>
- /// 限流计数器
- /// </summary>
- public class RateLimitCounter
- {
- public int Count { get; set; }
- public DateTime FirstRequestTime { get; set; }
- public DateTime LastRequestTime { get; set; }
- }
- #region 配置模型
- /// <summary>
- /// 限流策略枚举
- /// </summary>
- public enum RateLimitPolicy
- {
- /// <summary>
- /// 按IP地址限流
- /// </summary>
- IP = 0,
- /// <summary>
- /// 按用户限流(需要用户登录)
- /// </summary>
- User = 1,
- /// <summary>
- /// 全局限流(所有用户共享限制)
- /// </summary>
- Global = 2,
- /// <summary>
- /// 按客户端ID限流
- /// </summary>
- Client = 3
- }
- public class RateLimitConfig
- {
- public bool Enabled { get; set; } = true;
- public int DefaultLimit { get; set; } = 10;
- public int DefaultPeriod { get; set; } = 1;
- public List<EndpointRateLimit> Endpoints { get; set; } = new List<EndpointRateLimit>();
- }
- public class EndpointRateLimit
- {
- public string Path { get; set; }
- public string Method { get; set; } = "*";
- public int Limit { get; set; }
- public int Period { get; set; } // 秒
- public RateLimitPolicy Policy { get; set; } = RateLimitPolicy.IP;
- }
- #endregion
- }
|