Skip to content

Commit

Permalink
feat: Add Prometheus metric to track exec querys per transaction
Browse files Browse the repository at this point in the history
* Added two counter to track exec and query statements seperately.
* UTs added for the same.
  • Loading branch information
yvardhineni committed Jun 18, 2024
1 parent 1bff9ed commit 612f262
Show file tree
Hide file tree
Showing 7 changed files with 287 additions and 22 deletions.
28 changes: 22 additions & 6 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ type managedConn struct {
killed bool
mu sync.RWMutex

execQueryCounter int
execStmtsCounter int // count the number of exec calls in a transaction
queryStmtsCounter int // count the number of query calls in a transaction
}

// BeginTx calls the underlying BeginTx method unless the supervising context
Expand Down Expand Up @@ -78,6 +79,7 @@ func (c *managedConn) Exec(query string, args []driver.Value) (driver.Result, er
if !ok {
return nil, driver.ErrSkip
}
c.incExecStmtsCounter() //increment the exec counter to keep track of the number of exec calls
return conn.Exec(query, args)
}

Expand All @@ -86,7 +88,7 @@ func (c *managedConn) ExecContext(ctx context.Context, query string, args []driv
if !ok {
return nil, driver.ErrSkip
}
c.incExecQueryCounter() //increment the exec counter to keep track of the number of exec calls
c.incExecStmtsCounter() //increment the exec counter to keep track of the number of exec calls
return conn.ExecContext(ctx, query, args)
}

Expand All @@ -103,6 +105,7 @@ func (c *managedConn) Query(query string, args []driver.Value) (driver.Rows, err
if !ok {
return nil, driver.ErrSkip
}
c.incQueryStmtsCounter() //increment the query counter to keep track of the number of query calls
return conn.Query(query, args)
}

Expand All @@ -111,6 +114,7 @@ func (c *managedConn) QueryContext(ctx context.Context, query string, args []dri
if !ok {
return nil, driver.ErrSkip
}
c.incQueryStmtsCounter() //increment the query counter to keep track of the number of query calls
return conn.QueryContext(ctx, query, args)
}

Expand Down Expand Up @@ -193,14 +197,26 @@ func (c *managedConn) GetKill() bool {
return c.killed
}

