Skip to content

Commit

Permalink
feat: implement snowpipe destination config validation (#5472)
Browse files Browse the repository at this point in the history
  • Loading branch information
shekhar-rudder authored Feb 11, 2025
1 parent 4854220 commit a3810fc
Show file tree
Hide file tree
Showing 7 changed files with 192 additions and 86 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import (
"github.com/rudderlabs/rudder-server/utils/timeutil"
"github.com/rudderlabs/rudder-server/warehouse/integrations/manager"
whutils "github.com/rudderlabs/rudder-server/warehouse/utils"
"github.com/rudderlabs/rudder-server/warehouse/validations"
)

var json = jsoniter.ConfigCompatibleWithStandardLibrary
Expand All @@ -57,6 +58,7 @@ func New(
now: timeutil.Now,
channelCache: sync.Map{},
polledImportInfoMap: make(map[string]*importInfo),
validator: validations.NewDestinationValidator(),
}

m.config.client.url = conf.GetString("SnowpipeStreaming.Client.URL", "http://localhost:9078")
Expand Down Expand Up @@ -138,6 +140,15 @@ func (m *Manager) retryableClient() *retryablehttp.Client {
return client
}

func (m *Manager) validateConfig(ctx context.Context, dest *backendconfig.DestinationT) error {
dest.Config["useKeyPairAuth"] = true // Since we are currently only supporting key pair auth
response := m.validator.Validate(ctx, dest)
if response.Success {
return nil
}
return errors.New(response.Error)
}

func (m *Manager) Now() time.Time {
return m.now()
}
Expand Down Expand Up @@ -176,6 +187,10 @@ func (m *Manager) Upload(asyncDest *common.AsyncDestinationStruct) common.AsyncU
switch {
case errors.Is(err, errAuthz):
m.setBackOff(err)
validationError := m.validateConfig(ctx, asyncDest.Destination)
if validationError != nil {
err = fmt.Errorf("failed to validate snowpipe credentials: %s", validationError.Error())
}
return m.failedJobs(asyncDest, err.Error())
case errors.Is(err, errBackoff):
return m.failedJobs(asyncDest, err.Error())
Expand Down Expand Up @@ -225,6 +240,10 @@ func (m *Manager) Upload(asyncDest *common.AsyncDestinationStruct) common.AsyncU
if !isBackoffSet {
isBackoffSet = true
m.setBackOff(err)
validationError := m.validateConfig(ctx, asyncDest.Destination)
if validationError != nil && failedReason == "" {
failedReason = fmt.Sprintf("failed to validate snowpipe credentials: %s", validationError.Error())
}
}
case errors.Is(err, errBackoff):
shouldResetBackoff = false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/rudderlabs/rudder-server/warehouse/integrations/manager"
"github.com/rudderlabs/rudder-server/warehouse/integrations/snowflake"
whutils "github.com/rudderlabs/rudder-server/warehouse/utils"
"github.com/rudderlabs/rudder-server/warehouse/validations"
)

type mockAPI struct {
Expand Down Expand Up @@ -68,6 +69,22 @@ func (m *mockManager) CreateTable(context.Context, string, whutils.ModelTableSch
return nil
}

type mockValidator struct {
err error
}

func (m *mockValidator) Validate(_ context.Context, _ *backendconfig.DestinationT) *validations.DestinationValidationResponse {
if m.err != nil {
return &validations.DestinationValidationResponse{
Success: false,
Error: m.err.Error(),
}
}
return &validations.DestinationValidationResponse{
Success: true,
}
}

var (
usersChannelResponse = &model.ChannelResponse{
ChannelID: "test-users-channel",
Expand Down Expand Up @@ -104,6 +121,7 @@ func TestSnowpipeStreaming(t *testing.T) {
},
Config: make(map[string]interface{}),
}
validations.Init()

t.Run("Upload with invalid file path", func(t *testing.T) {
statsStore, err := memstats.New()
Expand Down Expand Up @@ -405,34 +423,99 @@ func TestSnowpipeStreaming(t *testing.T) {
require.False(t, sm.isInBackoff())
})

t.Run("Upload with discards table authorization error should mark the job as failed", func(t *testing.T) {
statsStore, err := memstats.New()
require.NoError(t, err)
t.Run("destination config validation", func(t *testing.T) {
testCases := []struct {
name string
validationError error
expectedFailedReason string
}{
{
name: "should return validation error",
validationError: fmt.Errorf("missing permissions to do xyz"),
expectedFailedReason: "missing permissions to do xyz",
},
{
name: "should not return any error",
validationError: nil,
expectedFailedReason: "failed to create schema",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
sm := New(config.New(), logger.NOP, stats.NOP, destination)
sm.channelCache.Store("RUDDER_DISCARDS", rudderDiscardsChannelResponse)
sm.api = &mockAPI{
createChannelOutputMap: map[string]func() (*model.ChannelResponse, error){
"USERS": func() (*model.ChannelResponse, error) {
return &model.ChannelResponse{Code: internalapi.ErrSchemaDoesNotExistOrNotAuthorized}, nil
},
},
}
sm.managerCreator = func(_ context.Context, _ whutils.ModelWarehouse, _ *config.Config, _ logger.Logger, _ stats.Stats) (manager.Manager, error) {
sf := snowflake.New(config.New(), logger.NOP, stats.NOP)
mm := newMockManager(sf)
mm.createSchemaErr = fmt.Errorf("failed to create schema")
return mm, nil
}
sm.validator = &mockValidator{err: tc.validationError}
asyncDestStruct := &common.AsyncDestinationStruct{
Destination: destination,
FileName: "testdata/successful_user_records.txt",
}
output := sm.Upload(asyncDestStruct)
require.Equal(t, 2, output.FailedCount)
require.Equal(t, 0, output.AbortCount)
require.Contains(t, output.FailedReason, tc.expectedFailedReason)
})
}
})

sm := New(config.New(), logger.NOP, statsStore, destination)
sm.api = &mockAPI{
createChannelOutputMap: map[string]func() (*model.ChannelResponse, error){
"RUDDER_DISCARDS": func() (*model.ChannelResponse, error) {
return &model.ChannelResponse{Code: internalapi.ErrSchemaDoesNotExistOrNotAuthorized}, nil
},
t.Run("Upload with discards table authorization error should mark the job as failed", func(t *testing.T) {
testCases := []struct {
name string
validationError error
expectedFailedReason string
}{
{
name: "authorization error",
validationError: fmt.Errorf("authorization error"),
expectedFailedReason: "failed to validate snowpipe credentials: authorization error",
},
{
name: "other error",
validationError: nil,
expectedFailedReason: "failed to create schema",
},
}
sm.managerCreator = func(_ context.Context, _ whutils.ModelWarehouse, _ *config.Config, _ logger.Logger, _ stats.Stats) (manager.Manager, error) {
sf := snowflake.New(config.New(), logger.NOP, stats.NOP)
mm := newMockManager(sf)
mm.createSchemaErr = fmt.Errorf("failed to create schema")
return mm, nil
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
sm := New(config.New(), logger.NOP, stats.NOP, destination)
sm.api = &mockAPI{
createChannelOutputMap: map[string]func() (*model.ChannelResponse, error){
"RUDDER_DISCARDS": func() (*model.ChannelResponse, error) {
return &model.ChannelResponse{Code: internalapi.ErrSchemaDoesNotExistOrNotAuthorized}, nil
},
},
}
sm.managerCreator = func(_ context.Context, _ whutils.ModelWarehouse, _ *config.Config, _ logger.Logger, _ stats.Stats) (manager.Manager, error) {
sf := snowflake.New(config.New(), logger.NOP, stats.NOP)
mm := newMockManager(sf)
mm.createSchemaErr = fmt.Errorf("failed to create schema")
return mm, nil
}
sm.validator = &mockValidator{err: tc.validationError}
output := sm.Upload(&common.AsyncDestinationStruct{
ImportingJobIDs: []int64{1},
Destination: destination,
FileName: "testdata/successful_user_records.txt",
})
require.Equal(t, 1, output.FailedCount)
require.Equal(t, 0, output.AbortCount)
require.Contains(t, output.FailedReason, tc.expectedFailedReason)
require.Empty(t, output.AbortReason)
require.Equal(t, true, sm.isInBackoff())
})
}
output := sm.Upload(&common.AsyncDestinationStruct{
ImportingJobIDs: []int64{1},
Destination: destination,
FileName: "testdata/successful_user_records.txt",
})
require.Equal(t, 1, output.FailedCount)
require.Equal(t, 0, output.AbortCount)
require.NotEmpty(t, output.FailedReason)
require.Empty(t, output.AbortReason)
require.Equal(t, true, sm.isInBackoff())
})

t.Run("Upload insert error for all events", func(t *testing.T) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
backendconfig "github.com/rudderlabs/rudder-server/backend-config"
"github.com/rudderlabs/rudder-server/warehouse/integrations/manager"
whutils "github.com/rudderlabs/rudder-server/warehouse/utils"
"github.com/rudderlabs/rudder-server/warehouse/validations"
)

type (
Expand All @@ -31,6 +32,7 @@ type (
api api
channelCache sync.Map
polledImportInfoMap map[string]*importInfo
validator validations.DestinationValidator

config struct {
client struct {
Expand Down
6 changes: 0 additions & 6 deletions warehouse/internal/model/validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,3 @@ type Step struct {
type StepsResponse struct {
Steps []*Step `json:"steps"`
}

type DestinationValidationResponse struct {
Success bool `json:"success"`
Error string `json:"error"`
Steps []*Step `json:"steps"`
}
88 changes: 38 additions & 50 deletions warehouse/validations/steps.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"context"
"encoding/json"

"github.com/samber/lo"

backendconfig "github.com/rudderlabs/rudder-server/backend-config"
schemarepository "github.com/rudderlabs/rudder-server/warehouse/integrations/datalake/schema-repository"
"github.com/rudderlabs/rudder-server/warehouse/internal/model"
Expand All @@ -15,63 +17,49 @@ func validateStepFunc(_ context.Context, destination *backendconfig.DestinationT
}

func StepsToValidate(dest *backendconfig.DestinationT) *model.StepsResponse {
var (
destType = dest.DestinationDefinition.Name
steps []*model.Step
)
destType := dest.DestinationDefinition.Name

if destType == warehouseutils.SnowpipeStreaming {
return &model.StepsResponse{
Steps: []*model.Step{
{ID: 1, Name: model.VerifyingConnections},
{ID: 2, Name: model.VerifyingCreateSchema},
{ID: 3, Name: model.VerifyingCreateAndAlterTable},
{ID: 4, Name: model.VerifyingFetchSchema},
},
}
}

steps = []*model.Step{{
ID: len(steps) + 1,
Name: model.VerifyingObjectStorage,
}}
steps := []*model.Step{
{ID: 1, Name: model.VerifyingObjectStorage},
}

appendSteps := func(newSteps ...string) {
for _, step := range newSteps {
steps = append(steps, &model.Step{ID: len(steps) + 1, Name: step})
}
}

switch destType {
case warehouseutils.GCSDatalake, warehouseutils.AzureDatalake:
// No additional steps
case warehouseutils.S3Datalake:
wh := createDummyWarehouse(dest)
if canUseGlue := schemarepository.UseGlue(&wh); !canUseGlue {
break
if schemarepository.UseGlue(lo.ToPtr(createDummyWarehouse(dest))) {
appendSteps(
model.VerifyingCreateSchema,
model.VerifyingCreateAndAlterTable,
model.VerifyingFetchSchema,
)
}

steps = append(steps,
&model.Step{
ID: len(steps) + 1,
Name: model.VerifyingCreateSchema,
},
&model.Step{
ID: len(steps) + 2,
Name: model.VerifyingCreateAndAlterTable,
},
&model.Step{
ID: len(steps) + 3,
Name: model.VerifyingFetchSchema,
},
)
default:
steps = append(steps,
&model.Step{
ID: len(steps) + 1,
Name: model.VerifyingConnections,
},
&model.Step{
ID: len(steps) + 2,
Name: model.VerifyingCreateSchema,
},
&model.Step{
ID: len(steps) + 3,
Name: model.VerifyingCreateAndAlterTable,
},
&model.Step{
ID: len(steps) + 4,
Name: model.VerifyingFetchSchema,
},
&model.Step{
ID: len(steps) + 5,
Name: model.VerifyingLoadTable,
},
appendSteps(
model.VerifyingConnections,
model.VerifyingCreateSchema,
model.VerifyingCreateAndAlterTable,
model.VerifyingFetchSchema,
model.VerifyingLoadTable,
)
}
return &model.StepsResponse{
Steps: steps,
}

return &model.StepsResponse{Steps: steps}
}
14 changes: 14 additions & 0 deletions warehouse/validations/steps_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,20 @@ func TestValidationSteps(t *testing.T) {
model.VerifyingLoadTable,
},
},
{
name: "Snowpipe",
dest: backendconfig.DestinationT{
DestinationDefinition: backendconfig.DestinationDefinitionT{
Name: warehouseutils.SnowpipeStreaming,
},
},
steps: []string{
model.VerifyingConnections,
model.VerifyingCreateSchema,
model.VerifyingCreateAndAlterTable,
model.VerifyingFetchSchema,
},
},
}

for _, tc := range testCases {
Expand Down
Loading

0 comments on commit a3810fc

Please sign in to comment.