1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
| #ifndef RATE_LIMITER_H #define RATE_LIMITER_H
#include "DataCollector.h" #include "cmmlib/CmmBase.h" #include <map> #include <memory> #include <chrono>
namespace RateLimitFramework {
class CMM_API RateLimitHandler { public: virtual void handle(const std::string& key, const std::map<std::string, std::string>& context) = 0; virtual ~RateLimitHandler() = default; };
class CMM_API RateLimiter { private: struct Strategy { RateLimitStrategy config; std::shared_ptr<DataCollector> dataCollector; std::shared_ptr<RateLimitHandler> handler; std::chrono::steady_clock::time_point lastExecutionTime; double tokens;
Strategy(const RateLimitStrategy& cfg, std::shared_ptr<DataCollector> dc, std::shared_ptr<RateLimitHandler> rh) : config(cfg), dataCollector(dc), handler(rh), lastExecutionTime(std::chrono::steady_clock::now()), tokens(cfg.threshold) {}
Strategy() : config{}, dataCollector(nullptr), handler(nullptr), lastExecutionTime(std::chrono::steady_clock::now()), tokens(0.0) {} };
std::map<StrategyType, Strategy> strategies;
public: void addStrategy(const RateLimitStrategy& strategy, std::shared_ptr<DataCollector> dataCollector, std::shared_ptr<RateLimitHandler> handler); bool checkLimit(StrategyType strategyType, const std::string& key);
private: bool checkSimpleCounter(const Strategy& strategy, const std::string& key); bool checkSlidingWindow(const Strategy& strategy, const std::string& key); bool checkTokenBucket(Strategy& strategy, const std::string& key); };
}
#endif
#include "cmmlib/ratelimit/RateLimit.h" #include <map> #include <algorithm>
using namespace std; using namespace std::chrono; using namespace RateLimitFramework;
namespace RateLimitFramework {
void RateLimiter::addStrategy(const RateLimitStrategy& strategy, shared_ptr<DataCollector> dataCollector, shared_ptr<RateLimitHandler> handler) { strategies[strategy.strategyType] = Strategy(strategy, dataCollector, handler); }
bool RateLimiter::checkLimit(StrategyType strategyType, const std::string& key) { auto strategyIt = strategies.find(strategyType); if (strategyIt == strategies.end()) { return true; }
auto& strategy = strategyIt->second; strategy.dataCollector->record(key, strategy.config);
switch (strategy.config.algorithm) { case RateLimitAlgorithm::SIMPLE_COUNTER: return checkSimpleCounter(strategy, key); case RateLimitAlgorithm::SLIDING_WINDOW: return checkSlidingWindow(strategy, key); case RateLimitAlgorithm::TOKEN_BUCKET: return checkTokenBucket(strategy, key); } return true; }
bool RateLimiter::checkSimpleCounter(const Strategy& strategy, const std::string& key) { int count = strategy.dataCollector->getData(key, strategy.config); if (count > strategy.config.threshold) { std::map<std::string, std::string> context; strategy.handler->handle(key, context); return false; } return true; }
bool RateLimiter::checkSlidingWindow(const Strategy& strategy, const std::string& key) { int count = strategy.dataCollector->getData(key, strategy.config); if (count > strategy.config.threshold) { std::map<std::string, std::string> context; strategy.handler->handle(key, context); return false; } return true; }
bool RateLimiter::checkTokenBucket(Strategy& strategy, const std::string& key) { auto now = steady_clock::now(); auto timePassed = chrono::duration_cast<chrono::milliseconds>(now - strategy.lastExecutionTime).count(); strategy.tokens += timePassed * (strategy.config.threshold / static_cast<double>(strategy.config.timeWindow.count())); strategy.tokens = (std::min)(strategy.tokens, static_cast<double>(strategy.config.threshold)); strategy.lastExecutionTime = now;
if (strategy.tokens < 1.0) { std::map<std::string, std::string> context; strategy.handler->handle(key, context); return false; } else { strategy.tokens -= 1.0; return true; } } }
|