Skip to content

Commit

Permalink
fix: remove unsafe interface conversion (#953)
Browse files Browse the repository at this point in the history
  • Loading branch information
Sugar-pack authored Nov 13, 2023
1 parent 0ff9fa8 commit dc753e7
Show file tree
Hide file tree
Showing 8 changed files with 352 additions and 13 deletions.
13 changes: 9 additions & 4 deletions authexternalbrowser.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,18 @@ func buildResponse(application string) bytes.Buffer {
// This opens a socket that listens on all available unicast
// and any anycast IP addresses locally. By specifying "0", we are
// able to bind to a free port.
func bindToPort() (net.Listener, error) {
func createLocalTCPListener() (*net.TCPListener, error) {
l, err := net.Listen("tcp", "localhost:0")
if err != nil {
logger.Infof("unable to bind to a port on localhost, err: %v", err)
return nil, err
}
return l, nil

tcpListener, ok := l.(*net.TCPListener)
if !ok {
return nil, fmt.Errorf("failed to assert type as *net.TCPListener")
}

return tcpListener, nil
}

// Opens a browser window (or new tab) with the configured IDP Url.
Expand Down Expand Up @@ -213,7 +218,7 @@ func doAuthenticateByExternalBrowser(
user string,
password string,
) authenticateByExternalBrowserResult {
l, err := bindToPort()
l, err := createLocalTCPListener()
if err != nil {
return authenticateByExternalBrowserResult{nil, nil, err}
}
Expand Down
13 changes: 13 additions & 0 deletions authexternalbrowser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,16 @@ func TestAuthenticationTimeout(t *testing.T) {
t.Fatal("should have timed out")
}
}

func Test_createLocalTCPListener(t *testing.T) {
listener, err := createLocalTCPListener()
if err != nil {
t.Fatalf("createLocalTCPListener() failed: %v", err)
}
if listener == nil {
t.Fatal("createLocalTCPListener() returned nil listener")
}

// Close the listener after the test.
defer listener.Close()
}
6 changes: 5 additions & 1 deletion dsn.go
Original file line number Diff line number Diff line change
Expand Up @@ -869,5 +869,9 @@ func parsePrivateKeyFromFile(path string) (*rsa.PrivateKey, error) {
if err != nil {
return nil, err
}
return privateKey.(*rsa.PrivateKey), nil
pk, ok := privateKey.(*rsa.PrivateKey)
if !ok {
return nil, fmt.Errorf("interface convertion. expected type *rsa.PrivateKey, but got %T", privateKey)
}
return pk, nil
}
31 changes: 31 additions & 0 deletions dsn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
package gosnowflake

import (
"crypto/ecdsa"
"crypto/elliptic"
cr "crypto/rand"
"crypto/rsa"
"crypto/x509"
Expand All @@ -15,6 +17,8 @@ import (
"strings"
"testing"
"time"

"github.com/aws/smithy-go/rand"
)

type tcParseDSN struct {
Expand Down Expand Up @@ -1358,6 +1362,33 @@ func TestParsePrivateKeyFromFileIncorrectData(t *testing.T) {
}
}

func TestParsePrivateKeyFromFileNotRSAPrivateKey(t *testing.T) {
// Generate an ECDSA private key for testing
ecdsaPrivateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatalf("failed to generate ECDSA private key: %v", err)
}

ecdsaPrivateKeyBytes, err := x509.MarshalECPrivateKey(ecdsaPrivateKey)
if err != nil {
t.Fatalf("failed to marshal ECDSA private key: %v", err)
}
pemBlock := &pem.Block{
Type: "EC PRIVATE KEY",
Bytes: ecdsaPrivateKeyBytes,
}
pemData := pem.EncodeToMemory(pemBlock)

// Write the PEM data to a temporary file
pemFile := createTmpFile("ecdsaKey.pem", pemData)

// Attempt to parse the private key
_, err = parsePrivateKeyFromFile(pemFile)
if err == nil {
t.Error("expected an error when trying to parse an ECDSA private key as RSA")
}
}

func TestParsePrivateKeyFromFile(t *testing.T) {
generatedKey, _ := rsa.GenerateKey(cr.Reader, 1024)
pemKey, _ := x509.MarshalPKCS8PrivateKey(generatedKey)
Expand Down
17 changes: 14 additions & 3 deletions gcs_storage_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@ func (util *snowflakeGcsClient) getFileHeader(meta *fileMetadata, filename strin
if err != nil {
return nil, err
}
accessToken := meta.client.(string)
accessToken, ok := meta.client.(string)
if !ok {
return nil, fmt.Errorf("interface convertion. expected type string but got %T", meta.client)
}
gcsHeaders := map[string]string{
"Authorization": "Bearer " + accessToken,
}
Expand Down Expand Up @@ -145,7 +148,11 @@ func (util *snowflakeGcsClient) uploadFile(
if err != nil {
return err
}
accessToken = meta.client.(string)
var ok bool
accessToken, ok = meta.client.(string)
if !ok {
return fmt.Errorf("interface convertion. expected type string but got %T", meta.client)
}
}

var contentEncoding string
Expand Down Expand Up @@ -271,7 +278,11 @@ func (util *snowflakeGcsClient) nativeDownloadFile(
if err != nil {
return err
}
accessToken = meta.client.(string)
var ok bool
accessToken, ok = meta.client.(string)
if !ok {
return fmt.Errorf("interface convertion. expected type string but got %T", meta.client)
}
if accessToken != "" {
gcsHeaders["Authorization"] = "Bearer " + accessToken
}
Expand Down
70 changes: 70 additions & 0 deletions gcs_storage_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -810,6 +810,40 @@ func TestGetFileHeaderEncryptionData(t *testing.T) {
}
}

func TestGetFileHeaderEncryptionDataInterfaceConversionError(t *testing.T) {
mockEncDataResp := "{\"EncryptionMode\":\"FullBlob\",\"WrappedContentKey\": {\"KeyId\":\"symmKey1\",\"EncryptedKey\":\"testencryptedkey12345678910==\",\"Algorithm\":\"AES_CBC_256\"},\"EncryptionAgent\": {\"Protocol\":\"1.0\",\"EncryptionAlgorithm\":\"AES_CBC_256\"},\"ContentEncryptionIV\":\"testIVkey12345678910==\",\"KeyWrappingMetadata\":{\"EncryptionLibrary\":\"Java 5.3.0\"}}"
mockMatDesc := "{\"queryid\":\"01abc874-0406-1bf0-0000-53b10668e056\",\"smkid\":\"92019681909886\",\"key\":\"128\"}"
info := execResponseStageInfo{
Location: "gcs/teststage/users/34/",
LocationType: "GCS",
Creds: execResponseCredentials{
GcsAccessToken: "test-token-124456577",
},
}
meta := fileMetadata{
client: 1,
stageInfo: &info,
mockGcsClient: &clientMock{
DoFunc: func(req *http.Request) (*http.Response, error) {
return &http.Response{
Status: "200 OK",
StatusCode: 200,
Header: http.Header{
"X-Goog-Meta-Encryptiondata": []string{mockEncDataResp},
"Content-Length": []string{"4256"},
"X-Goog-Meta-Sfc-Digest": []string{"123456789abcdef"},
"X-Goog-Meta-Matdesc": []string{mockMatDesc},
},
}, nil
},
},
}
_, err := new(snowflakeGcsClient).getFileHeader(&meta, "file.txt")
if err == nil {
t.Error("should have raised an error")
}
}

func TestUploadFileToGcsNoStatus(t *testing.T) {
info := execResponseStageInfo{
Location: "gcs-blob/storage/users/456/",
Expand Down Expand Up @@ -961,3 +995,39 @@ func TestDownloadFileWithBadRequest(t *testing.T) {
renewPresignedURL, downloadMeta.resStatus)
}
}

func Test_snowflakeGcsClient_uploadFile(t *testing.T) {
info := execResponseStageInfo{
Location: "gcs/teststage/users/34/",
LocationType: "GCS",
Creds: execResponseCredentials{
GcsAccessToken: "test-token-124456577",
},
}
meta := fileMetadata{
client: 1,
stageInfo: &info,
}
err := new(snowflakeGcsClient).uploadFile("somedata", &meta, nil, 1, 1)
if err == nil {
t.Error("should have raised an error")
}
}

func Test_snowflakeGcsClient_nativeDownloadFile(t *testing.T) {
info := execResponseStageInfo{
Location: "gcs/teststage/users/34/",
LocationType: "GCS",
Creds: execResponseCredentials{
GcsAccessToken: "test-token-124456577",
},
}
meta := fileMetadata{
client: 1,
stageInfo: &info,
}
err := new(snowflakeGcsClient).nativeDownloadFile(&meta, "dummy data", 1)
if err == nil {
t.Error("should have raised an error")
}
}
39 changes: 34 additions & 5 deletions statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package gosnowflake
import (
"context"
"database/sql/driver"
"fmt"
)

// SnowflakeStmt represents the prepared statement in driver.
Expand Down Expand Up @@ -33,28 +34,56 @@ func (stmt *snowflakeStmt) NumInput() int {
func (stmt *snowflakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
logger.WithContext(stmt.sc.ctx).Infoln("Stmt.ExecContext")
result, err := stmt.sc.ExecContext(ctx, stmt.query, args)
stmt.lastQueryID = result.(SnowflakeResult).GetQueryID()
if err != nil {
return nil, err
}
r, ok := result.(SnowflakeResult)
if !ok {
return nil, fmt.Errorf("interface convertion. expected type SnowflakeResult but got %T", result)
}
stmt.lastQueryID = r.GetQueryID()
return result, err
}

func (stmt *snowflakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
logger.WithContext(stmt.sc.ctx).Infoln("Stmt.QueryContext")
rows, err := stmt.sc.QueryContext(ctx, stmt.query, args)
stmt.lastQueryID = rows.(SnowflakeRows).GetQueryID()
return rows, err
if err != nil {
return nil, err
}
r, ok := rows.(SnowflakeRows)
if !ok {
return nil, fmt.Errorf("interface convertion. expected type SnowflakeRows but got %T", rows)
}
stmt.lastQueryID = r.GetQueryID()
return rows, nil
}

func (stmt *snowflakeStmt) Exec(args []driver.Value) (driver.Result, error) {
logger.WithContext(stmt.sc.ctx).Infoln("Stmt.Exec")
result, err := stmt.sc.Exec(stmt.query, args)
stmt.lastQueryID = result.(SnowflakeResult).GetQueryID()
if err != nil {
return nil, err
}
r, ok := result.(SnowflakeResult)
if !ok {
return nil, fmt.Errorf("interface convertion. expected type SnowflakeResult but got %T", result)
}
stmt.lastQueryID = r.GetQueryID()
return result, err
}

func (stmt *snowflakeStmt) Query(args []driver.Value) (driver.Rows, error) {
logger.WithContext(stmt.sc.ctx).Infoln("Stmt.Query")
rows, err := stmt.sc.Query(stmt.query, args)
stmt.lastQueryID = rows.(SnowflakeRows).GetQueryID()
if err != nil {
return nil, err
}
r, ok := rows.(SnowflakeRows)
if !ok {
return nil, fmt.Errorf("interface convertion. expected type SnowflakeRows but got %T", rows)
}
stmt.lastQueryID = r.GetQueryID()
return rows, err
}

Expand Down
Loading

0 comments on commit dc753e7

Please sign in to comment.