diff --git a/go/vt/logutil/console_logger.go b/go/vt/logutil/console_logger.go index fe43985a8b6..b5a34244c28 100644 --- a/go/vt/logutil/console_logger.go +++ b/go/vt/logutil/console_logger.go @@ -6,9 +6,7 @@ import ( log "github.com/golang/glog" ) -// ConsoleLogger is a Logger that uses glog directly to log. -// We can't specify the depth of the stack trace, -// So we just find it and add it to the message. +// ConsoleLogger is a Logger that uses glog directly to log, at the right level. type ConsoleLogger struct{} // NewConsoleLogger returns a simple ConsoleLogger @@ -18,26 +16,17 @@ func NewConsoleLogger() ConsoleLogger { // Infof is part of the Logger interface func (cl ConsoleLogger) Infof(format string, v ...interface{}) { - file, line := fileAndLine(3) - vals := []interface{}{file, line} - vals = append(vals, v...) - log.Infof("%v:%v] "+format, vals...) + log.InfoDepth(2, fmt.Sprintf(format, v...)) } // Warningf is part of the Logger interface func (cl ConsoleLogger) Warningf(format string, v ...interface{}) { - file, line := fileAndLine(3) - vals := []interface{}{file, line} - vals = append(vals, v...) - log.Warningf("%v:%v] "+format, vals...) + log.WarningDepth(2, fmt.Sprintf(format, v...)) } // Errorf is part of the Logger interface func (cl ConsoleLogger) Errorf(format string, v ...interface{}) { - file, line := fileAndLine(3) - vals := []interface{}{file, line} - vals = append(vals, v...) - log.Errorf("%v:%v] "+format, vals...) + log.ErrorDepth(2, fmt.Sprintf(format, v...)) } // Printf is part of the Logger interface diff --git a/go/vt/logutil/throttled.go b/go/vt/logutil/throttled.go index 1ea612f27f1..867071f7b88 100644 --- a/go/vt/logutil/throttled.go +++ b/go/vt/logutil/throttled.go @@ -1,6 +1,7 @@ package logutil import ( + "fmt" "sync" "time" @@ -29,12 +30,12 @@ func NewThrottledLogger(name string, maxInterval time.Duration) *ThrottledLogger } } -type logFunc func(string, ...interface{}) +type logFunc func(int, ...interface{}) var ( - infof = log.Infof - warningf = log.Warningf - errorf = log.Errorf + infoDepth = log.InfoDepth + warningDepth = log.WarningDepth + errorDepth = log.ErrorDepth ) func (tl *ThrottledLogger) log(logF logFunc, format string, v ...interface{}) { @@ -45,7 +46,7 @@ func (tl *ThrottledLogger) log(logF logFunc, format string, v ...interface{}) { logWaitTime := tl.maxInterval - (now.Sub(tl.lastlogTime)) if logWaitTime < 0 { tl.lastlogTime = now - logF(tl.name+":"+format, v...) + logF(2, fmt.Sprintf(tl.name+":"+format, v...)) return } // If this is the first message to be skipped, start a goroutine @@ -55,7 +56,9 @@ func (tl *ThrottledLogger) log(logF logFunc, format string, v ...interface{}) { time.Sleep(d) tl.mu.Lock() defer tl.mu.Unlock() - logF("%v: skipped %v log messages", tl.name, tl.skippedCount) + // Because of the go func(), we lose the stack trace, + // so we just use the current line for this. + logF(0, fmt.Sprintf("%v: skipped %v log messages", tl.name, tl.skippedCount)) tl.skippedCount = 0 }(logWaitTime) } @@ -64,15 +67,15 @@ func (tl *ThrottledLogger) log(logF logFunc, format string, v ...interface{}) { // Infof logs an info if not throttled. func (tl *ThrottledLogger) Infof(format string, v ...interface{}) { - tl.log(infof, format, v...) + tl.log(infoDepth, format, v...) } // Warningf logs a warning if not throttled. func (tl *ThrottledLogger) Warningf(format string, v ...interface{}) { - tl.log(warningf, format, v...) + tl.log(warningDepth, format, v...) } // Errorf logs an error if not throttled. func (tl *ThrottledLogger) Errorf(format string, v ...interface{}) { - tl.log(errorf, format, v...) + tl.log(errorDepth, format, v...) } diff --git a/go/vt/logutil/throttled_test.go b/go/vt/logutil/throttled_test.go index f6b447f2f51..c8e0813c085 100644 --- a/go/vt/logutil/throttled_test.go +++ b/go/vt/logutil/throttled_test.go @@ -9,8 +9,8 @@ import ( func TestThrottledLogger(t *testing.T) { // Install a fake log func for testing. log := make(chan string) - infof = func(format string, args ...interface{}) { - log <- fmt.Sprintf(format, args...) + infoDepth = func(depth int, args ...interface{}) { + log <- fmt.Sprint(args...) } interval := 100 * time.Millisecond tl := NewThrottledLogger("name", interval) diff --git a/go/vt/tabletserver/query_executor.go b/go/vt/tabletserver/query_executor.go index 414484ac86a..021faabc4d0 100644 --- a/go/vt/tabletserver/query_executor.go +++ b/go/vt/tabletserver/query_executor.go @@ -252,7 +252,7 @@ func (qre *QueryExecutor) fetchMulti(pkRows [][]sqltypes.Value, limit int64) (re qre.logStats.CacheAbsent = absent qre.logStats.CacheMisses = misses - qre.logStats.QuerySources |= QUERY_SOURCE_ROWCACHE + qre.logStats.QuerySources |= QuerySourceRowcache tableInfo.hits.Add(hits) tableInfo.absent.Add(absent) @@ -538,7 +538,7 @@ func (qre *QueryExecutor) qFetch(logStats *SQLQueryStats, parsedQuery *sqlparser q.Result, q.Err = qre.execSQLNoPanic(conn, sql, false) } } else { - logStats.QuerySources |= QUERY_SOURCE_CONSOLIDATOR + logStats.QuerySources |= QuerySourceConsolidator startTime := time.Now() q.Wait() waitStats.Record("Consolidations", startTime) diff --git a/go/vt/tabletserver/query_executor_test.go b/go/vt/tabletserver/query_executor_test.go index 183af9135ee..f02522be8a5 100644 --- a/go/vt/tabletserver/query_executor_test.go +++ b/go/vt/tabletserver/query_executor_test.go @@ -6,7 +6,6 @@ package tabletserver import ( "fmt" - "html/template" "math/rand" "reflect" "testing" @@ -24,27 +23,6 @@ import ( "golang.org/x/net/context" ) -type fakeCallInfo struct { - remoteAddr string - username string -} - -func (fci *fakeCallInfo) RemoteAddr() string { - return fci.remoteAddr -} - -func (fci *fakeCallInfo) Username() string { - return fci.username -} - -func (fci *fakeCallInfo) Text() string { - return "" -} - -func (fci *fakeCallInfo) HTML() template.HTML { - return template.HTML("") -} - func TestQueryExecutorPlanDDL(t *testing.T) { db := setUpQueryExecutorTest() query := "alter table test_table add zipcode int" diff --git a/go/vt/tabletserver/streamlogger.go b/go/vt/tabletserver/sqlquery_stats.go similarity index 92% rename from go/vt/tabletserver/streamlogger.go rename to go/vt/tabletserver/sqlquery_stats.go index 32c30667bdf..8394b05380f 100644 --- a/go/vt/tabletserver/streamlogger.go +++ b/go/vt/tabletserver/sqlquery_stats.go @@ -23,9 +23,12 @@ import ( var SqlQueryLogger = streamlog.New("SqlQuery", 50) const ( - QUERY_SOURCE_ROWCACHE = 1 << iota - QUERY_SOURCE_CONSOLIDATOR - QUERY_SOURCE_MYSQL + // QuerySourceRowcache means query result is found in rowcache. + QuerySourceRowcache = 1 << iota + // QuerySourceConsolidator means query result is found in consolidator. + QuerySourceConsolidator + // QuerySourceMySQL means query result is returned from MySQL. + QuerySourceMySQL ) // SQLQueryStats records the stats for a single query @@ -68,7 +71,7 @@ func (stats *SQLQueryStats) Send() { // AddRewrittenSql adds a single sql statement to the rewritten list func (stats *SQLQueryStats) AddRewrittenSql(sql string, start time.Time) { - stats.QuerySources |= QUERY_SOURCE_MYSQL + stats.QuerySources |= QuerySourceMySQL stats.NumberOfQueries++ stats.rewrittenSqls = append(stats.rewrittenSqls, sql) stats.MysqlResponseTime += time.Now().Sub(start) @@ -140,15 +143,15 @@ func (stats *SQLQueryStats) FmtQuerySources() string { } sources := make([]string, 3) n := 0 - if stats.QuerySources&QUERY_SOURCE_MYSQL != 0 { + if stats.QuerySources&QuerySourceMySQL != 0 { sources[n] = "mysql" n++ } - if stats.QuerySources&QUERY_SOURCE_ROWCACHE != 0 { + if stats.QuerySources&QuerySourceRowcache != 0 { sources[n] = "rowcache" n++ } - if stats.QuerySources&QUERY_SOURCE_CONSOLIDATOR != 0 { + if stats.QuerySources&QuerySourceConsolidator != 0 { sources[n] = "consolidator" n++ } diff --git a/go/vt/tabletserver/sqlquery_stats_test.go b/go/vt/tabletserver/sqlquery_stats_test.go new file mode 100644 index 00000000000..c5cc81a4505 --- /dev/null +++ b/go/vt/tabletserver/sqlquery_stats_test.go @@ -0,0 +1,142 @@ +// Copyright 2015, Google Inc. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tabletserver + +import ( + "fmt" + "net/url" + "strings" + "testing" + "time" + + "github.com/youtube/vitess/go/sqltypes" + "github.com/youtube/vitess/go/vt/callinfo" + "golang.org/x/net/context" +) + +func TestSqlQueryStats(t *testing.T) { + logStats := newSqlQueryStats("test", context.Background()) + logStats.AddRewrittenSql("sql1", time.Now()) + + if !strings.Contains(logStats.RewrittenSql(), "sql1") { + t.Fatalf("RewrittenSql should contains sql: sql1") + } + + if logStats.SizeOfResponse() != 0 { + t.Fatalf("there is no rows in log stats, estimated size should be 0 bytes") + } + + logStats.Rows = [][]sqltypes.Value{[]sqltypes.Value{sqltypes.MakeString([]byte("a"))}} + if logStats.SizeOfResponse() <= 0 { + t.Fatalf("log stats has some rows, should have positive response size") + } + + params := map[string][]string{"full": []string{}} + + logStats.Format(url.Values(params)) +} + +func TestSqlQueryStatsFormatBindVariables(t *testing.T) { + logStats := newSqlQueryStats("test", context.Background()) + logStats.BindVariables = make(map[string]interface{}) + logStats.BindVariables["key_1"] = "val_1" + logStats.BindVariables["key_2"] = 789 + + formattedStr := logStats.FmtBindVariables(true) + if !strings.Contains(formattedStr, "key_1") || + !strings.Contains(formattedStr, "val_1") { + t.Fatalf("bind variable 'key_1': 'val_1' is not formatted") + } + if !strings.Contains(formattedStr, "key_2") || + !strings.Contains(formattedStr, "789") { + t.Fatalf("bind variable 'key_2': '789' is not formatted") + } + + logStats.BindVariables["key_3"] = []byte("val_3") + formattedStr = logStats.FmtBindVariables(false) + if !strings.Contains(formattedStr, "key_1") { + t.Fatalf("bind variable 'key_1' is not formatted") + } + if !strings.Contains(formattedStr, "key_2") || + !strings.Contains(formattedStr, "789") { + t.Fatalf("bind variable 'key_2': '789' is not formatted") + } + if !strings.Contains(formattedStr, "key_3") { + t.Fatalf("bind variable 'key_3' is not formatted") + } +} + +func TestSqlQueryStatsFormatQuerySources(t *testing.T) { + logStats := newSqlQueryStats("test", context.Background()) + if logStats.FmtQuerySources() != "none" { + t.Fatalf("should return none since log stats does not have any query source, but got: %s", logStats.FmtQuerySources()) + } + + logStats.QuerySources |= QuerySourceMySQL + if !strings.Contains(logStats.FmtQuerySources(), "mysql") { + t.Fatalf("'mysql' should be in formated query sources") + } + + logStats.QuerySources |= QuerySourceRowcache + if !strings.Contains(logStats.FmtQuerySources(), "rowcache") { + t.Fatalf("'rowcache' should be in formated query sources") + } + + logStats.QuerySources |= QuerySourceConsolidator + if !strings.Contains(logStats.FmtQuerySources(), "consolidator") { + t.Fatalf("'consolidator' should be in formated query sources") + } +} + +func TestSqlQueryStatsContextHTML(t *testing.T) { + html := "HtmlContext" + callInfo := &fakeCallInfo{ + html: html, + } + ctx := callinfo.NewContext(context.Background(), callInfo) + logStats := newSqlQueryStats("test", ctx) + if string(logStats.ContextHTML()) != html { + t.Fatalf("expect to get html: %s, but got: %s", html, string(logStats.ContextHTML())) + } +} + +func TestSqlQueryStatsErrorStr(t *testing.T) { + logStats := newSqlQueryStats("test", context.Background()) + if logStats.ErrorStr() != "" { + t.Fatalf("should not get error in stats, but got: %s", logStats.ErrorStr()) + } + errStr := "unknown error" + logStats.Error = fmt.Errorf(errStr) + if logStats.ErrorStr() != errStr { + t.Fatalf("expect to get error string: %s, but got: %s", errStr, logStats.ErrorStr()) + } +} + +func TestSqlQueryStatsRemoteAddrUsername(t *testing.T) { + logStats := newSqlQueryStats("test", context.Background()) + addr, user := logStats.RemoteAddrUsername() + if addr != "" { + t.Fatalf("remote addr should be empty") + } + if user != "" { + t.Fatalf("username should be empty") + } + + remoteAddr := "1.2.3.4" + username := "vt" + callInfo := &fakeCallInfo{ + remoteAddr: remoteAddr, + username: username, + } + ctx := callinfo.NewContext(context.Background(), callInfo) + logStats = newSqlQueryStats("test", ctx) + addr, user = logStats.RemoteAddrUsername() + if addr != remoteAddr { + t.Fatalf("expected to get remote addr: %s, but got: %s", remoteAddr, addr) + } + if user != username { + t.Fatalf("expected to get username: %s, but got: %s", username, user) + } +} diff --git a/go/vt/tabletserver/testutils_test.go b/go/vt/tabletserver/testutils_test.go new file mode 100644 index 00000000000..b8fdc107074 --- /dev/null +++ b/go/vt/tabletserver/testutils_test.go @@ -0,0 +1,30 @@ +// Copyright 2015, Google Inc. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tabletserver + +import "html/template" + +type fakeCallInfo struct { + remoteAddr string + username string + text string + html string +} + +func (fci *fakeCallInfo) RemoteAddr() string { + return fci.remoteAddr +} + +func (fci *fakeCallInfo) Username() string { + return fci.username +} + +func (fci *fakeCallInfo) Text() string { + return fci.text +} + +func (fci *fakeCallInfo) HTML() template.HTML { + return template.HTML(fci.html) +} diff --git a/go/vt/wrangler/keyspace.go b/go/vt/wrangler/keyspace.go index 636ae3ff44f..51688314192 100644 --- a/go/vt/wrangler/keyspace.go +++ b/go/vt/wrangler/keyspace.go @@ -168,7 +168,7 @@ func (wr *Wrangler) MigrateServedTypes(ctx context.Context, keyspace, shard stri // rebuild the keyspace serving graph if there was no error if !rec.HasErrors() { - rec.RecordError(wr.RebuildKeyspaceGraph(ctx, keyspace, nil)) + rec.RecordError(wr.RebuildKeyspaceGraph(ctx, keyspace, cells)) } // Send a refresh to the tablets we just disabled, iff: diff --git a/py/vtdb/database_context.py b/py/vtdb/database_context.py index c46687a4d06..7de1b8be5f3 100644 --- a/py/vtdb/database_context.py +++ b/py/vtdb/database_context.py @@ -27,7 +27,6 @@ #TODO: verify that these values make sense. DEFAULT_CONNECTION_TIMEOUT = 5.0 -DEFAULT_QUERY_TIMEOUT = 15.0 __app_read_only_mode_method = lambda:False __vtgate_connect_method = vtgatev2.connect @@ -64,9 +63,8 @@ def __init__(self, vtgate_addrs=None, lag_tolerant_mode=False, master_access_dis self.vtgate_connection = None self.change_master_read_to_replica = False self._transaction_stack_depth = 0 - self.connection_timeout = DEFAULT_CONNECTION_TIMEOUT - self.query_timeout = DEFAULT_QUERY_TIMEOUT self.event_logger = vtdb_logger.get_logger() + self.connection_timeout = DEFAULT_CONNECTION_TIMEOUT self._tablet_type = None @property diff --git a/py/vtdb/db_object.py b/py/vtdb/db_object.py index e2847664073..bbc7629ca7d 100644 --- a/py/vtdb/db_object.py +++ b/py/vtdb/db_object.py @@ -50,6 +50,19 @@ def _is_iterable_container(x): return hasattr(x, '__iter__') +INSERT_KW = "insert" +UPDATE_KW = "update" +DELETE_KW = "delete" + + +def is_dml(sql): + first_kw = sql.split(' ')[0] + first_kw = first_kw.lower() + if first_kw == INSERT_KW or first_kw == UPDATE_KW or first_kw == DELETE_KW: + return True + return False + + def create_cursor_from_params(vtgate_conn, tablet_type, is_dml, table_class): """This method creates the cursor from the required params. @@ -110,6 +123,27 @@ def create_stream_cursor_from_cursor(original_cursor): return stream_cursor +def create_batch_cursor_from_cursor(original_cursor, writable=False): + """ + This method creates a batch cursor from a regular cursor. + + Args: + original_cursor: Cursor of VTGateCursor type + + Returns: + Returns BatchVTGateCursor that has same attributes as original_cursor. + """ + if not isinstance(original_cursor, vtgate_cursor.VTGateCursor): + raise dbexceptions.ProgrammingError( + "Original cursor should be of VTGateCursor type.") + batch_cursor = vtgate_cursor.BatchVTGateCursor( + original_cursor._conn, original_cursor.keyspace, + original_cursor.tablet_type, + keyspace_ids=original_cursor.keyspace_ids, + writable=writable) + return batch_cursor + + def db_wrapper(method): """Decorator that is used to create the appropriate cursor for the table and call the database method with it. @@ -140,6 +174,80 @@ def db_class_method(*pargs, **kargs): return classmethod(db_wrapper(*pargs, **kargs)) +def execute_batch_read(cursor, query_list, bind_vars_list): + """Method for executing select queries in batch. + + Args: + cursor: original cursor - that is converted to read-only BatchVTGateCursor. + query_list: query_list. + bind_vars_list: bind variables list. + + Returns: + Result of the form [[q1row1, q1row2,...], [q2row1, ...],..] + + Raises: + dbexceptions.ProgrammingError when dmls are issued to read batch cursor. + """ + if not isinstance(cursor, vtgate_cursor.VTGateCursor): + raise dbexceptions.ProgrammingError( + "cursor is not of the type VTGateCursor.") + batch_cursor = create_batch_cursor_from_cursor(cursor) + for q, bv in zip(query_list, bind_vars_list): + if is_dml(q): + raise dbexceptions.ProgrammingError("Dml %s for read batch cursor." % q) + batch_cursor.execute(q, bv) + + batch_cursor.flush() + rowsets = batch_cursor.rowsets + result = [] + # rowset is of the type [(results, rowcount, lastrowid, fields),..] + for rowset in rowsets: + rowset_results = rowset[0] + fields = [f[0] for f in rowset[3]] + rows = [] + for row in rowset_results: + rows.append(sql_builder.DBRow(fields, row)) + result.append(rows) + return result + + +def execute_batch_write(cursor, query_list, bind_vars_list): + """Method for executing dml queries in batch. + + Args: + cursor: original cursor - that is converted to read-only BatchVTGateCursor. + query_list: query_list. + bind_vars_list: bind variables list. + + Returns: + Result of the form [{'rowcount':rowcount, 'lastrowid':lastrowid}, ...] + since for dmls those two values are valuable. + + Raises: + dbexceptions.ProgrammingError when non-dmls are issued to writable batch cursor. + """ + if not isinstance(cursor, vtgate_cursor.VTGateCursor): + raise dbexceptions.ProgrammingError( + "cursor is not of the type VTGateCursor.") + batch_cursor = create_batch_cursor_from_cursor(cursor, writable=True) + if batch_cursor.is_writable() and len(batch_cursor.keyspace_ids) != 1: + raise dbexceptions.ProgrammingError( + "writable batch execute can also execute on one keyspace_id.") + for q, bv in zip(query_list, bind_vars_list): + if not is_dml(q): + raise dbexceptions.ProgrammingError("query %s is not a dml" % q) + batch_cursor.execute(q, bv) + + batch_cursor.flush() + + rowsets = batch_cursor.rowsets + result = [] + # rowset is of the type [(results, rowcount, lastrowid, fields),..] + for rowset in rowsets: + result.append({'rowcount':rowset[1], 'lastrowid':rowset[2]}) + return result + + class DBObjectBase(object): """Base class for db classes. @@ -179,58 +287,81 @@ def create_vtgate_cursor(class_, vtgate_conn, tablet_type, is_dml, **cursor_karg @db_class_method def select_by_columns(class_, cursor, where_column_value_pairs, - columns_list = None,order_by=None, group_by=None, - limit=None, **kwargs): + columns_list=None, order_by=None, group_by=None, + limit=None): + if columns_list is None: + columns_list = class_.columns_list + query, bind_vars = class_.create_select_query(where_column_value_pairs, + columns_list=columns_list, + order_by=order_by, + group_by=group_by, + limit=limit) + + rowcount = cursor.execute(query, bind_vars) + rows = cursor.fetchall() + return [sql_builder.DBRow(columns_list, row) for row in rows] + + @classmethod + def create_insert_query(class_, **bind_vars): + return sql_builder.insert_query(class_.table_name, + class_.columns_list, + **bind_vars) + + @classmethod + def create_update_query(class_, where_column_value_pairs, + update_column_value_pairs): + return sql_builder.update_columns_query( + class_.table_name, where_column_value_pairs, + update_column_value_pairs=update_column_value_pairs) + + @classmethod + def create_delete_query(class_, where_column_value_pairs, limit=None): + return sql_builder.delete_by_columns_query(class_.table_name, + where_column_value_pairs, + limit=limit) + + @classmethod + def create_select_query(class_, where_column_value_pairs, columns_list=None, + order_by=None, group_by=None, limit=None): if class_.columns_list is None: raise dbexceptions.ProgrammingError("DB class should define columns_list") if columns_list is None: columns_list = class_.columns_list + query, bind_vars = sql_builder.select_by_columns_query(columns_list, class_.table_name, where_column_value_pairs, order_by=order_by, group_by=group_by, - limit=limit, - **kwargs) - - rowcount = cursor.execute(query, bind_vars) - rows = cursor.fetchall() - return [sql_builder.DBRow(columns_list, row) for row in rows] + limit=limit) + return query, bind_vars @db_class_method def insert(class_, cursor, **bind_vars): if class_.columns_list is None: raise dbexceptions.ProgrammingError("DB class should define columns_list") - query, bind_vars = sql_builder.insert_query(class_.table_name, - class_.columns_list, - **bind_vars) + query, bind_vars = class_.create_insert_query(**bind_vars) cursor.execute(query, bind_vars) return cursor.lastrowid @db_class_method def update_columns(class_, cursor, where_column_value_pairs, - **update_columns): - - query, bind_vars = sql_builder.update_columns_query( - class_.table_name, where_column_value_pairs, **update_columns) + update_column_value_pairs): + query, bind_vars = class_.create_update_query( + where_column_value_pairs, + update_column_value_pairs=update_column_value_pairs) return cursor.execute(query, bind_vars) @db_class_method - def delete_by_columns(class_, cursor, where_column_value_pairs, limit=None, - **columns): - if not where_column_value_pairs: - where_column_value_pairs = columns.items() - where_column_value_pairs.sort() - + def delete_by_columns(class_, cursor, where_column_value_pairs, limit=None): if not where_column_value_pairs: raise dbexceptions.ProgrammingError("deleting the whole table is not allowed") - query, bind_vars = sql_builder.delete_by_columns_query(class_.table_name, - where_column_value_pairs, - limit=limit) + query, bind_vars = class_.create_delete_query(where_column_value_pairs, + limit=limit) cursor.execute(query, bind_vars) if cursor.rowcount == 0: raise dbexceptions.DatabaseError("DB Row not found") @@ -238,20 +369,13 @@ def delete_by_columns(class_, cursor, where_column_value_pairs, limit=None, @db_class_method def select_by_columns_streaming(class_, cursor, where_column_value_pairs, - columns_list = None,order_by=None, group_by=None, - limit=None, fetch_size=100, **kwargs): - if class_.columns_list is None: - raise dbexceptions.ProgrammingError("DB class should define columns_list") - - if columns_list is None: - columns_list = class_.columns_list - query, bind_vars = sql_builder.select_by_columns_query(columns_list, - class_.table_name, - where_column_value_pairs, - order_by=order_by, - group_by=group_by, - limit=limit, - **kwargs) + columns_list=None, order_by=None, group_by=None, + limit=None, fetch_size=100): + query, bind_vars = class_.create_select_query(where_column_value_pairs, + columns_list=columns_list, + order_by=order_by, + group_by=group_by, + limit=limit) return class_._stream_fetch(cursor, query, bind_vars, fetch_size) @@ -273,6 +397,8 @@ def _stream_fetch(class_, cursor, query, bind_vars, fetch_size=100): break stream_cursor.close() + + @db_class_method def get_count(class_, cursor, column_value_pairs=None, **columns): if not column_value_pairs: diff --git a/py/vtdb/db_object_range_sharded.py b/py/vtdb/db_object_range_sharded.py index 8d297346ff5..4bfee97e305 100644 --- a/py/vtdb/db_object_range_sharded.py +++ b/py/vtdb/db_object_range_sharded.py @@ -214,7 +214,7 @@ def lookup_sharding_key_from_entity_id(class_, cursor_method, entity_id_column, @db_object.db_class_method def select_by_ids(class_, cursor, where_column_value_pairs, - columns_list = None,order_by=None, group_by=None, + columns_list=None, order_by=None, group_by=None, limit=None, **kwargs): """This method is used to perform in-clause queries. @@ -223,19 +223,15 @@ def select_by_ids(class_, cursor, where_column_value_pairs, column and the associated entity_keyspace_id_map is computed based on the routing used - sharding_key or entity_id_map. """ - - if class_.columns_list is None: - raise dbexceptions.ProgrammingError("DB class should define columns_list") - if columns_list is None: columns_list = class_.columns_list - query, bind_vars = sql_builder.select_by_columns_query(columns_list, - class_.table_name, - where_column_value_pairs, - order_by=order_by, - group_by=group_by, - limit=limit, - **kwargs) + + query, bind_vars = class_.create_select_query(where_column_value_pairs, + columns_list=columns_list, + order_by=order_by, + group_by=group_by, + limit=limit, + **kwargs) entity_col_name = None entity_id_keyspace_id_map = {} diff --git a/py/vtdb/sql_builder.py b/py/vtdb/sql_builder.py index 4246f8d0750..36a2a3c29f8 100644 --- a/py/vtdb/sql_builder.py +++ b/py/vtdb/sql_builder.py @@ -186,12 +186,7 @@ def update_bindvars(newvars): def select_by_columns_query(select_column_list, table_name, column_value_pairs=None, order_by=None, group_by=None, limit=None, for_update=False,client_aggregate=False, - vt_routing_info=None, **columns): - - # generate WHERE clause and bind variables - if not column_value_pairs: - column_value_pairs = columns.items() - column_value_pairs.sort() + vt_routing_info=None): if client_aggregate: clause_list = [select_clause(select_column_list, table_name, @@ -199,6 +194,7 @@ def select_by_columns_query(select_column_list, table_name, column_value_pairs=N else: clause_list = [select_clause(select_column_list, table_name)] + # generate WHERE clause and bind variables if column_value_pairs: where_clause, bind_vars = build_where_clause(column_value_pairs) # add vt routing info @@ -224,11 +220,10 @@ def select_by_columns_query(select_column_list, table_name, column_value_pairs=N return query, bind_vars def update_columns_query(table_name, where_column_value_pairs=None, - update_column_value_pairs=None, limit=None, - order_by=None, **update_columns): + update_column_value_pairs=None, limit=None, + order_by=None): if not update_column_value_pairs: - update_column_value_pairs = update_columns.items() - update_column_value_pairs.sort() + raise dbexceptions.ProgrammingError("No update values specified.") clause_list = [] bind_vals = {} diff --git a/py/vtdb/vtgate_cursor.py b/py/vtdb/vtgate_cursor.py index 2423bfbe36d..0624d6717d8 100644 --- a/py/vtdb/vtgate_cursor.py +++ b/py/vtdb/vtgate_cursor.py @@ -191,8 +191,16 @@ def next(self): class BatchVTGateCursor(VTGateCursor): + """Batch Cursor for VTGate. + + This cursor allows 'n' queries to be executed against + 'm' keyspace_ids. For writes though, it maybe prefereable + to only execute against one keyspace_id. + This only supports keyspace_ids right now since that is what + the underlying vtgate server supports. + """ def __init__(self, connection, keyspace, tablet_type, keyspace_ids=None, - keyranges=None, writable=False): + writable=False): # rowset is [(results, rowcount, lastrowid, fields),] self.rowsets = None self.query_list = [] diff --git a/test/client_test.py b/test/client_test.py index 152fcd5a98f..226965a1dd1 100644 --- a/test/client_test.py +++ b/test/client_test.py @@ -23,6 +23,7 @@ from clientlib_tests import db_class_lookup from vtdb import database_context +from vtdb import db_object from vtdb import keyrange from vtdb import keyrange_constants from vtdb import keyspace @@ -205,6 +206,7 @@ def populate_table(): cursor.execute('insert into vt_unsharded (id, msg) values (%s, %s)' % (str(x), 'msg'), {}) cursor.commit() + class TestUnshardedTable(unittest.TestCase): def setUp(self): @@ -233,9 +235,10 @@ def test_update_and_read(self): id_val = self.all_ids[0] where_column_value_pairs = [('id', id_val)] with database_context.WriteTransaction(self.dc) as context: + update_cols = [('msg', "test update"),] db_class_unsharded.VtUnsharded.update_columns(context.get_cursor(), where_column_value_pairs, - msg="test update") + update_column_value_pairs=update_cols) with database_context.ReadFromMaster(self.dc) as context: rows = db_class_unsharded.VtUnsharded.select_by_id(context.get_cursor(), id_val) @@ -277,7 +280,6 @@ def test_max_id(self): self.assertEqual(max_id, expected, "wrong max value fetched; expected %d got %d" % (expected, max_id)) - class TestRangeSharded(unittest.TestCase): def populate_tables(self): self.user_id_list = [] @@ -484,9 +486,10 @@ def update_columns(self): where_column_value_pairs = [('id', user_id),] entity_id_map = {'id': user_id} new_username = 'new_user%s' % user_id + update_cols = [('username', new_username),] db_class_sharded.VtUser.update_columns(context.get_cursor(entity_id_map=entity_id_map), where_column_value_pairs, - username=new_username) + update_column_value_pairs=update_cols) # verify the updated value. where_column_value_pairs = [('id', user_id),] rows = db_class_sharded.VtUser.select_by_columns( @@ -501,10 +504,10 @@ def update_columns(self): m = hashlib.md5() m.update(new_email) email_hash = m.digest() + update_cols = [('email', new_email), ('email_hash', email_hash)] db_class_sharded.VtUserEmail.update_columns(context.get_cursor(entity_id_map={'user_id':user_id}), where_column_value_pairs, - email=new_email, - email_hash=email_hash) + update_column_value_pairs=update_cols) # verify the updated value. with database_context.ReadFromMaster(self.dc) as context: @@ -569,6 +572,78 @@ def test_max_id(self): expected = max(self.user_id_list) self.assertEqual(max_id, expected, "wrong max value fetched; expected %d got %d" % (expected, max_id)) + def test_batch_read(self): + query_list = [] + bv_list = [] + user_id_list = [self.user_id_list[0], self.user_id_list[1]] + where_column_value_pairs = (('id', user_id_list),) + entity_id_map = dict(where_column_value_pairs) + q, bv = db_class_sharded.VtUser.create_select_query(where_column_value_pairs) + query_list.append(q) + bv_list.append(bv) + where_column_value_pairs = (('user_id', user_id_list),) + q, bv = db_class_sharded.VtUserEmail.create_select_query(where_column_value_pairs) + query_list.append(q) + bv_list.append(bv) + with database_context.ReadFromMaster(self.dc) as context: + cursor = context.get_cursor(entity_id_map=entity_id_map)(db_class_sharded.VtUser) + results = db_object.execute_batch_read( + cursor, query_list, bv_list) + self.assertEqual(len(results), len(query_list)) + res_ids = [row.id for row in results[0]] + res_user_ids = [row.user_id for row in results[1]] + self.assertEqual(res_ids, user_id_list) + self.assertEqual(res_user_ids, user_id_list) + + def test_batch_write(self): + # 1. Create DMLs using DB Classes. + query_list = [] + bv_list = [] + # Update VtUser table. + user_id = self.user_id_list[1] + where_column_value_pairs = (('id', user_id),) + entity_id_map = dict(where_column_value_pairs) + new_username = 'new_user%s' % user_id + update_cols = [('username', new_username),] + q, bv = db_class_sharded.VtUser.create_update_query( + where_column_value_pairs, update_column_value_pairs=update_cols) + query_list.append(q) + bv_list.append(bv) + # Update VtUserEmail table. + where_column_value_pairs = [('user_id', user_id),] + new_email = 'new_user%s@google.com' % user_id + m = hashlib.md5() + m.update(new_email) + email_hash = m.digest() + update_cols = [('email', new_email), ('email_hash', email_hash)] + q, bv = db_class_sharded.VtUserEmail.create_update_query( + where_column_value_pairs, update_column_value_pairs=update_cols) + query_list.append(q) + bv_list.append(bv) + # Delete a VtSong row + where_column_value_pairs = [('user_id', user_id),] + q, bv = db_class_sharded.VtSong.create_delete_query(where_column_value_pairs) + query_list.append(q) + bv_list.append(bv) + with database_context.WriteTransaction(self.dc) as context: + # 2. Routing for query_list is done by associating + # the common entity_id to the cursor. + # NOTE: cursor creation needs binding to a particular db class, + # so we create a writable cursor using the common entity (user_id). + # This entity_id is used to derive the keyspace_id for routing the dmls. + entity_id_map = {'id': user_id} + cursor = context.get_cursor(entity_id_map=entity_id_map)(db_class_sharded.VtUser) + # 3. Execute the writable batch query. + results = db_object.execute_batch_write( + cursor, query_list, bv_list) + + # 4. Verify results + self.assertEqual(len(results), len(query_list)) + self.assertEqual(results[0]['rowcount'], 1, "VtUser update didn't update 1 row") + self.assertEqual(results[1]['rowcount'], 1, "VtUserEmail update didn't update 1 row") + self.assertEqual(results[2]['rowcount'], len(self.user_song_map[user_id]), + "VtSong deleted '%d' rows, expected '%d'" % (results[2]['rowcount'], len(self.user_song_map[user_id]))) + if __name__ == '__main__': utils.main()