diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml
deleted file mode 100644
index 1a207faa04..0000000000
--- a/.github/workflows/format.yml
+++ /dev/null
@@ -1,89 +0,0 @@
-name: Check/Format PR
-
-on:
- pull_request:
- branches: [ main ]
-
-concurrency:
- group: format-${{ github.event.pull_request.number || github.ref }}
- cancel-in-progress: true
-
-jobs:
- format:
- name: Format PR
- if: github.event.pull_request.head.repo.full_name == github.repository
- runs-on: ubuntu-22.04
- steps:
- - name: Setup Go 1.x
- uses: actions/setup-go@v3
- with:
- go-version: ^1.20
- - uses: actions/checkout@v3
- with:
- token: ${{ secrets.REPO_ACCESS_TOKEN || secrets.GITHUB_TOKEN }}
- ref: ${{ github.head_ref }}
- - name: Install goimports
- run: |
- go mod download golang.org/x/tools
- go install golang.org/x/tools/cmd/goimports
- - name: Format repo
- run: |
- ./format_repo.sh
- env:
- BRANCH_NAME: ${{ github.head_ref }}
- CHANGE_TARGET: ${{ github.base_ref }}
- - name: Changes detected
- id: detect-changes
- run: |
- changes=$(git status --porcelain)
- if [ ! -z "$changes" ]; then
- echo "has-changes=true" >> $GITHUB_OUTPUT
- fi
- - uses: EndBug/add-and-commit@v9.1.1
- if: ${{ steps.detect-changes.outputs.has-changes == 'true' }}
- with:
- message: "[ga-format-pr] Run ./format_repo.sh to fix formatting"
- add: "."
- cwd: "."
- verify:
- needs: format
- name: Verify format
- runs-on: ubuntu-22.04
- steps:
- - name: Setup Go 1.x
- uses: actions/setup-go@v3
- with:
- go-version: ^1.20
- id: go
- - uses: actions/checkout@v3
- with:
- ref: ${{ github.head_ref }}
- - name: Check all
- id: check_format
- run: |
- ./check_repo.sh
- env:
- BRANCH_NAME: ${{ github.head_ref }}
- CHANGE_TARGET: ${{ github.base_ref }}
- alt-verify:
- if: github.event.pull_request.head.repo.full_name != github.repository
- name: Verify format
- runs-on: ubuntu-22.04
- steps:
- - name: Setup Go 1.x
- uses: actions/setup-go@v3
- with:
- go-version: ^1.20
- id: go
- - uses: actions/checkout@v3
- - name: Check all
- id: check_format
- run: |
- ./check_repo.sh
- code=$(echo $?)
- if [ "$code" != 0 ]; then
- echo "Please run ./format_repo.sh to fix this pull request's formatting"
- fi
- env:
- BRANCH_NAME: ${{ github.head_ref }}
- CHANGE_TARGET: ${{ github.base_ref }}
diff --git a/fileno_check.go b/fileno_check.go
index 5a4b302667..0b9d9f075b 100644
--- a/fileno_check.go
+++ b/fileno_check.go
@@ -1,4 +1,4 @@
-// Copyright 2019 Dolthub, Inc.
+// Copyright 2023 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
diff --git a/fileno_check_darwin.go b/fileno_check_darwin.go
index 0223dbf2fa..78d9370c14 100644
--- a/fileno_check_darwin.go
+++ b/fileno_check_darwin.go
@@ -1,4 +1,4 @@
-// Copyright 2019 Dolthub, Inc.
+// Copyright 2023 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
diff --git a/fileno_check_linux.go b/fileno_check_linux.go
index 6c85f9d1df..1c06e4e285 100644
--- a/fileno_check_linux.go
+++ b/fileno_check_linux.go
@@ -1,4 +1,4 @@
-// Copyright 2019 Dolthub, Inc.
+// Copyright 2023 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
diff --git a/go.mod b/go.mod
index 41f9158374..cc19f2f2c2 100644
--- a/go.mod
+++ b/go.mod
@@ -3,10 +3,16 @@ module github.com/dolthub/doltgresql
go 1.20
require (
- github.com/dolthub/dolt/go v0.40.5-0.20230911151544-6c18c2cbfebe
- github.com/dolthub/vitess v0.0.0-20230823204737-4a21a94e90c3
+ github.com/dolthub/dolt/go v0.40.5-0.20230917024726-1ae3a49864ac
+ github.com/dolthub/go-mysql-server v0.17.1-0.20230916212652-86f1cdf0339c
+ github.com/dolthub/vitess v0.0.0-20230915082726-ef1b92774b14
+ github.com/fatih/color v1.13.0
+ github.com/jackc/pgx/v5 v5.4.3
github.com/shopspring/decimal v1.2.0
+ github.com/stretchr/testify v1.8.2
+ github.com/tidwall/gjson v1.14.4
golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1
+ golang.org/x/sys v0.10.0
)
require (
@@ -26,18 +32,17 @@ require (
github.com/cenkalti/backoff/v4 v4.1.3 // indirect
github.com/cespare/xxhash v1.1.0 // indirect
github.com/cespare/xxhash/v2 v2.2.0 // indirect
+ github.com/davecgh/go-spew v1.1.1 // indirect
github.com/denisbrodbeck/machineid v1.0.1 // indirect
- github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20230911151544-6c18c2cbfebe // indirect
+ github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20230917024726-1ae3a49864ac // indirect
github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2 // indirect
github.com/dolthub/fslock v0.0.3 // indirect
github.com/dolthub/go-icu-regex v0.0.0-20230524105445-af7e7991c97e // indirect
- github.com/dolthub/go-mysql-server v0.17.1-0.20230911134248-fb77f4090d9b // indirect
github.com/dolthub/ishell v0.0.0-20221214210346-d7db0b066488 // indirect
github.com/dolthub/jsonpath v0.0.2-0.20230525180605-8dc13778fd72 // indirect
github.com/dolthub/maphash v0.0.0-20221220182448-74e1e1ea1577 // indirect
github.com/dolthub/swiss v0.1.0 // indirect
github.com/dustin/go-humanize v1.0.0 // indirect
- github.com/fatih/color v1.13.0 // indirect
github.com/flynn-archive/go-shlex v0.0.0-20150515145356-3f9db97f8568 // indirect
github.com/go-kit/kit v0.10.0 // indirect
github.com/go-logr/logr v1.2.3 // indirect
@@ -56,6 +61,8 @@ require (
github.com/googleapis/gax-go/v2 v2.11.0 // indirect
github.com/hashicorp/golang-lru v0.5.4 // indirect
github.com/hashicorp/golang-lru/v2 v2.0.2 // indirect
+ github.com/jackc/pgpassfile v1.0.0 // indirect
+ github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jmespath/go-jmespath v0.4.0 // indirect
github.com/jpillora/backoff v1.0.0 // indirect
github.com/juju/gnuflag v0.0.0-20171113085948-2ce1bb71843d // indirect
@@ -72,20 +79,19 @@ require (
github.com/mitchellh/hashstructure v1.1.0 // indirect
github.com/pierrec/lz4/v4 v4.1.6 // indirect
github.com/pkg/errors v0.9.1 // indirect
- github.com/pkg/profile v1.5.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/client_golang v1.13.0 // indirect
github.com/prometheus/client_model v0.2.0 // indirect
github.com/prometheus/common v0.37.0 // indirect
github.com/prometheus/procfs v0.8.0 // indirect
github.com/rivo/uniseg v0.2.0 // indirect
+ github.com/rogpeppe/go-internal v1.11.0 // indirect
github.com/sergi/go-diff v1.1.0 // indirect
github.com/silvasur/buzhash v0.0.0-20160816060738-9bdec3dec7c6 // indirect
github.com/sirupsen/logrus v1.8.1 // indirect
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 // indirect
github.com/tealeg/xlsx v1.0.5 // indirect
github.com/tetratelabs/wazero v1.1.0 // indirect
- github.com/tidwall/gjson v1.14.4 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
github.com/tidwall/sjson v1.2.5 // indirect
@@ -95,18 +101,15 @@ require (
github.com/zeebo/xxh3 v1.0.2 // indirect
go.opencensus.io v0.24.0 // indirect
go.opentelemetry.io/otel v1.7.0 // indirect
- go.opentelemetry.io/otel/exporters/jaeger v1.7.0 // indirect
- go.opentelemetry.io/otel/sdk v1.7.0 // indirect
go.opentelemetry.io/otel/trace v1.7.0 // indirect
go.uber.org/atomic v1.7.0 // indirect
go.uber.org/multierr v1.6.0 // indirect
go.uber.org/zap v1.24.0 // indirect
golang.org/x/crypto v0.11.0 // indirect
- golang.org/x/mod v0.8.0 // indirect
+ golang.org/x/mod v0.9.0 // indirect
golang.org/x/net v0.12.0 // indirect
golang.org/x/oauth2 v0.8.0 // indirect
golang.org/x/sync v0.2.0 // indirect
- golang.org/x/sys v0.10.0 // indirect
golang.org/x/term v0.10.0 // indirect
golang.org/x/text v0.11.0 // indirect
golang.org/x/time v0.1.0 // indirect
@@ -122,4 +125,5 @@ require (
gopkg.in/square/go-jose.v2 v2.5.1 // indirect
gopkg.in/src-d/go-errors.v1 v1.0.0 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
+ gopkg.in/yaml.v3 v3.0.1 // indirect
)
diff --git a/go.sum b/go.sum
index bcc6a9ceac..7862b7877a 100644
--- a/go.sum
+++ b/go.sum
@@ -157,16 +157,14 @@ github.com/denisbrodbeck/machineid v1.0.1/go.mod h1:dJUwb7PTidGDeYyUBmXZ2GphQBbj
github.com/denisenkom/go-mssqldb v0.10.0 h1:QykgLZBorFE95+gO3u9esLd0BmbvpWp0/waNNZfHBM8=
github.com/denisenkom/go-mssqldb v0.10.0/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU=
github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ=
-github.com/dolthub/dolt/go v0.40.5-0.20230911135842-43e82ba2510c h1:k0RWniTdbLoRpH9gCs0j0qC7kbTcUaH0S0etmEf8JMk=
-github.com/dolthub/dolt/go v0.40.5-0.20230911135842-43e82ba2510c/go.mod h1:eYa1zEjLxMs5ZvYEDYiqbpAVdJ2NfnUXSnZVwyy9fUQ=
github.com/dolthub/dolt/go v0.40.5-0.20230911151544-6c18c2cbfebe h1:Bo+kXeBK1CvrPKKfVydbXKbu5Mt8HM+GH8lZv2d+0A4=
github.com/dolthub/dolt/go v0.40.5-0.20230911151544-6c18c2cbfebe/go.mod h1:eYa1zEjLxMs5ZvYEDYiqbpAVdJ2NfnUXSnZVwyy9fUQ=
-github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20201005193433-3ee972b1d078 h1:nrkoh/RcgTq5EsWTcbSBF8KQghCtM+1dhyslghbBoj8=
-github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20201005193433-3ee972b1d078/go.mod h1:8Jdiq6CVg8HM4n9fF17sGgXUpFa98zDyscW0A7OQmuM=
-github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20230911135842-43e82ba2510c h1:XsYJr7EG3FD00FqYtllQyT0ycG2dGp6rXlXts+eTELA=
-github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20230911135842-43e82ba2510c/go.mod h1:Fi7KchJVfwMuPJkX4vJeAlNZkxCiVyhvVYfCgaSDlTU=
+github.com/dolthub/dolt/go v0.40.5-0.20230917024726-1ae3a49864ac h1:JE93q8eaYCujLQmpa3JibkqFLY1CXSfPdm6w351sHC8=
+github.com/dolthub/dolt/go v0.40.5-0.20230917024726-1ae3a49864ac/go.mod h1:VSSW98p69Lsd/dy3/aySvlJxmJW2g4yjvNWIXfddSEM=
github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20230911151544-6c18c2cbfebe h1:la317evV7J8bzeKfYdmpBo6/OIwsSO9QxrOPYeWoY6A=
github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20230911151544-6c18c2cbfebe/go.mod h1:Fi7KchJVfwMuPJkX4vJeAlNZkxCiVyhvVYfCgaSDlTU=
+github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20230917024726-1ae3a49864ac h1:zQtagkisqzs3QWtGxOlpayQI01wL9pYgfu8Zn3asMdU=
+github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20230917024726-1ae3a49864ac/go.mod h1:Fi7KchJVfwMuPJkX4vJeAlNZkxCiVyhvVYfCgaSDlTU=
github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2 h1:u3PMzfF8RkKd3lB9pZ2bfn0qEG+1Gms9599cr0REMww=
github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2/go.mod h1:mIEZOHnFx4ZMQeawhw9rhsj+0zwQj7adVsnBX7t+eKY=
github.com/dolthub/fslock v0.0.3 h1:iLMpUIvJKMKm92+N1fmHVdxJP5NdyDK5bK7z7Ba2s2U=
@@ -175,6 +173,8 @@ github.com/dolthub/go-icu-regex v0.0.0-20230524105445-af7e7991c97e h1:kPsT4a47cw
github.com/dolthub/go-icu-regex v0.0.0-20230524105445-af7e7991c97e/go.mod h1:KPUcpx070QOfJK1gNe0zx4pA5sicIK1GMikIGLKC168=
github.com/dolthub/go-mysql-server v0.17.1-0.20230911134248-fb77f4090d9b h1:4b+gOvFMSZ7YhiB8VPFWhyPYZcWBGALvaxT1FGLmP34=
github.com/dolthub/go-mysql-server v0.17.1-0.20230911134248-fb77f4090d9b/go.mod h1:vSQ47leaIPTtvSLKo89D1FdYdypU5OH6VBV63B2MS8Y=
+github.com/dolthub/go-mysql-server v0.17.1-0.20230916212652-86f1cdf0339c h1:jjJ7mV0X7LVAkNGgDBtKAnPlsZYoeHFF2vSPB7NH/r4=
+github.com/dolthub/go-mysql-server v0.17.1-0.20230916212652-86f1cdf0339c/go.mod h1:s4hz5TZpFkw8IdLVz58EKKm9cxvRd4jVw1Hs2o+tEXw=
github.com/dolthub/ishell v0.0.0-20221214210346-d7db0b066488 h1:0HHu0GWJH0N6a6keStrHhUAK5/o9LVfkh44pvsV4514=
github.com/dolthub/ishell v0.0.0-20221214210346-d7db0b066488/go.mod h1:ehexgi1mPxRTk0Mok/pADALuHbvATulTh6gzr7NzZto=
github.com/dolthub/jsonpath v0.0.2-0.20230525180605-8dc13778fd72 h1:NfWmngMi1CYUWU4Ix8wM+USEhjc+mhPlT9JUR/anvbQ=
@@ -185,6 +185,8 @@ github.com/dolthub/swiss v0.1.0 h1:EaGQct3AqeP/MjASHLiH6i4TAmgbG/c4rA6a1bzCOPc=
github.com/dolthub/swiss v0.1.0/go.mod h1:BeucyB08Vb1G9tumVN3Vp/pyY4AMUnr9p7Rz7wJ7kAQ=
github.com/dolthub/vitess v0.0.0-20230823204737-4a21a94e90c3 h1:lY3oQbYNMSVjT02n6f2M2H0u4icF6lGbS/IpWr27ti8=
github.com/dolthub/vitess v0.0.0-20230823204737-4a21a94e90c3/go.mod h1:IwjNXSQPymrja5pVqmfnYdcy7Uv7eNJNBPK/MEh9OOw=
+github.com/dolthub/vitess v0.0.0-20230915082726-ef1b92774b14 h1:w0MwDEgWWei13UqZeS6mlSPUQ3iMHMDlrU+MUKQtR4s=
+github.com/dolthub/vitess v0.0.0-20230915082726-ef1b92774b14/go.mod h1:IwjNXSQPymrja5pVqmfnYdcy7Uv7eNJNBPK/MEh9OOw=
github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo=
github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
@@ -373,6 +375,12 @@ github.com/iancoleman/strcase v0.1.3/go.mod h1:SK73tn/9oHe+/Y0h39VT4UCxmurVJkR5N
github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8=
github.com/influxdata/influxdb1-client v0.0.0-20191209144304-8bf82d3c094d/go.mod h1:qj24IKcXYK6Iy9ceXlo3Tc+vtHo9lIhSX5JddghvEPo=
+github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
+github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
+github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
+github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
+github.com/jackc/pgx/v5 v5.4.3 h1:cxFyXhxlvAifxnkKKdlxv8XqUf59tDlYjnV5YYfsJJY=
+github.com/jackc/pgx/v5 v5.4.3/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA=
github.com/jcmturner/gofork v0.0.0-20180107083740-2aebee971930/go.mod h1:MK8+TM0La+2rjBD4jE12Kj1pCCxK7d2LK/UM3ncEo0o=
github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k=
github.com/jmespath/go-jmespath v0.3.0/go.mod h1:9QtRXoHjLGCJ5IBSaohpXITPlowMeeYCZ7fLUTSywik=
@@ -415,8 +423,8 @@ github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxv
github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg=
github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
-github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
+github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
@@ -515,8 +523,6 @@ github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/profile v1.2.1/go.mod h1:hJw3o1OdXxsrSjjVksARp5W95eeEaEfptyVZyv6JUPA=
-github.com/pkg/profile v1.5.0 h1:042Buzk+NhDI+DeSAA62RwJL8VAuZUMQZUjCsRz1Mug=
-github.com/pkg/profile v1.5.0/go.mod h1:qBsxPvzyUincmltOk6iyRVxHYg4adc0OFOv72ZdLa18=
github.com/pkg/sftp v1.10.1/go.mod h1:lYOWFsE0bwd1+KfKJaKeuokY15vzFx25BLbzYYoAxZI=
github.com/pkg/sftp v1.13.0/go.mod h1:41g+FIPlQUTDCveupEmEA65IoiQFrtgCeDopC4ajGIM=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
@@ -561,6 +567,8 @@ github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJ
github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg=
github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
+github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
+github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA=
github.com/rs/xid v1.4.0 h1:qd7wPTDkN6KQx2VmMBLrpHkiyQwgFXRnkOLacUiaSNY=
github.com/rs/zerolog v1.28.0 h1:MirSo27VyNi7RJYP3078AA1+Cyzd2GB66qy3aUHvsWY=
github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
@@ -597,7 +605,6 @@ github.com/streadway/handy v0.0.0-20190108123426-d5acb3125c2a/go.mod h1:qNTQ5P5J
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
-github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.1.4/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
@@ -610,6 +617,7 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
+github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/tealeg/xlsx v1.0.5 h1:+f8oFmvY8Gw1iUXzPk+kz+4GpbDZPK1FhPiQRd+ypgE=
github.com/tealeg/xlsx v1.0.5/go.mod h1:btRS8dz54TDnvKNosuAqxrM1QgN1udgk9O34bDCnORM=
github.com/tetratelabs/wazero v1.1.0 h1:EByoAhC+QcYpwSZJSs/aV0uokxPwBgKxfiokSUwAknQ=
@@ -661,10 +669,6 @@ go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0=
go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo=
go.opentelemetry.io/otel v1.7.0 h1:Z2lA3Tdch0iDcrhJXDIlC94XE+bxok1F9B+4Lz/lGsM=
go.opentelemetry.io/otel v1.7.0/go.mod h1:5BdUoMIz5WEs0vt0CUEMtSSaTSHBBVwrhnz7+nrD5xk=
-go.opentelemetry.io/otel/exporters/jaeger v1.7.0 h1:wXgjiRldljksZkZrldGVe6XrG9u3kYDyQmkZwmm5dI0=
-go.opentelemetry.io/otel/exporters/jaeger v1.7.0/go.mod h1:PwQAOqBgqbLQRKlj466DuD2qyMjbtcPpfPfj+AqbSBs=
-go.opentelemetry.io/otel/sdk v1.7.0 h1:4OmStpcKVOfvDOgCt7UriAPtKolwIhxpnSNI/yK+1B0=
-go.opentelemetry.io/otel/sdk v1.7.0/go.mod h1:uTEOTwaqIVuTGiJN7ii13Ibp75wJmYUDe374q6cZwUU=
go.opentelemetry.io/otel/trace v1.7.0 h1:O37Iogk1lEkMRXewVtZ1BBTVn5JEp8GrJvP92bJqC6o=
go.opentelemetry.io/otel/trace v1.7.0/go.mod h1:fzLSB9nqR2eXzxPXb2JW9IKE+ScyXA48yyE4TNvoHqU=
go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI=
@@ -738,8 +742,8 @@ golang.org/x/mod v0.1.1-0.20191107180719-034126e5016b/go.mod h1:QqPTAvyqsEbceGzB
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
-golang.org/x/mod v0.8.0 h1:LUYupSeNrTNCGzR/hVBk2NHZO4hXcVaW1k4Qx7rjPx8=
-golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
+golang.org/x/mod v0.9.0 h1:KENHtAZL2y3NLMYZeHY9DW8HW8V+kQyJsY/V9JlKvCs=
+golang.org/x/mod v0.9.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
@@ -772,7 +776,6 @@ golang.org/x/net v0.0.0-20200501053045-e0ff5e5a1de5/go.mod h1:qpuaurCH72eLCgpAm/
golang.org/x/net v0.0.0-20200506145744-7e3656a0809f/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
golang.org/x/net v0.0.0-20200513185701-a91f0712d120/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
golang.org/x/net v0.0.0-20200520182314-0ba52f642ac2/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
-golang.org/x/net v0.0.0-20200602114024-627f9648deb9/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
@@ -848,7 +851,6 @@ golang.org/x/sys v0.0.0-20200511232937-7e40ca221e25/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20200515095857-1151b9dac4a9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200523222454-059865788121/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20200620081246-981b61492c35/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200803210538-64077c9b5642/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200828194041-157a740278f4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -857,7 +859,6 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@@ -1010,7 +1011,6 @@ google.golang.org/genproto v0.0.0-20200513103714-09dca8ec2884/go.mod h1:55QSHmfG
google.golang.org/genproto v0.0.0-20200515170657-fc4c6c6a6587/go.mod h1:YsZOwe1myG/8QRHRsmBRE1LrgQY60beZKjly0O1fX9U=
google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo=
google.golang.org/genproto v0.0.0-20200618031413-b414f8b61790/go.mod h1:jDfRM7FcilCzHH/e9qn6dsT145K34l5v+OpcnNgKAAA=
-google.golang.org/genproto v0.0.0-20200622133129-d0ee0c36e670/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
google.golang.org/genproto v0.0.0-20200729003335-053ba62fc06f/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
google.golang.org/genproto v0.0.0-20200804131852-c06518451d9c/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
google.golang.org/genproto v0.0.0-20200825200019-8632dd797987/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no=
diff --git a/main.go b/main.go
index a0bd6cd998..624bda67f7 100644
--- a/main.go
+++ b/main.go
@@ -20,12 +20,8 @@ import (
"encoding/binary"
"fmt"
"math/rand"
- "net/http"
_ "net/http/pprof"
"os"
- "os/exec"
- "strconv"
- "time"
"github.com/dolthub/dolt/go/cmd/dolt/cli"
"github.com/dolthub/dolt/go/cmd/dolt/commands"
@@ -35,7 +31,6 @@ import (
"github.com/dolthub/dolt/go/libraries/doltcore/env"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dfunctions"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
- "github.com/dolthub/dolt/go/libraries/events"
"github.com/dolthub/dolt/go/libraries/utils/argparser"
"github.com/dolthub/dolt/go/libraries/utils/config"
"github.com/dolthub/dolt/go/libraries/utils/filesys"
@@ -44,13 +39,7 @@ import (
"github.com/dolthub/go-mysql-server/server"
"github.com/dolthub/go-mysql-server/sql"
"github.com/fatih/color"
- "github.com/pkg/profile"
"github.com/tidwall/gjson"
- "go.opentelemetry.io/otel"
- "go.opentelemetry.io/otel/exporters/jaeger"
- "go.opentelemetry.io/otel/sdk/resource"
- tracesdk "go.opentelemetry.io/otel/sdk/trace"
- semconv "go.opentelemetry.io/otel/semconv/v1.10.0"
"github.com/dolthub/doltgresql/postgres"
)
@@ -66,44 +55,26 @@ var doltCommand = cli.NewSubCommandHandler("doltgresql", "it's git for data", []
sqlserver.SqlServerCmd{VersionStr: Version},
})
var globalArgParser = cli.CreateGlobalArgParser("doltgresql")
-var globalDocs = cli.CommandDocsForCommandString("doltgresql", doc, globalArgParser)
-
-var globalSpecialMsg = `
-Dolt subcommands are in transition to using the flags listed below as global flags.
-Not all subcommands use these flags. If your command accepts these flags without error, then they are supported.
-`
func init() {
server.DefaultProtocolListenerFunc = postgres.NewListenerWithConfig
- sqlserver.DoltgreSQLDisableUsers = true
+ sqlserver.ExternalDisableUsers = true
dfunctions.VersionString = Version
}
-const pprofServerFlag = "--pprof-server"
const chdirFlag = "--chdir"
-const jaegerFlag = "--jaeger"
-const profFlag = "--prof"
-const csMetricsFlag = "--csmetrics"
const stdInFlag = "--stdin"
const stdOutFlag = "--stdout"
const stdErrFlag = "--stderr"
const stdOutAndErrFlag = "--out-and-err"
const ignoreLocksFlag = "--ignore-lock-file"
-const verboseEngineSetupFlag = "--verbose-engine-setup"
-
-const cpuProf = "cpu"
-const memProf = "mem"
-const blockingProf = "blocking"
-const traceProf = "trace"
-
-const featureVersionFlag = "--feature-version"
func main() {
- os.Exit(runMain())
+ os.Exit(RunMain(os.Args[1:]))
}
-func runMain() int {
- args := os.Args[1:]
+func RunMain(args []string) int {
+ ctx := context.Background()
// Inject the "sql-server" command
args = append([]string{"sql-server"}, args...)
// Enforce a default port of 5432
@@ -113,102 +84,15 @@ func runMain() int {
}
}
- start := time.Now()
-
- if len(args) == 0 {
- doltCommand.PrintUsage("dolt")
- return 1
- }
-
if os.Getenv("DOLT_VERBOSE_ASSERT_TABLE_FILES_CLOSED") == "" {
nbs.TableIndexGCFinalizerWithStackTrace = false
}
- csMetrics := false
ignoreLockFile := false
- verboseEngineSetup := false
if len(args) > 0 {
var doneDebugFlags bool
for !doneDebugFlags && len(args) > 0 {
switch args[0] {
- case profFlag:
- switch args[1] {
- case cpuProf:
- cli.Println("cpu profiling enabled.")
- defer profile.Start(profile.CPUProfile, profile.NoShutdownHook).Stop()
- case memProf:
- cli.Println("mem profiling enabled.")
- defer profile.Start(profile.MemProfile, profile.NoShutdownHook).Stop()
- case blockingProf:
- cli.Println("block profiling enabled")
- defer profile.Start(profile.BlockProfile, profile.NoShutdownHook).Stop()
- case traceProf:
- cli.Println("trace profiling enabled")
- defer profile.Start(profile.TraceProfile, profile.NoShutdownHook).Stop()
- default:
- panic("Unexpected prof flag: " + args[1])
- }
- args = args[2:]
-
- case pprofServerFlag:
- // serve the pprof endpoints setup in the init function run when "net/http/pprof" is imported
- go func() {
- cyanStar := color.CyanString("*")
- cli.Println(cyanStar, "Starting pprof server on port 6060.")
- cli.Println(cyanStar, "Go to", color.CyanString("http://localhost:6060/debug/pprof"), "in a browser to see supported endpoints.")
- cli.Println(cyanStar)
- cli.Println(cyanStar, "Known endpoints are:")
- cli.Println(cyanStar, " /allocs: A sampling of all past memory allocations")
- cli.Println(cyanStar, " /block: Stack traces that led to blocking on synchronization primitives")
- cli.Println(cyanStar, " /cmdline: The command line invocation of the current program")
- cli.Println(cyanStar, " /goroutine: Stack traces of all current goroutines")
- cli.Println(cyanStar, " /heap: A sampling of memory allocations of live objects. You can specify the gc GET parameter to run GC before taking the heap sample.")
- cli.Println(cyanStar, " /mutex: Stack traces of holders of contended mutexes")
- cli.Println(cyanStar, " /profile: CPU profile. You can specify the duration in the seconds GET parameter. After you get the profile file, use the go tool pprof command to investigate the profile.")
- cli.Println(cyanStar, " /threadcreate: Stack traces that led to the creation of new OS threads")
- cli.Println(cyanStar, " /trace: A trace of execution of the current program. You can specify the duration in the seconds GET parameter. After you get the trace file, use the go tool trace command to investigate the trace.")
- cli.Println()
-
- err := http.ListenAndServe("0.0.0.0:6060", nil)
-
- if err != nil {
- cli.Println(color.YellowString("pprof server exited with error: %v", err))
- }
- }()
- args = args[1:]
-
- // Enable a global jaeger tracer for this run of Dolt,
- // emitting traces to a collector running at
- // localhost:14268. To visualize these traces, run:
- // docker run -d --name jaeger \
- // -e COLLECTOR_ZIPKIN_HTTP_PORT=9411 \
- // -p 5775:5775/udp \
- // -p 6831:6831/udp \
- // -p 6832:6832/udp \
- // -p 5778:5778 \
- // -p 16686:16686 \
- // -p 14268:14268 \
- // -p 14250:14250 \
- // -p 9411:9411 \
- // jaegertracing/all-in-one:1.21
- // and browse to http://localhost:16686
- case jaegerFlag:
- cli.Println("running with jaeger tracing reporting to localhost")
- exp, err := jaeger.New(jaeger.WithCollectorEndpoint(jaeger.WithEndpoint("http://localhost:14268/api/traces")))
- if err != nil {
- cli.Println(color.YellowString("could not create jaeger collector: %v", err))
- } else {
- tp := tracesdk.NewTracerProvider(
- tracesdk.WithBatcher(exp),
- tracesdk.WithResource(resource.NewWithAttributes(
- semconv.SchemaURL,
- semconv.ServiceNameKey.String("dolt"),
- )),
- )
- otel.SetTracerProvider(tp)
- defer tp.Shutdown(context.Background())
- args = args[1:]
- }
// Currently goland doesn't support running with a different working directory when using go modules.
// This is a hack that allows a different working directory to be set after the application starts using
// chdir=
. The syntax is not flexible and must match exactly this.
@@ -259,33 +143,10 @@ func runMain() int {
color.NoColor = true
args = args[2:]
- case csMetricsFlag:
- csMetrics = true
- args = args[1:]
-
case ignoreLocksFlag:
ignoreLockFile = true
args = args[1:]
- case featureVersionFlag:
- var err error
- if len(args) == 0 {
- err = fmt.Errorf("missing argument for the --feature-version flag")
- } else {
- if featureVersion, err := strconv.Atoi(args[1]); err == nil {
- doltdb.DoltFeatureVersion = doltdb.FeatureVersion(featureVersion)
- }
- }
- if err != nil {
- cli.PrintErrln(err.Error())
- return 1
- }
-
- args = args[2:]
-
- case verboseEngineSetupFlag:
- verboseEngineSetup = true
- args = args[1:]
default:
doneDebugFlags = true
}
@@ -299,24 +160,11 @@ func runMain() int {
warnIfMaxFilesTooLow()
- ctx := context.Background()
- if ok, exit := interceptSendMetrics(ctx, args); ok {
- return exit
- }
-
- _, usage := cli.HelpAndUsagePrinters(globalDocs)
-
var fs filesys.Filesys
fs = filesys.LocalFS
dEnv := env.Load(ctx, env.GetCurrentUserHomeDir, fs, doltdb.LocalDirDoltDB, Version)
dEnv.IgnoreLockFile = ignoreLockFile
- root, err := env.GetCurrentUserHomeDir()
- if err != nil {
- cli.PrintErrln(color.RedString("Failed to load the HOME directory: %v", err))
- return 1
- }
-
globalConfig, ok := dEnv.Config.GetConfig(env.GlobalConfig)
if !ok {
cli.PrintErrln(color.RedString("Failed to get global config"))
@@ -325,10 +173,8 @@ func runMain() int {
apr, remainingArgs, subcommandName, err := parseGlobalArgsAndSubCommandName(globalConfig, args)
if err == argparser.ErrHelp {
+ //TODO: display some help message
doltCommand.PrintUsage("dolt")
- cli.Println(globalSpecialMsg)
- usage()
-
return 0
} else if err != nil {
cli.PrintErrln(color.RedString("Failure to parse arguments: %v", err))
@@ -354,35 +200,7 @@ func runMain() int {
return 1
}
- emitter := events.NewFileEmitter(root, dbfactory.DoltDir)
-
- defer func() {
- ces := events.GlobalCollector.Close()
- // events.WriterEmitter{cli.CliOut}.LogEvents(Version, ces)
-
- metricsDisabled := dEnv.Config.GetStringOrDefault(env.MetricsDisabled, "false")
-
- disabled, err := strconv.ParseBool(metricsDisabled)
- if err != nil {
- // log.Print(err)
- return
- }
-
- if disabled {
- return
- }
-
- // write events
- _ = emitter.LogEvents(Version, ces)
-
- // flush events
- if err := processEventsDir(args, dEnv); err != nil {
- // log.Print(err)
- }
- }()
-
err = reconfigIfTempFileMoveFails(dEnv)
-
if err != nil {
cli.PrintErrln(color.RedString("Failed to setup the temporary directory. %v`", err))
return 1
@@ -443,7 +261,7 @@ func runMain() int {
return 1
}
- lateBind, err := buildLateBinder(ctx, cwdFS, dEnv, mrEnv, creds, apr, subcommandName, verboseEngineSetup)
+ lateBind, err := buildLateBinder(ctx, cwdFS, dEnv, mrEnv, creds, apr, subcommandName, false)
if err != nil {
cli.PrintErrln(color.RedString("%v", err))
@@ -471,12 +289,6 @@ func runMain() int {
}
}
- if csMetrics && dEnv.DoltDB != nil {
- metricsSummary := dEnv.DoltDB.CSMetricsSummary()
- cli.Println("Command took", time.Since(start).Seconds())
- cli.PrintErrln(metricsSummary)
- }
-
return res
}
@@ -571,17 +383,6 @@ func buildLateBinder(ctx context.Context, cwdFS filesys.Filesys, rootEnv *env.Do
return commands.BuildSqlEngineQueryist(ctx, cwdFS, mrEnv, creds, apr)
}
-// doc is currently used only when a `initCliContext` command is specified. This will include all commands in time,
-// otherwise you only see these docs if you specify a nonsense argument before the `sql` subcommand.
-var doc = cli.CommandDocumentationContent{
- ShortDesc: "Dolt is git for data",
- LongDesc: `Dolt comprises of multiple subcommands that allow users to import, export, update, and manipulate data with SQL.`,
-
- Synopsis: []string{
- "<--data-dir=> subcommand ",
- },
-}
-
func seedGlobalRand() {
bs := make([]byte, 8)
_, err := crand.Read(bs)
@@ -591,42 +392,6 @@ func seedGlobalRand() {
rand.Seed(int64(binary.LittleEndian.Uint64(bs)))
}
-// processEventsDir runs the dolt send-metrics command in a new process
-func processEventsDir(args []string, dEnv *env.DoltEnv) error {
- if len(args) > 0 {
- ignoreCommands := map[string]struct{}{
- commands.SendMetricsCommand: {},
- "init": {},
- "config": {},
- }
-
- _, ok := ignoreCommands[args[0]]
-
- if ok {
- return nil
- }
-
- cmd := exec.Command("dolt", commands.SendMetricsCommand)
-
- if err := cmd.Start(); err != nil {
- // log.Print(err)
- return err
- }
-
- return nil
- }
-
- return nil
-}
-
-func interceptSendMetrics(ctx context.Context, args []string) (bool, int) {
- if len(args) < 1 || args[0] != commands.SendMetricsCommand {
- return false, 0
- }
- dEnv := env.LoadWithoutDB(ctx, env.GetCurrentUserHomeDir, filesys.LocalFS, Version)
- return true, doltCommand.Exec(ctx, "dolt", args, dEnv, nil)
-}
-
// parseGlobalArgsAndSubCommandName parses the global arguments, including a profile if given or a default profile if exists. Also returns the subcommand name.
func parseGlobalArgsAndSubCommandName(globalConfig config.ReadWriteConfig, args []string) (apr *argparser.ArgParseResults, remaining []string, subcommandName string, err error) {
apr, remaining, err = globalArgParser.ParseGlobalArgs(args)
diff --git a/main_test.go b/main_test.go
new file mode 100644
index 0000000000..8e98678aff
--- /dev/null
+++ b/main_test.go
@@ -0,0 +1,59 @@
+// Copyright 2023 Dolthub, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+ "context"
+ "fmt"
+ "net"
+ "testing"
+
+ "github.com/jackc/pgx/v5"
+ "github.com/stretchr/testify/require"
+)
+
+func TestBasicConnection(t *testing.T) {
+ //TODO: fix me
+ return
+ //port := getEmptyPort(t)
+ //go RunMain([]string{fmt.Sprintf("--port=%d", port)})
+ //port := 5431
+ port := 5432
+
+ ctx := context.Background()
+ conn, err := pgx.Connect(ctx, fmt.Sprintf("postgres://postgres:password@localhost:%d/postgres", port))
+ require.NoError(t, err)
+ defer conn.Close(ctx)
+
+ func() {
+ //rows, err := conn.Query(ctx, "CREATE DATABASE testdb;")
+ rows, err := conn.Query(ctx, "SELECT * FROM test;")
+ require.NoError(t, err)
+ defer rows.Close()
+ for rows.Next() {
+ row, err := rows.Values()
+ require.NoError(t, err)
+ row = row
+ }
+ }()
+}
+
+func getEmptyPort(t *testing.T) int {
+ listener, err := net.Listen("tcp", ":0")
+ require.NoError(t, err)
+ port := listener.Addr().(*net.TCPAddr).Port
+ require.NoError(t, listener.Close())
+ return port
+}
diff --git a/postgres/listener.go b/postgres/listener.go
index f15ddf521b..5acf2ebb41 100644
--- a/postgres/listener.go
+++ b/postgres/listener.go
@@ -93,9 +93,9 @@ func (l *Listener) HandleConnection(conn net.Conn) {
return
}
- if _, err = conn.Write(messages.SSLResponse{
+ if err = messages.Send(conn, messages.SSLResponse{
SupportsSSL: false,
- }.Bytes()); err != nil {
+ }); err != nil {
fmt.Println(err)
return
}
@@ -107,43 +107,43 @@ func (l *Listener) HandleConnection(conn net.Conn) {
}
return
}
- startupMessage, err := messages.ReadStartupMessage(buf)
+ startupMessage, err := messages.ReceiveInto(buf, messages.StartupMessage{})
if err != nil {
fmt.Println(err)
return
}
- if _, err = conn.Write(messages.AuthenticationOk{}.Bytes()); err != nil {
+ if err = messages.Send(conn, messages.AuthenticationOk{}); err != nil {
fmt.Println(err)
return
}
- if _, err = conn.Write(messages.ParameterStatus{
+ if err = messages.Send(conn, messages.ParameterStatus{
Name: "server_version",
Value: "15.0",
- }.Bytes()); err != nil {
+ }); err != nil {
fmt.Println(err)
return
}
- if _, err = conn.Write(messages.ParameterStatus{
+ if err = messages.Send(conn, messages.ParameterStatus{
Name: "client_encoding",
Value: "UTF8",
- }.Bytes()); err != nil {
+ }); err != nil {
fmt.Println(err)
return
}
- if _, err = conn.Write(messages.BackendKeyData{
+ if err = messages.Send(conn, messages.BackendKeyData{
ProcessID: 1,
SecretKey: 0,
- }.Bytes()); err != nil {
+ }); err != nil {
fmt.Println(err)
return
}
- if _, err = conn.Write(messages.ReadyForQuery{
+ if err = messages.Send(conn, messages.ReadyForQuery{
Indicator: messages.ReadyForQueryTransactionIndicator_Idle,
- }.Bytes()); err != nil {
+ }); err != nil {
fmt.Println(err)
return
}
@@ -162,62 +162,65 @@ func (l *Listener) HandleConnection(conn net.Conn) {
return
}
- if messages.ReadTerminate(buf) {
- return
- }
- query, ok := messages.ReadQuery(buf)
- if !ok {
- fmt.Println("unknown message, terminating connection")
- return
- }
- commandCompleteTag, err := messages.QueryToCommandCompleteTag(query)
+ message, ok, err := messages.Receive(buf)
if err != nil {
fmt.Println(err.Error())
return
+ } else if !ok {
+ fmt.Println("unknown message format, terminating connection")
+ return
}
- var rowTotal int32
- if err = l.cfg.Handler.ComQuery(mysqlConn, query, func(res *sqltypes.Result, more bool) error {
- rowDescription, err := messages.NewRowDescription(res.Fields)
- if err != nil {
- return err
- }
- if _, err = conn.Write(rowDescription.Bytes()); err != nil {
- return err
- }
+ switch message := message.(type) {
+ case messages.Terminate:
+ return
+ case messages.Query:
+ l.query(conn, mysqlConn, message.String)
+ }
+ }
+}
- for _, row := range res.Rows {
- if _, err = conn.Write(messages.NewDataRow(row).Bytes()); err != nil {
- return err
- }
- }
+func (l *Listener) query(conn net.Conn, mysqlConn *mysql.Conn, query string) {
+ commandComplete := messages.CommandComplete{
+ Query: query,
+ Rows: 0,
+ }
- if commandCompleteTag == messages.CommandCompleteTag_INSERT ||
- commandCompleteTag == messages.CommandCompleteTag_UPDATE ||
- commandCompleteTag == messages.CommandCompleteTag_DELETE {
- rowTotal = int32(res.RowsAffected)
- } else {
- rowTotal += int32(len(res.Rows))
- }
- return nil
+ if err := l.cfg.Handler.ComQuery(mysqlConn, query, func(res *sqltypes.Result, more bool) error {
+ if err := messages.Send(conn, messages.RowDescription{
+ Fields: res.Fields,
}); err != nil {
- fmt.Println(err.Error())
- return
+ return err
}
- if _, err = conn.Write(messages.CommandComplete{
- Tag: commandCompleteTag,
- Rows: rowTotal,
- }.Bytes()); err != nil {
- fmt.Println(err)
- return
+ for _, row := range res.Rows {
+ if err := messages.Send(conn, messages.DataRow{
+ Values: row,
+ }); err != nil {
+ return err
+ }
}
- if _, err = conn.Write(messages.ReadyForQuery{
- Indicator: messages.ReadyForQueryTransactionIndicator_Idle,
- }.Bytes()); err != nil {
- fmt.Println(err)
- return
+ if commandComplete.IsIUD() {
+ commandComplete.Rows = int32(res.RowsAffected)
+ } else {
+ commandComplete.Rows += int32(len(res.Rows))
}
+ return nil
+ }); err != nil {
+ fmt.Println(err.Error())
+ return
+ }
+
+ if err := messages.Send(conn, commandComplete); err != nil {
+ fmt.Println(err)
+ return
+ }
+
+ if err := messages.Send(conn, messages.ReadyForQuery{
+ Indicator: messages.ReadyForQueryTransactionIndicator_Idle,
+ }); err != nil {
+ fmt.Println(err)
+ return
}
}
diff --git a/postgres/messages/authentication_ok.go b/postgres/messages/authentication_ok.go
index 148d6b8687..73631d3ff5 100644
--- a/postgres/messages/authentication_ok.go
+++ b/postgres/messages/authentication_ok.go
@@ -1,13 +1,65 @@
+// Copyright 2023 Dolthub, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
package messages
// AuthenticationOk tells the client that authentication was successful.
type AuthenticationOk struct{}
-// Bytes returns AuthenticationOk as a byte slice, ready to be returned to the client.
-func (aok AuthenticationOk) Bytes() []byte {
- return []byte{
- 'R', // Message Type
- 0, 0, 0, 8, // Message Length
- 0, 0, 0, 0, // Padding
+func init() {
+ initializeDefaultMessage(AuthenticationOk{})
+}
+
+var authenticationOkDefault = Message{
+ Name: "AuthenticationOk",
+ Fields: []*Field{
+ {
+ Name: "Header",
+ Type: Byte1,
+ Tags: Header,
+ Data: int32('R'),
+ },
+ {
+ Name: "MessageLength",
+ Type: Int32,
+ Tags: MessageLengthInclusive,
+ Data: int32(8),
+ },
+ {
+ Name: "Status",
+ Type: Int32,
+ Data: int32(0),
+ },
+ },
+}
+
+var _ MessageType = AuthenticationOk{}
+
+// encode implements the interface MessageType.
+func (m AuthenticationOk) encode() (Message, error) {
+ return m.defaultMessage().Copy(), nil
+}
+
+// decode implements the interface MessageType.
+func (m AuthenticationOk) decode(s Message) (MessageType, error) {
+ if err := s.MatchesStructure(*m.defaultMessage()); err != nil {
+ return nil, err
}
+ return AuthenticationOk{}, nil
+}
+
+// defaultMessage implements the interface MessageType.
+func (m AuthenticationOk) defaultMessage() *Message {
+ return &authenticationOkDefault
}
diff --git a/postgres/messages/backend_key_data.go b/postgres/messages/backend_key_data.go
index 4e800ec9ff..ff1ba7cfe2 100644
--- a/postgres/messages/backend_key_data.go
+++ b/postgres/messages/backend_key_data.go
@@ -1,6 +1,23 @@
+// Copyright 2023 Dolthub, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
package messages
-import "bytes"
+func init() {
+ initializeDefaultMessage(BackendKeyData{})
+ addMessageHeader(BackendKeyData{})
+}
// BackendKeyData provides the client with information about the server.
type BackendKeyData struct {
@@ -8,12 +25,56 @@ type BackendKeyData struct {
SecretKey int32
}
-// Bytes returns BackendKeyData as a byte slice, ready to be returned to the client.
-func (bkd BackendKeyData) Bytes() []byte {
- buf := bytes.Buffer{}
- buf.WriteByte('K') // Message Type
- WriteNumber(&buf, int32(12)) // Message Length
- WriteNumber(&buf, bkd.ProcessID)
- WriteNumber(&buf, bkd.SecretKey)
- return buf.Bytes()
+var backendKeyDataDefault = Message{
+ Name: "BackendKeyData",
+ Fields: []*Field{
+ {
+ Name: "Header",
+ Type: Byte1,
+ Tags: Header,
+ Data: int32('K'),
+ },
+ {
+ Name: "MessageLength",
+ Type: Int32,
+ Tags: MessageLengthInclusive,
+ Data: int32(12),
+ },
+ {
+ Name: "ProcessID",
+ Type: Int32,
+ Data: int32(0),
+ },
+ {
+ Name: "SecretKey",
+ Type: Int32,
+ Data: int32(0),
+ },
+ },
+}
+
+var _ MessageType = BackendKeyData{}
+
+// encode implements the interface MessageType.
+func (m BackendKeyData) encode() (Message, error) {
+ outputMessage := m.defaultMessage().Copy()
+ outputMessage.Field("ProcessID").MustWrite(m.ProcessID)
+ outputMessage.Field("SecretKey").MustWrite(m.SecretKey)
+ return outputMessage, nil
+}
+
+// decode implements the interface MessageType.
+func (m BackendKeyData) decode(s Message) (MessageType, error) {
+ if err := s.MatchesStructure(*m.defaultMessage()); err != nil {
+ return nil, err
+ }
+ return BackendKeyData{
+ ProcessID: s.Field("ProcessID").MustGet().(int32),
+ SecretKey: s.Field("SecretKey").MustGet().(int32),
+ }, nil
+}
+
+// defaultMessage implements the interface MessageType.
+func (m BackendKeyData) defaultMessage() *Message {
+ return &backendKeyDataDefault
}
diff --git a/postgres/messages/command_complete.go b/postgres/messages/command_complete.go
index de6c339546..c4127e739f 100644
--- a/postgres/messages/command_complete.go
+++ b/postgres/messages/command_complete.go
@@ -1,76 +1,112 @@
+// Copyright 2023 Dolthub, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
package messages
import (
- "bytes"
"fmt"
"strconv"
"strings"
)
-// CommandCompleteTag indicates which SQL command was completed.
-type CommandCompleteTag byte
-
-const (
- CommandCompleteTag_INSERT CommandCompleteTag = iota
- CommandCompleteTag_DELETE
- CommandCompleteTag_UPDATE
- CommandCompleteTag_MERGE
- CommandCompleteTag_SELECT
- CommandCompleteTag_MOVE
- CommandCompleteTag_FETCH
- CommandCompleteTag_COPY
-)
+func init() {
+ initializeDefaultMessage(CommandComplete{})
+}
// CommandComplete tells the client that the command has completed.
type CommandComplete struct {
- Tag CommandCompleteTag
- Rows int32
+ Query string
+ Rows int32
}
-// Bytes returns CommandComplete as a byte slice, ready to be returned to the client.
-func (cc CommandComplete) Bytes() []byte {
- buf := bytes.Buffer{}
- buf.WriteByte('C') // Message Type
- WriteNumber(&buf, int32(0)) // Message length, will be corrected later
- switch cc.Tag {
- case CommandCompleteTag_INSERT:
- buf.WriteString("INSERT 0 ")
- case CommandCompleteTag_DELETE:
- buf.WriteString("DELETE ")
- case CommandCompleteTag_UPDATE:
- buf.WriteString("UPDATE ")
- case CommandCompleteTag_MERGE:
- buf.WriteString("MERGE ")
- case CommandCompleteTag_SELECT:
- buf.WriteString("SELECT ")
- case CommandCompleteTag_MOVE:
- buf.WriteString("MOVE ")
- case CommandCompleteTag_FETCH:
- buf.WriteString("FETCH ")
- case CommandCompleteTag_COPY:
- buf.WriteString("COPY ")
+var commandCompleteDefault = Message{
+ Name: "CommandComplete",
+ Fields: []*Field{
+ {
+ Name: "Header",
+ Type: Byte1,
+ Tags: Header,
+ Data: int32('C'),
+ },
+ {
+ Name: "MessageLength",
+ Type: Int32,
+ Tags: MessageLengthInclusive,
+ Data: int32(0),
+ },
+ {
+ Name: "CommandTag",
+ Type: String,
+ Data: "",
+ },
+ },
+}
+
+var _ MessageType = CommandComplete{}
+
+// IsIUD returns whether the query is either an INSERT, UPDATE, or DELETE query.
+func (m CommandComplete) IsIUD() bool {
+ query := strings.TrimSpace(strings.ToLower(m.Query))
+ if strings.HasPrefix(query, "insert") ||
+ strings.HasPrefix(query, "update") ||
+ strings.HasPrefix(query, "delete") {
+ return true
+ } else {
+ return false
}
- buf.WriteString(strconv.Itoa(int(cc.Rows)))
- buf.WriteByte(0) // Trailing NULL character, denoting the end of the string
- return WriteLength(buf.Bytes())
}
-// QueryToCommandCompleteTag returns the appropriate command tag for the given query.
-func QueryToCommandCompleteTag(query string) (CommandCompleteTag, error) {
- query = strings.TrimSpace(strings.ToLower(query))
+// encode implements the interface MessageType.
+func (m CommandComplete) encode() (Message, error) {
+ outputMessage := m.defaultMessage().Copy()
+ query := strings.TrimSpace(strings.ToLower(m.Query))
if strings.HasPrefix(query, "select") {
- return CommandCompleteTag_SELECT, nil
+ outputMessage.Field("CommandTag").MustWrite(fmt.Sprintf("SELECT %d", m.Rows))
} else if strings.HasPrefix(query, "insert") {
- return CommandCompleteTag_INSERT, nil
+ outputMessage.Field("CommandTag").MustWrite(fmt.Sprintf("INSERT 0 %d", m.Rows))
} else if strings.HasPrefix(query, "update") {
- return CommandCompleteTag_UPDATE, nil
+ outputMessage.Field("CommandTag").MustWrite(fmt.Sprintf("UPDATE %d", m.Rows))
} else if strings.HasPrefix(query, "delete") {
- return CommandCompleteTag_DELETE, nil
+ outputMessage.Field("CommandTag").MustWrite(fmt.Sprintf("DELETE %d", m.Rows))
} else if strings.HasPrefix(query, "create") {
- return CommandCompleteTag_SELECT, nil
+ outputMessage.Field("CommandTag").MustWrite(fmt.Sprintf("SELECT %d", m.Rows))
} else if strings.HasPrefix(query, "call") {
- return CommandCompleteTag_SELECT, nil
+ outputMessage.Field("CommandTag").MustWrite(fmt.Sprintf("SELECT %d", m.Rows))
} else {
- return 0, fmt.Errorf("unsupported query for now")
+ return Message{}, fmt.Errorf("unsupported query for now")
+ }
+ return outputMessage, nil
+}
+
+// decode implements the interface MessageType.
+func (m CommandComplete) decode(s Message) (MessageType, error) {
+ if err := s.MatchesStructure(*m.defaultMessage()); err != nil {
+ return nil, err
}
+ query := strings.TrimSpace(s.Field("CommandTag").MustGet().(string))
+ tokens := strings.Split(query, " ")
+ rows, err := strconv.Atoi(tokens[len(tokens)-1])
+ if err != nil {
+ return nil, err
+ }
+ return CommandComplete{
+ Query: query,
+ Rows: int32(rows),
+ }, nil
+}
+
+// defaultMessage implements the interface MessageType.
+func (m CommandComplete) defaultMessage() *Message {
+ return &commandCompleteDefault
}
diff --git a/postgres/messages/data_row.go b/postgres/messages/data_row.go
index 2ddddd9e00..13128929c8 100644
--- a/postgres/messages/data_row.go
+++ b/postgres/messages/data_row.go
@@ -1,103 +1,102 @@
+// Copyright 2023 Dolthub, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
package messages
import (
- "bytes"
- "fmt"
- "strconv"
- "time"
-
"github.com/dolthub/vitess/go/sqltypes"
-
- "github.com/shopspring/decimal"
)
+func init() {
+ initializeDefaultMessage(DataRow{})
+}
+
// DataRow represents a row of data.
type DataRow struct {
- Values []DataRowValue
+ Values []sqltypes.Value
}
-// DataRowValue represents a column's value in a DataRow.
-type DataRowValue struct {
- Value any
+var dataRowDefault = Message{
+ Name: "DataRow",
+ Fields: []*Field{
+ {
+ Name: "Header",
+ Type: Byte1,
+ Tags: Header,
+ Data: int32('D'),
+ },
+ {
+ Name: "MessageLength",
+ Type: Int32,
+ Tags: MessageLengthInclusive,
+ Data: int32(0),
+ },
+ {
+ Name: "Columns",
+ Type: Int16,
+ Data: int32(0),
+ Children: [][]*Field{
+ {
+ {
+ Name: "ColumnLength",
+ Type: Int32,
+ Tags: ByteCount,
+ Data: int32(0),
+ },
+ {
+ Name: "ColumnData",
+ Type: ByteN,
+ Data: []byte{},
+ },
+ },
+ },
+ },
+ },
}
-// NewDataRow creates a new DataRow from the given rows.
-func NewDataRow(row []sqltypes.Value) DataRow {
- values := make([]DataRowValue, len(row))
- for i, value := range row {
- values[i] = DataRowValue{value.ToString()}
- }
- return DataRow{
- Values: values,
+var _ MessageType = DataRow{}
+
+// encode implements the interface MessageType.
+func (m DataRow) encode() (Message, error) {
+ outputMessage := m.defaultMessage().Copy()
+ for i := 0; i < len(m.Values); i++ {
+ if m.Values[i].IsNull() {
+ outputMessage.Field("Columns").Child("ColumnLength", i).MustWrite(-1)
+ } else {
+ value := []byte(m.Values[i].ToString())
+ outputMessage.Field("Columns").Child("ColumnLength", i).MustWrite(len(value))
+ outputMessage.Field("Columns").Child("ColumnData", i).MustWrite(value)
+ }
}
+ return outputMessage, nil
}
-// Bytes returns DataRow as a byte slice, ready to be returned to the client.
-func (dr DataRow) Bytes() []byte {
- buf := bytes.Buffer{}
- buf.WriteByte('D') // Message Type
- WriteNumber(&buf, int32(0)) // Message length, will be corrected later
- WriteNumber(&buf, int16(len(dr.Values)))
- for _, drv := range dr.Values {
- drv.Bytes(&buf)
+// decode implements the interface MessageType.
+func (m DataRow) decode(s Message) (MessageType, error) {
+ if err := s.MatchesStructure(*m.defaultMessage()); err != nil {
+ return nil, err
}
- return WriteLength(buf.Bytes())
+ columnCount := int(s.Field("Columns").MustGet().(int32))
+ for i := 0; i < columnCount; i++ {
+ //TODO: decode the message in here
+ }
+ return DataRow{
+ Values: nil,
+ }, nil
}
-// Bytes writes the value into the given buffer.
-func (drv DataRowValue) Bytes(buf *bytes.Buffer) {
- var dataBytes []byte
- switch val := drv.Value.(type) {
- case int:
- dataBytes = []byte(strconv.FormatInt(int64(val), 10))
- case int8:
- dataBytes = []byte(strconv.FormatInt(int64(val), 10))
- case int16:
- dataBytes = []byte(strconv.FormatInt(int64(val), 10))
- case int32:
- dataBytes = []byte(strconv.FormatInt(int64(val), 10))
- case int64:
- dataBytes = []byte(strconv.FormatInt(val, 10))
- case uint:
- dataBytes = []byte(strconv.FormatUint(uint64(val), 10))
- case uint8:
- dataBytes = []byte(strconv.FormatUint(uint64(val), 10))
- case uint16:
- dataBytes = []byte(strconv.FormatUint(uint64(val), 10))
- case uint32:
- dataBytes = []byte(strconv.FormatUint(uint64(val), 10))
- case uint64:
- dataBytes = []byte(strconv.FormatUint(val, 10))
- case float32:
- dataBytes = []byte(strconv.FormatFloat(float64(val), 'g', -1, 32))
- case float64:
- dataBytes = []byte(strconv.FormatFloat(val, 'g', -1, 64))
- case decimal.NullDecimal:
- if !val.Valid {
- WriteNumber(buf, int32(-1))
- return
- }
- dataBytes = []byte(val.Decimal.String())
- case decimal.Decimal:
- dataBytes = []byte(val.String())
- case []byte:
- dataBytes = val
- case string:
- dataBytes = []byte(val)
- case bool:
- if val {
- dataBytes = []byte("true")
- } else {
- dataBytes = []byte("false")
- }
- case time.Time:
- dataBytes = []byte(val.Format(time.RFC3339))
- case nil:
- WriteNumber(buf, int32(-1))
- return
- default:
- panic(fmt.Errorf("unknown DataRow value type: %T", val))
- }
- WriteNumber(buf, int32(len(dataBytes))) // This length only covers the value's size
- buf.Write(dataBytes)
+// defaultMessage implements the interface MessageType.
+func (m DataRow) defaultMessage() *Message {
+ return &dataRowDefault
}
diff --git a/postgres/messages/message.go b/postgres/messages/message.go
new file mode 100644
index 0000000000..6e78f45049
--- /dev/null
+++ b/postgres/messages/message.go
@@ -0,0 +1,88 @@
+// Copyright 2023 Dolthub, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package messages
+
+import (
+ "fmt"
+ "strings"
+)
+
+// Message is a message as defined by PostgreSQL.
+// https://www.postgresql.org/docs/15/protocol-message-formats.html
+type Message struct {
+ Name string
+ Fields []*Field
+ info *messageInfo
+ isDefault bool
+}
+
+// MessageType is a type that represents a PostgreSQL message.
+type MessageType interface {
+ // encode returns a new Message containing any modified data contained within the object. This should NOT be
+ // the default message.
+ encode() (Message, error)
+ // decode returns a new MessageType that represents the given Message. You should never return the default
+ // message, even if the message never varies from the default. Always make a copy, and then modify that copy.
+ decode(s Message) (MessageType, error)
+ // defaultMessage returns the default, unmodified message for this type.
+ defaultMessage() *Message
+}
+
+// messageFieldInfo contains information on a specific field within a messageInfo.
+type messageFieldInfo struct {
+ RelativeIndex int
+ Parent string
+ UsesByteCount bool // Only used by ByteN fields
+}
+
+// messageInfo contains all of the information that a message should keep track of. Used internally by messages.
+type messageInfo struct {
+ fieldInfo map[string]messageFieldInfo
+ appendNullByte bool
+ defaultMessage *Message
+}
+
+// Copy returns a copy of the Message, which is free to modify.
+func (m Message) Copy() Message {
+ newFields := make([]*Field, len(m.Fields))
+ for i, field := range m.Fields {
+ newFields[i] = field.Copy()
+ }
+ return Message{m.Name, newFields, m.info, false}
+}
+
+// String returns a printable version of the Message.
+func (m Message) String() string {
+ buffer := strings.Builder{}
+ buffer.WriteString(fmt.Sprintf("%s: {\n", m.Name))
+ buffer.WriteString("\n") //TODO: print this
+ buffer.WriteString("}")
+ return buffer.String()
+}
+
+// MatchesStructure returns an error if the given Message has a different structure than the calling Message.
+func (m Message) MatchesStructure(otherMessage Message) error {
+ //TODO: check this
+ return nil
+}
+
+// Field returns a MessageWriter for the calling Message, which makes it easier (and safer) to update the field whose
+// name was given.
+func (m Message) Field(name string) MessageWriter {
+ return MessageWriter{
+ message: m,
+ fieldQueue: []messageWriterChildPosition{{name, 0}},
+ }
+}
diff --git a/postgres/messages/message_decode_encode.go b/postgres/messages/message_decode_encode.go
new file mode 100644
index 0000000000..a26bddb365
--- /dev/null
+++ b/postgres/messages/message_decode_encode.go
@@ -0,0 +1,295 @@
+// Copyright 2023 Dolthub, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package messages
+
+import (
+ "bytes"
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "net"
+)
+
+// Receive returns a MessageType from the given buffer, generally generated by the client in the main read loop of a
+// connection.
+func Receive(buffer []byte) (MessageType, bool, error) {
+ if len(buffer) == 0 {
+ return nil, false, nil
+ }
+ message, ok := allMessageHeaders[buffer[0]]
+ if !ok {
+ return nil, false, nil
+ }
+ outMessage, err := ReceiveInto(buffer, message)
+ return outMessage, true, err
+}
+
+// ReceiveInto writes the contents of the buffer into the given MessageType.
+func ReceiveInto[T MessageType](buffer []byte, message T) (out T, err error) {
+ defaultMessage := message.defaultMessage()
+ fields := defaultMessage.Copy().Fields
+ if err = decode(&decodeBuffer{buffer}, [][]*Field{fields}, 1); err != nil {
+ return out, err
+ }
+ decodedMessage, err := message.decode(Message{defaultMessage.Name, fields, defaultMessage.info, false})
+ if err != nil {
+ return out, err
+ }
+ return decodedMessage.(T), nil
+}
+
+// Send sends the given message over the connection.
+func Send(conn net.Conn, message MessageType) error {
+ encodedMessage, err := message.encode()
+ if err != nil {
+ return err
+ }
+ data, err := encode(encodedMessage)
+ if err != nil {
+ return err
+ }
+ _, err = conn.Write(data)
+ return err
+}
+
+// decodeBuffer just provides an easy way to reference the same buffer, so that decode can modify its length.
+type decodeBuffer struct {
+ data []byte
+}
+
+// decode writes the contents of the buffer into the given fields. The iteration count determines how many times the
+// fields will be looped over.
+func decode(buffer *decodeBuffer, fields [][]*Field, iterations int32) error {
+ for iteration := int32(0); iteration < iterations; iteration++ {
+ for i, field := range fields[iteration] {
+ if len(buffer.data) == 0 {
+ return errors.New("buffer too small")
+ }
+ switch field.Type {
+ case Byte1, Int8:
+ field.Data = int32(buffer.data[0])
+ buffer.data = buffer.data[1:]
+ case ByteN:
+ if i > 0 && fields[iteration][i-1].Tags&ByteCount > 0 {
+ byteCount := fields[iteration][i-1].Data.(int32)
+ // -1 is a valid value for byte counts, which is used to signal a NULL value.
+ // We don't need to care about the assumption, so we can just treat it equivalent to zero.
+ if byteCount == -1 {
+ byteCount = 0
+ }
+ data := make([]byte, byteCount)
+ copy(data, buffer.data)
+ field.Data = data
+ buffer.data = buffer.data[byteCount:]
+ } else {
+ data := make([]byte, len(buffer.data))
+ copy(data, buffer.data)
+ field.Data = data
+ buffer.data = nil
+ }
+ case Int16:
+ field.Data = int32(binary.BigEndian.Uint16(buffer.data))
+ buffer.data = buffer.data[2:]
+ case Int32:
+ field.Data = int32(binary.BigEndian.Uint32(buffer.data))
+ buffer.data = buffer.data[4:]
+ case String:
+ found := false
+ for bufferIdx := range buffer.data {
+ if buffer.data[bufferIdx] == 0 {
+ field.Data = string(buffer.data[:bufferIdx])
+ buffer.data = buffer.data[bufferIdx:]
+ if field.Tags&ExcludeTerminator == 0 {
+ buffer.data = buffer.data[1:]
+ }
+ found = true
+ break
+ }
+ }
+ if !found {
+ return errors.New("terminating zero not found for string")
+ }
+ case Repeated:
+ // Track if we've decoded at least once, so that we only update the count if we've decoded something
+ decodedAtLeastOnce := false
+ originalChildren := field.Copy().Children[0]
+ for i := 1; len(buffer.data) > 0; i++ {
+ // If there is only a single byte left, then it may be the terminator, so we check.
+ // Otherwise, we'll assume that we should pass it to the child.
+ if len(buffer.data) == 1 && field.Tags&RepeatedTerminator > 0 {
+ if buffer.data[0] == 0 {
+ buffer.data = buffer.data[1:]
+ break
+ } else {
+ return fmt.Errorf("Expected terminator after Repeated type, found invalid byte: %d", buffer.data[0])
+ }
+ }
+ field.extend(i, originalChildren)
+ if err := decode(buffer, field.Children[len(field.Children)-1:], 1); err != nil {
+ return err
+ }
+ decodedAtLeastOnce = true
+ }
+ if decodedAtLeastOnce {
+ field.Data = int32(len(field.Children))
+ }
+ default:
+ panic("message type has not been defined")
+ }
+
+ if field.Tags&MessageLengthInclusive > 0 {
+ messageLength := field.Data.(int32)
+ switch field.Type {
+ case Byte1, Int8:
+ messageLength -= 1
+ case Int16:
+ messageLength -= 2
+ case Int32:
+ messageLength -= 4
+ }
+ buffer.data = buffer.data[:messageLength]
+ } else if field.Tags&MessageLengthExclusive > 0 {
+ buffer.data = buffer.data[:field.Data.(int32)]
+ }
+ if len(field.Children) > 0 && field.Type != Repeated {
+ count, ok := field.Data.(int32)
+ if !ok {
+ return errors.New("non-integer is being used as a count")
+ }
+ // Counts may be negative numbers, which have a special meaning depending on the message.
+ // In all such cases, they'll never have children, so we can just check for cases where it's > 0.
+ if count > 0 {
+ field.extend(int(count), field.Children[0])
+ if err := decode(buffer, field.Children, count); err != nil {
+ return err
+ }
+ }
+ }
+ }
+ }
+ return nil
+}
+
+// encode transforms the message into a byte slice, which may be sent to a connection.
+func encode(ms Message) ([]byte, error) {
+ buffer := bytes.Buffer{}
+ encodeLoop(&buffer, [][]*Field{ms.Fields}, 1)
+ if ms.info.appendNullByte {
+ buffer.WriteByte(0)
+ }
+ data := buffer.Bytes()
+
+ // Find and write the message length
+ byteOffset := int32(0)
+ for i, field := range ms.Fields {
+ if field.Tags&(MessageLengthInclusive|MessageLengthExclusive) > 0 {
+ typeLength := int32(0)
+ // Exclusive lengths must take their own type size into account and exclude them from the overall length
+ if field.Tags&MessageLengthExclusive > 0 {
+ switch field.Type {
+ case Byte1, Int8:
+ typeLength = 1
+ case Int16:
+ typeLength = 2
+ case Int32:
+ typeLength = 4
+ }
+ }
+ messageLength := int32(len(data)) - byteOffset - typeLength
+ switch field.Type {
+ case Byte1, Int8:
+ data[byteOffset] = byte(messageLength)
+ case Int16:
+ binary.BigEndian.PutUint16(data[byteOffset:], uint16(messageLength))
+ case Int32:
+ binary.BigEndian.PutUint32(data[byteOffset:], uint32(messageLength))
+ default:
+ panic("invalid type for message length")
+ }
+ break
+ }
+
+ // Advance the offset
+ switch field.Type {
+ case Byte1, Int8:
+ byteOffset += 1
+ case ByteN:
+ if i > 0 && ms.Fields[i-1].Tags&ByteCount > 0 {
+ byteOffset += ms.Fields[i-1].Data.(int32)
+ } else {
+ byteOffset = int32(len(data)) // Last field, so we can set it to the remaining data
+ }
+ case Int16:
+ byteOffset += 2
+ case Int32:
+ byteOffset += 4
+ case String:
+ found := false
+ for bufferIdx := range data[byteOffset:] {
+ if data[bufferIdx] == 0 {
+ found = true
+ byteOffset += int32(bufferIdx)
+ //TODO: is this the correct place to put this? investigate/test
+ if field.Tags&ExcludeTerminator == 0 {
+ byteOffset += 1
+ }
+ break
+ }
+ }
+ if !found {
+ panic("terminating zero not found for string")
+ }
+ case Repeated:
+ byteOffset = int32(len(data)) // Last field, so we can set it to the remaining data
+ default:
+ panic("message type has not been defined")
+ }
+ }
+ return data, nil
+}
+
+// encodeLoop is the inner recursive loop of encode, which writes the given fields into the buffer. The iteration
+// count determines how many times the fields are looped over.
+func encodeLoop(buffer *bytes.Buffer, fields [][]*Field, iterations int32) {
+ for iteration := int32(0); iteration < iterations; iteration++ {
+ for _, field := range fields[iteration] {
+ switch field.Type {
+ case Byte1:
+ _ = binary.Write(buffer, binary.BigEndian, byte(field.Data.(int32)))
+ case ByteN:
+ buffer.Write(field.Data.([]byte))
+ case Int8:
+ _ = binary.Write(buffer, binary.BigEndian, int8(field.Data.(int32)))
+ case Int16:
+ _ = binary.Write(buffer, binary.BigEndian, int16(field.Data.(int32)))
+ case Int32:
+ _ = binary.Write(buffer, binary.BigEndian, field.Data.(int32))
+ case String:
+ buffer.WriteString(field.Data.(string))
+ if field.Tags&ExcludeTerminator == 0 {
+ buffer.WriteByte(0)
+ }
+ case Repeated:
+ // We don't write anything for repeated fields, since they repeat their children until the end
+ default:
+ panic("message type has not been defined")
+ }
+
+ if len(field.Children) > 0 {
+ encodeLoop(buffer, field.Children, field.Data.(int32))
+ }
+ }
+ }
+}
diff --git a/postgres/messages/message_field.go b/postgres/messages/message_field.go
new file mode 100644
index 0000000000..a45938c1c7
--- /dev/null
+++ b/postgres/messages/message_field.go
@@ -0,0 +1,81 @@
+// Copyright 2023 Dolthub, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package messages
+
+// FieldType is the type of the field as defined by PostgreSQL.
+type FieldType byte
+
+const (
+ Byte1 FieldType = iota // Byte1 is a single unsigned byte.
+ ByteN // ByteN is a variable number of bytes. Allowed on the last field, or when a ByteCount-tagged field precedes it.
+ Int8 // Int8 is a single signed byte.
+ Int16 // Int16 are two bytes.
+ Int32 // Int32 are four bytes.
+ String // String is a variable-length type, generally punctuated by a NULL terminator.
+ Repeated // Repeated is a parent type that states its children will be repeated until the end of the message.
+
+ Byte4 = Int32 //TODO: verify that this is correct, only used on one type
+)
+
+// FieldTags are special attributes that may be assigned to fields.
+type FieldTags int32
+
+const (
+ Header FieldTags = 1 << iota // Header is the header of the message.
+ MessageLengthInclusive // MessageLengthInclusive is the length of the message, including the count's size.
+ MessageLengthExclusive // MessageLengthExclusive is the length of the message, excluding the count's size.
+ ExcludeTerminator // ExcludeTerminator excludes the terminator for String types.
+ ByteCount // ByteCount signals that the following ByteN non-child field uses this field for its count.
+ RepeatedTerminator // RepeatedTerminator states that the Repeated type always ends with a NULL terminator.
+)
+
+// Field is a field within the PostgreSQL message.
+type Field struct {
+ Name string
+ Type FieldType
+ Tags FieldTags
+ Data any // Data may ONLY be one of the following: int32, string, []byte. Nil is not allowed.
+ Children [][]*Field
+}
+
+// Copy returns a copy of this field, which is free to modify.
+func (f *Field) Copy() *Field {
+ fieldCopy := *f
+ if len(f.Children) > 0 {
+ newChildren := make([][]*Field, len(f.Children))
+ for groupIndex, fieldGroup := range f.Children {
+ newFields := make([]*Field, len(fieldGroup))
+ for fieldIndex, field := range fieldGroup {
+ newFields[fieldIndex] = field.Copy()
+ }
+ newChildren[groupIndex] = newFields
+ }
+ fieldCopy.Children = newChildren
+ }
+ return &fieldCopy
+}
+
+// extend lengthens the children to the new length, using the given default children to fill each newly-created entry.
+// All new entries will be copied from the default, therefore they're free to modify. This modifies the calling Field
+// in-place.
+func (f *Field) extend(newLength int, defaultChildren []*Field) {
+ for currentIndex := len(f.Children); currentIndex < newLength; currentIndex++ {
+ newFields := make([]*Field, len(defaultChildren))
+ for i, field := range defaultChildren {
+ newFields[i] = field.Copy()
+ }
+ f.Children = append(f.Children, newFields)
+ }
+}
diff --git a/postgres/messages/message_initialization.go b/postgres/messages/message_initialization.go
new file mode 100644
index 0000000000..900fa98b72
--- /dev/null
+++ b/postgres/messages/message_initialization.go
@@ -0,0 +1,206 @@
+// Copyright 2023 Dolthub, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package messages
+
+import "fmt"
+
+// allMessageHeaders contains any message headers that should be read within the main read loop of a connection.
+var allMessageHeaders = make(map[byte]MessageType)
+
+// allMessageNames contains the names of all messages, as they should all be unique.
+var allMessageNames = make(map[string]struct{})
+
+// allMessageDefaults contains all of the default message pointers, to make sure that they're not accidentally being reused.
+var allMessageDefaults = make(map[*Message]struct{})
+
+// addMessageHeader adds the given MessageType's header. This also ensures that each header is unique. This should be
+// called in an init() function.
+func addMessageHeader(message MessageType) {
+ for _, field := range message.defaultMessage().Fields {
+ if field.Tags&Header > 0 {
+ header := byte(field.Data.(int32))
+ if _, ok := allMessageHeaders[header]; ok {
+ panic(fmt.Errorf("Header already taken.\nMessage:\n\n%s", message.defaultMessage().String()))
+ }
+ allMessageHeaders[header] = message
+ return
+ }
+ }
+ panic(fmt.Errorf("Header does not exist.\nMessage:\n\n%s", message.defaultMessage().String()))
+}
+
+// initializeDefaultMessage creates the internal structure of the default message, while ensuring that the structure of
+// the message is correct. This should be called in an init() function.
+func initializeDefaultMessage(messageType MessageType) {
+ message := messageType.defaultMessage()
+ if _, ok := allMessageDefaults[message]; ok {
+ panic(fmt.Errorf("Message default was used in another message.\nMessage:\n\n%s", message.String()))
+ }
+ allMessageDefaults[message] = struct{}{}
+ if message.info != nil {
+ panic(fmt.Errorf("Message has already been initialized.\nMessage:\n\n%s", message.String()))
+ }
+ if _, ok := allMessageNames[message.Name]; ok {
+ panic(fmt.Errorf("Message has already been initialized with the same name.\nName: %s", message.Name))
+ }
+ allMessageNames[message.Name] = struct{}{}
+ message.info = &messageInfo{make(map[string]messageFieldInfo), false, message}
+ message.isDefault = true
+
+ allFieldNames := make(map[string]struct{}) // Verify that all field names are unique
+ headerFound := false // Only one header may exist in the message
+ messageLengthFound := false // Only one message length may exist in the message
+ endingByteNFound := false // If a ByteN has been found that does not have a preceding ByteCount
+ repeatedFoundHeight := 0 // The depth that a Repeated type has been found
+ type FieldTraversal struct {
+ Index int
+ Fields []*Field
+ }
+
+ ftStack := NewStack[FieldTraversal]()
+ ftStack.Push(FieldTraversal{0, message.Fields})
+ for !ftStack.Empty() {
+ // If we're at the end of the loop for this stacked entry, then we pop it and move to the next
+ if ftStack.Peek().Index >= len(ftStack.Peek().Fields) {
+ _ = ftStack.Pop()
+ continue
+ }
+ // Check if we've found a ByteN that is not preceded by a ByteCount-tagged field, as it should be the last
+ // field, and we're now looking at a field after it.
+ if endingByteNFound {
+ panic(fmt.Errorf("ByteN found that was not preceded by a field with the ByteCount tag.\nMessage:\n\n%s", message.String()))
+ }
+ // If the stack is larger than Repeated's height, then we're probably in Repeated's children.
+ // Otherwise, there are more non-child fields after the Repeated type.
+ if ftStack.Len() <= repeatedFoundHeight {
+ panic(fmt.Errorf("Repeated is not on the last field at its level\nMessage:\n\n%s", message.String()))
+ }
+ // Grab the field.
+ field := ftStack.Peek().Fields[ftStack.Peek().Index]
+ // Verify uniqueness and correctness of tags (if any)
+ if field.Tags&Header > 0 {
+ if headerFound {
+ panic(fmt.Errorf("Multiple headers in message.\nMessage:\n\n%s", message.String()))
+ }
+ headerFound = true
+ }
+ if field.Tags&(MessageLengthInclusive|MessageLengthExclusive) > 0 {
+ if messageLengthFound {
+ panic(fmt.Errorf("Multiple message lengths in message.\nMessage:\n\n%s", message.String()))
+ }
+ switch field.Type {
+ case Byte1, Int8, Int16, Int32:
+ default:
+ panic(fmt.Errorf("Message length tags are only allowed on integer types.\nField: %s\nMessage:\n\n%s", field.Name, message.String()))
+ }
+ messageLengthFound = true
+ }
+ if field.Tags&ByteCount > 0 {
+ switch field.Type {
+ case Byte1, Int8, Int16, Int32:
+ default:
+ panic(fmt.Errorf("ByteCount tag is only allowed on integer types.\nField: %s\nMessage:\n\n%s", field.Name, message.String()))
+ }
+ }
+ if field.Tags&ExcludeTerminator > 0 && field.Type != String {
+ panic(fmt.Errorf("ExcludeTerminator tag is only allowed on String fields.\nField: %s\nMessage:\n\n%s", field.Name, message.String()))
+ }
+ // Verify uniqueness of names (case-sensitive for maximum flexibility)
+ if len(field.Name) == 0 {
+ panic(fmt.Errorf("All fields must have a name.\nMessage:\n\n%s", message.String()))
+ }
+ if _, ok := allFieldNames[field.Name]; ok {
+ panic(fmt.Errorf("Multiple fields with the same name.\nMessage:\n\n%s", message.String()))
+ }
+ allFieldNames[field.Name] = struct{}{}
+ // Verify that ByteN is the last field, or is preceded by a field with the ByteCount tag
+ usesByteCount := false
+ if field.Type == ByteN {
+ // If the preceding field has the ByteCount tag, then ByteN does not have the ending-field-only restriction
+ if ftStack.Peek().Index > 0 && (ftStack.Peek().Fields[ftStack.Peek().Index-1].Tags&ByteCount > 0) {
+ usesByteCount = true
+ } else {
+ endingByteNFound = true
+ }
+ }
+ // Verify the type for each default value
+ switch field.Type {
+ case Byte1, Int8, Int16, Int32, Repeated:
+ if _, ok := field.Data.(int32); !ok {
+ panic(fmt.Errorf("Integer field types must set their Data to an int32 value.\nField: %s\nMessage:\n\n%s", field.Name, message.String()))
+ }
+ case ByteN:
+ if _, ok := field.Data.([]byte); !ok {
+ panic(fmt.Errorf("ByteN fields must set their Data to a []byte value.\nField: %s\nMessage:\n\n%s", field.Name, message.String()))
+ }
+ case String:
+ if _, ok := field.Data.(string); !ok {
+ panic(fmt.Errorf("String fields must set their Data to a string value.\nField: %s\nMessage:\n\n%s", field.Name, message.String()))
+ }
+ default:
+ panic("message type has not been defined")
+ }
+ // Verify that, for fields with children, the default count matches the default child count
+ if len(field.Children) > 0 {
+ count := int32(0)
+ switch field.Type {
+ case Byte1, Int8, Int16, Int32, Repeated:
+ count = field.Data.(int32)
+ default:
+ panic(fmt.Errorf("Only integer types may have children, as they determine the count.\nField: %s\nMessage:\n\n%s", field.Name, message.String()))
+ }
+ // A value of zero means that the child is only used as a prototype. A value of one means that the child is
+ // actually used as a default value. We do not allow declaring children with multiple default values.
+ if count != 0 && count != 1 {
+ panic(fmt.Errorf("Only integer types may have children, as they determine the count.\nField: %s\nMessage:\n\n%s", field.Name, message.String()))
+ }
+ if len(field.Children) > 1 {
+ panic(fmt.Errorf("Only a single child may be declared in the default message.\nField: %s\nMessage:\n\n%s", field.Name, message.String()))
+ }
+ }
+ // Repeated may only be on a single field. Children of a Repeated field cannot also have Repeated children.
+ if field.Type == Repeated {
+ if repeatedFoundHeight != 0 {
+ panic(fmt.Errorf("Multiple Repeated types declared.\nField: %s\nMessage:\n\n%s", field.Name, message.String()))
+ }
+ repeatedFoundHeight = ftStack.Len()
+ }
+ // RepeatedTerminator is only allowed on Repeated types, and therefore follows all of its restrictions automatically.
+ if field.Tags&RepeatedTerminator > 0 {
+ if field.Type != Repeated {
+ panic(fmt.Errorf("RepeatedTerminator may only be used on a Repeated type.\nMessage:\n\n%s", message.String()))
+ }
+ message.info.appendNullByte = true
+ }
+
+ // Write the field info into our message
+ parentName := ""
+ if ftStack.Len() > 1 {
+ parentName = ftStack.PeekDepth(1).Fields[ftStack.PeekDepth(1).Index-1].Name
+ }
+ message.info.fieldInfo[field.Name] = messageFieldInfo{
+ RelativeIndex: ftStack.Peek().Index,
+ Parent: parentName,
+ UsesByteCount: usesByteCount,
+ }
+
+ // Increment the index
+ ftStack.PeekReference().Index++
+ // If there are any children, then we throw them onto the stack
+ if len(field.Children) == 1 {
+ ftStack.Push(FieldTraversal{0, field.Children[0]})
+ }
+ }
+}
diff --git a/postgres/messages/message_writer.go b/postgres/messages/message_writer.go
new file mode 100644
index 0000000000..20ebf72b96
--- /dev/null
+++ b/postgres/messages/message_writer.go
@@ -0,0 +1,167 @@
+// Copyright 2023 Dolthub, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package messages
+
+import "fmt"
+
+// MessageWriter is used to easily (and safely) interact with the contents of a Message.
+type MessageWriter struct {
+ message Message
+ fieldQueue []messageWriterChildPosition
+}
+
+// messageWriterChildPosition contains the name and position of a queued entry in a MessageWriter.
+type messageWriterChildPosition struct {
+ name string
+ position int
+}
+
+// Child returns a new MessageWriter pointing to the child of the field chain provided so far. As there may be multiple
+// children, the position determines which child will be referenced. If a child position is given that does not exist,
+// then a child at that position will be created. If the position is more than an increment, then this will also create
+// all children up to the position, by giving them their default values.
+func (mw MessageWriter) Child(name string, position int) MessageWriter {
+ fieldQueue := make([]messageWriterChildPosition, len(mw.fieldQueue)+1)
+ copy(fieldQueue, mw.fieldQueue)
+ fieldQueue[len(fieldQueue)-1] = messageWriterChildPosition{name, position}
+ return MessageWriter{
+ message: mw.message,
+ fieldQueue: fieldQueue,
+ }
+}
+
+// Write writes the given value to the field pointed to by the field chain provided. Only accepts values with the
+// following types: int/8/16/32/64, uint/8/16/32/64, string, []byte (use an empty slice instead of nil).
+func (mw MessageWriter) Write(value any) error {
+ if mw.message.isDefault {
+ return fmt.Errorf("Cannot write to the default message: %s", mw.message.Name)
+ }
+
+ var field *Field
+ var defaultField *Field
+ if fieldInfo, ok := mw.message.info.fieldInfo[mw.fieldQueue[0].name]; ok {
+ field = mw.message.Fields[fieldInfo.RelativeIndex]
+ defaultField = mw.message.info.defaultMessage.Fields[fieldInfo.RelativeIndex]
+ } else {
+ return fmt.Errorf(`The message "%s" does not contain a field named "%s"`, mw.message.Name, mw.fieldQueue[0].name)
+ }
+ fq := mw.fieldQueue[1:]
+ for len(fq) > 0 {
+ fieldInfo, ok := mw.message.info.fieldInfo[fq[0].name]
+ if !ok {
+ return fmt.Errorf(`The message "%s" does not contain a field named "%s"`, mw.message.Name, fq[0].name)
+ }
+ if fieldInfo.Parent != field.Name {
+ return fmt.Errorf(`In the message "%s", the field "%s"" is not a child of the field "%s"`,
+ mw.message.Name, fq[0].name, field.Name)
+ }
+ field.extend(fq[0].position+1, defaultField.Children[0]) // extend() takes the length, so add 1 to the position
+ field.Data = int32(len(field.Children)) // All types that have children are integer types
+ field = field.Children[fq[0].position][fieldInfo.RelativeIndex]
+ defaultField = defaultField.Children[0][fieldInfo.RelativeIndex]
+ // Remove the child from the queue
+ fq = fq[1:]
+ }
+
+ switch field.Type {
+ case Byte1, Int8, Int16, Int32, Repeated:
+ switch value := value.(type) {
+ case int:
+ field.Data = int32(value)
+ case int8:
+ field.Data = int32(value)
+ case int16:
+ field.Data = int32(value)
+ case int32:
+ field.Data = value
+ case int64:
+ field.Data = int32(value)
+ case uint:
+ field.Data = int32(value)
+ case uint8:
+ field.Data = int32(value)
+ case uint16:
+ field.Data = int32(value)
+ case uint32:
+ field.Data = int32(value)
+ case uint64:
+ field.Data = int32(value)
+ default:
+ return fmt.Errorf("Attempted to write an invalid value of type `%T` into the following integer field: %s", value, field.Name)
+ }
+ case ByteN:
+ switch value := value.(type) {
+ case []byte:
+ field.Data = value
+ default:
+ return fmt.Errorf("Attempted to write an invalid value of type `%T` into the following ByteN field: %s", value, field.Name)
+ }
+ case String:
+ switch value := value.(type) {
+ case string:
+ field.Data = value
+ default:
+ return fmt.Errorf("Attempted to write an invalid value of type `%T` into the following String field: %s", value, field.Name)
+ }
+ default:
+ panic("message type has not been defined")
+ }
+ return nil
+}
+
+// Get returns the value of the field pointed to by the field chain provided.
+func (mw MessageWriter) Get() (any, error) {
+ var field *Field
+ if fieldInfo, ok := mw.message.info.fieldInfo[mw.fieldQueue[0].name]; ok {
+ field = mw.message.Fields[fieldInfo.RelativeIndex]
+ } else {
+ return nil, fmt.Errorf(`The message "%s" does not contain a field named "%s"`, mw.message.Name, mw.fieldQueue[0].name)
+ }
+ fq := mw.fieldQueue[1:]
+ for len(fq) > 0 {
+ fieldInfo, ok := mw.message.info.fieldInfo[fq[0].name]
+ if !ok {
+ return nil, fmt.Errorf(`The message "%s" does not contain a field named "%s"`, mw.message.Name, fq[0].name)
+ }
+ if fieldInfo.Parent != field.Name {
+ return nil, fmt.Errorf(`In the message "%s", the field "%s" is not a child of the field "%s"`,
+ mw.message.Name, fq[0].name, field.Name)
+ }
+ if fq[0].position >= len(field.Children) {
+ return nil, fmt.Errorf("Index out of bounds.\nMessage: %s\nField: %s\nIndex: %d\nLength: %d",
+ mw.message.Name, field.Name, fq[0].position, len(field.Children))
+ }
+ field = field.Children[fq[0].position][fieldInfo.RelativeIndex]
+ // Remove the child from the queue
+ fq = fq[1:]
+ }
+ return field.Data, nil
+}
+
+// MustWrite is the same as Write, except that this panics on errors rather than returning them.
+func (mw MessageWriter) MustWrite(value any) {
+ if err := mw.Write(value); err != nil {
+ panic(err)
+ }
+}
+
+// MustGet is the same as Get, except that this panics on errors rather than returning them.
+func (mw MessageWriter) MustGet() any {
+ value, err := mw.Get()
+ if err != nil {
+ panic(err)
+ }
+ return value
+}
diff --git a/postgres/messages/parameter_status.go b/postgres/messages/parameter_status.go
index 03f076de48..b7dfbc0d65 100644
--- a/postgres/messages/parameter_status.go
+++ b/postgres/messages/parameter_status.go
@@ -1,6 +1,22 @@
+// Copyright 2023 Dolthub, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
package messages
-import "bytes"
+func init() {
+ initializeDefaultMessage(ParameterStatus{})
+}
// ParameterStatus reports various parameters to the client.
type ParameterStatus struct {
@@ -8,14 +24,56 @@ type ParameterStatus struct {
Value string
}
-// Bytes returns ParameterStatus as a byte slice, ready to be returned to the client.
-func (ps ParameterStatus) Bytes() []byte {
- buf := bytes.Buffer{}
- buf.WriteByte('S') // Message Type
- WriteNumber(&buf, int32(0)) // Message length, will be corrected later
- buf.WriteString(ps.Name)
- buf.WriteByte(0) // Trailing NULL character, denoting the end of the string
- buf.WriteString(ps.Value)
- buf.WriteByte(0) // Trailing NULL character, denoting the end of the string
- return WriteLength(buf.Bytes())
+var parameterStatusDefault = Message{
+ Name: "ParameterStatus",
+ Fields: []*Field{
+ {
+ Name: "Header",
+ Type: Byte1,
+ Tags: Header,
+ Data: int32('S'),
+ },
+ {
+ Name: "MessageLength",
+ Type: Int32,
+ Tags: MessageLengthInclusive,
+ Data: int32(0),
+ },
+ {
+ Name: "Name",
+ Type: String,
+ Data: "",
+ },
+ {
+ Name: "Value",
+ Type: String,
+ Data: "",
+ },
+ },
+}
+
+var _ MessageType = ParameterStatus{}
+
+// encode implements the interface MessageType.
+func (m ParameterStatus) encode() (Message, error) {
+ outputMessage := m.defaultMessage().Copy()
+ outputMessage.Field("Name").MustWrite(m.Name)
+ outputMessage.Field("Value").MustWrite(m.Value)
+ return outputMessage, nil
+}
+
+// decode implements the interface MessageType.
+func (m ParameterStatus) decode(s Message) (MessageType, error) {
+ if err := s.MatchesStructure(*m.defaultMessage()); err != nil {
+ return nil, err
+ }
+ return ParameterStatus{
+ Name: s.Field("Name").MustGet().(string),
+ Value: s.Field("Value").MustGet().(string),
+ }, nil
+}
+
+// defaultMessage implements the interface MessageType.
+func (m ParameterStatus) defaultMessage() *Message {
+ return ¶meterStatusDefault
}
diff --git a/postgres/messages/query.go b/postgres/messages/query.go
index 6d3238e31f..6a8e7b5e69 100644
--- a/postgres/messages/query.go
+++ b/postgres/messages/query.go
@@ -1,24 +1,72 @@
+// Copyright 2023 Dolthub, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
package messages
-import (
- "encoding/binary"
-)
+func init() {
+ initializeDefaultMessage(Query{})
+ addMessageHeader(Query{})
+}
-// ReadQuery returns the query from the given buffer. Assumes that the buffer contains a serialized form of a Query
-// message.
-func ReadQuery(buf []byte) (string, bool) {
- if len(buf) < 5 {
- return "", false
- }
- if buf[0] != 'Q' {
- return "", false
- }
- queryLength := int32(binary.BigEndian.Uint32(buf[1:]))
- if queryLength <= 5 {
- // A query of length 5 or less is empty
- return "", true
+// Query contains a query given by the client.
+type Query struct {
+ String string
+}
+
+var queryDefault = Message{
+ Name: "Query",
+ Fields: []*Field{
+ {
+ Name: "Header",
+ Type: Byte1,
+ Tags: Header,
+ Data: int32('Q'),
+ },
+ {
+ Name: "MessageLength",
+ Type: Int32,
+ Tags: MessageLengthInclusive,
+ Data: int32(0),
+ },
+ {
+ Name: "String",
+ Type: String,
+ Data: "",
+ },
+ },
+}
+
+var _ MessageType = Query{}
+
+// encode implements the interface MessageType.
+func (m Query) encode() (Message, error) {
+ outputMessage := m.defaultMessage().Copy()
+ outputMessage.Field("String").MustWrite(m.String)
+ return outputMessage, nil
+}
+
+// decode implements the interface MessageType.
+func (m Query) decode(s Message) (MessageType, error) {
+ if err := s.MatchesStructure(*m.defaultMessage()); err != nil {
+ return nil, err
}
- // The length includes the length bytes, along with the NULL terminator. It does not include the message identifier
- // though, so it cancels out the NULL terminator and allows us to use the length as-is.
- return string(buf[5:queryLength]), true
+ return Query{
+ String: s.Field("String").MustGet().(string),
+ }, nil
+}
+
+// defaultMessage implements the interface MessageType.
+func (m Query) defaultMessage() *Message {
+ return &queryDefault
}
diff --git a/postgres/messages/ready_for_query.go b/postgres/messages/ready_for_query.go
index 5fa9db449f..d5a4922168 100644
--- a/postgres/messages/ready_for_query.go
+++ b/postgres/messages/ready_for_query.go
@@ -1,12 +1,30 @@
+// Copyright 2023 Dolthub, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
package messages
+func init() {
+ initializeDefaultMessage(ReadyForQuery{})
+}
+
// ReadyForQueryTransactionIndicator indicates the state of the transaction related to the query.
type ReadyForQueryTransactionIndicator byte
const (
- ReadyForQueryTransactionIndicator_Idle ReadyForQueryTransactionIndicator = iota
- ReadyForQueryTransactionIndicator_TransactionBlock
- ReadyForQueryTransactionIndicator_FailedTransactionBlock
+ ReadyForQueryTransactionIndicator_Idle ReadyForQueryTransactionIndicator = 'I'
+ ReadyForQueryTransactionIndicator_TransactionBlock ReadyForQueryTransactionIndicator = 'T'
+ ReadyForQueryTransactionIndicator_FailedTransactionBlock ReadyForQueryTransactionIndicator = 'E'
)
// ReadyForQuery tells the client that the server is ready for a new query cycle.
@@ -14,16 +32,49 @@ type ReadyForQuery struct {
Indicator ReadyForQueryTransactionIndicator
}
-// Bytes returns ReadyForQuery as a byte slice, ready to be returned to the client.
-func (rfq ReadyForQuery) Bytes() []byte {
- var indicator byte
- switch rfq.Indicator {
- case ReadyForQueryTransactionIndicator_Idle:
- indicator = 'I'
- case ReadyForQueryTransactionIndicator_TransactionBlock:
- indicator = 'T'
- case ReadyForQueryTransactionIndicator_FailedTransactionBlock:
- indicator = 'E'
+var readyForQueryDefault = Message{
+ Name: "ReadyForQuery",
+ Fields: []*Field{
+ {
+ Name: "Header",
+ Type: Byte1,
+ Tags: Header,
+ Data: int32('Z'),
+ },
+ {
+ Name: "MessageLength",
+ Type: Int32,
+ Tags: MessageLengthInclusive,
+ Data: int32(5),
+ },
+ {
+ Name: "TransactionIndicator",
+ Type: Byte1,
+ Data: int32(0),
+ },
+ },
+}
+
+var _ MessageType = ReadyForQuery{}
+
+// encode implements the interface MessageType.
+func (m ReadyForQuery) encode() (Message, error) {
+ outputMessage := m.defaultMessage().Copy()
+ outputMessage.Field("TransactionIndicator").MustWrite(byte(m.Indicator))
+ return outputMessage, nil
+}
+
+// decode implements the interface MessageType.
+func (m ReadyForQuery) decode(s Message) (MessageType, error) {
+ if err := s.MatchesStructure(*m.defaultMessage()); err != nil {
+ return nil, err
}
- return []byte{'Z', 0, 0, 0, 5, indicator}
+ return ReadyForQuery{
+ Indicator: ReadyForQueryTransactionIndicator(s.Field("TransactionIndicator").MustGet().(int32)),
+ }, nil
+}
+
+// defaultMessage implements the interface MessageType.
+func (m ReadyForQuery) defaultMessage() *Message {
+ return &readyForQueryDefault
}
diff --git a/postgres/messages/row_description.go b/postgres/messages/row_description.go
index c1545d84b8..ec290bc6ba 100644
--- a/postgres/messages/row_description.go
+++ b/postgres/messages/row_description.go
@@ -1,81 +1,147 @@
+// Copyright 2023 Dolthub, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
package messages
import (
- "bytes"
"fmt"
+ "github.com/dolthub/go-mysql-server/sql"
+
"github.com/dolthub/vitess/go/vt/proto/query"
)
+func init() {
+ initializeDefaultMessage(RowDescription{})
+}
+
// RowDescription represents a RowDescription message intended for the client.
type RowDescription struct {
- Fields []RowDescriptionField
+ Fields []*query.Field
}
-// NewRowDescription creates a new RowDescription from the given fields.
-func NewRowDescription(fields []*query.Field) (RowDescription, error) {
- var err error
- rdFields := make([]RowDescriptionField, len(fields))
- for i, field := range fields {
- rdFields[i] = RowDescriptionField{
- TableObjectID: 0, // Unused for now
- ColumnAttributeNumber: 0, // Unused for now
- DataTypeModifier: 0, // Always -1 since we're supporting a narrow set of integers
- FormatCode: 0, // Always text for now
+var rowDescriptionDefault = Message{
+ Name: "RowDescription",
+ Fields: []*Field{
+ {
+ Name: "Header",
+ Type: Byte1,
+ Tags: Header,
+ Data: int32('T'),
+ },
+ {
+ Name: "MessageLength",
+ Type: Int32,
+ Tags: MessageLengthInclusive,
+ Data: int32(0),
+ },
+ {
+ Name: "Fields",
+ Type: Int16,
+ Data: int32(0),
+ Children: [][]*Field{
+ {
+ {
+ Name: "ColumnName",
+ Type: String,
+ Data: "",
+ },
+ {
+ Name: "TableObjectID",
+ Type: Int32,
+ Data: int32(0),
+ },
+ {
+ Name: "ColumnAttributeNumber",
+ Type: Int16,
+ Data: int32(0),
+ },
+ {
+ Name: "DataTypeObjectID",
+ Type: Int32,
+ Data: int32(0),
+ },
+ {
+ Name: "DataTypeSize",
+ Type: Int16,
+ Data: int32(0),
+ },
+ {
+ Name: "DataTypeModifier",
+ Type: Int32,
+ Data: int32(0),
+ },
+ {
+ Name: "FormatCode",
+ Type: Int16,
+ Data: int32(0),
+ },
+ },
+ },
+ },
+ },
+}
+
+var _ MessageType = RowDescription{}
+
+// encode implements the interface MessageType.
+func (m RowDescription) encode() (Message, error) {
+ outputMessage := m.defaultMessage().Copy()
+ for i := 0; i < len(m.Fields); i++ {
+ field := m.Fields[i]
+ dataTypeObjectID, err := VitessFieldToDataTypeObjectID(field)
+ if err != nil {
+ return Message{}, err
}
- rdFields[i].Name = field.Name
- rdFields[i].DataTypeObjectID, err = VitessTypeToDataTypeObjectID(field.Type)
+ dataTypeSize, err := VitessFieldToDataTypeSize(field)
if err != nil {
- return RowDescription{}, err
+ return Message{}, err
}
- rdFields[i].DataTypeSize, err = VitessTypeToDataTypeSize(field.Type)
+ dataTypeModifier, err := VitessFieldToDataTypeModifier(field)
if err != nil {
- return RowDescription{}, err
+ return Message{}, err
}
+ outputMessage.Field("Fields").Child("ColumnName", i).MustWrite(field.Name)
+ outputMessage.Field("Fields").Child("DataTypeObjectID", i).MustWrite(dataTypeObjectID)
+ outputMessage.Field("Fields").Child("DataTypeSize", i).MustWrite(dataTypeSize)
+ outputMessage.Field("Fields").Child("DataTypeModifier", i).MustWrite(dataTypeModifier)
}
- return RowDescription{
- Fields: rdFields,
- }, nil
+ return outputMessage, nil
}
-// RowDescriptionField represents a field in RowDescription.
-type RowDescriptionField struct {
- Name string
- TableObjectID int32
- ColumnAttributeNumber int16
- DataTypeObjectID int32
- DataTypeSize int16
- DataTypeModifier int32
- FormatCode int16
-}
-
-// Bytes returns RowDescription as a byte slice, ready to be returned to the client.
-func (rd RowDescription) Bytes() []byte {
- buf := bytes.Buffer{}
- buf.WriteByte('T') // Message Type
- WriteNumber(&buf, int32(0)) // Message length, will be corrected later
- WriteNumber(&buf, int16(len(rd.Fields)))
- for _, rdf := range rd.Fields {
- rdf.Bytes(&buf)
+// decode implements the interface MessageType.
+func (m RowDescription) decode(s Message) (MessageType, error) {
+ if err := s.MatchesStructure(*m.defaultMessage()); err != nil {
+ return nil, err
+ }
+ fieldCount := int(s.Field("Fields").MustGet().(int32))
+ for i := 0; i < fieldCount; i++ {
+ //TODO: decode the message in here
}
- return WriteLength(buf.Bytes())
+ return RowDescription{
+ Fields: nil,
+ }, nil
}
-// Bytes writes the field into the given buffer.
-func (rdf RowDescriptionField) Bytes(buf *bytes.Buffer) {
- buf.WriteString(rdf.Name)
- buf.WriteByte(0) // Trailing NULL character, denoting the end of the string
- WriteNumber(buf, rdf.TableObjectID)
- WriteNumber(buf, rdf.ColumnAttributeNumber)
- WriteNumber(buf, rdf.DataTypeObjectID)
- WriteNumber(buf, rdf.DataTypeSize)
- WriteNumber(buf, rdf.DataTypeModifier)
- WriteNumber(buf, rdf.FormatCode)
+// defaultMessage implements the interface MessageType.
+func (m RowDescription) defaultMessage() *Message {
+ return &rowDescriptionDefault
}
-// VitessTypeToDataTypeObjectID returns a type, as defined by Vitess, into a type as defined by Postgres.
-func VitessTypeToDataTypeObjectID(typ query.Type) (int32, error) {
- switch typ {
+// VitessFieldToDataTypeObjectID returns a type, as defined by Vitess, into a type as defined by Postgres.
+func VitessFieldToDataTypeObjectID(field *query.Field) (int32, error) {
+ switch field.Type {
case query.Type_INT8:
return 17, nil
case query.Type_INT16:
@@ -86,14 +152,20 @@ func VitessTypeToDataTypeObjectID(typ query.Type) (int32, error) {
return 23, nil
case query.Type_INT64:
return 20, nil
+ case query.Type_CHAR:
+ return 18, nil
+ case query.Type_VARCHAR:
+ return 1043, nil
+ case query.Type_TEXT:
+ return 25, nil
default:
return 0, fmt.Errorf("unsupported type returned from engine")
}
}
-// VitessTypeToDataTypeSize returns the type's size, as defined by Vitess, into the size as defined by Postgres.
-func VitessTypeToDataTypeSize(typ query.Type) (int16, error) {
- switch typ {
+// VitessFieldToDataTypeSize returns the type's size, as defined by Vitess, into the size as defined by Postgres.
+func VitessFieldToDataTypeSize(field *query.Field) (int16, error) {
+ switch field.Type {
case query.Type_INT8:
return 1, nil
case query.Type_INT16:
@@ -104,6 +176,38 @@ func VitessTypeToDataTypeSize(typ query.Type) (int16, error) {
return 4, nil
case query.Type_INT64:
return 8, nil
+ case query.Type_CHAR:
+ return -1, nil
+ case query.Type_VARCHAR:
+ return -1, nil
+ case query.Type_TEXT:
+ return -1, nil
+ default:
+ return 0, fmt.Errorf("unsupported type returned from engine")
+ }
+}
+
+// VitessFieldToDataTypeModifier returns the field's data type modifier as defined by Postgres.
+func VitessFieldToDataTypeModifier(field *query.Field) (int32, error) {
+ switch field.Type {
+ case query.Type_INT8:
+ return -1, nil
+ case query.Type_INT16:
+ return -1, nil
+ case query.Type_INT24:
+ return -1, nil
+ case query.Type_INT32:
+ return -1, nil
+ case query.Type_INT64:
+ return -1, nil
+ case query.Type_CHAR:
+ // PostgreSQL adds 4 to the length for an unknown reason
+ return int32(int64(field.ColumnLength)/sql.CharacterSetID(field.Charset).MaxLength()) + 4, nil
+ case query.Type_VARCHAR:
+ // PostgreSQL adds 4 to the length for an unknown reason
+ return int32(int64(field.ColumnLength)/sql.CharacterSetID(field.Charset).MaxLength()) + 4, nil
+ case query.Type_TEXT:
+ return -1, nil
default:
return 0, fmt.Errorf("unsupported type returned from engine")
}
diff --git a/postgres/messages/ssl_response.go b/postgres/messages/ssl_response.go
index 7ee6af9676..a8eac2f45e 100644
--- a/postgres/messages/ssl_response.go
+++ b/postgres/messages/ssl_response.go
@@ -1,15 +1,74 @@
+// Copyright 2023 Dolthub, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
package messages
+import "fmt"
+
+func init() {
+ initializeDefaultMessage(SSLResponse{})
+}
+
// SSLResponse tells the client whether SSL is supported.
type SSLResponse struct {
SupportsSSL bool
}
-// Bytes returns SSLResponse as a byte slice, ready to be returned to the client.
-func (sslr SSLResponse) Bytes() []byte {
- if sslr.SupportsSSL {
- return []byte{'Y'}
+var sslResponseDefault = Message{
+ Name: "SSLResponse",
+ Fields: []*Field{
+ {
+ Name: "Supported",
+ Type: Byte1,
+ Data: int32(0),
+ },
+ },
+}
+
+var _ MessageType = SSLResponse{}
+
+// encode implements the interface MessageType.
+func (m SSLResponse) encode() (Message, error) {
+ outputMessage := m.defaultMessage().Copy()
+ if m.SupportsSSL {
+ outputMessage.Field("Supported").MustWrite('Y')
+ } else {
+ outputMessage.Field("Supported").MustWrite('N')
+ }
+ return outputMessage, nil
+}
+
+// decode implements the interface MessageType.
+func (m SSLResponse) decode(s Message) (MessageType, error) {
+ if err := s.MatchesStructure(*m.defaultMessage()); err != nil {
+ return nil, err
+ }
+ var supported bool
+ supportedInt := s.Field("Supported").MustGet().(int32)
+ if supportedInt == 'Y' {
+ supported = true
+ } else if supportedInt == 'N' {
+ supported = false
} else {
- return []byte{'N'}
+ return nil, fmt.Errorf("Unexpected supported value in SSLResponse message: %d", supportedInt)
}
+ return SSLResponse{
+ SupportsSSL: supported,
+ }, nil
+}
+
+// defaultMessage implements the interface MessageType.
+func (m SSLResponse) defaultMessage() *Message {
+ return &sslResponseDefault
}
diff --git a/postgres/messages/startup_message.go b/postgres/messages/startup_message.go
index 57a05593e2..b1696a03d5 100644
--- a/postgres/messages/startup_message.go
+++ b/postgres/messages/startup_message.go
@@ -1,52 +1,107 @@
+// Copyright 2023 Dolthub, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
package messages
-import (
- "encoding/binary"
- "fmt"
-)
+func init() {
+ initializeDefaultMessage(StartupMessage{})
+}
// StartupMessage is returned by the client upon connecting to the server, providing details about the client.
type StartupMessage struct {
- ProtocolMajorVersion int16
- ProtocolMinorVersion int16
+ ProtocolMajorVersion int
+ ProtocolMinorVersion int
Parameters map[string]string
}
-// ReadStartupMessage returns the StartupMessage from the buffer.
-func ReadStartupMessage(buf []byte) (StartupMessage, error) {
- if len(buf) < 4 {
- return StartupMessage{}, fmt.Errorf("invalid StartupMessage")
+var startupMessageDefault = Message{
+ Name: "StartupMessage",
+ Fields: []*Field{
+ {
+ Name: "MessageLength",
+ Type: Int32,
+ Tags: MessageLengthInclusive,
+ Data: int32(0),
+ },
+ { // The docs specify a single Int32 field, but the upper and lower bits are different values, so this just splits them
+ Name: "ProtocolMajorVersion",
+ Type: Int16,
+ Data: int32(0),
+ },
+ {
+ Name: "ProtocolMinorVersion",
+ Type: Int16,
+ Data: int32(0),
+ },
+ {
+ Name: "Parameters",
+ Type: Repeated,
+ Tags: RepeatedTerminator,
+ Data: int32(0),
+ Children: [][]*Field{
+ {
+ {
+ Name: "ParameterName",
+ Type: String,
+ Data: "",
+ },
+ {
+ Name: "ParameterValue",
+ Type: String,
+ Data: "",
+ },
+ },
+ },
+ },
+ },
+}
+
+var _ MessageType = StartupMessage{}
+
+// encode implements the interface MessageType.
+func (m StartupMessage) encode() (Message, error) {
+ outputMessage := m.defaultMessage().Copy()
+ outputMessage.Field("ProtocolMajorVersion").MustWrite(m.ProtocolMajorVersion)
+ outputMessage.Field("ProtocolMinorVersion").MustWrite(m.ProtocolMinorVersion)
+ index := 0
+ for name, value := range m.Parameters {
+ outputMessage.Field("Parameters").Child("ParameterName", index).MustWrite(name)
+ outputMessage.Field("Parameters").Child("ParameterValue", index).MustWrite(value)
+ index++
+ }
+ return outputMessage, nil
+}
+
+// decode implements the interface MessageType.
+func (m StartupMessage) decode(s Message) (MessageType, error) {
+ if err := s.MatchesStructure(*m.defaultMessage()); err != nil {
+ return nil, err
}
- messageLength := int32(binary.BigEndian.Uint32(buf))
- protocolMajorVersion := int16(binary.BigEndian.Uint16(buf[4:]))
- protocolMinorVersion := int16(binary.BigEndian.Uint16(buf[6:]))
- // Set the buffer to the stated length and skip the length and version
- buf = buf[8:messageLength]
parameters := make(map[string]string)
- for len(buf) > 0 {
- var name string
- var value string
- for i, b := range buf {
- if b == 0 {
- name = string(buf[:i])
- buf = buf[i+1:]
- break
- }
- }
- for i, b := range buf {
- if b == 0 {
- value = string(buf[:i])
- buf = buf[i+1:]
- break
- }
- }
- if len(name) > 0 && len(value) > 0 {
- parameters[name] = value
- }
+ count := int(s.Field("Parameters").MustGet().(int32))
+ for i := 0; i < count; i++ {
+ parameters[s.Field("Parameters").Child("ParameterName", i).MustGet().(string)] =
+ s.Field("Parameters").Child("ParameterValue", i).MustGet().(string)
}
return StartupMessage{
- ProtocolMajorVersion: protocolMajorVersion,
- ProtocolMinorVersion: protocolMinorVersion,
+ ProtocolMajorVersion: int(s.Field("ProtocolMajorVersion").MustGet().(int32)),
+ ProtocolMinorVersion: int(s.Field("ProtocolMinorVersion").MustGet().(int32)),
Parameters: parameters,
}, nil
}
+
+// defaultMessage implements the interface MessageType.
+func (m StartupMessage) defaultMessage() *Message {
+ return &startupMessageDefault
+}
diff --git a/postgres/messages/terminate.go b/postgres/messages/terminate.go
index d55b11e44f..6966961162 100644
--- a/postgres/messages/terminate.go
+++ b/postgres/messages/terminate.go
@@ -1,9 +1,61 @@
+// Copyright 2023 Dolthub, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
package messages
-// ReadTerminate returns whether the buffer represents a Terminate message.
-func ReadTerminate(buf []byte) bool {
- if len(buf) < 5 {
- return false
+func init() {
+ initializeDefaultMessage(Terminate{})
+ addMessageHeader(Terminate{})
+}
+
+// Terminate tells the server to close the connection.
+type Terminate struct{}
+
+var terminateDefault = Message{
+ Name: "Terminate",
+ Fields: []*Field{
+ {
+ Name: "Header",
+ Type: Byte1,
+ Tags: Header,
+ Data: int32('X'),
+ },
+ {
+ Name: "MessageLength",
+ Type: Int32,
+ Tags: MessageLengthInclusive,
+ Data: int32(0),
+ },
+ },
+}
+
+var _ MessageType = Terminate{}
+
+// encode implements the interface MessageType.
+func (m Terminate) encode() (Message, error) {
+ return terminateDefault.Copy(), nil
+}
+
+// decode implements the interface MessageType.
+func (m Terminate) decode(s Message) (MessageType, error) {
+ if err := s.MatchesStructure(*m.defaultMessage()); err != nil {
+ return nil, err
}
- return buf[0] == 'X' && buf[4] == 4
+ return Terminate{}, nil
+}
+
+// defaultMessage implements the interface MessageType.
+func (m Terminate) defaultMessage() *Message {
+ return &terminateDefault
}
diff --git a/postgres/messages/utils.go b/postgres/messages/utils.go
index 71372a5cfd..7f15acfedd 100644
--- a/postgres/messages/utils.go
+++ b/postgres/messages/utils.go
@@ -1,3 +1,17 @@
+// Copyright 2023 Dolthub, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
package messages
import (
@@ -7,6 +21,8 @@ import (
"golang.org/x/exp/constraints"
)
+//TODO: delete these Write functions
+
// WriteLength writes the length of the message into the byte slice. Modifies the given byte slice, while also
// returning the same slice. Assumes that the first byte is the message identifier, while the next 4 bytes are
// the length.
@@ -22,3 +38,62 @@ func WriteLength(b []byte) []byte {
func WriteNumber[T constraints.Integer | constraints.Float](buf *bytes.Buffer, num T) {
_ = binary.Write(buf, binary.BigEndian, num)
}
+
+// Stack is a generic stack.
+type Stack[T any] struct {
+ values []T
+}
+
+// NewStack creates a new, empty stack.
+func NewStack[T any]() *Stack[T] {
+ return &Stack[T]{}
+}
+
+// Len returns the size of the stack.
+func (s *Stack[T]) Len() int {
+ return len(s.values)
+}
+
+// Peek returns the top value on the stack without removing it.
+func (s *Stack[T]) Peek() (value T) {
+ if len(s.values) == 0 {
+ return
+ }
+ return s.values[len(s.values)-1]
+}
+
+// PeekDepth returns the n-th value from the top. PeekDepth(0) is equivalent to the standard Peek().
+func (s *Stack[T]) PeekDepth(depth int) (value T) {
+ if len(s.values) <= depth {
+ return
+ }
+ return s.values[len(s.values)-(1+depth)]
+}
+
+// PeekReference returns a reference to the top value on the stack without removing it.
+func (s *Stack[T]) PeekReference() *T {
+ if len(s.values) == 0 {
+ return nil
+ }
+ return &s.values[len(s.values)-1]
+}
+
+// Pop returns the top value on the stack while also removing it from the stack.
+func (s *Stack[T]) Pop() (value T) {
+ if len(s.values) == 0 {
+ return
+ }
+ value = s.values[len(s.values)-1]
+ s.values = s.values[:len(s.values)-1]
+ return
+}
+
+// Push adds the given value to the stack.
+func (s *Stack[T]) Push(value T) {
+ s.values = append(s.values, value)
+}
+
+// Empty returns whether the stack is empty.
+func (s *Stack[T]) Empty() bool {
+ return len(s.values) == 0
+}
diff --git a/system_checks.go b/system_checks.go
index 99fb1a7137..ce8fb87ff3 100644
--- a/system_checks.go
+++ b/system_checks.go
@@ -1,4 +1,4 @@
-// Copyright 2020 Dolthub, Inc.
+// Copyright 2023 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.