Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support IP rate limiting #59

Merged
merged 5 commits into from
Aug 10, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions controllers/rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ func checkExpressions(expressions []*object.Expression, ruleType string) error {
return checkWafRule(values)
case "IP":
return checkIpRule(values)
case "IpRate":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Label is "IP Rate Limiting"

return checkIpRateRule(expressions)
}
return nil
}
Expand All @@ -157,3 +159,19 @@ func checkIpRule(ipLists []string) error {
}
return nil
}

func checkIpRateRule(expressions []*object.Expression) error {
if len(expressions) != 1 {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this be zero?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not allowed to be zero.

return errors.New("IpRate rule should have only one expression")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to limit this in frontend too

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It has been limited in frontend.

}
expression := expressions[0]
_, err := util.ParseIntWithError(expression.Operator)
if err != nil {
return err
}
_, err = util.ParseIntWithError(expression.Value)
if err != nil {
return err
}
return nil
}
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,6 @@ require (
github.com/xorm-io/core v0.7.4
github.com/xorm-io/xorm v1.1.6
golang.org/x/net v0.21.0
golang.org/x/time v0.0.0-20191024005414-555d28b269f0
modernc.org/sqlite v1.11.2
)
1 change: 1 addition & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -762,6 +762,7 @@ golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
Expand Down
4 changes: 4 additions & 0 deletions rule/rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ func CheckRules(ruleIds []string, r *http.Request) (string, string, error) {
ruleObj = &IpRule{}
case "WAF":
ruleObj = &WafRule{}
case "IpRate":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Label is "IP Rate Limiting"

ruleObj = &IpRateRule{
ruleName: rule.GetId(),
}
default:
return "", "", fmt.Errorf("unknown rule type: %s for rule: %s", rule.Type, rule.GetId())
}
Expand Down
128 changes: 128 additions & 0 deletions rule/rule_ip_rate.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
// Copyright 2024 The casbin Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package rule

import (
"net/http"
"sync"
"time"

"github.com/casbin/caswaf/object"
"github.com/casbin/caswaf/util"
"golang.org/x/time/rate"
)

type IpRateRule struct {
ruleName string
}

type IpRateLimiter struct {
ips map[string]*rate.Limiter
mu *sync.RWMutex
r rate.Limit
b int
}

var blackList = map[string]map[string]time.Time{}

var ipRateLimiters = map[string]*IpRateLimiter{}

// NewIpRateLimiter .
func NewIpRateLimiter(r rate.Limit, b int) *IpRateLimiter {
i := &IpRateLimiter{
ips: make(map[string]*rate.Limiter),
mu: &sync.RWMutex{},
r: r,
b: b,
}

return i
}

// AddIP creates a new rate limiter and adds it to the ips map,
// using the IP address as the key
func (i *IpRateLimiter) AddIP(ip string) *rate.Limiter {
i.mu.Lock()
defer i.mu.Unlock()

limiter := rate.NewLimiter(i.r, i.b)

i.ips[ip] = limiter

return limiter
}

// GetLimiter returns the rate limiter for the provided IP address if it exists.
// Otherwise, calls AddIP to add IP address to the map
func (i *IpRateLimiter) GetLimiter(ip string) *rate.Limiter {
i.mu.Lock()
limiter, exists := i.ips[ip]

if !exists {
i.mu.Unlock()
return i.AddIP(ip)
}

i.mu.Unlock()

return limiter
}

func (r *IpRateRule) checkRule(expressions []*object.Expression, req *http.Request) (bool, string, string, error) {
expression := expressions[0] // IpRate rule should have only one expression
clientIp := util.GetClientIp(req)

// If the client IP is in the blacklist, check the block time
createAt, ok := blackList[r.ruleName][clientIp]
if ok {
blockTime := util.ParseInt(expression.Value)
if time.Now().Sub(createAt) < time.Duration(blockTime)*time.Second {
return true, "Block", "Rate limit exceeded", nil
} else {
delete(blackList, clientIp)
}
}

// If the client IP is not in the blacklist, check the rate limit
ipRateLimiter := ipRateLimiters[r.ruleName]
parseInt := util.ParseInt(expression.Operator)
if ipRateLimiter == nil {
ipRateLimiter = NewIpRateLimiter(rate.Limit(parseInt), parseInt)
ipRateLimiters[r.ruleName] = ipRateLimiter
}

// If the rate limit has changed, update the rate limiter
limiter := ipRateLimiter.GetLimiter(clientIp)
if ipRateLimiter.r != rate.Limit(parseInt) {
ipRateLimiter.r = rate.Limit(parseInt)
ipRateLimiter.b = parseInt
limiter.SetLimit(ipRateLimiter.r)
limiter.SetBurst(ipRateLimiter.b)
err := limiter.Wait(req.Context())
if err != nil {
return false, "", "", err
}
} else {
// If the rate limit is exceeded, add the client IP to the blacklist
allow := limiter.Allow()
if !allow {
blackList[r.ruleName] = map[string]time.Time{}
blackList[r.ruleName][clientIp] = time.Now()
return true, "Block", "Rate limit exceeded", nil
}
}

return false, "", "", nil
}
133 changes: 133 additions & 0 deletions rule/rule_ip_rate_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
package rule

