diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index f4a98f94..ac94b663 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -5,11 +5,11 @@ jobs: runs-on: ubuntu-latest steps: - name: Install Go - uses: actions/setup-go@v4 + uses: actions/setup-go@v5 with: go-version: 1.21.x - name: Checkout code - uses: actions/checkout@v2 + uses: actions/checkout@v4 - name: Run linters uses: golangci/golangci-lint-action@v6 with: @@ -24,23 +24,26 @@ jobs: steps: - name: Install Go if: success() - uses: actions/setup-go@v4 + uses: actions/setup-go@v5 with: go-version: ${{ matrix.go-version }} - name: Checkout code - uses: actions/checkout@v2 + uses: actions/checkout@v4 - name: go tests - run: go test -v -covermode=count -json ./... | tee test.json + run: go test -v -covermode=count -json ./... > test.json + - name: Print go test results + if: always() + run: cat test.json - name: annotate go tests if: always() - uses: guyarb/golang-test-annotations@v0.5.1 + uses: guyarb/golang-test-annotations@v0.8.0 with: test-results: test.json buf-lint-and-breaking-change-detection: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v2 + uses: actions/checkout@v4 - name: Setup uses: bufbuild/buf-setup-action@v1 with: diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index ca380b42..67f5b3c3 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -8,7 +8,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v2 + uses: actions/checkout@v4 - name: Setup uses: bufbuild/buf-setup-action@v1 with: @@ -25,13 +25,13 @@ jobs: runs-on: ubuntu-latest steps: - name: Install Go - uses: actions/setup-go@v2 + uses: actions/setup-go@v5 with: go-version: 1.21.x - name: Checkout code - uses: actions/checkout@v2 + uses: actions/checkout@v4 - name: Run linters - uses: golangci/golangci-lint-action@v3 + uses: golangci/golangci-lint-action@v6 with: version: latest args: --timeout=3m @@ -44,15 +44,18 @@ jobs: steps: - name: Install Go if: success() - uses: actions/setup-go@v2 + uses: actions/setup-go@v5 with: go-version: ${{ matrix.go-version }} - name: Checkout code - uses: actions/checkout@v2 + uses: actions/checkout@v4 - name: go tests - run: go test -v -covermode=count -json ./... | tee test.json + run: go test -v -covermode=count -json ./... > test.json + - name: Print go test results + if: always() + run: cat test.json - name: annotate go tests if: always() - uses: guyarb/golang-test-annotations@v0.5.1 + uses: guyarb/golang-test-annotations@v0.8.0 with: test-results: test.json diff --git a/go.mod b/go.mod index 724a56b9..d2a2bc58 100644 --- a/go.mod +++ b/go.mod @@ -74,6 +74,7 @@ require ( github.com/lufia/plan9stats v0.0.0-20240408141607-282e7b5d6b74 // indirect github.com/magiconair/properties v1.8.7 // indirect github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mattn/go-sqlite3 v1.14.22 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/ncruces/go-strftime v0.1.9 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect diff --git a/go.sum b/go.sum index b34351be..930ac6fd 100644 --- a/go.sum +++ b/go.sum @@ -140,8 +140,9 @@ github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0V github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-sqlite3 v1.14.7 h1:fxWBnXkxfM6sRiuH3bqJ4CfzZojMOLVc0UTsTglEghA= github.com/mattn/go-sqlite3 v1.14.7/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= +github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= +github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4= diff --git a/pkg/uhttp/client.go b/pkg/uhttp/client.go index 6b800c58..2fcb9b21 100644 --- a/pkg/uhttp/client.go +++ b/pkg/uhttp/client.go @@ -2,8 +2,12 @@ package uhttp import ( "context" + "crypto/sha256" "crypto/tls" + "fmt" "net/http" + "sort" + "strings" "time" "go.uber.org/zap" @@ -72,3 +76,52 @@ func NewClient(ctx context.Context, options ...Option) (*http.Client, error) { httpClient.Transport = t return httpClient, nil } + +type icache interface { + Get(req *http.Request) (*http.Response, error) + Set(req *http.Request, value *http.Response) error +} + +// CreateCacheKey generates a cache key based on the request URL, query parameters, and headers. +func CreateCacheKey(req *http.Request) (string, error) { + if req == nil { + return "", fmt.Errorf("request is nil") + } + var sortedParams []string + // Normalize the URL path + path := strings.ToLower(req.URL.Path) + // Combine the path with sorted query parameters + queryParams := req.URL.Query() + for k, v := range queryParams { + for _, value := range v { + sortedParams = append(sortedParams, fmt.Sprintf("%s=%s", k, value)) + } + } + + sort.Strings(sortedParams) + queryString := strings.Join(sortedParams, "&") + // Include relevant headers in the cache key + var headerParts []string + for key, values := range req.Header { + for _, value := range values { + if key == "Accept" || key == "Content-Type" || key == "Cookie" || key == "Range" { + headerParts = append(headerParts, fmt.Sprintf("%s=%s", key, value)) + } + } + } + + sort.Strings(headerParts) + headersString := strings.Join(headerParts, "&") + // Create a unique string for the cache key + cacheString := fmt.Sprintf("%s?%s&headers=%s", path, queryString, headersString) + + // Hash the cache string to create a key + hash := sha256.New() + _, err := hash.Write([]byte(cacheString)) + if err != nil { + return "", err + } + + cacheKey := fmt.Sprintf("%x", hash.Sum(nil)) + return cacheKey, nil +} diff --git a/pkg/uhttp/dbcache.go b/pkg/uhttp/dbcache.go new file mode 100644 index 00000000..8635c47f --- /dev/null +++ b/pkg/uhttp/dbcache.go @@ -0,0 +1,577 @@ +package uhttp + +import ( + "bufio" + "bytes" + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/http/httputil" + "os" + "path/filepath" + "time" + + "github.com/doug-martin/goqu/v9" + // NOTE: required to register the dialect for goqu. + // + // If you remove this import, goqu.Dialect("sqlite3") will + // return a copy of the default dialect, which is not what we want, + // and allocates a ton of memory. + _ "github.com/doug-martin/goqu/v9/dialect/sqlite3" + _ "github.com/glebarez/go-sqlite" + "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap/ctxzap" + "go.uber.org/zap" +) + +type DBCache struct { + rawDb *sql.DB + db *goqu.Database + // Cleanup interval, close and remove db + waitDuration time.Duration + // Cache duration for removing expired keys + expirationTime time.Duration + // Database path + location string + // Enable statistics(hits, misses) + stats bool +} + +type CacheRow struct { + Key string + Value []byte + Expires time.Time + LastAccess time.Time + Url string +} + +type Stats struct { + // Hits is a number of successfully found keys + Hits int64 `json:"hits"` + // Misses is a number of not found keys + Misses int64 `json:"misses"` +} + +// SqliteError implement sqlite error code. +type SqliteError struct { + Code int `json:"Code,omitempty"` /* The error code returned by SQLite */ + ExtendedCode int `json:"ExtendedCode,omitempty"` /* The extended error code returned by SQLite */ + err string +} + +func (b *SqliteError) Error() string { + return b.err +} + +const ( + failStartTransaction = "Failed to start a transaction" + nilConnection = "Database connection is nil" + errQueryingTable = "Error querying cache table" + failRollback = "Failed to rollback transaction" + failInsert = "Failed to insert response data into cache table" + staticQuery = "INSERT INTO http_stats(key, %s) values(?, 1)" + failScanResponse = "Failed to scan rows for cached response" + cacheTTLThreshold = 60 + cacheTTLMultiplier int64 = 5 +) + +var defaultWaitDuration = cacheTTLThreshold * time.Second // Default Cleanup interval, 60 seconds + +const tableName = "http_cache" + +func NewDBCache(ctx context.Context, cfg CacheConfig) (*DBCache, error) { + var ( + err error + dc = &DBCache{ + waitDuration: defaultWaitDuration, // Default Cleanup interval, 60 seconds + stats: cfg.LogDebug, + } + ) + l := ctxzap.Extract(ctx) + dc, err = dc.load(ctx) + if err != nil { + l.Debug("Failed to open database", zap.Error(err)) + return nil, err + } + + // Create cache table and index + _, err = dc.db.ExecContext(ctx, ` + CREATE TABLE IF NOT EXISTS http_cache( + key TEXT PRIMARY KEY, + value BLOB, + expires INT, + lastAccess TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL, + url TEXT + ); + CREATE UNIQUE INDEX IF NOT EXISTS idx_cache_key ON http_cache (key); + CREATE INDEX IF NOT EXISTS expires ON http_cache (expires); + CREATE INDEX IF NOT EXISTS lastAccess ON http_cache (lastAccess); + CREATE TABLE IF NOT EXISTS http_stats( + id INTEGER PRIMARY KEY, + key TEXT, + hits INT DEFAULT 0, + misses INT DEFAULT 0 + ); + DELETE FROM http_cache; + DELETE FROM http_stats;`) + if err != nil { + l.Debug("Failed to create cache table in database", zap.Error(err)) + return nil, err + } + + if cfg.CacheTTL <= 0 { + l.Debug("Cache TTL is 0. Disabling cache.") + return nil, nil + } + + if cfg.CacheTTL > cacheTTLThreshold { + dc.waitDuration = time.Duration(cfg.CacheTTL*cacheTTLMultiplier) * time.Second // set as a fraction of the Cache TTL + } + + dc.expirationTime = time.Duration(cfg.CacheTTL) * time.Second // time for removing expired key + + go func(waitDuration, expirationTime time.Duration) { + ctxWithTimeout, cancel := context.WithTimeout( + ctx, + waitDuration, + ) + defer cancel() + // TODO: I think this should be wait duration + ticker := time.NewTicker(expirationTime) + defer ticker.Stop() + for { + select { + case <-ctxWithTimeout.Done(): + // ctx done, shutting down cache cleanup routine + ticker.Stop() + err := dc.cleanup(ctx) + if err != nil { + l.Debug("shutting down cache failed", zap.Error(err)) + } + return + case <-ticker.C: + err := dc.deleteExpired(ctx) + if err != nil { + l.Debug("Failed to delete expired cache entries", zap.Error(err)) + } + } + } + }(dc.waitDuration, dc.expirationTime) + + return dc, nil +} + +func (d *DBCache) load(ctx context.Context) (*DBCache, error) { + l := ctxzap.Extract(ctx) + cacheDir, err := os.UserCacheDir() + if err != nil { + l.Debug("Failed to read user cache directory", zap.Error(err)) + return nil, err + } + + file := filepath.Join(cacheDir, "lcache.db") + d.location = file + + rawDB, err := sql.Open("sqlite", file) + if err != nil { + return nil, err + } + + d.db = goqu.New("sqlite3", rawDB) + d.rawDb = rawDB + return d, nil +} + +func checkFileExists(filePath string) bool { + _, err := os.Stat(filePath) + return !errors.Is(err, os.ErrNotExist) +} + +func (d *DBCache) removeDB(ctx context.Context) error { + if !checkFileExists(d.location) { + return fmt.Errorf("file not found %s", d.location) + } + + err := d.close(ctx) + if err != nil { + return err + } + // TODO: close DB so no file handles exist and we can delete the file on windows + err = os.Remove(d.location) + if err != nil { + ctxzap.Extract(ctx).Debug("error removing database", zap.Error(err)) + return err + } + + return nil +} + +// Get returns cached response (if exists). +func (d *DBCache) Get(req *http.Request) (*http.Response, error) { + var ( + isFound bool = false + resp *http.Response + ) + if d.IsNilConnection() { + return nil, fmt.Errorf("%s", nilConnection) + } + + key, err := CreateCacheKey(req) + if err != nil { + return nil, err + } + ctx := req.Context() + + entry, err := d.pick(ctx, key) + if err == nil && len(entry) > 0 { + r := bufio.NewReader(bytes.NewReader(entry)) + resp, err = http.ReadResponse(r, nil) + if err != nil { + return nil, err + } + + isFound = true + } + + if d.stats { + if isFound { + err = d.hits(ctx, key) + if err != nil { + ctxzap.Extract(ctx).Debug("Failed to update cache hits", zap.Error(err)) + } + } + + err = d.misses(ctx, key) + if err != nil { + ctxzap.Extract(ctx).Debug("Failed to update cache misses", zap.Error(err)) + } + } + + return resp, nil +} + +// Set stores and save response in the db. +func (d *DBCache) Set(req *http.Request, value *http.Response) error { + key, err := CreateCacheKey(req) + if err != nil { + return err + } + var url string + if d.IsNilConnection() { + return fmt.Errorf("%s", nilConnection) + } + + cacheableResponse, err := httputil.DumpResponse(value, true) + if err != nil { + return err + } + + if value.Request != nil { + url = value.Request.URL.String() + } + + err = d.insert(req.Context(), + key, + cacheableResponse, + url, + ) + if err != nil { + return err + } + + return nil +} + +func (d *DBCache) cleanup(ctx context.Context) error { + if d.IsNilConnection() { + return fmt.Errorf("%s", nilConnection) + } + + l := ctxzap.Extract(ctx) + stats, err := d.getStats(ctx) + if err != nil { + l.Debug("error getting stats", zap.Error(err)) + return err + } + + l.Debug("summary and stats", zap.Any("stats", stats)) + err = d.close(ctx) + if err != nil { + l.Debug("error closing db", zap.Error(err)) + return err + } + + err = d.removeDB(ctx) + if err != nil { + l.Debug("error removing db", zap.Error(err)) + return err + } + + return nil +} + +// Insert data into the cache table. +func (d *DBCache) insert(ctx context.Context, key string, value any, url string) error { + var ( + bytes []byte + err error + ok bool + ) + if d.IsNilConnection() { + return fmt.Errorf("%s", nilConnection) + } + + l := ctxzap.Extract(ctx) + if bytes, ok = value.([]byte); !ok { + bytes, err = json.Marshal(value) + if err != nil { + l.Debug("Failed to marshal response data", zap.Error(err)) + return err + } + } + + tx, err := d.db.Begin() + if err != nil { + l.Debug(failStartTransaction, zap.Error(err)) + return err + } + + ds := goqu.Insert(tableName).Rows( + CacheRow{ + Key: key, + Value: bytes, + Expires: time.Now().Add(d.expirationTime), + Url: url, + }, + ) + ds = ds.OnConflict(goqu.DoUpdate("key", CacheRow{ + Value: bytes, + Expires: time.Now().Add(d.expirationTime), + Url: url, + })) + insertSQL, args, err := ds.ToSQL() + if err != nil { + l.Debug("Failed to create insert statement", zap.Error(err)) + return err + } + _, err = tx.ExecContext(ctx, insertSQL, args...) + if err != nil { + if errtx := tx.Rollback(); errtx != nil { + l.Debug(failRollback, zap.Error(errtx)) + } + + l.Debug(failInsert, zap.Error(err)) + return err + } + + err = tx.Commit() + if err != nil { + if errtx := tx.Rollback(); errtx != nil { + l.Debug(failRollback, zap.Error(errtx)) + } + + l.Debug(failInsert, zap.Error(err)) + return err + } + + return nil +} + +// IsNilConnection check if the database connection is nil. +func (d *DBCache) IsNilConnection() bool { + return d.db == nil +} + +// pick query for cached response. +func (d *DBCache) pick(ctx context.Context, key string) ([]byte, error) { + var data []byte + if d.IsNilConnection() { + return nil, fmt.Errorf("%s", nilConnection) + } + + l := ctxzap.Extract(ctx) + rows, err := d.db.QueryContext(ctx, "SELECT value FROM http_cache where key = ?", key) + if err != nil { + l.Debug(errQueryingTable, zap.Error(err)) + return nil, err + } + + defer rows.Close() + for rows.Next() { + err = rows.Scan(&data) + if err != nil { + l.Debug(failScanResponse, zap.Error(err)) + return nil, err + } + } + + return data, nil +} + +func (d *DBCache) remove(ctx context.Context) error { + if d.IsNilConnection() { + return fmt.Errorf("%s", nilConnection) + } + + l := ctxzap.Extract(ctx) + tx, err := d.db.Begin() + if err != nil { + l.Debug(failStartTransaction, zap.Error(err)) + return err + } + + _, err = d.db.ExecContext(ctx, "DELETE FROM http_cache WHERE expires < ?", time.Now().UnixNano()) + if err != nil { + if errtx := tx.Rollback(); errtx != nil { + l.Debug(failRollback, zap.Error(errtx)) + } + + l.Debug("Failed to delete cache key", zap.Error(err)) + return err + } + + err = tx.Commit() + if err != nil { + if errtx := tx.Rollback(); errtx != nil { + l.Debug(failRollback, zap.Error(errtx)) + } + + l.Debug("Failed to remove cache entry", zap.Error(err)) + return err + } + + return nil +} + +func (d *DBCache) close(ctx context.Context) error { + if d.IsNilConnection() { + return fmt.Errorf("%s", nilConnection) + } + + err := d.rawDb.Close() + if err != nil { + ctxzap.Extract(ctx).Debug("Failed to close database connection", zap.Error(err)) + return err + } + + return nil +} + +// Delete all expired items from the cache. +func (d *DBCache) deleteExpired(ctx context.Context) error { + if d.IsNilConnection() { + return fmt.Errorf("%s", nilConnection) + } + + l := ctxzap.Extract(ctx) + err := d.remove(ctx) + if err != nil { + l.Debug("error removing rows", zap.Error(err)) + } + + return nil +} + +func (d *DBCache) hits(ctx context.Context, key string) error { + if d.IsNilConnection() { + return fmt.Errorf("%s", nilConnection) + } + + strField := "hits" + err := d.update(ctx, strField, key) + if err != nil { + return err + } + + return nil +} + +func (d *DBCache) misses(ctx context.Context, key string) error { + if d.IsNilConnection() { + return fmt.Errorf("%s", nilConnection) + } + + strField := "misses" + err := d.update(ctx, strField, key) + if err != nil { + return err + } + + return nil +} + +func (d *DBCache) update(ctx context.Context, field, key string) error { + l := ctxzap.Extract(ctx) + tx, err := d.db.Begin() + if err != nil { + l.Debug(failStartTransaction, zap.Error(err)) + return err + } + + query, args := d.queryString(field) + _, err = d.db.ExecContext(ctx, fmt.Sprintf(query, args...), key) + if err != nil { + if errtx := tx.Rollback(); errtx != nil { + l.Debug(failRollback, zap.Error(errtx)) + } + + l.Debug("error updating "+field, zap.Error(err)) + return err + } + + err = tx.Commit() + if err != nil { + if errtx := tx.Rollback(); errtx != nil { + l.Debug(failRollback, zap.Error(errtx)) + } + + l.Debug("Failed to update "+field, zap.Error(err)) + return err + } + + return nil +} + +func (d *DBCache) queryString(field string) (string, []interface{}) { + return staticQuery, []interface{}{ + fmt.Sprint(field), + } +} + +func (d *DBCache) getStats(ctx context.Context) (Stats, error) { + var ( + hits = 0 + misses = 0 + ) + if d.IsNilConnection() { + return Stats{}, fmt.Errorf("%s", nilConnection) + } + + l := ctxzap.Extract(ctx) + rows, err := d.db.QueryContext(ctx, ` + SELECT + sum(hits) total_hits, + sum(misses) total_misses + FROM http_stats + `) + if err != nil { + l.Debug(errQueryingTable, zap.Error(err)) + return Stats{}, err + } + + defer rows.Close() + for rows.Next() { + err = rows.Scan(&hits, &misses) + if err != nil { + l.Debug(failScanResponse, zap.Error(err)) + return Stats{}, err + } + } + + return Stats{ + Hits: int64(hits), + Misses: int64(misses), + }, nil +} diff --git a/pkg/uhttp/dbcache_test.go b/pkg/uhttp/dbcache_test.go new file mode 100644 index 00000000..74d8a742 --- /dev/null +++ b/pkg/uhttp/dbcache_test.go @@ -0,0 +1,78 @@ +package uhttp + +import ( + "encoding/json" + "net/http" + "testing" + + _ "github.com/doug-martin/goqu/v9/dialect/sqlite3" + _ "github.com/glebarez/go-sqlite" + "github.com/stretchr/testify/require" +) + +var urlTest = "https://jsonplaceholder.typicode.com/posts/1/comments" + +func TestDBCacheGettersAndSetters(t *testing.T) { + cli := &http.Client{} + fc, err := getDBCacheForTesting() + require.Nil(t, err) + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, urlTest, nil) + require.Nil(t, err) + require.NotNil(t, req) + + resp, err := cli.Do(req) + require.Nil(t, err) + require.NotNil(t, resp) + defer resp.Body.Close() + + var ic icache = &DBCache{ + db: fc.db, + } + cKey, err := CreateCacheKey(resp.Request) + require.Nil(t, err) + require.NotEmpty(t, cKey) + + err = ic.Set(req, resp) + require.Nil(t, err) + + res, err := ic.Get(req) + require.Nil(t, err) + require.NotNil(t, res) + require.Equal(t, resp.StatusCode, res.StatusCode) + require.Equal(t, resp.ContentLength, res.ContentLength) + require.EqualValues(t, resp.Header, res.Header) + + err = ic.Set(req, resp) + require.Nil(t, err, "Setting same cache key again should not error") + + defer res.Body.Close() +} + +func TestDBCache(t *testing.T) { + fc, err := getDBCacheForTesting() + require.Nil(t, err) + + err = fc.insert(ctx, "urlTest", urlTest, "http://example.com") + require.Nil(t, err) + + res, err := fc.pick(ctx, "urlTest") + require.Nil(t, err) + require.NotNil(t, res) + + var val string + err = json.Unmarshal(res, &val) + require.Nil(t, err) + require.Equal(t, val, urlTest) +} + +func getDBCacheForTesting() (*DBCache, error) { + fc, err := NewDBCache(ctx, CacheConfig{ + CacheTTL: 3600, + }) + if err != nil { + return nil, err + } + + return fc, nil +} diff --git a/pkg/uhttp/gocache.go b/pkg/uhttp/gocache.go index c29d4296..826467ae 100644 --- a/pkg/uhttp/gocache.go +++ b/pkg/uhttp/gocache.go @@ -4,12 +4,11 @@ import ( "bufio" "bytes" "context" - "crypto/sha256" "fmt" "net/http" "net/http/httputil" - "sort" - "strings" + "os" + "strconv" "time" bigCache "github.com/allegro/bigcache/v3" @@ -17,16 +16,136 @@ import ( "go.uber.org/zap" ) +const ( + cacheTTLMaximum = 31536000 // 31536000 seconds = one year + cacheTTLDefault = 3600 // 3600 seconds = one hour + defaultCacheSize = 50 // MB +) + +type CacheConfig struct { + LogDebug bool + CacheTTL int64 // If 0, cache is disabled + CacheMaxSize int +} +type ContextKey struct{} + type GoCache struct { rootLibrary *bigCache.BigCache } -func NewGoCache(ctx context.Context, cfg CacheConfig) (GoCache, error) { +type NoopCache struct{} + +func NewNoopCache(ctx context.Context) *NoopCache { + return &NoopCache{} +} + +func (g *NoopCache) Get(req *http.Request) (*http.Response, error) { + return nil, nil +} + +func (n *NoopCache) Set(req *http.Request, value *http.Response) error { + return nil +} + +func (cc *CacheConfig) ToString() string { + return fmt.Sprintf("CacheTTL: %d, CacheMaxSize: %d, LogDebug: %t", cc.CacheTTL, cc.CacheMaxSize, cc.LogDebug) +} + +func DefaultCacheConfig() CacheConfig { + return CacheConfig{ + CacheTTL: cacheTTLDefault, + CacheMaxSize: defaultCacheSize, + LogDebug: false, + } +} + +func NewCacheConfigFromEnv() *CacheConfig { + config := DefaultCacheConfig() + + cacheMaxSize, err := strconv.ParseInt(os.Getenv("BATON_HTTP_CACHE_MAX_SIZE"), 10, 64) + if err == nil { + config.CacheMaxSize = int(cacheMaxSize) + } + + // read the `BATON_HTTP_CACHE_TTL` environment variable and return + // the value as a number of seconds between 0 and an arbitrary maximum. Note: + // this means that passing a value of `-1` will set the TTL to zero rather than + // infinity. + cacheTTL, err := strconv.ParseInt(os.Getenv("BATON_HTTP_CACHE_TTL"), 10, 64) + if err == nil { + config.CacheTTL = min(cacheTTLMaximum, max(0, cacheTTL)) + } + + return &config +} + +func NewCacheConfigFromCtx(ctx context.Context) (*CacheConfig, error) { + defaultConfig := DefaultCacheConfig() + if v := ctx.Value(ContextKey{}); v != nil { + ctxConfig, ok := v.(CacheConfig) + if !ok { + return nil, fmt.Errorf("error casting config values from context") + } + return &ctxConfig, nil + } + return &defaultConfig, nil +} + +func NewHttpCache(ctx context.Context, config *CacheConfig) (icache, error) { l := ctxzap.Extract(ctx) - if cfg.DisableCache { - l.Debug("http cache disabled") - return GoCache{}, nil + + var noopCache icache = &NoopCache{} + cache := noopCache + + if config == nil { + config = NewCacheConfigFromEnv() + } + + if config.CacheTTL <= 0 { + l.Debug("CacheTTL is <=0, disabling cache.", zap.Int64("CacheTTL", config.CacheTTL)) + return noopCache, nil + } + + disableCache, err := strconv.ParseBool(os.Getenv("BATON_DISABLE_HTTP_CACHE")) + if err != nil { + disableCache = false + } + if disableCache { + l.Debug("BATON_DISABLE_HTTP_CACHE set, disabling cache.") + return noopCache, nil } + + cacheBackend := os.Getenv("BATON_HTTP_CACHE_BACKEND") + if cacheBackend == "" { + l.Debug("defaulting to db-cache") + cacheBackend = "db" + } + + switch cacheBackend { + case "memory": + l.Debug("Using in-memory cache") + memCache, err := NewGoCache(ctx, *config) + if err != nil { + l.Error("error creating http cache (in-memory)", zap.Error(err)) + return nil, err + } + cache = memCache + case "db": + l.Debug("Using db cache") + dbCache, err := NewDBCache(ctx, *config) + if err != nil { + l.Error("error creating http cache (db-cache)", zap.Error(err)) + return nil, err + } + cache = dbCache + } + + return cache, nil +} + +func NewGoCache(ctx context.Context, cfg CacheConfig) (*GoCache, error) { + l := ctxzap.Extract(ctx) + gc := GoCache{} config := bigCache.DefaultConfig(time.Duration(cfg.CacheTTL) * time.Second) config.Verbose = cfg.LogDebug config.Shards = 4 @@ -34,7 +153,7 @@ func NewGoCache(ctx context.Context, cfg CacheConfig) (GoCache, error) { cache, err := bigCache.New(ctx, config) if err != nil { l.Error("http cache initialization error", zap.Error(err)) - return GoCache{}, err + return nil, err } l.Debug("http cache config", @@ -48,11 +167,9 @@ func NewGoCache(ctx context.Context, cfg CacheConfig) (GoCache, error) { zap.Bool("Verbose", config.Verbose), zap.Int("HardMaxCacheSize", config.HardMaxCacheSize), )) - gc := GoCache{ - rootLibrary: cache, - } + gc.rootLibrary = cache - return gc, nil + return &gc, nil } func (g *GoCache) Statistics() bigCache.Stats { @@ -63,54 +180,14 @@ func (g *GoCache) Statistics() bigCache.Stats { return g.rootLibrary.Stats() } -// CreateCacheKey generates a cache key based on the request URL, query parameters, and headers. -// The key is a SHA-256 hash of the normalized URL path, sorted query parameters, and relevant headers. -func CreateCacheKey(req *http.Request) (string, error) { - // Normalize the URL path - path := strings.ToLower(req.URL.Path) - - // Combine the path with sorted query parameters - queryParams := req.URL.Query() - var sortedParams []string - for k, v := range queryParams { - for _, value := range v { - sortedParams = append(sortedParams, fmt.Sprintf("%s=%s", k, value)) - } - } - - sort.Strings(sortedParams) - queryString := strings.Join(sortedParams, "&") - - // Include relevant headers in the cache key - var headerParts []string - for key, values := range req.Header { - for _, value := range values { - if key == "Accept" || key == "Authorization" || key == "Cookie" || key == "Range" { - headerParts = append(headerParts, fmt.Sprintf("%s=%s", key, value)) - } - } +func (g *GoCache) Get(req *http.Request) (*http.Response, error) { + if g.rootLibrary == nil { + return nil, nil } - sort.Strings(headerParts) - headersString := strings.Join(headerParts, "&") - - // Create a unique string for the cache key - cacheString := fmt.Sprintf("%s?%s&headers=%s", path, queryString, headersString) - - // Hash the cache string to create a key - hash := sha256.New() - _, err := hash.Write([]byte(cacheString)) + key, err := CreateCacheKey(req) if err != nil { - return "", err - } - - cacheKey := fmt.Sprintf("%x", hash.Sum(nil)) - return cacheKey, nil -} - -func (g *GoCache) Get(key string) (*http.Response, error) { - if g.rootLibrary == nil { - return nil, nil + return nil, err } entry, err := g.rootLibrary.Get(key) @@ -127,11 +204,16 @@ func (g *GoCache) Get(key string) (*http.Response, error) { return nil, nil } -func (g *GoCache) Set(key string, value *http.Response) error { +func (g *GoCache) Set(req *http.Request, value *http.Response) error { if g.rootLibrary == nil { return nil } + key, err := CreateCacheKey(req) + if err != nil { + return err + } + cacheableResponse, err := httputil.DumpResponse(value, true) if err != nil { return err diff --git a/pkg/uhttp/wrapper.go b/pkg/uhttp/wrapper.go index 7c393f1d..bd0b07fe 100644 --- a/pkg/uhttp/wrapper.go +++ b/pkg/uhttp/wrapper.go @@ -10,8 +10,6 @@ import ( "io" "net/http" "net/url" - "os" - "strconv" v2 "github.com/conductorone/baton-sdk/pb/c1/connector/v2" "github.com/conductorone/baton-sdk/pkg/ratelimit" @@ -28,8 +26,6 @@ const ( applicationFormUrlencoded = "application/x-www-form-urlencoded" applicationVndApiJSON = "application/vnd.api+json" acceptHeader = "Accept" - cacheTTLMaximum = 31536000 // 31536000 seconds = one year - cacheTTLDefault = 3600 // 3600 seconds = one hour ) type WrapperResponse struct { @@ -47,18 +43,11 @@ type ( } BaseHttpClient struct { HttpClient *http.Client - baseHttpCache GoCache + baseHttpCache icache } DoOption func(resp *WrapperResponse) error RequestOption func() (io.ReadWriter, map[string]string, error) - ContextKey struct{} - CacheConfig struct { - LogDebug bool - CacheTTL int32 - CacheMaxSize int - DisableCache bool - } ) func NewBaseHttpClient(httpClient *http.Client) *BaseHttpClient { @@ -70,57 +59,19 @@ func NewBaseHttpClient(httpClient *http.Client) *BaseHttpClient { return client } -// getCacheTTL read the `BATON_HTTP_CACHE_TTL` environment variable and return -// the value as a number of seconds between 0 and an arbitrary maximum. Note: -// this means that passing a value of `-1` will set the TTL to zero rather than -// infinity. -func getCacheTTL() int32 { - cacheTTL, err := strconv.ParseInt(os.Getenv("BATON_HTTP_CACHE_TTL"), 10, 64) - if err != nil { - cacheTTL = cacheTTLDefault // seconds - } - - cacheTTL = min(cacheTTLMaximum, max(0, cacheTTL)) - - //nolint:gosec // No risk of overflow because we have a low maximum. - return int32(cacheTTL) -} - func NewBaseHttpClientWithContext(ctx context.Context, httpClient *http.Client) (*BaseHttpClient, error) { l := ctxzap.Extract(ctx) - disableCache, err := strconv.ParseBool(os.Getenv("BATON_DISABLE_HTTP_CACHE")) - if err != nil { - disableCache = false - } - cacheMaxSize, err := strconv.ParseInt(os.Getenv("BATON_HTTP_CACHE_MAX_SIZE"), 10, 64) - if err != nil { - cacheMaxSize = 128 // MB - } - var ( - config = CacheConfig{ - LogDebug: l.Level().Enabled(zap.DebugLevel), - CacheTTL: getCacheTTL(), // seconds - CacheMaxSize: int(cacheMaxSize), // MB - DisableCache: disableCache, - } - ok bool - ) - if v := ctx.Value(ContextKey{}); v != nil { - if config, ok = v.(CacheConfig); !ok { - return nil, fmt.Errorf("error casting config values from context") - } - } - cache, err := NewGoCache(ctx, config) + cache, err := NewHttpCache(ctx, nil) if err != nil { l.Error("error creating http cache", zap.Error(err)) - return nil, err } - - return &BaseHttpClient{ + cli := &BaseHttpClient{ HttpClient: httpClient, baseHttpCache: cache, - }, nil + } + + return cli, nil } // WithJSONResponse is a wrapper that marshals the returned response body into @@ -232,25 +183,19 @@ func WrapErrorsWithRateLimitInfo(preferredCode codes.Code, resp *http.Response, func (c *BaseHttpClient) Do(req *http.Request, options ...DoOption) (*http.Response, error) { var ( - cacheKey string - err error - resp *http.Response + err error + resp *http.Response ) l := ctxzap.Extract(req.Context()) if req.Method == http.MethodGet { - cacheKey, err = CreateCacheKey(req) - if err != nil { - return nil, err - } - - resp, err = c.baseHttpCache.Get(cacheKey) + resp, err = c.baseHttpCache.Get(req) if err != nil { return nil, err } if resp == nil { - l.Debug("http cache miss", zap.String("cacheKey", cacheKey), zap.String("url", req.URL.String())) + l.Debug("http cache miss", zap.String("url", req.URL.String())) } else { - l.Debug("http cache hit", zap.String("cacheKey", cacheKey), zap.String("url", req.URL.String())) + l.Debug("http cache hit", zap.String("url", req.URL.String())) } } @@ -318,9 +263,9 @@ func (c *BaseHttpClient) Do(req *http.Request, options ...DoOption) (*http.Respo } if req.Method == http.MethodGet && resp.StatusCode == http.StatusOK { - cacheErr := c.baseHttpCache.Set(cacheKey, resp) + cacheErr := c.baseHttpCache.Set(req, resp) if cacheErr != nil { - l.Warn("error setting cache", zap.String("cacheKey", cacheKey), zap.String("url", req.URL.String()), zap.Error(cacheErr)) + l.Warn("error setting cache", zap.String("url", req.URL.String()), zap.Error(cacheErr)) } } diff --git a/pkg/uhttp/wrapper_test.go b/pkg/uhttp/wrapper_test.go index 6386e072..f86705b0 100644 --- a/pkg/uhttp/wrapper_test.go +++ b/pkg/uhttp/wrapper_test.go @@ -432,8 +432,8 @@ func TestWrapperConfig(t *testing.T) { options: nil, cc: CacheConfig{ LogDebug: true, - CacheTTL: int32(1000), - CacheMaxSize: int(1024), + CacheTTL: 1800, + CacheMaxSize: 1024, }, expected: expected{ method: http.MethodGet, @@ -450,8 +450,8 @@ func TestWrapperConfig(t *testing.T) { options: []RequestOption{WithJSONBody(exampleBody), WithAcceptJSONHeader()}, cc: CacheConfig{ LogDebug: true, - CacheTTL: int32(2000), - CacheMaxSize: int(0), + CacheTTL: 600, + CacheMaxSize: 0, }, expected: expected{ method: http.MethodPost, diff --git a/vendor/modules.txt b/vendor/modules.txt index 91d98c36..aed856e5 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -251,6 +251,8 @@ github.com/magiconair/properties # github.com/mattn/go-isatty v0.0.20 ## explicit; go 1.15 github.com/mattn/go-isatty +# github.com/mattn/go-sqlite3 v1.14.22 +## explicit; go 1.19 # github.com/mitchellh/mapstructure v1.5.0 ## explicit; go 1.14 github.com/mitchellh/mapstructure