Skip to content

Commit

Permalink
Update caddywaf.go
Browse files Browse the repository at this point in the history
Minor improvements.
  • Loading branch information
fabriziosalmi authored Jan 6, 2025
1 parent 73338c6 commit 51878f3
Showing 1 changed file with 120 additions and 23 deletions.
143 changes: 120 additions & 23 deletions caddywaf.go
Original file line number Diff line number Diff line change
Expand Up @@ -834,6 +834,9 @@ func (m *Middleware) Provision(ctx caddy.Context) error {

// Validate the GeoIP database path
if !fileExists(geoIPPath) {
m.logger.Error("GeoIP database does not exist or is not readable",
zap.String("path", geoIPPath),
)
return fmt.Errorf("GeoIP database does not exist or is not readable: %s", geoIPPath)
}

Expand All @@ -843,12 +846,13 @@ func (m *Middleware) Provision(ctx caddy.Context) error {
)
reader, err := maxminddb.Open(geoIPPath)
if err != nil {
m.logger.Error("Failed to load GeoIP database", zap.String("path", geoIPPath), zap.Error(err))
return fmt.Errorf("failed to load GeoIP database: %w", err) // Wrap the error
m.logger.Error("Failed to load GeoIP database",
zap.String("path", geoIPPath),
zap.Error(err),
)
return fmt.Errorf("failed to load GeoIP database: %w", err)
}

// REMOVE defer reader.Close() HERE

// Share the GeoIP database between CountryBlock and CountryWhitelist
if m.CountryBlock.Enabled {
m.CountryBlock.geoIP = reader
Expand All @@ -869,7 +873,11 @@ func (m *Middleware) Provision(ctx caddy.Context) error {
zap.String("file", file),
)
if err := m.loadRulesFromFile(file); err != nil {
return fmt.Errorf("failed to load rules from %s: %w", file, err) // Wrap the error
m.logger.Error("Failed to load rules from file",
zap.String("file", file),
zap.Error(err),
)
return fmt.Errorf("failed to load rules from %s: %w", file, err)
}
}

Expand Down Expand Up @@ -949,7 +957,11 @@ func (m *Middleware) Provision(ctx caddy.Context) error {
zap.String("file", m.IPBlacklistFile),
)
if err := m.loadIPBlacklistFromFile(m.IPBlacklistFile); err != nil {
return fmt.Errorf("failed to load IP blacklist from %s: %w", m.IPBlacklistFile, err) // Wrap the error
m.logger.Error("Failed to load IP blacklist from file",
zap.String("file", m.IPBlacklistFile),
zap.Error(err),
)
return fmt.Errorf("failed to load IP blacklist from %s: %w", m.IPBlacklistFile, err)
}
} else {
m.ipBlacklist = make(map[string]bool)
Expand All @@ -962,7 +974,11 @@ func (m *Middleware) Provision(ctx caddy.Context) error {
zap.String("file", m.DNSBlacklistFile),
)
if err := m.loadDNSBlacklistFromFile(m.DNSBlacklistFile); err != nil {
return fmt.Errorf("failed to load DNS blacklist from %s: %w", m.DNSBlacklistFile, err) // Wrap the error
m.logger.Error("Failed to load DNS blacklist from file",
zap.String("file", m.DNSBlacklistFile),
zap.Error(err),
)
return fmt.Errorf("failed to load DNS blacklist from %s: %w", m.DNSBlacklistFile, err)
}
} else {
m.dnsBlacklist = []string{}
Expand Down Expand Up @@ -992,23 +1008,30 @@ func (m *Middleware) isIPBlacklisted(remoteAddr string) bool {

// Early return if the blacklist is empty
if len(m.ipBlacklist) == 0 {
m.logger.Debug("IP blacklist is empty, skipping check")
return false
}

// Extract and validate the IP from the remote address
ipStr := extractIP(remoteAddr)
if ipStr == "" {
m.logger.Warn("Failed to extract IP from remote address",
zap.String("remoteAddr", remoteAddr),
)
return false
}

ip := net.ParseIP(ipStr)
if ip == nil {
m.logger.Warn("Invalid IP address extracted",
zap.String("ipStr", ipStr),
)
return false
}

// Check if the IP is directly blacklisted
if m.ipBlacklist[ipStr] {
m.logger.Debug("IP is directly blacklisted",
m.logger.Info("IP is directly blacklisted",
zap.String("ip", ipStr),
)
return true
Expand All @@ -1033,14 +1056,17 @@ func (m *Middleware) isIPBlacklisted(remoteAddr string) bool {

// Check if the IP falls within the CIDR range
if ipNet.Contains(ip) {
m.logger.Debug("IP falls within a blacklisted CIDR range",
m.logger.Info("IP falls within a blacklisted CIDR range",
zap.String("ip", ipStr),
zap.String("cidr", blacklistEntry),
)
return true
}
}

m.logger.Debug("IP is not blacklisted",
zap.String("ip", ipStr),
)
return false
}

Expand All @@ -1051,27 +1077,32 @@ func (m *Middleware) isDNSBlacklisted(host string) bool {

// Early return if the blacklist is empty or nil
if len(m.dnsBlacklist) == 0 {
m.logger.Debug("DNS blacklist is empty, skipping check")
return false
}

// Normalize the host to lowercase and trim whitespace
host = strings.ToLower(strings.TrimSpace(host))
if host == "" {
m.logger.Warn("Empty host provided for DNS blacklist check")
return false
}

// Check if the host is an exact match to any blacklisted domain
for _, blacklistedDomain := range m.dnsBlacklist {
blacklistedDomain = strings.ToLower(strings.TrimSpace(blacklistedDomain)) // Normalize blacklisted domain as well
if host == blacklistedDomain {
m.logger.Debug("Host is blacklisted",
m.logger.Info("Host is blacklisted",
zap.String("host", host),
zap.String("blacklisted_domain", blacklistedDomain),
)
return true
}
}

m.logger.Debug("Host is not blacklisted",
zap.String("host", host),
)
return false
}

Expand Down Expand Up @@ -1181,13 +1212,21 @@ func validateRule(rule *Rule) error {
}

func (m *Middleware) loadIPBlacklistFromFile(path string) error {
// Acquire a write lock to protect shared state
m.mu.Lock()
defer m.mu.Unlock()

// Initialize the IP blacklist
m.ipBlacklist = make(map[string]bool)

// Log the attempt to load the IP blacklist file
m.logger.Debug("Loading IP blacklist from file",
zap.String("file", path),
)

// Attempt to read the file
content, err := os.ReadFile(path)
if err != nil {
// Log a warning and continue with an empty blacklist
m.logger.Warn("Failed to read IP blacklist file",
zap.String("file", path),
zap.Error(err),
Expand All @@ -1197,6 +1236,8 @@ func (m *Middleware) loadIPBlacklistFromFile(path string) error {

// Split the file content into lines
lines := strings.Split(string(content), "\n")
validEntries := 0

for i, line := range lines {
line = strings.TrimSpace(line)
if line == "" || strings.HasPrefix(line, "#") {
Expand All @@ -1207,6 +1248,7 @@ func (m *Middleware) loadIPBlacklistFromFile(path string) error {
if _, _, err := net.ParseCIDR(line); err == nil {
// It's a valid CIDR range
m.ipBlacklist[line] = true
validEntries++
m.logger.Debug("Added CIDR range to blacklist",
zap.String("cidr", line),
)
Expand All @@ -1216,6 +1258,7 @@ func (m *Middleware) loadIPBlacklistFromFile(path string) error {
if ip := net.ParseIP(line); ip != nil {
// It's a valid IP address
m.ipBlacklist[line] = true
validEntries++
m.logger.Debug("Added IP to blacklist",
zap.String("ip", line),
)
Expand All @@ -1232,22 +1275,28 @@ func (m *Middleware) loadIPBlacklistFromFile(path string) error {

m.logger.Info("IP blacklist loaded successfully",
zap.String("file", path),
zap.Int("count", len(m.ipBlacklist)),
zap.Int("valid_entries", validEntries),
zap.Int("total_lines", len(lines)),
)
return nil
}

func (m *Middleware) loadDNSBlacklistFromFile(path string) error {
// Acquire a write lock to protect shared state
m.mu.Lock()
defer m.mu.Unlock()

// Initialize an empty DNS blacklist
m.dnsBlacklist = []string{}

// Log the attempt to load the DNS blacklist file
m.logger.Debug("Loading DNS blacklist from file", zap.String("file", path))
m.logger.Debug("Loading DNS blacklist from file",
zap.String("file", path),
)

// Attempt to read the file
content, err := os.ReadFile(path)
if err != nil {
// Log a warning and continue with an empty blacklist
m.logger.Warn("Failed to read DNS blacklist file",
zap.String("file", path),
zap.Error(err),
Expand All @@ -1274,7 +1323,8 @@ func (m *Middleware) loadDNSBlacklistFromFile(path string) error {
// Log the successful loading of the DNS blacklist
m.logger.Info("DNS blacklist loaded successfully",
zap.String("file", path),
zap.Int("entries_loaded", len(validEntries)),
zap.Int("valid_entries", len(validEntries)),
zap.Int("total_lines", len(lines)),
)

return nil
Expand All @@ -1285,22 +1335,46 @@ func (m *Middleware) ReloadConfig() error {
m.mu.Lock()
defer m.mu.Unlock()

// Log the start of the reload process
m.logger.Info("Reloading WAF configuration")

// Reload rules
if err := m.loadRulesFromFiles(); err != nil {
m.logger.Error("Failed to reload rules",
zap.Error(err),
)
return fmt.Errorf("failed to reload rules: %v", err)
}

// Reload IP blacklist
if err := m.loadIPBlacklistFromFile(m.IPBlacklistFile); err != nil {
return fmt.Errorf("failed to reload IP blacklist: %v", err)
if m.IPBlacklistFile != "" {
if err := m.loadIPBlacklistFromFile(m.IPBlacklistFile); err != nil {
m.logger.Error("Failed to reload IP blacklist",
zap.String("file", m.IPBlacklistFile),
zap.Error(err),
)
return fmt.Errorf("failed to reload IP blacklist: %v", err)
}
} else {
m.logger.Debug("No IP blacklist file specified, skipping reload")
}

// Reload DNS blacklist
if err := m.loadDNSBlacklistFromFile(m.DNSBlacklistFile); err != nil {
return fmt.Errorf("failed to reload DNS blacklist: %v", err)
if m.DNSBlacklistFile != "" {
if err := m.loadDNSBlacklistFromFile(m.DNSBlacklistFile); err != nil {
m.logger.Error("Failed to reload DNS blacklist",
zap.String("file", m.DNSBlacklistFile),
zap.Error(err),
)
return fmt.Errorf("failed to reload DNS blacklist: %v", err)
}
} else {
m.logger.Debug("No DNS blacklist file specified, skipping reload")
}

// Update shared state while holding the lock
// Log the successful completion of the reload process
m.logger.Info("WAF configuration reloaded successfully")

return nil
}

Expand Down Expand Up @@ -1355,16 +1429,33 @@ func (m *Middleware) loadRulesFromFiles() error {
}

func (m *Middleware) loadRulesFromFile(path string) error {
// Acquire a write lock to protect shared state
m.mu.Lock()
defer m.mu.Unlock()

// Log the attempt to load the rule file
m.logger.Debug("Loading rules from file",
zap.String("file", path),
)

// Read the rule file
content, err := os.ReadFile(path)
if err != nil {
m.logger.Error("Failed to read rule file",
zap.String("file", path),
zap.Error(err),
)
return fmt.Errorf("failed to read rule file: %s, error: %v", path, err)
}

// Unmarshal the JSON content into a slice of Rule structs
var rules []Rule
if err := json.Unmarshal(content, &rules); err != nil {
return fmt.Errorf("failed to unmarshal rules from file: %s, error: %v. Ensure the file contains valid JSON for a list of WAF rules.", path, err)
m.logger.Error("Failed to unmarshal rules from file",
zap.String("file", path),
zap.Error(err),
)
return fmt.Errorf("failed to unmarshal rules from file: %s, error: %v. Ensure the file contains valid JSON for a list of WAF rules", path, err)
}

var invalidRules []string // Track invalid rules for logging
Expand Down Expand Up @@ -1406,11 +1497,10 @@ func (m *Middleware) loadRulesFromFile(path string) error {
// Compile the regex pattern and log detailed errors if it fails
regex, err := regexp.Compile(rule.Pattern)
if err != nil {
// Log the exact error with context
m.logger.Error("Failed to compile regex pattern for rule",
zap.String("rule_id", rule.ID),
zap.String("pattern", rule.Pattern),
zap.Error(err), // Log the exact error from regexp.Compile
zap.Error(err),
)
invalidRules = append(invalidRules, fmt.Sprintf("Rule '%s': invalid regex pattern '%s'. Error: %v. Ensure the pattern is a valid regular expression.", rule.ID, rule.Pattern, err))
continue
Expand All @@ -1434,5 +1524,12 @@ func (m *Middleware) loadRulesFromFile(path string) error {
)
}

// Log the successful loading of the rule file
m.logger.Info("Rules loaded successfully",
zap.String("file", path),
zap.Int("total_rules", len(rules)),
zap.Int("invalid_rules", len(invalidRules)),
)

return nil
}

0 comments on commit 51878f3

Please sign in to comment.