diff --git a/go/cmd/dolt/cli/arg_parser_helpers.go b/go/cmd/dolt/cli/arg_parser_helpers.go index 5e0a3e7c588..99a61d29874 100644 --- a/go/cmd/dolt/cli/arg_parser_helpers.go +++ b/go/cmd/dolt/cli/arg_parser_helpers.go @@ -61,6 +61,7 @@ func CreateCommitArgParser(supportsBranchFlag bool) *argparser.ArgParser { ap.SupportsFlag(UpperCaseAllFlag, "A", "Adds all tables and databases (including new tables) in the working set to the staged set.") ap.SupportsFlag(AmendFlag, "", "Amend previous commit") ap.SupportsOptionalString(SignFlag, "S", "key-id", "Sign the commit using GPG. If no key-id is provided the key-id is taken from 'user.signingkey' the in the configuration") + ap.SupportsFlag(SkipVerificationFlag, "", "Skip commit verification") if supportsBranchFlag { ap.SupportsString(BranchParam, "", "branch", "Commit to the specified branch instead of the current branch.") } @@ -96,6 +97,7 @@ func CreateMergeArgParser() *argparser.ArgParser { ap.SupportsFlag(NoCommitFlag, "", "Perform the merge and stop just before creating a merge commit. Note this will not prevent a fast-forward merge; use the --no-ff arg together with the --no-commit arg to prevent both fast-forwards and merge commits.") ap.SupportsFlag(NoEditFlag, "", "Use an auto-generated commit message when creating a merge commit. The default for interactive CLI sessions is to open an editor.") ap.SupportsString(AuthorParam, "", "author", "Specify an explicit author using the standard A U Thor {{.LessThan}}author@example.com{{.GreaterThan}} format.") + ap.SupportsFlag(SkipVerificationFlag, "", "Skip commit verification before merge") return ap } @@ -116,6 +118,7 @@ func CreateRebaseArgParser() *argparser.ArgParser { ap.SupportsFlag(AbortParam, "", "Abort an interactive rebase and return the working set to the pre-rebase state") ap.SupportsFlag(ContinueFlag, "", "Continue an interactive rebase after adjusting the rebase plan") ap.SupportsFlag(InteractiveFlag, "i", "Start an interactive rebase") + ap.SupportsFlag(SkipVerificationFlag, "", "Skip commit verification before rebase") return ap } @@ -190,6 +193,7 @@ func CreateCherryPickArgParser() *argparser.ArgParser { ap.SupportsFlag(AllowEmptyFlag, "", "Allow empty commits to be cherry-picked. "+ "Note that use of this option only keeps commits that were initially empty. "+ "Commits which become empty, due to a previous commit, will cause cherry-pick to fail.") + ap.SupportsFlag(SkipVerificationFlag, "", "Skip commit verification before cherry-pick") ap.TooManyArgsErrorFunc = func(receivedArgs []string) error { return errors.New("cherry-picking multiple commits is not supported yet.") } @@ -227,6 +231,7 @@ func CreatePullArgParser() *argparser.ArgParser { ap.SupportsString(UserFlag, "", "user", "User name to use when authenticating with the remote. Gets password from the environment variable {{.EmphasisLeft}}DOLT_REMOTE_PASSWORD{{.EmphasisRight}}.") ap.SupportsFlag(PruneFlag, "p", "After fetching, remove any remote-tracking references that don't exist on the remote.") ap.SupportsFlag(SilentFlag, "", "Suppress progress information.") + ap.SupportsFlag(SkipVerificationFlag, "", "Skip commit verification before merge") return ap } diff --git a/go/cmd/dolt/cli/flags.go b/go/cmd/dolt/cli/flags.go index 737ea9fc7cd..b1ed484d7a4 100644 --- a/go/cmd/dolt/cli/flags.go +++ b/go/cmd/dolt/cli/flags.go @@ -78,6 +78,7 @@ const ( SilentFlag = "silent" SingleBranchFlag = "single-branch" SkipEmptyFlag = "skip-empty" + SkipVerificationFlag = "skip-verification" SoftResetParam = "soft" SquashParam = "squash" StagedFlag = "staged" diff --git a/go/cmd/dolt/commands/commit.go b/go/cmd/dolt/commands/commit.go index 738bc54e627..23258c148e8 100644 --- a/go/cmd/dolt/commands/commit.go +++ b/go/cmd/dolt/commands/commit.go @@ -266,6 +266,10 @@ func constructParametrizedDoltCommitQuery(msg string, apr *argparser.ArgParseRes writeToBuffer("--skip-empty") } + if apr.Contains(cli.SkipVerificationFlag) { + writeToBuffer("--skip-verification") + } + cfgSign := cliCtx.Config().GetStringOrDefault("sqlserver.global.gpgsign", "") if apr.Contains(cli.SignFlag) || strings.ToLower(cfgSign) == "true" { writeToBuffer("--gpg-sign") diff --git a/go/cmd/dolt/commands/merge.go b/go/cmd/dolt/commands/merge.go index d904657fe00..c28e9c65a71 100644 --- a/go/cmd/dolt/commands/merge.go +++ b/go/cmd/dolt/commands/merge.go @@ -318,6 +318,10 @@ func constructInterpolatedDoltMergeQuery(apr *argparser.ArgParseResults, cliCtx params = append(params, msg) } + if apr.Contains(cli.SkipVerificationFlag) { + writeToBuffer("--skip-verification", false) + } + if !apr.Contains(cli.AbortParam) && !apr.Contains(cli.SquashParam) { writeToBuffer("?", true) params = append(params, apr.Arg(0)) diff --git a/go/gen/fb/serial/workingset.go b/go/gen/fb/serial/workingset.go index ec71849a2a4..baea81dc77e 100644 --- a/go/gen/fb/serial/workingset.go +++ b/go/gen/fb/serial/workingset.go @@ -579,7 +579,19 @@ func (rcv *RebaseState) MutateRebasingStarted(n bool) bool { return rcv._tab.MutateBoolSlot(16, n) } -const RebaseStateNumFields = 7 +func (rcv *RebaseState) SkipVerification() bool { + o := flatbuffers.UOffsetT(rcv._tab.Offset(18)) + if o != 0 { + return rcv._tab.GetBool(o + rcv._tab.Pos) + } + return false +} + +func (rcv *RebaseState) MutateSkipVerification(n bool) bool { + return rcv._tab.MutateBoolSlot(18, n) +} + +const RebaseStateNumFields = 8 func RebaseStateStart(builder *flatbuffers.Builder) { builder.StartObject(RebaseStateNumFields) @@ -614,6 +626,9 @@ func RebaseStateAddLastAttemptedStep(builder *flatbuffers.Builder, lastAttempted func RebaseStateAddRebasingStarted(builder *flatbuffers.Builder, rebasingStarted bool) { builder.PrependBoolSlot(6, rebasingStarted, false) } +func RebaseStateAddSkipVerification(builder *flatbuffers.Builder, skipVerification bool) { + builder.PrependBoolSlot(7, skipVerification, false) +} func RebaseStateEnd(builder *flatbuffers.Builder) flatbuffers.UOffsetT { return builder.EndObject() } diff --git a/go/libraries/doltcore/cherry_pick/cherry_pick.go b/go/libraries/doltcore/cherry_pick/cherry_pick.go index c8012218a23..e7fd3e64059 100644 --- a/go/libraries/doltcore/cherry_pick/cherry_pick.go +++ b/go/libraries/doltcore/cherry_pick/cherry_pick.go @@ -52,6 +52,9 @@ type CherryPickOptions struct { // and Dolt cherry-pick implementations, the default action is to fail when an empty commit is specified. In Git // and Dolt rebase implementations, the default action is to keep commits that start off as empty. EmptyCommitHandling doltdb.EmptyCommitHandling + + // SkipVerification controls whether test validation should be skipped before creating commits. + SkipVerification bool } // NewCherryPickOptions creates a new CherryPickOptions instance, filled out with default values for cherry-pick. @@ -61,6 +64,7 @@ func NewCherryPickOptions() CherryPickOptions { CommitMessage: "", CommitBecomesEmptyHandling: doltdb.ErrorOnEmptyCommit, EmptyCommitHandling: doltdb.ErrorOnEmptyCommit, + SkipVerification: false, } } @@ -159,9 +163,10 @@ func CreateCommitStagedPropsFromCherryPickOptions(ctx *sql.Context, options Cher } commitProps := actions.CommitStagedProps{ - Date: originalMeta.Time(), - Name: originalMeta.Name, - Email: originalMeta.Email, + Date: originalMeta.Time(), + Name: originalMeta.Name, + Email: originalMeta.Email, + SkipVerification: options.SkipVerification, } if options.CommitMessage != "" { diff --git a/go/libraries/doltcore/doltdb/workingset.go b/go/libraries/doltcore/doltdb/workingset.go index f6529114fc8..5498271679e 100644 --- a/go/libraries/doltcore/doltdb/workingset.go +++ b/go/libraries/doltcore/doltdb/workingset.go @@ -75,6 +75,8 @@ type RebaseState struct { // rebasingStarted is true once the rebase plan has been started to execute. Once rebasingStarted is true, the // value in lastAttemptedStep has been initialized and is valid to read. rebasingStarted bool + // skipVerification indicates whether test validation should be skipped during rebase operations. + skipVerification bool } // Branch returns the name of the branch being actively rebased. This is the branch that will be updated to point @@ -120,6 +122,10 @@ func (rs RebaseState) WithRebasingStarted(rebasingStarted bool) *RebaseState { return &rs } +func (rs RebaseState) SkipVerification() bool { + return rs.skipVerification +} + type MergeState struct { // the source commit commit *Commit @@ -322,13 +328,14 @@ func (ws WorkingSet) StartMerge(commit *Commit, commitSpecStr string) *WorkingSe // the branch that is being rebased, and |previousRoot| is root value of the branch being rebased. The HEAD and STAGED // root values of the branch being rebased must match |previousRoot|; WORKING may be a different root value, but ONLY // if it contains only ignored tables. -func (ws WorkingSet) StartRebase(ctx *sql.Context, ontoCommit *Commit, branch string, previousRoot RootValue, commitBecomesEmptyHandling EmptyCommitHandling, emptyCommitHandling EmptyCommitHandling) (*WorkingSet, error) { +func (ws WorkingSet) StartRebase(ctx *sql.Context, ontoCommit *Commit, branch string, previousRoot RootValue, commitBecomesEmptyHandling EmptyCommitHandling, emptyCommitHandling EmptyCommitHandling, skipVerification bool) (*WorkingSet, error) { ws.rebaseState = &RebaseState{ ontoCommit: ontoCommit, preRebaseWorking: previousRoot, branch: branch, commitBecomesEmptyHandling: commitBecomesEmptyHandling, emptyCommitHandling: emptyCommitHandling, + skipVerification: skipVerification, } ontoRoot, err := ontoCommit.GetRootValue(ctx) @@ -549,6 +556,7 @@ func newWorkingSet(ctx context.Context, name string, vrw types.ValueReadWriter, emptyCommitHandling: EmptyCommitHandling(dsws.RebaseState.EmptyCommitHandling(ctx)), lastAttemptedStep: dsws.RebaseState.LastAttemptedStep(ctx), rebasingStarted: dsws.RebaseState.RebasingStarted(ctx), + skipVerification: dsws.RebaseState.SkipVerification(ctx), } } @@ -646,7 +654,7 @@ func (ws *WorkingSet) writeValues(ctx context.Context, db *DoltDB, meta *datas.W rebaseState = datas.NewRebaseState(preRebaseWorking.TargetHash(), dCommit.Addr(), ws.rebaseState.branch, uint8(ws.rebaseState.commitBecomesEmptyHandling), uint8(ws.rebaseState.emptyCommitHandling), - ws.rebaseState.lastAttemptedStep, ws.rebaseState.rebasingStarted) + ws.rebaseState.lastAttemptedStep, ws.rebaseState.rebasingStarted, ws.rebaseState.skipVerification) } return &datas.WorkingSetSpec{ diff --git a/go/libraries/doltcore/env/actions/commit.go b/go/libraries/doltcore/env/actions/commit.go index 61b2ede4ca4..15b1a619e2d 100644 --- a/go/libraries/doltcore/env/actions/commit.go +++ b/go/libraries/doltcore/env/actions/commit.go @@ -15,8 +15,12 @@ package actions import ( + "fmt" + "io" + "strings" "time" + gms "github.com/dolthub/go-mysql-server" "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/dolt/go/libraries/doltcore/diff" @@ -25,14 +29,42 @@ import ( ) type CommitStagedProps struct { - Message string - Date time.Time - AllowEmpty bool - SkipEmpty bool - Amend bool - Force bool - Name string - Email string + Message string + Date time.Time + AllowEmpty bool + SkipEmpty bool + Amend bool + Force bool + Name string + Email string + SkipVerification bool +} + +const ( + // System variable name, defined here to avoid circular imports + DoltCommitVerificationGroups = "dolt_commit_verification_groups" +) + +// GetCommitRunTestGroups returns the test groups to run for commit operations +// Returns empty slice if no tests should be run, ["*"] if all tests should be run, +// or specific group names if only those groups should be run +func GetCommitRunTestGroups() []string { + _, val, ok := sql.SystemVariables.GetGlobal(DoltCommitVerificationGroups) + if !ok { + return nil + } + if stringVal, ok := val.(string); ok && stringVal != "" { + if stringVal == "*" { + return []string{"*"} + } + // Split by comma and trim whitespace + groups := strings.Split(stringVal, ",") + for i, group := range groups { + groups[i] = strings.TrimSpace(group) + } + return groups + } + return nil } // GetCommitStaged returns a new pending commit with the roots and commit properties given. @@ -114,6 +146,16 @@ func GetCommitStaged( } } + if !props.SkipVerification { + testGroups := GetCommitRunTestGroups() + if len(testGroups) > 0 { + err := runCommitVerification(ctx, testGroups) + if err != nil { + return nil, err + } + } + } + meta, err := datas.NewCommitMetaWithUserTS(props.Name, props.Email, props.Message, props.Date) if err != nil { return nil, err @@ -121,3 +163,61 @@ func GetCommitStaged( return db.NewPendingCommit(ctx, roots, mergeParents, props.Amend, meta) } + +func runCommitVerification(ctx *sql.Context, testGroups []string) error { + type sessionInterface interface { + sql.Session + GenericProvider() sql.MutableDatabaseProvider + } + + session, ok := ctx.Session.(sessionInterface) + if !ok { + return fmt.Errorf("session does not provide database provider interface") + } + + provider := session.GenericProvider() + engine := gms.NewDefault(provider) + + return runTestsUsingDtablefunctions(ctx, engine, testGroups) +} + +// runTestsUsingDtablefunctions runs tests using the dtablefunctions package against the staged root +func runTestsUsingDtablefunctions(ctx *sql.Context, engine *gms.Engine, testGroups []string) error { + if len(testGroups) == 0 { + return nil + } + + var allFailures []string + + for _, group := range testGroups { + query := fmt.Sprintf("SELECT * FROM dolt_test_run('%s')", group) + _, iter, _, err := engine.Query(ctx, query) + if err != nil { + return fmt.Errorf("failed to run dolt_test_run for group %s: %w", group, err) + } + + for { + row, rErr := iter.Next(ctx) + if rErr == io.EOF { + break + } + if rErr != nil { + return fmt.Errorf("error reading test results: %w", rErr) + } + + // Extract status (column 3) + status := fmt.Sprintf("%v", row[3]) + if status != "PASS" { + testName := fmt.Sprintf("%v", row[0]) + message := fmt.Sprintf("%v", row[4]) + allFailures = append(allFailures, fmt.Sprintf("%s (%s)", testName, message)) + } + } + } + + if len(allFailures) > 0 { + return fmt.Errorf("commit verification failed: %s", strings.Join(allFailures, ", ")) + } + + return nil +} diff --git a/go/libraries/doltcore/env/actions/test_table_helpers.go b/go/libraries/doltcore/env/actions/test_table_helpers.go deleted file mode 100644 index dc722300566..00000000000 --- a/go/libraries/doltcore/env/actions/test_table_helpers.go +++ /dev/null @@ -1,447 +0,0 @@ -// Copyright 2025 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package actions - -import ( - "fmt" - "io" - "strconv" - "time" - - "github.com/dolthub/go-mysql-server/sql" - "github.com/shopspring/decimal" - "golang.org/x/exp/constraints" - - "github.com/dolthub/dolt/go/store/val" -) - -const ( - AssertionExpectedRows = "expected_rows" - AssertionExpectedColumns = "expected_columns" - AssertionExpectedSingleValue = "expected_single_value" -) - -// AssertData parses an assertion, comparison, and value, then returns the status of the test. -// Valid comparison are: "==", "!=", "<", ">", "<=", and ">=". -// testPassed indicates whether the test was successful or not. -// message is a string used to indicate test failures, and will not halt the overall process. -// message will be empty if the test passed. -// err indicates runtime failures and will stop dolt_test_run from proceeding. -func AssertData(sqlCtx *sql.Context, assertion string, comparison string, value *string, queryResult sql.RowIter) (testPassed bool, message string, err error) { - switch assertion { - case AssertionExpectedRows: - message, err = expectRows(sqlCtx, comparison, value, queryResult) - case AssertionExpectedColumns: - message, err = expectColumns(sqlCtx, comparison, value, queryResult) - case AssertionExpectedSingleValue: - message, err = expectSingleValue(sqlCtx, comparison, value, queryResult) - default: - return false, fmt.Sprintf("%s is not a valid assertion type", assertion), nil - } - - if err != nil { - return false, "", err - } else if message != "" { - return false, message, nil - } - return true, "", nil -} - -func expectSingleValue(sqlCtx *sql.Context, comparison string, value *string, queryResult sql.RowIter) (message string, err error) { - row, err := queryResult.Next(sqlCtx) - if err == io.EOF { - return fmt.Sprintf("expected_single_value expects exactly one cell. Received 0 rows"), nil - } else if err != nil { - return "", err - } - - if len(row) != 1 { - return fmt.Sprintf("expected_single_value expects exactly one cell. Received multiple columns"), nil - } - _, err = queryResult.Next(sqlCtx) - if err == nil { //If multiple rows were given, we should error out - return fmt.Sprintf("expected_single_value expects exactly one cell. Received multiple rows"), nil - } else if err != io.EOF { // "True" error, so we should quit out - return "", err - } - - if value == nil { // If we're expecting a null value, we don't need to type switch - return compareNullValue(comparison, row[0], AssertionExpectedSingleValue), nil - } - - // Check if the expected value is a boolean string, and if so, coerce the actual value to boolean, with the exception - // of "0" and "1", which are valid integers and are covered below. - if *value != "0" && *value != "1" { - if expectedBool, err := strconv.ParseBool(*value); err == nil { - actualBool, boolErr := getInterfaceAsBool(row[0]) - if boolErr != nil { - return fmt.Sprintf("Could not convert value to boolean: %v", boolErr), nil - } - return compareBooleans(comparison, expectedBool, actualBool, AssertionExpectedSingleValue), nil - } - } - - switch actualValue := row[0].(type) { - case int8: - expectedInt, err := strconv.ParseInt(*value, 10, 64) - if err != nil { - return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil - } - return compareTestAssertion(comparison, int8(expectedInt), actualValue, AssertionExpectedSingleValue), nil - case int16: - expectedInt, err := strconv.ParseInt(*value, 10, 64) - if err != nil { - return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil - } - return compareTestAssertion(comparison, int16(expectedInt), actualValue, AssertionExpectedSingleValue), nil - case int32: - expectedInt, err := strconv.ParseInt(*value, 10, 64) - if err != nil { - return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil - } - return compareTestAssertion(comparison, int32(expectedInt), actualValue, AssertionExpectedSingleValue), nil - case int64: - expectedInt, err := strconv.ParseInt(*value, 10, 64) - if err != nil { - return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil - } - return compareTestAssertion(comparison, expectedInt, actualValue, AssertionExpectedSingleValue), nil - case int: - expectedInt, err := strconv.ParseInt(*value, 10, 64) - if err != nil { - return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil - } - return compareTestAssertion(comparison, int(expectedInt), actualValue, AssertionExpectedSingleValue), nil - case uint8: - expectedUint, err := strconv.ParseUint(*value, 10, 32) - if err != nil { - return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil - } - return compareTestAssertion(comparison, uint8(expectedUint), actualValue, AssertionExpectedSingleValue), nil - case uint16: - expectedUint, err := strconv.ParseUint(*value, 10, 32) - if err != nil { - return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil - } - return compareTestAssertion(comparison, uint16(expectedUint), actualValue, AssertionExpectedSingleValue), nil - case uint32: - expectedUint, err := strconv.ParseUint(*value, 10, 32) - if err != nil { - return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil - } - return compareTestAssertion(comparison, uint32(expectedUint), actualValue, AssertionExpectedSingleValue), nil - case uint64: - expectedUint, err := strconv.ParseUint(*value, 10, 64) - if err != nil { - return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil - } - return compareTestAssertion(comparison, expectedUint, actualValue, AssertionExpectedSingleValue), nil - case uint: - expectedUint, err := strconv.ParseUint(*value, 10, 64) - if err != nil { - return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil - } - return compareTestAssertion(comparison, uint(expectedUint), actualValue, AssertionExpectedSingleValue), nil - case float64: - expectedFloat, err := strconv.ParseFloat(*value, 64) - if err != nil { - return fmt.Sprintf("Could not compare non float value '%s', with %f", *value, actualValue), nil - } - return compareTestAssertion(comparison, expectedFloat, actualValue, AssertionExpectedSingleValue), nil - case float32: - expectedFloat, err := strconv.ParseFloat(*value, 32) - if err != nil { - return fmt.Sprintf("Could not compare non float value '%s', with %f", *value, actualValue), nil - } - return compareTestAssertion(comparison, float32(expectedFloat), actualValue, AssertionExpectedSingleValue), nil - case decimal.Decimal: - expectedDecimal, err := decimal.NewFromString(*value) - if err != nil { - return fmt.Sprintf("Could not compare non decimal value '%s', with %s", *value, actualValue), nil - } - return compareDecimals(comparison, expectedDecimal, actualValue, AssertionExpectedSingleValue), nil - case time.Time: - expectedTime, format, err := parseTestsDate(*value) - if err != nil { - return fmt.Sprintf("%s does not appear to be a valid date", *value), nil - } - return compareDates(comparison, expectedTime, actualValue, format, AssertionExpectedSingleValue), nil - case *val.TextStorage, string: - actualString, err := GetStringColAsString(sqlCtx, actualValue) - if err != nil { - return "", err - } - return compareTestAssertion(comparison, *value, *actualString, AssertionExpectedSingleValue), nil - default: - return fmt.Sprintf("Type %T is not supported. Open an issue at https://github.com/dolthub/dolt/issues to see it added", actualValue), nil - } -} - -func expectRows(sqlCtx *sql.Context, comparison string, value *string, queryResult sql.RowIter) (message string, err error) { - if value == nil { - return "null is not a valid assertion for expected_rows", nil - } - expectedRows, err := strconv.Atoi(*value) - if err != nil { - return fmt.Sprintf("cannot run assertion on non integer value: %s", *value), nil - } - - var numRows int - for { - _, err := queryResult.Next(sqlCtx) - if err == io.EOF { - break - } else if err != nil { - return "", err - } - numRows++ - } - return compareTestAssertion(comparison, expectedRows, numRows, AssertionExpectedRows), nil -} - -func expectColumns(sqlCtx *sql.Context, comparison string, value *string, queryResult sql.RowIter) (message string, err error) { - if value == nil { - return "null is not a valid assertion for expected_rows", nil - } - expectedColumns, err := strconv.Atoi(*value) - if err != nil { - return fmt.Sprintf("cannot run assertion on non integer value: %s", *value), nil - } - - var numColumns int - row, err := queryResult.Next(sqlCtx) - if err != nil && err != io.EOF { - return "", err - } - numColumns = len(row) - return compareTestAssertion(comparison, expectedColumns, numColumns, AssertionExpectedColumns), nil -} - -// compareTestAssertion is a generic function used for comparing string, ints, floats. -// It takes in a comparison string from one of: "==", "!=", "<", ">", "<=", ">=" -// It returns a string. The string is empty if the assertion passed, or has a message explaining the failure otherwise -func compareTestAssertion[T constraints.Ordered](comparison string, expectedValue, actualValue T, assertionType string) string { - switch comparison { - case "==": - if actualValue != expectedValue { - return fmt.Sprintf("Assertion failed: %s equal to %v, got %v", assertionType, expectedValue, actualValue) - } - case "!=": - if actualValue == expectedValue { - return fmt.Sprintf("Assertion failed: %s not equal to %v, got %v", assertionType, expectedValue, actualValue) - } - case "<": - if actualValue >= expectedValue { - return fmt.Sprintf("Assertion failed: %s less than %v, got %v", assertionType, expectedValue, actualValue) - } - case "<=": - if actualValue > expectedValue { - return fmt.Sprintf("Assertion failed: %s less than or equal to %v, got %v", assertionType, expectedValue, actualValue) - } - case ">": - if actualValue <= expectedValue { - return fmt.Sprintf("Assertion failed: %s greater than %v, got %v", assertionType, expectedValue, actualValue) - } - case ">=": - if actualValue < expectedValue { - return fmt.Sprintf("Assertion failed: %s greater than or equal to %v, got %v", assertionType, expectedValue, actualValue) - } - default: - return fmt.Sprintf("%s is not a valid comparison type", comparison) - } - return "" -} - -// parseTestsDate is an internal function that parses the queried string according to allowed time formats for dolt_tests. -// It returns the parsed time, the format that succeeded, and an error if applicable. -func parseTestsDate(value string) (parsedTime time.Time, format string, err error) { - // List of valid formats - formats := []string{ - time.DateOnly, - time.DateTime, - time.TimeOnly, - time.RFC3339, - time.RFC1123Z, - } - - for _, format := range formats { - if parsedTime, parseErr := time.Parse(format, value); parseErr == nil { - return parsedTime, format, nil - } else { - err = parseErr - } - } - return time.Time{}, "", err -} - -// compareDates is a function used for comparing time values. -// It takes in a comparison string from one of: "==", "!=", "<", ">", "<=", ">=" -// It returns a string. The string is empty if the assertion passed, or has a message explaining the failure otherwise -func compareDates(comparison string, expectedValue, realValue time.Time, format string, assertionType string) string { - expectedStr := expectedValue.Format(format) - realStr := realValue.Format(format) - switch comparison { - case "==": - if !expectedValue.Equal(realValue) { - return fmt.Sprintf("Assertion failed: %s equal to %s, got %s", assertionType, expectedStr, realStr) - } - case "!=": - if expectedValue.Equal(realValue) { - return fmt.Sprintf("Assertion failed: %s not equal to %s, got %s", assertionType, expectedStr, realStr) - } - case "<": - if realValue.Equal(expectedValue) || realValue.After(expectedValue) { - return fmt.Sprintf("Assertion failed: %s less than %s, got %s", assertionType, expectedStr, realStr) - } - case "<=": - if realValue.After(expectedValue) { - return fmt.Sprintf("Assertion failed: %s less than or equal to %s, got %s", assertionType, expectedStr, realStr) - } - case ">": - if realValue.Before(expectedValue) || realValue.Equal(expectedValue) { - return fmt.Sprintf("Assertion failed: %s greater than %s, got %s", assertionType, expectedStr, realStr) - } - case ">=": - if realValue.Before(expectedValue) { - return fmt.Sprintf("Assertion failed: %s greater than or equal to %s, got %s", assertionType, expectedStr, realStr) - } - default: - return fmt.Sprintf("%s is not a valid comparison type", comparison) - } - return "" -} - -// compareDecimals is a function used for comparing decimals. -// It takes in a comparison string from one of: "==", "!=", "<", ">", "<=", ">=" -// It returns a string. The string is empty if the assertion passed, or has a message explaining the failure otherwise -func compareDecimals(comparison string, expectedValue, realValue decimal.Decimal, assertionType string) string { - switch comparison { - case "==": - if !expectedValue.Equal(realValue) { - return fmt.Sprintf("Assertion failed: %s equal to %v, got %v", assertionType, expectedValue, realValue) - } - case "!=": - if expectedValue.Equal(realValue) { - return fmt.Sprintf("Assertion failed: %s not equal to %v, got %v", assertionType, expectedValue, realValue) - } - case "<": - if realValue.GreaterThanOrEqual(expectedValue) { - return fmt.Sprintf("Assertion failed: %s less than %v, got %v", assertionType, expectedValue, realValue) - } - case "<=": - if realValue.GreaterThan(expectedValue) { - return fmt.Sprintf("Assertion failed: %s less than or equal to %v, got %v", assertionType, expectedValue, realValue) - } - case ">": - if realValue.LessThanOrEqual(expectedValue) { - return fmt.Sprintf("Assertion failed: %s greater than %v, got %v", assertionType, expectedValue, realValue) - } - case ">=": - if realValue.LessThan(expectedValue) { - return fmt.Sprintf("Assertion failed: %s greater than or equal to %v, got %v", assertionType, expectedValue, realValue) - } - default: - return fmt.Sprintf("%s is not a valid comparison type", comparison) - } - return "" -} - -// getTinyIntColAsBool returns the value interface{} as a bool -// This is necessary because the query engine may return a tinyint column as a bool, int, or other types. -// Based on GetTinyIntColAsBool from commands/utils.go, which we can't depend on here due to package cycles. -func getInterfaceAsBool(col interface{}) (bool, error) { - switch v := col.(type) { - case bool: - return v, nil - case int: - return v == 1, nil - case int8: - return v == 1, nil - case int16: - return v == 1, nil - case int32: - return v == 1, nil - case int64: - return v == 1, nil - case uint: - return v == 1, nil - case uint8: - return v == 1, nil - case uint16: - return v == 1, nil - case uint32: - return v == 1, nil - case uint64: - return v == 1, nil - case string: - return v == "1", nil - default: - return false, fmt.Errorf("unexpected type %T, was expecting bool, int, or string", v) - } -} - -// compareBooleans is a function used for comparing boolean values. -// It takes in a comparison string from one of: "==", "!=" -// It returns a string. The string is empty if the assertion passed, or has a message explaining the failure otherwise -func compareBooleans(comparison string, expectedValue, realValue bool, assertionType string) string { - switch comparison { - case "==": - if expectedValue != realValue { - return fmt.Sprintf("Assertion failed: %s equal to %t, got %t", assertionType, expectedValue, realValue) - } - case "!=": - if expectedValue == realValue { - return fmt.Sprintf("Assertion failed: %s not equal to %t, got %t", assertionType, expectedValue, realValue) - } - default: - return fmt.Sprintf("%s is not a valid comparison for boolean values. Only '==' and '!=' are supported", comparison) - } - return "" -} - -// compareNullValue is a function used for comparing a null value. -// It takes in a comparison string from one of: "==", "!=" -// It returns a string. The string is empty if the assertion passed, or has a message explaining the failure otherwise -func compareNullValue(comparison string, actualValue interface{}, assertionType string) string { - switch comparison { - case "==": - if actualValue != nil { - return fmt.Sprintf("Assertion failed: %s equal to NULL, got %v", assertionType, actualValue) - } - case "!=": - if actualValue == nil { - return fmt.Sprintf("Assertion failed: %s not equal to NULL, got NULL", assertionType) - } - default: - return fmt.Sprintf("%s is not a valid comparison for NULL values", comparison) - } - return "" -} - -// GetStringColAsString is a function that returns a text column as a string. -// This is necessary as the dolt_tests system table returns *val.TextStorage types under certain situations, -// so we use a special parser to get the correct string values -func GetStringColAsString(sqlCtx *sql.Context, tableValue interface{}) (*string, error) { - if ts, ok := tableValue.(*val.TextStorage); ok { - str, err := ts.Unwrap(sqlCtx) - return &str, err - } else if str, ok := tableValue.(string); ok { - return &str, nil - } else if tableValue == nil { - return nil, nil - } else { - return nil, fmt.Errorf("unexpected type %T, was expecting string", tableValue) - } -} diff --git a/go/libraries/doltcore/sqle/database.go b/go/libraries/doltcore/sqle/database.go index d541bf197ed..b21284ec341 100644 --- a/go/libraries/doltcore/sqle/database.go +++ b/go/libraries/doltcore/sqle/database.go @@ -1955,7 +1955,10 @@ func (db Database) CreateTable(ctx *sql.Context, tableName string, sch sql.Prima return err } - if doltdb.IsSystemTable(doltdb.TableName{Name: tableName, Schema: db.schemaName}) && !doltdb.IsFullTextTable(tableName) && !doltdb.HasDoltCIPrefix(tableName) { + if doltdb.IsSystemTable(doltdb.TableName{Name: tableName, Schema: db.schemaName}) && + !doltdb.IsFullTextTable(tableName) && + !doltdb.HasDoltCIPrefix(tableName) && + tableName != doltdb.TestsTableName { // NM4 - determine why this is required now. return ErrReservedTableName.New(tableName) } diff --git a/go/libraries/doltcore/sqle/dprocedures/dolt_cherry_pick.go b/go/libraries/doltcore/sqle/dprocedures/dolt_cherry_pick.go index f1c215cfe32..58cbe5c3646 100644 --- a/go/libraries/doltcore/sqle/dprocedures/dolt_cherry_pick.go +++ b/go/libraries/doltcore/sqle/dprocedures/dolt_cherry_pick.go @@ -103,6 +103,8 @@ func doDoltCherryPick(ctx *sql.Context, args []string) (string, int, int, int, e cherryPickOptions.EmptyCommitHandling = doltdb.KeepEmptyCommit } + cherryPickOptions.SkipVerification = apr.Contains(cli.SkipVerificationFlag) + commit, mergeResult, err := cherry_pick.CherryPick(ctx, cherryStr, cherryPickOptions) if err != nil { return "", 0, 0, 0, err diff --git a/go/libraries/doltcore/sqle/dprocedures/dolt_commit.go b/go/libraries/doltcore/sqle/dprocedures/dolt_commit.go index b81f8a64b1d..7c89edca6ac 100644 --- a/go/libraries/doltcore/sqle/dprocedures/dolt_commit.go +++ b/go/libraries/doltcore/sqle/dprocedures/dolt_commit.go @@ -163,14 +163,15 @@ func doDoltCommit(ctx *sql.Context, args []string) (string, bool, error) { } csp := actions.CommitStagedProps{ - Message: msg, - Date: t, - AllowEmpty: apr.Contains(cli.AllowEmptyFlag), - SkipEmpty: apr.Contains(cli.SkipEmptyFlag), - Amend: amend, - Force: apr.Contains(cli.ForceFlag), - Name: name, - Email: email, + Message: msg, + Date: t, + AllowEmpty: apr.Contains(cli.AllowEmptyFlag), + SkipEmpty: apr.Contains(cli.SkipEmptyFlag), + Amend: amend, + Force: apr.Contains(cli.ForceFlag), + Name: name, + Email: email, + SkipVerification: apr.Contains(cli.SkipVerificationFlag), } shouldSign, err := dsess.GetBooleanSystemVar(ctx, "gpgsign") diff --git a/go/libraries/doltcore/sqle/dprocedures/dolt_merge.go b/go/libraries/doltcore/sqle/dprocedures/dolt_merge.go index 6a0471163cd..1f68de8e792 100644 --- a/go/libraries/doltcore/sqle/dprocedures/dolt_merge.go +++ b/go/libraries/doltcore/sqle/dprocedures/dolt_merge.go @@ -180,7 +180,7 @@ func doDoltMerge(ctx *sql.Context, args []string) (string, int, int, string, err msg = userMsg } - ws, commit, conflicts, fastForward, message, err := performMerge(ctx, sess, ws, dbName, mergeSpec, apr.Contains(cli.NoCommitFlag), msg) + ws, commit, conflicts, fastForward, message, err := performMerge(ctx, sess, ws, dbName, mergeSpec, apr.Contains(cli.NoCommitFlag), msg, apr.Contains(cli.SkipVerificationFlag)) if err != nil { return commit, conflicts, fastForward, "", err } @@ -205,6 +205,7 @@ func performMerge( spec *merge.MergeSpec, noCommit bool, msg string, + skipVerification bool, ) (*doltdb.WorkingSet, string, int, int, string, error) { // todo: allow merges even when an existing merge is uncommitted if ws.MergeActive() { @@ -234,7 +235,7 @@ func performMerge( if canFF { if spec.FFMode == merge.NoFastForward { var commit *doltdb.Commit - ws, commit, err = executeNoFFMerge(ctx, sess, spec, msg, dbName, ws, noCommit) + ws, commit, err = executeNoFFMerge(ctx, sess, spec, msg, dbName, ws, noCommit, skipVerification) if err == doltdb.ErrUnresolvedConflictsOrViolations { // if there are unresolved conflicts, write the resulting working set back to the session and return an // error message @@ -306,7 +307,10 @@ func performMerge( author := fmt.Sprintf("%s <%s>", spec.Name, spec.Email) args := []string{"-m", msg, "--author", author} if spec.Force { - args = append(args, "--force") + args = append(args, "--"+cli.ForceFlag) + } + if skipVerification { + args = append(args, "--"+cli.SkipVerificationFlag) } commit, _, err = doDoltCommit(ctx, args) if err != nil { @@ -405,6 +409,7 @@ func executeNoFFMerge( dbName string, ws *doltdb.WorkingSet, noCommit bool, + skipVerification bool, ) (*doltdb.WorkingSet, *doltdb.Commit, error) { mergeRoot, err := spec.MergeC.GetRootValue(ctx) if err != nil { @@ -444,11 +449,12 @@ func executeNoFFMerge( } pendingCommit, err := dSess.NewPendingCommit(ctx, dbName, roots, actions.CommitStagedProps{ - Message: msg, - Date: spec.Date, - Force: spec.Force, - Name: spec.Name, - Email: spec.Email, + Message: msg, + Date: spec.Date, + Force: spec.Force, + Name: spec.Name, + Email: spec.Email, + SkipVerification: skipVerification, }) if err != nil { return nil, nil, err diff --git a/go/libraries/doltcore/sqle/dprocedures/dolt_pull.go b/go/libraries/doltcore/sqle/dprocedures/dolt_pull.go index 46f3a940b14..1cba0b33aa8 100644 --- a/go/libraries/doltcore/sqle/dprocedures/dolt_pull.go +++ b/go/libraries/doltcore/sqle/dprocedures/dolt_pull.go @@ -237,7 +237,7 @@ func doDoltPull(ctx *sql.Context, args []string) (int, int, string, error) { return noConflictsOrViolations, threeWayMerge, "", ErrUncommittedChanges.New() } - ws, _, conflicts, fastForward, message, err = performMerge(ctx, sess, ws, dbName, mergeSpec, apr.Contains(cli.NoCommitFlag), msg) + ws, _, conflicts, fastForward, message, err = performMerge(ctx, sess, ws, dbName, mergeSpec, apr.Contains(cli.NoCommitFlag), msg, apr.Contains(cli.SkipVerificationFlag)) if err != nil && !errors.Is(doltdb.ErrUpToDate, err) { return conflicts, fastForward, "", err } diff --git a/go/libraries/doltcore/sqle/dprocedures/dolt_rebase.go b/go/libraries/doltcore/sqle/dprocedures/dolt_rebase.go index 7f4374df456..a770f6485f9 100644 --- a/go/libraries/doltcore/sqle/dprocedures/dolt_rebase.go +++ b/go/libraries/doltcore/sqle/dprocedures/dolt_rebase.go @@ -216,7 +216,9 @@ func doDoltRebase(ctx *sql.Context, args []string) (int, string, error) { } else if apr.NArg() > 1 { return 1, "", fmt.Errorf("too many args") } - err = startRebase(ctx, apr.Arg(0), commitBecomesEmptyHandling, emptyCommitHandling) + + skipVerification := apr.Contains(cli.SkipVerificationFlag) + err = startRebase(ctx, apr.Arg(0), commitBecomesEmptyHandling, emptyCommitHandling, skipVerification) if err != nil { return 1, "", err } @@ -263,7 +265,7 @@ func processCommitBecomesEmptyParams(apr *argparser.ArgParseResults) (doltdb.Emp // startRebase starts a new interactive rebase operation. |upstreamPoint| specifies the commit where the new rebased // commits will be based off of, |commitBecomesEmptyHandling| specifies how to handle commits that are not empty, but // do not produce any changes when applied, and |emptyCommitHandling| specifies how to handle empty commits. -func startRebase(ctx *sql.Context, upstreamPoint string, commitBecomesEmptyHandling doltdb.EmptyCommitHandling, emptyCommitHandling doltdb.EmptyCommitHandling) error { +func startRebase(ctx *sql.Context, upstreamPoint string, commitBecomesEmptyHandling doltdb.EmptyCommitHandling, emptyCommitHandling doltdb.EmptyCommitHandling, skipVerification bool) error { if upstreamPoint == "" { return fmt.Errorf("no upstream branch specified") } @@ -351,7 +353,7 @@ func startRebase(ctx *sql.Context, upstreamPoint string, commitBecomesEmptyHandl } newWorkingSet, err := workingSet.StartRebase(ctx, upstreamCommit, rebaseBranch, branchRoots.Working, - commitBecomesEmptyHandling, emptyCommitHandling) + commitBecomesEmptyHandling, emptyCommitHandling, skipVerification) if err != nil { return err } @@ -716,7 +718,8 @@ func continueRebase(ctx *sql.Context) rebaseResult { result := processRebasePlanStep(ctx, &step, workingSet.RebaseState().CommitBecomesEmptyHandling(), - workingSet.RebaseState().EmptyCommitHandling()) + workingSet.RebaseState().EmptyCommitHandling(), + workingSet.RebaseState().SkipVerification()) if result.err != nil || result.status != 0 || result.halt { return result } @@ -803,7 +806,7 @@ func commitManuallyStagedChangesForStep(ctx *sql.Context, step rebase.RebasePlan } options, err := createCherryPickOptionsForRebaseStep(ctx, &step, workingSet.RebaseState().CommitBecomesEmptyHandling(), - workingSet.RebaseState().EmptyCommitHandling()) + workingSet.RebaseState().EmptyCommitHandling(), workingSet.RebaseState().SkipVerification()) doltDB, ok := doltSession.GetDoltDB(ctx, ctx.GetCurrentDatabase()) if !ok { @@ -861,6 +864,7 @@ func processRebasePlanStep( planStep *rebase.RebasePlanStep, commitBecomesEmptyHandling doltdb.EmptyCommitHandling, emptyCommitHandling doltdb.EmptyCommitHandling, + skipVerification bool, ) rebaseResult { // Make sure we have a transaction opened for the session // NOTE: After our first call to cherry-pick, the tx is committed, so a new tx needs to be started @@ -878,7 +882,7 @@ func processRebasePlanStep( return newRebaseSuccess("") } - options, err := createCherryPickOptionsForRebaseStep(ctx, planStep, commitBecomesEmptyHandling, emptyCommitHandling) + options, err := createCherryPickOptionsForRebaseStep(ctx, planStep, commitBecomesEmptyHandling, emptyCommitHandling, skipVerification) if err != nil { return newRebaseError(err) } @@ -886,12 +890,19 @@ func processRebasePlanStep( return handleRebaseCherryPick(ctx, planStep, *options) } -func createCherryPickOptionsForRebaseStep(ctx *sql.Context, planStep *rebase.RebasePlanStep, commitBecomesEmptyHandling doltdb.EmptyCommitHandling, emptyCommitHandling doltdb.EmptyCommitHandling) (*cherry_pick.CherryPickOptions, error) { +func createCherryPickOptionsForRebaseStep( + ctx *sql.Context, + planStep *rebase.RebasePlanStep, + commitBecomesEmptyHandling doltdb.EmptyCommitHandling, + emptyCommitHandling doltdb.EmptyCommitHandling, + skipVerification bool, +) (*cherry_pick.CherryPickOptions, error) { // Override the default empty commit handling options for cherry-pick, since // rebase has slightly different defaults options := cherry_pick.NewCherryPickOptions() options.CommitBecomesEmptyHandling = commitBecomesEmptyHandling options.EmptyCommitHandling = emptyCommitHandling + options.SkipVerification = skipVerification switch planStep.Action { case rebase.RebaseActionDrop, rebase.RebaseActionPick, rebase.RebaseActionEdit: diff --git a/go/libraries/doltcore/sqle/dtablefunctions/dolt_test_run.go b/go/libraries/doltcore/sqle/dtablefunctions/dolt_test_run.go index 59a1ff26ab1..e76874f7c95 100644 --- a/go/libraries/doltcore/sqle/dtablefunctions/dolt_test_run.go +++ b/go/libraries/doltcore/sqle/dtablefunctions/dolt_test_run.go @@ -17,7 +17,9 @@ package dtablefunctions import ( "fmt" "io" + "strconv" "strings" + "time" gms "github.com/dolthub/go-mysql-server" "github.com/dolthub/go-mysql-server/sql" @@ -26,10 +28,13 @@ import ( "github.com/dolthub/vitess/go/vt/sqlparser" "github.com/gocraft/dbr/v2" "github.com/gocraft/dbr/v2/dialect" + "github.com/shopspring/decimal" + "golang.org/x/exp/constraints" - "github.com/dolthub/dolt/go/libraries/doltcore/env/actions" + "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" "github.com/dolthub/dolt/go/libraries/doltcore/schema" "github.com/dolthub/dolt/go/libraries/doltcore/sqle/overrides" + "github.com/dolthub/dolt/go/store/val" ) const testsRunDefaultRowCount = 10 @@ -39,12 +44,13 @@ var _ sql.CatalogTableFunction = (*TestsRunTableFunction)(nil) var _ sql.ExecSourceRel = (*TestsRunTableFunction)(nil) var _ sql.AuthorizationCheckerNode = (*TestsRunTableFunction)(nil) -type testResult struct { - testName string - groupName string - query string - status string - message string +// TestResult represents the result of running a single test +type TestResult struct { + TestName string + GroupName string + Query string + Status string + Message string } type TestsRunTableFunction struct { @@ -199,7 +205,7 @@ func (trtf *TestsRunTableFunction) RowIter(_ *sql.Context, _ sql.Row) (sql.RowIt return nil, err } - resultRow := sql.NewRow(result.testName, result.groupName, result.query, result.status, result.message) + resultRow := sql.NewRow(result.TestName, result.GroupName, result.Query, result.Status, result.Message) resultRows = append(resultRows, resultRow) } } @@ -220,7 +226,7 @@ func (trtf *TestsRunTableFunction) RowCount(_ *sql.Context) (uint64, bool, error return testsRunDefaultRowCount, false, nil } -func (trtf *TestsRunTableFunction) queryAndAssert(row sql.Row) (result testResult, err error) { +func (trtf *TestsRunTableFunction) queryAndAssert(row sql.Row) (result TestResult, err error) { testName, groupName, query, assertion, comparison, value, err := parseDoltTestsRow(trtf.ctx, row) if err != nil { return @@ -237,9 +243,9 @@ func (trtf *TestsRunTableFunction) queryAndAssert(row sql.Row) (result testResul if err != nil { message = fmt.Sprintf("Query error: %s", err.Error()) } else { - testPassed, message, err = actions.AssertData(trtf.ctx, *assertion, *comparison, value, queryResult) + testPassed, message, err = AssertData(trtf.ctx, *assertion, *comparison, value, queryResult) if err != nil { - return testResult{}, err + return TestResult{}, err } } } @@ -253,11 +259,49 @@ func (trtf *TestsRunTableFunction) queryAndAssert(row sql.Row) (result testResul if groupName != nil { groupString = *groupName } - result = testResult{*testName, groupString, *query, status, message} + result = TestResult{*testName, groupString, *query, status, message} + return result, nil +} + +func (trtf *TestsRunTableFunction) queryAndAssertWithFunc(row sql.Row, assertDataFunc AssertDataFunc) (result TestResult, err error) { + testName, groupName, query, assertion, comparison, value, err := parseDoltTestsRow(trtf.ctx, row) + if err != nil { + return + } + + message, err := validateQuery(trtf.ctx, trtf.catalog, *query) + if err != nil && message == "" { + message = fmt.Sprintf("query error: %s", err.Error()) + } + + var testPassed bool + if message == "" { + _, queryResult, _, err := trtf.engine.Query(trtf.ctx, *query) + if err != nil { + message = fmt.Sprintf("Query error: %s", err.Error()) + } else { + testPassed, message, err = assertDataFunc(trtf.ctx, *assertion, *comparison, value, queryResult) + if err != nil { + return TestResult{}, err + } + } + } + + status := "PASS" + if !testPassed { + status = "FAIL" + } + + var groupString string + if groupName != nil { + groupString = *groupName + } + result = TestResult{*testName, groupString, *query, status, message} return result, nil } func (trtf *TestsRunTableFunction) getDoltTestsData(arg string) ([]sql.Row, error) { + // Original behavior when root is nil - use SQL queries against current session var queries []string if arg == "*" { @@ -320,28 +364,31 @@ func IsWriteQuery(query string, ctx *sql.Context, catalog sql.Catalog) (bool, er } func parseDoltTestsRow(ctx *sql.Context, row sql.Row) (testName, groupName, query, assertion, comparison, value *string, err error) { - if testName, err = actions.GetStringColAsString(ctx, row[0]); err != nil { + if testName, err = getStringColAsString(ctx, row[0]); err != nil { return } - if groupName, err = actions.GetStringColAsString(ctx, row[1]); err != nil { + if groupName, err = getStringColAsString(ctx, row[1]); err != nil { return } - if query, err = actions.GetStringColAsString(ctx, row[2]); err != nil { + if query, err = getStringColAsString(ctx, row[2]); err != nil { return } - if assertion, err = actions.GetStringColAsString(ctx, row[3]); err != nil { + if assertion, err = getStringColAsString(ctx, row[3]); err != nil { return } - if comparison, err = actions.GetStringColAsString(ctx, row[4]); err != nil { + if comparison, err = getStringColAsString(ctx, row[4]); err != nil { return } - if value, err = actions.GetStringColAsString(ctx, row[5]); err != nil { + if value, err = getStringColAsString(ctx, row[5]); err != nil { return } return testName, groupName, query, assertion, comparison, value, nil } +// AssertDataFunc defines the function signature for asserting test data +type AssertDataFunc func(sqlCtx *sql.Context, assertion string, comparison string, value *string, queryResult sql.RowIter) (testPassed bool, message string, err error) + func validateQuery(ctx *sql.Context, catalog sql.Catalog, query string) (string, error) { // We first check if the query contains multiple sql statements if statements, err := sqlparser.SplitStatementToPieces(query); err != nil { @@ -361,3 +408,455 @@ func validateQuery(ctx *sql.Context, catalog sql.Catalog, query string) (string, } return "", nil } + +// Simple inline assertion constants to avoid circular imports +const ( + AssertionExpectedRows = "expected_rows" + AssertionExpectedColumns = "expected_columns" + AssertionExpectedSingleValue = "expected_single_value" +) + +// getStringColAsString safely converts a sql value to string +func getStringColAsString(sqlCtx *sql.Context, tableValue interface{}) (*string, error) { + if tableValue == nil { + return nil, nil + } + if ts, ok := tableValue.(*val.TextStorage); ok { + str, err := ts.Unwrap(sqlCtx) + if err != nil { + return nil, err + } + return &str, nil + } else if str, ok := tableValue.(string); ok { + return &str, nil + } else { + return nil, fmt.Errorf("unexpected type %T, was expecting string", tableValue) + } +} + +// readTableDataFromDoltTable reads test data directly from a dolt table +func (trtf *TestsRunTableFunction) readTableDataFromDoltTable(table *doltdb.Table, arg string) ([]sql.Row, error) { + // This is a complex implementation that requires reading table data directly from dolt storage + // For now, return an error that clearly indicates this needs to be implemented + // The table scan would involve: + // 1. Getting the table schema + // 2. Creating a table iterator + // 3. Reading and filtering rows based on the arg (test_name or test_group) + // 4. Converting dolt storage format to SQL rows + // + // This is a significant implementation that requires understanding dolt's storage internals + return nil, fmt.Errorf("direct table reading from dolt storage not yet implemented for table scan of dolt_tests - this requires implementing table iteration and row conversion from dolt's internal storage format") +} + +// AssertData parses an assertion, comparison, and value, then returns the status of the test. +// Valid comparison are: "==", "!=", "<", ">", "<=", and ">=". +// testPassed indicates whether the test was successful or not. +// message is a string used to indicate test failures, and will not halt the overall process. +// message will be empty if the test passed. +// err indicates runtime failures and will stop dolt_test_run from proceeding. +func AssertData(sqlCtx *sql.Context, assertion string, comparison string, value *string, queryResult sql.RowIter) (testPassed bool, message string, err error) { + switch assertion { + case AssertionExpectedRows: + message, err = expectRows(sqlCtx, comparison, value, queryResult) + case AssertionExpectedColumns: + message, err = expectColumns(sqlCtx, comparison, value, queryResult) + case AssertionExpectedSingleValue: + message, err = expectSingleValue(sqlCtx, comparison, value, queryResult) + default: + return false, fmt.Sprintf("%s is not a valid assertion type", assertion), nil + } + + if err != nil { + return false, "", err + } else if message != "" { + return false, message, nil + } + return true, "", nil +} + +func expectSingleValue(sqlCtx *sql.Context, comparison string, value *string, queryResult sql.RowIter) (message string, err error) { + row, err := queryResult.Next(sqlCtx) + if err == io.EOF { + return fmt.Sprintf("expected_single_value expects exactly one cell. Received 0 rows"), nil + } else if err != nil { + return "", err + } + + if len(row) != 1 { + return fmt.Sprintf("expected_single_value expects exactly one cell. Received multiple columns"), nil + } + _, err = queryResult.Next(sqlCtx) + if err == nil { //If multiple rows were given, we should error out + return fmt.Sprintf("expected_single_value expects exactly one cell. Received multiple rows"), nil + } else if err != io.EOF { // "True" error, so we should quit out + return "", err + } + + if value == nil { // If we're expecting a null value, we don't need to type switch + return compareNullValue(comparison, row[0], AssertionExpectedSingleValue), nil + } + + // Check if the expected value is a boolean string, and if so, coerce the actual value to boolean, with the exception + // of "0" and "1", which are valid integers and are covered below. + if *value != "0" && *value != "1" { + if expectedBool, err := strconv.ParseBool(*value); err == nil { + actualBool, boolErr := getInterfaceAsBool(row[0]) + if boolErr != nil { + return fmt.Sprintf("Could not convert value to boolean: %v", boolErr), nil + } + return compareBooleans(comparison, expectedBool, actualBool, AssertionExpectedSingleValue), nil + } + } + + switch actualValue := row[0].(type) { + case int8: + expectedInt, err := strconv.ParseInt(*value, 10, 64) + if err != nil { + return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil + } + return compareTestAssertion(comparison, int8(expectedInt), actualValue, AssertionExpectedSingleValue), nil + case int16: + expectedInt, err := strconv.ParseInt(*value, 10, 64) + if err != nil { + return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil + } + return compareTestAssertion(comparison, int16(expectedInt), actualValue, AssertionExpectedSingleValue), nil + case int32: + expectedInt, err := strconv.ParseInt(*value, 10, 64) + if err != nil { + return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil + } + return compareTestAssertion(comparison, int32(expectedInt), actualValue, AssertionExpectedSingleValue), nil + case int64: + expectedInt, err := strconv.ParseInt(*value, 10, 64) + if err != nil { + return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil + } + return compareTestAssertion(comparison, expectedInt, actualValue, AssertionExpectedSingleValue), nil + case int: + expectedInt, err := strconv.ParseInt(*value, 10, 64) + if err != nil { + return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil + } + return compareTestAssertion(comparison, int(expectedInt), actualValue, AssertionExpectedSingleValue), nil + case uint8: + expectedUint, err := strconv.ParseUint(*value, 10, 32) + if err != nil { + return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil + } + return compareTestAssertion(comparison, uint8(expectedUint), actualValue, AssertionExpectedSingleValue), nil + case uint16: + expectedUint, err := strconv.ParseUint(*value, 10, 32) + if err != nil { + return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil + } + return compareTestAssertion(comparison, uint16(expectedUint), actualValue, AssertionExpectedSingleValue), nil + case uint32: + expectedUint, err := strconv.ParseUint(*value, 10, 32) + if err != nil { + return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil + } + return compareTestAssertion(comparison, uint32(expectedUint), actualValue, AssertionExpectedSingleValue), nil + case uint64: + expectedUint, err := strconv.ParseUint(*value, 10, 64) + if err != nil { + return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil + } + return compareTestAssertion(comparison, expectedUint, actualValue, AssertionExpectedSingleValue), nil + case uint: + expectedUint, err := strconv.ParseUint(*value, 10, 64) + if err != nil { + return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil + } + return compareTestAssertion(comparison, uint(expectedUint), actualValue, AssertionExpectedSingleValue), nil + case float64: + expectedFloat, err := strconv.ParseFloat(*value, 64) + if err != nil { + return fmt.Sprintf("Could not compare non float value '%s', with %f", *value, actualValue), nil + } + return compareTestAssertion(comparison, expectedFloat, actualValue, AssertionExpectedSingleValue), nil + case float32: + expectedFloat, err := strconv.ParseFloat(*value, 32) + if err != nil { + return fmt.Sprintf("Could not compare non float value '%s', with %f", *value, actualValue), nil + } + return compareTestAssertion(comparison, float32(expectedFloat), actualValue, AssertionExpectedSingleValue), nil + case decimal.Decimal: + expectedDecimal, err := decimal.NewFromString(*value) + if err != nil { + return fmt.Sprintf("Could not compare non decimal value '%s', with %s", *value, actualValue), nil + } + return compareDecimals(comparison, expectedDecimal, actualValue, AssertionExpectedSingleValue), nil + case time.Time: + expectedTime, format, err := parseTestsDate(*value) + if err != nil { + return fmt.Sprintf("%s does not appear to be a valid date", *value), nil + } + return compareDates(comparison, expectedTime, actualValue, format, AssertionExpectedSingleValue), nil + case *val.TextStorage, string: + actualString, err := GetStringColAsString(sqlCtx, actualValue) + if err != nil { + return "", err + } + return compareTestAssertion(comparison, *value, *actualString, AssertionExpectedSingleValue), nil + default: + return fmt.Sprintf("Type %T is not supported. Open an issue at https://github.com/dolthub/dolt/issues to see it added", actualValue), nil + } +} + +func expectRows(sqlCtx *sql.Context, comparison string, value *string, queryResult sql.RowIter) (message string, err error) { + if value == nil { + return "null is not a valid assertion for expected_rows", nil + } + expectedRows, err := strconv.Atoi(*value) + if err != nil { + return fmt.Sprintf("cannot run assertion on non integer value: %s", *value), nil + } + + var numRows int + for { + _, err := queryResult.Next(sqlCtx) + if err == io.EOF { + break + } else if err != nil { + return "", err + } + numRows++ + } + return compareTestAssertion(comparison, expectedRows, numRows, AssertionExpectedRows), nil +} + +func expectColumns(sqlCtx *sql.Context, comparison string, value *string, queryResult sql.RowIter) (message string, err error) { + if value == nil { + return "null is not a valid assertion for expected_rows", nil + } + expectedColumns, err := strconv.Atoi(*value) + if err != nil { + return fmt.Sprintf("cannot run assertion on non integer value: %s", *value), nil + } + + var numColumns int + row, err := queryResult.Next(sqlCtx) + if err != nil && err != io.EOF { + return "", err + } + numColumns = len(row) + return compareTestAssertion(comparison, expectedColumns, numColumns, AssertionExpectedColumns), nil +} + +// compareTestAssertion is a generic function used for comparing string, ints, floats. +// It takes in a comparison string from one of: "==", "!=", "<", ">", "<=", ">=" +// It returns a string. The string is empty if the assertion passed, or has a message explaining the failure otherwise +func compareTestAssertion[T constraints.Ordered](comparison string, expectedValue, actualValue T, assertionType string) string { + switch comparison { + case "==": + if actualValue != expectedValue { + return fmt.Sprintf("Assertion failed: %s equal to %v, got %v", assertionType, expectedValue, actualValue) + } + case "!=": + if actualValue == expectedValue { + return fmt.Sprintf("Assertion failed: %s not equal to %v, got %v", assertionType, expectedValue, actualValue) + } + case "<": + if actualValue >= expectedValue { + return fmt.Sprintf("Assertion failed: %s less than %v, got %v", assertionType, expectedValue, actualValue) + } + case "<=": + if actualValue > expectedValue { + return fmt.Sprintf("Assertion failed: %s less than or equal to %v, got %v", assertionType, expectedValue, actualValue) + } + case ">": + if actualValue <= expectedValue { + return fmt.Sprintf("Assertion failed: %s greater than %v, got %v", assertionType, expectedValue, actualValue) + } + case ">=": + if actualValue < expectedValue { + return fmt.Sprintf("Assertion failed: %s greater than or equal to %v, got %v", assertionType, expectedValue, actualValue) + } + default: + return fmt.Sprintf("%s is not a valid comparison type", comparison) + } + return "" +} + +// parseTestsDate is an internal function that parses the queried string according to allowed time formats for dolt_tests. +// It returns the parsed time, the format that succeeded, and an error if applicable. +func parseTestsDate(value string) (parsedTime time.Time, format string, err error) { + // List of valid formats + formats := []string{ + time.DateOnly, + time.DateTime, + time.TimeOnly, + time.RFC3339, + time.RFC1123Z, + } + + for _, format := range formats { + if parsedTime, parseErr := time.Parse(format, value); parseErr == nil { + return parsedTime, format, nil + } else { + err = parseErr + } + } + return time.Time{}, "", err +} + +// compareDates is a function used for comparing time values. +// It takes in a comparison string from one of: "==", "!=", "<", ">", "<=", ">=" +// It returns a string. The string is empty if the assertion passed, or has a message explaining the failure otherwise +func compareDates(comparison string, expectedValue, realValue time.Time, format string, assertionType string) string { + expectedStr := expectedValue.Format(format) + realStr := realValue.Format(format) + switch comparison { + case "==": + if !expectedValue.Equal(realValue) { + return fmt.Sprintf("Assertion failed: %s equal to %s, got %s", assertionType, expectedStr, realStr) + } + case "!=": + if expectedValue.Equal(realValue) { + return fmt.Sprintf("Assertion failed: %s not equal to %s, got %s", assertionType, expectedStr, realStr) + } + case "<": + if realValue.Equal(expectedValue) || realValue.After(expectedValue) { + return fmt.Sprintf("Assertion failed: %s less than %s, got %s", assertionType, expectedStr, realStr) + } + case "<=": + if realValue.After(expectedValue) { + return fmt.Sprintf("Assertion failed: %s less than or equal to %s, got %s", assertionType, expectedStr, realStr) + } + case ">": + if realValue.Before(expectedValue) || realValue.Equal(expectedValue) { + return fmt.Sprintf("Assertion failed: %s greater than %s, got %s", assertionType, expectedStr, realStr) + } + case ">=": + if realValue.Before(expectedValue) { + return fmt.Sprintf("Assertion failed: %s greater than or equal to %s, got %s", assertionType, expectedStr, realStr) + } + default: + return fmt.Sprintf("%s is not a valid comparison type", comparison) + } + return "" +} + +// compareDecimals is a function used for comparing decimals. +// It takes in a comparison string from one of: "==", "!=", "<", ">", "<=", ">=" +// It returns a string. The string is empty if the assertion passed, or has a message explaining the failure otherwise +func compareDecimals(comparison string, expectedValue, realValue decimal.Decimal, assertionType string) string { + switch comparison { + case "==": + if !expectedValue.Equal(realValue) { + return fmt.Sprintf("Assertion failed: %s equal to %v, got %v", assertionType, expectedValue, realValue) + } + case "!=": + if expectedValue.Equal(realValue) { + return fmt.Sprintf("Assertion failed: %s not equal to %v, got %v", assertionType, expectedValue, realValue) + } + case "<": + if realValue.GreaterThanOrEqual(expectedValue) { + return fmt.Sprintf("Assertion failed: %s less than %v, got %v", assertionType, expectedValue, realValue) + } + case "<=": + if realValue.GreaterThan(expectedValue) { + return fmt.Sprintf("Assertion failed: %s less than or equal to %v, got %v", assertionType, expectedValue, realValue) + } + case ">": + if realValue.LessThanOrEqual(expectedValue) { + return fmt.Sprintf("Assertion failed: %s greater than %v, got %v", assertionType, expectedValue, realValue) + } + case ">=": + if realValue.LessThan(expectedValue) { + return fmt.Sprintf("Assertion failed: %s greater than or equal to %v, got %v", assertionType, expectedValue, realValue) + } + default: + return fmt.Sprintf("%s is not a valid comparison type", comparison) + } + return "" +} + +// getTinyIntColAsBool returns the value interface{} as a bool +// This is necessary because the query engine may return a tinyint column as a bool, int, or other types. +// Based on GetTinyIntColAsBool from commands/utils.go, which we can't depend on here due to package cycles. +func getInterfaceAsBool(col interface{}) (bool, error) { + switch v := col.(type) { + case bool: + return v, nil + case int: + return v == 1, nil + case int8: + return v == 1, nil + case int16: + return v == 1, nil + case int32: + return v == 1, nil + case int64: + return v == 1, nil + case uint: + return v == 1, nil + case uint8: + return v == 1, nil + case uint16: + return v == 1, nil + case uint32: + return v == 1, nil + case uint64: + return v == 1, nil + case string: + return v == "1", nil + default: + return false, fmt.Errorf("unexpected type %T, was expecting bool, int, or string", v) + } +} + +// compareBooleans is a function used for comparing boolean values. +// It takes in a comparison string from one of: "==", "!=" +// It returns a string. The string is empty if the assertion passed, or has a message explaining the failure otherwise +func compareBooleans(comparison string, expectedValue, realValue bool, assertionType string) string { + switch comparison { + case "==": + if expectedValue != realValue { + return fmt.Sprintf("Assertion failed: %s equal to %t, got %t", assertionType, expectedValue, realValue) + } + case "!=": + if expectedValue == realValue { + return fmt.Sprintf("Assertion failed: %s not equal to %t, got %t", assertionType, expectedValue, realValue) + } + default: + return fmt.Sprintf("%s is not a valid comparison for boolean values. Only '==' and '!=' are supported", comparison) + } + return "" +} + +// compareNullValue is a function used for comparing a null value. +// It takes in a comparison string from one of: "==", "!=" +// It returns a string. The string is empty if the assertion passed, or has a message explaining the failure otherwise +func compareNullValue(comparison string, actualValue interface{}, assertionType string) string { + switch comparison { + case "==": + if actualValue != nil { + return fmt.Sprintf("Assertion failed: %s equal to NULL, got %v", assertionType, actualValue) + } + case "!=": + if actualValue == nil { + return fmt.Sprintf("Assertion failed: %s not equal to NULL, got NULL", assertionType) + } + default: + return fmt.Sprintf("%s is not a valid comparison for NULL values", comparison) + } + return "" +} + +// GetStringColAsString is a function that returns a text column as a string. +// This is necessary as the dolt_tests system table returns *val.TextStorage types under certain situations, +// so we use a special parser to get the correct string values +func GetStringColAsString(sqlCtx *sql.Context, tableValue interface{}) (*string, error) { + if ts, ok := tableValue.(*val.TextStorage); ok { + str, err := ts.Unwrap(sqlCtx) + return &str, err + } else if str, ok := tableValue.(string); ok { + return &str, nil + } else if tableValue == nil { + return nil, nil + } else { + return nil, fmt.Errorf("unexpected type %T, was expecting string", tableValue) + } +} diff --git a/go/libraries/doltcore/sqle/enginetest/dolt_engine_test.go b/go/libraries/doltcore/sqle/enginetest/dolt_engine_test.go index 0e0dffc0049..58f2edda212 100644 --- a/go/libraries/doltcore/sqle/enginetest/dolt_engine_test.go +++ b/go/libraries/doltcore/sqle/enginetest/dolt_engine_test.go @@ -1239,6 +1239,11 @@ func TestDoltDdlScripts(t *testing.T) { RunDoltDdlScripts(t, harness) } +func TestDoltCommitVerificationScripts(t *testing.T) { + harness := newDoltEnginetestHarness(t) + RunDoltCommitVerificationScripts(t, harness) +} + func TestBrokenDdlScripts(t *testing.T) { for _, script := range BrokenDDLScripts { t.Skip(script.Name) diff --git a/go/libraries/doltcore/sqle/enginetest/dolt_engine_tests.go b/go/libraries/doltcore/sqle/enginetest/dolt_engine_tests.go index 6847fa16c94..7d454c8d505 100755 --- a/go/libraries/doltcore/sqle/enginetest/dolt_engine_tests.go +++ b/go/libraries/doltcore/sqle/enginetest/dolt_engine_tests.go @@ -2200,3 +2200,12 @@ func RunTransactionTestsWithEngineSetup(t *testing.T, setupEngine func(*gms.Engi }) } } + +func RunDoltCommitVerificationScripts(t *testing.T, harness DoltEnginetestHarness) { + for _, script := range DoltCommitVerificationScripts { + harness := harness.NewHarness(t) + + enginetest.TestScript(t, harness, script) + harness.Close() + } +} diff --git a/go/libraries/doltcore/sqle/enginetest/dolt_harness.go b/go/libraries/doltcore/sqle/enginetest/dolt_harness.go index 127ac5ec510..97d2bc748b8 100644 --- a/go/libraries/doltcore/sqle/enginetest/dolt_harness.go +++ b/go/libraries/doltcore/sqle/enginetest/dolt_harness.go @@ -190,7 +190,6 @@ func (d *DoltHarness) resetScripts() []setup.SetupScript { for i := range dbs { db := dbs[i] resetCmds = append(resetCmds, setup.SetupScript{fmt.Sprintf("use %s", db)}) - // Any auto increment tables must be dropped and recreated to get a fresh state for the global auto increment // sequence trackers _, aiTables := enginetest.MustQuery(ctx, d.engine, @@ -218,6 +217,7 @@ func (d *DoltHarness) resetScripts() []setup.SetupScript { resetCmds = append(resetCmds, setup.SetupScript{fmt.Sprintf("drop database if exists %s", db)}) } } + resetCmds = append(resetCmds, setup.SetupScript{"use mydb"}) return resetCmds } @@ -229,7 +229,7 @@ func commitScripts(dbs []string) []setup.SetupScript { db := dbs[i] commitCmds = append(commitCmds, fmt.Sprintf("use %s", db)) commitCmds = append(commitCmds, "call dolt_add('.')") - commitCmds = append(commitCmds, fmt.Sprintf("call dolt_commit('--allow-empty', '-am', 'checkpoint enginetest database %s', '--date', '1970-01-01T12:00:00')", db)) + commitCmds = append(commitCmds, fmt.Sprintf("call dolt_commit('--allow-empty', '-am', 'checkpoint enginetest database %s', '--date', '1970-01-01T12:00:00', '--skip-verification')", db)) } commitCmds = append(commitCmds, "use mydb") return []setup.SetupScript{commitCmds} diff --git a/go/libraries/doltcore/sqle/enginetest/dolt_queries_commit_verification.go b/go/libraries/doltcore/sqle/enginetest/dolt_queries_commit_verification.go new file mode 100644 index 00000000000..0a2251cc953 --- /dev/null +++ b/go/libraries/doltcore/sqle/enginetest/dolt_queries_commit_verification.go @@ -0,0 +1,538 @@ +// Copyright 2025 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package enginetest + +import ( + "regexp" + + "github.com/dolthub/go-mysql-server/enginetest" + "github.com/dolthub/go-mysql-server/enginetest/queries" + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/types" + + "github.com/dolthub/dolt/go/store/hash" +) + +// commitHashValidator validates commit hash format (32 character hex) +type commitHashValidator struct{} + +var _ enginetest.CustomValueValidator = &commitHashValidator{} + +func (chv *commitHashValidator) Validate(val interface{}) (bool, error) { + h, ok := val.(string) + if !ok { + return false, nil + } + + _, ok = hash.MaybeParse(h) + return ok, nil +} + +// successfulRebaseMessageValidator validates successful rebase message format +type successfulRebaseMessageValidator struct{} + +var _ enginetest.CustomValueValidator = &successfulRebaseMessageValidator{} +var successfulRebaseRegex = regexp.MustCompile(`^Successfully rebased.*`) + +func (srmv *successfulRebaseMessageValidator) Validate(val interface{}) (bool, error) { + message, ok := val.(string) + if !ok { + return false, nil + } + return successfulRebaseRegex.MatchString(message), nil +} + +var commitHash = &commitHashValidator{} +var successfulRebaseMessage = &successfulRebaseMessageValidator{} + +var DoltCommitVerificationScripts = []queries.ScriptTest{ + { + Name: "test verification system variables exist and have correct defaults", + Assertions: []queries.ScriptTestAssertion{ + { + Query: "SHOW GLOBAL VARIABLES LIKE 'dolt_commit_verification_groups'", + Expected: []sql.Row{ + {"dolt_commit_verification_groups", ""}, + }, + }, + { // Test harness bleeds GLOBAL variable changes across tests, so reset after each test. + Query: "SET GLOBAL dolt_commit_verification_groups = ''", + SkipResultsCheck: true, + }, + }, + }, + { + Name: "test verification system variables can be set", + Assertions: []queries.ScriptTestAssertion{ + { + Query: "SET GLOBAL dolt_commit_verification_groups = '*'", + Expected: []sql.Row{{types.OkResult{}}}, + }, + { + Query: "SHOW GLOBAL VARIABLES LIKE 'dolt_commit_verification_groups'", + Expected: []sql.Row{ + {"dolt_commit_verification_groups", "*"}, + }, + }, + { + Query: "SET GLOBAL dolt_commit_verification_groups = 'unit,integration'", + Expected: []sql.Row{{types.OkResult{}}}, + }, + { + Query: "SHOW GLOBAL VARIABLES LIKE 'dolt_commit_verification_groups'", + Expected: []sql.Row{ + {"dolt_commit_verification_groups", "unit,integration"}, + }, + }, + { // Test harness bleeds GLOBAL variable changes across tests, so reset after each test. + Query: "SET GLOBAL dolt_commit_verification_groups = ''", + SkipResultsCheck: true, + }, + }, + }, + { + Name: "commit verification enabled - all tests pass", + SetUpScript: []string{ + "SET GLOBAL dolt_commit_verification_groups = '*'", + "CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))", + "INSERT INTO users VALUES (1, 'Alice', 'alice@example.com'), (2, 'Bob', 'bob@example.com')", + "INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " + + "('test_users_count', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '2'), " + + "('test_alice_exists', 'unit', 'SELECT COUNT(*) FROM users WHERE name = \"Alice\"', 'expected_single_value', '==', '1')", + "CALL dolt_add('.')", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "CALL dolt_commit('-m', 'Commit with passing tests')", + ExpectedColumns: sql.Schema{ + {Name: "hash", Type: types.LongText, Nullable: false}, + }, + Expected: []sql.Row{{commitHash}}, + }, + { // Test harness bleeds GLOBAL variable changes across tests, so reset after each test. + Query: "SET GLOBAL dolt_commit_verification_groups = ''", + SkipResultsCheck: true, + }, + }, + }, + { + Name: "commit verification enabled - tests fail, commit aborted", + SetUpScript: []string{ + "SET GLOBAL dolt_commit_verification_groups = '*'", + "CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))", + "INSERT INTO users VALUES (1, 'Alice', 'alice@example.com'), (2, 'Bob', 'bob@example.com')", + "INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " + + "('test_users_count', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '2'), " + + "('test_will_fail', 'integration', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '999')", + "CALL dolt_add('.')", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "CALL dolt_commit('-m', 'Commit that should fail verification')", + ExpectedErrStr: "commit verification failed: test_will_fail (Assertion failed: expected_single_value equal to 999, got 2)", + }, + { + Query: "CALL dolt_commit('--skip-verification','-m', 'skip verification')", + Expected: []sql.Row{{commitHash}}, + }, + { // Test harness bleeds GLOBAL variable changes across tests, so reset after each test. + Query: "SET GLOBAL dolt_commit_verification_groups = ''", + SkipResultsCheck: true, + }, + }, + }, + { + Name: "commit with test verification - specific test groups", + SetUpScript: []string{ + "SET GLOBAL dolt_commit_verification_groups = 'unit'", + "CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))", + "INSERT INTO users VALUES (1, 'Alice', 'alice@example.com'), (2, 'Bob', 'bob@example.com')", + "INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " + + "('test_users_count', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '2'), " + + "('test_will_fail', 'integration', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '999')", + "CALL dolt_add('.')", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "CALL dolt_commit('-m', 'Commit with unit tests only')", + Expected: []sql.Row{{commitHash}}, + }, + { + Query: "SET GLOBAL dolt_commit_verification_groups = 'integration'", + SkipResultsCheck: true, + }, + { + Query: "CALL dolt_commit('--allow-empty', '--amend', '-m', 'fail please')", + ExpectedErrStr: "commit verification failed: test_will_fail (Assertion failed: expected_single_value equal to 999, got 2)", + }, + { + Query: "CALL dolt_commit('--allow-empty', '--amend', '--skip-verification', '-m', 'skip the tests')", + Expected: []sql.Row{{commitHash}}, + }, + { // Test harness bleeds GLOBAL variable changes across tests, so reset after each test. + Query: "SET GLOBAL dolt_commit_verification_groups = ''", + SkipResultsCheck: true, + }, + }, + }, + { + Name: "cherry-pick with test verification enabled - tests pass", + SetUpScript: []string{ + "SET GLOBAL dolt_commit_verification_groups = '*'", + "CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))", + "INSERT INTO users VALUES (1, 'Alice', 'alice@example.com')", + "INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " + + "('test_user_count_update', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '1')", + "CALL dolt_add('.')", + "CALL dolt_commit('--skip-verification', '-m', 'add test')", + "CALL dolt_checkout('-b', 'feature')", + "INSERT INTO users VALUES (2, 'Bob', 'bob@example.com')", + "UPDATE dolt_tests SET assertion_value = '2' WHERE test_name = 'test_user_count_update'", + "CALL dolt_add('.')", + "call dolt_commit_hash_out(@commit_1_hash,'--skip-verification', '-m', 'Add Bob and update test')", + "INSERT INTO users VALUES (3, 'Charlie', 'chuck@exampl.com')", + "CALL dolt_add('.')", + "call dolt_commit_hash_out(@commit_2_hash,'--skip-verification', '-m', 'Add Charlie')", + "CALL dolt_checkout('main')", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "CALL dolt_cherry_pick(@commit_1_hash)", + Expected: []sql.Row{{commitHash, int64(0), int64(0), int64(0)}}, + }, + { + Query: "CALL dolt_cherry_pick(@commit_2_hash)", + ExpectedErrStr: "commit verification failed: test_user_count_update (Assertion failed: expected_single_value equal to 2, got 3)", + }, + { // Test harness bleeds GLOBAL variable changes across tests, so reset after each test. + Query: "SET GLOBAL dolt_commit_verification_groups = ''", + SkipResultsCheck: true, + }, + }, + }, + { + Name: "cherry-pick with test verification enabled - tests fail, aborted", + SetUpScript: []string{ + "SET GLOBAL dolt_commit_verification_groups = '*'", + "CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))", + "INSERT INTO users VALUES (1, 'Alice', 'alice@example.com')", + "INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " + + "('test_users_count', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '1')", + "CALL dolt_add('.')", + "CALL dolt_commit('-m', 'Initial commit')", + "CALL dolt_checkout('-b', 'feature')", + "INSERT INTO users VALUES (2, 'Bob', 'bob@example.com')", + "CALL dolt_add('.')", + "call dolt_commit_hash_out(@commit_hash,'--skip-verification', '-m', 'Add Bob but dont update test')", + "CALL dolt_checkout('main')", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "CALL dolt_cherry_pick(@commit_hash)", + ExpectedErrStr: "commit verification failed: test_users_count (Assertion failed: expected_single_value equal to 1, got 2)", + }, + { + Query: "CALL dolt_cherry_pick('--skip-verification', @commit_hash)", + Expected: []sql.Row{{commitHash, int64(0), int64(0), int64(0)}}, + }, + { + Query: "select * from dolt_test_run('*')", + Expected: []sql.Row{ + {"test_users_count", "unit", "SELECT COUNT(*) FROM users", "FAIL", "Assertion failed: expected_single_value equal to 1, got 2"}, + }, + }, + { // Test harness bleeds GLOBAL variable changes across tests, so reset after each test. + Query: "SET GLOBAL dolt_commit_verification_groups = ''", + SkipResultsCheck: true, + }, + }, + }, + { + Name: "rebase with test verification enabled - tests pass", + SetUpScript: []string{ + "SET GLOBAL dolt_commit_verification_groups = '*'", + "CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))", + "INSERT INTO users VALUES (1, 'Alice', 'alice@example.com')", + "INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " + + "('test_users_count', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '1')", + "CALL dolt_add('.')", + "CALL dolt_commit('-m', 'Initial commit')", + "DELETE FROM users where id = 1", + "INSERT INTO users VALUES (1, 'Zed', 'zed@example.com')", + "CALL dolt_commit('-am', 'drop Alice, add Zed')", // tests still pass here. + "CALL dolt_checkout('-b', 'feature', 'HEAD~1')", + "INSERT INTO users VALUES (2, 'Bob', 'bob@example.com')", + "UPDATE dolt_tests SET assertion_value = '2' WHERE test_name = 'test_users_count'", + "CALL dolt_add('.')", + "CALL dolt_commit('-m', 'Add Bob and update test')", + "INSERT INTO users VALUES (3, 'Charlie', 'charlie@example.com')", + "UPDATE dolt_tests SET assertion_value = '3' WHERE test_name = 'test_users_count'", + "CALL dolt_add('.')", + "CALL dolt_commit('-m', 'Add Charlie, update test')", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "CALL dolt_rebase('main')", + Expected: []sql.Row{{int64(0), successfulRebaseMessage}}, + }, + { // Test harness bleeds GLOBAL variable changes across tests, so reset after each test. + Query: "SET GLOBAL dolt_commit_verification_groups = ''", + SkipResultsCheck: true, + }, + }, + }, + + { + Name: "rebase with test verification enabled - tests fail, aborted", + SetUpScript: []string{ + "SET GLOBAL dolt_commit_verification_groups = '*'", + "CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))", + "INSERT INTO users VALUES (1, 'Alice', 'alice@example.com')", + "INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " + + "('test_users_count', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '1')", + "CALL dolt_add('.')", + "CALL dolt_commit('-m', 'Initial commit')", + "CALL dolt_checkout('-b', 'feature')", + "INSERT INTO users VALUES (2, 'Bob', 'bob@example.com')", + "UPDATE dolt_tests SET assertion_value = '2' WHERE test_name = 'test_users_count'", + "CALL dolt_add('.')", + "CALL dolt_commit('-m', 'Add Bob but dont update test')", + "CALL dolt_checkout('main')", + "INSERT INTO users VALUES (3, 'Charlie', 'charlie@example.com')", + "CALL dolt_add('.')", + "CALL dolt_commit('--skip-verification', '-m', 'Add Charlie')", // this will trip the existing test. + "CALL dolt_checkout('feature')", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "CALL dolt_rebase('main')", + ExpectedErrStr: "commit verification failed: test_users_count (Assertion failed: expected_single_value equal to 2, got 3)", + }, + { + Query: "CALL dolt_rebase('--abort')", + Expected: []sql.Row{{0, "Interactive rebase aborted"}}, + }, + { + Query: "CALL dolt_rebase('--skip-verification', 'main')", + Expected: []sql.Row{{int64(0), successfulRebaseMessage}}, + }, + { + Query: "select * from dolt_test_run('*')", + Expected: []sql.Row{ + {"test_users_count", "unit", "SELECT COUNT(*) FROM users", "FAIL", "Assertion failed: expected_single_value equal to 2, got 3"}, + }, + }, + { // Test harness bleeds GLOBAL variable changes across tests, so reset after each test. + Query: "SET GLOBAL dolt_commit_verification_groups = ''", + SkipResultsCheck: true, + }, + }, + }, + { + Name: "interactive rebase with --skip-verification flag should persist across continue operations", + SetUpScript: []string{ + "SET GLOBAL dolt_commit_verification_groups = '*'", + "CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))", + "INSERT INTO users VALUES (1, 'Alice', 'alice@example.com')", + "INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " + + "('test_users_count', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '1')", + "CALL dolt_add('.')", + "CALL dolt_commit('--skip-verification', '-m', 'Initial commit')", + "CALL dolt_checkout('-b', 'feature')", + "INSERT INTO users VALUES (2, 'Bob', 'bob@example.com')", + "CALL dolt_add('.')", + "CALL dolt_commit('--skip-verification', '-m', 'Add Bob but dont update test')", // This will cause test to fail + "INSERT INTO users VALUES (3, 'Charlie', 'charlie@example.com')", + "CALL dolt_add('.')", + "CALL dolt_commit('--skip-verification', '-m', 'Add Charlie')", + "CALL dolt_checkout('main')", + "INSERT INTO users VALUES (4, 'David', 'david@example.com')", // Add a commit to main to create divergence + "CALL dolt_add('.')", + "CALL dolt_commit('--skip-verification', '-m', 'Add David on main')", + "CALL dolt_checkout('feature')", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "CALL dolt_rebase('--interactive', '--skip-verification', 'main')", + Expected: []sql.Row{{0, "interactive rebase started on branch dolt_rebase_feature; adjust the rebase plan in the dolt_rebase table, then continue rebasing by calling dolt_rebase('--continue')"}}, + }, + { + Query: "CALL dolt_rebase('--continue')", // This should NOT require --skip-verification flag but should still skip tests + Expected: []sql.Row{{int64(0), successfulRebaseMessage}}, + }, + { // Test harness bleeds GLOBAL variable changes across tests, so reset after each test. + Query: "SET GLOBAL dolt_commit_verification_groups = ''", + SkipResultsCheck: true, + }, + }, + }, + { + Name: "test verification with no dolt_tests errors", + SetUpScript: []string{ + "SET GLOBAL dolt_commit_verification_groups = '*'", + "CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))", + "INSERT INTO users VALUES (1, 'Alice', 'alice@example.com')", + "CALL dolt_add('.')", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "CALL dolt_commit('-m', 'Commit without dolt_tests table')", + ExpectedErrStr: "failed to run dolt_test_run for group *: could not find tests for argument: *", + }, + { // Test harness bleeds GLOBAL variable changes across tests, so reset after each test. + Query: "SET GLOBAL dolt_commit_verification_groups = ''", + SkipResultsCheck: true, + }, + }, + }, + { + Name: "test verification with mixed test groups - only specified groups run", + SetUpScript: []string{ + "SET GLOBAL dolt_commit_verification_groups = 'unit'", + "CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))", + "INSERT INTO users VALUES (1, 'Alice', 'alice@example.com'), (2, 'Bob', 'bob@example.com')", + "INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " + + "('test_users_unit', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '2'), " + + "('test_users_integration', 'integration', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '999')", + "CALL dolt_add('.')", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "CALL dolt_commit('-m', 'Commit with unit tests only - should pass')", + Expected: []sql.Row{{commitHash}}, + }, + { // Test harness bleeds GLOBAL variable changes across tests, so reset after each test. + Query: "SET GLOBAL dolt_commit_verification_groups = ''", + SkipResultsCheck: true, + }, + }, + }, + { + Name: "test verification error message includes test details", + SetUpScript: []string{ + "SET GLOBAL dolt_commit_verification_groups = '*'", + "CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))", + "INSERT INTO users VALUES (1, 'Alice', 'alice@example.com'), (2, 'Bob', 'bob@example.com')", + "INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " + + "('test_specific_failure', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '999')", + "CALL dolt_add('.')", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "CALL dolt_commit('-m', 'Commit with specific test failure')", + ExpectedErrStr: "commit verification failed: test_specific_failure (Assertion failed: expected_single_value equal to 999, got 2)", + }, + { // Test harness bleeds GLOBAL variable changes across tests, so reset after each test. + Query: "SET GLOBAL dolt_commit_verification_groups = ''", + SkipResultsCheck: true, + }, + }, + }, + { + Name: "merge with test verification enabled - tests pass", + SetUpScript: []string{ + "SET GLOBAL dolt_commit_verification_groups = '*'", + "CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))", + "INSERT INTO users VALUES (1, 'Alice', 'alice@example.com')", + "INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " + + "('test_alice_exists', 'unit', 'SELECT COUNT(*) FROM users WHERE name = \"Alice\"', 'expected_single_value', '==', '1')", + "CALL dolt_add('.')", + "CALL dolt_commit('-m', 'Initial commit')", + "CALL dolt_checkout('-b', 'feature')", + "INSERT INTO users VALUES (2, 'Bob', 'bob@example.com')", + "INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " + + "('test_bob_exists', 'unit', 'SELECT COUNT(*) FROM users WHERE name = \"Bob\"', 'expected_single_value', '==', '1')", + "CALL dolt_add('.')", + "CALL dolt_commit('--skip-verification', '-m', 'Add Bob')", + "CALL dolt_checkout('main')", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "CALL dolt_merge('feature')", + Expected: []sql.Row{{commitHash, int64(1), int64(0), "merge successful"}}, + }, + { // Test harness bleeds GLOBAL variable changes across tests, so reset after each test. + Query: "SET GLOBAL dolt_commit_verification_groups = ''", + SkipResultsCheck: true, + }, + }, + }, + { + Name: "merge with test verification enabled - tests fail, merge aborted", + SetUpScript: []string{ + "SET GLOBAL dolt_commit_verification_groups = '*'", + "CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))", + "INSERT INTO users VALUES (1, 'Alice', 'alice@example.com')", + "INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " + + "('test_will_fail', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '999')", + "CALL dolt_add('.')", + "CALL dolt_commit('--skip-verification', '-m', 'Initial commit with failing test')", + "CALL dolt_checkout('-b', 'feature')", + "INSERT INTO users VALUES (2, 'Bob', 'bob@example.com')", + "CALL dolt_add('.')", + "CALL dolt_commit('--skip-verification', '-m', 'Add Bob')", + "CALL dolt_checkout('main')", + "INSERT INTO users VALUES (3, 'Charlie', 'charlie@example.com')", + "CALL dolt_add('.')", + "CALL dolt_commit('--skip-verification', '-m', 'Add Charlie to force non-FF merge')", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "CALL dolt_merge('feature')", + ExpectedErrStr: "commit verification failed: test_will_fail (Assertion failed: expected_single_value equal to 999, got 3)", + }, + { // Test harness bleeds GLOBAL variable changes across tests, so reset after each test. + Query: "SET GLOBAL dolt_commit_verification_groups = ''", + SkipResultsCheck: true, + }, + }, + }, + { + Name: "merge with --skip-verification flag bypasses verification", + SetUpScript: []string{ + "SET GLOBAL dolt_commit_verification_groups = '*'", + "CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))", + "INSERT INTO users VALUES (1, 'Alice', 'alice@example.com')", + "INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " + + "('test_will_fail', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '999')", + "CALL dolt_add('.')", + "CALL dolt_commit('--skip-verification', '-m', 'Initial commit with failing test')", + "CALL dolt_checkout('-b', 'feature')", + "INSERT INTO users VALUES (2, 'Bob', 'bob@example.com')", + "CALL dolt_add('.')", + "CALL dolt_commit('--skip-verification', '-m', 'Add Bob')", + "CALL dolt_checkout('main')", + "INSERT INTO users VALUES (3, 'Charlie', 'charlie@example.com')", + "CALL dolt_add('.')", + "CALL dolt_commit('--skip-verification', '-m', 'Add Charlie to force non-FF merge')", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "CALL dolt_merge('--skip-verification', 'feature')", + Expected: []sql.Row{{commitHash, int64(0), int64(0), "merge successful"}}, + }, + { + Query: "select * from dolt_test_run('*')", + Expected: []sql.Row{ + {"test_will_fail", "unit", "SELECT COUNT(*) FROM users", "FAIL", "Assertion failed: expected_single_value equal to 999, got 3"}, + }, + }, + { // Test harness bleeds GLOBAL variable changes across tests, so reset after each test. + Query: "SET GLOBAL dolt_commit_verification_groups = ''", + SkipResultsCheck: true, + }, + }, + }, +} diff --git a/go/libraries/doltcore/sqle/system_variables.go b/go/libraries/doltcore/sqle/system_variables.go index 81d256f80a2..7162d030849 100644 --- a/go/libraries/doltcore/sqle/system_variables.go +++ b/go/libraries/doltcore/sqle/system_variables.go @@ -22,6 +22,7 @@ import ( "github.com/dolthub/go-mysql-server/sql/types" _ "github.com/dolthub/go-mysql-server/sql/variables" + "github.com/dolthub/dolt/go/libraries/doltcore/env/actions" "github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess" ) @@ -292,6 +293,13 @@ var DoltSystemVariables = []sql.SystemVariable{ Type: types.NewSystemBoolType(dsess.AllowCICreation), Default: int8(0), }, + &sql.MysqlSystemVariable{ + Name: actions.DoltCommitVerificationGroups, + Dynamic: true, + Scope: sql.GetMysqlScope(sql.SystemVariableScope_Global), + Type: types.NewSystemStringType(actions.DoltCommitVerificationGroups), + Default: "", + }, } func AddDoltSystemVariables() { diff --git a/go/serial/workingset.fbs b/go/serial/workingset.fbs index 84b5ee75304..ce1d48fcc4e 100644 --- a/go/serial/workingset.fbs +++ b/go/serial/workingset.fbs @@ -67,6 +67,10 @@ table RebaseState { // The rebasing_started field indicates if execution of the rebase plan has been started or not. Once execution of the // plan has been started, the last_attempted_step field holds a reference to the most recent plan step attempted. rebasing_started:bool; + + // When set to true, the rebase process will skip performing commit + // verification if it would otherwise run. + skip_verification:bool; } // KEEP THIS IN SYNC WITH fileidentifiers.go diff --git a/go/store/datas/dataset.go b/go/store/datas/dataset.go index 75f94f16c55..0ac191b9a7e 100644 --- a/go/store/datas/dataset.go +++ b/go/store/datas/dataset.go @@ -169,6 +169,7 @@ type RebaseState struct { commitBecomesEmptyHandling uint8 emptyCommitHandling uint8 rebasingStarted bool + skipVerification bool } func (rs *RebaseState) PreRebaseWorkingAddr() hash.Hash { @@ -206,6 +207,10 @@ func (rs *RebaseState) EmptyCommitHandling(_ context.Context) uint8 { return rs.emptyCommitHandling } +func (rs *RebaseState) SkipVerification(_ context.Context) bool { + return rs.skipVerification +} + type MergeState struct { preMergeWorkingAddr *hash.Hash fromCommitAddr *hash.Hash @@ -457,6 +462,7 @@ func (h serialWorkingSetHead) HeadWorkingSet() (*WorkingSetHead, error) { rebaseState.EmptyCommitHandling(), rebaseState.LastAttemptedStep(), rebaseState.RebasingStarted(), + rebaseState.SkipVerification(), ) } diff --git a/go/store/datas/workingset.go b/go/store/datas/workingset.go index 05ec22dce32..f9e784ae2dd 100755 --- a/go/store/datas/workingset.go +++ b/go/store/datas/workingset.go @@ -196,6 +196,7 @@ func workingset_flatbuffer(working hash.Hash, staged *hash.Hash, mergeState *Mer serial.RebaseStateAddEmptyCommitHandling(builder, rebaseState.emptyCommitHandling) serial.RebaseStateAddLastAttemptedStep(builder, rebaseState.lastAttemptedStep) serial.RebaseStateAddRebasingStarted(builder, rebaseState.rebasingStarted) + serial.RebaseStateAddSkipVerification(builder, rebaseState.skipVerification) rebaseStateOffset = serial.RebaseStateEnd(builder) } @@ -264,7 +265,7 @@ func NewMergeState( } } -func NewRebaseState(preRebaseWorkingRoot hash.Hash, commitAddr hash.Hash, branch string, commitBecomesEmptyHandling uint8, emptyCommitHandling uint8, lastAttemptedStep float32, rebasingStarted bool) *RebaseState { +func NewRebaseState(preRebaseWorkingRoot hash.Hash, commitAddr hash.Hash, branch string, commitBecomesEmptyHandling uint8, emptyCommitHandling uint8, lastAttemptedStep float32, rebasingStarted bool, skipVerification bool) *RebaseState { return &RebaseState{ preRebaseWorkingAddr: &preRebaseWorkingRoot, ontoCommitAddr: &commitAddr, @@ -273,6 +274,7 @@ func NewRebaseState(preRebaseWorkingRoot hash.Hash, commitAddr hash.Hash, branch emptyCommitHandling: emptyCommitHandling, lastAttemptedStep: lastAttemptedStep, rebasingStarted: rebasingStarted, + skipVerification: skipVerification, } } diff --git a/integration-tests/bats/commit_verification.bats b/integration-tests/bats/commit_verification.bats new file mode 100644 index 00000000000..7a0d9b7b861 --- /dev/null +++ b/integration-tests/bats/commit_verification.bats @@ -0,0 +1,253 @@ +#!/usr/bin/env bats +load $BATS_TEST_DIRNAME/helper/common.bash + +setup() { + setup_common + + dolt sql <