Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cmd/admin/handlers/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ type HandlersAdmin struct {
DB *gorm.DB
Users *users.UserManager
Tags *tags.TagManager
Envs *environments.Environment
Envs *environments.EnvManager
Nodes *nodes.NodeManager
Queries *queries.Queries
Carves *carves.Carves
Expand All @@ -55,7 +55,7 @@ func WithDB(db *gorm.DB) HandlersOption {
}
}

func WithEnvs(envs *environments.Environment) HandlersOption {
func WithEnvs(envs *environments.EnvManager) HandlersOption {
return func(h *HandlersAdmin) {
h.Envs = envs
}
Expand Down
2 changes: 1 addition & 1 deletion cmd/admin/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ var (
queriesmgr *queries.Queries
carvesmgr *carves.Carves
sessionsmgr *sessions.SessionManager
envs *environments.Environment
envs *environments.EnvManager
adminUsers *users.UserManager
tagsmgr *tags.TagManager
carvers3 *carves.CarverS3
Expand Down
4 changes: 2 additions & 2 deletions cmd/api/handlers/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ type HandlersApi struct {
DB *gorm.DB
Users *users.UserManager
Tags *tags.TagManager
Envs *environments.Environment
Envs *environments.EnvManager
Nodes *nodes.NodeManager
Queries *queries.Queries
Carves *carves.Carves
Expand Down Expand Up @@ -56,7 +56,7 @@ func WithTags(tags *tags.TagManager) HandlersOption {
}
}

func WithEnvs(envs *environments.Environment) HandlersOption {
func WithEnvs(envs *environments.EnvManager) HandlersOption {
return func(h *HandlersApi) {
h.Envs = envs
}
Expand Down
2 changes: 1 addition & 1 deletion cmd/api/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ var (
apiUsers *users.UserManager
tagsmgr *tags.TagManager
settingsmgr *settings.Settings
envs *environments.Environment
envs *environments.EnvManager
nodesmgr *nodes.NodeManager
queriesmgr *queries.Queries
filecarves *carves.Carves
Expand Down
2 changes: 1 addition & 1 deletion cmd/cli/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ var (
filecarves *carves.Carves
adminUsers *users.UserManager
tagsmgr *tags.TagManager
envs *environments.Environment
envs *environments.EnvManager
db *backend.DBManager
osctrlAPI *OsctrlAPI
formats map[string]bool
Expand Down
7 changes: 4 additions & 3 deletions cmd/tls/handlers/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ var validPlatform = map[string]bool{

// HandlersTLS to keep all handlers for TLS
type HandlersTLS struct {
Envs *environments.Environment
Envs *environments.EnvManager
EnvsMap *environments.MapEnvironments
EnvCache *environments.EnvCache
Nodes *nodes.NodeManager
Tags *tags.TagManager
Queries *queries.Queries
Expand All @@ -71,7 +72,7 @@ type TLSResponse struct {
type Option func(*HandlersTLS)

// WithEnvs to pass value as option
func WithEnvs(envs *environments.Environment) Option {
func WithEnvs(envs *environments.EnvManager) Option {
return func(h *HandlersTLS) {
h.Envs = envs
}
Expand Down Expand Up @@ -167,7 +168,7 @@ func CreateHandlersTLS(opts ...Option) *HandlersTLS {
for _, opt := range opts {
opt(h)
}
// All these opt function need be refactored to reduce unnecessary complexity
h.EnvCache = environments.NewEnvCache(*h.Envs)
return h
}

Expand Down
12 changes: 7 additions & 5 deletions cmd/tls/handlers/post.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package handlers

import (
"compress/gzip"
"context"
"encoding/json"
"fmt"
"io"
Expand Down Expand Up @@ -34,7 +35,7 @@ func (h *HandlersTLS) EnrollHandler(w http.ResponseWriter, r *http.Request) {
return
}
// Get environment
env, err := h.Envs.GetByUUID(envVar)
env, err := h.EnvCache.GetByUUID(context.TODO(), envVar)
if err != nil {
log.Err(err).Msg("error getting environment")
utils.HTTPResponse(w, "", http.StatusInternalServerError, []byte(""))
Expand Down Expand Up @@ -107,6 +108,7 @@ func (h *HandlersTLS) EnrollHandler(w http.ResponseWriter, r *http.Request) {

// ConfigHandler - Function to handle the configuration requests from osquery nodes
func (h *HandlersTLS) ConfigHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
var response interface{}
// Retrieve environment variable
envVar := r.PathValue("env")
Expand All @@ -120,7 +122,7 @@ func (h *HandlersTLS) ConfigHandler(w http.ResponseWriter, r *http.Request) {
return
}
// Get environment
env, err := h.Envs.GetByUUID(envVar)
env, err := h.EnvCache.GetByUUID(ctx, envVar)
if err != nil {
log.Err(err).Msg("error getting environment")
return
Expand Down Expand Up @@ -184,7 +186,7 @@ func (h *HandlersTLS) LogHandler(w http.ResponseWriter, r *http.Request) {
return
}
// Get environment
env, err := h.Envs.GetByUUID(envVar)
env, err := h.EnvCache.GetByUUID(context.TODO(), envVar)
if err != nil {
log.Err(err).Msg("error getting environment")
utils.HTTPResponse(w, "", http.StatusInternalServerError, []byte(""))
Expand Down Expand Up @@ -272,7 +274,7 @@ func (h *HandlersTLS) QueryReadHandler(w http.ResponseWriter, r *http.Request) {
return
}
// Get environment
env, err := h.Envs.GetByUUID(envVar)
env, err := h.EnvCache.GetByUUID(context.TODO(), envVar)
if err != nil {
log.Err(err).Msg("error getting environment")
utils.HTTPResponse(w, "", http.StatusInternalServerError, []byte(""))
Expand Down Expand Up @@ -350,7 +352,7 @@ func (h *HandlersTLS) QueryWriteHandler(w http.ResponseWriter, r *http.Request)
return
}
// Get environment
env, err := h.Envs.GetByUUID(envVar)
env, err := h.EnvCache.GetByUUID(context.TODO(), envVar)
if err != nil {
log.Err(err).Msg("error getting environment")
utils.HTTPResponse(w, "", http.StatusInternalServerError, []byte(""))
Expand Down
4 changes: 2 additions & 2 deletions cmd/tls/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ var (
db *backend.DBManager
redis *cache.RedisManager
settingsmgr *settings.Settings
envs *environments.Environment
envs *environments.EnvManager
envsmap environments.MapEnvironments
settingsmap settings.MapSettings
nodesmgr *nodes.NodeManager
Expand Down Expand Up @@ -242,6 +242,7 @@ func osctrlService() {
log.Info().Msg("Metrics are enabled")
// Register Prometheus metrics
handlers.RegisterMetrics(prometheus.DefaultRegisterer)
cache.RegisterMetrics(prometheus.DefaultRegisterer)
// Creating a new prometheus service
prometheusServer := http.NewServeMux()
prometheusServer.Handle("/metrics", promhttp.Handler())
Expand All @@ -268,7 +269,6 @@ func osctrlService() {
handlers.WithWriteHandler(tlsWriter),
handlers.WithDebugHTTP(&flagParams.DebugHTTPValues),
)

// ///////////////////////// ALL CONTENT IS UNAUTHENTICATED FOR TLS
log.Info().Msg("Initializing router")
// Create router for TLS endpoint
Expand Down
194 changes: 194 additions & 0 deletions pkg/cache/in-memory.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
package cache

import (
"context"
"sync"
"time"
)

// Item represents a cached item with expiration
type Item[T any] struct {
Value T
Expiration int64
}

// Cache interface defines methods that any cache implementation must provide
type Cache[T any] interface {
// Get retrieves an item from the cache by key
Get(ctx context.Context, key string) (T, bool)

// Set adds or updates an item in the cache with expiration
Set(ctx context.Context, key string, value T, duration time.Duration)

// Delete removes an item from the cache
Delete(ctx context.Context, key string)

// Clear removes all items from the cache
Clear(ctx context.Context)

// ItemCount returns the number of items in the cache
ItemCount() int

// Stop stops the cleanup goroutine
Stop()
}

// MemoryCacheOption is a function that configures a MemoryCache
type MemoryCacheOption[T any] func(*MemoryCache[T])

// WithCleanupInterval sets the interval for cleaning expired items
func WithCleanupInterval[T any](interval time.Duration) MemoryCacheOption[T] {
return func(mc *MemoryCache[T]) {
mc.cleanupInterval = interval
}
}

// WithName sets the name for the cache instance
func WithName[T any](name string) MemoryCacheOption[T] {
return func(mc *MemoryCache[T]) {
mc.name = name
}
}

// MemoryCache provides an in-memory implementation of the Cache interface
type MemoryCache[T any] struct {
items map[string]Item[T]
mutex sync.RWMutex
cleanupInterval time.Duration
stopCleanup chan struct{}
name string
}

// NewMemoryCache creates a new in-memory cache with the provided options
func NewMemoryCache[T any](opts ...MemoryCacheOption[T]) *MemoryCache[T] {
cache := &MemoryCache[T]{
items: make(map[string]Item[T]),
cleanupInterval: 5 * time.Minute,
stopCleanup: make(chan struct{}),
name: "default",
}

// Apply all provided options
for _, opt := range opts {
opt(cache)
}

CacheItems.WithLabelValues(cache.name).Set(0)

// Start cleanup routine to remove expired items
go cache.cleanupRoutine()

return cache
}

// Get retrieves an item from the cache
func (c *MemoryCache[T]) Get(ctx context.Context, key string) (T, bool) {
c.mutex.RLock()
defer c.mutex.RUnlock()

item, found := c.items[key]
if !found {
CacheMisses.WithLabelValues(c.name).Inc()
var zero T
return zero, false
}

// Check if item has expired
if item.Expiration > 0 && item.Expiration < time.Now().UnixNano() {
CacheMisses.WithLabelValues(c.name).Inc()
var zero T
return zero, false
}

CacheHits.WithLabelValues(c.name).Inc()
return item.Value, true
}

// Set adds an item to the cache with expiration
func (c *MemoryCache[T]) Set(ctx context.Context, key string, value T, duration time.Duration) {
var expiration int64

if duration > 0 {
expiration = time.Now().Add(duration).UnixNano()
}

c.mutex.Lock()
defer c.mutex.Unlock()

c.items[key] = Item[T]{
Value: value,
Expiration: expiration,
}

CacheItems.WithLabelValues(c.name).Set(float64(len(c.items)))
}

// Delete removes an item from the cache
func (c *MemoryCache[T]) Delete(ctx context.Context, key string) {
c.mutex.Lock()
defer c.mutex.Unlock()

delete(c.items, key)

CacheItems.WithLabelValues(c.name).Set(float64(len(c.items)))
}

// Clear removes all items from the cache
func (c *MemoryCache[T]) Clear(ctx context.Context) {
c.mutex.Lock()
defer c.mutex.Unlock()

c.items = make(map[string]Item[T])

CacheItems.WithLabelValues(c.name).Set(0)
}

// ItemCount returns the number of items in the cache
func (c *MemoryCache[T]) ItemCount() int {
c.mutex.RLock()
defer c.mutex.RUnlock()

return len(c.items)
}

// Stop stops the cleanup goroutine
func (c *MemoryCache[T]) Stop() {
close(c.stopCleanup)
}

// cleanupRoutine periodically cleans up expired items
func (c *MemoryCache[T]) cleanupRoutine() {
ticker := time.NewTicker(c.cleanupInterval)
defer ticker.Stop()

for {
select {
case <-ticker.C:
c.deleteExpired()
case <-c.stopCleanup:
return
}
}
}

// deleteExpired removes expired items from the cache
func (c *MemoryCache[T]) deleteExpired() {
now := time.Now().UnixNano()

c.mutex.Lock()
defer c.mutex.Unlock()

evictionCount := 0
for k, v := range c.items {
if v.Expiration > 0 && v.Expiration < now {
delete(c.items, k)
evictionCount++
}
}

if evictionCount > 0 {
// Update Prometheus metrics for evictions and item count
CacheEvictions.WithLabelValues(c.name).Add(float64(evictionCount))
CacheItems.WithLabelValues(c.name).Set(float64(len(c.items)))
}
}
Loading