import (
"net/http"
"testing"

"github.com/casbin/caswaf/object"
)

func TestIpRateRule_checkRule(t *testing.T) {
type fields struct {
ruleName string
}
type args struct {
args []struct {
expressions []*object.Expression
req *http.Request
}
}

tests := []struct {
name string
fields fields
args args
want []bool
want1 []string
want2 []string
wantErr []bool
}{
{
name: "Test 1",
fields: fields{
ruleName: "rule1",
},
args: args{
args: []struct {
expressions []*object.Expression
req *http.Request
}{
{
expressions: []*object.Expression{
{
Operator: "1",
Value: "1",
},
},
req: &http.Request{
RemoteAddr: "127.0.0.1",
},
},
{
expressions: []*object.Expression{
{
Operator: "1",
Value: "1",
},
},
req: &http.Request{
RemoteAddr: "127.0.0.1",
},
},
},
},
want: []bool{false, true},
want1: []string{"", "Block"},
want2: []string{"", "Rate limit exceeded"},
wantErr: []bool{false, false},
},
{
name: "Test 2",
fields: fields{
ruleName: "rule2",
},
args: args{
args: []struct {
expressions []*object.Expression
req *http.Request
}{
{
expressions: []*object.Expression{
{
Operator: "1",
Value: "1",
},
},
req: &http.Request{
RemoteAddr: "127.0.0.1",
},
},
{
expressions: []*object.Expression{
{
Operator: "10",
Value: "1",
},
},
req: &http.Request{
RemoteAddr: "127.0.0.1",
},
},
},
},
want: []bool{false, false},
want1: []string{"", ""},
want2: []string{"", ""},
wantErr: []bool{false, false},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := &IpRateRule{
ruleName: tt.fields.ruleName,
}
for i, arg := range tt.args.args {
got, got1, got2, err := r.checkRule(arg.expressions, arg.req)
if (err != nil) != tt.wantErr[i] {
t.Errorf("checkRule() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want[i] {
t.Errorf("checkRule() got = %v, want %v", got, tt.want)
}
if got1 != tt.want1[i] {
t.Errorf("checkRule() got1 = %v, want %v", got1, tt.want1)
}
if got2 != tt.want2[i] {
t.Errorf("checkRule() got2 = %v, want %v", got2, tt.want2)
}
}
})
}
}
23 changes: 22 additions & 1 deletion web/src/RuleEditPage.js
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import i18next from "i18next";
import WafRuleTable from "./components/WafRuleTable";
import IpRuleTable from "./components/IpRuleTable";
import UaRuleTable from "./components/UaRuleTable";
import IpRateRuleTable from "./components/IpRateRuleTable";

const {Option} = Select;

Expand Down Expand Up @@ -57,6 +58,15 @@ class RuleEditPage extends React.Component {
});
}

updateRuleFieldInExpressions(index, key, value) {
const rule = Setting.deepCopy(this.state.rule);
rule.expressions[index][key] = value;
this.updateRuleField("expressions", rule.expressions);
this.setState({
rule: rule,
});
}

renderRule() {
return (
<Card size="small" title={
Expand Down Expand Up @@ -86,7 +96,7 @@ class RuleEditPage extends React.Component {
{value: "WAF", text: "WAF"},
{value: "IP", text: "IP"},
{value: "User-Agent", text: "User-Agent"},
// {value: "frequency", text: "Frequency"},
{value: "IpRate", text: "IP Rate"},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Label is "IP Rate Limiting"

// {value: "complex", text: "Complex"},
].map((item, index) => <Option key={index} value={item.value}>{item.text}</Option>)
}
Expand Down Expand Up @@ -131,6 +141,17 @@ class RuleEditPage extends React.Component {
/>
) : null
}
{
this.state.rule.type === "IpRate" ? (
<IpRateRuleTable
title={"IP Rate"}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Label is "IP Rate Limiting"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i18n

table={this.state.rule.expressions}
ruleName={this.state.rule.name}
account={this.props.account}
onUpdateTable={(value) => {this.updateRuleField("expressions", value);}}
/>
) : null
}
</Col>
</Row>
{
Expand Down
Loading
Loading