diff --git a/main.go b/main.go index 624bda67f7..5f63ebbd50 100644 --- a/main.go +++ b/main.go @@ -57,7 +57,7 @@ var doltCommand = cli.NewSubCommandHandler("doltgresql", "it's git for data", [] var globalArgParser = cli.CreateGlobalArgParser("doltgresql") func init() { - server.DefaultProtocolListenerFunc = postgres.NewListenerWithConfig + server.DefaultProtocolListenerFunc = postgres.NewListener sqlserver.ExternalDisableUsers = true dfunctions.VersionString = Version } diff --git a/postgres/connection/connection.go b/postgres/connection/connection.go new file mode 100644 index 0000000000..f653554585 --- /dev/null +++ b/postgres/connection/connection.go @@ -0,0 +1,59 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package connection + +import "net" + +// Receive returns a Message from the given buffer, generally generated by the client in the main read loop of a +// connection. +func Receive(buffer []byte) (Message, bool, error) { + if len(buffer) == 0 { + return nil, false, nil + } + message, ok := allMessageHeaders[buffer[0]] + if !ok { + return nil, false, nil + } + outMessage, err := ReceiveInto(buffer, message) + return outMessage, true, err +} + +// ReceiveInto writes the contents of the buffer into the given Message. +func ReceiveInto[T Message](buffer []byte, message T) (out T, err error) { + defaultMessage := message.DefaultMessage() + fields := defaultMessage.Copy().Fields + if err = decode(&decodeBuffer{buffer}, []FieldGroup{fields}, 1); err != nil { + return out, err + } + decodedMessage, err := message.Decode(MessageFormat{defaultMessage.Name, fields, defaultMessage.info, false}) + if err != nil { + return out, err + } + return decodedMessage.(T), nil +} + +// Send sends the given message over the connection. +func Send(conn net.Conn, message Message) error { + encodedMessage, err := message.Encode() + if err != nil { + return err + } + data, err := encode(encodedMessage) + if err != nil { + return err + } + _, err = conn.Write(data) + return err +} diff --git a/postgres/messages/message.go b/postgres/connection/message.go similarity index 89% rename from postgres/messages/message.go rename to postgres/connection/message.go index f1b97688d3..b807dfe0eb 100644 --- a/postgres/messages/message.go +++ b/postgres/connection/message.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package messages +package connection import ( "fmt" @@ -30,14 +30,14 @@ type MessageFormat struct { // Message is a type that represents a PostgreSQL message. type Message interface { - // encode returns a new MessageFormat containing any modified data contained within the object. This should NOT be + // Encode returns a new MessageFormat containing any modified data contained within the object. This should NOT be // the default message. - encode() (MessageFormat, error) - // decode returns a new Message that represents the given MessageFormat. You should never return the default + Encode() (MessageFormat, error) + // Decode returns a new Message that represents the given MessageFormat. You should never return the default // message, even if the message never varies from the default. Always make a copy, and then modify that copy. - decode(s MessageFormat) (Message, error) - // defaultMessage returns the default, unmodified message for this type. - defaultMessage() *MessageFormat + Decode(s MessageFormat) (Message, error) + // DefaultMessage returns the default, unmodified message for this type. + DefaultMessage() *MessageFormat } // messageFieldInfo contains information on a specific field within a messageInfo. diff --git a/postgres/messages/message_decode_encode.go b/postgres/connection/message_decode_encode.go similarity index 85% rename from postgres/messages/message_decode_encode.go rename to postgres/connection/message_decode_encode.go index 2f0ced99a4..dec7abf66e 100644 --- a/postgres/messages/message_decode_encode.go +++ b/postgres/connection/message_decode_encode.go @@ -12,58 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -package messages +package connection import ( "bytes" "encoding/binary" "errors" "fmt" - "net" ) -// Receive returns a Message from the given buffer, generally generated by the client in the main read loop of a -// connection. -func Receive(buffer []byte) (Message, bool, error) { - if len(buffer) == 0 { - return nil, false, nil - } - message, ok := allMessageHeaders[buffer[0]] - if !ok { - return nil, false, nil - } - outMessage, err := ReceiveInto(buffer, message) - return outMessage, true, err -} - -// ReceiveInto writes the contents of the buffer into the given Message. -func ReceiveInto[T Message](buffer []byte, message T) (out T, err error) { - defaultMessage := message.defaultMessage() - fields := defaultMessage.Copy().Fields - if err = decode(&decodeBuffer{buffer}, []FieldGroup{fields}, 1); err != nil { - return out, err - } - decodedMessage, err := message.decode(MessageFormat{defaultMessage.Name, fields, defaultMessage.info, false}) - if err != nil { - return out, err - } - return decodedMessage.(T), nil -} - -// Send sends the given message over the connection. -func Send(conn net.Conn, message Message) error { - encodedMessage, err := message.encode() - if err != nil { - return err - } - data, err := encode(encodedMessage) - if err != nil { - return err - } - _, err = conn.Write(data) - return err -} - // decodeBuffer just provides an easy way to reference the same buffer, so that decode can modify its length. type decodeBuffer struct { data []byte @@ -241,7 +198,6 @@ func encode(ms MessageFormat) ([]byte, error) { if data[bufferIdx] == 0 { found = true byteOffset += int32(bufferIdx) - //TODO: is this the correct place to put this? investigate/test if field.Flags&ExcludeTerminator == 0 { byteOffset += 1 } diff --git a/postgres/messages/message_field.go b/postgres/connection/message_field.go similarity index 99% rename from postgres/messages/message_field.go rename to postgres/connection/message_field.go index 7925656de9..6f152f876c 100644 --- a/postgres/messages/message_field.go +++ b/postgres/connection/message_field.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package messages +package connection // FieldType is the type of the field as defined by PostgreSQL. type FieldType byte diff --git a/postgres/messages/message_initialization.go b/postgres/connection/message_initialization.go similarity index 94% rename from postgres/messages/message_initialization.go rename to postgres/connection/message_initialization.go index 962fb6d044..352f6dce91 100644 --- a/postgres/messages/message_initialization.go +++ b/postgres/connection/message_initialization.go @@ -12,9 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -package messages +package connection -import "fmt" +import ( + "fmt" + + "github.com/dolthub/doltgresql/utils" +) // allMessageHeaders contains any message headers that should be read within the main read loop of a connection. var allMessageHeaders = make(map[byte]Message) @@ -25,26 +29,26 @@ var allMessageNames = make(map[string]struct{}) // allMessageDefaults contains all of the default message pointers, to make sure that they're not accidentally being reused. var allMessageDefaults = make(map[*MessageFormat]struct{}) -// addMessageHeader adds the given Message's header. This also ensures that each header is unique. This should be +// AddMessageHeader adds the given Message's header. This also ensures that each header is unique. This should be // called in an init() function. -func addMessageHeader(message Message) { - for _, field := range message.defaultMessage().Fields { +func AddMessageHeader(message Message) { + for _, field := range message.DefaultMessage().Fields { if field.Flags&Header != 0 { header := byte(field.Data.(int32)) if _, ok := allMessageHeaders[header]; ok { - panic(fmt.Errorf("Header already taken.\nMessageFormat:\n\n%s", message.defaultMessage().String())) + panic(fmt.Errorf("Header already taken.\nMessageFormat:\n\n%s", message.DefaultMessage().String())) } allMessageHeaders[header] = message return } } - panic(fmt.Errorf("Header does not exist.\nMessageFormat:\n\n%s", message.defaultMessage().String())) + panic(fmt.Errorf("Header does not exist.\nMessageFormat:\n\n%s", message.DefaultMessage().String())) } -// initializeDefaultMessage creates the internal structure of the default message, while ensuring that the structure of +// InitializeDefaultMessage creates the internal structure of the default message, while ensuring that the structure of // the message is correct. This should be called in an init() function. -func initializeDefaultMessage(messageType Message) { - message := messageType.defaultMessage() +func InitializeDefaultMessage(messageType Message) { + message := messageType.DefaultMessage() if _, ok := allMessageDefaults[message]; ok { panic(fmt.Errorf("MessageFormat default was used in another message.\nMessageFormat:\n\n%s", message.String())) } @@ -69,7 +73,7 @@ func initializeDefaultMessage(messageType Message) { Fields FieldGroup } - ftStack := NewStack[FieldTraversal]() + ftStack := utils.NewStack[FieldTraversal]() ftStack.Push(FieldTraversal{0, message.Fields}) for !ftStack.Empty() { // If we're at the end of the loop for this stacked entry, then we pop it and move to the next diff --git a/postgres/messages/message_writer.go b/postgres/connection/message_writer.go similarity index 99% rename from postgres/messages/message_writer.go rename to postgres/connection/message_writer.go index 1a9e76b810..1df78d7e62 100644 --- a/postgres/messages/message_writer.go +++ b/postgres/connection/message_writer.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package messages +package connection import "fmt" diff --git a/postgres/listener.go b/postgres/listener.go index 5acf2ebb41..f031dbb47e 100644 --- a/postgres/listener.go +++ b/postgres/listener.go @@ -24,12 +24,13 @@ import ( "github.com/dolthub/vitess/go/mysql" "github.com/dolthub/vitess/go/sqltypes" + "github.com/dolthub/doltgresql/postgres/connection" "github.com/dolthub/doltgresql/postgres/messages" ) var connectionIDCounter uint32 -// TODO: doc +// Listener listens for connections to process PostgreSQL requests into Dolt requests. type Listener struct { listener net.Listener cfg mysql.ListenerConfig @@ -37,15 +38,15 @@ type Listener struct { var _ server.ProtocolListener = (*Listener)(nil) -// TODO: doc -func NewListenerWithConfig(listenerCfg mysql.ListenerConfig) (server.ProtocolListener, error) { +// NewListener creates a new Listener. +func NewListener(listenerCfg mysql.ListenerConfig) (server.ProtocolListener, error) { return &Listener{ listener: listenerCfg.Listener, cfg: listenerCfg, }, nil } -// TODO: doc +// Accept handles incoming connections. func (l *Listener) Accept() { for { conn, err := l.listener.Accept() @@ -58,12 +59,12 @@ func (l *Listener) Accept() { } } -// TODO: doc +// Close stops the handling of incoming connections. func (l *Listener) Close() { _ = l.listener.Close() } -// TODO: doc +// Addr returns the address that the listener is listening on. func (l *Listener) Addr() net.Addr { return l.listener.Addr() } @@ -93,7 +94,7 @@ func (l *Listener) HandleConnection(conn net.Conn) { return } - if err = messages.Send(conn, messages.SSLResponse{ + if err = connection.Send(conn, messages.SSLResponse{ SupportsSSL: false, }); err != nil { fmt.Println(err) @@ -107,25 +108,25 @@ func (l *Listener) HandleConnection(conn net.Conn) { } return } - startupMessage, err := messages.ReceiveInto(buf, messages.StartupMessage{}) + startupMessage, err := connection.ReceiveInto(buf, messages.StartupMessage{}) if err != nil { fmt.Println(err) return } - if err = messages.Send(conn, messages.AuthenticationOk{}); err != nil { + if err = connection.Send(conn, messages.AuthenticationOk{}); err != nil { fmt.Println(err) return } - if err = messages.Send(conn, messages.ParameterStatus{ + if err = connection.Send(conn, messages.ParameterStatus{ Name: "server_version", Value: "15.0", }); err != nil { fmt.Println(err) return } - if err = messages.Send(conn, messages.ParameterStatus{ + if err = connection.Send(conn, messages.ParameterStatus{ Name: "client_encoding", Value: "UTF8", }); err != nil { @@ -133,7 +134,7 @@ func (l *Listener) HandleConnection(conn net.Conn) { return } - if err = messages.Send(conn, messages.BackendKeyData{ + if err = connection.Send(conn, messages.BackendKeyData{ ProcessID: 1, SecretKey: 0, }); err != nil { @@ -141,7 +142,7 @@ func (l *Listener) HandleConnection(conn net.Conn) { return } - if err = messages.Send(conn, messages.ReadyForQuery{ + if err = connection.Send(conn, messages.ReadyForQuery{ Indicator: messages.ReadyForQueryTransactionIndicator_Idle, }); err != nil { fmt.Println(err) @@ -162,7 +163,7 @@ func (l *Listener) HandleConnection(conn net.Conn) { return } - message, ok, err := messages.Receive(buf) + message, ok, err := connection.Receive(buf) if err != nil { fmt.Println(err.Error()) return @@ -187,14 +188,14 @@ func (l *Listener) query(conn net.Conn, mysqlConn *mysql.Conn, query string) { } if err := l.cfg.Handler.ComQuery(mysqlConn, query, func(res *sqltypes.Result, more bool) error { - if err := messages.Send(conn, messages.RowDescription{ + if err := connection.Send(conn, messages.RowDescription{ Fields: res.Fields, }); err != nil { return err } for _, row := range res.Rows { - if err := messages.Send(conn, messages.DataRow{ + if err := connection.Send(conn, messages.DataRow{ Values: row, }); err != nil { return err @@ -212,12 +213,12 @@ func (l *Listener) query(conn net.Conn, mysqlConn *mysql.Conn, query string) { return } - if err := messages.Send(conn, commandComplete); err != nil { + if err := connection.Send(conn, commandComplete); err != nil { fmt.Println(err) return } - if err := messages.Send(conn, messages.ReadyForQuery{ + if err := connection.Send(conn, messages.ReadyForQuery{ Indicator: messages.ReadyForQueryTransactionIndicator_Idle, }); err != nil { fmt.Println(err) diff --git a/postgres/messages/authentication_cleartext_password.go b/postgres/messages/authentication_cleartext_password.go index 07a2eeb824..3d92cc968b 100644 --- a/postgres/messages/authentication_cleartext_password.go +++ b/postgres/messages/authentication_cleartext_password.go @@ -14,52 +14,54 @@ package messages +import "github.com/dolthub/doltgresql/postgres/connection" + func init() { - initializeDefaultMessage(AuthenticationCleartextPassword{}) + connection.InitializeDefaultMessage(AuthenticationCleartextPassword{}) } // AuthenticationCleartextPassword represents a PostgreSQL message. type AuthenticationCleartextPassword struct{} -var authenticationCleartextPasswordDefault = MessageFormat{ +var authenticationCleartextPasswordDefault = connection.MessageFormat{ Name: "AuthenticationCleartextPassword", - Fields: FieldGroup{ + Fields: connection.FieldGroup{ { Name: "Header", - Type: Byte1, - Flags: Header, + Type: connection.Byte1, + Flags: connection.Header, Data: int32('R'), }, { Name: "MessageLength", - Type: Int32, - Flags: MessageLengthInclusive, + Type: connection.Int32, + Flags: connection.MessageLengthInclusive, Data: int32(8), }, { Name: "Status", - Type: Int32, + Type: connection.Int32, Data: int32(3), }, }, } -var _ Message = AuthenticationCleartextPassword{} +var _ connection.Message = AuthenticationCleartextPassword{} -// encode implements the interface Message. -func (m AuthenticationCleartextPassword) encode() (MessageFormat, error) { - return m.defaultMessage().Copy(), nil +// Encode implements the interface connection.Message. +func (m AuthenticationCleartextPassword) Encode() (connection.MessageFormat, error) { + return m.DefaultMessage().Copy(), nil } -// decode implements the interface Message. -func (m AuthenticationCleartextPassword) decode(s MessageFormat) (Message, error) { - if err := s.MatchesStructure(*m.defaultMessage()); err != nil { +// Decode implements the interface connection.Message. +func (m AuthenticationCleartextPassword) Decode(s connection.MessageFormat) (connection.Message, error) { + if err := s.MatchesStructure(*m.DefaultMessage()); err != nil { return nil, err } return AuthenticationCleartextPassword{}, nil } -// defaultMessage implements the interface Message. -func (m AuthenticationCleartextPassword) defaultMessage() *MessageFormat { +// DefaultMessage implements the interface connection.Message. +func (m AuthenticationCleartextPassword) DefaultMessage() *connection.MessageFormat { return &authenticationCleartextPasswordDefault } diff --git a/postgres/messages/authentication_gss.go b/postgres/messages/authentication_gss.go index 223315eb34..5dafe415cf 100644 --- a/postgres/messages/authentication_gss.go +++ b/postgres/messages/authentication_gss.go @@ -14,52 +14,54 @@ package messages +import "github.com/dolthub/doltgresql/postgres/connection" + func init() { - initializeDefaultMessage(AuthenticationGSS{}) + connection.InitializeDefaultMessage(AuthenticationGSS{}) } // AuthenticationGSS represents a PostgreSQL message. type AuthenticationGSS struct{} -var authenticationGSSDefault = MessageFormat{ +var authenticationGSSDefault = connection.MessageFormat{ Name: "AuthenticationGSS", - Fields: FieldGroup{ + Fields: connection.FieldGroup{ { Name: "Header", - Type: Byte1, - Flags: Header, + Type: connection.Byte1, + Flags: connection.Header, Data: int32('R'), }, { Name: "MessageLength", - Type: Int32, - Flags: MessageLengthInclusive, + Type: connection.Int32, + Flags: connection.MessageLengthInclusive, Data: int32(8), }, { Name: "Status", - Type: Int32, + Type: connection.Int32, Data: int32(7), }, }, } -var _ Message = AuthenticationGSS{} +var _ connection.Message = AuthenticationGSS{} -// encode implements the interface Message. -func (m AuthenticationGSS) encode() (MessageFormat, error) { - return m.defaultMessage().Copy(), nil +// Encode implements the interface connection.Message. +func (m AuthenticationGSS) Encode() (connection.MessageFormat, error) { + return m.DefaultMessage().Copy(), nil } -// decode implements the interface Message. -func (m AuthenticationGSS) decode(s MessageFormat) (Message, error) { - if err := s.MatchesStructure(*m.defaultMessage()); err != nil { +// Decode implements the interface connection.Message. +func (m AuthenticationGSS) Decode(s connection.MessageFormat) (connection.Message, error) { + if err := s.MatchesStructure(*m.DefaultMessage()); err != nil { return nil, err } return AuthenticationGSS{}, nil } -// defaultMessage implements the interface Message. -func (m AuthenticationGSS) defaultMessage() *MessageFormat { +// DefaultMessage implements the interface connection.Message. +func (m AuthenticationGSS) DefaultMessage() *connection.MessageFormat { return &authenticationGSSDefault } diff --git a/postgres/messages/authentication_gss_continue.go b/postgres/messages/authentication_gss_continue.go index aec0e229d5..7817197ef4 100644 --- a/postgres/messages/authentication_gss_continue.go +++ b/postgres/messages/authentication_gss_continue.go @@ -14,8 +14,10 @@ package messages +import "github.com/dolthub/doltgresql/postgres/connection" + func init() { - initializeDefaultMessage(AuthenticationGSSContinue{}) + connection.InitializeDefaultMessage(AuthenticationGSSContinue{}) } // AuthenticationGSSContinue represents a PostgreSQL message. @@ -23,46 +25,46 @@ type AuthenticationGSSContinue struct { Data []byte } -var authenticationGSSContinueDefault = MessageFormat{ +var authenticationGSSContinueDefault = connection.MessageFormat{ Name: "AuthenticationGSSContinue", - Fields: FieldGroup{ + Fields: connection.FieldGroup{ { Name: "Header", - Type: Byte1, - Flags: Header, + Type: connection.Byte1, + Flags: connection.Header, Data: int32('R'), }, { Name: "MessageLength", - Type: Int32, - Flags: MessageLengthInclusive, + Type: connection.Int32, + Flags: connection.MessageLengthInclusive, Data: int32(0), }, { Name: "Status", - Type: Int32, + Type: connection.Int32, Data: int32(8), }, { Name: "AuthenticationData", - Type: ByteN, + Type: connection.ByteN, Data: []byte{}, }, }, } -var _ Message = AuthenticationGSSContinue{} +var _ connection.Message = AuthenticationGSSContinue{} -// encode implements the interface Message. -func (m AuthenticationGSSContinue) encode() (MessageFormat, error) { - outputMessage := m.defaultMessage().Copy() +// Encode implements the interface connection.Message. +func (m AuthenticationGSSContinue) Encode() (connection.MessageFormat, error) { + outputMessage := m.DefaultMessage().Copy() outputMessage.Field("AuthenticationData").MustWrite(m.Data) return outputMessage, nil } -// decode implements the interface Message. -func (m AuthenticationGSSContinue) decode(s MessageFormat) (Message, error) { - if err := s.MatchesStructure(*m.defaultMessage()); err != nil { +// Decode implements the interface connection.Message. +func (m AuthenticationGSSContinue) Decode(s connection.MessageFormat) (connection.Message, error) { + if err := s.MatchesStructure(*m.DefaultMessage()); err != nil { return nil, err } return AuthenticationGSSContinue{ @@ -70,7 +72,7 @@ func (m AuthenticationGSSContinue) decode(s MessageFormat) (Message, error) { }, nil } -// defaultMessage implements the interface Message. -func (m AuthenticationGSSContinue) defaultMessage() *MessageFormat { +// DefaultMessage implements the interface connection.Message. +func (m AuthenticationGSSContinue) DefaultMessage() *connection.MessageFormat { return &authenticationGSSContinueDefault } diff --git a/postgres/messages/authentication_kerberos_v5.go b/postgres/messages/authentication_kerberos_v5.go index 5df34282e8..8994fb6e6c 100644 --- a/postgres/messages/authentication_kerberos_v5.go +++ b/postgres/messages/authentication_kerberos_v5.go @@ -14,52 +14,54 @@ package messages +import "github.com/dolthub/doltgresql/postgres/connection" + func init() { - initializeDefaultMessage(AuthenticationKerberosV5{}) + connection.InitializeDefaultMessage(AuthenticationKerberosV5{}) } // AuthenticationKerberosV5 represents a PostgreSQL message. type AuthenticationKerberosV5 struct{} -var authenticationKerberosV5Default = MessageFormat{ +var authenticationKerberosV5Default = connection.MessageFormat{ Name: "AuthenticationKerberosV5", - Fields: FieldGroup{ + Fields: connection.FieldGroup{ { Name: "Header", - Type: Byte1, - Flags: Header, + Type: connection.Byte1, + Flags: connection.Header, Data: int32('R'), }, { Name: "MessageLength", - Type: Int32, - Flags: MessageLengthInclusive, + Type: connection.Int32, + Flags: connection.MessageLengthInclusive, Data: int32(8), }, { Name: "Status", - Type: Int32, + Type: connection.Int32, Data: int32(2), }, }, } -var _ Message = AuthenticationKerberosV5{} +var _ connection.Message = AuthenticationKerberosV5{} -// encode implements the interface Message. -func (m AuthenticationKerberosV5) encode() (MessageFormat, error) { - return m.defaultMessage().Copy(), nil +// Encode implements the interface connection.Message. +func (m AuthenticationKerberosV5) Encode() (connection.MessageFormat, error) { + return m.DefaultMessage().Copy(), nil } -// decode implements the interface Message. -func (m AuthenticationKerberosV5) decode(s MessageFormat) (Message, error) { - if err := s.MatchesStructure(*m.defaultMessage()); err != nil { +// Decode implements the interface connection.Message. +func (m AuthenticationKerberosV5) Decode(s connection.MessageFormat) (connection.Message, error) { + if err := s.MatchesStructure(*m.DefaultMessage()); err != nil { return nil, err } return AuthenticationKerberosV5{}, nil } -// defaultMessage implements the interface Message. -func (m AuthenticationKerberosV5) defaultMessage() *MessageFormat { +// DefaultMessage implements the interface connection.Message. +func (m AuthenticationKerberosV5) DefaultMessage() *connection.MessageFormat { return &authenticationKerberosV5Default } diff --git a/postgres/messages/authentication_md5_password.go b/postgres/messages/authentication_md5_password.go index 44da6c286c..c0d26ea2ef 100644 --- a/postgres/messages/authentication_md5_password.go +++ b/postgres/messages/authentication_md5_password.go @@ -14,8 +14,10 @@ package messages +import "github.com/dolthub/doltgresql/postgres/connection" + func init() { - initializeDefaultMessage(AuthenticationMD5Password{}) + connection.InitializeDefaultMessage(AuthenticationMD5Password{}) } // AuthenticationMD5Password represents a PostgreSQL message. @@ -23,46 +25,46 @@ type AuthenticationMD5Password struct { Salt int32 } -var authenticationMD5PasswordDefault = MessageFormat{ +var authenticationMD5PasswordDefault = connection.MessageFormat{ Name: "AuthenticationMD5Password", - Fields: FieldGroup{ + Fields: connection.FieldGroup{ { Name: "Header", - Type: Byte1, - Flags: Header, + Type: connection.Byte1, + Flags: connection.Header, Data: int32('R'), }, { Name: "MessageLength", - Type: Int32, - Flags: MessageLengthInclusive, + Type: connection.Int32, + Flags: connection.MessageLengthInclusive, Data: int32(12), }, { Name: "Status", - Type: Int32, + Type: connection.Int32, Data: int32(5), }, { Name: "Salt", - Type: Byte4, + Type: connection.Byte4, Data: int32(0), }, }, } -var _ Message = AuthenticationMD5Password{} +var _ connection.Message = AuthenticationMD5Password{} -// encode implements the interface Message. -func (m AuthenticationMD5Password) encode() (MessageFormat, error) { - outputMessage := m.defaultMessage().Copy() +// Encode implements the interface connection.Message. +func (m AuthenticationMD5Password) Encode() (connection.MessageFormat, error) { + outputMessage := m.DefaultMessage().Copy() outputMessage.Field("Salt").MustWrite(m.Salt) return outputMessage, nil } -// decode implements the interface Message. -func (m AuthenticationMD5Password) decode(s MessageFormat) (Message, error) { - if err := s.MatchesStructure(*m.defaultMessage()); err != nil { +// Decode implements the interface connection.Message. +func (m AuthenticationMD5Password) Decode(s connection.MessageFormat) (connection.Message, error) { + if err := s.MatchesStructure(*m.DefaultMessage()); err != nil { return nil, err } return AuthenticationMD5Password{ @@ -70,7 +72,7 @@ func (m AuthenticationMD5Password) decode(s MessageFormat) (Message, error) { }, nil } -// defaultMessage implements the interface Message. -func (m AuthenticationMD5Password) defaultMessage() *MessageFormat { +// DefaultMessage implements the interface connection.Message. +func (m AuthenticationMD5Password) DefaultMessage() *connection.MessageFormat { return &authenticationMD5PasswordDefault } diff --git a/postgres/messages/authentication_ok.go b/postgres/messages/authentication_ok.go index 5bfd289098..75665be5e7 100644 --- a/postgres/messages/authentication_ok.go +++ b/postgres/messages/authentication_ok.go @@ -14,52 +14,54 @@ package messages +import "github.com/dolthub/doltgresql/postgres/connection" + // AuthenticationOk tells the client that authentication was successful. type AuthenticationOk struct{} func init() { - initializeDefaultMessage(AuthenticationOk{}) + connection.InitializeDefaultMessage(AuthenticationOk{}) } -var authenticationOkDefault = MessageFormat{ +var authenticationOkDefault = connection.MessageFormat{ Name: "AuthenticationOk", - Fields: FieldGroup{ + Fields: connection.FieldGroup{ { Name: "Header", - Type: Byte1, - Flags: Header, + Type: connection.Byte1, + Flags: connection.Header, Data: int32('R'), }, { Name: "MessageLength", - Type: Int32, - Flags: MessageLengthInclusive, + Type: connection.Int32, + Flags: connection.MessageLengthInclusive, Data: int32(8), }, { Name: "Status", - Type: Int32, + Type: connection.Int32, Data: int32(0), }, }, } -var _ Message = AuthenticationOk{} +var _ connection.Message = AuthenticationOk{} -// encode implements the interface Message. -func (m AuthenticationOk) encode() (MessageFormat, error) { - return m.defaultMessage().Copy(), nil +// Encode implements the interface connection.Message. +func (m AuthenticationOk) Encode() (connection.MessageFormat, error) { + return m.DefaultMessage().Copy(), nil } -// decode implements the interface Message. -func (m AuthenticationOk) decode(s MessageFormat) (Message, error) { - if err := s.MatchesStructure(*m.defaultMessage()); err != nil { +// Decode implements the interface connection.Message. +func (m AuthenticationOk) Decode(s connection.MessageFormat) (connection.Message, error) { + if err := s.MatchesStructure(*m.DefaultMessage()); err != nil { return nil, err } return AuthenticationOk{}, nil } -// defaultMessage implements the interface Message. -func (m AuthenticationOk) defaultMessage() *MessageFormat { +// DefaultMessage implements the interface connection.Message. +func (m AuthenticationOk) DefaultMessage() *connection.MessageFormat { return &authenticationOkDefault } diff --git a/postgres/messages/authentication_sasl.go b/postgres/messages/authentication_sasl.go index 78577b565b..2de6c8e9f9 100644 --- a/postgres/messages/authentication_sasl.go +++ b/postgres/messages/authentication_sasl.go @@ -14,8 +14,10 @@ package messages +import "github.com/dolthub/doltgresql/postgres/connection" + func init() { - initializeDefaultMessage(AuthenticationSASL{}) + connection.InitializeDefaultMessage(AuthenticationSASL{}) } // AuthenticationSASL represents a PostgreSQL message. @@ -23,36 +25,36 @@ type AuthenticationSASL struct { Mechanisms []string } -var authenticationSASLDefault = MessageFormat{ +var authenticationSASLDefault = connection.MessageFormat{ Name: "AuthenticationSASL", - Fields: FieldGroup{ + Fields: connection.FieldGroup{ { Name: "Header", - Type: Byte1, - Flags: Header, + Type: connection.Byte1, + Flags: connection.Header, Data: int32('R'), }, { Name: "MessageLength", - Type: Int32, - Flags: MessageLengthInclusive, + Type: connection.Int32, + Flags: connection.MessageLengthInclusive, Data: int32(0), }, { Name: "Status", - Type: Int32, + Type: connection.Int32, Data: int32(10), }, { Name: "Mechanisms", - Type: Repeated, - Flags: RepeatedTerminator, + Type: connection.Repeated, + Flags: connection.RepeatedTerminator, Data: int32(0), - Children: []FieldGroup{ + Children: []connection.FieldGroup{ { { Name: "Mechanism", - Type: String, + Type: connection.String, Data: "", }, }, @@ -61,20 +63,20 @@ var authenticationSASLDefault = MessageFormat{ }, } -var _ Message = AuthenticationSASL{} +var _ connection.Message = AuthenticationSASL{} -// encode implements the interface Message. -func (m AuthenticationSASL) encode() (MessageFormat, error) { - outputMessage := m.defaultMessage().Copy() +// Encode implements the interface connection.Message. +func (m AuthenticationSASL) Encode() (connection.MessageFormat, error) { + outputMessage := m.DefaultMessage().Copy() for i, mechanism := range m.Mechanisms { outputMessage.Field("Mechanisms").Child("Mechanism", i).MustWrite(mechanism) } return outputMessage, nil } -// decode implements the interface Message. -func (m AuthenticationSASL) decode(s MessageFormat) (Message, error) { - if err := s.MatchesStructure(*m.defaultMessage()); err != nil { +// Decode implements the interface connection.Message. +func (m AuthenticationSASL) Decode(s connection.MessageFormat) (connection.Message, error) { + if err := s.MatchesStructure(*m.DefaultMessage()); err != nil { return nil, err } count := int(s.Field("Mechanisms").MustGet().(int32)) @@ -87,7 +89,7 @@ func (m AuthenticationSASL) decode(s MessageFormat) (Message, error) { }, nil } -// defaultMessage implements the interface Message. -func (m AuthenticationSASL) defaultMessage() *MessageFormat { +// DefaultMessage implements the interface connection.Message. +func (m AuthenticationSASL) DefaultMessage() *connection.MessageFormat { return &authenticationSASLDefault } diff --git a/postgres/messages/authentication_sasl_continue.go b/postgres/messages/authentication_sasl_continue.go index 6d129ad127..571f5b08da 100644 --- a/postgres/messages/authentication_sasl_continue.go +++ b/postgres/messages/authentication_sasl_continue.go @@ -14,8 +14,10 @@ package messages +import "github.com/dolthub/doltgresql/postgres/connection" + func init() { - initializeDefaultMessage(AuthenticationSASLContinue{}) + connection.InitializeDefaultMessage(AuthenticationSASLContinue{}) } // AuthenticationSASLContinue represents a PostgreSQL message. @@ -23,46 +25,46 @@ type AuthenticationSASLContinue struct { Data []byte } -var authenticationSASLContinueDefault = MessageFormat{ +var authenticationSASLContinueDefault = connection.MessageFormat{ Name: "AuthenticationSASLContinue", - Fields: FieldGroup{ + Fields: connection.FieldGroup{ { Name: "Header", - Type: Byte1, - Flags: Header, + Type: connection.Byte1, + Flags: connection.Header, Data: int32('R'), }, { Name: "MessageLength", - Type: Int32, - Flags: MessageLengthInclusive, + Type: connection.Int32, + Flags: connection.MessageLengthInclusive, Data: int32(0), }, { Name: "Status", - Type: Int32, + Type: connection.Int32, Data: int32(11), }, { Name: "SASLData", - Type: ByteN, + Type: connection.ByteN, Data: []byte{}, }, }, } -var _ Message = AuthenticationSASLContinue{} +var _ connection.Message = AuthenticationSASLContinue{} -// encode implements the interface Message. -func (m AuthenticationSASLContinue) encode() (MessageFormat, error) { - outputMessage := m.defaultMessage().Copy() +// Encode implements the interface connection.Message. +func (m AuthenticationSASLContinue) Encode() (connection.MessageFormat, error) { + outputMessage := m.DefaultMessage().Copy() outputMessage.Field("SASLData").MustWrite(m.Data) return outputMessage, nil } -// decode implements the interface Message. -func (m AuthenticationSASLContinue) decode(s MessageFormat) (Message, error) { - if err := s.MatchesStructure(*m.defaultMessage()); err != nil { +// Decode implements the interface connection.Message. +func (m AuthenticationSASLContinue) Decode(s connection.MessageFormat) (connection.Message, error) { + if err := s.MatchesStructure(*m.DefaultMessage()); err != nil { return nil, err } return AuthenticationSASLContinue{ @@ -70,7 +72,7 @@ func (m AuthenticationSASLContinue) decode(s MessageFormat) (Message, error) { }, nil } -// defaultMessage implements the interface Message. -func (m AuthenticationSASLContinue) defaultMessage() *MessageFormat { +// DefaultMessage implements the interface connection.Message. +func (m AuthenticationSASLContinue) DefaultMessage() *connection.MessageFormat { return &authenticationSASLContinueDefault } diff --git a/postgres/messages/authentication_sasl_final.go b/postgres/messages/authentication_sasl_final.go index be6c19b122..d9f1773906 100644 --- a/postgres/messages/authentication_sasl_final.go +++ b/postgres/messages/authentication_sasl_final.go @@ -14,8 +14,10 @@ package messages +import "github.com/dolthub/doltgresql/postgres/connection" + func init() { - initializeDefaultMessage(AuthenticationSASLFinal{}) + connection.InitializeDefaultMessage(AuthenticationSASLFinal{}) } // AuthenticationSASLFinal represents a PostgreSQL message. @@ -23,46 +25,46 @@ type AuthenticationSASLFinal struct { AdditionalData []byte } -var authenticationSASLFinalDefault = MessageFormat{ +var authenticationSASLFinalDefault = connection.MessageFormat{ Name: "AuthenticationSASLFinal", - Fields: FieldGroup{ + Fields: connection.FieldGroup{ { Name: "Header", - Type: Byte1, - Flags: Header, + Type: connection.Byte1, + Flags: connection.Header, Data: int32('R'), }, { Name: "MessageLength", - Type: Int32, - Flags: MessageLengthInclusive, + Type: connection.Int32, + Flags: connection.MessageLengthInclusive, Data: int32(0), }, { Name: "Status", - Type: Int32, + Type: connection.Int32, Data: int32(12), }, { Name: "AdditionalData", - Type: ByteN, + Type: connection.ByteN, Data: []byte{}, }, }, } -var _ Message = AuthenticationSASLFinal{} +var _ connection.Message = AuthenticationSASLFinal{} -// encode implements the interface Message. -func (m AuthenticationSASLFinal) encode() (MessageFormat, error) { - outputMessage := m.defaultMessage().Copy() +// Encode implements the interface connection.Message. +func (m AuthenticationSASLFinal) Encode() (connection.MessageFormat, error) { + outputMessage := m.DefaultMessage().Copy() outputMessage.Field("AdditionalData").MustWrite(m.AdditionalData) return outputMessage, nil } -// decode implements the interface Message. -func (m AuthenticationSASLFinal) decode(s MessageFormat) (Message, error) { - if err := s.MatchesStructure(*m.defaultMessage()); err != nil { +// Decode implements the interface connection.Message. +func (m AuthenticationSASLFinal) Decode(s connection.MessageFormat) (connection.Message, error) { + if err := s.MatchesStructure(*m.DefaultMessage()); err != nil { return nil, err } return AuthenticationSASLFinal{ @@ -70,7 +72,7 @@ func (m AuthenticationSASLFinal) decode(s MessageFormat) (Message, error) { }, nil } -// defaultMessage implements the interface Message. -func (m AuthenticationSASLFinal) defaultMessage() *MessageFormat { +// DefaultMessage implements the interface connection.Message. +func (m AuthenticationSASLFinal) DefaultMessage() *connection.MessageFormat { return &authenticationSASLFinalDefault } diff --git a/postgres/messages/authentication_scm_credential.go b/postgres/messages/authentication_scm_credential.go index fa93f144b1..fab270f5ba 100644 --- a/postgres/messages/authentication_scm_credential.go +++ b/postgres/messages/authentication_scm_credential.go @@ -14,52 +14,54 @@ package messages +import "github.com/dolthub/doltgresql/postgres/connection" + func init() { - initializeDefaultMessage(AuthenticationSCMCredential{}) + connection.InitializeDefaultMessage(AuthenticationSCMCredential{}) } // AuthenticationSCMCredential represents a PostgreSQL message. type AuthenticationSCMCredential struct{} -var authenticationSCMCredentialDefault = MessageFormat{ +var authenticationSCMCredentialDefault = connection.MessageFormat{ Name: "AuthenticationSCMCredential", - Fields: FieldGroup{ + Fields: connection.FieldGroup{ { Name: "Header", - Type: Byte1, - Flags: Header, + Type: connection.Byte1, + Flags: connection.Header, Data: int32('R'), }, { Name: "MessageLength", - Type: Int32, - Flags: MessageLengthInclusive, + Type: connection.Int32, + Flags: connection.MessageLengthInclusive, Data: int32(8), }, { Name: "Status", - Type: Int32, + Type: connection.Int32, Data: int32(6), }, }, } -var _ Message = AuthenticationSCMCredential{} +var _ connection.Message = AuthenticationSCMCredential{} -// encode implements the interface Message. -func (m AuthenticationSCMCredential) encode() (MessageFormat, error) { - return m.defaultMessage().Copy(), nil +// Encode implements the interface connection.Message. +func (m AuthenticationSCMCredential) Encode() (connection.MessageFormat, error) { + return m.DefaultMessage().Copy(), nil } -// decode implements the interface Message. -func (m AuthenticationSCMCredential) decode(s MessageFormat) (Message, error) { - if err := s.MatchesStructure(*m.defaultMessage()); err != nil { +// Decode implements the interface connection.Message. +func (m AuthenticationSCMCredential) Decode(s connection.MessageFormat) (connection.Message, error) { + if err := s.MatchesStructure(*m.DefaultMessage()); err != nil { return nil, err } return AuthenticationSCMCredential{}, nil } -// defaultMessage implements the interface Message. -func (m AuthenticationSCMCredential) defaultMessage() *MessageFormat { +// DefaultMessage implements the interface connection.Message. +func (m AuthenticationSCMCredential) DefaultMessage() *connection.MessageFormat { return &authenticationSCMCredentialDefault } diff --git a/postgres/messages/authentication_sspi.go b/postgres/messages/authentication_sspi.go index ea2f20eb02..9a9c231fb4 100644 --- a/postgres/messages/authentication_sspi.go +++ b/postgres/messages/authentication_sspi.go @@ -14,52 +14,54 @@ package messages +import "github.com/dolthub/doltgresql/postgres/connection" + func init() { - initializeDefaultMessage(AuthenticationSSPI{}) + connection.InitializeDefaultMessage(AuthenticationSSPI{}) } // AuthenticationSSPI represents a PostgreSQL message. type AuthenticationSSPI struct{} -var authenticationSSPIDefault = MessageFormat{ +var authenticationSSPIDefault = connection.MessageFormat{ Name: "AuthenticationSSPI", - Fields: FieldGroup{ + Fields: connection.FieldGroup{ { Name: "Header", - Type: Byte1, - Flags: Header, + Type: connection.Byte1, + Flags: connection.Header, Data: int32('R'), }, { Name: "MessageLength", - Type: Int32, - Flags: MessageLengthInclusive, + Type: connection.Int32, + Flags: connection.MessageLengthInclusive, Data: int32(8), }, { Name: "Status", - Type: Int32, + Type: connection.Int32, Data: int32(9), }, }, } -var _ Message = AuthenticationSSPI{} +var _ connection.Message = AuthenticationSSPI{} -// encode implements the interface Message. -func (m AuthenticationSSPI) encode() (MessageFormat, error) { - return m.defaultMessage().Copy(), nil +// Encode implements the interface connection.Message. +func (m AuthenticationSSPI) Encode() (connection.MessageFormat, error) { + return m.DefaultMessage().Copy(), nil } -// decode implements the interface Message. -func (m AuthenticationSSPI) decode(s MessageFormat) (Message, error) { - if err := s.MatchesStructure(*m.defaultMessage()); err != nil { +// Decode implements the interface connection.Message. +func (m AuthenticationSSPI) Decode(s connection.MessageFormat) (connection.Message, error) { + if err := s.MatchesStructure(*m.DefaultMessage()); err != nil { return nil, err } return AuthenticationSSPI{}, nil } -// defaultMessage implements the interface Message. -func (m AuthenticationSSPI) defaultMessage() *MessageFormat { +// DefaultMessage implements the interface connection.Message. +func (m AuthenticationSSPI) DefaultMessage() *connection.MessageFormat { return &authenticationSSPIDefault } diff --git a/postgres/messages/backend_key_data.go b/postgres/messages/backend_key_data.go index 458809545c..7b6cadf5fa 100644 --- a/postgres/messages/backend_key_data.go +++ b/postgres/messages/backend_key_data.go @@ -14,9 +14,11 @@ package messages +import "github.com/dolthub/doltgresql/postgres/connection" + func init() { - initializeDefaultMessage(BackendKeyData{}) - addMessageHeader(BackendKeyData{}) + connection.InitializeDefaultMessage(BackendKeyData{}) + connection.AddMessageHeader(BackendKeyData{}) } // BackendKeyData provides the client with information about the server. @@ -25,47 +27,47 @@ type BackendKeyData struct { SecretKey int32 } -var backendKeyDataDefault = MessageFormat{ +var backendKeyDataDefault = connection.MessageFormat{ Name: "BackendKeyData", - Fields: FieldGroup{ + Fields: connection.FieldGroup{ { Name: "Header", - Type: Byte1, - Flags: Header, + Type: connection.Byte1, + Flags: connection.Header, Data: int32('K'), }, { Name: "MessageLength", - Type: Int32, - Flags: MessageLengthInclusive, + Type: connection.Int32, + Flags: connection.MessageLengthInclusive, Data: int32(12), }, { Name: "ProcessID", - Type: Int32, + Type: connection.Int32, Data: int32(0), }, { Name: "SecretKey", - Type: Int32, + Type: connection.Int32, Data: int32(0), }, }, } -var _ Message = BackendKeyData{} +var _ connection.Message = BackendKeyData{} -// encode implements the interface Message. -func (m BackendKeyData) encode() (MessageFormat, error) { - outputMessage := m.defaultMessage().Copy() +// Encode implements the interface connection.Message. +func (m BackendKeyData) Encode() (connection.MessageFormat, error) { + outputMessage := m.DefaultMessage().Copy() outputMessage.Field("ProcessID").MustWrite(m.ProcessID) outputMessage.Field("SecretKey").MustWrite(m.SecretKey) return outputMessage, nil } -// decode implements the interface Message. -func (m BackendKeyData) decode(s MessageFormat) (Message, error) { - if err := s.MatchesStructure(*m.defaultMessage()); err != nil { +// Decode implements the interface connection.Message. +func (m BackendKeyData) Decode(s connection.MessageFormat) (connection.Message, error) { + if err := s.MatchesStructure(*m.DefaultMessage()); err != nil { return nil, err } return BackendKeyData{ @@ -74,7 +76,7 @@ func (m BackendKeyData) decode(s MessageFormat) (Message, error) { }, nil } -// defaultMessage implements the interface Message. -func (m BackendKeyData) defaultMessage() *MessageFormat { +// DefaultMessage implements the interface connection.Message. +func (m BackendKeyData) DefaultMessage() *connection.MessageFormat { return &backendKeyDataDefault } diff --git a/postgres/messages/bind.go b/postgres/messages/bind.go index 67c1fd8bde..b767521d84 100644 --- a/postgres/messages/bind.go +++ b/postgres/messages/bind.go @@ -14,9 +14,11 @@ package messages +import "github.com/dolthub/doltgresql/postgres/connection" + func init() { - initializeDefaultMessage(Bind{}) - addMessageHeader(Bind{}) + connection.InitializeDefaultMessage(Bind{}) + connection.AddMessageHeader(Bind{}) } // Bind represents a PostgreSQL message. @@ -34,40 +36,40 @@ type BindParameterValue struct { IsNull bool } -var bindDefault = MessageFormat{ +var bindDefault = connection.MessageFormat{ Name: "Bind", - Fields: FieldGroup{ + Fields: connection.FieldGroup{ { Name: "Header", - Type: Byte1, - Flags: Header, + Type: connection.Byte1, + Flags: connection.Header, Data: int32('B'), }, { Name: "MessageLength", - Type: Int32, - Flags: MessageLengthInclusive, + Type: connection.Int32, + Flags: connection.MessageLengthInclusive, Data: int32(0), }, { Name: "DestinationPortal", - Type: String, + Type: connection.String, Data: "", }, { Name: "SourcePreparedStatement", - Type: String, + Type: connection.String, Data: "", }, { Name: "ParameterFormatCodes", - Type: Int16, + Type: connection.Int16, Data: int32(0), - Children: []FieldGroup{ + Children: []connection.FieldGroup{ { { Name: "ParameterFormatCode", - Type: Int16, + Type: connection.Int16, Data: int32(0), }, }, @@ -75,19 +77,19 @@ var bindDefault = MessageFormat{ }, { Name: "ParameterValues", - Type: Int16, + Type: connection.Int16, Data: int32(0), - Children: []FieldGroup{ + Children: []connection.FieldGroup{ { { Name: "ParameterLength", - Type: Int32, - Flags: ByteCount, + Type: connection.Int32, + Flags: connection.ByteCount, Data: int32(0), }, { Name: "ParameterValue", - Type: ByteN, + Type: connection.ByteN, Data: []byte{}, }, }, @@ -95,13 +97,13 @@ var bindDefault = MessageFormat{ }, { Name: "ResultFormatCodes", - Type: Int16, + Type: connection.Int16, Data: int32(0), - Children: []FieldGroup{ + Children: []connection.FieldGroup{ { { Name: "ResultFormatCode", - Type: Int16, + Type: connection.Int16, Data: int32(0), }, }, @@ -110,11 +112,11 @@ var bindDefault = MessageFormat{ }, } -var _ Message = Bind{} +var _ connection.Message = Bind{} -// encode implements the interface Message. -func (m Bind) encode() (MessageFormat, error) { - outputMessage := m.defaultMessage().Copy() +// Encode implements the interface connection.Message. +func (m Bind) Encode() (connection.MessageFormat, error) { + outputMessage := m.DefaultMessage().Copy() outputMessage.Field("DestinationPortal").MustWrite(m.DestinationPortal) outputMessage.Field("SourcePreparedStatement").MustWrite(m.SourcePreparedStatement) for i, pFormatCode := range m.ParameterFormatCodes { @@ -134,9 +136,9 @@ func (m Bind) encode() (MessageFormat, error) { return outputMessage, nil } -// decode implements the interface Message. -func (m Bind) decode(s MessageFormat) (Message, error) { - if err := s.MatchesStructure(*m.defaultMessage()); err != nil { +// Decode implements the interface connection.Message. +func (m Bind) Decode(s connection.MessageFormat) (connection.Message, error) { + if err := s.MatchesStructure(*m.DefaultMessage()); err != nil { return nil, err } @@ -178,7 +180,7 @@ func (m Bind) decode(s MessageFormat) (Message, error) { }, nil } -// defaultMessage implements the interface Message. -func (m Bind) defaultMessage() *MessageFormat { +// DefaultMessage implements the interface connection.Message. +func (m Bind) DefaultMessage() *connection.MessageFormat { return &bindDefault } diff --git a/postgres/messages/bind_complete.go b/postgres/messages/bind_complete.go index 8959a9542e..038c73c435 100644 --- a/postgres/messages/bind_complete.go +++ b/postgres/messages/bind_complete.go @@ -14,48 +14,50 @@ package messages +import "github.com/dolthub/doltgresql/postgres/connection" + func init() { - initializeDefaultMessage(BindComplete{}) - addMessageHeader(BindComplete{}) + connection.InitializeDefaultMessage(BindComplete{}) + connection.AddMessageHeader(BindComplete{}) } // BindComplete represents a PostgreSQL message. type BindComplete struct{} -var bindCompleteDefault = MessageFormat{ +var bindCompleteDefault = connection.MessageFormat{ Name: "BindComplete", - Fields: FieldGroup{ + Fields: connection.FieldGroup{ { Name: "Header", - Type: Byte1, - Flags: Header, + Type: connection.Byte1, + Flags: connection.Header, Data: int32('2'), }, { Name: "MessageLength", - Type: Int32, - Flags: MessageLengthInclusive, + Type: connection.Int32, + Flags: connection.MessageLengthInclusive, Data: int32(4), }, }, } -var _ Message = BindComplete{} +var _ connection.Message = BindComplete{} -// encode implements the interface Message. -func (m BindComplete) encode() (MessageFormat, error) { - return m.defaultMessage().Copy(), nil +// Encode implements the interface connection.Message. +func (m BindComplete) Encode() (connection.MessageFormat, error) { + return m.DefaultMessage().Copy(), nil } -// decode implements the interface Message. -func (m BindComplete) decode(s MessageFormat) (Message, error) { - if err := s.MatchesStructure(*m.defaultMessage()); err != nil { +// Decode implements the interface connection.Message. +func (m BindComplete) Decode(s connection.MessageFormat) (connection.Message, error) { + if err := s.MatchesStructure(*m.DefaultMessage()); err != nil { return nil, err } return BindComplete{}, nil } -// defaultMessage implements the interface Message. -func (m BindComplete) defaultMessage() *MessageFormat { +// DefaultMessage implements the interface connection.Message. +func (m BindComplete) DefaultMessage() *connection.MessageFormat { return &bindCompleteDefault } diff --git a/postgres/messages/cancel_request.go b/postgres/messages/cancel_request.go index 3cd8169b5c..82fe7e1483 100644 --- a/postgres/messages/cancel_request.go +++ b/postgres/messages/cancel_request.go @@ -14,8 +14,10 @@ package messages +import "github.com/dolthub/doltgresql/postgres/connection" + func init() { - initializeDefaultMessage(CancelRequest{}) + connection.InitializeDefaultMessage(CancelRequest{}) } // CancelRequest represents a PostgreSQL message. @@ -24,46 +26,46 @@ type CancelRequest struct { SecretKey int32 } -var cancelRequestDefault = MessageFormat{ +var cancelRequestDefault = connection.MessageFormat{ Name: "CancelRequest", - Fields: FieldGroup{ + Fields: connection.FieldGroup{ { Name: "MessageLength", - Type: Int32, - Flags: MessageLengthInclusive, + Type: connection.Int32, + Flags: connection.MessageLengthInclusive, Data: int32(0), }, { Name: "RequestCode", - Type: Int32, + Type: connection.Int32, Data: int32(80877102), }, { Name: "ProcessID", - Type: Int32, + Type: connection.Int32, Data: int32(0), }, { Name: "SecretKey", - Type: Int32, + Type: connection.Int32, Data: int32(0), }, }, } -var _ Message = CancelRequest{} +var _ connection.Message = CancelRequest{} -// encode implements the interface Message. -func (m CancelRequest) encode() (MessageFormat, error) { - outputMessage := m.defaultMessage().Copy() +// Encode implements the interface connection.Message. +func (m CancelRequest) Encode() (connection.MessageFormat, error) { + outputMessage := m.DefaultMessage().Copy() outputMessage.Field("ProcessID").MustWrite(m.ProcessID) outputMessage.Field("SecretKey").MustWrite(m.SecretKey) return outputMessage, nil } -// decode implements the interface Message. -func (m CancelRequest) decode(s MessageFormat) (Message, error) { - if err := s.MatchesStructure(*m.defaultMessage()); err != nil { +// Decode implements the interface connection.Message. +func (m CancelRequest) Decode(s connection.MessageFormat) (connection.Message, error) { + if err := s.MatchesStructure(*m.DefaultMessage()); err != nil { return nil, err } return CancelRequest{ @@ -72,7 +74,7 @@ func (m CancelRequest) decode(s MessageFormat) (Message, error) { }, nil } -// defaultMessage implements the interface Message. -func (m CancelRequest) defaultMessage() *MessageFormat { +// DefaultMessage implements the interface connection.Message. +func (m CancelRequest) DefaultMessage() *connection.MessageFormat { return &cancelRequestDefault } diff --git a/postgres/messages/close.go b/postgres/messages/close.go index 0e156c9d89..d48ecbbd4e 100644 --- a/postgres/messages/close.go +++ b/postgres/messages/close.go @@ -14,11 +14,15 @@ package messages -import "fmt" +import ( + "fmt" + + "github.com/dolthub/doltgresql/postgres/connection" +) func init() { - initializeDefaultMessage(Close{}) - addMessageHeader(Close{}) + connection.InitializeDefaultMessage(Close{}) + connection.AddMessageHeader(Close{}) } // Close represents a PostgreSQL message. @@ -27,39 +31,39 @@ type Close struct { Target string // Target is the name of whatever we are closing. } -var closeDefault = MessageFormat{ +var closeDefault = connection.MessageFormat{ Name: "Close", - Fields: FieldGroup{ + Fields: connection.FieldGroup{ { Name: "Header", - Type: Byte1, - Flags: Header, + Type: connection.Byte1, + Flags: connection.Header, Data: int32('C'), }, { Name: "MessageLength", - Type: Int32, - Flags: MessageLengthInclusive, + Type: connection.Int32, + Flags: connection.MessageLengthInclusive, Data: int32(0), }, { Name: "ClosingTarget", - Type: Byte1, + Type: connection.Byte1, Data: int32(0), }, { Name: "TargetName", - Type: String, + Type: connection.String, Data: "", }, }, } -var _ Message = Close{} +var _ connection.Message = Close{} -// encode implements the interface Message. -func (m Close) encode() (MessageFormat, error) { - outputMessage := m.defaultMessage().Copy() +// Encode implements the interface connection.Message. +func (m Close) Encode() (connection.MessageFormat, error) { + outputMessage := m.DefaultMessage().Copy() if m.ClosingPreparedStatement { outputMessage.Field("ClosingTarget").MustWrite('S') } else { @@ -69,9 +73,9 @@ func (m Close) encode() (MessageFormat, error) { return outputMessage, nil } -// decode implements the interface Message. -func (m Close) decode(s MessageFormat) (Message, error) { - if err := s.MatchesStructure(*m.defaultMessage()); err != nil { +// Decode implements the interface connection.Message. +func (m Close) Decode(s connection.MessageFormat) (connection.Message, error) { + if err := s.MatchesStructure(*m.DefaultMessage()); err != nil { return nil, err } closingTarget := s.Field("ClosingTarget").MustGet().(int32) @@ -89,7 +93,7 @@ func (m Close) decode(s MessageFormat) (Message, error) { }, nil } -// defaultMessage implements the interface Message. -func (m Close) defaultMessage() *MessageFormat { +// DefaultMessage implements the interface connection.Message. +func (m Close) DefaultMessage() *connection.MessageFormat { return &closeDefault } diff --git a/postgres/messages/close_complete.go b/postgres/messages/close_complete.go index 0097d1daa2..039d09e71d 100644 --- a/postgres/messages/close_complete.go +++ b/postgres/messages/close_complete.go @@ -14,48 +14,50 @@ package messages +import "github.com/dolthub/doltgresql/postgres/connection" + func init() { - initializeDefaultMessage(CloseComplete{}) - addMessageHeader(CloseComplete{}) + connection.InitializeDefaultMessage(CloseComplete{}) + connection.AddMessageHeader(CloseComplete{}) } // CloseComplete represents a PostgreSQL message. type CloseComplete struct{} -var closeCompleteDefault = MessageFormat{ +var closeCompleteDefault = connection.MessageFormat{ Name: "CloseComplete", - Fields: FieldGroup{ + Fields: connection.FieldGroup{ { Name: "Header", - Type: Byte1, - Flags: Header, + Type: connection.Byte1, + Flags: connection.Header, Data: int32('3'), }, { Name: "MessageLength", - Type: Int32, - Flags: MessageLengthInclusive, + Type: connection.Int32, + Flags: connection.MessageLengthInclusive, Data: int32(4), }, }, } -var _ Message = CloseComplete{} +var _ connection.Message = CloseComplete{} -// encode implements the interface Message. -func (m CloseComplete) encode() (MessageFormat, error) { - return m.defaultMessage().Copy(), nil +// Encode implements the interface connection.Message. +func (m CloseComplete) Encode() (connection.MessageFormat, error) { + return m.DefaultMessage().Copy(), nil } -// decode implements the interface Message. -func (m CloseComplete) decode(s MessageFormat) (Message, error) { - if err := s.MatchesStructure(*m.defaultMessage()); err != nil { +// Decode implements the interface connection.Message. +func (m CloseComplete) Decode(s connection.MessageFormat) (connection.Message, error) { + if err := s.MatchesStructure(*m.DefaultMessage()); err != nil { return nil, err } return CloseComplete{}, nil } -// defaultMessage implements the interface Message. -func (m CloseComplete) defaultMessage() *MessageFormat { +// DefaultMessage implements the interface connection.Message. +func (m CloseComplete) DefaultMessage() *connection.MessageFormat { return &closeCompleteDefault } diff --git a/postgres/messages/command_complete.go b/postgres/messages/command_complete.go index c086f17dcd..8d1e34d9dc 100644 --- a/postgres/messages/command_complete.go +++ b/postgres/messages/command_complete.go @@ -18,10 +18,12 @@ import ( "fmt" "strconv" "strings" + + "github.com/dolthub/doltgresql/postgres/connection" ) func init() { - initializeDefaultMessage(CommandComplete{}) + connection.InitializeDefaultMessage(CommandComplete{}) } // CommandComplete tells the client that the command has completed. @@ -30,30 +32,30 @@ type CommandComplete struct { Rows int32 } -var commandCompleteDefault = MessageFormat{ +var commandCompleteDefault = connection.MessageFormat{ Name: "CommandComplete", - Fields: FieldGroup{ + Fields: connection.FieldGroup{ { Name: "Header", - Type: Byte1, - Flags: Header, + Type: connection.Byte1, + Flags: connection.Header, Data: int32('C'), }, { Name: "MessageLength", - Type: Int32, - Flags: MessageLengthInclusive, + Type: connection.Int32, + Flags: connection.MessageLengthInclusive, Data: int32(0), }, { Name: "CommandTag", - Type: String, + Type: connection.String, Data: "", }, }, } -var _ Message = CommandComplete{} +var _ connection.Message = CommandComplete{} // IsIUD returns whether the query is either an INSERT, UPDATE, or DELETE query. func (m CommandComplete) IsIUD() bool { @@ -67,9 +69,9 @@ func (m CommandComplete) IsIUD() bool { } } -// encode implements the interface Message. -func (m CommandComplete) encode() (MessageFormat, error) { - outputMessage := m.defaultMessage().Copy() +// Encode implements the interface connection.Message. +func (m CommandComplete) Encode() (connection.MessageFormat, error) { + outputMessage := m.DefaultMessage().Copy() query := strings.TrimSpace(strings.ToLower(m.Query)) if strings.HasPrefix(query, "select") { outputMessage.Field("CommandTag").MustWrite(fmt.Sprintf("SELECT %d", m.Rows)) @@ -84,14 +86,14 @@ func (m CommandComplete) encode() (MessageFormat, error) { } else if strings.HasPrefix(query, "call") { outputMessage.Field("CommandTag").MustWrite(fmt.Sprintf("SELECT %d", m.Rows)) } else { - return MessageFormat{}, fmt.Errorf("unsupported query for now") + return connection.MessageFormat{}, fmt.Errorf("unsupported query for now") } return outputMessage, nil } -// decode implements the interface Message. -func (m CommandComplete) decode(s MessageFormat) (Message, error) { - if err := s.MatchesStructure(*m.defaultMessage()); err != nil { +// Decode implements the interface connection.Message. +func (m CommandComplete) Decode(s connection.MessageFormat) (connection.Message, error) { + if err := s.MatchesStructure(*m.DefaultMessage()); err != nil { return nil, err } query := strings.TrimSpace(s.Field("CommandTag").MustGet().(string)) @@ -106,7 +108,7 @@ func (m CommandComplete) decode(s MessageFormat) (Message, error) { }, nil } -// defaultMessage implements the interface Message. -func (m CommandComplete) defaultMessage() *MessageFormat { +// DefaultMessage implements the interface connection.Message. +func (m CommandComplete) DefaultMessage() *connection.MessageFormat { return &commandCompleteDefault } diff --git a/postgres/messages/copy_both_response.go b/postgres/messages/copy_both_response.go index 58a23a2d30..7d9f4b0acb 100644 --- a/postgres/messages/copy_both_response.go +++ b/postgres/messages/copy_both_response.go @@ -14,10 +14,14 @@ package messages -import "fmt" +import ( + "fmt" + + "github.com/dolthub/doltgresql/postgres/connection" +) func init() { - initializeDefaultMessage(CopyBothResponse{}) + connection.InitializeDefaultMessage(CopyBothResponse{}) } // CopyBothResponse represents a PostgreSQL message. @@ -26,35 +30,35 @@ type CopyBothResponse struct { FormatCodes []int32 } -var copyBothResponseDefault = MessageFormat{ +var copyBothResponseDefault = connection.MessageFormat{ Name: "CopyBothResponse", - Fields: FieldGroup{ + Fields: connection.FieldGroup{ { Name: "Header", - Type: Byte1, - Flags: Header, + Type: connection.Byte1, + Flags: connection.Header, Data: int32('W'), }, { Name: "MessageLength", - Type: Int32, - Flags: MessageLengthInclusive, + Type: connection.Int32, + Flags: connection.MessageLengthInclusive, Data: int32(0), }, { Name: "ResponseType", - Type: Int8, + Type: connection.Int8, Data: int32(0), }, { Name: "Columns", - Type: Int16, + Type: connection.Int16, Data: int32(0), - Children: []FieldGroup{ + Children: []connection.FieldGroup{ { { Name: "FormatCode", - Type: Int16, + Type: connection.Int16, Data: int32(0), }, }, @@ -63,11 +67,11 @@ var copyBothResponseDefault = MessageFormat{ }, } -var _ Message = CopyBothResponse{} +var _ connection.Message = CopyBothResponse{} -// encode implements the interface Message. -func (m CopyBothResponse) encode() (MessageFormat, error) { - outputMessage := m.defaultMessage().Copy() +// Encode implements the interface connection.Message. +func (m CopyBothResponse) Encode() (connection.MessageFormat, error) { + outputMessage := m.DefaultMessage().Copy() if m.IsTextual { outputMessage.Field("ResponseType").MustWrite(0) } else { @@ -79,9 +83,9 @@ func (m CopyBothResponse) encode() (MessageFormat, error) { return outputMessage, nil } -// decode implements the interface Message. -func (m CopyBothResponse) decode(s MessageFormat) (Message, error) { - if err := s.MatchesStructure(*m.defaultMessage()); err != nil { +// Decode implements the interface connection.Message. +func (m CopyBothResponse) Decode(s connection.MessageFormat) (connection.Message, error) { + if err := s.MatchesStructure(*m.DefaultMessage()); err != nil { return nil, err } var isTextual bool @@ -104,7 +108,7 @@ func (m CopyBothResponse) decode(s MessageFormat) (Message, error) { }, nil } -// defaultMessage implements the interface Message. -func (m CopyBothResponse) defaultMessage() *MessageFormat { +// DefaultMessage implements the interface connection.Message. +func (m CopyBothResponse) DefaultMessage() *connection.MessageFormat { return ©BothResponseDefault } diff --git a/postgres/messages/copy_data.go b/postgres/messages/copy_data.go index 5e33daccb9..1526d72151 100644 --- a/postgres/messages/copy_data.go +++ b/postgres/messages/copy_data.go @@ -14,9 +14,11 @@ package messages +import "github.com/dolthub/doltgresql/postgres/connection" + func init() { - initializeDefaultMessage(CopyData{}) - addMessageHeader(CopyData{}) + connection.InitializeDefaultMessage(CopyData{}) + connection.AddMessageHeader(CopyData{}) } // CopyData represents a PostgreSQL message. @@ -24,41 +26,41 @@ type CopyData struct { Data []byte } -var copyDataDefault = MessageFormat{ +var copyDataDefault = connection.MessageFormat{ Name: "CopyData", - Fields: FieldGroup{ + Fields: connection.FieldGroup{ { Name: "Header", - Type: Byte1, - Flags: Header, + Type: connection.Byte1, + Flags: connection.Header, Data: int32('d'), }, { Name: "MessageLength", - Type: Int32, - Flags: MessageLengthInclusive, + Type: connection.Int32, + Flags: connection.MessageLengthInclusive, Data: int32(0), }, { Name: "Data", - Type: ByteN, + Type: connection.ByteN, Data: []byte{}, }, }, } -var _ Message = CopyData{} +var _ connection.Message = CopyData{} -// encode implements the interface Message. -func (m CopyData) encode() (MessageFormat, error) { - outputMessage := m.defaultMessage().Copy() +// Encode implements the interface connection.Message. +func (m CopyData) Encode() (connection.MessageFormat, error) { + outputMessage := m.DefaultMessage().Copy() outputMessage.Field("Data").MustWrite(m.Data) return outputMessage, nil } -// decode implements the interface Message. -func (m CopyData) decode(s MessageFormat) (Message, error) { - if err := s.MatchesStructure(*m.defaultMessage()); err != nil { +// Decode implements the interface connection.Message. +func (m CopyData) Decode(s connection.MessageFormat) (connection.Message, error) { + if err := s.MatchesStructure(*m.DefaultMessage()); err != nil { return nil, err } return CopyData{ @@ -66,7 +68,7 @@ func (m CopyData) decode(s MessageFormat) (Message, error) { }, nil } -// defaultMessage implements the interface Message. -func (m CopyData) defaultMessage() *MessageFormat { +// DefaultMessage implements the interface connection.Message. +func (m CopyData) DefaultMessage() *connection.MessageFormat { return ©DataDefault } diff --git a/postgres/messages/copy_done.go b/postgres/messages/copy_done.go index 8e93802c42..1f3df6e758 100644 --- a/postgres/messages/copy_done.go +++ b/postgres/messages/copy_done.go @@ -14,48 +14,50 @@ package messages +import "github.com/dolthub/doltgresql/postgres/connection" + func init() { - initializeDefaultMessage(CopyDone{}) - addMessageHeader(CopyDone{}) + connection.InitializeDefaultMessage(CopyDone{}) + connection.AddMessageHeader(CopyDone{}) } // CopyDone represents a PostgreSQL message. type CopyDone struct{} -var copyDoneDefault = MessageFormat{ +var copyDoneDefault = connection.MessageFormat{ Name: "CopyDone", - Fields: FieldGroup{ + Fields: connection.FieldGroup{ { Name: "Header", - Type: Byte1, - Flags: Header, + Type: connection.Byte1, + Flags: connection.Header, Data: int32('c'), }, { Name: "MessageLength", - Type: Int32, - Flags: MessageLengthInclusive, + Type: connection.Int32, + Flags: connection.MessageLengthInclusive, Data: int32(4), }, }, } -var _ Message = CopyDone{} +var _ connection.Message = CopyDone{} -// encode implements the interface Message. -func (m CopyDone) encode() (MessageFormat, error) { - return m.defaultMessage().Copy(), nil +// Encode implements the interface connection.Message. +func (m CopyDone) Encode() (connection.MessageFormat, error) { + return m.DefaultMessage().Copy(), nil } -// decode implements the interface Message. -func (m CopyDone) decode(s MessageFormat) (Message, error) { - if err := s.MatchesStructure(*m.defaultMessage()); err != nil { +// Decode implements the interface connection.Message. +func (m CopyDone) Decode(s connection.MessageFormat) (connection.Message, error) { + if err := s.MatchesStructure(*m.DefaultMessage()); err != nil { return nil, err } return CopyDone{}, nil } -// defaultMessage implements the interface Message. -func (m CopyDone) defaultMessage() *MessageFormat { +// DefaultMessage implements the interface connection.Message. +func (m CopyDone) DefaultMessage() *connection.MessageFormat { return ©DoneDefault } diff --git a/postgres/messages/copy_fail.go b/postgres/messages/copy_fail.go index 55759af610..774e23aaf6 100644 --- a/postgres/messages/copy_fail.go +++ b/postgres/messages/copy_fail.go @@ -14,9 +14,11 @@ package messages +import "github.com/dolthub/doltgresql/postgres/connection" + func init() { - initializeDefaultMessage(CopyFail{}) - addMessageHeader(CopyFail{}) + connection.InitializeDefaultMessage(CopyFail{}) + connection.AddMessageHeader(CopyFail{}) } // CopyFail represents a PostgreSQL message. @@ -24,41 +26,41 @@ type CopyFail struct { ErrorMessage string } -var copyFailDefault = MessageFormat{ +var copyFailDefault = connection.MessageFormat{ Name: "CopyFail", - Fields: FieldGroup{ + Fields: connection.FieldGroup{ { Name: "Header", - Type: Byte1, - Flags: Header, + Type: connection.Byte1, + Flags: connection.Header, Data: int32('f'), }, { Name: "MessageLength", - Type: Int32, - Flags: MessageLengthInclusive, + Type: connection.Int32, + Flags: connection.MessageLengthInclusive, Data: int32(0), }, { Name: "ErrorMessage", - Type: String, + Type: connection.String, Data: "", }, }, } -var _ Message = CopyFail{} +var _ connection.Message = CopyFail{} -// encode implements the interface Message. -func (m CopyFail) encode() (MessageFormat, error) { - outputMessage := m.defaultMessage().Copy() +// Encode implements the interface connection.Message. +func (m CopyFail) Encode() (connection.MessageFormat, error) { + outputMessage := m.DefaultMessage().Copy() outputMessage.Field("ErrorMessage").MustWrite(m.ErrorMessage) return outputMessage, nil } -// decode implements the interface Message. -func (m CopyFail) decode(s MessageFormat) (Message, error) { - if err := s.MatchesStructure(*m.defaultMessage()); err != nil { +// Decode implements the interface connection.Message. +func (m CopyFail) Decode(s connection.MessageFormat) (connection.Message, error) { + if err := s.MatchesStructure(*m.DefaultMessage()); err != nil { return nil, err } return CopyFail{ @@ -66,7 +68,7 @@ func (m CopyFail) decode(s MessageFormat) (Message, error) { }, nil } -// defaultMessage implements the interface Message. -func (m CopyFail) defaultMessage() *MessageFormat { +// DefaultMessage implements the interface connection.Message. +func (m CopyFail) DefaultMessage() *connection.MessageFormat { return ©FailDefault } diff --git a/postgres/messages/copy_in_response.go b/postgres/messages/copy_in_response.go index 2aea1c229f..7745f43147 100644 --- a/postgres/messages/copy_in_response.go +++ b/postgres/messages/copy_in_response.go @@ -14,10 +14,14 @@ package messages -import "fmt" +import ( + "fmt" + + "github.com/dolthub/doltgresql/postgres/connection" +) func init() { - initializeDefaultMessage(CopyInResponse{}) + connection.InitializeDefaultMessage(CopyInResponse{}) } // CopyInResponse represents a PostgreSQL message. @@ -26,35 +30,35 @@ type CopyInResponse struct { FormatCodes []int32 } -var copyInResponseDefault = MessageFormat{ +var copyInResponseDefault = connection.MessageFormat{ Name: "CopyInResponse", - Fields: FieldGroup{ + Fields: connection.FieldGroup{ { Name: "Header", - Type: Byte1, - Flags: Header, + Type: connection.Byte1, + Flags: connection.Header, Data: int32('G'), }, { Name: "MessageLength", - Type: Int32, - Flags: MessageLengthInclusive, + Type: connection.Int32, + Flags: connection.MessageLengthInclusive, Data: int32(0), }, { Name: "ResponseType", - Type: Int8, + Type: connection.Int8, Data: int32(0), }, { Name: "Columns", - Type: Int16, + Type: connection.Int16, Data: int32(0), - Children: []FieldGroup{ + Children: []connection.FieldGroup{ { { Name: "FormatCode", - Type: Int16, + Type: connection.Int16, Data: int32(0), }, }, @@ -63,11 +67,11 @@ var copyInResponseDefault = MessageFormat{ }, } -var _ Message = CopyInResponse{} +var _ connection.Message = CopyInResponse{} -// encode implements the interface Message. -func (m CopyInResponse) encode() (MessageFormat, error) { - outputMessage := m.defaultMessage().Copy() +// Encode implements the interface connection.Message. +func (m CopyInResponse) Encode() (connection.MessageFormat, error) { + outputMessage := m.DefaultMessage().Copy() if m.IsTextual { outputMessage.Field("ResponseType").MustWrite(0) } else { @@ -79,9 +83,9 @@ func (m CopyInResponse) encode() (MessageFormat, error) { return outputMessage, nil } -// decode implements the interface Message. -func (m CopyInResponse) decode(s MessageFormat) (Message, error) { - if err := s.MatchesStructure(*m.defaultMessage()); err != nil { +// Decode implements the interface connection.Message. +func (m CopyInResponse) Decode(s connection.MessageFormat) (connection.Message, error) { + if err := s.MatchesStructure(*m.DefaultMessage()); err != nil { return nil, err } var isTextual bool @@ -104,7 +108,7 @@ func (m CopyInResponse) decode(s MessageFormat) (Message, error) { }, nil } -// defaultMessage implements the interface Message. -func (m CopyInResponse) defaultMessage() *MessageFormat { +// DefaultMessage implements the interface connection.Message. +func (m CopyInResponse) DefaultMessage() *connection.MessageFormat { return ©InResponseDefault } diff --git a/postgres/messages/copy_out_response.go b/postgres/messages/copy_out_response.go index 39e65b3c37..1fa29ae47c 100644 --- a/postgres/messages/copy_out_response.go +++ b/postgres/messages/copy_out_response.go @@ -14,10 +14,14 @@ package messages -import "fmt" +import ( + "fmt" + + "github.com/dolthub/doltgresql/postgres/connection" +) func init() { - initializeDefaultMessage(CopyOutResponse{}) + connection.InitializeDefaultMessage(CopyOutResponse{}) } // CopyOutResponse represents a PostgreSQL message. @@ -26,35 +30,35 @@ type CopyOutResponse struct { FormatCodes []int32 } -var copyOutResponseDefault = MessageFormat{ +var copyOutResponseDefault = connection.MessageFormat{ Name: "CopyOutResponse", - Fields: FieldGroup{ + Fields: connection.FieldGroup{ { Name: "Header", - Type: Byte1, - Flags: Header, + Type: connection.Byte1, + Flags: connection.Header, Data: int32('H'), }, { Name: "MessageLength", - Type: Int32, - Flags: MessageLengthInclusive, + Type: connection.Int32, + Flags: connection.MessageLengthInclusive, Data: int32(0), }, { Name: "ResponseType", - Type: Int8, + Type: connection.Int8, Data: int32(0), }, { Name: "Columns", - Type: Int16, + Type: connection.Int16, Data: int32(0), - Children: []FieldGroup{ + Children: []connection.FieldGroup{ { { Name: "FormatCode", - Type: Int16, + Type: connection.Int16, Data: int32(0), }, }, @@ -63,11 +67,11 @@ var copyOutResponseDefault = MessageFormat{ }, } -var _ Message = CopyOutResponse{} +var _ connection.Message = CopyOutResponse{} -// encode implements the interface Message. -func (m CopyOutResponse) encode() (MessageFormat, error) { - outputMessage := m.defaultMessage().Copy() +// Encode implements the interface connection.Message. +func (m CopyOutResponse) Encode() (connection.MessageFormat, error) { + outputMessage := m.DefaultMessage().Copy() if m.IsTextual { outputMessage.Field("ResponseType").MustWrite(0) } else { @@ -79,9 +83,9 @@ func (m CopyOutResponse) encode() (MessageFormat, error) { return outputMessage, nil } -// decode implements the interface Message. -func (m CopyOutResponse) decode(s MessageFormat) (Message, error) { - if err := s.MatchesStructure(*m.defaultMessage()); err != nil { +// Decode implements the interface connection.Message. +func (m CopyOutResponse) Decode(s connection.MessageFormat) (connection.Message, error) { + if err := s.MatchesStructure(*m.DefaultMessage()); err != nil { return nil, err } var isTextual bool @@ -104,7 +108,7 @@ func (m CopyOutResponse) decode(s MessageFormat) (Message, error) { }, nil } -// defaultMessage implements the interface Message. -func (m CopyOutResponse) defaultMessage() *MessageFormat { +// DefaultMessage implements the interface connection.Message. +func (m CopyOutResponse) DefaultMessage() *connection.MessageFormat { return ©OutResponseDefault } diff --git a/postgres/messages/data_row.go b/postgres/messages/data_row.go index 76e84c0a2e..911a8adeaa 100644 --- a/postgres/messages/data_row.go +++ b/postgres/messages/data_row.go @@ -15,11 +15,15 @@ package messages import ( + "fmt" + "github.com/dolthub/vitess/go/sqltypes" + + "github.com/dolthub/doltgresql/postgres/connection" ) func init() { - initializeDefaultMessage(DataRow{}) + connection.InitializeDefaultMessage(DataRow{}) } // DataRow represents a row of data. @@ -27,36 +31,36 @@ type DataRow struct { Values []sqltypes.Value } -var dataRowDefault = MessageFormat{ +var dataRowDefault = connection.MessageFormat{ Name: "DataRow", - Fields: FieldGroup{ + Fields: connection.FieldGroup{ { Name: "Header", - Type: Byte1, - Flags: Header, + Type: connection.Byte1, + Flags: connection.Header, Data: int32('D'), }, { Name: "MessageLength", - Type: Int32, - Flags: MessageLengthInclusive, + Type: connection.Int32, + Flags: connection.MessageLengthInclusive, Data: int32(0), }, { Name: "Columns", - Type: Int16, + Type: connection.Int16, Data: int32(0), - Children: []FieldGroup{ + Children: []connection.FieldGroup{ { { Name: "ColumnLength", - Type: Int32, - Flags: ByteCount, + Type: connection.Int32, + Flags: connection.ByteCount, Data: int32(0), }, { Name: "ColumnData", - Type: ByteN, + Type: connection.ByteN, Data: []byte{}, }, }, @@ -65,11 +69,11 @@ var dataRowDefault = MessageFormat{ }, } -var _ Message = DataRow{} +var _ connection.Message = DataRow{} -// encode implements the interface Message. -func (m DataRow) encode() (MessageFormat, error) { - outputMessage := m.defaultMessage().Copy() +// Encode implements the interface connection.Message. +func (m DataRow) Encode() (connection.MessageFormat, error) { + outputMessage := m.DefaultMessage().Copy() for i := 0; i < len(m.Values); i++ { if m.Values[i].IsNull() { outputMessage.Field("Columns").Child("ColumnLength", i).MustWrite(-1) @@ -82,21 +86,15 @@ func (m DataRow) encode() (MessageFormat, error) { return outputMessage, nil } -// decode implements the interface Message. -func (m DataRow) decode(s MessageFormat) (Message, error) { - if err := s.MatchesStructure(*m.defaultMessage()); err != nil { +// Decode implements the interface connection.Message. +func (m DataRow) Decode(s connection.MessageFormat) (connection.Message, error) { + if err := s.MatchesStructure(*m.DefaultMessage()); err != nil { return nil, err } - columnCount := int(s.Field("Columns").MustGet().(int32)) - for i := 0; i < columnCount; i++ { - //TODO: decode the message in here - } - return DataRow{ - Values: nil, - }, nil + return nil, fmt.Errorf("DataRow messages do not support decoding, as they're only sent from the server.") } -// defaultMessage implements the interface Message. -func (m DataRow) defaultMessage() *MessageFormat { +// DefaultMessage implements the interface connection.Message. +func (m DataRow) DefaultMessage() *connection.MessageFormat { return &dataRowDefault } diff --git a/postgres/messages/describe.go b/postgres/messages/describe.go index ef7f54f403..e1a9cc1e89 100644 --- a/postgres/messages/describe.go +++ b/postgres/messages/describe.go @@ -14,11 +14,15 @@ package messages -import "fmt" +import ( + "fmt" + + "github.com/dolthub/doltgresql/postgres/connection" +) func init() { - initializeDefaultMessage(Describe{}) - addMessageHeader(Describe{}) + connection.InitializeDefaultMessage(Describe{}) + connection.AddMessageHeader(Describe{}) } // Describe represents a PostgreSQL message. @@ -27,39 +31,39 @@ type Describe struct { Target string } -var describeDefault = MessageFormat{ +var describeDefault = connection.MessageFormat{ Name: "Describe", - Fields: FieldGroup{ + Fields: connection.FieldGroup{ { Name: "Header", - Type: Byte1, - Flags: Header, + Type: connection.Byte1, + Flags: connection.Header, Data: int32('D'), }, { Name: "MessageLength", - Type: Int32, - Flags: MessageLengthInclusive, + Type: connection.Int32, + Flags: connection.MessageLengthInclusive, Data: int32(0), }, { Name: "DescribingTarget", - Type: Byte1, + Type: connection.Byte1, Data: int32(0), }, { Name: "TargetName", - Type: String, + Type: connection.String, Data: "", }, }, } -var _ Message = Describe{} +var _ connection.Message = Describe{} -// encode implements the interface Message. -func (m Describe) encode() (MessageFormat, error) { - outputMessage := m.defaultMessage().Copy() +// Encode implements the interface connection.Message. +func (m Describe) Encode() (connection.MessageFormat, error) { + outputMessage := m.DefaultMessage().Copy() if m.IsPrepared { outputMessage.Field("DescribingTarget").MustWrite('S') } else { @@ -69,9 +73,9 @@ func (m Describe) encode() (MessageFormat, error) { return outputMessage, nil } -// decode implements the interface Message. -func (m Describe) decode(s MessageFormat) (Message, error) { - if err := s.MatchesStructure(*m.defaultMessage()); err != nil { +// Decode implements the interface connection.Message. +func (m Describe) Decode(s connection.MessageFormat) (connection.Message, error) { + if err := s.MatchesStructure(*m.DefaultMessage()); err != nil { return nil, err } describingTarget := s.Field("DescribingTarget").MustGet().(int32) @@ -89,7 +93,7 @@ func (m Describe) decode(s MessageFormat) (Message, error) { }, nil } -// defaultMessage implements the interface Message. -func (m Describe) defaultMessage() *MessageFormat { +// DefaultMessage implements the interface connection.Message. +func (m Describe) DefaultMessage() *connection.MessageFormat { return &describeDefault } diff --git a/postgres/messages/empty_query_response.go b/postgres/messages/empty_query_response.go index 43748f1016..3d847e90d2 100644 --- a/postgres/messages/empty_query_response.go +++ b/postgres/messages/empty_query_response.go @@ -14,48 +14,50 @@ package messages +import "github.com/dolthub/doltgresql/postgres/connection" + func init() { - initializeDefaultMessage(EmptyQueryResponse{}) - addMessageHeader(EmptyQueryResponse{}) + connection.InitializeDefaultMessage(EmptyQueryResponse{}) + connection.AddMessageHeader(EmptyQueryResponse{}) } // EmptyQueryResponse represents a PostgreSQL message. type EmptyQueryResponse struct{} -var emptyQueryResponseDefault = MessageFormat{ +var emptyQueryResponseDefault = connection.MessageFormat{ Name: "EmptyQueryResponse", - Fields: FieldGroup{ + Fields: connection.FieldGroup{ { Name: "Header", - Type: Byte1, - Flags: Header, + Type: connection.Byte1, + Flags: connection.Header, Data: int32('I'), }, { Name: "MessageLength", - Type: Int32, - Flags: MessageLengthInclusive, + Type: connection.Int32, + Flags: connection.MessageLengthInclusive, Data: int32(4), }, }, } -var _ Message = EmptyQueryResponse{} +var _ connection.Message = EmptyQueryResponse{} -// encode implements the interface Message. -func (m EmptyQueryResponse) encode() (MessageFormat, error) { - return m.defaultMessage().Copy(), nil +// Encode implements the interface connection.Message. +func (m EmptyQueryResponse) Encode() (connection.MessageFormat, error) { + return m.DefaultMessage().Copy(), nil } -// decode implements the interface Message. -func (m EmptyQueryResponse) decode(s MessageFormat) (Message, error) { - if err := s.MatchesStructure(*m.defaultMessage()); err != nil { +// Decode implements the interface connection.Message. +func (m EmptyQueryResponse) Decode(s connection.MessageFormat) (connection.Message, error) { + if err := s.MatchesStructure(*m.DefaultMessage()); err != nil { return nil, err } return EmptyQueryResponse{}, nil } -// defaultMessage implements the interface Message. -func (m EmptyQueryResponse) defaultMessage() *MessageFormat { +// DefaultMessage implements the interface connection.Message. +func (m EmptyQueryResponse) DefaultMessage() *connection.MessageFormat { return &emptyQueryResponseDefault } diff --git a/postgres/messages/error_response.go b/postgres/messages/error_response.go index 426ab237fb..cce3556fac 100644 --- a/postgres/messages/error_response.go +++ b/postgres/messages/error_response.go @@ -14,8 +14,10 @@ package messages +import "github.com/dolthub/doltgresql/postgres/connection" + func init() { - initializeDefaultMessage(ErrorResponse{}) + connection.InitializeDefaultMessage(ErrorResponse{}) } // ErrorResponse represents a PostgreSQL message. @@ -29,36 +31,36 @@ type ErrorResponseField struct { Value string } -var errorResponseDefault = MessageFormat{ +var errorResponseDefault = connection.MessageFormat{ Name: "ErrorResponse", - Fields: FieldGroup{ + Fields: connection.FieldGroup{ { Name: "Header", - Type: Byte1, - Flags: Header, + Type: connection.Byte1, + Flags: connection.Header, Data: int32('E'), }, { Name: "MessageLength", - Type: Int32, - Flags: MessageLengthInclusive, + Type: connection.Int32, + Flags: connection.MessageLengthInclusive, Data: int32(0), }, { Name: "Fields", - Type: Repeated, - Flags: RepeatedTerminator, + Type: connection.Repeated, + Flags: connection.RepeatedTerminator, Data: int32(0), - Children: []FieldGroup{ + Children: []connection.FieldGroup{ { { Name: "Code", - Type: Byte1, + Type: connection.Byte1, Data: int32(0), }, { Name: "Value", - Type: String, + Type: connection.String, Data: "", }, }, @@ -67,11 +69,11 @@ var errorResponseDefault = MessageFormat{ }, } -var _ Message = ErrorResponse{} +var _ connection.Message = ErrorResponse{} -// encode implements the interface Message. -func (m ErrorResponse) encode() (MessageFormat, error) { - outputMessage := m.defaultMessage().Copy() +// Encode implements the interface connection.Message. +func (m ErrorResponse) Encode() (connection.MessageFormat, error) { + outputMessage := m.DefaultMessage().Copy() for i, field := range m.Fields { outputMessage.Field("Fields").Child("Code", i).MustWrite(field.Code) outputMessage.Field("Fields").Child("Value", i).MustWrite(field.Value) @@ -79,9 +81,9 @@ func (m ErrorResponse) encode() (MessageFormat, error) { return outputMessage, nil } -// decode implements the interface Message. -func (m ErrorResponse) decode(s MessageFormat) (Message, error) { - if err := s.MatchesStructure(*m.defaultMessage()); err != nil { +// Decode implements the interface connection.Message. +func (m ErrorResponse) Decode(s connection.MessageFormat) (connection.Message, error) { + if err := s.MatchesStructure(*m.DefaultMessage()); err != nil { return nil, err } count := int(s.Field("Fields").MustGet().(int32)) @@ -97,7 +99,7 @@ func (m ErrorResponse) decode(s MessageFormat) (Message, error) { }, nil } -// defaultMessage implements the interface Message. -func (m ErrorResponse) defaultMessage() *MessageFormat { +// DefaultMessage implements the interface connection.Message. +func (m ErrorResponse) DefaultMessage() *connection.MessageFormat { return &errorResponseDefault } diff --git a/postgres/messages/execute.go b/postgres/messages/execute.go index d75569da1d..5d2290db24 100644 --- a/postgres/messages/execute.go +++ b/postgres/messages/execute.go @@ -14,9 +14,11 @@ package messages +import "github.com/dolthub/doltgresql/postgres/connection" + func init() { - initializeDefaultMessage(Execute{}) - addMessageHeader(Execute{}) + connection.InitializeDefaultMessage(Execute{}) + connection.AddMessageHeader(Execute{}) } // Execute represents a PostgreSQL message. @@ -25,47 +27,47 @@ type Execute struct { RowMax int32 } -var executeDefault = MessageFormat{ +var executeDefault = connection.MessageFormat{ Name: "Execute", - Fields: FieldGroup{ + Fields: connection.FieldGroup{ { Name: "Header", - Type: Byte1, - Flags: Header, + Type: connection.Byte1, + Flags: connection.Header, Data: int32('E'), }, { Name: "MessageLength", - Type: Int32, - Flags: MessageLengthInclusive, + Type: connection.Int32, + Flags: connection.MessageLengthInclusive, Data: int32(0), }, { Name: "Portal", - Type: String, + Type: connection.String, Data: "", }, { Name: "RowMax", - Type: Int32, + Type: connection.Int32, Data: int32(0), }, }, } -var _ Message = Execute{} +var _ connection.Message = Execute{} -// encode implements the interface Message. -func (m Execute) encode() (MessageFormat, error) { - outputMessage := m.defaultMessage().Copy() +// Encode implements the interface connection.Message. +func (m Execute) Encode() (connection.MessageFormat, error) { + outputMessage := m.DefaultMessage().Copy() outputMessage.Field("Portal").MustWrite(m.Portal) outputMessage.Field("RowMax").MustWrite(m.RowMax) return outputMessage, nil } -// decode implements the interface Message. -func (m Execute) decode(s MessageFormat) (Message, error) { - if err := s.MatchesStructure(*m.defaultMessage()); err != nil { +// Decode implements the interface connection.Message. +func (m Execute) Decode(s connection.MessageFormat) (connection.Message, error) { + if err := s.MatchesStructure(*m.DefaultMessage()); err != nil { return nil, err } return Execute{ @@ -74,7 +76,7 @@ func (m Execute) decode(s MessageFormat) (Message, error) { }, nil } -// defaultMessage implements the interface Message. -func (m Execute) defaultMessage() *MessageFormat { +// DefaultMessage implements the interface connection.Message. +func (m Execute) DefaultMessage() *connection.MessageFormat { return &executeDefault } diff --git a/postgres/messages/flush.go b/postgres/messages/flush.go index 69484803ae..208a2146d9 100644 --- a/postgres/messages/flush.go +++ b/postgres/messages/flush.go @@ -14,48 +14,50 @@ package messages +import "github.com/dolthub/doltgresql/postgres/connection" + func init() { - initializeDefaultMessage(Flush{}) - addMessageHeader(Flush{}) + connection.InitializeDefaultMessage(Flush{}) + connection.AddMessageHeader(Flush{}) } // Flush represents a PostgreSQL message. type Flush struct{} -var flushDefault = MessageFormat{ +var flushDefault = connection.MessageFormat{ Name: "Flush", - Fields: FieldGroup{ + Fields: connection.FieldGroup{ { Name: "Header", - Type: Byte1, - Flags: Header, + Type: connection.Byte1, + Flags: connection.Header, Data: int32('H'), }, { Name: "MessageLength", - Type: Int32, - Flags: MessageLengthInclusive, + Type: connection.Int32, + Flags: connection.MessageLengthInclusive, Data: int32(0), }, }, } -var _ Message = Flush{} +var _ connection.Message = Flush{} -// encode implements the interface Message. -func (m Flush) encode() (MessageFormat, error) { - return m.defaultMessage().Copy(), nil +// Encode implements the interface connection.Message. +func (m Flush) Encode() (connection.MessageFormat, error) { + return m.DefaultMessage().Copy(), nil } -// decode implements the interface Message. -func (m Flush) decode(s MessageFormat) (Message, error) { - if err := s.MatchesStructure(*m.defaultMessage()); err != nil { +// Decode implements the interface connection.Message. +func (m Flush) Decode(s connection.MessageFormat) (connection.Message, error) { + if err := s.MatchesStructure(*m.DefaultMessage()); err != nil { return nil, err } return Flush{}, nil } -// defaultMessage implements the interface Message. -func (m Flush) defaultMessage() *MessageFormat { +// DefaultMessage implements the interface connection.Message. +func (m Flush) DefaultMessage() *connection.MessageFormat { return &flushDefault } diff --git a/postgres/messages/function_call.go b/postgres/messages/function_call.go index d3f1a500ae..0df029c967 100644 --- a/postgres/messages/function_call.go +++ b/postgres/messages/function_call.go @@ -14,9 +14,11 @@ package messages +import "github.com/dolthub/doltgresql/postgres/connection" + func init() { - initializeDefaultMessage(FunctionCall{}) - addMessageHeader(FunctionCall{}) + connection.InitializeDefaultMessage(FunctionCall{}) + connection.AddMessageHeader(FunctionCall{}) } // FunctionCall represents a PostgreSQL message. @@ -33,35 +35,35 @@ type FunctionCallArgument struct { IsNull bool } -var functionCallDefault = MessageFormat{ +var functionCallDefault = connection.MessageFormat{ Name: "FunctionCall", - Fields: FieldGroup{ + Fields: connection.FieldGroup{ { Name: "Header", - Type: Byte1, - Flags: Header, + Type: connection.Byte1, + Flags: connection.Header, Data: int32('F'), }, { Name: "MessageLength", - Type: Int32, - Flags: MessageLengthInclusive, + Type: connection.Int32, + Flags: connection.MessageLengthInclusive, Data: int32(0), }, { Name: "ObjectID", - Type: Int32, + Type: connection.Int32, Data: int32(0), }, { Name: "ArgumentFormatCodes", - Type: Int16, + Type: connection.Int16, Data: int32(0), - Children: []FieldGroup{ + Children: []connection.FieldGroup{ { { Name: "ArgumentFormatCode", - Type: Int16, + Type: connection.Int16, Data: int32(0), }, }, @@ -69,19 +71,19 @@ var functionCallDefault = MessageFormat{ }, { Name: "Arguments", - Type: Int16, + Type: connection.Int16, Data: int32(0), - Children: []FieldGroup{ + Children: []connection.FieldGroup{ { { Name: "ArgumentLength", - Type: Int32, - Flags: ByteCount, + Type: connection.Int32, + Flags: connection.ByteCount, Data: int32(0), }, { Name: "ArgumentValue", - Type: ByteN, + Type: connection.ByteN, Data: []byte{}, }, }, @@ -89,17 +91,17 @@ var functionCallDefault = MessageFormat{ }, { Name: "ResultFormatCode", - Type: Int16, + Type: connection.Int16, Data: int32(0), }, }, } -var _ Message = FunctionCall{} +var _ connection.Message = FunctionCall{} -// encode implements the interface Message. -func (m FunctionCall) encode() (MessageFormat, error) { - outputMessage := m.defaultMessage().Copy() +// Encode implements the interface connection.Message. +func (m FunctionCall) Encode() (connection.MessageFormat, error) { + outputMessage := m.DefaultMessage().Copy() outputMessage.Field("ObjectID").MustWrite(m.ObjectID) for i, formatCode := range m.ArgumentFormatCodes { outputMessage.Field("ArgumentFormatCodes").Child("ArgumentFormatCode", i).MustWrite(formatCode) @@ -116,9 +118,9 @@ func (m FunctionCall) encode() (MessageFormat, error) { return outputMessage, nil } -// decode implements the interface Message. -func (m FunctionCall) decode(s MessageFormat) (Message, error) { - if err := s.MatchesStructure(*m.defaultMessage()); err != nil { +// Decode implements the interface connection.Message. +func (m FunctionCall) Decode(s connection.MessageFormat) (connection.Message, error) { + if err := s.MatchesStructure(*m.DefaultMessage()); err != nil { return nil, err } @@ -153,7 +155,7 @@ func (m FunctionCall) decode(s MessageFormat) (Message, error) { }, nil } -// defaultMessage implements the interface Message. -func (m FunctionCall) defaultMessage() *MessageFormat { +// DefaultMessage implements the interface connection.Message. +func (m FunctionCall) DefaultMessage() *connection.MessageFormat { return &functionCallDefault } diff --git a/postgres/messages/function_call_response.go b/postgres/messages/function_call_response.go index 21516170d5..e9eab0e1f4 100644 --- a/postgres/messages/function_call_response.go +++ b/postgres/messages/function_call_response.go @@ -14,8 +14,10 @@ package messages +import "github.com/dolthub/doltgresql/postgres/connection" + func init() { - initializeDefaultMessage(FunctionCallResponse{}) + connection.InitializeDefaultMessage(FunctionCallResponse{}) } // FunctionCallResponse represents a PostgreSQL message. @@ -24,40 +26,40 @@ type FunctionCallResponse struct { ResultValue []byte } -var functionCallResponseDefault = MessageFormat{ +var functionCallResponseDefault = connection.MessageFormat{ Name: "FunctionCallResponse", - Fields: FieldGroup{ + Fields: connection.FieldGroup{ { Name: "Header", - Type: Byte1, - Flags: Header, + Type: connection.Byte1, + Flags: connection.Header, Data: int32('V'), }, { Name: "MessageLength", - Type: Int32, - Flags: MessageLengthInclusive, + Type: connection.Int32, + Flags: connection.MessageLengthInclusive, Data: int32(0), }, { Name: "ResultLength", - Type: Int32, - Flags: ByteCount, + Type: connection.Int32, + Flags: connection.ByteCount, Data: int32(0), }, { Name: "ResultValue", - Type: ByteN, + Type: connection.ByteN, Data: []byte{}, }, }, } -var _ Message = FunctionCallResponse{} +var _ connection.Message = FunctionCallResponse{} -// encode implements the interface Message. -func (m FunctionCallResponse) encode() (MessageFormat, error) { - outputMessage := m.defaultMessage().Copy() +// Encode implements the interface connection.Message. +func (m FunctionCallResponse) Encode() (connection.MessageFormat, error) { + outputMessage := m.DefaultMessage().Copy() if m.IsResultNull { outputMessage.Field("ResultLength").MustWrite(-1) } else { @@ -70,9 +72,9 @@ func (m FunctionCallResponse) encode() (MessageFormat, error) { return outputMessage, nil } -// decode implements the interface Message. -func (m FunctionCallResponse) decode(s MessageFormat) (Message, error) { - if err := s.MatchesStructure(*m.defaultMessage()); err != nil { +// Decode implements the interface connection.Message. +func (m FunctionCallResponse) Decode(s connection.MessageFormat) (connection.Message, error) { + if err := s.MatchesStructure(*m.DefaultMessage()); err != nil { return nil, err } isNull := s.Field("ResultLength").MustGet().(int32) == -1 @@ -82,7 +84,7 @@ func (m FunctionCallResponse) decode(s MessageFormat) (Message, error) { }, nil } -// defaultMessage implements the interface Message. -func (m FunctionCallResponse) defaultMessage() *MessageFormat { +// DefaultMessage implements the interface connection.Message. +func (m FunctionCallResponse) DefaultMessage() *connection.MessageFormat { return &functionCallResponseDefault } diff --git a/postgres/messages/gss_response.go b/postgres/messages/gss_response.go index 104b56eef1..de17a42b42 100644 --- a/postgres/messages/gss_response.go +++ b/postgres/messages/gss_response.go @@ -14,8 +14,10 @@ package messages +import "github.com/dolthub/doltgresql/postgres/connection" + func init() { - initializeDefaultMessage(GSSResponse{}) + connection.InitializeDefaultMessage(GSSResponse{}) } // GSSResponse represents a PostgreSQL message. @@ -23,41 +25,41 @@ type GSSResponse struct { Data []byte } -var gSSResponseDefault = MessageFormat{ +var gSSResponseDefault = connection.MessageFormat{ Name: "GSSResponse", - Fields: FieldGroup{ + Fields: connection.FieldGroup{ { Name: "Header", - Type: Byte1, - Flags: Header, + Type: connection.Byte1, + Flags: connection.Header, Data: int32('p'), }, { Name: "MessageLength", - Type: Int32, - Flags: MessageLengthInclusive, + Type: connection.Int32, + Flags: connection.MessageLengthInclusive, Data: int32(0), }, { Name: "Data", - Type: ByteN, + Type: connection.ByteN, Data: []byte{}, }, }, } -var _ Message = GSSResponse{} +var _ connection.Message = GSSResponse{} -// encode implements the interface Message. -func (m GSSResponse) encode() (MessageFormat, error) { - outputMessage := m.defaultMessage().Copy() +// Encode implements the interface connection.Message. +func (m GSSResponse) Encode() (connection.MessageFormat, error) { + outputMessage := m.DefaultMessage().Copy() outputMessage.Field("Data").MustWrite(m.Data) return outputMessage, nil } -// decode implements the interface Message. -func (m GSSResponse) decode(s MessageFormat) (Message, error) { - if err := s.MatchesStructure(*m.defaultMessage()); err != nil { +// Decode implements the interface connection.Message. +func (m GSSResponse) Decode(s connection.MessageFormat) (connection.Message, error) { + if err := s.MatchesStructure(*m.DefaultMessage()); err != nil { return nil, err } return GSSResponse{ @@ -65,7 +67,7 @@ func (m GSSResponse) decode(s MessageFormat) (Message, error) { }, nil } -// defaultMessage implements the interface Message. -func (m GSSResponse) defaultMessage() *MessageFormat { +// DefaultMessage implements the interface connection.Message. +func (m GSSResponse) DefaultMessage() *connection.MessageFormat { return &gSSResponseDefault } diff --git a/postgres/messages/gssenc_request.go b/postgres/messages/gssenc_request.go index 8e2c4e5ced..f7c1ea7df9 100644 --- a/postgres/messages/gssenc_request.go +++ b/postgres/messages/gssenc_request.go @@ -14,46 +14,48 @@ package messages +import "github.com/dolthub/doltgresql/postgres/connection" + func init() { - initializeDefaultMessage(GSSENCRequest{}) + connection.InitializeDefaultMessage(GSSENCRequest{}) } // GSSENCRequest represents a PostgreSQL message. type GSSENCRequest struct{} -var gSSENCRequestDefault = MessageFormat{ +var gSSENCRequestDefault = connection.MessageFormat{ Name: "GSSENCRequest", - Fields: FieldGroup{ + Fields: connection.FieldGroup{ { Name: "MessageLength", - Type: Int32, - Flags: MessageLengthInclusive, + Type: connection.Int32, + Flags: connection.MessageLengthInclusive, Data: int32(8), }, { Name: "RequestCode", - Type: Int32, + Type: connection.Int32, Data: int32(80877104), }, }, } -var _ Message = GSSENCRequest{} +var _ connection.Message = GSSENCRequest{} -// encode implements the interface Message. -func (m GSSENCRequest) encode() (MessageFormat, error) { - return m.defaultMessage().Copy(), nil +// Encode implements the interface connection.Message. +func (m GSSENCRequest) Encode() (connection.MessageFormat, error) { + return m.DefaultMessage().Copy(), nil } -// decode implements the interface Message. -func (m GSSENCRequest) decode(s MessageFormat) (Message, error) { - if err := s.MatchesStructure(*m.defaultMessage()); err != nil { +// Decode implements the interface connection.Message. +func (m GSSENCRequest) Decode(s connection.MessageFormat) (connection.Message, error) { + if err := s.MatchesStructure(*m.DefaultMessage()); err != nil { return nil, err } return GSSENCRequest{}, nil } -// defaultMessage implements the interface Message. -func (m GSSENCRequest) defaultMessage() *MessageFormat { +// DefaultMessage implements the interface connection.Message. +func (m GSSENCRequest) DefaultMessage() *connection.MessageFormat { return &gSSENCRequestDefault } diff --git a/postgres/messages/negotiate_protocol_version.go b/postgres/messages/negotiate_protocol_version.go index 3257badcf4..dfd6271cd6 100644 --- a/postgres/messages/negotiate_protocol_version.go +++ b/postgres/messages/negotiate_protocol_version.go @@ -14,8 +14,10 @@ package messages +import "github.com/dolthub/doltgresql/postgres/connection" + func init() { - initializeDefaultMessage(NegotiateProtocolVersion{}) + connection.InitializeDefaultMessage(NegotiateProtocolVersion{}) } // NegotiateProtocolVersion represents a PostgreSQL message. @@ -24,35 +26,35 @@ type NegotiateProtocolVersion struct { UnrecognizedOptions []string } -var negotiateProtocolVersionDefault = MessageFormat{ +var negotiateProtocolVersionDefault = connection.MessageFormat{ Name: "NegotiateProtocolVersion", - Fields: FieldGroup{ + Fields: connection.FieldGroup{ { Name: "Header", - Type: Byte1, - Flags: Header, + Type: connection.Byte1, + Flags: connection.Header, Data: int32('v'), }, { Name: "MessageLength", - Type: Int32, - Flags: MessageLengthInclusive, + Type: connection.Int32, + Flags: connection.MessageLengthInclusive, Data: int32(0), }, { Name: "NewestMinorProtocol", - Type: Int32, + Type: connection.Int32, Data: int32(0), }, { Name: "UnrecognizedOptions", - Type: Int32, + Type: connection.Int32, Data: int32(0), - Children: []FieldGroup{ + Children: []connection.FieldGroup{ { { Name: "UnrecognizedOption", - Type: String, + Type: connection.String, Data: "", }, }, @@ -61,11 +63,11 @@ var negotiateProtocolVersionDefault = MessageFormat{ }, } -var _ Message = NegotiateProtocolVersion{} +var _ connection.Message = NegotiateProtocolVersion{} -// encode implements the interface Message. -func (m NegotiateProtocolVersion) encode() (MessageFormat, error) { - outputMessage := m.defaultMessage().Copy() +// Encode implements the interface connection.Message. +func (m NegotiateProtocolVersion) Encode() (connection.MessageFormat, error) { + outputMessage := m.DefaultMessage().Copy() outputMessage.Field("NewestMinorProtocol").MustWrite(m.NewestMinorProtocol) for i, option := range m.UnrecognizedOptions { outputMessage.Field("UnrecognizedOptions").Child("UnrecognizedOption", i).MustWrite(option) @@ -73,9 +75,9 @@ func (m NegotiateProtocolVersion) encode() (MessageFormat, error) { return outputMessage, nil } -// decode implements the interface Message. -func (m NegotiateProtocolVersion) decode(s MessageFormat) (Message, error) { - if err := s.MatchesStructure(*m.defaultMessage()); err != nil { +// Decode implements the interface connection.Message. +func (m NegotiateProtocolVersion) Decode(s connection.MessageFormat) (connection.Message, error) { + if err := s.MatchesStructure(*m.DefaultMessage()); err != nil { return nil, err } count := int(s.Field("UnrecognizedOptions").MustGet().(int32)) @@ -89,7 +91,7 @@ func (m NegotiateProtocolVersion) decode(s MessageFormat) (Message, error) { }, nil } -// defaultMessage implements the interface Message. -func (m NegotiateProtocolVersion) defaultMessage() *MessageFormat { +// DefaultMessage implements the interface connection.Message. +func (m NegotiateProtocolVersion) DefaultMessage() *connection.MessageFormat { return &negotiateProtocolVersionDefault } diff --git a/postgres/messages/no_data.go b/postgres/messages/no_data.go index aab20f4a64..7dca9abbd6 100644 --- a/postgres/messages/no_data.go +++ b/postgres/messages/no_data.go @@ -14,47 +14,49 @@ package messages +import "github.com/dolthub/doltgresql/postgres/connection" + func init() { - initializeDefaultMessage(NoData{}) + connection.InitializeDefaultMessage(NoData{}) } // NoData represents a PostgreSQL message. type NoData struct{} -var noDataDefault = MessageFormat{ +var noDataDefault = connection.MessageFormat{ Name: "NoData", - Fields: FieldGroup{ + Fields: connection.FieldGroup{ { Name: "Header", - Type: Byte1, - Flags: Header, + Type: connection.Byte1, + Flags: connection.Header, Data: int32('n'), }, { Name: "MessageLength", - Type: Int32, - Flags: MessageLengthInclusive, + Type: connection.Int32, + Flags: connection.MessageLengthInclusive, Data: int32(4), }, }, } -var _ Message = NoData{} +var _ connection.Message = NoData{} -// encode implements the interface Message. -func (m NoData) encode() (MessageFormat, error) { - return m.defaultMessage().Copy(), nil +// Encode implements the interface connection.Message. +func (m NoData) Encode() (connection.MessageFormat, error) { + return m.DefaultMessage().Copy(), nil } -// decode implements the interface Message. -func (m NoData) decode(s MessageFormat) (Message, error) { - if err := s.MatchesStructure(*m.defaultMessage()); err != nil { +// Decode implements the interface connection.Message. +func (m NoData) Decode(s connection.MessageFormat) (connection.Message, error) { + if err := s.MatchesStructure(*m.DefaultMessage()); err != nil { return nil, err } return NoData{}, nil } -// defaultMessage implements the interface Message. -func (m NoData) defaultMessage() *MessageFormat { +// DefaultMessage implements the interface connection.Message. +func (m NoData) DefaultMessage() *connection.MessageFormat { return &noDataDefault } diff --git a/postgres/messages/notice_response.go b/postgres/messages/notice_response.go index 471d725820..54425af793 100644 --- a/postgres/messages/notice_response.go +++ b/postgres/messages/notice_response.go @@ -14,8 +14,10 @@ package messages +import "github.com/dolthub/doltgresql/postgres/connection" + func init() { - initializeDefaultMessage(NoticeResponse{}) + connection.InitializeDefaultMessage(NoticeResponse{}) } // NoticeResponse represents a PostgreSQL message. @@ -29,36 +31,36 @@ type NoticeResponseField struct { Value string } -var noticeResponseDefault = MessageFormat{ +var noticeResponseDefault = connection.MessageFormat{ Name: "NoticeResponse", - Fields: FieldGroup{ + Fields: connection.FieldGroup{ { Name: "Header", - Type: Byte1, - Flags: Header, + Type: connection.Byte1, + Flags: connection.Header, Data: int32('N'), }, { Name: "MessageLength", - Type: Int32, - Flags: MessageLengthInclusive, + Type: connection.Int32, + Flags: connection.MessageLengthInclusive, Data: int32(0), }, { Name: "Fields", - Type: Repeated, - Flags: RepeatedTerminator, + Type: connection.Repeated, + Flags: connection.RepeatedTerminator, Data: int32(0), - Children: []FieldGroup{ + Children: []connection.FieldGroup{ { { Name: "Code", - Type: Byte1, + Type: connection.Byte1, Data: int32(0), }, { Name: "Value", - Type: String, + Type: connection.String, Data: "", }, }, @@ -67,11 +69,11 @@ var noticeResponseDefault = MessageFormat{ }, } -var _ Message = NoticeResponse{} +var _ connection.Message = NoticeResponse{} -// encode implements the interface Message. -func (m NoticeResponse) encode() (MessageFormat, error) { - outputMessage := m.defaultMessage().Copy() +// Encode implements the interface connection.Message. +func (m NoticeResponse) Encode() (connection.MessageFormat, error) { + outputMessage := m.DefaultMessage().Copy() for i, field := range m.Fields { outputMessage.Field("Fields").Child("Code", i).MustWrite(field.Code) outputMessage.Field("Fields").Child("Value", i).MustWrite(field.Value) @@ -79,9 +81,9 @@ func (m NoticeResponse) encode() (MessageFormat, error) { return outputMessage, nil } -// decode implements the interface Message. -func (m NoticeResponse) decode(s MessageFormat) (Message, error) { - if err := s.MatchesStructure(*m.defaultMessage()); err != nil { +// Decode implements the interface connection.Message. +func (m NoticeResponse) Decode(s connection.MessageFormat) (connection.Message, error) { + if err := s.MatchesStructure(*m.DefaultMessage()); err != nil { return nil, err } count := int(s.Field("Fields").MustGet().(int32)) @@ -97,7 +99,7 @@ func (m NoticeResponse) decode(s MessageFormat) (Message, error) { }, nil } -// defaultMessage implements the interface Message. -func (m NoticeResponse) defaultMessage() *MessageFormat { +// DefaultMessage implements the interface connection.Message. +func (m NoticeResponse) DefaultMessage() *connection.MessageFormat { return ¬iceResponseDefault } diff --git a/postgres/messages/notification_response.go b/postgres/messages/notification_response.go index 0299e2b92e..271237116a 100644 --- a/postgres/messages/notification_response.go +++ b/postgres/messages/notification_response.go @@ -14,8 +14,10 @@ package messages +import "github.com/dolthub/doltgresql/postgres/connection" + func init() { - initializeDefaultMessage(NotificationResponse{}) + connection.InitializeDefaultMessage(NotificationResponse{}) } // NotificationResponse represents a PostgreSQL message. @@ -25,53 +27,53 @@ type NotificationResponse struct { Payload string } -var notificationResponseDefault = MessageFormat{ +var notificationResponseDefault = connection.MessageFormat{ Name: "NotificationResponse", - Fields: FieldGroup{ + Fields: connection.FieldGroup{ { Name: "Header", - Type: Byte1, - Flags: Header, + Type: connection.Byte1, + Flags: connection.Header, Data: int32('A'), }, { Name: "MessageLength", - Type: Int32, - Flags: MessageLengthInclusive, + Type: connection.Int32, + Flags: connection.MessageLengthInclusive, Data: int32(0), }, { Name: "ProcessID", - Type: Int32, + Type: connection.Int32, Data: int32(0), }, { Name: "Channel", - Type: String, + Type: connection.String, Data: "", }, { Name: "Payload", - Type: String, + Type: connection.String, Data: "", }, }, } -var _ Message = NotificationResponse{} +var _ connection.Message = NotificationResponse{} -// encode implements the interface Message. -func (m NotificationResponse) encode() (MessageFormat, error) { - outputMessage := m.defaultMessage().Copy() +// Encode implements the interface connection.Message. +func (m NotificationResponse) Encode() (connection.MessageFormat, error) { + outputMessage := m.DefaultMessage().Copy() outputMessage.Field("ProcessID").MustWrite(m.ProcessID) outputMessage.Field("Channel").MustWrite(m.Channel) outputMessage.Field("Payload").MustWrite(m.Payload) return outputMessage, nil } -// decode implements the interface Message. -func (m NotificationResponse) decode(s MessageFormat) (Message, error) { - if err := s.MatchesStructure(*m.defaultMessage()); err != nil { +// Decode implements the interface connection.Message. +func (m NotificationResponse) Decode(s connection.MessageFormat) (connection.Message, error) { + if err := s.MatchesStructure(*m.DefaultMessage()); err != nil { return nil, err } return NotificationResponse{ @@ -81,7 +83,7 @@ func (m NotificationResponse) decode(s MessageFormat) (Message, error) { }, nil } -// defaultMessage implements the interface Message. -func (m NotificationResponse) defaultMessage() *MessageFormat { +// DefaultMessage implements the interface connection.Message. +func (m NotificationResponse) DefaultMessage() *connection.MessageFormat { return ¬ificationResponseDefault } diff --git a/postgres/messages/parameter_description.go b/postgres/messages/parameter_description.go index 5fef3b95f0..a4b3f99d81 100644 --- a/postgres/messages/parameter_description.go +++ b/postgres/messages/parameter_description.go @@ -14,8 +14,10 @@ package messages +import "github.com/dolthub/doltgresql/postgres/connection" + func init() { - initializeDefaultMessage(ParameterDescription{}) + connection.InitializeDefaultMessage(ParameterDescription{}) } // ParameterDescription represents a PostgreSQL message. @@ -23,30 +25,30 @@ type ParameterDescription struct { ObjectIDs []int32 } -var parameterDescriptionDefault = MessageFormat{ +var parameterDescriptionDefault = connection.MessageFormat{ Name: "ParameterDescription", - Fields: FieldGroup{ + Fields: connection.FieldGroup{ { Name: "Header", - Type: Byte1, - Flags: Header, + Type: connection.Byte1, + Flags: connection.Header, Data: int32('t'), }, { Name: "MessageLength", - Type: Int32, - Flags: MessageLengthInclusive, + Type: connection.Int32, + Flags: connection.MessageLengthInclusive, Data: int32(0), }, { Name: "Parameters", - Type: Int16, + Type: connection.Int16, Data: int32(0), - Children: []FieldGroup{ + Children: []connection.FieldGroup{ { { Name: "ObjectID", - Type: Int32, + Type: connection.Int32, Data: int32(0), }, }, @@ -55,20 +57,20 @@ var parameterDescriptionDefault = MessageFormat{ }, } -var _ Message = ParameterDescription{} +var _ connection.Message = ParameterDescription{} -// encode implements the interface Message. -func (m ParameterDescription) encode() (MessageFormat, error) { - outputMessage := m.defaultMessage().Copy() +// Encode implements the interface connection.Message. +func (m ParameterDescription) Encode() (connection.MessageFormat, error) { + outputMessage := m.DefaultMessage().Copy() for i, objectID := range m.ObjectIDs { outputMessage.Field("Parameters").Child("ObjectID", i).MustWrite(objectID) } return outputMessage, nil } -// decode implements the interface Message. -func (m ParameterDescription) decode(s MessageFormat) (Message, error) { - if err := s.MatchesStructure(*m.defaultMessage()); err != nil { +// Decode implements the interface connection.Message. +func (m ParameterDescription) Decode(s connection.MessageFormat) (connection.Message, error) { + if err := s.MatchesStructure(*m.DefaultMessage()); err != nil { return nil, err } count := int(s.Field("Parameters").MustGet().(int32)) @@ -81,7 +83,7 @@ func (m ParameterDescription) decode(s MessageFormat) (Message, error) { }, nil } -// defaultMessage implements the interface Message. -func (m ParameterDescription) defaultMessage() *MessageFormat { +// DefaultMessage implements the interface connection.Message. +func (m ParameterDescription) DefaultMessage() *connection.MessageFormat { return ¶meterDescriptionDefault } diff --git a/postgres/messages/parameter_status.go b/postgres/messages/parameter_status.go index 0c4c961272..a6293cc42e 100644 --- a/postgres/messages/parameter_status.go +++ b/postgres/messages/parameter_status.go @@ -14,8 +14,10 @@ package messages +import "github.com/dolthub/doltgresql/postgres/connection" + func init() { - initializeDefaultMessage(ParameterStatus{}) + connection.InitializeDefaultMessage(ParameterStatus{}) } // ParameterStatus reports various parameters to the client. @@ -24,47 +26,47 @@ type ParameterStatus struct { Value string } -var parameterStatusDefault = MessageFormat{ +var parameterStatusDefault = connection.MessageFormat{ Name: "ParameterStatus", - Fields: FieldGroup{ + Fields: connection.FieldGroup{ { Name: "Header", - Type: Byte1, - Flags: Header, + Type: connection.Byte1, + Flags: connection.Header, Data: int32('S'), }, { Name: "MessageLength", - Type: Int32, - Flags: MessageLengthInclusive, + Type: connection.Int32, + Flags: connection.MessageLengthInclusive, Data: int32(0), }, { Name: "Name", - Type: String, + Type: connection.String, Data: "", }, { Name: "Value", - Type: String, + Type: connection.String, Data: "", }, }, } -var _ Message = ParameterStatus{} +var _ connection.Message = ParameterStatus{} -// encode implements the interface Message. -func (m ParameterStatus) encode() (MessageFormat, error) { - outputMessage := m.defaultMessage().Copy() +// Encode implements the interface connection.Message. +func (m ParameterStatus) Encode() (connection.MessageFormat, error) { + outputMessage := m.DefaultMessage().Copy() outputMessage.Field("Name").MustWrite(m.Name) outputMessage.Field("Value").MustWrite(m.Value) return outputMessage, nil } -// decode implements the interface Message. -func (m ParameterStatus) decode(s MessageFormat) (Message, error) { - if err := s.MatchesStructure(*m.defaultMessage()); err != nil { +// Decode implements the interface connection.Message. +func (m ParameterStatus) Decode(s connection.MessageFormat) (connection.Message, error) { + if err := s.MatchesStructure(*m.DefaultMessage()); err != nil { return nil, err } return ParameterStatus{ @@ -73,7 +75,7 @@ func (m ParameterStatus) decode(s MessageFormat) (Message, error) { }, nil } -// defaultMessage implements the interface Message. -func (m ParameterStatus) defaultMessage() *MessageFormat { +// DefaultMessage implements the interface connection.Message. +func (m ParameterStatus) DefaultMessage() *connection.MessageFormat { return ¶meterStatusDefault } diff --git a/postgres/messages/parse.go b/postgres/messages/parse.go index a18696a58d..b1c430efef 100644 --- a/postgres/messages/parse.go +++ b/postgres/messages/parse.go @@ -14,9 +14,11 @@ package messages +import "github.com/dolthub/doltgresql/postgres/connection" + func init() { - initializeDefaultMessage(Parse{}) - addMessageHeader(Parse{}) + connection.InitializeDefaultMessage(Parse{}) + connection.AddMessageHeader(Parse{}) } // Parse represents a PostgreSQL message. @@ -26,40 +28,40 @@ type Parse struct { ParameterObjectIDs []int32 } -var parseDefault = MessageFormat{ +var parseDefault = connection.MessageFormat{ Name: "Parse", - Fields: FieldGroup{ + Fields: connection.FieldGroup{ { Name: "Header", - Type: Byte1, - Flags: Header, + Type: connection.Byte1, + Flags: connection.Header, Data: int32('P'), }, { Name: "MessageLength", - Type: Int32, - Flags: MessageLengthInclusive, + Type: connection.Int32, + Flags: connection.MessageLengthInclusive, Data: int32(0), }, { Name: "PreparedStatement", - Type: String, + Type: connection.String, Data: "", }, { Name: "Query", - Type: String, + Type: connection.String, Data: "", }, { Name: "Parameters", - Type: Int16, + Type: connection.Int16, Data: int32(0), - Children: []FieldGroup{ + Children: []connection.FieldGroup{ { { Name: "ObjectID", - Type: Int32, + Type: connection.Int32, Data: int32(0), }, }, @@ -68,11 +70,11 @@ var parseDefault = MessageFormat{ }, } -var _ Message = Parse{} +var _ connection.Message = Parse{} -// encode implements the interface Message. -func (m Parse) encode() (MessageFormat, error) { - outputMessage := m.defaultMessage().Copy() +// Encode implements the interface connection.Message. +func (m Parse) Encode() (connection.MessageFormat, error) { + outputMessage := m.DefaultMessage().Copy() outputMessage.Field("PreparedStatement").MustWrite(m.PreparedStatement) outputMessage.Field("Query").MustWrite(m.Query) for i, objectID := range m.ParameterObjectIDs { @@ -81,9 +83,9 @@ func (m Parse) encode() (MessageFormat, error) { return outputMessage, nil } -// decode implements the interface Message. -func (m Parse) decode(s MessageFormat) (Message, error) { - if err := s.MatchesStructure(*m.defaultMessage()); err != nil { +// Decode implements the interface connection.Message. +func (m Parse) Decode(s connection.MessageFormat) (connection.Message, error) { + if err := s.MatchesStructure(*m.DefaultMessage()); err != nil { return nil, err } count := int(s.Field("Parameters").MustGet().(int32)) @@ -98,7 +100,7 @@ func (m Parse) decode(s MessageFormat) (Message, error) { }, nil } -// defaultMessage implements the interface Message. -func (m Parse) defaultMessage() *MessageFormat { +// DefaultMessage implements the interface connection.Message. +func (m Parse) DefaultMessage() *connection.MessageFormat { return &parseDefault } diff --git a/postgres/messages/parse_complete.go b/postgres/messages/parse_complete.go index 180d20de29..47ea3587fc 100644 --- a/postgres/messages/parse_complete.go +++ b/postgres/messages/parse_complete.go @@ -14,47 +14,49 @@ package messages +import "github.com/dolthub/doltgresql/postgres/connection" + func init() { - initializeDefaultMessage(ParseComplete{}) + connection.InitializeDefaultMessage(ParseComplete{}) } // ParseComplete represents a PostgreSQL message. type ParseComplete struct{} -var parseCompleteDefault = MessageFormat{ +var parseCompleteDefault = connection.MessageFormat{ Name: "ParseComplete", - Fields: FieldGroup{ + Fields: connection.FieldGroup{ { Name: "Header", - Type: Byte1, - Flags: Header, + Type: connection.Byte1, + Flags: connection.Header, Data: int32('1'), }, { Name: "MessageLength", - Type: Int32, - Flags: MessageLengthInclusive, + Type: connection.Int32, + Flags: connection.MessageLengthInclusive, Data: int32(4), }, }, } -var _ Message = ParseComplete{} +var _ connection.Message = ParseComplete{} -// encode implements the interface Message. -func (m ParseComplete) encode() (MessageFormat, error) { - return m.defaultMessage().Copy(), nil +// Encode implements the interface connection.Message. +func (m ParseComplete) Encode() (connection.MessageFormat, error) { + return m.DefaultMessage().Copy(), nil } -// decode implements the interface Message. -func (m ParseComplete) decode(s MessageFormat) (Message, error) { - if err := s.MatchesStructure(*m.defaultMessage()); err != nil { +// Decode implements the interface connection.Message. +func (m ParseComplete) Decode(s connection.MessageFormat) (connection.Message, error) { + if err := s.MatchesStructure(*m.DefaultMessage()); err != nil { return nil, err } return ParseComplete{}, nil } -// defaultMessage implements the interface Message. -func (m ParseComplete) defaultMessage() *MessageFormat { +// DefaultMessage implements the interface connection.Message. +func (m ParseComplete) DefaultMessage() *connection.MessageFormat { return &parseCompleteDefault } diff --git a/postgres/messages/password_message.go b/postgres/messages/password_message.go index ec0d945915..0c396d085c 100644 --- a/postgres/messages/password_message.go +++ b/postgres/messages/password_message.go @@ -14,8 +14,10 @@ package messages +import "github.com/dolthub/doltgresql/postgres/connection" + func init() { - initializeDefaultMessage(PasswordMessage{}) + connection.InitializeDefaultMessage(PasswordMessage{}) } // PasswordMessage represents a PostgreSQL message. @@ -23,41 +25,41 @@ type PasswordMessage struct { Password string } -var passwordMessageDefault = MessageFormat{ +var passwordMessageDefault = connection.MessageFormat{ Name: "PasswordMessage", - Fields: FieldGroup{ + Fields: connection.FieldGroup{ { Name: "Header", - Type: Byte1, - Flags: Header, + Type: connection.Byte1, + Flags: connection.Header, Data: int32('p'), }, { Name: "MessageLength", - Type: Int32, - Flags: MessageLengthInclusive, + Type: connection.Int32, + Flags: connection.MessageLengthInclusive, Data: int32(0), }, { Name: "Password", - Type: String, + Type: connection.String, Data: "", }, }, } -var _ Message = PasswordMessage{} +var _ connection.Message = PasswordMessage{} -// encode implements the interface Message. -func (m PasswordMessage) encode() (MessageFormat, error) { - outputMessage := m.defaultMessage().Copy() +// Encode implements the interface connection.Message. +func (m PasswordMessage) Encode() (connection.MessageFormat, error) { + outputMessage := m.DefaultMessage().Copy() outputMessage.Field("Password").MustWrite(m.Password) return outputMessage, nil } -// decode implements the interface Message. -func (m PasswordMessage) decode(s MessageFormat) (Message, error) { - if err := s.MatchesStructure(*m.defaultMessage()); err != nil { +// Decode implements the interface connection.Message. +func (m PasswordMessage) Decode(s connection.MessageFormat) (connection.Message, error) { + if err := s.MatchesStructure(*m.DefaultMessage()); err != nil { return nil, err } return PasswordMessage{ @@ -65,7 +67,7 @@ func (m PasswordMessage) decode(s MessageFormat) (Message, error) { }, nil } -// defaultMessage implements the interface Message. -func (m PasswordMessage) defaultMessage() *MessageFormat { +// DefaultMessage implements the interface connection.Message. +func (m PasswordMessage) DefaultMessage() *connection.MessageFormat { return &passwordMessageDefault } diff --git a/postgres/messages/portal_suspended.go b/postgres/messages/portal_suspended.go index 55b6501b1e..7107a9b4ba 100644 --- a/postgres/messages/portal_suspended.go +++ b/postgres/messages/portal_suspended.go @@ -14,8 +14,10 @@ package messages +import "github.com/dolthub/doltgresql/postgres/connection" + func init() { - initializeDefaultMessage(PortalSuspended{}) + connection.InitializeDefaultMessage(PortalSuspended{}) } // PortalSuspended represents a PostgreSQL message. @@ -24,40 +26,40 @@ type PortalSuspended struct { String string } -var portalSuspendedDefault = MessageFormat{ +var portalSuspendedDefault = connection.MessageFormat{ Name: "PortalSuspended", - Fields: FieldGroup{ + Fields: connection.FieldGroup{ { Name: "Header", - Type: Byte1, - Flags: Header, + Type: connection.Byte1, + Flags: connection.Header, Data: int32('s'), }, { Name: "MessageLength", - Type: Int32, - Flags: MessageLengthInclusive, + Type: connection.Int32, + Flags: connection.MessageLengthInclusive, Data: int32(4), }, }, } -var _ Message = PortalSuspended{} +var _ connection.Message = PortalSuspended{} -// encode implements the interface Message. -func (m PortalSuspended) encode() (MessageFormat, error) { - return m.defaultMessage().Copy(), nil +// Encode implements the interface connection.Message. +func (m PortalSuspended) Encode() (connection.MessageFormat, error) { + return m.DefaultMessage().Copy(), nil } -// decode implements the interface Message. -func (m PortalSuspended) decode(s MessageFormat) (Message, error) { - if err := s.MatchesStructure(*m.defaultMessage()); err != nil { +// Decode implements the interface connection.Message. +func (m PortalSuspended) Decode(s connection.MessageFormat) (connection.Message, error) { + if err := s.MatchesStructure(*m.DefaultMessage()); err != nil { return nil, err } return PortalSuspended{}, nil } -// defaultMessage implements the interface Message. -func (m PortalSuspended) defaultMessage() *MessageFormat { +// DefaultMessage implements the interface connection.Message. +func (m PortalSuspended) DefaultMessage() *connection.MessageFormat { return &portalSuspendedDefault } diff --git a/postgres/messages/query.go b/postgres/messages/query.go index 847b411fb9..4b41a44cc8 100644 --- a/postgres/messages/query.go +++ b/postgres/messages/query.go @@ -14,9 +14,11 @@ package messages +import "github.com/dolthub/doltgresql/postgres/connection" + func init() { - initializeDefaultMessage(Query{}) - addMessageHeader(Query{}) + connection.InitializeDefaultMessage(Query{}) + connection.AddMessageHeader(Query{}) } // Query contains a query given by the client. @@ -24,41 +26,41 @@ type Query struct { String string } -var queryDefault = MessageFormat{ +var queryDefault = connection.MessageFormat{ Name: "Query", - Fields: FieldGroup{ + Fields: connection.FieldGroup{ { Name: "Header", - Type: Byte1, - Flags: Header, + Type: connection.Byte1, + Flags: connection.Header, Data: int32('Q'), }, { Name: "MessageLength", - Type: Int32, - Flags: MessageLengthInclusive, + Type: connection.Int32, + Flags: connection.MessageLengthInclusive, Data: int32(0), }, { Name: "String", - Type: String, + Type: connection.String, Data: "", }, }, } -var _ Message = Query{} +var _ connection.Message = Query{} -// encode implements the interface Message. -func (m Query) encode() (MessageFormat, error) { - outputMessage := m.defaultMessage().Copy() +// Encode implements the interface connection.Message. +func (m Query) Encode() (connection.MessageFormat, error) { + outputMessage := m.DefaultMessage().Copy() outputMessage.Field("String").MustWrite(m.String) return outputMessage, nil } -// decode implements the interface Message. -func (m Query) decode(s MessageFormat) (Message, error) { - if err := s.MatchesStructure(*m.defaultMessage()); err != nil { +// Decode implements the interface connection.Message. +func (m Query) Decode(s connection.MessageFormat) (connection.Message, error) { + if err := s.MatchesStructure(*m.DefaultMessage()); err != nil { return nil, err } return Query{ @@ -66,7 +68,7 @@ func (m Query) decode(s MessageFormat) (Message, error) { }, nil } -// defaultMessage implements the interface Message. -func (m Query) defaultMessage() *MessageFormat { +// DefaultMessage implements the interface connection.Message. +func (m Query) DefaultMessage() *connection.MessageFormat { return &queryDefault } diff --git a/postgres/messages/ready_for_query.go b/postgres/messages/ready_for_query.go index 96511f1770..0824a9b287 100644 --- a/postgres/messages/ready_for_query.go +++ b/postgres/messages/ready_for_query.go @@ -14,8 +14,10 @@ package messages +import "github.com/dolthub/doltgresql/postgres/connection" + func init() { - initializeDefaultMessage(ReadyForQuery{}) + connection.InitializeDefaultMessage(ReadyForQuery{}) } // ReadyForQueryTransactionIndicator indicates the state of the transaction related to the query. @@ -32,41 +34,41 @@ type ReadyForQuery struct { Indicator ReadyForQueryTransactionIndicator } -var readyForQueryDefault = MessageFormat{ +var readyForQueryDefault = connection.MessageFormat{ Name: "ReadyForQuery", - Fields: FieldGroup{ + Fields: connection.FieldGroup{ { Name: "Header", - Type: Byte1, - Flags: Header, + Type: connection.Byte1, + Flags: connection.Header, Data: int32('Z'), }, { Name: "MessageLength", - Type: Int32, - Flags: MessageLengthInclusive, + Type: connection.Int32, + Flags: connection.MessageLengthInclusive, Data: int32(5), }, { Name: "TransactionIndicator", - Type: Byte1, + Type: connection.Byte1, Data: int32(0), }, }, } -var _ Message = ReadyForQuery{} +var _ connection.Message = ReadyForQuery{} -// encode implements the interface Message. -func (m ReadyForQuery) encode() (MessageFormat, error) { - outputMessage := m.defaultMessage().Copy() +// Encode implements the interface connection.Message. +func (m ReadyForQuery) Encode() (connection.MessageFormat, error) { + outputMessage := m.DefaultMessage().Copy() outputMessage.Field("TransactionIndicator").MustWrite(byte(m.Indicator)) return outputMessage, nil } -// decode implements the interface Message. -func (m ReadyForQuery) decode(s MessageFormat) (Message, error) { - if err := s.MatchesStructure(*m.defaultMessage()); err != nil { +// Decode implements the interface connection.Message. +func (m ReadyForQuery) Decode(s connection.MessageFormat) (connection.Message, error) { + if err := s.MatchesStructure(*m.DefaultMessage()); err != nil { return nil, err } return ReadyForQuery{ @@ -74,7 +76,7 @@ func (m ReadyForQuery) decode(s MessageFormat) (Message, error) { }, nil } -// defaultMessage implements the interface Message. -func (m ReadyForQuery) defaultMessage() *MessageFormat { +// DefaultMessage implements the interface connection.Message. +func (m ReadyForQuery) DefaultMessage() *connection.MessageFormat { return &readyForQueryDefault } diff --git a/postgres/messages/row_description.go b/postgres/messages/row_description.go index 6b5473ead7..f2d48f61bf 100644 --- a/postgres/messages/row_description.go +++ b/postgres/messages/row_description.go @@ -18,12 +18,13 @@ import ( "fmt" "github.com/dolthub/go-mysql-server/sql" - "github.com/dolthub/vitess/go/vt/proto/query" + + "github.com/dolthub/doltgresql/postgres/connection" ) func init() { - initializeDefaultMessage(RowDescription{}) + connection.InitializeDefaultMessage(RowDescription{}) } // RowDescription represents a RowDescription message intended for the client. @@ -31,60 +32,60 @@ type RowDescription struct { Fields []*query.Field } -var rowDescriptionDefault = MessageFormat{ +var rowDescriptionDefault = connection.MessageFormat{ Name: "RowDescription", - Fields: FieldGroup{ + Fields: connection.FieldGroup{ { Name: "Header", - Type: Byte1, - Flags: Header, + Type: connection.Byte1, + Flags: connection.Header, Data: int32('T'), }, { Name: "MessageLength", - Type: Int32, - Flags: MessageLengthInclusive, + Type: connection.Int32, + Flags: connection.MessageLengthInclusive, Data: int32(0), }, { Name: "Fields", - Type: Int16, + Type: connection.Int16, Data: int32(0), - Children: []FieldGroup{ + Children: []connection.FieldGroup{ { { Name: "ColumnName", - Type: String, + Type: connection.String, Data: "", }, { Name: "TableObjectID", - Type: Int32, + Type: connection.Int32, Data: int32(0), }, { Name: "ColumnAttributeNumber", - Type: Int16, + Type: connection.Int16, Data: int32(0), }, { Name: "DataTypeObjectID", - Type: Int32, + Type: connection.Int32, Data: int32(0), }, { Name: "DataTypeSize", - Type: Int16, + Type: connection.Int16, Data: int32(0), }, { Name: "DataTypeModifier", - Type: Int32, + Type: connection.Int32, Data: int32(0), }, { Name: "FormatCode", - Type: Int16, + Type: connection.Int16, Data: int32(0), }, }, @@ -93,24 +94,24 @@ var rowDescriptionDefault = MessageFormat{ }, } -var _ Message = RowDescription{} +var _ connection.Message = RowDescription{} -// encode implements the interface Message. -func (m RowDescription) encode() (MessageFormat, error) { - outputMessage := m.defaultMessage().Copy() +// Encode implements the interface connection.Message. +func (m RowDescription) Encode() (connection.MessageFormat, error) { + outputMessage := m.DefaultMessage().Copy() for i := 0; i < len(m.Fields); i++ { field := m.Fields[i] dataTypeObjectID, err := VitessFieldToDataTypeObjectID(field) if err != nil { - return MessageFormat{}, err + return connection.MessageFormat{}, err } dataTypeSize, err := VitessFieldToDataTypeSize(field) if err != nil { - return MessageFormat{}, err + return connection.MessageFormat{}, err } dataTypeModifier, err := VitessFieldToDataTypeModifier(field) if err != nil { - return MessageFormat{}, err + return connection.MessageFormat{}, err } outputMessage.Field("Fields").Child("ColumnName", i).MustWrite(field.Name) outputMessage.Field("Fields").Child("DataTypeObjectID", i).MustWrite(dataTypeObjectID) @@ -120,22 +121,16 @@ func (m RowDescription) encode() (MessageFormat, error) { return outputMessage, nil } -// decode implements the interface Message. -func (m RowDescription) decode(s MessageFormat) (Message, error) { - if err := s.MatchesStructure(*m.defaultMessage()); err != nil { +// Decode implements the interface connection.Message. +func (m RowDescription) Decode(s connection.MessageFormat) (connection.Message, error) { + if err := s.MatchesStructure(*m.DefaultMessage()); err != nil { return nil, err } - fieldCount := int(s.Field("Fields").MustGet().(int32)) - for i := 0; i < fieldCount; i++ { - //TODO: decode the message in here - } - return RowDescription{ - Fields: nil, - }, nil + return nil, fmt.Errorf("RowDescription messages do not support decoding, as they're only sent from the server.") } -// defaultMessage implements the interface Message. -func (m RowDescription) defaultMessage() *MessageFormat { +// DefaultMessage implements the interface connection.Message. +func (m RowDescription) DefaultMessage() *connection.MessageFormat { return &rowDescriptionDefault } diff --git a/postgres/messages/sasl_initial_response.go b/postgres/messages/sasl_initial_response.go index f4b7a98c0d..59767fc115 100644 --- a/postgres/messages/sasl_initial_response.go +++ b/postgres/messages/sasl_initial_response.go @@ -14,8 +14,10 @@ package messages +import "github.com/dolthub/doltgresql/postgres/connection" + func init() { - initializeDefaultMessage(SASLInitialResponse{}) + connection.InitializeDefaultMessage(SASLInitialResponse{}) } // SASLInitialResponse represents a PostgreSQL message. @@ -24,45 +26,45 @@ type SASLInitialResponse struct { Response []byte } -var sASLInitialResponseDefault = MessageFormat{ +var sASLInitialResponseDefault = connection.MessageFormat{ Name: "SASLInitialResponse", - Fields: FieldGroup{ + Fields: connection.FieldGroup{ { Name: "Header", - Type: Byte1, - Flags: Header, + Type: connection.Byte1, + Flags: connection.Header, Data: int32('p'), }, { Name: "MessageLength", - Type: Int32, - Flags: MessageLengthInclusive, + Type: connection.Int32, + Flags: connection.MessageLengthInclusive, Data: int32(0), }, { Name: "Name", - Type: String, + Type: connection.String, Data: "", }, { Name: "ResponseLength", - Type: Int32, - Flags: ByteCount, + Type: connection.Int32, + Flags: connection.ByteCount, Data: int32(-1), }, { Name: "ResponseData", - Type: String, + Type: connection.String, Data: "", }, }, } -var _ Message = SASLInitialResponse{} +var _ connection.Message = SASLInitialResponse{} -// encode implements the interface Message. -func (m SASLInitialResponse) encode() (MessageFormat, error) { - outputMessage := m.defaultMessage().Copy() +// Encode implements the interface connection.Message. +func (m SASLInitialResponse) Encode() (connection.MessageFormat, error) { + outputMessage := m.DefaultMessage().Copy() outputMessage.Field("Name").MustWrite(m.Name) if len(m.Response) > 0 { outputMessage.Field("ResponseLength").MustWrite(len(m.Response)) @@ -71,9 +73,9 @@ func (m SASLInitialResponse) encode() (MessageFormat, error) { return outputMessage, nil } -// decode implements the interface Message. -func (m SASLInitialResponse) decode(s MessageFormat) (Message, error) { - if err := s.MatchesStructure(*m.defaultMessage()); err != nil { +// Decode implements the interface connection.Message. +func (m SASLInitialResponse) Decode(s connection.MessageFormat) (connection.Message, error) { + if err := s.MatchesStructure(*m.DefaultMessage()); err != nil { return nil, err } var responseData []byte @@ -86,7 +88,7 @@ func (m SASLInitialResponse) decode(s MessageFormat) (Message, error) { }, nil } -// defaultMessage implements the interface Message. -func (m SASLInitialResponse) defaultMessage() *MessageFormat { +// DefaultMessage implements the interface connection.Message. +func (m SASLInitialResponse) DefaultMessage() *connection.MessageFormat { return &sASLInitialResponseDefault } diff --git a/postgres/messages/sasl_response.go b/postgres/messages/sasl_response.go index b40d5ec019..b5bc3fa4cb 100644 --- a/postgres/messages/sasl_response.go +++ b/postgres/messages/sasl_response.go @@ -14,8 +14,10 @@ package messages +import "github.com/dolthub/doltgresql/postgres/connection" + func init() { - initializeDefaultMessage(SASLResponse{}) + connection.InitializeDefaultMessage(SASLResponse{}) } // SASLResponse represents a PostgreSQL message. @@ -23,41 +25,41 @@ type SASLResponse struct { Data []byte } -var sASLResponseDefault = MessageFormat{ +var sASLResponseDefault = connection.MessageFormat{ Name: "SASLResponse", - Fields: FieldGroup{ + Fields: connection.FieldGroup{ { Name: "Header", - Type: Byte1, - Flags: Header, + Type: connection.Byte1, + Flags: connection.Header, Data: int32('p'), }, { Name: "MessageLength", - Type: Int32, - Flags: MessageLengthInclusive, + Type: connection.Int32, + Flags: connection.MessageLengthInclusive, Data: int32(0), }, { Name: "Data", - Type: ByteN, + Type: connection.ByteN, Data: []byte{}, }, }, } -var _ Message = SASLResponse{} +var _ connection.Message = SASLResponse{} -// encode implements the interface Message. -func (m SASLResponse) encode() (MessageFormat, error) { - outputMessage := m.defaultMessage().Copy() +// Encode implements the interface connection.Message. +func (m SASLResponse) Encode() (connection.MessageFormat, error) { + outputMessage := m.DefaultMessage().Copy() outputMessage.Field("Data").MustWrite(m.Data) return outputMessage, nil } -// decode implements the interface Message. -func (m SASLResponse) decode(s MessageFormat) (Message, error) { - if err := s.MatchesStructure(*m.defaultMessage()); err != nil { +// Decode implements the interface connection.Message. +func (m SASLResponse) Decode(s connection.MessageFormat) (connection.Message, error) { + if err := s.MatchesStructure(*m.DefaultMessage()); err != nil { return nil, err } return SASLResponse{ @@ -65,7 +67,7 @@ func (m SASLResponse) decode(s MessageFormat) (Message, error) { }, nil } -// defaultMessage implements the interface Message. -func (m SASLResponse) defaultMessage() *MessageFormat { +// DefaultMessage implements the interface connection.Message. +func (m SASLResponse) DefaultMessage() *connection.MessageFormat { return &sASLResponseDefault } diff --git a/postgres/messages/ssl_request.go b/postgres/messages/ssl_request.go index aa9398ea4e..cad3f910ef 100644 --- a/postgres/messages/ssl_request.go +++ b/postgres/messages/ssl_request.go @@ -14,46 +14,48 @@ package messages +import "github.com/dolthub/doltgresql/postgres/connection" + func init() { - initializeDefaultMessage(SSLRequest{}) + connection.InitializeDefaultMessage(SSLRequest{}) } // SSLRequest represents a PostgreSQL message. type SSLRequest struct{} -var sslRequestDefault = MessageFormat{ +var sslRequestDefault = connection.MessageFormat{ Name: "SSLRequest", - Fields: FieldGroup{ + Fields: connection.FieldGroup{ { Name: "MessageLength", - Type: Int32, - Flags: MessageLengthInclusive, + Type: connection.Int32, + Flags: connection.MessageLengthInclusive, Data: int32(8), }, { Name: "RequestCode", - Type: Int32, + Type: connection.Int32, Data: int32(80877103), }, }, } -var _ Message = SSLRequest{} +var _ connection.Message = SSLRequest{} -// encode implements the interface Message. -func (m SSLRequest) encode() (MessageFormat, error) { - return m.defaultMessage().Copy(), nil +// Encode implements the interface connection.Message. +func (m SSLRequest) Encode() (connection.MessageFormat, error) { + return m.DefaultMessage().Copy(), nil } -// decode implements the interface Message. -func (m SSLRequest) decode(s MessageFormat) (Message, error) { - if err := s.MatchesStructure(*m.defaultMessage()); err != nil { +// Decode implements the interface connection.Message. +func (m SSLRequest) Decode(s connection.MessageFormat) (connection.Message, error) { + if err := s.MatchesStructure(*m.DefaultMessage()); err != nil { return nil, err } return SSLRequest{}, nil } -// defaultMessage implements the interface Message. -func (m SSLRequest) defaultMessage() *MessageFormat { +// DefaultMessage implements the interface connection.Message. +func (m SSLRequest) DefaultMessage() *connection.MessageFormat { return &sslRequestDefault } diff --git a/postgres/messages/ssl_response.go b/postgres/messages/ssl_response.go index b09667cf52..b754b8490b 100644 --- a/postgres/messages/ssl_response.go +++ b/postgres/messages/ssl_response.go @@ -14,10 +14,14 @@ package messages -import "fmt" +import ( + "fmt" + + "github.com/dolthub/doltgresql/postgres/connection" +) func init() { - initializeDefaultMessage(SSLResponse{}) + connection.InitializeDefaultMessage(SSLResponse{}) } // SSLResponse tells the client whether SSL is supported. @@ -25,22 +29,22 @@ type SSLResponse struct { SupportsSSL bool } -var sslResponseDefault = MessageFormat{ +var sslResponseDefault = connection.MessageFormat{ Name: "SSLResponse", - Fields: FieldGroup{ + Fields: connection.FieldGroup{ { Name: "Supported", - Type: Byte1, + Type: connection.Byte1, Data: int32(0), }, }, } -var _ Message = SSLResponse{} +var _ connection.Message = SSLResponse{} -// encode implements the interface Message. -func (m SSLResponse) encode() (MessageFormat, error) { - outputMessage := m.defaultMessage().Copy() +// Encode implements the interface connection.Message. +func (m SSLResponse) Encode() (connection.MessageFormat, error) { + outputMessage := m.DefaultMessage().Copy() if m.SupportsSSL { outputMessage.Field("Supported").MustWrite('Y') } else { @@ -49,9 +53,9 @@ func (m SSLResponse) encode() (MessageFormat, error) { return outputMessage, nil } -// decode implements the interface Message. -func (m SSLResponse) decode(s MessageFormat) (Message, error) { - if err := s.MatchesStructure(*m.defaultMessage()); err != nil { +// Decode implements the interface connection.Message. +func (m SSLResponse) Decode(s connection.MessageFormat) (connection.Message, error) { + if err := s.MatchesStructure(*m.DefaultMessage()); err != nil { return nil, err } var supported bool @@ -68,7 +72,7 @@ func (m SSLResponse) decode(s MessageFormat) (Message, error) { }, nil } -// defaultMessage implements the interface Message. -func (m SSLResponse) defaultMessage() *MessageFormat { +// DefaultMessage implements the interface connection.Message. +func (m SSLResponse) DefaultMessage() *connection.MessageFormat { return &sslResponseDefault } diff --git a/postgres/messages/startup_message.go b/postgres/messages/startup_message.go index db675be839..efc1bebf6c 100644 --- a/postgres/messages/startup_message.go +++ b/postgres/messages/startup_message.go @@ -14,8 +14,10 @@ package messages +import "github.com/dolthub/doltgresql/postgres/connection" + func init() { - initializeDefaultMessage(StartupMessage{}) + connection.InitializeDefaultMessage(StartupMessage{}) } // StartupMessage is returned by the client upon connecting to the server, providing details about the client. @@ -25,40 +27,40 @@ type StartupMessage struct { Parameters map[string]string } -var startupMessageDefault = MessageFormat{ +var startupMessageDefault = connection.MessageFormat{ Name: "StartupMessage", - Fields: FieldGroup{ + Fields: connection.FieldGroup{ { Name: "MessageLength", - Type: Int32, - Flags: MessageLengthInclusive, + Type: connection.Int32, + Flags: connection.MessageLengthInclusive, Data: int32(0), }, { // The docs specify a single Int32 field, but the upper and lower bits are different values, so this just splits them Name: "ProtocolMajorVersion", - Type: Int16, + Type: connection.Int16, Data: int32(0), }, { Name: "ProtocolMinorVersion", - Type: Int16, + Type: connection.Int16, Data: int32(0), }, { Name: "Parameters", - Type: Repeated, - Flags: RepeatedTerminator, + Type: connection.Repeated, + Flags: connection.RepeatedTerminator, Data: int32(0), - Children: []FieldGroup{ + Children: []connection.FieldGroup{ { { Name: "ParameterName", - Type: String, + Type: connection.String, Data: "", }, { Name: "ParameterValue", - Type: String, + Type: connection.String, Data: "", }, }, @@ -67,11 +69,11 @@ var startupMessageDefault = MessageFormat{ }, } -var _ Message = StartupMessage{} +var _ connection.Message = StartupMessage{} -// encode implements the interface Message. -func (m StartupMessage) encode() (MessageFormat, error) { - outputMessage := m.defaultMessage().Copy() +// Encode implements the interface connection.Message. +func (m StartupMessage) Encode() (connection.MessageFormat, error) { + outputMessage := m.DefaultMessage().Copy() outputMessage.Field("ProtocolMajorVersion").MustWrite(m.ProtocolMajorVersion) outputMessage.Field("ProtocolMinorVersion").MustWrite(m.ProtocolMinorVersion) index := 0 @@ -83,9 +85,9 @@ func (m StartupMessage) encode() (MessageFormat, error) { return outputMessage, nil } -// decode implements the interface Message. -func (m StartupMessage) decode(s MessageFormat) (Message, error) { - if err := s.MatchesStructure(*m.defaultMessage()); err != nil { +// Decode implements the interface connection.Message. +func (m StartupMessage) Decode(s connection.MessageFormat) (connection.Message, error) { + if err := s.MatchesStructure(*m.DefaultMessage()); err != nil { return nil, err } parameters := make(map[string]string) @@ -101,7 +103,7 @@ func (m StartupMessage) decode(s MessageFormat) (Message, error) { }, nil } -// defaultMessage implements the interface Message. -func (m StartupMessage) defaultMessage() *MessageFormat { +// DefaultMessage implements the interface connection.Message. +func (m StartupMessage) DefaultMessage() *connection.MessageFormat { return &startupMessageDefault } diff --git a/postgres/messages/sync.go b/postgres/messages/sync.go index 70dee95c52..51f1c7e1cb 100644 --- a/postgres/messages/sync.go +++ b/postgres/messages/sync.go @@ -14,48 +14,50 @@ package messages +import "github.com/dolthub/doltgresql/postgres/connection" + func init() { - initializeDefaultMessage(Sync{}) - addMessageHeader(Sync{}) + connection.InitializeDefaultMessage(Sync{}) + connection.AddMessageHeader(Sync{}) } // Sync represents a PostgreSQL message. type Sync struct{} -var syncDefault = MessageFormat{ +var syncDefault = connection.MessageFormat{ Name: "Sync", - Fields: FieldGroup{ + Fields: connection.FieldGroup{ { Name: "Header", - Type: Byte1, - Flags: Header, + Type: connection.Byte1, + Flags: connection.Header, Data: int32('S'), }, { Name: "MessageLength", - Type: Int32, - Flags: MessageLengthInclusive, + Type: connection.Int32, + Flags: connection.MessageLengthInclusive, Data: int32(4), }, }, } -var _ Message = Sync{} +var _ connection.Message = Sync{} -// encode implements the interface Message. -func (m Sync) encode() (MessageFormat, error) { - return m.defaultMessage().Copy(), nil +// Encode implements the interface connection.Message. +func (m Sync) Encode() (connection.MessageFormat, error) { + return m.DefaultMessage().Copy(), nil } -// decode implements the interface Message. -func (m Sync) decode(s MessageFormat) (Message, error) { - if err := s.MatchesStructure(*m.defaultMessage()); err != nil { +// Decode implements the interface connection.Message. +func (m Sync) Decode(s connection.MessageFormat) (connection.Message, error) { + if err := s.MatchesStructure(*m.DefaultMessage()); err != nil { return nil, err } return Sync{}, nil } -// defaultMessage implements the interface Message. -func (m Sync) defaultMessage() *MessageFormat { +// DefaultMessage implements the interface connection.Message. +func (m Sync) DefaultMessage() *connection.MessageFormat { return &syncDefault } diff --git a/postgres/messages/terminate.go b/postgres/messages/terminate.go index b0b6074f25..b218fe89d2 100644 --- a/postgres/messages/terminate.go +++ b/postgres/messages/terminate.go @@ -14,48 +14,50 @@ package messages +import "github.com/dolthub/doltgresql/postgres/connection" + func init() { - initializeDefaultMessage(Terminate{}) - addMessageHeader(Terminate{}) + connection.InitializeDefaultMessage(Terminate{}) + connection.AddMessageHeader(Terminate{}) } // Terminate tells the server to close the connection. type Terminate struct{} -var terminateDefault = MessageFormat{ +var terminateDefault = connection.MessageFormat{ Name: "Terminate", - Fields: FieldGroup{ + Fields: connection.FieldGroup{ { Name: "Header", - Type: Byte1, - Flags: Header, + Type: connection.Byte1, + Flags: connection.Header, Data: int32('X'), }, { Name: "MessageLength", - Type: Int32, - Flags: MessageLengthInclusive, + Type: connection.Int32, + Flags: connection.MessageLengthInclusive, Data: int32(0), }, }, } -var _ Message = Terminate{} +var _ connection.Message = Terminate{} -// encode implements the interface Message. -func (m Terminate) encode() (MessageFormat, error) { +// Encode implements the interface connection.Message. +func (m Terminate) Encode() (connection.MessageFormat, error) { return terminateDefault.Copy(), nil } -// decode implements the interface Message. -func (m Terminate) decode(s MessageFormat) (Message, error) { - if err := s.MatchesStructure(*m.defaultMessage()); err != nil { +// Decode implements the interface connection.Message. +func (m Terminate) Decode(s connection.MessageFormat) (connection.Message, error) { + if err := s.MatchesStructure(*m.DefaultMessage()); err != nil { return nil, err } return Terminate{}, nil } -// defaultMessage implements the interface Message. -func (m Terminate) defaultMessage() *MessageFormat { +// DefaultMessage implements the interface connection.Message. +func (m Terminate) DefaultMessage() *connection.MessageFormat { return &terminateDefault } diff --git a/postgres/messages/utils.go b/utils/stack.go similarity index 69% rename from postgres/messages/utils.go rename to utils/stack.go index 7f15acfedd..90f17a786e 100644 --- a/postgres/messages/utils.go +++ b/utils/stack.go @@ -12,32 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package messages - -import ( - "bytes" - "encoding/binary" - - "golang.org/x/exp/constraints" -) - -//TODO: delete these Write functions - -// WriteLength writes the length of the message into the byte slice. Modifies the given byte slice, while also -// returning the same slice. Assumes that the first byte is the message identifier, while the next 4 bytes are -// the length. -func WriteLength(b []byte) []byte { - // We never include the message identifier in the length. - // Technically, the length field is an int32, however we'll assume that our return values will be under 2GB for now. - length := uint32(len(b) - 1) - binary.BigEndian.PutUint32(b[1:], length) - return b -} - -// WriteNumber writes the given number to the buffer. -func WriteNumber[T constraints.Integer | constraints.Float](buf *bytes.Buffer, num T) { - _ = binary.Write(buf, binary.BigEndian, num) -} +package utils // Stack is a generic stack. type Stack[T any] struct {