Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ var doltCommand = cli.NewSubCommandHandler("doltgresql", "it's git for data", []
var globalArgParser = cli.CreateGlobalArgParser("doltgresql")

func init() {
server.DefaultProtocolListenerFunc = postgres.NewListenerWithConfig
server.DefaultProtocolListenerFunc = postgres.NewListener
sqlserver.ExternalDisableUsers = true
dfunctions.VersionString = Version
}
Expand Down
59 changes: 59 additions & 0 deletions postgres/connection/connection.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// Copyright 2023 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package connection

import "net"

// Receive returns a Message from the given buffer, generally generated by the client in the main read loop of a
// connection.
func Receive(buffer []byte) (Message, bool, error) {
if len(buffer) == 0 {
return nil, false, nil
}
message, ok := allMessageHeaders[buffer[0]]
if !ok {
return nil, false, nil
}
outMessage, err := ReceiveInto(buffer, message)
return outMessage, true, err
}

// ReceiveInto writes the contents of the buffer into the given Message.
func ReceiveInto[T Message](buffer []byte, message T) (out T, err error) {
defaultMessage := message.DefaultMessage()
fields := defaultMessage.Copy().Fields
if err = decode(&decodeBuffer{buffer}, []FieldGroup{fields}, 1); err != nil {
return out, err
}
decodedMessage, err := message.Decode(MessageFormat{defaultMessage.Name, fields, defaultMessage.info, false})
if err != nil {
return out, err
}
return decodedMessage.(T), nil
}

// Send sends the given message over the connection.
func Send(conn net.Conn, message Message) error {
encodedMessage, err := message.Encode()
if err != nil {
return err
}
data, err := encode(encodedMessage)
if err != nil {
return err
}
_, err = conn.Write(data)
return err
}
14 changes: 7 additions & 7 deletions postgres/messages/message.go → postgres/connection/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

package messages
package connection

import (
"fmt"
Expand All @@ -30,14 +30,14 @@ type MessageFormat struct {

// Message is a type that represents a PostgreSQL message.
type Message interface {
// encode returns a new MessageFormat containing any modified data contained within the object. This should NOT be
// Encode returns a new MessageFormat containing any modified data contained within the object. This should NOT be
// the default message.
encode() (MessageFormat, error)
// decode returns a new Message that represents the given MessageFormat. You should never return the default
Encode() (MessageFormat, error)
// Decode returns a new Message that represents the given MessageFormat. You should never return the default
// message, even if the message never varies from the default. Always make a copy, and then modify that copy.
decode(s MessageFormat) (Message, error)
// defaultMessage returns the default, unmodified message for this type.
defaultMessage() *MessageFormat
Decode(s MessageFormat) (Message, error)
// DefaultMessage returns the default, unmodified message for this type.
DefaultMessage() *MessageFormat
}

// messageFieldInfo contains information on a specific field within a messageInfo.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,58 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.

package messages
package connection

import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"net"
)

// Receive returns a Message from the given buffer, generally generated by the client in the main read loop of a
// connection.
func Receive(buffer []byte) (Message, bool, error) {
if len(buffer) == 0 {
return nil, false, nil
}
message, ok := allMessageHeaders[buffer[0]]
if !ok {
return nil, false, nil
}
outMessage, err := ReceiveInto(buffer, message)
return outMessage, true, err
}

// ReceiveInto writes the contents of the buffer into the given Message.
func ReceiveInto[T Message](buffer []byte, message T) (out T, err error) {
defaultMessage := message.defaultMessage()
fields := defaultMessage.Copy().Fields
if err = decode(&decodeBuffer{buffer}, []FieldGroup{fields}, 1); err != nil {
return out, err
}
decodedMessage, err := message.decode(MessageFormat{defaultMessage.Name, fields, defaultMessage.info, false})
if err != nil {
return out, err
}
return decodedMessage.(T), nil
}

// Send sends the given message over the connection.
func Send(conn net.Conn, message Message) error {
encodedMessage, err := message.encode()
if err != nil {
return err
}
data, err := encode(encodedMessage)
if err != nil {
return err
}
_, err = conn.Write(data)
return err
}

// decodeBuffer just provides an easy way to reference the same buffer, so that decode can modify its length.
type decodeBuffer struct {
data []byte
Expand Down Expand Up @@ -241,7 +198,6 @@ func encode(ms MessageFormat) ([]byte, error) {
if data[bufferIdx] == 0 {
found = true
byteOffset += int32(bufferIdx)
//TODO: is this the correct place to put this? investigate/test
if field.Flags&ExcludeTerminator == 0 {
byteOffset += 1
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

package messages
package connection

// FieldType is the type of the field as defined by PostgreSQL.
type FieldType byte
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.

package messages
package connection

import "fmt"
import (
"fmt"

"github.com/dolthub/doltgresql/utils"
)

// allMessageHeaders contains any message headers that should be read within the main read loop of a connection.
var allMessageHeaders = make(map[byte]Message)
Expand All @@ -25,26 +29,26 @@ var allMessageNames = make(map[string]struct{})
// allMessageDefaults contains all of the default message pointers, to make sure that they're not accidentally being reused.
var allMessageDefaults = make(map[*MessageFormat]struct{})

// addMessageHeader adds the given Message's header. This also ensures that each header is unique. This should be
// AddMessageHeader adds the given Message's header. This also ensures that each header is unique. This should be
// called in an init() function.
func addMessageHeader(message Message) {
for _, field := range message.defaultMessage().Fields {
func AddMessageHeader(message Message) {
for _, field := range message.DefaultMessage().Fields {
if field.Flags&Header != 0 {
header := byte(field.Data.(int32))
if _, ok := allMessageHeaders[header]; ok {
panic(fmt.Errorf("Header already taken.\nMessageFormat:\n\n%s", message.defaultMessage().String()))
panic(fmt.Errorf("Header already taken.\nMessageFormat:\n\n%s", message.DefaultMessage().String()))
}
allMessageHeaders[header] = message
return
}
}
panic(fmt.Errorf("Header does not exist.\nMessageFormat:\n\n%s", message.defaultMessage().String()))
panic(fmt.Errorf("Header does not exist.\nMessageFormat:\n\n%s", message.DefaultMessage().String()))
}

// initializeDefaultMessage creates the internal structure of the default message, while ensuring that the structure of
// InitializeDefaultMessage creates the internal structure of the default message, while ensuring that the structure of
// the message is correct. This should be called in an init() function.
func initializeDefaultMessage(messageType Message) {
message := messageType.defaultMessage()
func InitializeDefaultMessage(messageType Message) {
message := messageType.DefaultMessage()
if _, ok := allMessageDefaults[message]; ok {
panic(fmt.Errorf("MessageFormat default was used in another message.\nMessageFormat:\n\n%s", message.String()))
}
Expand All @@ -69,7 +73,7 @@ func initializeDefaultMessage(messageType Message) {
Fields FieldGroup
}

ftStack := NewStack[FieldTraversal]()
ftStack := utils.NewStack[FieldTraversal]()
ftStack.Push(FieldTraversal{0, message.Fields})
for !ftStack.Empty() {
// If we're at the end of the loop for this stacked entry, then we pop it and move to the next
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

package messages
package connection

import "fmt"

Expand Down
37 changes: 19 additions & 18 deletions postgres/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,28 +24,29 @@ import (
"github.com/dolthub/vitess/go/mysql"
"github.com/dolthub/vitess/go/sqltypes"

"github.com/dolthub/doltgresql/postgres/connection"
"github.com/dolthub/doltgresql/postgres/messages"
)

var connectionIDCounter uint32

// TODO: doc
// Listener listens for connections to process PostgreSQL requests into Dolt requests.
type Listener struct {
listener net.Listener
cfg mysql.ListenerConfig
}

var _ server.ProtocolListener = (*Listener)(nil)

// TODO: doc
func NewListenerWithConfig(listenerCfg mysql.ListenerConfig) (server.ProtocolListener, error) {
// NewListener creates a new Listener.
func NewListener(listenerCfg mysql.ListenerConfig) (server.ProtocolListener, error) {
return &Listener{
listener: listenerCfg.Listener,
cfg: listenerCfg,
}, nil
}

// TODO: doc
// Accept handles incoming connections.
func (l *Listener) Accept() {
for {
conn, err := l.listener.Accept()
Expand All @@ -58,12 +59,12 @@ func (l *Listener) Accept() {
}
}

// TODO: doc
// Close stops the handling of incoming connections.
func (l *Listener) Close() {
_ = l.listener.Close()
}

// TODO: doc
// Addr returns the address that the listener is listening on.
func (l *Listener) Addr() net.Addr {
return l.listener.Addr()
}
Expand Down Expand Up @@ -93,7 +94,7 @@ func (l *Listener) HandleConnection(conn net.Conn) {
return
}

if err = messages.Send(conn, messages.SSLResponse{
if err = connection.Send(conn, messages.SSLResponse{
SupportsSSL: false,
}); err != nil {
fmt.Println(err)
Expand All @@ -107,41 +108,41 @@ func (l *Listener) HandleConnection(conn net.Conn) {
}
return
}
startupMessage, err := messages.ReceiveInto(buf, messages.StartupMessage{})
startupMessage, err := connection.ReceiveInto(buf, messages.StartupMessage{})
if err != nil {
fmt.Println(err)
return
}

if err = messages.Send(conn, messages.AuthenticationOk{}); err != nil {
if err = connection.Send(conn, messages.AuthenticationOk{}); err != nil {
fmt.Println(err)
return
}

if err = messages.Send(conn, messages.ParameterStatus{
if err = connection.Send(conn, messages.ParameterStatus{
Name: "server_version",
Value: "15.0",
}); err != nil {
fmt.Println(err)
return
}
if err = messages.Send(conn, messages.ParameterStatus{
if err = connection.Send(conn, messages.ParameterStatus{
Name: "client_encoding",
Value: "UTF8",
}); err != nil {
fmt.Println(err)
return
}

if err = messages.Send(conn, messages.BackendKeyData{
if err = connection.Send(conn, messages.BackendKeyData{
ProcessID: 1,
SecretKey: 0,
}); err != nil {
fmt.Println(err)
return
}

if err = messages.Send(conn, messages.ReadyForQuery{
if err = connection.Send(conn, messages.ReadyForQuery{
Indicator: messages.ReadyForQueryTransactionIndicator_Idle,
}); err != nil {
fmt.Println(err)
Expand All @@ -162,7 +163,7 @@ func (l *Listener) HandleConnection(conn net.Conn) {
return
}

message, ok, err := messages.Receive(buf)
message, ok, err := connection.Receive(buf)
if err != nil {
fmt.Println(err.Error())
return
Expand All @@ -187,14 +188,14 @@ func (l *Listener) query(conn net.Conn, mysqlConn *mysql.Conn, query string) {
}

if err := l.cfg.Handler.ComQuery(mysqlConn, query, func(res *sqltypes.Result, more bool) error {
if err := messages.Send(conn, messages.RowDescription{
if err := connection.Send(conn, messages.RowDescription{
Fields: res.Fields,
}); err != nil {
return err
}

for _, row := range res.Rows {
if err := messages.Send(conn, messages.DataRow{
if err := connection.Send(conn, messages.DataRow{
Values: row,
}); err != nil {
return err
Expand All @@ -212,12 +213,12 @@ func (l *Listener) query(conn net.Conn, mysqlConn *mysql.Conn, query string) {
return
}

if err := messages.Send(conn, commandComplete); err != nil {
if err := connection.Send(conn, commandComplete); err != nil {
fmt.Println(err)
return
}

if err := messages.Send(conn, messages.ReadyForQuery{
if err := connection.Send(conn, messages.ReadyForQuery{
Indicator: messages.ReadyForQueryTransactionIndicator_Idle,
}); err != nil {
fmt.Println(err)
Expand Down
Loading