Skip to content

Commit

Permalink
feat: add WAF rule (#48)
Browse files Browse the repository at this point in the history
* feat: support waf rule in the backend

* feat: add default rules to waf

* feat: avoid editing the rule_cache from outside the object package

* Update proxy.go

---------

Co-authored-by: Gucheng <[email protected]>
  • Loading branch information
love98ooo and nomeguy authored Aug 2, 2024
1 parent 304cc79 commit dc05abe
Show file tree
Hide file tree
Showing 10 changed files with 191 additions and 138 deletions.
4 changes: 0 additions & 4 deletions controllers/rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import (
"strings"

"github.com/casbin/caswaf/object"
"github.com/casbin/caswaf/service"
"github.com/casbin/caswaf/util"
"github.com/hsluoyz/modsecurity-go/seclang/parser"
)
Expand Down Expand Up @@ -80,7 +79,6 @@ func (c *ApiController) AddRule() {
return
}
c.Data["json"] = wrapActionResponse(object.AddRule(&rule))
go service.UpdateWafs()
c.ServeJSON()
}

Expand All @@ -104,7 +102,6 @@ func (c *ApiController) UpdateRule() {

id := c.Input().Get("id")
c.Data["json"] = wrapActionResponse(object.UpdateRule(id, &rule))
go service.UpdateWafs()
c.ServeJSON()
}

Expand All @@ -121,7 +118,6 @@ func (c *ApiController) DeleteRule() {
}

c.Data["json"] = wrapActionResponse(object.DeleteRule(&rule))
go service.UpdateWafs()
c.ServeJSON()
}

Expand Down
36 changes: 17 additions & 19 deletions object/site.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import (
"github.com/casbin/caswaf/run"
"github.com/casbin/caswaf/util"
"github.com/casdoor/casdoor-go-sdk/casdoorsdk"
"github.com/corazawaf/coraza/v3"
"github.com/xorm-io/core"
)