func (c *managedConn) incExecQueryCounter() {
func (c *managedConn) incExecStmtsCounter() {
c.mu.Lock()
defer c.mu.Unlock()
c.execQueryCounter++
c.execStmtsCounter++
}

func (c *managedConn) resetExecQueryCounter() {
func (c *managedConn) resetExecStmtsCounter() {
c.mu.Lock()
defer c.mu.Unlock()
c.execQueryCounter = 0
c.execStmtsCounter = 0
}

func (c *managedConn) incQueryStmtsCounter() {
c.mu.Lock()
defer c.mu.Unlock()
c.queryStmtsCounter++
}

func (c *managedConn) resetQueryStmtsCounter() {
c.mu.Lock()
defer c.mu.Unlock()
c.queryStmtsCounter = 0
}
202 changes: 201 additions & 1 deletion conn_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
package hotload

import (
"context"
"database/sql/driver"
"io"
"strings"
"sync"

. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"sync"
"github.com/prometheus/client_golang/prometheus/testutil"
)

var _ = Describe("managedConn", func() {
Expand Down Expand Up @@ -34,3 +40,197 @@ var _ = Describe("managedConn", func() {
Consistently(readLockAcquired).Should(BeFalse())
})
})

/**** Mocks for Prometheus Metrics ****/

type mockDriverConn struct{}

type mockTx struct{}

func (mockTx) Commit() error {
return nil
}

func (mockTx) Rollback() error {
return nil
}

func (mockDriverConn) Prepare(query string) (driver.Stmt, error) {
return nil, nil
}

func (mockDriverConn) Begin() (driver.Tx, error) {
return mockTx{}, nil
}

func (mockDriverConn) Close() error {
return nil
}

func (mockDriverConn) IsValid() bool {
return true
}

func (mockDriverConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
return mockTx{}, nil
}

func (mockDriverConn) Exec(query string, args []driver.Value) (driver.Result, error) {
return nil, nil
}

func (mockDriverConn) Query(query string, args []driver.Value) (driver.Rows, error) {
return nil, nil
}

func (mockDriverConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
return nil, nil
}

func (mockDriverConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
return nil, nil
}

/**** End Mocks for Prometheus Metrics ****/

var _ = Describe("PrometheusMetrics", func() {
const help = `
# HELP transaction_sql_stmts_total The number of sql stmts called in a transaction by statement type per grpc service and method
# TYPE transaction_sql_stmts_total summary
`

var service1Metrics = `
transaction_sql_stmts_total_sum{grpc_method="service_1",grpc_service="method_1",stmt="exec"} 3
transaction_sql_stmts_total_count{grpc_method="service_1",grpc_service="method_1",stmt="exec"} 1
transaction_sql_stmts_total_sum{grpc_method="service_1",grpc_service="method_1",stmt="query"} 3
transaction_sql_stmts_total_count{grpc_method="service_1",grpc_service="method_1",stmt="query"} 1
`

var service2Metrics = `
transaction_sql_stmts_total_sum{grpc_method="service_2",grpc_service="method_2",stmt="exec"} 4
transaction_sql_stmts_total_count{grpc_method="service_2",grpc_service="method_2",stmt="exec"} 1
transaction_sql_stmts_total_sum{grpc_method="service_2",grpc_service="method_2",stmt="query"} 4
transaction_sql_stmts_total_count{grpc_method="service_2",grpc_service="method_2",stmt="query"} 1
`

var service1RerunMetrics = `
transaction_sql_stmts_total_sum{grpc_method="service_1",grpc_service="method_1",stmt="exec"} 4
transaction_sql_stmts_total_count{grpc_method="service_1",grpc_service="method_1",stmt="exec"} 2
transaction_sql_stmts_total_sum{grpc_method="service_1",grpc_service="method_1",stmt="query"} 4
transaction_sql_stmts_total_count{grpc_method="service_1",grpc_service="method_1",stmt="query"} 2
`

var noMethodMetrics = `
transaction_sql_stmts_total_sum{grpc_method="",grpc_service="",stmt="exec"} 1
transaction_sql_stmts_total_count{grpc_method="",grpc_service="",stmt="exec"} 1
transaction_sql_stmts_total_sum{grpc_method="",grpc_service="",stmt="query"} 1
transaction_sql_stmts_total_count{grpc_method="",grpc_service="",stmt="query"} 1
`

It("Should emit the correct metrics", func() {
mc := newManagedConn(context.Background(), mockDriverConn{})

ctx := ContextWithExecLabels(context.Background(), map[string]string{"grpc_method": "service_1", "grpc_service": "method_1"})

// begin a transaction
tx, err := mc.BeginTx(ctx, driver.TxOptions{})
Expect(err).ShouldNot(HaveOccurred())

// exec a statement
mc.Exec("INSERT INTO table (column) VALUES (?)", []driver.Value{"value"})

// query a statement
mc.Query("SELECT * FROM table WHERE column = ?", []driver.Value{"value"})
mc.Query("SELECT * FROM table WHERE column = ?", []driver.Value{"value"})

// exec a statement with context
mc.ExecContext(ctx, "INSERT INTO table (column) VALUES (?)", []driver.NamedValue{{Value: "value"}})
mc.ExecContext(ctx, "INSERT INTO table (column) VALUES (?)", []driver.NamedValue{{Value: "value"}})

// query a statement with context
mc.QueryContext(ctx, "SELECT * FROM table WHERE column = ?", []driver.NamedValue{{Value: "value"}})

// commit the transaction
err = tx.Commit()
Expect(err).ShouldNot(HaveOccurred())

// collect and compare metrics
err = testutil.CollectAndCompare(sqlStmtsSummary, strings.NewReader(help+service1Metrics))
Expect(err).ShouldNot(HaveOccurred())

// reset the metrics
// new context
ctx = ContextWithExecLabels(context.Background(), map[string]string{"grpc_method": "service_2", "grpc_service": "method_2"})
// begin a transaction
tx, err = mc.BeginTx(ctx, driver.TxOptions{})
Expect(err).ShouldNot(HaveOccurred())

// exec a statement
mc.Exec("INSERT INTO table (column) VALUES (?)", []driver.Value{"value"})
mc.Exec("INSERT INTO table (column) VALUES (?)", []driver.Value{"value"})

// query a statement
mc.Query("SELECT * FROM table WHERE column = ?", []driver.Value{"value"})
mc.Query("SELECT * FROM table WHERE column = ?", []driver.Value{"value"})

// exec a statement with context
mc.ExecContext(ctx, "INSERT INTO table (column) VALUES (?)", []driver.NamedValue{{Value: "value"}})
mc.ExecContext(ctx, "INSERT INTO table (column) VALUES (?)", []driver.NamedValue{{Value: "value"}})

// query a statement with context
mc.QueryContext(ctx, "SELECT * FROM table WHERE column = ?", []driver.NamedValue{{Value: "value"}})
mc.QueryContext(ctx, "SELECT * FROM table WHERE column = ?", []driver.NamedValue{{Value: "value"}})

// commit the transaction
err = tx.Commit()
Expect(err).ShouldNot(HaveOccurred())

// collect and compare metrics
err = testutil.CollectAndCompare(sqlStmtsSummary, strings.NewReader(help+service1Metrics+service2Metrics))
Expect(err).ShouldNot(HaveOccurred())

// rerun with initial metrics
ctx = ContextWithExecLabels(context.Background(), map[string]string{"grpc_method": "service_1", "grpc_service": "method_1"})
// begin a transaction
tx, err = mc.BeginTx(ctx, driver.TxOptions{})
Expect(err).ShouldNot(HaveOccurred())

// exec a statement with context
mc.ExecContext(ctx, "INSERT INTO table (column) VALUES (?)", []driver.NamedValue{{Value: "value"}})

// query a statement with context
mc.QueryContext(ctx, "SELECT * FROM table WHERE column = ?", []driver.NamedValue{{Value: "value"}})

// rollback the transaction
err = tx.Rollback()
Expect(err).ShouldNot(HaveOccurred())

// collect and compare metrics
err = testutil.CollectAndCompare(sqlStmtsSummary, strings.NewReader(help+service1RerunMetrics+service2Metrics))
Expect(err).ShouldNot(HaveOccurred())

// non labeled context
ctx = context.Background()
// begin a transaction
tx, err = mc.BeginTx(ctx, driver.TxOptions{})
Expect(err).ShouldNot(HaveOccurred())

// exec query context
mc.ExecContext(ctx, "INSERT INTO table (column) VALUES (?)", []driver.NamedValue{{Value: "value"}})

// query a statement with context
mc.QueryContext(ctx, "SELECT * FROM table WHERE column = ?", []driver.NamedValue{{Value: "value"}})

// commit the transaction
err = tx.Commit()
Expect(err).ShouldNot(HaveOccurred())

// collect and compare metrics
err = testutil.CollectAndCompare(sqlStmtsSummary, strings.NewReader(help+noMethodMetrics+service1RerunMetrics+service2Metrics))
Expect(err).ShouldNot(HaveOccurred())
})
})

func CollectAndCompareMetrics(r io.Reader) error {
return testutil.CollectAndCompare(sqlStmtsSummary, r)
}
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ require (
require (
github.com/beorn7/perks v1.0.1 // indirect
github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/google/go-cmp v0.6.0 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/nxadm/tail v1.4.8 // indirect
Expand Down
7 changes: 4 additions & 3 deletions integrationtests/hotload_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@ package integrationtests
import (
"database/sql"
"fmt"
"io/ioutil"
"log"
"time"

"github.com/infobloxopen/hotload"
_ "github.com/infobloxopen/hotload/fsnotify"
"github.com/lib/pq"
_ "github.com/lib/pq"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"io/ioutil"
"log"
"time"
)

const (
Expand Down
17 changes: 10 additions & 7 deletions prometheus.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,20 @@ import (
const (
GRPCMethodKey = "grpc_method"
GRPCServiceKey = "grpc_service"
StatementKey = "stmt" // either exec or query
ExecStatement = "exec"
QueryStatement = "query"
)

// execQuerySummary is a prometheus metric to keep track of the number of times
// exec query is called in a transaction
var execQuerySummary = prometheus.NewSummaryVec(prometheus.SummaryOpts{
Name: "transaction_exec_query_total",
Help: "The number of times exec query is called in a transaction",
}, []string{GRPCServiceKey, GRPCMethodKey})
// sqlStmtsSummary is a prometheus metric to keep track of the number of times
// a sql statement is called in a transaction by statement type per grpc service
var sqlStmtsSummary = prometheus.NewSummaryVec(prometheus.SummaryOpts{
Name: "transaction_sql_stmts_total",
Help: "The number of sql stmts called in a transaction by statement type per grpc service and method",
}, []string{GRPCServiceKey, GRPCMethodKey, StatementKey})

func init() {
prometheus.MustRegister(execQuerySummary)
prometheus.MustRegister(sqlStmtsSummary)
}

// PromUnaryServerInterceptor returns a unary server interceptor that sets the
Expand Down
39 changes: 39 additions & 0 deletions prometheus_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package hotload

import (
"context"
"errors"

. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"github.com/prometheus/client_golang/prometheus"
"google.golang.org/grpc"
)

var _ = Describe("PrometheusMetric", func() {
It("Should register a prometheus metric", func() {
// This test is a placeholder for a real test
err := prometheus.Register(sqlStmtsSummary)
Expect(err).Should(HaveOccurred())
Expect(errors.As(err, &prometheus.AlreadyRegisteredError{})).Should(BeTrue())
})
})

var _ = Describe("PromUnaryServerInterceptor", func() {
It("Should return a unary server interceptor", func() {
validationHandler := func(ctx context.Context, req interface{}) (interface{}, error) {
labels := GetExecLabelsFromContext(ctx)

Expect(labels).ShouldNot(BeNil())
Expect(labels[GRPCMethodKey]).Should(Equal("List"))
Expect(labels[GRPCServiceKey]).Should(Equal("infoblox.service.SampleService"))

return nil, nil
}

promUnaryServerInterceptor := PromUnaryServerInterceptor()
promUnaryServerInterceptor(context.Background(), struct{}{}, &grpc.UnaryServerInfo{
FullMethod: "/infoblox.service.SampleService/List",
}, validationHandler)
})
})
Loading

0 comments on commit 612f262

Please sign in to comment.