Expand All @@ -44,24 +43,23 @@ type Site struct {
UpdatedTime string `xorm:"varchar(100)" json:"updatedTime"`
DisplayName string `xorm:"varchar(100)" json:"displayName"`

Tag string `xorm:"varchar(100)" json:"tag"`
Domain string `xorm:"varchar(100)" json:"domain"`
OtherDomains []string `xorm:"varchar(500)" json:"otherDomains"`
NeedRedirect bool `json:"needRedirect"`
EnableWaf bool `json:"enableWaf"`
Rules []string `xorm:"varchar(500)" json:"wafRuleIds"`
Waf coraza.WAF `xorm:"-" json:"-"`
Challenges []string `xorm:"mediumtext" json:"challenges"`
Host string `xorm:"varchar(100)" json:"host"`
Port int `json:"port"`
Hosts []string `xorm:"varchar(1000)" json:"hosts"`
SslMode string `xorm:"varchar(100)" json:"sslMode"`
SslCert string `xorm:"-" json:"sslCert"`
PublicIp string `xorm:"varchar(100)" json:"publicIp"`
Node string `xorm:"varchar(100)" json:"node"`
IsSelf bool `json:"isSelf"`
Status string `xorm:"varchar(100)" json:"status"`
Nodes []*Node `xorm:"mediumtext" json:"nodes"`
Tag string `xorm:"varchar(100)" json:"tag"`
Domain string `xorm:"varchar(100)" json:"domain"`
OtherDomains []string `xorm:"varchar(500)" json:"otherDomains"`
NeedRedirect bool `json:"needRedirect"`
EnableWaf bool `json:"enableWaf"`
Rules []string `xorm:"varchar(500)" json:"wafRuleIds"`
Challenges []string `xorm:"mediumtext" json:"challenges"`
Host string `xorm:"varchar(100)" json:"host"`
Port int `json:"port"`
Hosts []string `xorm:"varchar(1000)" json:"hosts"`
SslMode string `xorm:"varchar(100)" json:"sslMode"`
SslCert string `xorm:"-" json:"sslCert"`
PublicIp string `xorm:"varchar(100)" json:"publicIp"`
Node string `xorm:"varchar(100)" json:"node"`
IsSelf bool `json:"isSelf"`
Status string `xorm:"varchar(100)" json:"status"`
Nodes []*Node `xorm:"mediumtext" json:"nodes"`

CasdoorApplication string `xorm:"varchar(100)" json:"casdoorApplication"`
ApplicationObj *casdoorsdk.Application `xorm:"-" json:"applicationObj"`
Expand Down
12 changes: 0 additions & 12 deletions object/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,15 +195,3 @@ func GetCertByDomain(domain string) (*Cert, error) {

return nil, nil
}

func GetWafRulesByIds(ids []string) string {
var res string
for _, id := range ids {
if rule, ok := ruleMap[id]; ok {
for _, expression := range rule.Expressions {
res += expression.Value + "\n"
}
}
}
return res
}
29 changes: 16 additions & 13 deletions rule/rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ import (
)

type Rule interface {
checkRule(expressions []*object.Expression, req *http.Request) (bool, string, error)
checkRule(expressions []*object.Expression, req *http.Request) (bool, string, string, error)
}

func CheckRules(wafRuleIds []string, r *http.Request) (bool, string, error) {
func CheckRules(wafRuleIds []string, r *http.Request) (string, string, error) {
rules := object.GetRulesByRuleIds(wafRuleIds)
for _, rule := range rules {
var ruleObj Rule
Expand All @@ -34,29 +34,32 @@ func CheckRules(wafRuleIds []string, r *http.Request) (bool, string, error) {
ruleObj = &UaRule{}
case "IP":
ruleObj = &IpRule{}
case "WAF":
ruleObj = &WafRule{}
default:
return false, "", fmt.Errorf("unknown rule type: %s for rule: %s", rule.Type, rule.GetId())
return "", "", fmt.Errorf("unknown rule type: %s for rule: %s", rule.Type, rule.GetId())
}

isHit, reason, err := ruleObj.checkRule(rule.Expressions, r)
isHit, action, reason, err := ruleObj.checkRule(rule.Expressions, r)
if err != nil {
return false, "", err
return "", "", err
}
if action == "" {
action = rule.Action
}

if isHit {
if rule.Action == "Block" {
if action == "Block" || action == "Drop" {
if rule.Reason != "" {
reason = rule.Reason
}

return false, reason, nil
} else if rule.Action == "Allow" {
return true, "", nil
return action, reason, nil
} else if action == "Allow" {
return action, reason, nil
} else {
return false, "", fmt.Errorf("unknown rule action: %s for rule: %s", rule.Action, rule.GetId())
return "", "", fmt.Errorf("unknown rule action: %s for rule: %s", action, rule.GetId())
}
}
}

return true, "", nil
return "", "", nil
}
22 changes: 11 additions & 11 deletions rule/rule_ip.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ import (

type IpRule struct{}

func (r *IpRule) checkRule(expressions []*object.Expression, req *http.Request) (bool, string, error) {
func (r *IpRule) checkRule(expressions []*object.Expression, req *http.Request) (bool, string, string, error) {
clientIp := util.GetClientIp(req)
netIp, err := parseIp(clientIp)
if err != nil {
return false, "", err
return false, "", "", err
}
for _, expression := range expressions {
reason := fmt.Sprintf("expression matched: \"%s %s %s\"", clientIp, expression.Operator, expression.Value)
Expand All @@ -39,40 +39,40 @@ func (r *IpRule) checkRule(expressions []*object.Expression, req *http.Request)
if strings.Contains(ip, "/") {
_, ipNet, err := net.ParseCIDR(ip)
if err != nil {
return false, "", err
return false, "", "", err
}

switch expression.Operator {
case "is in":
if ipNet.Contains(netIp) {
return true, reason, nil
return true, "", reason, nil
}
case "is not in":
if !ipNet.Contains(netIp) {
return true, reason, nil
return true, "", reason, nil
}
default:
return false, "", fmt.Errorf("unknown operator: %s", expression.Operator)
return false, "", "", fmt.Errorf("unknown operator: %s", expression.Operator)
}
} else if strings.ContainsAny(ip, ".:") {
switch expression.Operator {
case "is in":
if ip == clientIp {
return true, reason, nil
return true, "", reason, nil
}
case "is not in":
if ip != clientIp {
return true, reason, nil
return true, "", reason, nil
}
default:
return false, "", fmt.Errorf("unknown operator: %s", expression.Operator)
return false, "", "", fmt.Errorf("unknown operator: %s", expression.Operator)
}
} else {
return false, "", fmt.Errorf("unknown IP or CIDR format: %s", ip)
return false, "", "", fmt.Errorf("unknown IP or CIDR format: %s", ip)
}
}
}
return false, "", nil
return false, "", "", nil
}

func parseIp(ipStr string) (net.IP, error) {
Expand Down
16 changes: 8 additions & 8 deletions rule/rule_ua.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,39 +25,39 @@ import (

type UaRule struct{}

func (r *UaRule) checkRule(expressions []*object.Expression, req *http.Request) (bool, string, error) {
func (r *UaRule) checkRule(expressions []*object.Expression, req *http.Request) (bool, string, string, error) {
userAgent := req.UserAgent()
for _, expression := range expressions {
ua := expression.Value
reason := fmt.Sprintf("expression matched: \"%s %s %s\"", userAgent, expression.Operator, expression.Value)
switch expression.Operator {
case "contains":
if strings.Contains(userAgent, ua) {
return true, reason, nil
return true, "", reason, nil
}
case "does not contain":
if !strings.Contains(userAgent, ua) {
return true, reason, nil
return true, "", reason, nil
}
case "equals":
if userAgent == ua {
return true, reason, nil
return true, "", reason, nil
}
case "does not equal":
if strings.Compare(userAgent, ua) != 0 {
return true, reason, nil
return true, "", reason, nil
}
case "match":
// regex match
isHit, err := regexp.MatchString(ua, userAgent)
if err != nil {
return false, "", err
return false, "", "", err
}
if isHit {
return true, reason, nil
return true, "", reason, nil
}
}
}

return false, "", nil
return false, "", "", nil
}
97 changes: 97 additions & 0 deletions rule/rule_waf.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// Copyright 2024 The casbin Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package rule

import (
"fmt"
"net/http"

"github.com/casbin/caswaf/conf"
"github.com/casbin/caswaf/object"
"github.com/corazawaf/coraza/v3"
"github.com/corazawaf/coraza/v3/types"
"github.com/hsluoyz/modsecurity-go/seclang/parser"
)

type WafRule struct{}

func (r *WafRule) checkRule(expressions []*object.Expression, req *http.Request) (bool, string, string, error) {
var ruleStr string
for _, expression := range expressions {
ruleStr += expression.Value
}
waf, err := coraza.NewWAF(
coraza.NewWAFConfig().
WithErrorCallback(logError).
WithDirectives(conf.WafConf).
WithDirectives(ruleStr),
)
if err != nil {
return false, "", "", fmt.Errorf("create WAF failed")
}
tx := waf.NewTransaction()
processRequest(tx, req)
matchedRules := tx.MatchedRules()
for _, matchedRule := range matchedRules {
rule := matchedRule.Rule()
directive, err := parser.NewSecLangScannerFromString(rule.Raw()).AllDirective()
if err != nil {
return false, "", "", err
}
for _, d := range directive {
ruleDirective := d.(*parser.RuleDirective)
for _, action := range ruleDirective.Actions.Action {
switch action.Tk {
case parser.TkActionBlock, parser.TkActionDeny:
return true, "Block", fmt.Sprintf("blocked by WAF rule: %d", rule.ID()), nil
case parser.TkActionAllow:
return true, "Allow", "", nil
case parser.TkActionDrop:
return true, "Drop", fmt.Sprintf("dropped by WAF rule: %d", rule.ID()), nil
default:
// skip other actions
continue
}
}
}
}
return false, "", "", nil
}

func processRequest(tx types.Transaction, req *http.Request) {
// Process URI and method
tx.ProcessURI(req.URL.String(), req.Method, req.Proto)

// Process request headers
for key, values := range req.Header {
for _, value := range values {
tx.AddRequestHeader(key, value)
}
}
tx.ProcessRequestHeaders()

// Process request body (if any)
if req.Body != nil {
_, err := tx.ProcessRequestBody()
if err != nil {
return
}
}
}

func logError(error types.MatchedRule) {
msg := error.ErrorLog()
fmt.Printf("[WAFlogError][%s] %s\n", error.Rule().Severity(), msg)
}
Loading

0 comments on commit dc05abe

Please sign in to comment.