diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml deleted file mode 100644 index 1a207faa04..0000000000 --- a/.github/workflows/format.yml +++ /dev/null @@ -1,89 +0,0 @@ -name: Check/Format PR - -on: - pull_request: - branches: [ main ] - -concurrency: - group: format-${{ github.event.pull_request.number || github.ref }} - cancel-in-progress: true - -jobs: - format: - name: Format PR - if: github.event.pull_request.head.repo.full_name == github.repository - runs-on: ubuntu-22.04 - steps: - - name: Setup Go 1.x - uses: actions/setup-go@v3 - with: - go-version: ^1.20 - - uses: actions/checkout@v3 - with: - token: ${{ secrets.REPO_ACCESS_TOKEN || secrets.GITHUB_TOKEN }} - ref: ${{ github.head_ref }} - - name: Install goimports - run: | - go mod download golang.org/x/tools - go install golang.org/x/tools/cmd/goimports - - name: Format repo - run: | - ./format_repo.sh - env: - BRANCH_NAME: ${{ github.head_ref }} - CHANGE_TARGET: ${{ github.base_ref }} - - name: Changes detected - id: detect-changes - run: | - changes=$(git status --porcelain) - if [ ! -z "$changes" ]; then - echo "has-changes=true" >> $GITHUB_OUTPUT - fi - - uses: EndBug/add-and-commit@v9.1.1 - if: ${{ steps.detect-changes.outputs.has-changes == 'true' }} - with: - message: "[ga-format-pr] Run ./format_repo.sh to fix formatting" - add: "." - cwd: "." - verify: - needs: format - name: Verify format - runs-on: ubuntu-22.04 - steps: - - name: Setup Go 1.x - uses: actions/setup-go@v3 - with: - go-version: ^1.20 - id: go - - uses: actions/checkout@v3 - with: - ref: ${{ github.head_ref }} - - name: Check all - id: check_format - run: | - ./check_repo.sh - env: - BRANCH_NAME: ${{ github.head_ref }} - CHANGE_TARGET: ${{ github.base_ref }} - alt-verify: - if: github.event.pull_request.head.repo.full_name != github.repository - name: Verify format - runs-on: ubuntu-22.04 - steps: - - name: Setup Go 1.x - uses: actions/setup-go@v3 - with: - go-version: ^1.20 - id: go - - uses: actions/checkout@v3 - - name: Check all - id: check_format - run: | - ./check_repo.sh - code=$(echo $?) - if [ "$code" != 0 ]; then - echo "Please run ./format_repo.sh to fix this pull request's formatting" - fi - env: - BRANCH_NAME: ${{ github.head_ref }} - CHANGE_TARGET: ${{ github.base_ref }} diff --git a/fileno_check.go b/fileno_check.go index 5a4b302667..0b9d9f075b 100644 --- a/fileno_check.go +++ b/fileno_check.go @@ -1,4 +1,4 @@ -// Copyright 2019 Dolthub, Inc. +// Copyright 2023 Dolthub, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/fileno_check_darwin.go b/fileno_check_darwin.go index 0223dbf2fa..78d9370c14 100644 --- a/fileno_check_darwin.go +++ b/fileno_check_darwin.go @@ -1,4 +1,4 @@ -// Copyright 2019 Dolthub, Inc. +// Copyright 2023 Dolthub, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/fileno_check_linux.go b/fileno_check_linux.go index 6c85f9d1df..1c06e4e285 100644 --- a/fileno_check_linux.go +++ b/fileno_check_linux.go @@ -1,4 +1,4 @@ -// Copyright 2019 Dolthub, Inc. +// Copyright 2023 Dolthub, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/go.mod b/go.mod index 41f9158374..cc19f2f2c2 100644 --- a/go.mod +++ b/go.mod @@ -3,10 +3,16 @@ module github.com/dolthub/doltgresql go 1.20 require ( - github.com/dolthub/dolt/go v0.40.5-0.20230911151544-6c18c2cbfebe - github.com/dolthub/vitess v0.0.0-20230823204737-4a21a94e90c3 + github.com/dolthub/dolt/go v0.40.5-0.20230917024726-1ae3a49864ac + github.com/dolthub/go-mysql-server v0.17.1-0.20230916212652-86f1cdf0339c + github.com/dolthub/vitess v0.0.0-20230915082726-ef1b92774b14 + github.com/fatih/color v1.13.0 + github.com/jackc/pgx/v5 v5.4.3 github.com/shopspring/decimal v1.2.0 + github.com/stretchr/testify v1.8.2 + github.com/tidwall/gjson v1.14.4 golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 + golang.org/x/sys v0.10.0 ) require ( @@ -26,18 +32,17 @@ require ( github.com/cenkalti/backoff/v4 v4.1.3 // indirect github.com/cespare/xxhash v1.1.0 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/denisbrodbeck/machineid v1.0.1 // indirect - github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20230911151544-6c18c2cbfebe // indirect + github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20230917024726-1ae3a49864ac // indirect github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2 // indirect github.com/dolthub/fslock v0.0.3 // indirect github.com/dolthub/go-icu-regex v0.0.0-20230524105445-af7e7991c97e // indirect - github.com/dolthub/go-mysql-server v0.17.1-0.20230911134248-fb77f4090d9b // indirect github.com/dolthub/ishell v0.0.0-20221214210346-d7db0b066488 // indirect github.com/dolthub/jsonpath v0.0.2-0.20230525180605-8dc13778fd72 // indirect github.com/dolthub/maphash v0.0.0-20221220182448-74e1e1ea1577 // indirect github.com/dolthub/swiss v0.1.0 // indirect github.com/dustin/go-humanize v1.0.0 // indirect - github.com/fatih/color v1.13.0 // indirect github.com/flynn-archive/go-shlex v0.0.0-20150515145356-3f9db97f8568 // indirect github.com/go-kit/kit v0.10.0 // indirect github.com/go-logr/logr v1.2.3 // indirect @@ -56,6 +61,8 @@ require ( github.com/googleapis/gax-go/v2 v2.11.0 // indirect github.com/hashicorp/golang-lru v0.5.4 // indirect github.com/hashicorp/golang-lru/v2 v2.0.2 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/jpillora/backoff v1.0.0 // indirect github.com/juju/gnuflag v0.0.0-20171113085948-2ce1bb71843d // indirect @@ -72,20 +79,19 @@ require ( github.com/mitchellh/hashstructure v1.1.0 // indirect github.com/pierrec/lz4/v4 v4.1.6 // indirect github.com/pkg/errors v0.9.1 // indirect - github.com/pkg/profile v1.5.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/client_golang v1.13.0 // indirect github.com/prometheus/client_model v0.2.0 // indirect github.com/prometheus/common v0.37.0 // indirect github.com/prometheus/procfs v0.8.0 // indirect github.com/rivo/uniseg v0.2.0 // indirect + github.com/rogpeppe/go-internal v1.11.0 // indirect github.com/sergi/go-diff v1.1.0 // indirect github.com/silvasur/buzhash v0.0.0-20160816060738-9bdec3dec7c6 // indirect github.com/sirupsen/logrus v1.8.1 // indirect github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 // indirect github.com/tealeg/xlsx v1.0.5 // indirect github.com/tetratelabs/wazero v1.1.0 // indirect - github.com/tidwall/gjson v1.14.4 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect github.com/tidwall/sjson v1.2.5 // indirect @@ -95,18 +101,15 @@ require ( github.com/zeebo/xxh3 v1.0.2 // indirect go.opencensus.io v0.24.0 // indirect go.opentelemetry.io/otel v1.7.0 // indirect - go.opentelemetry.io/otel/exporters/jaeger v1.7.0 // indirect - go.opentelemetry.io/otel/sdk v1.7.0 // indirect go.opentelemetry.io/otel/trace v1.7.0 // indirect go.uber.org/atomic v1.7.0 // indirect go.uber.org/multierr v1.6.0 // indirect go.uber.org/zap v1.24.0 // indirect golang.org/x/crypto v0.11.0 // indirect - golang.org/x/mod v0.8.0 // indirect + golang.org/x/mod v0.9.0 // indirect golang.org/x/net v0.12.0 // indirect golang.org/x/oauth2 v0.8.0 // indirect golang.org/x/sync v0.2.0 // indirect - golang.org/x/sys v0.10.0 // indirect golang.org/x/term v0.10.0 // indirect golang.org/x/text v0.11.0 // indirect golang.org/x/time v0.1.0 // indirect @@ -122,4 +125,5 @@ require ( gopkg.in/square/go-jose.v2 v2.5.1 // indirect gopkg.in/src-d/go-errors.v1 v1.0.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index bcc6a9ceac..7862b7877a 100644 --- a/go.sum +++ b/go.sum @@ -157,16 +157,14 @@ github.com/denisbrodbeck/machineid v1.0.1/go.mod h1:dJUwb7PTidGDeYyUBmXZ2GphQBbj github.com/denisenkom/go-mssqldb v0.10.0 h1:QykgLZBorFE95+gO3u9esLd0BmbvpWp0/waNNZfHBM8= github.com/denisenkom/go-mssqldb v0.10.0/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU= github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= -github.com/dolthub/dolt/go v0.40.5-0.20230911135842-43e82ba2510c h1:k0RWniTdbLoRpH9gCs0j0qC7kbTcUaH0S0etmEf8JMk= -github.com/dolthub/dolt/go v0.40.5-0.20230911135842-43e82ba2510c/go.mod h1:eYa1zEjLxMs5ZvYEDYiqbpAVdJ2NfnUXSnZVwyy9fUQ= github.com/dolthub/dolt/go v0.40.5-0.20230911151544-6c18c2cbfebe h1:Bo+kXeBK1CvrPKKfVydbXKbu5Mt8HM+GH8lZv2d+0A4= github.com/dolthub/dolt/go v0.40.5-0.20230911151544-6c18c2cbfebe/go.mod h1:eYa1zEjLxMs5ZvYEDYiqbpAVdJ2NfnUXSnZVwyy9fUQ= -github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20201005193433-3ee972b1d078 h1:nrkoh/RcgTq5EsWTcbSBF8KQghCtM+1dhyslghbBoj8= -github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20201005193433-3ee972b1d078/go.mod h1:8Jdiq6CVg8HM4n9fF17sGgXUpFa98zDyscW0A7OQmuM= -github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20230911135842-43e82ba2510c h1:XsYJr7EG3FD00FqYtllQyT0ycG2dGp6rXlXts+eTELA= -github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20230911135842-43e82ba2510c/go.mod h1:Fi7KchJVfwMuPJkX4vJeAlNZkxCiVyhvVYfCgaSDlTU= +github.com/dolthub/dolt/go v0.40.5-0.20230917024726-1ae3a49864ac h1:JE93q8eaYCujLQmpa3JibkqFLY1CXSfPdm6w351sHC8= +github.com/dolthub/dolt/go v0.40.5-0.20230917024726-1ae3a49864ac/go.mod h1:VSSW98p69Lsd/dy3/aySvlJxmJW2g4yjvNWIXfddSEM= github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20230911151544-6c18c2cbfebe h1:la317evV7J8bzeKfYdmpBo6/OIwsSO9QxrOPYeWoY6A= github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20230911151544-6c18c2cbfebe/go.mod h1:Fi7KchJVfwMuPJkX4vJeAlNZkxCiVyhvVYfCgaSDlTU= +github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20230917024726-1ae3a49864ac h1:zQtagkisqzs3QWtGxOlpayQI01wL9pYgfu8Zn3asMdU= +github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20230917024726-1ae3a49864ac/go.mod h1:Fi7KchJVfwMuPJkX4vJeAlNZkxCiVyhvVYfCgaSDlTU= github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2 h1:u3PMzfF8RkKd3lB9pZ2bfn0qEG+1Gms9599cr0REMww= github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2/go.mod h1:mIEZOHnFx4ZMQeawhw9rhsj+0zwQj7adVsnBX7t+eKY= github.com/dolthub/fslock v0.0.3 h1:iLMpUIvJKMKm92+N1fmHVdxJP5NdyDK5bK7z7Ba2s2U= @@ -175,6 +173,8 @@ github.com/dolthub/go-icu-regex v0.0.0-20230524105445-af7e7991c97e h1:kPsT4a47cw github.com/dolthub/go-icu-regex v0.0.0-20230524105445-af7e7991c97e/go.mod h1:KPUcpx070QOfJK1gNe0zx4pA5sicIK1GMikIGLKC168= github.com/dolthub/go-mysql-server v0.17.1-0.20230911134248-fb77f4090d9b h1:4b+gOvFMSZ7YhiB8VPFWhyPYZcWBGALvaxT1FGLmP34= github.com/dolthub/go-mysql-server v0.17.1-0.20230911134248-fb77f4090d9b/go.mod h1:vSQ47leaIPTtvSLKo89D1FdYdypU5OH6VBV63B2MS8Y= +github.com/dolthub/go-mysql-server v0.17.1-0.20230916212652-86f1cdf0339c h1:jjJ7mV0X7LVAkNGgDBtKAnPlsZYoeHFF2vSPB7NH/r4= +github.com/dolthub/go-mysql-server v0.17.1-0.20230916212652-86f1cdf0339c/go.mod h1:s4hz5TZpFkw8IdLVz58EKKm9cxvRd4jVw1Hs2o+tEXw= github.com/dolthub/ishell v0.0.0-20221214210346-d7db0b066488 h1:0HHu0GWJH0N6a6keStrHhUAK5/o9LVfkh44pvsV4514= github.com/dolthub/ishell v0.0.0-20221214210346-d7db0b066488/go.mod h1:ehexgi1mPxRTk0Mok/pADALuHbvATulTh6gzr7NzZto= github.com/dolthub/jsonpath v0.0.2-0.20230525180605-8dc13778fd72 h1:NfWmngMi1CYUWU4Ix8wM+USEhjc+mhPlT9JUR/anvbQ= @@ -185,6 +185,8 @@ github.com/dolthub/swiss v0.1.0 h1:EaGQct3AqeP/MjASHLiH6i4TAmgbG/c4rA6a1bzCOPc= github.com/dolthub/swiss v0.1.0/go.mod h1:BeucyB08Vb1G9tumVN3Vp/pyY4AMUnr9p7Rz7wJ7kAQ= github.com/dolthub/vitess v0.0.0-20230823204737-4a21a94e90c3 h1:lY3oQbYNMSVjT02n6f2M2H0u4icF6lGbS/IpWr27ti8= github.com/dolthub/vitess v0.0.0-20230823204737-4a21a94e90c3/go.mod h1:IwjNXSQPymrja5pVqmfnYdcy7Uv7eNJNBPK/MEh9OOw= +github.com/dolthub/vitess v0.0.0-20230915082726-ef1b92774b14 h1:w0MwDEgWWei13UqZeS6mlSPUQ3iMHMDlrU+MUKQtR4s= +github.com/dolthub/vitess v0.0.0-20230915082726-ef1b92774b14/go.mod h1:IwjNXSQPymrja5pVqmfnYdcy7Uv7eNJNBPK/MEh9OOw= github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= @@ -373,6 +375,12 @@ github.com/iancoleman/strcase v0.1.3/go.mod h1:SK73tn/9oHe+/Y0h39VT4UCxmurVJkR5N github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= github.com/influxdata/influxdb1-client v0.0.0-20191209144304-8bf82d3c094d/go.mod h1:qj24IKcXYK6Iy9ceXlo3Tc+vtHo9lIhSX5JddghvEPo= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.4.3 h1:cxFyXhxlvAifxnkKKdlxv8XqUf59tDlYjnV5YYfsJJY= +github.com/jackc/pgx/v5 v5.4.3/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA= github.com/jcmturner/gofork v0.0.0-20180107083740-2aebee971930/go.mod h1:MK8+TM0La+2rjBD4jE12Kj1pCCxK7d2LK/UM3ncEo0o= github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= github.com/jmespath/go-jmespath v0.3.0/go.mod h1:9QtRXoHjLGCJ5IBSaohpXITPlowMeeYCZ7fLUTSywik= @@ -415,8 +423,8 @@ github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxv github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= -github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= @@ -515,8 +523,6 @@ github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/profile v1.2.1/go.mod h1:hJw3o1OdXxsrSjjVksARp5W95eeEaEfptyVZyv6JUPA= -github.com/pkg/profile v1.5.0 h1:042Buzk+NhDI+DeSAA62RwJL8VAuZUMQZUjCsRz1Mug= -github.com/pkg/profile v1.5.0/go.mod h1:qBsxPvzyUincmltOk6iyRVxHYg4adc0OFOv72ZdLa18= github.com/pkg/sftp v1.10.1/go.mod h1:lYOWFsE0bwd1+KfKJaKeuokY15vzFx25BLbzYYoAxZI= github.com/pkg/sftp v1.13.0/go.mod h1:41g+FIPlQUTDCveupEmEA65IoiQFrtgCeDopC4ajGIM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -561,6 +567,8 @@ github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJ github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= +github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= github.com/rs/xid v1.4.0 h1:qd7wPTDkN6KQx2VmMBLrpHkiyQwgFXRnkOLacUiaSNY= github.com/rs/zerolog v1.28.0 h1:MirSo27VyNi7RJYP3078AA1+Cyzd2GB66qy3aUHvsWY= github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= @@ -597,7 +605,6 @@ github.com/streadway/handy v0.0.0-20190108123426-d5acb3125c2a/go.mod h1:qNTQ5P5J github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= -github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.1.4/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= @@ -610,6 +617,7 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= +github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/tealeg/xlsx v1.0.5 h1:+f8oFmvY8Gw1iUXzPk+kz+4GpbDZPK1FhPiQRd+ypgE= github.com/tealeg/xlsx v1.0.5/go.mod h1:btRS8dz54TDnvKNosuAqxrM1QgN1udgk9O34bDCnORM= github.com/tetratelabs/wazero v1.1.0 h1:EByoAhC+QcYpwSZJSs/aV0uokxPwBgKxfiokSUwAknQ= @@ -661,10 +669,6 @@ go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= go.opentelemetry.io/otel v1.7.0 h1:Z2lA3Tdch0iDcrhJXDIlC94XE+bxok1F9B+4Lz/lGsM= go.opentelemetry.io/otel v1.7.0/go.mod h1:5BdUoMIz5WEs0vt0CUEMtSSaTSHBBVwrhnz7+nrD5xk= -go.opentelemetry.io/otel/exporters/jaeger v1.7.0 h1:wXgjiRldljksZkZrldGVe6XrG9u3kYDyQmkZwmm5dI0= -go.opentelemetry.io/otel/exporters/jaeger v1.7.0/go.mod h1:PwQAOqBgqbLQRKlj466DuD2qyMjbtcPpfPfj+AqbSBs= -go.opentelemetry.io/otel/sdk v1.7.0 h1:4OmStpcKVOfvDOgCt7UriAPtKolwIhxpnSNI/yK+1B0= -go.opentelemetry.io/otel/sdk v1.7.0/go.mod h1:uTEOTwaqIVuTGiJN7ii13Ibp75wJmYUDe374q6cZwUU= go.opentelemetry.io/otel/trace v1.7.0 h1:O37Iogk1lEkMRXewVtZ1BBTVn5JEp8GrJvP92bJqC6o= go.opentelemetry.io/otel/trace v1.7.0/go.mod h1:fzLSB9nqR2eXzxPXb2JW9IKE+ScyXA48yyE4TNvoHqU= go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI= @@ -738,8 +742,8 @@ golang.org/x/mod v0.1.1-0.20191107180719-034126e5016b/go.mod h1:QqPTAvyqsEbceGzB golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= -golang.org/x/mod v0.8.0 h1:LUYupSeNrTNCGzR/hVBk2NHZO4hXcVaW1k4Qx7rjPx8= -golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.9.0 h1:KENHtAZL2y3NLMYZeHY9DW8HW8V+kQyJsY/V9JlKvCs= +golang.org/x/mod v0.9.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -772,7 +776,6 @@ golang.org/x/net v0.0.0-20200501053045-e0ff5e5a1de5/go.mod h1:qpuaurCH72eLCgpAm/ golang.org/x/net v0.0.0-20200506145744-7e3656a0809f/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20200513185701-a91f0712d120/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20200520182314-0ba52f642ac2/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= -golang.org/x/net v0.0.0-20200602114024-627f9648deb9/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= @@ -848,7 +851,6 @@ golang.org/x/sys v0.0.0-20200511232937-7e40ca221e25/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200515095857-1151b9dac4a9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200523222454-059865788121/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200620081246-981b61492c35/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200803210538-64077c9b5642/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200828194041-157a740278f4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -857,7 +859,6 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -1010,7 +1011,6 @@ google.golang.org/genproto v0.0.0-20200513103714-09dca8ec2884/go.mod h1:55QSHmfG google.golang.org/genproto v0.0.0-20200515170657-fc4c6c6a6587/go.mod h1:YsZOwe1myG/8QRHRsmBRE1LrgQY60beZKjly0O1fX9U= google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= google.golang.org/genproto v0.0.0-20200618031413-b414f8b61790/go.mod h1:jDfRM7FcilCzHH/e9qn6dsT145K34l5v+OpcnNgKAAA= -google.golang.org/genproto v0.0.0-20200622133129-d0ee0c36e670/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20200729003335-053ba62fc06f/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20200804131852-c06518451d9c/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20200825200019-8632dd797987/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= diff --git a/main.go b/main.go index a0bd6cd998..624bda67f7 100644 --- a/main.go +++ b/main.go @@ -20,12 +20,8 @@ import ( "encoding/binary" "fmt" "math/rand" - "net/http" _ "net/http/pprof" "os" - "os/exec" - "strconv" - "time" "github.com/dolthub/dolt/go/cmd/dolt/cli" "github.com/dolthub/dolt/go/cmd/dolt/commands" @@ -35,7 +31,6 @@ import ( "github.com/dolthub/dolt/go/libraries/doltcore/env" "github.com/dolthub/dolt/go/libraries/doltcore/sqle/dfunctions" "github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess" - "github.com/dolthub/dolt/go/libraries/events" "github.com/dolthub/dolt/go/libraries/utils/argparser" "github.com/dolthub/dolt/go/libraries/utils/config" "github.com/dolthub/dolt/go/libraries/utils/filesys" @@ -44,13 +39,7 @@ import ( "github.com/dolthub/go-mysql-server/server" "github.com/dolthub/go-mysql-server/sql" "github.com/fatih/color" - "github.com/pkg/profile" "github.com/tidwall/gjson" - "go.opentelemetry.io/otel" - "go.opentelemetry.io/otel/exporters/jaeger" - "go.opentelemetry.io/otel/sdk/resource" - tracesdk "go.opentelemetry.io/otel/sdk/trace" - semconv "go.opentelemetry.io/otel/semconv/v1.10.0" "github.com/dolthub/doltgresql/postgres" ) @@ -66,44 +55,26 @@ var doltCommand = cli.NewSubCommandHandler("doltgresql", "it's git for data", [] sqlserver.SqlServerCmd{VersionStr: Version}, }) var globalArgParser = cli.CreateGlobalArgParser("doltgresql") -var globalDocs = cli.CommandDocsForCommandString("doltgresql", doc, globalArgParser) - -var globalSpecialMsg = ` -Dolt subcommands are in transition to using the flags listed below as global flags. -Not all subcommands use these flags. If your command accepts these flags without error, then they are supported. -` func init() { server.DefaultProtocolListenerFunc = postgres.NewListenerWithConfig - sqlserver.DoltgreSQLDisableUsers = true + sqlserver.ExternalDisableUsers = true dfunctions.VersionString = Version } -const pprofServerFlag = "--pprof-server" const chdirFlag = "--chdir" -const jaegerFlag = "--jaeger" -const profFlag = "--prof" -const csMetricsFlag = "--csmetrics" const stdInFlag = "--stdin" const stdOutFlag = "--stdout" const stdErrFlag = "--stderr" const stdOutAndErrFlag = "--out-and-err" const ignoreLocksFlag = "--ignore-lock-file" -const verboseEngineSetupFlag = "--verbose-engine-setup" - -const cpuProf = "cpu" -const memProf = "mem" -const blockingProf = "blocking" -const traceProf = "trace" - -const featureVersionFlag = "--feature-version" func main() { - os.Exit(runMain()) + os.Exit(RunMain(os.Args[1:])) } -func runMain() int { - args := os.Args[1:] +func RunMain(args []string) int { + ctx := context.Background() // Inject the "sql-server" command args = append([]string{"sql-server"}, args...) // Enforce a default port of 5432 @@ -113,102 +84,15 @@ func runMain() int { } } - start := time.Now() - - if len(args) == 0 { - doltCommand.PrintUsage("dolt") - return 1 - } - if os.Getenv("DOLT_VERBOSE_ASSERT_TABLE_FILES_CLOSED") == "" { nbs.TableIndexGCFinalizerWithStackTrace = false } - csMetrics := false ignoreLockFile := false - verboseEngineSetup := false if len(args) > 0 { var doneDebugFlags bool for !doneDebugFlags && len(args) > 0 { switch args[0] { - case profFlag: - switch args[1] { - case cpuProf: - cli.Println("cpu profiling enabled.") - defer profile.Start(profile.CPUProfile, profile.NoShutdownHook).Stop() - case memProf: - cli.Println("mem profiling enabled.") - defer profile.Start(profile.MemProfile, profile.NoShutdownHook).Stop() - case blockingProf: - cli.Println("block profiling enabled") - defer profile.Start(profile.BlockProfile, profile.NoShutdownHook).Stop() - case traceProf: - cli.Println("trace profiling enabled") - defer profile.Start(profile.TraceProfile, profile.NoShutdownHook).Stop() - default: - panic("Unexpected prof flag: " + args[1]) - } - args = args[2:] - - case pprofServerFlag: - // serve the pprof endpoints setup in the init function run when "net/http/pprof" is imported - go func() { - cyanStar := color.CyanString("*") - cli.Println(cyanStar, "Starting pprof server on port 6060.") - cli.Println(cyanStar, "Go to", color.CyanString("http://localhost:6060/debug/pprof"), "in a browser to see supported endpoints.") - cli.Println(cyanStar) - cli.Println(cyanStar, "Known endpoints are:") - cli.Println(cyanStar, " /allocs: A sampling of all past memory allocations") - cli.Println(cyanStar, " /block: Stack traces that led to blocking on synchronization primitives") - cli.Println(cyanStar, " /cmdline: The command line invocation of the current program") - cli.Println(cyanStar, " /goroutine: Stack traces of all current goroutines") - cli.Println(cyanStar, " /heap: A sampling of memory allocations of live objects. You can specify the gc GET parameter to run GC before taking the heap sample.") - cli.Println(cyanStar, " /mutex: Stack traces of holders of contended mutexes") - cli.Println(cyanStar, " /profile: CPU profile. You can specify the duration in the seconds GET parameter. After you get the profile file, use the go tool pprof command to investigate the profile.") - cli.Println(cyanStar, " /threadcreate: Stack traces that led to the creation of new OS threads") - cli.Println(cyanStar, " /trace: A trace of execution of the current program. You can specify the duration in the seconds GET parameter. After you get the trace file, use the go tool trace command to investigate the trace.") - cli.Println() - - err := http.ListenAndServe("0.0.0.0:6060", nil) - - if err != nil { - cli.Println(color.YellowString("pprof server exited with error: %v", err)) - } - }() - args = args[1:] - - // Enable a global jaeger tracer for this run of Dolt, - // emitting traces to a collector running at - // localhost:14268. To visualize these traces, run: - // docker run -d --name jaeger \ - // -e COLLECTOR_ZIPKIN_HTTP_PORT=9411 \ - // -p 5775:5775/udp \ - // -p 6831:6831/udp \ - // -p 6832:6832/udp \ - // -p 5778:5778 \ - // -p 16686:16686 \ - // -p 14268:14268 \ - // -p 14250:14250 \ - // -p 9411:9411 \ - // jaegertracing/all-in-one:1.21 - // and browse to http://localhost:16686 - case jaegerFlag: - cli.Println("running with jaeger tracing reporting to localhost") - exp, err := jaeger.New(jaeger.WithCollectorEndpoint(jaeger.WithEndpoint("http://localhost:14268/api/traces"))) - if err != nil { - cli.Println(color.YellowString("could not create jaeger collector: %v", err)) - } else { - tp := tracesdk.NewTracerProvider( - tracesdk.WithBatcher(exp), - tracesdk.WithResource(resource.NewWithAttributes( - semconv.SchemaURL, - semconv.ServiceNameKey.String("dolt"), - )), - ) - otel.SetTracerProvider(tp) - defer tp.Shutdown(context.Background()) - args = args[1:] - } // Currently goland doesn't support running with a different working directory when using go modules. // This is a hack that allows a different working directory to be set after the application starts using // chdir=. The syntax is not flexible and must match exactly this. @@ -259,33 +143,10 @@ func runMain() int { color.NoColor = true args = args[2:] - case csMetricsFlag: - csMetrics = true - args = args[1:] - case ignoreLocksFlag: ignoreLockFile = true args = args[1:] - case featureVersionFlag: - var err error - if len(args) == 0 { - err = fmt.Errorf("missing argument for the --feature-version flag") - } else { - if featureVersion, err := strconv.Atoi(args[1]); err == nil { - doltdb.DoltFeatureVersion = doltdb.FeatureVersion(featureVersion) - } - } - if err != nil { - cli.PrintErrln(err.Error()) - return 1 - } - - args = args[2:] - - case verboseEngineSetupFlag: - verboseEngineSetup = true - args = args[1:] default: doneDebugFlags = true } @@ -299,24 +160,11 @@ func runMain() int { warnIfMaxFilesTooLow() - ctx := context.Background() - if ok, exit := interceptSendMetrics(ctx, args); ok { - return exit - } - - _, usage := cli.HelpAndUsagePrinters(globalDocs) - var fs filesys.Filesys fs = filesys.LocalFS dEnv := env.Load(ctx, env.GetCurrentUserHomeDir, fs, doltdb.LocalDirDoltDB, Version) dEnv.IgnoreLockFile = ignoreLockFile - root, err := env.GetCurrentUserHomeDir() - if err != nil { - cli.PrintErrln(color.RedString("Failed to load the HOME directory: %v", err)) - return 1 - } - globalConfig, ok := dEnv.Config.GetConfig(env.GlobalConfig) if !ok { cli.PrintErrln(color.RedString("Failed to get global config")) @@ -325,10 +173,8 @@ func runMain() int { apr, remainingArgs, subcommandName, err := parseGlobalArgsAndSubCommandName(globalConfig, args) if err == argparser.ErrHelp { + //TODO: display some help message doltCommand.PrintUsage("dolt") - cli.Println(globalSpecialMsg) - usage() - return 0 } else if err != nil { cli.PrintErrln(color.RedString("Failure to parse arguments: %v", err)) @@ -354,35 +200,7 @@ func runMain() int { return 1 } - emitter := events.NewFileEmitter(root, dbfactory.DoltDir) - - defer func() { - ces := events.GlobalCollector.Close() - // events.WriterEmitter{cli.CliOut}.LogEvents(Version, ces) - - metricsDisabled := dEnv.Config.GetStringOrDefault(env.MetricsDisabled, "false") - - disabled, err := strconv.ParseBool(metricsDisabled) - if err != nil { - // log.Print(err) - return - } - - if disabled { - return - } - - // write events - _ = emitter.LogEvents(Version, ces) - - // flush events - if err := processEventsDir(args, dEnv); err != nil { - // log.Print(err) - } - }() - err = reconfigIfTempFileMoveFails(dEnv) - if err != nil { cli.PrintErrln(color.RedString("Failed to setup the temporary directory. %v`", err)) return 1 @@ -443,7 +261,7 @@ func runMain() int { return 1 } - lateBind, err := buildLateBinder(ctx, cwdFS, dEnv, mrEnv, creds, apr, subcommandName, verboseEngineSetup) + lateBind, err := buildLateBinder(ctx, cwdFS, dEnv, mrEnv, creds, apr, subcommandName, false) if err != nil { cli.PrintErrln(color.RedString("%v", err)) @@ -471,12 +289,6 @@ func runMain() int { } } - if csMetrics && dEnv.DoltDB != nil { - metricsSummary := dEnv.DoltDB.CSMetricsSummary() - cli.Println("Command took", time.Since(start).Seconds()) - cli.PrintErrln(metricsSummary) - } - return res } @@ -571,17 +383,6 @@ func buildLateBinder(ctx context.Context, cwdFS filesys.Filesys, rootEnv *env.Do return commands.BuildSqlEngineQueryist(ctx, cwdFS, mrEnv, creds, apr) } -// doc is currently used only when a `initCliContext` command is specified. This will include all commands in time, -// otherwise you only see these docs if you specify a nonsense argument before the `sql` subcommand. -var doc = cli.CommandDocumentationContent{ - ShortDesc: "Dolt is git for data", - LongDesc: `Dolt comprises of multiple subcommands that allow users to import, export, update, and manipulate data with SQL.`, - - Synopsis: []string{ - "<--data-dir=> subcommand ", - }, -} - func seedGlobalRand() { bs := make([]byte, 8) _, err := crand.Read(bs) @@ -591,42 +392,6 @@ func seedGlobalRand() { rand.Seed(int64(binary.LittleEndian.Uint64(bs))) } -// processEventsDir runs the dolt send-metrics command in a new process -func processEventsDir(args []string, dEnv *env.DoltEnv) error { - if len(args) > 0 { - ignoreCommands := map[string]struct{}{ - commands.SendMetricsCommand: {}, - "init": {}, - "config": {}, - } - - _, ok := ignoreCommands[args[0]] - - if ok { - return nil - } - - cmd := exec.Command("dolt", commands.SendMetricsCommand) - - if err := cmd.Start(); err != nil { - // log.Print(err) - return err - } - - return nil - } - - return nil -} - -func interceptSendMetrics(ctx context.Context, args []string) (bool, int) { - if len(args) < 1 || args[0] != commands.SendMetricsCommand { - return false, 0 - } - dEnv := env.LoadWithoutDB(ctx, env.GetCurrentUserHomeDir, filesys.LocalFS, Version) - return true, doltCommand.Exec(ctx, "dolt", args, dEnv, nil) -} - // parseGlobalArgsAndSubCommandName parses the global arguments, including a profile if given or a default profile if exists. Also returns the subcommand name. func parseGlobalArgsAndSubCommandName(globalConfig config.ReadWriteConfig, args []string) (apr *argparser.ArgParseResults, remaining []string, subcommandName string, err error) { apr, remaining, err = globalArgParser.ParseGlobalArgs(args) diff --git a/main_test.go b/main_test.go new file mode 100644 index 0000000000..8e98678aff --- /dev/null +++ b/main_test.go @@ -0,0 +1,59 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + "fmt" + "net" + "testing" + + "github.com/jackc/pgx/v5" + "github.com/stretchr/testify/require" +) + +func TestBasicConnection(t *testing.T) { + //TODO: fix me + return + //port := getEmptyPort(t) + //go RunMain([]string{fmt.Sprintf("--port=%d", port)}) + //port := 5431 + port := 5432 + + ctx := context.Background() + conn, err := pgx.Connect(ctx, fmt.Sprintf("postgres://postgres:password@localhost:%d/postgres", port)) + require.NoError(t, err) + defer conn.Close(ctx) + + func() { + //rows, err := conn.Query(ctx, "CREATE DATABASE testdb;") + rows, err := conn.Query(ctx, "SELECT * FROM test;") + require.NoError(t, err) + defer rows.Close() + for rows.Next() { + row, err := rows.Values() + require.NoError(t, err) + row = row + } + }() +} + +func getEmptyPort(t *testing.T) int { + listener, err := net.Listen("tcp", ":0") + require.NoError(t, err) + port := listener.Addr().(*net.TCPAddr).Port + require.NoError(t, listener.Close()) + return port +} diff --git a/postgres/listener.go b/postgres/listener.go index f15ddf521b..5acf2ebb41 100644 --- a/postgres/listener.go +++ b/postgres/listener.go @@ -93,9 +93,9 @@ func (l *Listener) HandleConnection(conn net.Conn) { return } - if _, err = conn.Write(messages.SSLResponse{ + if err = messages.Send(conn, messages.SSLResponse{ SupportsSSL: false, - }.Bytes()); err != nil { + }); err != nil { fmt.Println(err) return } @@ -107,43 +107,43 @@ func (l *Listener) HandleConnection(conn net.Conn) { } return } - startupMessage, err := messages.ReadStartupMessage(buf) + startupMessage, err := messages.ReceiveInto(buf, messages.StartupMessage{}) if err != nil { fmt.Println(err) return } - if _, err = conn.Write(messages.AuthenticationOk{}.Bytes()); err != nil { + if err = messages.Send(conn, messages.AuthenticationOk{}); err != nil { fmt.Println(err) return } - if _, err = conn.Write(messages.ParameterStatus{ + if err = messages.Send(conn, messages.ParameterStatus{ Name: "server_version", Value: "15.0", - }.Bytes()); err != nil { + }); err != nil { fmt.Println(err) return } - if _, err = conn.Write(messages.ParameterStatus{ + if err = messages.Send(conn, messages.ParameterStatus{ Name: "client_encoding", Value: "UTF8", - }.Bytes()); err != nil { + }); err != nil { fmt.Println(err) return } - if _, err = conn.Write(messages.BackendKeyData{ + if err = messages.Send(conn, messages.BackendKeyData{ ProcessID: 1, SecretKey: 0, - }.Bytes()); err != nil { + }); err != nil { fmt.Println(err) return } - if _, err = conn.Write(messages.ReadyForQuery{ + if err = messages.Send(conn, messages.ReadyForQuery{ Indicator: messages.ReadyForQueryTransactionIndicator_Idle, - }.Bytes()); err != nil { + }); err != nil { fmt.Println(err) return } @@ -162,62 +162,65 @@ func (l *Listener) HandleConnection(conn net.Conn) { return } - if messages.ReadTerminate(buf) { - return - } - query, ok := messages.ReadQuery(buf) - if !ok { - fmt.Println("unknown message, terminating connection") - return - } - commandCompleteTag, err := messages.QueryToCommandCompleteTag(query) + message, ok, err := messages.Receive(buf) if err != nil { fmt.Println(err.Error()) return + } else if !ok { + fmt.Println("unknown message format, terminating connection") + return } - var rowTotal int32 - if err = l.cfg.Handler.ComQuery(mysqlConn, query, func(res *sqltypes.Result, more bool) error { - rowDescription, err := messages.NewRowDescription(res.Fields) - if err != nil { - return err - } - if _, err = conn.Write(rowDescription.Bytes()); err != nil { - return err - } + switch message := message.(type) { + case messages.Terminate: + return + case messages.Query: + l.query(conn, mysqlConn, message.String) + } + } +} - for _, row := range res.Rows { - if _, err = conn.Write(messages.NewDataRow(row).Bytes()); err != nil { - return err - } - } +func (l *Listener) query(conn net.Conn, mysqlConn *mysql.Conn, query string) { + commandComplete := messages.CommandComplete{ + Query: query, + Rows: 0, + } - if commandCompleteTag == messages.CommandCompleteTag_INSERT || - commandCompleteTag == messages.CommandCompleteTag_UPDATE || - commandCompleteTag == messages.CommandCompleteTag_DELETE { - rowTotal = int32(res.RowsAffected) - } else { - rowTotal += int32(len(res.Rows)) - } - return nil + if err := l.cfg.Handler.ComQuery(mysqlConn, query, func(res *sqltypes.Result, more bool) error { + if err := messages.Send(conn, messages.RowDescription{ + Fields: res.Fields, }); err != nil { - fmt.Println(err.Error()) - return + return err } - if _, err = conn.Write(messages.CommandComplete{ - Tag: commandCompleteTag, - Rows: rowTotal, - }.Bytes()); err != nil { - fmt.Println(err) - return + for _, row := range res.Rows { + if err := messages.Send(conn, messages.DataRow{ + Values: row, + }); err != nil { + return err + } } - if _, err = conn.Write(messages.ReadyForQuery{ - Indicator: messages.ReadyForQueryTransactionIndicator_Idle, - }.Bytes()); err != nil { - fmt.Println(err) - return + if commandComplete.IsIUD() { + commandComplete.Rows = int32(res.RowsAffected) + } else { + commandComplete.Rows += int32(len(res.Rows)) } + return nil + }); err != nil { + fmt.Println(err.Error()) + return + } + + if err := messages.Send(conn, commandComplete); err != nil { + fmt.Println(err) + return + } + + if err := messages.Send(conn, messages.ReadyForQuery{ + Indicator: messages.ReadyForQueryTransactionIndicator_Idle, + }); err != nil { + fmt.Println(err) + return } } diff --git a/postgres/messages/authentication_cleartext_password.go b/postgres/messages/authentication_cleartext_password.go new file mode 100644 index 0000000000..07a2eeb824 --- /dev/null +++ b/postgres/messages/authentication_cleartext_password.go @@ -0,0 +1,65 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package messages + +func init() { + initializeDefaultMessage(AuthenticationCleartextPassword{}) +} + +// AuthenticationCleartextPassword represents a PostgreSQL message. +type AuthenticationCleartextPassword struct{} + +var authenticationCleartextPasswordDefault = MessageFormat{ + Name: "AuthenticationCleartextPassword", + Fields: FieldGroup{ + { + Name: "Header", + Type: Byte1, + Flags: Header, + Data: int32('R'), + }, + { + Name: "MessageLength", + Type: Int32, + Flags: MessageLengthInclusive, + Data: int32(8), + }, + { + Name: "Status", + Type: Int32, + Data: int32(3), + }, + }, +} + +var _ Message = AuthenticationCleartextPassword{} + +// encode implements the interface Message. +func (m AuthenticationCleartextPassword) encode() (MessageFormat, error) { + return m.defaultMessage().Copy(), nil +} + +// decode implements the interface Message. +func (m AuthenticationCleartextPassword) decode(s MessageFormat) (Message, error) { + if err := s.MatchesStructure(*m.defaultMessage()); err != nil { + return nil, err + } + return AuthenticationCleartextPassword{}, nil +} + +// defaultMessage implements the interface Message. +func (m AuthenticationCleartextPassword) defaultMessage() *MessageFormat { + return &authenticationCleartextPasswordDefault +} diff --git a/postgres/messages/authentication_gss.go b/postgres/messages/authentication_gss.go new file mode 100644 index 0000000000..223315eb34 --- /dev/null +++ b/postgres/messages/authentication_gss.go @@ -0,0 +1,65 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package messages + +func init() { + initializeDefaultMessage(AuthenticationGSS{}) +} + +// AuthenticationGSS represents a PostgreSQL message. +type AuthenticationGSS struct{} + +var authenticationGSSDefault = MessageFormat{ + Name: "AuthenticationGSS", + Fields: FieldGroup{ + { + Name: "Header", + Type: Byte1, + Flags: Header, + Data: int32('R'), + }, + { + Name: "MessageLength", + Type: Int32, + Flags: MessageLengthInclusive, + Data: int32(8), + }, + { + Name: "Status", + Type: Int32, + Data: int32(7), + }, + }, +} + +var _ Message = AuthenticationGSS{} + +// encode implements the interface Message. +func (m AuthenticationGSS) encode() (MessageFormat, error) { + return m.defaultMessage().Copy(), nil +} + +// decode implements the interface Message. +func (m AuthenticationGSS) decode(s MessageFormat) (Message, error) { + if err := s.MatchesStructure(*m.defaultMessage()); err != nil { + return nil, err + } + return AuthenticationGSS{}, nil +} + +// defaultMessage implements the interface Message. +func (m AuthenticationGSS) defaultMessage() *MessageFormat { + return &authenticationGSSDefault +} diff --git a/postgres/messages/authentication_gss_continue.go b/postgres/messages/authentication_gss_continue.go new file mode 100644 index 0000000000..aec0e229d5 --- /dev/null +++ b/postgres/messages/authentication_gss_continue.go @@ -0,0 +1,76 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package messages + +func init() { + initializeDefaultMessage(AuthenticationGSSContinue{}) +} + +// AuthenticationGSSContinue represents a PostgreSQL message. +type AuthenticationGSSContinue struct { + Data []byte +} + +var authenticationGSSContinueDefault = MessageFormat{ + Name: "AuthenticationGSSContinue", + Fields: FieldGroup{ + { + Name: "Header", + Type: Byte1, + Flags: Header, + Data: int32('R'), + }, + { + Name: "MessageLength", + Type: Int32, + Flags: MessageLengthInclusive, + Data: int32(0), + }, + { + Name: "Status", + Type: Int32, + Data: int32(8), + }, + { + Name: "AuthenticationData", + Type: ByteN, + Data: []byte{}, + }, + }, +} + +var _ Message = AuthenticationGSSContinue{} + +// encode implements the interface Message. +func (m AuthenticationGSSContinue) encode() (MessageFormat, error) { + outputMessage := m.defaultMessage().Copy() + outputMessage.Field("AuthenticationData").MustWrite(m.Data) + return outputMessage, nil +} + +// decode implements the interface Message. +func (m AuthenticationGSSContinue) decode(s MessageFormat) (Message, error) { + if err := s.MatchesStructure(*m.defaultMessage()); err != nil { + return nil, err + } + return AuthenticationGSSContinue{ + Data: s.Field("AuthenticationData").MustGet().([]byte), + }, nil +} + +// defaultMessage implements the interface Message. +func (m AuthenticationGSSContinue) defaultMessage() *MessageFormat { + return &authenticationGSSContinueDefault +} diff --git a/postgres/messages/authentication_kerberos_v5.go b/postgres/messages/authentication_kerberos_v5.go new file mode 100644 index 0000000000..5df34282e8 --- /dev/null +++ b/postgres/messages/authentication_kerberos_v5.go @@ -0,0 +1,65 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package messages + +func init() { + initializeDefaultMessage(AuthenticationKerberosV5{}) +} + +// AuthenticationKerberosV5 represents a PostgreSQL message. +type AuthenticationKerberosV5 struct{} + +var authenticationKerberosV5Default = MessageFormat{ + Name: "AuthenticationKerberosV5", + Fields: FieldGroup{ + { + Name: "Header", + Type: Byte1, + Flags: Header, + Data: int32('R'), + }, + { + Name: "MessageLength", + Type: Int32, + Flags: MessageLengthInclusive, + Data: int32(8), + }, + { + Name: "Status", + Type: Int32, + Data: int32(2), + }, + }, +} + +var _ Message = AuthenticationKerberosV5{} + +// encode implements the interface Message. +func (m AuthenticationKerberosV5) encode() (MessageFormat, error) { + return m.defaultMessage().Copy(), nil +} + +// decode implements the interface Message. +func (m AuthenticationKerberosV5) decode(s MessageFormat) (Message, error) { + if err := s.MatchesStructure(*m.defaultMessage()); err != nil { + return nil, err + } + return AuthenticationKerberosV5{}, nil +} + +// defaultMessage implements the interface Message. +func (m AuthenticationKerberosV5) defaultMessage() *MessageFormat { + return &authenticationKerberosV5Default +} diff --git a/postgres/messages/authentication_md5_password.go b/postgres/messages/authentication_md5_password.go new file mode 100644 index 0000000000..44da6c286c --- /dev/null +++ b/postgres/messages/authentication_md5_password.go @@ -0,0 +1,76 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package messages + +func init() { + initializeDefaultMessage(AuthenticationMD5Password{}) +} + +// AuthenticationMD5Password represents a PostgreSQL message. +type AuthenticationMD5Password struct { + Salt int32 +} + +var authenticationMD5PasswordDefault = MessageFormat{ + Name: "AuthenticationMD5Password", + Fields: FieldGroup{ + { + Name: "Header", + Type: Byte1, + Flags: Header, + Data: int32('R'), + }, + { + Name: "MessageLength", + Type: Int32, + Flags: MessageLengthInclusive, + Data: int32(12), + }, + { + Name: "Status", + Type: Int32, + Data: int32(5), + }, + { + Name: "Salt", + Type: Byte4, + Data: int32(0), + }, + }, +} + +var _ Message = AuthenticationMD5Password{} + +// encode implements the interface Message. +func (m AuthenticationMD5Password) encode() (MessageFormat, error) { + outputMessage := m.defaultMessage().Copy() + outputMessage.Field("Salt").MustWrite(m.Salt) + return outputMessage, nil +} + +// decode implements the interface Message. +func (m AuthenticationMD5Password) decode(s MessageFormat) (Message, error) { + if err := s.MatchesStructure(*m.defaultMessage()); err != nil { + return nil, err + } + return AuthenticationMD5Password{ + Salt: s.Field("Salt").MustGet().(int32), + }, nil +} + +// defaultMessage implements the interface Message. +func (m AuthenticationMD5Password) defaultMessage() *MessageFormat { + return &authenticationMD5PasswordDefault +} diff --git a/postgres/messages/authentication_ok.go b/postgres/messages/authentication_ok.go index 148d6b8687..5bfd289098 100644 --- a/postgres/messages/authentication_ok.go +++ b/postgres/messages/authentication_ok.go @@ -1,13 +1,65 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package messages // AuthenticationOk tells the client that authentication was successful. type AuthenticationOk struct{} -// Bytes returns AuthenticationOk as a byte slice, ready to be returned to the client. -func (aok AuthenticationOk) Bytes() []byte { - return []byte{ - 'R', // Message Type - 0, 0, 0, 8, // Message Length - 0, 0, 0, 0, // Padding +func init() { + initializeDefaultMessage(AuthenticationOk{}) +} + +var authenticationOkDefault = MessageFormat{ + Name: "AuthenticationOk", + Fields: FieldGroup{ + { + Name: "Header", + Type: Byte1, + Flags: Header, + Data: int32('R'), + }, + { + Name: "MessageLength", + Type: Int32, + Flags: MessageLengthInclusive, + Data: int32(8), + }, + { + Name: "Status", + Type: Int32, + Data: int32(0), + }, + }, +} + +var _ Message = AuthenticationOk{} + +// encode implements the interface Message. +func (m AuthenticationOk) encode() (MessageFormat, error) { + return m.defaultMessage().Copy(), nil +} + +// decode implements the interface Message. +func (m AuthenticationOk) decode(s MessageFormat) (Message, error) { + if err := s.MatchesStructure(*m.defaultMessage()); err != nil { + return nil, err } + return AuthenticationOk{}, nil +} + +// defaultMessage implements the interface Message. +func (m AuthenticationOk) defaultMessage() *MessageFormat { + return &authenticationOkDefault } diff --git a/postgres/messages/authentication_sasl.go b/postgres/messages/authentication_sasl.go new file mode 100644 index 0000000000..78577b565b --- /dev/null +++ b/postgres/messages/authentication_sasl.go @@ -0,0 +1,93 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package messages + +func init() { + initializeDefaultMessage(AuthenticationSASL{}) +} + +// AuthenticationSASL represents a PostgreSQL message. +type AuthenticationSASL struct { + Mechanisms []string +} + +var authenticationSASLDefault = MessageFormat{ + Name: "AuthenticationSASL", + Fields: FieldGroup{ + { + Name: "Header", + Type: Byte1, + Flags: Header, + Data: int32('R'), + }, + { + Name: "MessageLength", + Type: Int32, + Flags: MessageLengthInclusive, + Data: int32(0), + }, + { + Name: "Status", + Type: Int32, + Data: int32(10), + }, + { + Name: "Mechanisms", + Type: Repeated, + Flags: RepeatedTerminator, + Data: int32(0), + Children: []FieldGroup{ + { + { + Name: "Mechanism", + Type: String, + Data: "", + }, + }, + }, + }, + }, +} + +var _ Message = AuthenticationSASL{} + +// encode implements the interface Message. +func (m AuthenticationSASL) encode() (MessageFormat, error) { + outputMessage := m.defaultMessage().Copy() + for i, mechanism := range m.Mechanisms { + outputMessage.Field("Mechanisms").Child("Mechanism", i).MustWrite(mechanism) + } + return outputMessage, nil +} + +// decode implements the interface Message. +func (m AuthenticationSASL) decode(s MessageFormat) (Message, error) { + if err := s.MatchesStructure(*m.defaultMessage()); err != nil { + return nil, err + } + count := int(s.Field("Mechanisms").MustGet().(int32)) + mechanisms := make([]string, count) + for i := 0; i < count; i++ { + mechanisms[i] = s.Field("Mechanisms").Child("Mechanism", i).MustGet().(string) + } + return AuthenticationSASL{ + Mechanisms: mechanisms, + }, nil +} + +// defaultMessage implements the interface Message. +func (m AuthenticationSASL) defaultMessage() *MessageFormat { + return &authenticationSASLDefault +} diff --git a/postgres/messages/authentication_sasl_continue.go b/postgres/messages/authentication_sasl_continue.go new file mode 100644 index 0000000000..6d129ad127 --- /dev/null +++ b/postgres/messages/authentication_sasl_continue.go @@ -0,0 +1,76 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package messages + +func init() { + initializeDefaultMessage(AuthenticationSASLContinue{}) +} + +// AuthenticationSASLContinue represents a PostgreSQL message. +type AuthenticationSASLContinue struct { + Data []byte +} + +var authenticationSASLContinueDefault = MessageFormat{ + Name: "AuthenticationSASLContinue", + Fields: FieldGroup{ + { + Name: "Header", + Type: Byte1, + Flags: Header, + Data: int32('R'), + }, + { + Name: "MessageLength", + Type: Int32, + Flags: MessageLengthInclusive, + Data: int32(0), + }, + { + Name: "Status", + Type: Int32, + Data: int32(11), + }, + { + Name: "SASLData", + Type: ByteN, + Data: []byte{}, + }, + }, +} + +var _ Message = AuthenticationSASLContinue{} + +// encode implements the interface Message. +func (m AuthenticationSASLContinue) encode() (MessageFormat, error) { + outputMessage := m.defaultMessage().Copy() + outputMessage.Field("SASLData").MustWrite(m.Data) + return outputMessage, nil +} + +// decode implements the interface Message. +func (m AuthenticationSASLContinue) decode(s MessageFormat) (Message, error) { + if err := s.MatchesStructure(*m.defaultMessage()); err != nil { + return nil, err + } + return AuthenticationSASLContinue{ + Data: s.Field("SASLData").MustGet().([]byte), + }, nil +} + +// defaultMessage implements the interface Message. +func (m AuthenticationSASLContinue) defaultMessage() *MessageFormat { + return &authenticationSASLContinueDefault +} diff --git a/postgres/messages/authentication_sasl_final.go b/postgres/messages/authentication_sasl_final.go new file mode 100644 index 0000000000..be6c19b122 --- /dev/null +++ b/postgres/messages/authentication_sasl_final.go @@ -0,0 +1,76 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package messages + +func init() { + initializeDefaultMessage(AuthenticationSASLFinal{}) +} + +// AuthenticationSASLFinal represents a PostgreSQL message. +type AuthenticationSASLFinal struct { + AdditionalData []byte +} + +var authenticationSASLFinalDefault = MessageFormat{ + Name: "AuthenticationSASLFinal", + Fields: FieldGroup{ + { + Name: "Header", + Type: Byte1, + Flags: Header, + Data: int32('R'), + }, + { + Name: "MessageLength", + Type: Int32, + Flags: MessageLengthInclusive, + Data: int32(0), + }, + { + Name: "Status", + Type: Int32, + Data: int32(12), + }, + { + Name: "AdditionalData", + Type: ByteN, + Data: []byte{}, + }, + }, +} + +var _ Message = AuthenticationSASLFinal{} + +// encode implements the interface Message. +func (m AuthenticationSASLFinal) encode() (MessageFormat, error) { + outputMessage := m.defaultMessage().Copy() + outputMessage.Field("AdditionalData").MustWrite(m.AdditionalData) + return outputMessage, nil +} + +// decode implements the interface Message. +func (m AuthenticationSASLFinal) decode(s MessageFormat) (Message, error) { + if err := s.MatchesStructure(*m.defaultMessage()); err != nil { + return nil, err + } + return AuthenticationSASLFinal{ + AdditionalData: s.Field("AdditionalData").MustGet().([]byte), + }, nil +} + +// defaultMessage implements the interface Message. +func (m AuthenticationSASLFinal) defaultMessage() *MessageFormat { + return &authenticationSASLFinalDefault +} diff --git a/postgres/messages/authentication_scm_credential.go b/postgres/messages/authentication_scm_credential.go new file mode 100644 index 0000000000..fa93f144b1 --- /dev/null +++ b/postgres/messages/authentication_scm_credential.go @@ -0,0 +1,65 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package messages + +func init() { + initializeDefaultMessage(AuthenticationSCMCredential{}) +} + +// AuthenticationSCMCredential represents a PostgreSQL message. +type AuthenticationSCMCredential struct{} + +var authenticationSCMCredentialDefault = MessageFormat{ + Name: "AuthenticationSCMCredential", + Fields: FieldGroup{ + { + Name: "Header", + Type: Byte1, + Flags: Header, + Data: int32('R'), + }, + { + Name: "MessageLength", + Type: Int32, + Flags: MessageLengthInclusive, + Data: int32(8), + }, + { + Name: "Status", + Type: Int32, + Data: int32(6), + }, + }, +} + +var _ Message = AuthenticationSCMCredential{} + +// encode implements the interface Message. +func (m AuthenticationSCMCredential) encode() (MessageFormat, error) { + return m.defaultMessage().Copy(), nil +} + +// decode implements the interface Message. +func (m AuthenticationSCMCredential) decode(s MessageFormat) (Message, error) { + if err := s.MatchesStructure(*m.defaultMessage()); err != nil { + return nil, err + } + return AuthenticationSCMCredential{}, nil +} + +// defaultMessage implements the interface Message. +func (m AuthenticationSCMCredential) defaultMessage() *MessageFormat { + return &authenticationSCMCredentialDefault +} diff --git a/postgres/messages/authentication_sspi.go b/postgres/messages/authentication_sspi.go new file mode 100644 index 0000000000..ea2f20eb02 --- /dev/null +++ b/postgres/messages/authentication_sspi.go @@ -0,0 +1,65 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package messages + +func init() { + initializeDefaultMessage(AuthenticationSSPI{}) +} + +// AuthenticationSSPI represents a PostgreSQL message. +type AuthenticationSSPI struct{} + +var authenticationSSPIDefault = MessageFormat{ + Name: "AuthenticationSSPI", + Fields: FieldGroup{ + { + Name: "Header", + Type: Byte1, + Flags: Header, + Data: int32('R'), + }, + { + Name: "MessageLength", + Type: Int32, + Flags: MessageLengthInclusive, + Data: int32(8), + }, + { + Name: "Status", + Type: Int32, + Data: int32(9), + }, + }, +} + +var _ Message = AuthenticationSSPI{} + +// encode implements the interface Message. +func (m AuthenticationSSPI) encode() (MessageFormat, error) { + return m.defaultMessage().Copy(), nil +} + +// decode implements the interface Message. +func (m AuthenticationSSPI) decode(s MessageFormat) (Message, error) { + if err := s.MatchesStructure(*m.defaultMessage()); err != nil { + return nil, err + } + return AuthenticationSSPI{}, nil +} + +// defaultMessage implements the interface Message. +func (m AuthenticationSSPI) defaultMessage() *MessageFormat { + return &authenticationSSPIDefault +} diff --git a/postgres/messages/backend_key_data.go b/postgres/messages/backend_key_data.go index 4e800ec9ff..458809545c 100644 --- a/postgres/messages/backend_key_data.go +++ b/postgres/messages/backend_key_data.go @@ -1,6 +1,23 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package messages -import "bytes" +func init() { + initializeDefaultMessage(BackendKeyData{}) + addMessageHeader(BackendKeyData{}) +} // BackendKeyData provides the client with information about the server. type BackendKeyData struct { @@ -8,12 +25,56 @@ type BackendKeyData struct { SecretKey int32 } -// Bytes returns BackendKeyData as a byte slice, ready to be returned to the client. -func (bkd BackendKeyData) Bytes() []byte { - buf := bytes.Buffer{} - buf.WriteByte('K') // Message Type - WriteNumber(&buf, int32(12)) // Message Length - WriteNumber(&buf, bkd.ProcessID) - WriteNumber(&buf, bkd.SecretKey) - return buf.Bytes() +var backendKeyDataDefault = MessageFormat{ + Name: "BackendKeyData", + Fields: FieldGroup{ + { + Name: "Header", + Type: Byte1, + Flags: Header, + Data: int32('K'), + }, + { + Name: "MessageLength", + Type: Int32, + Flags: MessageLengthInclusive, + Data: int32(12), + }, + { + Name: "ProcessID", + Type: Int32, + Data: int32(0), + }, + { + Name: "SecretKey", + Type: Int32, + Data: int32(0), + }, + }, +} + +var _ Message = BackendKeyData{} + +// encode implements the interface Message. +func (m BackendKeyData) encode() (MessageFormat, error) { + outputMessage := m.defaultMessage().Copy() + outputMessage.Field("ProcessID").MustWrite(m.ProcessID) + outputMessage.Field("SecretKey").MustWrite(m.SecretKey) + return outputMessage, nil +} + +// decode implements the interface Message. +func (m BackendKeyData) decode(s MessageFormat) (Message, error) { + if err := s.MatchesStructure(*m.defaultMessage()); err != nil { + return nil, err + } + return BackendKeyData{ + ProcessID: s.Field("ProcessID").MustGet().(int32), + SecretKey: s.Field("SecretKey").MustGet().(int32), + }, nil +} + +// defaultMessage implements the interface Message. +func (m BackendKeyData) defaultMessage() *MessageFormat { + return &backendKeyDataDefault } diff --git a/postgres/messages/bind.go b/postgres/messages/bind.go new file mode 100644 index 0000000000..67c1fd8bde --- /dev/null +++ b/postgres/messages/bind.go @@ -0,0 +1,184 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package messages + +func init() { + initializeDefaultMessage(Bind{}) + addMessageHeader(Bind{}) +} + +// Bind represents a PostgreSQL message. +type Bind struct { + DestinationPortal string + SourcePreparedStatement string + ParameterFormatCodes []int32 + ParameterValues []BindParameterValue + ResultFormatCodes []int32 +} + +// BindParameterValue are parameter values for the Bind message. +type BindParameterValue struct { + Data []byte + IsNull bool +} + +var bindDefault = MessageFormat{ + Name: "Bind", + Fields: FieldGroup{ + { + Name: "Header", + Type: Byte1, + Flags: Header, + Data: int32('B'), + }, + { + Name: "MessageLength", + Type: Int32, + Flags: MessageLengthInclusive, + Data: int32(0), + }, + { + Name: "DestinationPortal", + Type: String, + Data: "", + }, + { + Name: "SourcePreparedStatement", + Type: String, + Data: "", + }, + { + Name: "ParameterFormatCodes", + Type: Int16, + Data: int32(0), + Children: []FieldGroup{ + { + { + Name: "ParameterFormatCode", + Type: Int16, + Data: int32(0), + }, + }, + }, + }, + { + Name: "ParameterValues", + Type: Int16, + Data: int32(0), + Children: []FieldGroup{ + { + { + Name: "ParameterLength", + Type: Int32, + Flags: ByteCount, + Data: int32(0), + }, + { + Name: "ParameterValue", + Type: ByteN, + Data: []byte{}, + }, + }, + }, + }, + { + Name: "ResultFormatCodes", + Type: Int16, + Data: int32(0), + Children: []FieldGroup{ + { + { + Name: "ResultFormatCode", + Type: Int16, + Data: int32(0), + }, + }, + }, + }, + }, +} + +var _ Message = Bind{} + +// encode implements the interface Message. +func (m Bind) encode() (MessageFormat, error) { + outputMessage := m.defaultMessage().Copy() + outputMessage.Field("DestinationPortal").MustWrite(m.DestinationPortal) + outputMessage.Field("SourcePreparedStatement").MustWrite(m.SourcePreparedStatement) + for i, pFormatCode := range m.ParameterFormatCodes { + outputMessage.Field("ParameterFormatCodes").Child("ParameterFormatCode", i).MustWrite(pFormatCode) + } + for i, paramValue := range m.ParameterValues { + if paramValue.IsNull { + outputMessage.Field("ParameterValues").Child("ParameterLength", i).MustWrite(-1) + } else { + outputMessage.Field("ParameterValues").Child("ParameterLength", i).MustWrite(len(paramValue.Data)) + outputMessage.Field("ParameterValues").Child("ParameterValue", i).MustWrite(paramValue.Data) + } + } + for i, rFormatCode := range m.ResultFormatCodes { + outputMessage.Field("ResultFormatCodes").Child("ResultFormatCode", i).MustWrite(rFormatCode) + } + return outputMessage, nil +} + +// decode implements the interface Message. +func (m Bind) decode(s MessageFormat) (Message, error) { + if err := s.MatchesStructure(*m.defaultMessage()); err != nil { + return nil, err + } + + // Get the parameter format codes + parameterFormatCodesCount := int(s.Field("ParameterFormatCodes").MustGet().(int32)) + parameterFormatCodes := make([]int32, parameterFormatCodesCount) + for i := 0; i < parameterFormatCodesCount; i++ { + parameterFormatCodes[i] = s.Field("ParameterFormatCodes").Child("ParameterFormatCode", i).MustGet().(int32) + } + // Get the parameter values + parameterValuesCount := int(s.Field("ParameterValues").MustGet().(int32)) + parameterValues := make([]BindParameterValue, parameterValuesCount) + for i := 0; i < parameterValuesCount; i++ { + paramLength := s.Field("ParameterValues").Child("ParameterLength", i).MustGet().(int32) + if paramLength == -1 { + parameterValues[i] = BindParameterValue{ + IsNull: true, + } + } else { + parameterValues[i] = BindParameterValue{ + Data: s.Field("ParameterValues").Child("ParameterValue", i).MustGet().([]byte), + IsNull: false, + } + } + } + // Get the result format codes + resultFormatCodesCount := int(s.Field("ResultFormatCodes").MustGet().(int32)) + resultFormatCodes := make([]int32, resultFormatCodesCount) + for i := 0; i < resultFormatCodesCount; i++ { + resultFormatCodes[i] = s.Field("ResultFormatCodes").Child("ResultFormatCode", i).MustGet().(int32) + } + + return Bind{ + DestinationPortal: s.Field("DestinationPortal").MustGet().(string), + SourcePreparedStatement: s.Field("SourcePreparedStatement").MustGet().(string), + ParameterFormatCodes: parameterFormatCodes, + ParameterValues: parameterValues, + ResultFormatCodes: resultFormatCodes, + }, nil +} + +// defaultMessage implements the interface Message. +func (m Bind) defaultMessage() *MessageFormat { + return &bindDefault +} diff --git a/postgres/messages/bind_complete.go b/postgres/messages/bind_complete.go new file mode 100644 index 0000000000..8959a9542e --- /dev/null +++ b/postgres/messages/bind_complete.go @@ -0,0 +1,61 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package messages + +func init() { + initializeDefaultMessage(BindComplete{}) + addMessageHeader(BindComplete{}) +} + +// BindComplete represents a PostgreSQL message. +type BindComplete struct{} + +var bindCompleteDefault = MessageFormat{ + Name: "BindComplete", + Fields: FieldGroup{ + { + Name: "Header", + Type: Byte1, + Flags: Header, + Data: int32('2'), + }, + { + Name: "MessageLength", + Type: Int32, + Flags: MessageLengthInclusive, + Data: int32(4), + }, + }, +} + +var _ Message = BindComplete{} + +// encode implements the interface Message. +func (m BindComplete) encode() (MessageFormat, error) { + return m.defaultMessage().Copy(), nil +} + +// decode implements the interface Message. +func (m BindComplete) decode(s MessageFormat) (Message, error) { + if err := s.MatchesStructure(*m.defaultMessage()); err != nil { + return nil, err + } + return BindComplete{}, nil +} + +// defaultMessage implements the interface Message. +func (m BindComplete) defaultMessage() *MessageFormat { + return &bindCompleteDefault +} diff --git a/postgres/messages/cancel_request.go b/postgres/messages/cancel_request.go new file mode 100644 index 0000000000..3cd8169b5c --- /dev/null +++ b/postgres/messages/cancel_request.go @@ -0,0 +1,78 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package messages + +func init() { + initializeDefaultMessage(CancelRequest{}) +} + +// CancelRequest represents a PostgreSQL message. +type CancelRequest struct { + ProcessID int32 + SecretKey int32 +} + +var cancelRequestDefault = MessageFormat{ + Name: "CancelRequest", + Fields: FieldGroup{ + { + Name: "MessageLength", + Type: Int32, + Flags: MessageLengthInclusive, + Data: int32(0), + }, + { + Name: "RequestCode", + Type: Int32, + Data: int32(80877102), + }, + { + Name: "ProcessID", + Type: Int32, + Data: int32(0), + }, + { + Name: "SecretKey", + Type: Int32, + Data: int32(0), + }, + }, +} + +var _ Message = CancelRequest{} + +// encode implements the interface Message. +func (m CancelRequest) encode() (MessageFormat, error) { + outputMessage := m.defaultMessage().Copy() + outputMessage.Field("ProcessID").MustWrite(m.ProcessID) + outputMessage.Field("SecretKey").MustWrite(m.SecretKey) + return outputMessage, nil +} + +// decode implements the interface Message. +func (m CancelRequest) decode(s MessageFormat) (Message, error) { + if err := s.MatchesStructure(*m.defaultMessage()); err != nil { + return nil, err + } + return CancelRequest{ + ProcessID: s.Field("ProcessID").MustGet().(int32), + SecretKey: s.Field("SecretKey").MustGet().(int32), + }, nil +} + +// defaultMessage implements the interface Message. +func (m CancelRequest) defaultMessage() *MessageFormat { + return &cancelRequestDefault +} diff --git a/postgres/messages/close.go b/postgres/messages/close.go new file mode 100644 index 0000000000..0e156c9d89 --- /dev/null +++ b/postgres/messages/close.go @@ -0,0 +1,95 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package messages + +import "fmt" + +func init() { + initializeDefaultMessage(Close{}) + addMessageHeader(Close{}) +} + +// Close represents a PostgreSQL message. +type Close struct { + ClosingPreparedStatement bool // ClosingPreparedStatement: If true, closing a prepared statement. If false, closing a portal. + Target string // Target is the name of whatever we are closing. +} + +var closeDefault = MessageFormat{ + Name: "Close", + Fields: FieldGroup{ + { + Name: "Header", + Type: Byte1, + Flags: Header, + Data: int32('C'), + }, + { + Name: "MessageLength", + Type: Int32, + Flags: MessageLengthInclusive, + Data: int32(0), + }, + { + Name: "ClosingTarget", + Type: Byte1, + Data: int32(0), + }, + { + Name: "TargetName", + Type: String, + Data: "", + }, + }, +} + +var _ Message = Close{} + +// encode implements the interface Message. +func (m Close) encode() (MessageFormat, error) { + outputMessage := m.defaultMessage().Copy() + if m.ClosingPreparedStatement { + outputMessage.Field("ClosingTarget").MustWrite('S') + } else { + outputMessage.Field("ClosingTarget").MustWrite('P') + } + outputMessage.Field("TargetName").MustWrite(m.Target) + return outputMessage, nil +} + +// decode implements the interface Message. +func (m Close) decode(s MessageFormat) (Message, error) { + if err := s.MatchesStructure(*m.defaultMessage()); err != nil { + return nil, err + } + closingTarget := s.Field("ClosingTarget").MustGet().(int32) + var closingPreparedStatement bool + if closingTarget == 'S' { + closingPreparedStatement = true + } else if closingTarget == 'P' { + closingPreparedStatement = false + } else { + return nil, fmt.Errorf("Unknown closing target in Close message: %d", closingTarget) + } + return Close{ + ClosingPreparedStatement: closingPreparedStatement, + Target: s.Field("TargetName").MustGet().(string), + }, nil +} + +// defaultMessage implements the interface Message. +func (m Close) defaultMessage() *MessageFormat { + return &closeDefault +} diff --git a/postgres/messages/close_complete.go b/postgres/messages/close_complete.go new file mode 100644 index 0000000000..0097d1daa2 --- /dev/null +++ b/postgres/messages/close_complete.go @@ -0,0 +1,61 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package messages + +func init() { + initializeDefaultMessage(CloseComplete{}) + addMessageHeader(CloseComplete{}) +} + +// CloseComplete represents a PostgreSQL message. +type CloseComplete struct{} + +var closeCompleteDefault = MessageFormat{ + Name: "CloseComplete", + Fields: FieldGroup{ + { + Name: "Header", + Type: Byte1, + Flags: Header, + Data: int32('3'), + }, + { + Name: "MessageLength", + Type: Int32, + Flags: MessageLengthInclusive, + Data: int32(4), + }, + }, +} + +var _ Message = CloseComplete{} + +// encode implements the interface Message. +func (m CloseComplete) encode() (MessageFormat, error) { + return m.defaultMessage().Copy(), nil +} + +// decode implements the interface Message. +func (m CloseComplete) decode(s MessageFormat) (Message, error) { + if err := s.MatchesStructure(*m.defaultMessage()); err != nil { + return nil, err + } + return CloseComplete{}, nil +} + +// defaultMessage implements the interface Message. +func (m CloseComplete) defaultMessage() *MessageFormat { + return &closeCompleteDefault +} diff --git a/postgres/messages/command_complete.go b/postgres/messages/command_complete.go index de6c339546..c086f17dcd 100644 --- a/postgres/messages/command_complete.go +++ b/postgres/messages/command_complete.go @@ -1,76 +1,112 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package messages import ( - "bytes" "fmt" "strconv" "strings" ) -// CommandCompleteTag indicates which SQL command was completed. -type CommandCompleteTag byte - -const ( - CommandCompleteTag_INSERT CommandCompleteTag = iota - CommandCompleteTag_DELETE - CommandCompleteTag_UPDATE - CommandCompleteTag_MERGE - CommandCompleteTag_SELECT - CommandCompleteTag_MOVE - CommandCompleteTag_FETCH - CommandCompleteTag_COPY -) +func init() { + initializeDefaultMessage(CommandComplete{}) +} // CommandComplete tells the client that the command has completed. type CommandComplete struct { - Tag CommandCompleteTag - Rows int32 + Query string + Rows int32 } -// Bytes returns CommandComplete as a byte slice, ready to be returned to the client. -func (cc CommandComplete) Bytes() []byte { - buf := bytes.Buffer{} - buf.WriteByte('C') // Message Type - WriteNumber(&buf, int32(0)) // Message length, will be corrected later - switch cc.Tag { - case CommandCompleteTag_INSERT: - buf.WriteString("INSERT 0 ") - case CommandCompleteTag_DELETE: - buf.WriteString("DELETE ") - case CommandCompleteTag_UPDATE: - buf.WriteString("UPDATE ") - case CommandCompleteTag_MERGE: - buf.WriteString("MERGE ") - case CommandCompleteTag_SELECT: - buf.WriteString("SELECT ") - case CommandCompleteTag_MOVE: - buf.WriteString("MOVE ") - case CommandCompleteTag_FETCH: - buf.WriteString("FETCH ") - case CommandCompleteTag_COPY: - buf.WriteString("COPY ") +var commandCompleteDefault = MessageFormat{ + Name: "CommandComplete", + Fields: FieldGroup{ + { + Name: "Header", + Type: Byte1, + Flags: Header, + Data: int32('C'), + }, + { + Name: "MessageLength", + Type: Int32, + Flags: MessageLengthInclusive, + Data: int32(0), + }, + { + Name: "CommandTag", + Type: String, + Data: "", + }, + }, +} + +var _ Message = CommandComplete{} + +// IsIUD returns whether the query is either an INSERT, UPDATE, or DELETE query. +func (m CommandComplete) IsIUD() bool { + query := strings.TrimSpace(strings.ToLower(m.Query)) + if strings.HasPrefix(query, "insert") || + strings.HasPrefix(query, "update") || + strings.HasPrefix(query, "delete") { + return true + } else { + return false } - buf.WriteString(strconv.Itoa(int(cc.Rows))) - buf.WriteByte(0) // Trailing NULL character, denoting the end of the string - return WriteLength(buf.Bytes()) } -// QueryToCommandCompleteTag returns the appropriate command tag for the given query. -func QueryToCommandCompleteTag(query string) (CommandCompleteTag, error) { - query = strings.TrimSpace(strings.ToLower(query)) +// encode implements the interface Message. +func (m CommandComplete) encode() (MessageFormat, error) { + outputMessage := m.defaultMessage().Copy() + query := strings.TrimSpace(strings.ToLower(m.Query)) if strings.HasPrefix(query, "select") { - return CommandCompleteTag_SELECT, nil + outputMessage.Field("CommandTag").MustWrite(fmt.Sprintf("SELECT %d", m.Rows)) } else if strings.HasPrefix(query, "insert") { - return CommandCompleteTag_INSERT, nil + outputMessage.Field("CommandTag").MustWrite(fmt.Sprintf("INSERT 0 %d", m.Rows)) } else if strings.HasPrefix(query, "update") { - return CommandCompleteTag_UPDATE, nil + outputMessage.Field("CommandTag").MustWrite(fmt.Sprintf("UPDATE %d", m.Rows)) } else if strings.HasPrefix(query, "delete") { - return CommandCompleteTag_DELETE, nil + outputMessage.Field("CommandTag").MustWrite(fmt.Sprintf("DELETE %d", m.Rows)) } else if strings.HasPrefix(query, "create") { - return CommandCompleteTag_SELECT, nil + outputMessage.Field("CommandTag").MustWrite(fmt.Sprintf("SELECT %d", m.Rows)) } else if strings.HasPrefix(query, "call") { - return CommandCompleteTag_SELECT, nil + outputMessage.Field("CommandTag").MustWrite(fmt.Sprintf("SELECT %d", m.Rows)) } else { - return 0, fmt.Errorf("unsupported query for now") + return MessageFormat{}, fmt.Errorf("unsupported query for now") + } + return outputMessage, nil +} + +// decode implements the interface Message. +func (m CommandComplete) decode(s MessageFormat) (Message, error) { + if err := s.MatchesStructure(*m.defaultMessage()); err != nil { + return nil, err } + query := strings.TrimSpace(s.Field("CommandTag").MustGet().(string)) + tokens := strings.Split(query, " ") + rows, err := strconv.Atoi(tokens[len(tokens)-1]) + if err != nil { + return nil, err + } + return CommandComplete{ + Query: query, + Rows: int32(rows), + }, nil +} + +// defaultMessage implements the interface Message. +func (m CommandComplete) defaultMessage() *MessageFormat { + return &commandCompleteDefault } diff --git a/postgres/messages/copy_both_response.go b/postgres/messages/copy_both_response.go new file mode 100644 index 0000000000..58a23a2d30 --- /dev/null +++ b/postgres/messages/copy_both_response.go @@ -0,0 +1,110 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package messages + +import "fmt" + +func init() { + initializeDefaultMessage(CopyBothResponse{}) +} + +// CopyBothResponse represents a PostgreSQL message. +type CopyBothResponse struct { + IsTextual bool // IsTextual states whether the copy is textual or binary. + FormatCodes []int32 +} + +var copyBothResponseDefault = MessageFormat{ + Name: "CopyBothResponse", + Fields: FieldGroup{ + { + Name: "Header", + Type: Byte1, + Flags: Header, + Data: int32('W'), + }, + { + Name: "MessageLength", + Type: Int32, + Flags: MessageLengthInclusive, + Data: int32(0), + }, + { + Name: "ResponseType", + Type: Int8, + Data: int32(0), + }, + { + Name: "Columns", + Type: Int16, + Data: int32(0), + Children: []FieldGroup{ + { + { + Name: "FormatCode", + Type: Int16, + Data: int32(0), + }, + }, + }, + }, + }, +} + +var _ Message = CopyBothResponse{} + +// encode implements the interface Message. +func (m CopyBothResponse) encode() (MessageFormat, error) { + outputMessage := m.defaultMessage().Copy() + if m.IsTextual { + outputMessage.Field("ResponseType").MustWrite(0) + } else { + outputMessage.Field("ResponseType").MustWrite(1) + } + for i, formatCode := range m.FormatCodes { + outputMessage.Field("Columns").Child("FormatCode", i).MustWrite(formatCode) + } + return outputMessage, nil +} + +// decode implements the interface Message. +func (m CopyBothResponse) decode(s MessageFormat) (Message, error) { + if err := s.MatchesStructure(*m.defaultMessage()); err != nil { + return nil, err + } + var isTextual bool + responseType := s.Field("ResponseType").MustGet().(int32) + if responseType == 0 { + isTextual = true + } else if responseType == 1 { + isTextual = false + } else { + return nil, fmt.Errorf("Unknown response type in the CopyBothResponse message: %d", responseType) + } + count := int(s.Field("Columns").MustGet().(int32)) + formatCodes := make([]int32, count) + for i := 0; i < count; i++ { + formatCodes[i] = s.Field("Columns").Child("FormatCode", i).MustGet().(int32) + } + return CopyBothResponse{ + IsTextual: isTextual, + FormatCodes: formatCodes, + }, nil +} + +// defaultMessage implements the interface Message. +func (m CopyBothResponse) defaultMessage() *MessageFormat { + return ©BothResponseDefault +} diff --git a/postgres/messages/copy_data.go b/postgres/messages/copy_data.go new file mode 100644 index 0000000000..5e33daccb9 --- /dev/null +++ b/postgres/messages/copy_data.go @@ -0,0 +1,72 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package messages + +func init() { + initializeDefaultMessage(CopyData{}) + addMessageHeader(CopyData{}) +} + +// CopyData represents a PostgreSQL message. +type CopyData struct { + Data []byte +} + +var copyDataDefault = MessageFormat{ + Name: "CopyData", + Fields: FieldGroup{ + { + Name: "Header", + Type: Byte1, + Flags: Header, + Data: int32('d'), + }, + { + Name: "MessageLength", + Type: Int32, + Flags: MessageLengthInclusive, + Data: int32(0), + }, + { + Name: "Data", + Type: ByteN, + Data: []byte{}, + }, + }, +} + +var _ Message = CopyData{} + +// encode implements the interface Message. +func (m CopyData) encode() (MessageFormat, error) { + outputMessage := m.defaultMessage().Copy() + outputMessage.Field("Data").MustWrite(m.Data) + return outputMessage, nil +} + +// decode implements the interface Message. +func (m CopyData) decode(s MessageFormat) (Message, error) { + if err := s.MatchesStructure(*m.defaultMessage()); err != nil { + return nil, err + } + return CopyData{ + Data: s.Field("Data").MustGet().([]byte), + }, nil +} + +// defaultMessage implements the interface Message. +func (m CopyData) defaultMessage() *MessageFormat { + return ©DataDefault +} diff --git a/postgres/messages/copy_done.go b/postgres/messages/copy_done.go new file mode 100644 index 0000000000..8e93802c42 --- /dev/null +++ b/postgres/messages/copy_done.go @@ -0,0 +1,61 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package messages + +func init() { + initializeDefaultMessage(CopyDone{}) + addMessageHeader(CopyDone{}) +} + +// CopyDone represents a PostgreSQL message. +type CopyDone struct{} + +var copyDoneDefault = MessageFormat{ + Name: "CopyDone", + Fields: FieldGroup{ + { + Name: "Header", + Type: Byte1, + Flags: Header, + Data: int32('c'), + }, + { + Name: "MessageLength", + Type: Int32, + Flags: MessageLengthInclusive, + Data: int32(4), + }, + }, +} + +var _ Message = CopyDone{} + +// encode implements the interface Message. +func (m CopyDone) encode() (MessageFormat, error) { + return m.defaultMessage().Copy(), nil +} + +// decode implements the interface Message. +func (m CopyDone) decode(s MessageFormat) (Message, error) { + if err := s.MatchesStructure(*m.defaultMessage()); err != nil { + return nil, err + } + return CopyDone{}, nil +} + +// defaultMessage implements the interface Message. +func (m CopyDone) defaultMessage() *MessageFormat { + return ©DoneDefault +} diff --git a/postgres/messages/copy_fail.go b/postgres/messages/copy_fail.go new file mode 100644 index 0000000000..55759af610 --- /dev/null +++ b/postgres/messages/copy_fail.go @@ -0,0 +1,72 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package messages + +func init() { + initializeDefaultMessage(CopyFail{}) + addMessageHeader(CopyFail{}) +} + +// CopyFail represents a PostgreSQL message. +type CopyFail struct { + ErrorMessage string +} + +var copyFailDefault = MessageFormat{ + Name: "CopyFail", + Fields: FieldGroup{ + { + Name: "Header", + Type: Byte1, + Flags: Header, + Data: int32('f'), + }, + { + Name: "MessageLength", + Type: Int32, + Flags: MessageLengthInclusive, + Data: int32(0), + }, + { + Name: "ErrorMessage", + Type: String, + Data: "", + }, + }, +} + +var _ Message = CopyFail{} + +// encode implements the interface Message. +func (m CopyFail) encode() (MessageFormat, error) { + outputMessage := m.defaultMessage().Copy() + outputMessage.Field("ErrorMessage").MustWrite(m.ErrorMessage) + return outputMessage, nil +} + +// decode implements the interface Message. +func (m CopyFail) decode(s MessageFormat) (Message, error) { + if err := s.MatchesStructure(*m.defaultMessage()); err != nil { + return nil, err + } + return CopyFail{ + ErrorMessage: s.Field("ErrorMessage").MustGet().(string), + }, nil +} + +// defaultMessage implements the interface Message. +func (m CopyFail) defaultMessage() *MessageFormat { + return ©FailDefault +} diff --git a/postgres/messages/copy_in_response.go b/postgres/messages/copy_in_response.go new file mode 100644 index 0000000000..2aea1c229f --- /dev/null +++ b/postgres/messages/copy_in_response.go @@ -0,0 +1,110 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package messages + +import "fmt" + +func init() { + initializeDefaultMessage(CopyInResponse{}) +} + +// CopyInResponse represents a PostgreSQL message. +type CopyInResponse struct { + IsTextual bool // IsTextual states whether the copy is textual or binary. + FormatCodes []int32 +} + +var copyInResponseDefault = MessageFormat{ + Name: "CopyInResponse", + Fields: FieldGroup{ + { + Name: "Header", + Type: Byte1, + Flags: Header, + Data: int32('G'), + }, + { + Name: "MessageLength", + Type: Int32, + Flags: MessageLengthInclusive, + Data: int32(0), + }, + { + Name: "ResponseType", + Type: Int8, + Data: int32(0), + }, + { + Name: "Columns", + Type: Int16, + Data: int32(0), + Children: []FieldGroup{ + { + { + Name: "FormatCode", + Type: Int16, + Data: int32(0), + }, + }, + }, + }, + }, +} + +var _ Message = CopyInResponse{} + +// encode implements the interface Message. +func (m CopyInResponse) encode() (MessageFormat, error) { + outputMessage := m.defaultMessage().Copy() + if m.IsTextual { + outputMessage.Field("ResponseType").MustWrite(0) + } else { + outputMessage.Field("ResponseType").MustWrite(1) + } + for i, formatCode := range m.FormatCodes { + outputMessage.Field("Columns").Child("FormatCode", i).MustWrite(formatCode) + } + return outputMessage, nil +} + +// decode implements the interface Message. +func (m CopyInResponse) decode(s MessageFormat) (Message, error) { + if err := s.MatchesStructure(*m.defaultMessage()); err != nil { + return nil, err + } + var isTextual bool + responseType := s.Field("ResponseType").MustGet().(int32) + if responseType == 0 { + isTextual = true + } else if responseType == 1 { + isTextual = false + } else { + return nil, fmt.Errorf("Unknown response type in the CopyInResponse message: %d", responseType) + } + count := int(s.Field("Columns").MustGet().(int32)) + formatCodes := make([]int32, count) + for i := 0; i < count; i++ { + formatCodes[i] = s.Field("Columns").Child("FormatCode", i).MustGet().(int32) + } + return CopyInResponse{ + IsTextual: isTextual, + FormatCodes: formatCodes, + }, nil +} + +// defaultMessage implements the interface Message. +func (m CopyInResponse) defaultMessage() *MessageFormat { + return ©InResponseDefault +} diff --git a/postgres/messages/copy_out_response.go b/postgres/messages/copy_out_response.go new file mode 100644 index 0000000000..39e65b3c37 --- /dev/null +++ b/postgres/messages/copy_out_response.go @@ -0,0 +1,110 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package messages + +import "fmt" + +func init() { + initializeDefaultMessage(CopyOutResponse{}) +} + +// CopyOutResponse represents a PostgreSQL message. +type CopyOutResponse struct { + IsTextual bool // IsTextual states whether the copy is textual or binary. + FormatCodes []int32 +} + +var copyOutResponseDefault = MessageFormat{ + Name: "CopyOutResponse", + Fields: FieldGroup{ + { + Name: "Header", + Type: Byte1, + Flags: Header, + Data: int32('H'), + }, + { + Name: "MessageLength", + Type: Int32, + Flags: MessageLengthInclusive, + Data: int32(0), + }, + { + Name: "ResponseType", + Type: Int8, + Data: int32(0), + }, + { + Name: "Columns", + Type: Int16, + Data: int32(0), + Children: []FieldGroup{ + { + { + Name: "FormatCode", + Type: Int16, + Data: int32(0), + }, + }, + }, + }, + }, +} + +var _ Message = CopyOutResponse{} + +// encode implements the interface Message. +func (m CopyOutResponse) encode() (MessageFormat, error) { + outputMessage := m.defaultMessage().Copy() + if m.IsTextual { + outputMessage.Field("ResponseType").MustWrite(0) + } else { + outputMessage.Field("ResponseType").MustWrite(1) + } + for i, formatCode := range m.FormatCodes { + outputMessage.Field("Columns").Child("FormatCode", i).MustWrite(formatCode) + } + return outputMessage, nil +} + +// decode implements the interface Message. +func (m CopyOutResponse) decode(s MessageFormat) (Message, error) { + if err := s.MatchesStructure(*m.defaultMessage()); err != nil { + return nil, err + } + var isTextual bool + responseType := s.Field("ResponseType").MustGet().(int32) + if responseType == 0 { + isTextual = true + } else if responseType == 1 { + isTextual = false + } else { + return nil, fmt.Errorf("Unknown response type in the CopyOutResponse message: %d", responseType) + } + count := int(s.Field("Columns").MustGet().(int32)) + formatCodes := make([]int32, count) + for i := 0; i < count; i++ { + formatCodes[i] = s.Field("Columns").Child("FormatCode", i).MustGet().(int32) + } + return CopyOutResponse{ + IsTextual: isTextual, + FormatCodes: formatCodes, + }, nil +} + +// defaultMessage implements the interface Message. +func (m CopyOutResponse) defaultMessage() *MessageFormat { + return ©OutResponseDefault +} diff --git a/postgres/messages/data_row.go b/postgres/messages/data_row.go index 2ddddd9e00..76e84c0a2e 100644 --- a/postgres/messages/data_row.go +++ b/postgres/messages/data_row.go @@ -1,103 +1,102 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package messages import ( - "bytes" - "fmt" - "strconv" - "time" - "github.com/dolthub/vitess/go/sqltypes" - - "github.com/shopspring/decimal" ) +func init() { + initializeDefaultMessage(DataRow{}) +} + // DataRow represents a row of data. type DataRow struct { - Values []DataRowValue + Values []sqltypes.Value } -// DataRowValue represents a column's value in a DataRow. -type DataRowValue struct { - Value any +var dataRowDefault = MessageFormat{ + Name: "DataRow", + Fields: FieldGroup{ + { + Name: "Header", + Type: Byte1, + Flags: Header, + Data: int32('D'), + }, + { + Name: "MessageLength", + Type: Int32, + Flags: MessageLengthInclusive, + Data: int32(0), + }, + { + Name: "Columns", + Type: Int16, + Data: int32(0), + Children: []FieldGroup{ + { + { + Name: "ColumnLength", + Type: Int32, + Flags: ByteCount, + Data: int32(0), + }, + { + Name: "ColumnData", + Type: ByteN, + Data: []byte{}, + }, + }, + }, + }, + }, } -// NewDataRow creates a new DataRow from the given rows. -func NewDataRow(row []sqltypes.Value) DataRow { - values := make([]DataRowValue, len(row)) - for i, value := range row { - values[i] = DataRowValue{value.ToString()} - } - return DataRow{ - Values: values, +var _ Message = DataRow{} + +// encode implements the interface Message. +func (m DataRow) encode() (MessageFormat, error) { + outputMessage := m.defaultMessage().Copy() + for i := 0; i < len(m.Values); i++ { + if m.Values[i].IsNull() { + outputMessage.Field("Columns").Child("ColumnLength", i).MustWrite(-1) + } else { + value := []byte(m.Values[i].ToString()) + outputMessage.Field("Columns").Child("ColumnLength", i).MustWrite(len(value)) + outputMessage.Field("Columns").Child("ColumnData", i).MustWrite(value) + } } + return outputMessage, nil } -// Bytes returns DataRow as a byte slice, ready to be returned to the client. -func (dr DataRow) Bytes() []byte { - buf := bytes.Buffer{} - buf.WriteByte('D') // Message Type - WriteNumber(&buf, int32(0)) // Message length, will be corrected later - WriteNumber(&buf, int16(len(dr.Values))) - for _, drv := range dr.Values { - drv.Bytes(&buf) +// decode implements the interface Message. +func (m DataRow) decode(s MessageFormat) (Message, error) { + if err := s.MatchesStructure(*m.defaultMessage()); err != nil { + return nil, err } - return WriteLength(buf.Bytes()) + columnCount := int(s.Field("Columns").MustGet().(int32)) + for i := 0; i < columnCount; i++ { + //TODO: decode the message in here + } + return DataRow{ + Values: nil, + }, nil } -// Bytes writes the value into the given buffer. -func (drv DataRowValue) Bytes(buf *bytes.Buffer) { - var dataBytes []byte - switch val := drv.Value.(type) { - case int: - dataBytes = []byte(strconv.FormatInt(int64(val), 10)) - case int8: - dataBytes = []byte(strconv.FormatInt(int64(val), 10)) - case int16: - dataBytes = []byte(strconv.FormatInt(int64(val), 10)) - case int32: - dataBytes = []byte(strconv.FormatInt(int64(val), 10)) - case int64: - dataBytes = []byte(strconv.FormatInt(val, 10)) - case uint: - dataBytes = []byte(strconv.FormatUint(uint64(val), 10)) - case uint8: - dataBytes = []byte(strconv.FormatUint(uint64(val), 10)) - case uint16: - dataBytes = []byte(strconv.FormatUint(uint64(val), 10)) - case uint32: - dataBytes = []byte(strconv.FormatUint(uint64(val), 10)) - case uint64: - dataBytes = []byte(strconv.FormatUint(val, 10)) - case float32: - dataBytes = []byte(strconv.FormatFloat(float64(val), 'g', -1, 32)) - case float64: - dataBytes = []byte(strconv.FormatFloat(val, 'g', -1, 64)) - case decimal.NullDecimal: - if !val.Valid { - WriteNumber(buf, int32(-1)) - return - } - dataBytes = []byte(val.Decimal.String()) - case decimal.Decimal: - dataBytes = []byte(val.String()) - case []byte: - dataBytes = val - case string: - dataBytes = []byte(val) - case bool: - if val { - dataBytes = []byte("true") - } else { - dataBytes = []byte("false") - } - case time.Time: - dataBytes = []byte(val.Format(time.RFC3339)) - case nil: - WriteNumber(buf, int32(-1)) - return - default: - panic(fmt.Errorf("unknown DataRow value type: %T", val)) - } - WriteNumber(buf, int32(len(dataBytes))) // This length only covers the value's size - buf.Write(dataBytes) +// defaultMessage implements the interface Message. +func (m DataRow) defaultMessage() *MessageFormat { + return &dataRowDefault } diff --git a/postgres/messages/describe.go b/postgres/messages/describe.go new file mode 100644 index 0000000000..ef7f54f403 --- /dev/null +++ b/postgres/messages/describe.go @@ -0,0 +1,95 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package messages + +import "fmt" + +func init() { + initializeDefaultMessage(Describe{}) + addMessageHeader(Describe{}) +} + +// Describe represents a PostgreSQL message. +type Describe struct { + IsPrepared bool // IsPrepared states whether we're describing a prepared statement or a portal. + Target string +} + +var describeDefault = MessageFormat{ + Name: "Describe", + Fields: FieldGroup{ + { + Name: "Header", + Type: Byte1, + Flags: Header, + Data: int32('D'), + }, + { + Name: "MessageLength", + Type: Int32, + Flags: MessageLengthInclusive, + Data: int32(0), + }, + { + Name: "DescribingTarget", + Type: Byte1, + Data: int32(0), + }, + { + Name: "TargetName", + Type: String, + Data: "", + }, + }, +} + +var _ Message = Describe{} + +// encode implements the interface Message. +func (m Describe) encode() (MessageFormat, error) { + outputMessage := m.defaultMessage().Copy() + if m.IsPrepared { + outputMessage.Field("DescribingTarget").MustWrite('S') + } else { + outputMessage.Field("DescribingTarget").MustWrite('P') + } + outputMessage.Field("TargetName").MustWrite(m.Target) + return outputMessage, nil +} + +// decode implements the interface Message. +func (m Describe) decode(s MessageFormat) (Message, error) { + if err := s.MatchesStructure(*m.defaultMessage()); err != nil { + return nil, err + } + describingTarget := s.Field("DescribingTarget").MustGet().(int32) + var isPrepared bool + if describingTarget == 'S' { + isPrepared = true + } else if describingTarget == 'P' { + isPrepared = false + } else { + return nil, fmt.Errorf("Unknown describing target in Describe message: %d", describingTarget) + } + return Describe{ + IsPrepared: isPrepared, + Target: s.Field("TargetName").MustGet().(string), + }, nil +} + +// defaultMessage implements the interface Message. +func (m Describe) defaultMessage() *MessageFormat { + return &describeDefault +} diff --git a/postgres/messages/empty_query_response.go b/postgres/messages/empty_query_response.go new file mode 100644 index 0000000000..43748f1016 --- /dev/null +++ b/postgres/messages/empty_query_response.go @@ -0,0 +1,61 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package messages + +func init() { + initializeDefaultMessage(EmptyQueryResponse{}) + addMessageHeader(EmptyQueryResponse{}) +} + +// EmptyQueryResponse represents a PostgreSQL message. +type EmptyQueryResponse struct{} + +var emptyQueryResponseDefault = MessageFormat{ + Name: "EmptyQueryResponse", + Fields: FieldGroup{ + { + Name: "Header", + Type: Byte1, + Flags: Header, + Data: int32('I'), + }, + { + Name: "MessageLength", + Type: Int32, + Flags: MessageLengthInclusive, + Data: int32(4), + }, + }, +} + +var _ Message = EmptyQueryResponse{} + +// encode implements the interface Message. +func (m EmptyQueryResponse) encode() (MessageFormat, error) { + return m.defaultMessage().Copy(), nil +} + +// decode implements the interface Message. +func (m EmptyQueryResponse) decode(s MessageFormat) (Message, error) { + if err := s.MatchesStructure(*m.defaultMessage()); err != nil { + return nil, err + } + return EmptyQueryResponse{}, nil +} + +// defaultMessage implements the interface Message. +func (m EmptyQueryResponse) defaultMessage() *MessageFormat { + return &emptyQueryResponseDefault +} diff --git a/postgres/messages/error_response.go b/postgres/messages/error_response.go new file mode 100644 index 0000000000..426ab237fb --- /dev/null +++ b/postgres/messages/error_response.go @@ -0,0 +1,103 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package messages + +func init() { + initializeDefaultMessage(ErrorResponse{}) +} + +// ErrorResponse represents a PostgreSQL message. +type ErrorResponse struct { + Fields []ErrorResponseField +} + +// ErrorResponseField are the fields to an ErrorResponse message. +type ErrorResponseField struct { + Code int32 + Value string +} + +var errorResponseDefault = MessageFormat{ + Name: "ErrorResponse", + Fields: FieldGroup{ + { + Name: "Header", + Type: Byte1, + Flags: Header, + Data: int32('E'), + }, + { + Name: "MessageLength", + Type: Int32, + Flags: MessageLengthInclusive, + Data: int32(0), + }, + { + Name: "Fields", + Type: Repeated, + Flags: RepeatedTerminator, + Data: int32(0), + Children: []FieldGroup{ + { + { + Name: "Code", + Type: Byte1, + Data: int32(0), + }, + { + Name: "Value", + Type: String, + Data: "", + }, + }, + }, + }, + }, +} + +var _ Message = ErrorResponse{} + +// encode implements the interface Message. +func (m ErrorResponse) encode() (MessageFormat, error) { + outputMessage := m.defaultMessage().Copy() + for i, field := range m.Fields { + outputMessage.Field("Fields").Child("Code", i).MustWrite(field.Code) + outputMessage.Field("Fields").Child("Value", i).MustWrite(field.Value) + } + return outputMessage, nil +} + +// decode implements the interface Message. +func (m ErrorResponse) decode(s MessageFormat) (Message, error) { + if err := s.MatchesStructure(*m.defaultMessage()); err != nil { + return nil, err + } + count := int(s.Field("Fields").MustGet().(int32)) + fields := make([]ErrorResponseField, count) + for i := 0; i < count; i++ { + fields[i] = ErrorResponseField{ + Code: s.Field("Fields").Child("Code", i).MustGet().(int32), + Value: s.Field("Fields").Child("Value", i).MustGet().(string), + } + } + return ErrorResponse{ + Fields: fields, + }, nil +} + +// defaultMessage implements the interface Message. +func (m ErrorResponse) defaultMessage() *MessageFormat { + return &errorResponseDefault +} diff --git a/postgres/messages/execute.go b/postgres/messages/execute.go new file mode 100644 index 0000000000..d75569da1d --- /dev/null +++ b/postgres/messages/execute.go @@ -0,0 +1,80 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package messages + +func init() { + initializeDefaultMessage(Execute{}) + addMessageHeader(Execute{}) +} + +// Execute represents a PostgreSQL message. +type Execute struct { + Portal string + RowMax int32 +} + +var executeDefault = MessageFormat{ + Name: "Execute", + Fields: FieldGroup{ + { + Name: "Header", + Type: Byte1, + Flags: Header, + Data: int32('E'), + }, + { + Name: "MessageLength", + Type: Int32, + Flags: MessageLengthInclusive, + Data: int32(0), + }, + { + Name: "Portal", + Type: String, + Data: "", + }, + { + Name: "RowMax", + Type: Int32, + Data: int32(0), + }, + }, +} + +var _ Message = Execute{} + +// encode implements the interface Message. +func (m Execute) encode() (MessageFormat, error) { + outputMessage := m.defaultMessage().Copy() + outputMessage.Field("Portal").MustWrite(m.Portal) + outputMessage.Field("RowMax").MustWrite(m.RowMax) + return outputMessage, nil +} + +// decode implements the interface Message. +func (m Execute) decode(s MessageFormat) (Message, error) { + if err := s.MatchesStructure(*m.defaultMessage()); err != nil { + return nil, err + } + return Execute{ + Portal: s.Field("Portal").MustGet().(string), + RowMax: s.Field("RowMax").MustGet().(int32), + }, nil +} + +// defaultMessage implements the interface Message. +func (m Execute) defaultMessage() *MessageFormat { + return &executeDefault +} diff --git a/postgres/messages/flush.go b/postgres/messages/flush.go new file mode 100644 index 0000000000..69484803ae --- /dev/null +++ b/postgres/messages/flush.go @@ -0,0 +1,61 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package messages + +func init() { + initializeDefaultMessage(Flush{}) + addMessageHeader(Flush{}) +} + +// Flush represents a PostgreSQL message. +type Flush struct{} + +var flushDefault = MessageFormat{ + Name: "Flush", + Fields: FieldGroup{ + { + Name: "Header", + Type: Byte1, + Flags: Header, + Data: int32('H'), + }, + { + Name: "MessageLength", + Type: Int32, + Flags: MessageLengthInclusive, + Data: int32(0), + }, + }, +} + +var _ Message = Flush{} + +// encode implements the interface Message. +func (m Flush) encode() (MessageFormat, error) { + return m.defaultMessage().Copy(), nil +} + +// decode implements the interface Message. +func (m Flush) decode(s MessageFormat) (Message, error) { + if err := s.MatchesStructure(*m.defaultMessage()); err != nil { + return nil, err + } + return Flush{}, nil +} + +// defaultMessage implements the interface Message. +func (m Flush) defaultMessage() *MessageFormat { + return &flushDefault +} diff --git a/postgres/messages/function_call.go b/postgres/messages/function_call.go new file mode 100644 index 0000000000..d3f1a500ae --- /dev/null +++ b/postgres/messages/function_call.go @@ -0,0 +1,159 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package messages + +func init() { + initializeDefaultMessage(FunctionCall{}) + addMessageHeader(FunctionCall{}) +} + +// FunctionCall represents a PostgreSQL message. +type FunctionCall struct { + ObjectID int32 + ArgumentFormatCodes []int32 + Arguments []FunctionCallArgument + ResultFormatCode int32 +} + +// FunctionCallArgument are arguments for the FunctionCall message. +type FunctionCallArgument struct { + Data []byte + IsNull bool +} + +var functionCallDefault = MessageFormat{ + Name: "FunctionCall", + Fields: FieldGroup{ + { + Name: "Header", + Type: Byte1, + Flags: Header, + Data: int32('F'), + }, + { + Name: "MessageLength", + Type: Int32, + Flags: MessageLengthInclusive, + Data: int32(0), + }, + { + Name: "ObjectID", + Type: Int32, + Data: int32(0), + }, + { + Name: "ArgumentFormatCodes", + Type: Int16, + Data: int32(0), + Children: []FieldGroup{ + { + { + Name: "ArgumentFormatCode", + Type: Int16, + Data: int32(0), + }, + }, + }, + }, + { + Name: "Arguments", + Type: Int16, + Data: int32(0), + Children: []FieldGroup{ + { + { + Name: "ArgumentLength", + Type: Int32, + Flags: ByteCount, + Data: int32(0), + }, + { + Name: "ArgumentValue", + Type: ByteN, + Data: []byte{}, + }, + }, + }, + }, + { + Name: "ResultFormatCode", + Type: Int16, + Data: int32(0), + }, + }, +} + +var _ Message = FunctionCall{} + +// encode implements the interface Message. +func (m FunctionCall) encode() (MessageFormat, error) { + outputMessage := m.defaultMessage().Copy() + outputMessage.Field("ObjectID").MustWrite(m.ObjectID) + for i, formatCode := range m.ArgumentFormatCodes { + outputMessage.Field("ArgumentFormatCodes").Child("ArgumentFormatCode", i).MustWrite(formatCode) + } + for i, argument := range m.Arguments { + if argument.IsNull { + outputMessage.Field("Arguments").Child("ArgumentLength", i).MustWrite(-1) + } else { + outputMessage.Field("Arguments").Child("ArgumentLength", i).MustWrite(len(argument.Data)) + outputMessage.Field("Arguments").Child("ArgumentValue", i).MustWrite(argument.Data) + } + } + outputMessage.Field("ResultFormatCode").MustWrite(m.ResultFormatCode) + return outputMessage, nil +} + +// decode implements the interface Message. +func (m FunctionCall) decode(s MessageFormat) (Message, error) { + if err := s.MatchesStructure(*m.defaultMessage()); err != nil { + return nil, err + } + + // Get the argument format codes + argumentFormatCodesCount := int(s.Field("ArgumentFormatCodes").MustGet().(int32)) + argumentFormatCodes := make([]int32, argumentFormatCodesCount) + for i := 0; i < argumentFormatCodesCount; i++ { + argumentFormatCodes[i] = s.Field("ArgumentFormatCodes").Child("ArgumentFormatCode", i).MustGet().(int32) + } + // Get the arguments + argumentsCount := int(s.Field("Arguments").MustGet().(int32)) + arguments := make([]FunctionCallArgument, argumentsCount) + for i := 0; i < argumentsCount; i++ { + paramLength := s.Field("Arguments").Child("ArgumentLength", i).MustGet().(int32) + if paramLength == -1 { + arguments[i] = FunctionCallArgument{ + IsNull: true, + } + } else { + arguments[i] = FunctionCallArgument{ + Data: s.Field("Arguments").Child("ArgumentValue", i).MustGet().([]byte), + IsNull: false, + } + } + } + + return FunctionCall{ + ObjectID: s.Field("ObjectID").MustGet().(int32), + ArgumentFormatCodes: argumentFormatCodes, + Arguments: arguments, + ResultFormatCode: s.Field("ResultFormatCode").MustGet().(int32), + }, nil +} + +// defaultMessage implements the interface Message. +func (m FunctionCall) defaultMessage() *MessageFormat { + return &functionCallDefault +} diff --git a/postgres/messages/function_call_response.go b/postgres/messages/function_call_response.go new file mode 100644 index 0000000000..21516170d5 --- /dev/null +++ b/postgres/messages/function_call_response.go @@ -0,0 +1,88 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package messages + +func init() { + initializeDefaultMessage(FunctionCallResponse{}) +} + +// FunctionCallResponse represents a PostgreSQL message. +type FunctionCallResponse struct { + IsResultNull bool + ResultValue []byte +} + +var functionCallResponseDefault = MessageFormat{ + Name: "FunctionCallResponse", + Fields: FieldGroup{ + { + Name: "Header", + Type: Byte1, + Flags: Header, + Data: int32('V'), + }, + { + Name: "MessageLength", + Type: Int32, + Flags: MessageLengthInclusive, + Data: int32(0), + }, + { + Name: "ResultLength", + Type: Int32, + Flags: ByteCount, + Data: int32(0), + }, + { + Name: "ResultValue", + Type: ByteN, + Data: []byte{}, + }, + }, +} + +var _ Message = FunctionCallResponse{} + +// encode implements the interface Message. +func (m FunctionCallResponse) encode() (MessageFormat, error) { + outputMessage := m.defaultMessage().Copy() + if m.IsResultNull { + outputMessage.Field("ResultLength").MustWrite(-1) + } else { + if m.ResultValue == nil { + m.ResultValue = []byte{} + } + outputMessage.Field("ResultLength").MustWrite(-1) + outputMessage.Field("ResultValue").MustWrite(m.ResultValue) + } + return outputMessage, nil +} + +// decode implements the interface Message. +func (m FunctionCallResponse) decode(s MessageFormat) (Message, error) { + if err := s.MatchesStructure(*m.defaultMessage()); err != nil { + return nil, err + } + isNull := s.Field("ResultLength").MustGet().(int32) == -1 + return FunctionCallResponse{ + IsResultNull: isNull, + ResultValue: s.Field("ResultValue").MustGet().([]byte), + }, nil +} + +// defaultMessage implements the interface Message. +func (m FunctionCallResponse) defaultMessage() *MessageFormat { + return &functionCallResponseDefault +} diff --git a/postgres/messages/gss_response.go b/postgres/messages/gss_response.go new file mode 100644 index 0000000000..104b56eef1 --- /dev/null +++ b/postgres/messages/gss_response.go @@ -0,0 +1,71 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package messages + +func init() { + initializeDefaultMessage(GSSResponse{}) +} + +// GSSResponse represents a PostgreSQL message. +type GSSResponse struct { + Data []byte +} + +var gSSResponseDefault = MessageFormat{ + Name: "GSSResponse", + Fields: FieldGroup{ + { + Name: "Header", + Type: Byte1, + Flags: Header, + Data: int32('p'), + }, + { + Name: "MessageLength", + Type: Int32, + Flags: MessageLengthInclusive, + Data: int32(0), + }, + { + Name: "Data", + Type: ByteN, + Data: []byte{}, + }, + }, +} + +var _ Message = GSSResponse{} + +// encode implements the interface Message. +func (m GSSResponse) encode() (MessageFormat, error) { + outputMessage := m.defaultMessage().Copy() + outputMessage.Field("Data").MustWrite(m.Data) + return outputMessage, nil +} + +// decode implements the interface Message. +func (m GSSResponse) decode(s MessageFormat) (Message, error) { + if err := s.MatchesStructure(*m.defaultMessage()); err != nil { + return nil, err + } + return GSSResponse{ + Data: s.Field("Data").MustGet().([]byte), + }, nil +} + +// defaultMessage implements the interface Message. +func (m GSSResponse) defaultMessage() *MessageFormat { + return &gSSResponseDefault +} diff --git a/postgres/messages/gssenc_request.go b/postgres/messages/gssenc_request.go new file mode 100644 index 0000000000..8e2c4e5ced --- /dev/null +++ b/postgres/messages/gssenc_request.go @@ -0,0 +1,59 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package messages + +func init() { + initializeDefaultMessage(GSSENCRequest{}) +} + +// GSSENCRequest represents a PostgreSQL message. +type GSSENCRequest struct{} + +var gSSENCRequestDefault = MessageFormat{ + Name: "GSSENCRequest", + Fields: FieldGroup{ + { + Name: "MessageLength", + Type: Int32, + Flags: MessageLengthInclusive, + Data: int32(8), + }, + { + Name: "RequestCode", + Type: Int32, + Data: int32(80877104), + }, + }, +} + +var _ Message = GSSENCRequest{} + +// encode implements the interface Message. +func (m GSSENCRequest) encode() (MessageFormat, error) { + return m.defaultMessage().Copy(), nil +} + +// decode implements the interface Message. +func (m GSSENCRequest) decode(s MessageFormat) (Message, error) { + if err := s.MatchesStructure(*m.defaultMessage()); err != nil { + return nil, err + } + return GSSENCRequest{}, nil +} + +// defaultMessage implements the interface Message. +func (m GSSENCRequest) defaultMessage() *MessageFormat { + return &gSSENCRequestDefault +} diff --git a/postgres/messages/message.go b/postgres/messages/message.go new file mode 100644 index 0000000000..f1b97688d3 --- /dev/null +++ b/postgres/messages/message.go @@ -0,0 +1,88 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package messages + +import ( + "fmt" + "strings" +) + +// MessageFormat is the format of a message as defined by PostgreSQL. Contains the description and values. +// https://www.postgresql.org/docs/15/protocol-message-formats.html +type MessageFormat struct { + Name string + Fields FieldGroup + info *messageInfo + isDefault bool +} + +// Message is a type that represents a PostgreSQL message. +type Message interface { + // encode returns a new MessageFormat containing any modified data contained within the object. This should NOT be + // the default message. + encode() (MessageFormat, error) + // decode returns a new Message that represents the given MessageFormat. You should never return the default + // message, even if the message never varies from the default. Always make a copy, and then modify that copy. + decode(s MessageFormat) (Message, error) + // defaultMessage returns the default, unmodified message for this type. + defaultMessage() *MessageFormat +} + +// messageFieldInfo contains information on a specific field within a messageInfo. +type messageFieldInfo struct { + RelativeIndex int + Parent string + UsesByteCount bool // Only used by ByteN fields +} + +// messageInfo contains all of the information that a message should keep track of. Used internally by messages. +type messageInfo struct { + fieldInfo map[string]messageFieldInfo + appendNullByte bool + defaultMessage *MessageFormat +} + +// Copy returns a copy of the MessageFormat, which is free to modify. +func (m MessageFormat) Copy() MessageFormat { + newFields := make(FieldGroup, len(m.Fields)) + for i, field := range m.Fields { + newFields[i] = field.Copy() + } + return MessageFormat{m.Name, newFields, m.info, false} +} + +// String returns a printable version of the MessageFormat. +func (m MessageFormat) String() string { + buffer := strings.Builder{} + buffer.WriteString(fmt.Sprintf("%s: {\n", m.Name)) + buffer.WriteString("\n") //TODO: print this + buffer.WriteString("}") + return buffer.String() +} + +// MatchesStructure returns an error if the given MessageFormat has a different structure than the calling MessageFormat. +func (m MessageFormat) MatchesStructure(otherMessage MessageFormat) error { + //TODO: check this + return nil +} + +// Field returns a MessageWriter for the calling MessageFormat, which makes it easier (and safer) to update the field whose +// name was given. +func (m MessageFormat) Field(name string) MessageWriter { + return MessageWriter{ + message: m, + fieldQueue: []messageWriterChildPosition{{name, 0}}, + } +} diff --git a/postgres/messages/message_decode_encode.go b/postgres/messages/message_decode_encode.go new file mode 100644 index 0000000000..2f0ced99a4 --- /dev/null +++ b/postgres/messages/message_decode_encode.go @@ -0,0 +1,295 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package messages + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "net" +) + +// Receive returns a Message from the given buffer, generally generated by the client in the main read loop of a +// connection. +func Receive(buffer []byte) (Message, bool, error) { + if len(buffer) == 0 { + return nil, false, nil + } + message, ok := allMessageHeaders[buffer[0]] + if !ok { + return nil, false, nil + } + outMessage, err := ReceiveInto(buffer, message) + return outMessage, true, err +} + +// ReceiveInto writes the contents of the buffer into the given Message. +func ReceiveInto[T Message](buffer []byte, message T) (out T, err error) { + defaultMessage := message.defaultMessage() + fields := defaultMessage.Copy().Fields + if err = decode(&decodeBuffer{buffer}, []FieldGroup{fields}, 1); err != nil { + return out, err + } + decodedMessage, err := message.decode(MessageFormat{defaultMessage.Name, fields, defaultMessage.info, false}) + if err != nil { + return out, err + } + return decodedMessage.(T), nil +} + +// Send sends the given message over the connection. +func Send(conn net.Conn, message Message) error { + encodedMessage, err := message.encode() + if err != nil { + return err + } + data, err := encode(encodedMessage) + if err != nil { + return err + } + _, err = conn.Write(data) + return err +} + +// decodeBuffer just provides an easy way to reference the same buffer, so that decode can modify its length. +type decodeBuffer struct { + data []byte +} + +// decode writes the contents of the buffer into the given fields. The iteration count determines how many times the +// fields will be looped over. +func decode(buffer *decodeBuffer, fields []FieldGroup, iterations int32) error { + for iteration := int32(0); iteration < iterations; iteration++ { + for i, field := range fields[iteration] { + if len(buffer.data) == 0 { + return errors.New("buffer too small") + } + switch field.Type { + case Byte1, Int8: + field.Data = int32(buffer.data[0]) + buffer.data = buffer.data[1:] + case ByteN: + if i > 0 && fields[iteration][i-1].Flags&ByteCount != 0 { + byteCount := fields[iteration][i-1].Data.(int32) + // -1 is a valid value for byte counts, which is used to signal a NULL value. + // We don't need to care about the assumption, so we can just treat it equivalent to zero. + if byteCount == -1 { + byteCount = 0 + } + data := make([]byte, byteCount) + copy(data, buffer.data) + field.Data = data + buffer.data = buffer.data[byteCount:] + } else { + data := make([]byte, len(buffer.data)) + copy(data, buffer.data) + field.Data = data + buffer.data = nil + } + case Int16: + field.Data = int32(binary.BigEndian.Uint16(buffer.data)) + buffer.data = buffer.data[2:] + case Int32: + field.Data = int32(binary.BigEndian.Uint32(buffer.data)) + buffer.data = buffer.data[4:] + case String: + found := false + for bufferIdx := range buffer.data { + if buffer.data[bufferIdx] == 0 { + field.Data = string(buffer.data[:bufferIdx]) + buffer.data = buffer.data[bufferIdx:] + if field.Flags&ExcludeTerminator == 0 { + buffer.data = buffer.data[1:] + } + found = true + break + } + } + if !found { + return errors.New("terminating zero not found for string") + } + case Repeated: + // Track if we've decoded at least once, so that we only update the count if we've decoded something + decodedAtLeastOnce := false + originalChildren := field.Copy().Children[0] + for i := 1; len(buffer.data) > 0; i++ { + // If there is only a single byte left, then it may be the terminator, so we check. + // Otherwise, we'll assume that we should pass it to the child. + if len(buffer.data) == 1 && field.Flags&RepeatedTerminator != 0 { + if buffer.data[0] == 0 { + buffer.data = buffer.data[1:] + break + } else { + return fmt.Errorf("Expected terminator after Repeated type, found invalid byte: %d", buffer.data[0]) + } + } + field.extend(i, originalChildren) + if err := decode(buffer, field.Children[len(field.Children)-1:], 1); err != nil { + return err + } + decodedAtLeastOnce = true + } + if decodedAtLeastOnce { + field.Data = int32(len(field.Children)) + } + default: + panic("message type has not been defined") + } + + if field.Flags&MessageLengthInclusive != 0 { + messageLength := field.Data.(int32) + switch field.Type { + case Byte1, Int8: + messageLength -= 1 + case Int16: + messageLength -= 2 + case Int32: + messageLength -= 4 + } + buffer.data = buffer.data[:messageLength] + } else if field.Flags&MessageLengthExclusive != 0 { + buffer.data = buffer.data[:field.Data.(int32)] + } + if len(field.Children) > 0 && field.Type != Repeated { + count, ok := field.Data.(int32) + if !ok { + return errors.New("non-integer is being used as a count") + } + // Counts may be negative numbers, which have a special meaning depending on the message. + // In all such cases, they'll never have children, so we can just check for cases where it's > 0. + if count > 0 { + field.extend(int(count), field.Children[0]) + if err := decode(buffer, field.Children, count); err != nil { + return err + } + } + } + } + } + return nil +} + +// encode transforms the message into a byte slice, which may be sent to a connection. +func encode(ms MessageFormat) ([]byte, error) { + buffer := bytes.Buffer{} + encodeLoop(&buffer, []FieldGroup{ms.Fields}, 1) + if ms.info.appendNullByte { + buffer.WriteByte(0) + } + data := buffer.Bytes() + + // Find and write the message length + byteOffset := int32(0) + for i, field := range ms.Fields { + if field.Flags&(MessageLengthInclusive|MessageLengthExclusive) != 0 { + typeLength := int32(0) + // Exclusive lengths must take their own type size into account and exclude them from the overall length + if field.Flags&MessageLengthExclusive != 0 { + switch field.Type { + case Byte1, Int8: + typeLength = 1 + case Int16: + typeLength = 2 + case Int32: + typeLength = 4 + } + } + messageLength := int32(len(data)) - byteOffset - typeLength + switch field.Type { + case Byte1, Int8: + data[byteOffset] = byte(messageLength) + case Int16: + binary.BigEndian.PutUint16(data[byteOffset:], uint16(messageLength)) + case Int32: + binary.BigEndian.PutUint32(data[byteOffset:], uint32(messageLength)) + default: + panic("invalid type for message length") + } + break + } + + // Advance the offset + switch field.Type { + case Byte1, Int8: + byteOffset += 1 + case ByteN: + if i > 0 && ms.Fields[i-1].Flags&ByteCount != 0 { + byteOffset += ms.Fields[i-1].Data.(int32) + } else { + byteOffset = int32(len(data)) // Last field, so we can set it to the remaining data + } + case Int16: + byteOffset += 2 + case Int32: + byteOffset += 4 + case String: + found := false + for bufferIdx := range data[byteOffset:] { + if data[bufferIdx] == 0 { + found = true + byteOffset += int32(bufferIdx) + //TODO: is this the correct place to put this? investigate/test + if field.Flags&ExcludeTerminator == 0 { + byteOffset += 1 + } + break + } + } + if !found { + return nil, errors.New("terminating zero not found for string") + } + case Repeated: + byteOffset = int32(len(data)) // Last field, so we can set it to the remaining data + default: + panic("message type has not been defined") + } + } + return data, nil +} + +// encodeLoop is the inner recursive loop of encode, which writes the given fields into the buffer. The iteration +// count determines how many times the fields are looped over. +func encodeLoop(buffer *bytes.Buffer, fields []FieldGroup, iterations int32) { + for iteration := int32(0); iteration < iterations; iteration++ { + for _, field := range fields[iteration] { + switch field.Type { + case Byte1: + _ = binary.Write(buffer, binary.BigEndian, byte(field.Data.(int32))) + case ByteN: + buffer.Write(field.Data.([]byte)) + case Int8: + _ = binary.Write(buffer, binary.BigEndian, int8(field.Data.(int32))) + case Int16: + _ = binary.Write(buffer, binary.BigEndian, int16(field.Data.(int32))) + case Int32: + _ = binary.Write(buffer, binary.BigEndian, field.Data.(int32)) + case String: + buffer.WriteString(field.Data.(string)) + if field.Flags&ExcludeTerminator == 0 { + buffer.WriteByte(0) + } + case Repeated: + // We don't write anything for repeated fields, since they repeat their children until the end + default: + panic("message type has not been defined") + } + + if len(field.Children) > 0 { + encodeLoop(buffer, field.Children, field.Data.(int32)) + } + } + } +} diff --git a/postgres/messages/message_field.go b/postgres/messages/message_field.go new file mode 100644 index 0000000000..7925656de9 --- /dev/null +++ b/postgres/messages/message_field.go @@ -0,0 +1,84 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package messages + +// FieldType is the type of the field as defined by PostgreSQL. +type FieldType byte + +const ( + Byte1 FieldType = iota // Byte1 is a single unsigned byte. + ByteN // ByteN is a variable number of bytes. Allowed on the last field, or when a ByteCount-tagged field precedes it. + Int8 // Int8 is a single signed byte. + Int16 // Int16 are two bytes. + Int32 // Int32 are four bytes. + String // String is a variable-length type, generally punctuated by a NULL terminator. + Repeated // Repeated is a parent type that states its children will be repeated until the end of the message. + + Byte4 = Int32 //TODO: verify that this is correct, only used on one type +) + +// FieldFlags are special attributes that may be assigned to fields. +type FieldFlags int32 + +const ( + Header FieldFlags = 1 << iota // Header is the header of the message. + MessageLengthInclusive // MessageLengthInclusive is the length of the message, including the count's size. + MessageLengthExclusive // MessageLengthExclusive is the length of the message, excluding the count's size. + ExcludeTerminator // ExcludeTerminator excludes the terminator for String types. + ByteCount // ByteCount signals that the following ByteN non-child field uses this field for its count. + RepeatedTerminator // RepeatedTerminator states that the Repeated type always ends with a NULL terminator. +) + +// FieldGroup is a slice of fields. Mainly used for organization, as []FieldGroup looks better than []FieldGroup. +type FieldGroup []*Field + +// Field is a field within the PostgreSQL message. +type Field struct { + Name string + Type FieldType + Flags FieldFlags + Data any // Data may ONLY be one of the following: int32, string, []byte. Nil is not allowed. + Children []FieldGroup +} + +// Copy returns a copy of this field, which is free to modify. +func (f *Field) Copy() *Field { + fieldCopy := *f + if len(f.Children) > 0 { + newChildren := make([]FieldGroup, len(f.Children)) + for groupIndex, fieldGroup := range f.Children { + newFields := make(FieldGroup, len(fieldGroup)) + for fieldIndex, field := range fieldGroup { + newFields[fieldIndex] = field.Copy() + } + newChildren[groupIndex] = newFields + } + fieldCopy.Children = newChildren + } + return &fieldCopy +} + +// extend lengthens the children to the new length, using the given default children to fill each newly-created entry. +// All new entries will be copied from the default, therefore they're free to modify. This modifies the calling Field +// in-place. +func (f *Field) extend(newLength int, defaultChildren FieldGroup) { + for currentIndex := len(f.Children); currentIndex < newLength; currentIndex++ { + newFields := make(FieldGroup, len(defaultChildren)) + for i, field := range defaultChildren { + newFields[i] = field.Copy() + } + f.Children = append(f.Children, newFields) + } +} diff --git a/postgres/messages/message_initialization.go b/postgres/messages/message_initialization.go new file mode 100644 index 0000000000..962fb6d044 --- /dev/null +++ b/postgres/messages/message_initialization.go @@ -0,0 +1,206 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package messages + +import "fmt" + +// allMessageHeaders contains any message headers that should be read within the main read loop of a connection. +var allMessageHeaders = make(map[byte]Message) + +// allMessageNames contains the names of all messages, as they should all be unique. +var allMessageNames = make(map[string]struct{}) + +// allMessageDefaults contains all of the default message pointers, to make sure that they're not accidentally being reused. +var allMessageDefaults = make(map[*MessageFormat]struct{}) + +// addMessageHeader adds the given Message's header. This also ensures that each header is unique. This should be +// called in an init() function. +func addMessageHeader(message Message) { + for _, field := range message.defaultMessage().Fields { + if field.Flags&Header != 0 { + header := byte(field.Data.(int32)) + if _, ok := allMessageHeaders[header]; ok { + panic(fmt.Errorf("Header already taken.\nMessageFormat:\n\n%s", message.defaultMessage().String())) + } + allMessageHeaders[header] = message + return + } + } + panic(fmt.Errorf("Header does not exist.\nMessageFormat:\n\n%s", message.defaultMessage().String())) +} + +// initializeDefaultMessage creates the internal structure of the default message, while ensuring that the structure of +// the message is correct. This should be called in an init() function. +func initializeDefaultMessage(messageType Message) { + message := messageType.defaultMessage() + if _, ok := allMessageDefaults[message]; ok { + panic(fmt.Errorf("MessageFormat default was used in another message.\nMessageFormat:\n\n%s", message.String())) + } + allMessageDefaults[message] = struct{}{} + if message.info != nil { + panic(fmt.Errorf("Message has already been initialized.\nMessage:\n\n%s", message.String())) + } + if _, ok := allMessageNames[message.Name]; ok { + panic(fmt.Errorf("Message has already been initialized with the same name.\nName: %s", message.Name)) + } + allMessageNames[message.Name] = struct{}{} + message.info = &messageInfo{make(map[string]messageFieldInfo), false, message} + message.isDefault = true + + allFieldNames := make(map[string]struct{}) // Verify that all field names are unique + headerFound := false // Only one header may exist in the message + messageLengthFound := false // Only one message length may exist in the message + endingByteNFound := false // If a ByteN has been found that does not have a preceding ByteCount + repeatedFoundHeight := 0 // The depth that a Repeated type has been found + type FieldTraversal struct { + Index int + Fields FieldGroup + } + + ftStack := NewStack[FieldTraversal]() + ftStack.Push(FieldTraversal{0, message.Fields}) + for !ftStack.Empty() { + // If we're at the end of the loop for this stacked entry, then we pop it and move to the next + if ftStack.Peek().Index >= len(ftStack.Peek().Fields) { + _ = ftStack.Pop() + continue + } + // Check if we've found a ByteN that is not preceded by a ByteCount-tagged field, as it should be the last + // field, and we're now looking at a field after it. + if endingByteNFound { + panic(fmt.Errorf("ByteN found that was not preceded by a field with the ByteCount tag.\nMessageFormat:\n\n%s", message.String())) + } + // If the stack is larger than Repeated's height, then we're probably in Repeated's children. + // Otherwise, there are more non-child fields after the Repeated type. + if ftStack.Len() <= repeatedFoundHeight { + panic(fmt.Errorf("Repeated is not on the last field at its level\nMessageFormat:\n\n%s", message.String())) + } + // Grab the field. + field := ftStack.Peek().Fields[ftStack.Peek().Index] + // Verify uniqueness and correctness of tags (if any) + if field.Flags&Header != 0 { + if headerFound { + panic(fmt.Errorf("Multiple headers in message.\nMessageFormat:\n\n%s", message.String())) + } + headerFound = true + } + if field.Flags&(MessageLengthInclusive|MessageLengthExclusive) != 0 { + if messageLengthFound { + panic(fmt.Errorf("Multiple message lengths in message.\nMessageFormat:\n\n%s", message.String())) + } + switch field.Type { + case Byte1, Int8, Int16, Int32: + default: + panic(fmt.Errorf("Message length tags are only allowed on integer types.\nField: %s\nMessage:\n\n%s", field.Name, message.String())) + } + messageLengthFound = true + } + if field.Flags&ByteCount != 0 { + switch field.Type { + case Byte1, Int8, Int16, Int32: + default: + panic(fmt.Errorf("ByteCount tag is only allowed on integer types.\nField: %s\nMessageFormat:\n\n%s", field.Name, message.String())) + } + } + if field.Flags&ExcludeTerminator != 0 && field.Type != String { + panic(fmt.Errorf("ExcludeTerminator tag is only allowed on String fields.\nField: %s\nMessageFormat:\n\n%s", field.Name, message.String())) + } + // Verify uniqueness of names (case-sensitive for maximum flexibility) + if len(field.Name) == 0 { + panic(fmt.Errorf("All fields must have a name.\nMessageFormat:\n\n%s", message.String())) + } + if _, ok := allFieldNames[field.Name]; ok { + panic(fmt.Errorf("Multiple fields with the same name.\nMessageFormat:\n\n%s", message.String())) + } + allFieldNames[field.Name] = struct{}{} + // Verify that ByteN is the last field, or is preceded by a field with the ByteCount tag + usesByteCount := false + if field.Type == ByteN { + // If the preceding field has the ByteCount tag, then ByteN does not have the ending-field-only restriction + if ftStack.Peek().Index > 0 && (ftStack.Peek().Fields[ftStack.Peek().Index-1].Flags&ByteCount != 0) { + usesByteCount = true + } else { + endingByteNFound = true + } + } + // Verify the type for each default value + switch field.Type { + case Byte1, Int8, Int16, Int32, Repeated: + if _, ok := field.Data.(int32); !ok { + panic(fmt.Errorf("Integer field types must set their Data to an int32 value.\nField: %s\nMessageFormat:\n\n%s", field.Name, message.String())) + } + case ByteN: + if _, ok := field.Data.([]byte); !ok { + panic(fmt.Errorf("ByteN fields must set their Data to a []byte value.\nField: %s\nMessageFormat:\n\n%s", field.Name, message.String())) + } + case String: + if _, ok := field.Data.(string); !ok { + panic(fmt.Errorf("String fields must set their Data to a string value.\nField: %s\nMessageFormat:\n\n%s", field.Name, message.String())) + } + default: + panic("message type has not been defined") + } + // Verify that, for fields with children, the default count matches the default child count + if len(field.Children) > 0 { + count := int32(0) + switch field.Type { + case Byte1, Int8, Int16, Int32, Repeated: + count = field.Data.(int32) + default: + panic(fmt.Errorf("Only integer types may have children, as they determine the count.\nField: %s\nMessageFormat:\n\n%s", field.Name, message.String())) + } + // A value of zero means that the child is only used as a prototype. A value of one means that the child is + // actually used as a default value. We do not allow declaring children with multiple default values. + if count != 0 && count != 1 { + panic(fmt.Errorf("Only integer types may have children, as they determine the count.\nField: %s\nMessageFormat:\n\n%s", field.Name, message.String())) + } + if len(field.Children) > 1 { + panic(fmt.Errorf("Only a single child may be declared in the default message.\nField: %s\nMessageFormat:\n\n%s", field.Name, message.String())) + } + } + // Repeated may only be on a single field. Children of a Repeated field cannot also have Repeated children. + if field.Type == Repeated { + if repeatedFoundHeight != 0 { + panic(fmt.Errorf("Multiple Repeated types declared.\nField: %s\nMessageFormat:\n\n%s", field.Name, message.String())) + } + repeatedFoundHeight = ftStack.Len() + } + // RepeatedTerminator is only allowed on Repeated types, and therefore follows all of its restrictions automatically. + if field.Flags&RepeatedTerminator != 0 { + if field.Type != Repeated { + panic(fmt.Errorf("RepeatedTerminator may only be used on a Repeated type.\nMessageFormat:\n\n%s", message.String())) + } + message.info.appendNullByte = true + } + + // Write the field info into our message + parentName := "" + if ftStack.Len() > 1 { + parentName = ftStack.PeekDepth(1).Fields[ftStack.PeekDepth(1).Index-1].Name + } + message.info.fieldInfo[field.Name] = messageFieldInfo{ + RelativeIndex: ftStack.Peek().Index, + Parent: parentName, + UsesByteCount: usesByteCount, + } + + // Increment the index + ftStack.PeekReference().Index++ + // If there are any children, then we throw them onto the stack + if len(field.Children) == 1 { + ftStack.Push(FieldTraversal{0, field.Children[0]}) + } + } +} diff --git a/postgres/messages/message_writer.go b/postgres/messages/message_writer.go new file mode 100644 index 0000000000..1a9e76b810 --- /dev/null +++ b/postgres/messages/message_writer.go @@ -0,0 +1,167 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package messages + +import "fmt" + +// MessageWriter is used to easily (and safely) interact with the contents of a MessageFormat. +type MessageWriter struct { + message MessageFormat + fieldQueue []messageWriterChildPosition +} + +// messageWriterChildPosition contains the name and position of a queued entry in a MessageWriter. +type messageWriterChildPosition struct { + name string + position int +} + +// Child returns a new MessageWriter pointing to the child of the field chain provided so far. As there may be multiple +// children, the position determines which child will be referenced. If a child position is given that does not exist, +// then a child at that position will be created. If the position is more than an increment, then this will also create +// all children up to the position, by giving them their default values. +func (mw MessageWriter) Child(name string, position int) MessageWriter { + fieldQueue := make([]messageWriterChildPosition, len(mw.fieldQueue)+1) + copy(fieldQueue, mw.fieldQueue) + fieldQueue[len(fieldQueue)-1] = messageWriterChildPosition{name, position} + return MessageWriter{ + message: mw.message, + fieldQueue: fieldQueue, + } +} + +// Write writes the given value to the field pointed to by the field chain provided. Only accepts values with the +// following types: int/8/16/32/64, uint/8/16/32/64, string, []byte (use an empty slice instead of nil). +func (mw MessageWriter) Write(value any) error { + if mw.message.isDefault { + return fmt.Errorf("Cannot write to the default message: %s", mw.message.Name) + } + + var field *Field + var defaultField *Field + if fieldInfo, ok := mw.message.info.fieldInfo[mw.fieldQueue[0].name]; ok { + field = mw.message.Fields[fieldInfo.RelativeIndex] + defaultField = mw.message.info.defaultMessage.Fields[fieldInfo.RelativeIndex] + } else { + return fmt.Errorf(`The message "%s" does not contain a field named "%s"`, mw.message.Name, mw.fieldQueue[0].name) + } + fq := mw.fieldQueue[1:] + for len(fq) > 0 { + fieldInfo, ok := mw.message.info.fieldInfo[fq[0].name] + if !ok { + return fmt.Errorf(`The message "%s" does not contain a field named "%s"`, mw.message.Name, fq[0].name) + } + if fieldInfo.Parent != field.Name { + return fmt.Errorf(`In the message "%s", the field "%s"" is not a child of the field "%s"`, + mw.message.Name, fq[0].name, field.Name) + } + field.extend(fq[0].position+1, defaultField.Children[0]) // extend() takes the length, so add 1 to the position + field.Data = int32(len(field.Children)) // All types that have children are integer types + field = field.Children[fq[0].position][fieldInfo.RelativeIndex] + defaultField = defaultField.Children[0][fieldInfo.RelativeIndex] + // Remove the child from the queue + fq = fq[1:] + } + + switch field.Type { + case Byte1, Int8, Int16, Int32, Repeated: + switch value := value.(type) { + case int: + field.Data = int32(value) + case int8: + field.Data = int32(value) + case int16: + field.Data = int32(value) + case int32: + field.Data = value + case int64: + field.Data = int32(value) + case uint: + field.Data = int32(value) + case uint8: + field.Data = int32(value) + case uint16: + field.Data = int32(value) + case uint32: + field.Data = int32(value) + case uint64: + field.Data = int32(value) + default: + return fmt.Errorf("Attempted to write an invalid value of type `%T` into the following integer field: %s", value, field.Name) + } + case ByteN: + switch value := value.(type) { + case []byte: + field.Data = value + default: + return fmt.Errorf("Attempted to write an invalid value of type `%T` into the following ByteN field: %s", value, field.Name) + } + case String: + switch value := value.(type) { + case string: + field.Data = value + default: + return fmt.Errorf("Attempted to write an invalid value of type `%T` into the following String field: %s", value, field.Name) + } + default: + panic("message type has not been defined") + } + return nil +} + +// Get returns the value of the field pointed to by the field chain provided. +func (mw MessageWriter) Get() (any, error) { + var field *Field + if fieldInfo, ok := mw.message.info.fieldInfo[mw.fieldQueue[0].name]; ok { + field = mw.message.Fields[fieldInfo.RelativeIndex] + } else { + return nil, fmt.Errorf(`The message "%s" does not contain a field named "%s"`, mw.message.Name, mw.fieldQueue[0].name) + } + fq := mw.fieldQueue[1:] + for len(fq) > 0 { + fieldInfo, ok := mw.message.info.fieldInfo[fq[0].name] + if !ok { + return nil, fmt.Errorf(`The message "%s" does not contain a field named "%s"`, mw.message.Name, fq[0].name) + } + if fieldInfo.Parent != field.Name { + return nil, fmt.Errorf(`In the message "%s", the field "%s" is not a child of the field "%s"`, + mw.message.Name, fq[0].name, field.Name) + } + if fq[0].position >= len(field.Children) { + return nil, fmt.Errorf("Index out of bounds.\nMessage: %s\nField: %s\nIndex: %d\nLength: %d", + mw.message.Name, field.Name, fq[0].position, len(field.Children)) + } + field = field.Children[fq[0].position][fieldInfo.RelativeIndex] + // Remove the child from the queue + fq = fq[1:] + } + return field.Data, nil +} + +// MustWrite is the same as Write, except that this panics on errors rather than returning them. +func (mw MessageWriter) MustWrite(value any) { + if err := mw.Write(value); err != nil { + panic(err) + } +} + +// MustGet is the same as Get, except that this panics on errors rather than returning them. +func (mw MessageWriter) MustGet() any { + value, err := mw.Get() + if err != nil { + panic(err) + } + return value +} diff --git a/postgres/messages/negotiate_protocol_version.go b/postgres/messages/negotiate_protocol_version.go new file mode 100644 index 0000000000..3257badcf4 --- /dev/null +++ b/postgres/messages/negotiate_protocol_version.go @@ -0,0 +1,95 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package messages + +func init() { + initializeDefaultMessage(NegotiateProtocolVersion{}) +} + +// NegotiateProtocolVersion represents a PostgreSQL message. +type NegotiateProtocolVersion struct { + NewestMinorProtocol int32 + UnrecognizedOptions []string +} + +var negotiateProtocolVersionDefault = MessageFormat{ + Name: "NegotiateProtocolVersion", + Fields: FieldGroup{ + { + Name: "Header", + Type: Byte1, + Flags: Header, + Data: int32('v'), + }, + { + Name: "MessageLength", + Type: Int32, + Flags: MessageLengthInclusive, + Data: int32(0), + }, + { + Name: "NewestMinorProtocol", + Type: Int32, + Data: int32(0), + }, + { + Name: "UnrecognizedOptions", + Type: Int32, + Data: int32(0), + Children: []FieldGroup{ + { + { + Name: "UnrecognizedOption", + Type: String, + Data: "", + }, + }, + }, + }, + }, +} + +var _ Message = NegotiateProtocolVersion{} + +// encode implements the interface Message. +func (m NegotiateProtocolVersion) encode() (MessageFormat, error) { + outputMessage := m.defaultMessage().Copy() + outputMessage.Field("NewestMinorProtocol").MustWrite(m.NewestMinorProtocol) + for i, option := range m.UnrecognizedOptions { + outputMessage.Field("UnrecognizedOptions").Child("UnrecognizedOption", i).MustWrite(option) + } + return outputMessage, nil +} + +// decode implements the interface Message. +func (m NegotiateProtocolVersion) decode(s MessageFormat) (Message, error) { + if err := s.MatchesStructure(*m.defaultMessage()); err != nil { + return nil, err + } + count := int(s.Field("UnrecognizedOptions").MustGet().(int32)) + unrecognizedOptions := make([]string, count) + for i := 0; i < count; i++ { + unrecognizedOptions[i] = s.Field("UnrecognizedOptions").Child("UnrecognizedOption", i).MustGet().(string) + } + return NegotiateProtocolVersion{ + NewestMinorProtocol: s.Field("NewestMinorProtocol").MustGet().(int32), + UnrecognizedOptions: unrecognizedOptions, + }, nil +} + +// defaultMessage implements the interface Message. +func (m NegotiateProtocolVersion) defaultMessage() *MessageFormat { + return &negotiateProtocolVersionDefault +} diff --git a/postgres/messages/no_data.go b/postgres/messages/no_data.go new file mode 100644 index 0000000000..aab20f4a64 --- /dev/null +++ b/postgres/messages/no_data.go @@ -0,0 +1,60 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package messages + +func init() { + initializeDefaultMessage(NoData{}) +} + +// NoData represents a PostgreSQL message. +type NoData struct{} + +var noDataDefault = MessageFormat{ + Name: "NoData", + Fields: FieldGroup{ + { + Name: "Header", + Type: Byte1, + Flags: Header, + Data: int32('n'), + }, + { + Name: "MessageLength", + Type: Int32, + Flags: MessageLengthInclusive, + Data: int32(4), + }, + }, +} + +var _ Message = NoData{} + +// encode implements the interface Message. +func (m NoData) encode() (MessageFormat, error) { + return m.defaultMessage().Copy(), nil +} + +// decode implements the interface Message. +func (m NoData) decode(s MessageFormat) (Message, error) { + if err := s.MatchesStructure(*m.defaultMessage()); err != nil { + return nil, err + } + return NoData{}, nil +} + +// defaultMessage implements the interface Message. +func (m NoData) defaultMessage() *MessageFormat { + return &noDataDefault +} diff --git a/postgres/messages/notice_response.go b/postgres/messages/notice_response.go new file mode 100644 index 0000000000..471d725820 --- /dev/null +++ b/postgres/messages/notice_response.go @@ -0,0 +1,103 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package messages + +func init() { + initializeDefaultMessage(NoticeResponse{}) +} + +// NoticeResponse represents a PostgreSQL message. +type NoticeResponse struct { + Fields []NoticeResponseField +} + +// NoticeResponseField are the fields to an NoticeResponse message. +type NoticeResponseField struct { + Code int32 + Value string +} + +var noticeResponseDefault = MessageFormat{ + Name: "NoticeResponse", + Fields: FieldGroup{ + { + Name: "Header", + Type: Byte1, + Flags: Header, + Data: int32('N'), + }, + { + Name: "MessageLength", + Type: Int32, + Flags: MessageLengthInclusive, + Data: int32(0), + }, + { + Name: "Fields", + Type: Repeated, + Flags: RepeatedTerminator, + Data: int32(0), + Children: []FieldGroup{ + { + { + Name: "Code", + Type: Byte1, + Data: int32(0), + }, + { + Name: "Value", + Type: String, + Data: "", + }, + }, + }, + }, + }, +} + +var _ Message = NoticeResponse{} + +// encode implements the interface Message. +func (m NoticeResponse) encode() (MessageFormat, error) { + outputMessage := m.defaultMessage().Copy() + for i, field := range m.Fields { + outputMessage.Field("Fields").Child("Code", i).MustWrite(field.Code) + outputMessage.Field("Fields").Child("Value", i).MustWrite(field.Value) + } + return outputMessage, nil +} + +// decode implements the interface Message. +func (m NoticeResponse) decode(s MessageFormat) (Message, error) { + if err := s.MatchesStructure(*m.defaultMessage()); err != nil { + return nil, err + } + count := int(s.Field("Fields").MustGet().(int32)) + fields := make([]NoticeResponseField, count) + for i := 0; i < count; i++ { + fields[i] = NoticeResponseField{ + Code: s.Field("Fields").Child("Code", i).MustGet().(int32), + Value: s.Field("Fields").Child("Value", i).MustGet().(string), + } + } + return NoticeResponse{ + Fields: fields, + }, nil +} + +// defaultMessage implements the interface Message. +func (m NoticeResponse) defaultMessage() *MessageFormat { + return ¬iceResponseDefault +} diff --git a/postgres/messages/notification_response.go b/postgres/messages/notification_response.go new file mode 100644 index 0000000000..0299e2b92e --- /dev/null +++ b/postgres/messages/notification_response.go @@ -0,0 +1,87 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package messages + +func init() { + initializeDefaultMessage(NotificationResponse{}) +} + +// NotificationResponse represents a PostgreSQL message. +type NotificationResponse struct { + ProcessID int32 + Channel string + Payload string +} + +var notificationResponseDefault = MessageFormat{ + Name: "NotificationResponse", + Fields: FieldGroup{ + { + Name: "Header", + Type: Byte1, + Flags: Header, + Data: int32('A'), + }, + { + Name: "MessageLength", + Type: Int32, + Flags: MessageLengthInclusive, + Data: int32(0), + }, + { + Name: "ProcessID", + Type: Int32, + Data: int32(0), + }, + { + Name: "Channel", + Type: String, + Data: "", + }, + { + Name: "Payload", + Type: String, + Data: "", + }, + }, +} + +var _ Message = NotificationResponse{} + +// encode implements the interface Message. +func (m NotificationResponse) encode() (MessageFormat, error) { + outputMessage := m.defaultMessage().Copy() + outputMessage.Field("ProcessID").MustWrite(m.ProcessID) + outputMessage.Field("Channel").MustWrite(m.Channel) + outputMessage.Field("Payload").MustWrite(m.Payload) + return outputMessage, nil +} + +// decode implements the interface Message. +func (m NotificationResponse) decode(s MessageFormat) (Message, error) { + if err := s.MatchesStructure(*m.defaultMessage()); err != nil { + return nil, err + } + return NotificationResponse{ + ProcessID: s.Field("ProcessID").MustGet().(int32), + Channel: s.Field("Channel").MustGet().(string), + Payload: s.Field("Payload").MustGet().(string), + }, nil +} + +// defaultMessage implements the interface Message. +func (m NotificationResponse) defaultMessage() *MessageFormat { + return ¬ificationResponseDefault +} diff --git a/postgres/messages/parameter_description.go b/postgres/messages/parameter_description.go new file mode 100644 index 0000000000..5fef3b95f0 --- /dev/null +++ b/postgres/messages/parameter_description.go @@ -0,0 +1,87 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package messages + +func init() { + initializeDefaultMessage(ParameterDescription{}) +} + +// ParameterDescription represents a PostgreSQL message. +type ParameterDescription struct { + ObjectIDs []int32 +} + +var parameterDescriptionDefault = MessageFormat{ + Name: "ParameterDescription", + Fields: FieldGroup{ + { + Name: "Header", + Type: Byte1, + Flags: Header, + Data: int32('t'), + }, + { + Name: "MessageLength", + Type: Int32, + Flags: MessageLengthInclusive, + Data: int32(0), + }, + { + Name: "Parameters", + Type: Int16, + Data: int32(0), + Children: []FieldGroup{ + { + { + Name: "ObjectID", + Type: Int32, + Data: int32(0), + }, + }, + }, + }, + }, +} + +var _ Message = ParameterDescription{} + +// encode implements the interface Message. +func (m ParameterDescription) encode() (MessageFormat, error) { + outputMessage := m.defaultMessage().Copy() + for i, objectID := range m.ObjectIDs { + outputMessage.Field("Parameters").Child("ObjectID", i).MustWrite(objectID) + } + return outputMessage, nil +} + +// decode implements the interface Message. +func (m ParameterDescription) decode(s MessageFormat) (Message, error) { + if err := s.MatchesStructure(*m.defaultMessage()); err != nil { + return nil, err + } + count := int(s.Field("Parameters").MustGet().(int32)) + objectIDs := make([]int32, count) + for i := 0; i < count; i++ { + objectIDs[i] = s.Field("Parameters").Child("ObjectID", i).MustGet().(int32) + } + return ParameterDescription{ + ObjectIDs: objectIDs, + }, nil +} + +// defaultMessage implements the interface Message. +func (m ParameterDescription) defaultMessage() *MessageFormat { + return ¶meterDescriptionDefault +} diff --git a/postgres/messages/parameter_status.go b/postgres/messages/parameter_status.go index 03f076de48..0c4c961272 100644 --- a/postgres/messages/parameter_status.go +++ b/postgres/messages/parameter_status.go @@ -1,6 +1,22 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package messages -import "bytes" +func init() { + initializeDefaultMessage(ParameterStatus{}) +} // ParameterStatus reports various parameters to the client. type ParameterStatus struct { @@ -8,14 +24,56 @@ type ParameterStatus struct { Value string } -// Bytes returns ParameterStatus as a byte slice, ready to be returned to the client. -func (ps ParameterStatus) Bytes() []byte { - buf := bytes.Buffer{} - buf.WriteByte('S') // Message Type - WriteNumber(&buf, int32(0)) // Message length, will be corrected later - buf.WriteString(ps.Name) - buf.WriteByte(0) // Trailing NULL character, denoting the end of the string - buf.WriteString(ps.Value) - buf.WriteByte(0) // Trailing NULL character, denoting the end of the string - return WriteLength(buf.Bytes()) +var parameterStatusDefault = MessageFormat{ + Name: "ParameterStatus", + Fields: FieldGroup{ + { + Name: "Header", + Type: Byte1, + Flags: Header, + Data: int32('S'), + }, + { + Name: "MessageLength", + Type: Int32, + Flags: MessageLengthInclusive, + Data: int32(0), + }, + { + Name: "Name", + Type: String, + Data: "", + }, + { + Name: "Value", + Type: String, + Data: "", + }, + }, +} + +var _ Message = ParameterStatus{} + +// encode implements the interface Message. +func (m ParameterStatus) encode() (MessageFormat, error) { + outputMessage := m.defaultMessage().Copy() + outputMessage.Field("Name").MustWrite(m.Name) + outputMessage.Field("Value").MustWrite(m.Value) + return outputMessage, nil +} + +// decode implements the interface Message. +func (m ParameterStatus) decode(s MessageFormat) (Message, error) { + if err := s.MatchesStructure(*m.defaultMessage()); err != nil { + return nil, err + } + return ParameterStatus{ + Name: s.Field("Name").MustGet().(string), + Value: s.Field("Value").MustGet().(string), + }, nil +} + +// defaultMessage implements the interface Message. +func (m ParameterStatus) defaultMessage() *MessageFormat { + return ¶meterStatusDefault } diff --git a/postgres/messages/parse.go b/postgres/messages/parse.go new file mode 100644 index 0000000000..a18696a58d --- /dev/null +++ b/postgres/messages/parse.go @@ -0,0 +1,104 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package messages + +func init() { + initializeDefaultMessage(Parse{}) + addMessageHeader(Parse{}) +} + +// Parse represents a PostgreSQL message. +type Parse struct { + PreparedStatement string + Query string + ParameterObjectIDs []int32 +} + +var parseDefault = MessageFormat{ + Name: "Parse", + Fields: FieldGroup{ + { + Name: "Header", + Type: Byte1, + Flags: Header, + Data: int32('P'), + }, + { + Name: "MessageLength", + Type: Int32, + Flags: MessageLengthInclusive, + Data: int32(0), + }, + { + Name: "PreparedStatement", + Type: String, + Data: "", + }, + { + Name: "Query", + Type: String, + Data: "", + }, + { + Name: "Parameters", + Type: Int16, + Data: int32(0), + Children: []FieldGroup{ + { + { + Name: "ObjectID", + Type: Int32, + Data: int32(0), + }, + }, + }, + }, + }, +} + +var _ Message = Parse{} + +// encode implements the interface Message. +func (m Parse) encode() (MessageFormat, error) { + outputMessage := m.defaultMessage().Copy() + outputMessage.Field("PreparedStatement").MustWrite(m.PreparedStatement) + outputMessage.Field("Query").MustWrite(m.Query) + for i, objectID := range m.ParameterObjectIDs { + outputMessage.Field("Parameters").Child("ObjectID", i).MustWrite(objectID) + } + return outputMessage, nil +} + +// decode implements the interface Message. +func (m Parse) decode(s MessageFormat) (Message, error) { + if err := s.MatchesStructure(*m.defaultMessage()); err != nil { + return nil, err + } + count := int(s.Field("Parameters").MustGet().(int32)) + objectIDs := make([]int32, count) + for i := 0; i < count; i++ { + objectIDs[i] = s.Field("Parameters").Child("ObjectID", i).MustGet().(int32) + } + return Parse{ + PreparedStatement: s.Field("PreparedStatement").MustGet().(string), + Query: s.Field("Query").MustGet().(string), + ParameterObjectIDs: objectIDs, + }, nil +} + +// defaultMessage implements the interface Message. +func (m Parse) defaultMessage() *MessageFormat { + return &parseDefault +} diff --git a/postgres/messages/parse_complete.go b/postgres/messages/parse_complete.go new file mode 100644 index 0000000000..180d20de29 --- /dev/null +++ b/postgres/messages/parse_complete.go @@ -0,0 +1,60 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package messages + +func init() { + initializeDefaultMessage(ParseComplete{}) +} + +// ParseComplete represents a PostgreSQL message. +type ParseComplete struct{} + +var parseCompleteDefault = MessageFormat{ + Name: "ParseComplete", + Fields: FieldGroup{ + { + Name: "Header", + Type: Byte1, + Flags: Header, + Data: int32('1'), + }, + { + Name: "MessageLength", + Type: Int32, + Flags: MessageLengthInclusive, + Data: int32(4), + }, + }, +} + +var _ Message = ParseComplete{} + +// encode implements the interface Message. +func (m ParseComplete) encode() (MessageFormat, error) { + return m.defaultMessage().Copy(), nil +} + +// decode implements the interface Message. +func (m ParseComplete) decode(s MessageFormat) (Message, error) { + if err := s.MatchesStructure(*m.defaultMessage()); err != nil { + return nil, err + } + return ParseComplete{}, nil +} + +// defaultMessage implements the interface Message. +func (m ParseComplete) defaultMessage() *MessageFormat { + return &parseCompleteDefault +} diff --git a/postgres/messages/password_message.go b/postgres/messages/password_message.go new file mode 100644 index 0000000000..ec0d945915 --- /dev/null +++ b/postgres/messages/password_message.go @@ -0,0 +1,71 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package messages + +func init() { + initializeDefaultMessage(PasswordMessage{}) +} + +// PasswordMessage represents a PostgreSQL message. +type PasswordMessage struct { + Password string +} + +var passwordMessageDefault = MessageFormat{ + Name: "PasswordMessage", + Fields: FieldGroup{ + { + Name: "Header", + Type: Byte1, + Flags: Header, + Data: int32('p'), + }, + { + Name: "MessageLength", + Type: Int32, + Flags: MessageLengthInclusive, + Data: int32(0), + }, + { + Name: "Password", + Type: String, + Data: "", + }, + }, +} + +var _ Message = PasswordMessage{} + +// encode implements the interface Message. +func (m PasswordMessage) encode() (MessageFormat, error) { + outputMessage := m.defaultMessage().Copy() + outputMessage.Field("Password").MustWrite(m.Password) + return outputMessage, nil +} + +// decode implements the interface Message. +func (m PasswordMessage) decode(s MessageFormat) (Message, error) { + if err := s.MatchesStructure(*m.defaultMessage()); err != nil { + return nil, err + } + return PasswordMessage{ + Password: s.Field("Password").MustGet().(string), + }, nil +} + +// defaultMessage implements the interface Message. +func (m PasswordMessage) defaultMessage() *MessageFormat { + return &passwordMessageDefault +} diff --git a/postgres/messages/portal_suspended.go b/postgres/messages/portal_suspended.go new file mode 100644 index 0000000000..55b6501b1e --- /dev/null +++ b/postgres/messages/portal_suspended.go @@ -0,0 +1,63 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package messages + +func init() { + initializeDefaultMessage(PortalSuspended{}) +} + +// PortalSuspended represents a PostgreSQL message. +type PortalSuspended struct { + Integer int32 + String string +} + +var portalSuspendedDefault = MessageFormat{ + Name: "PortalSuspended", + Fields: FieldGroup{ + { + Name: "Header", + Type: Byte1, + Flags: Header, + Data: int32('s'), + }, + { + Name: "MessageLength", + Type: Int32, + Flags: MessageLengthInclusive, + Data: int32(4), + }, + }, +} + +var _ Message = PortalSuspended{} + +// encode implements the interface Message. +func (m PortalSuspended) encode() (MessageFormat, error) { + return m.defaultMessage().Copy(), nil +} + +// decode implements the interface Message. +func (m PortalSuspended) decode(s MessageFormat) (Message, error) { + if err := s.MatchesStructure(*m.defaultMessage()); err != nil { + return nil, err + } + return PortalSuspended{}, nil +} + +// defaultMessage implements the interface Message. +func (m PortalSuspended) defaultMessage() *MessageFormat { + return &portalSuspendedDefault +} diff --git a/postgres/messages/query.go b/postgres/messages/query.go index 6d3238e31f..847b411fb9 100644 --- a/postgres/messages/query.go +++ b/postgres/messages/query.go @@ -1,24 +1,72 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package messages -import ( - "encoding/binary" -) +func init() { + initializeDefaultMessage(Query{}) + addMessageHeader(Query{}) +} -// ReadQuery returns the query from the given buffer. Assumes that the buffer contains a serialized form of a Query -// message. -func ReadQuery(buf []byte) (string, bool) { - if len(buf) < 5 { - return "", false - } - if buf[0] != 'Q' { - return "", false - } - queryLength := int32(binary.BigEndian.Uint32(buf[1:])) - if queryLength <= 5 { - // A query of length 5 or less is empty - return "", true +// Query contains a query given by the client. +type Query struct { + String string +} + +var queryDefault = MessageFormat{ + Name: "Query", + Fields: FieldGroup{ + { + Name: "Header", + Type: Byte1, + Flags: Header, + Data: int32('Q'), + }, + { + Name: "MessageLength", + Type: Int32, + Flags: MessageLengthInclusive, + Data: int32(0), + }, + { + Name: "String", + Type: String, + Data: "", + }, + }, +} + +var _ Message = Query{} + +// encode implements the interface Message. +func (m Query) encode() (MessageFormat, error) { + outputMessage := m.defaultMessage().Copy() + outputMessage.Field("String").MustWrite(m.String) + return outputMessage, nil +} + +// decode implements the interface Message. +func (m Query) decode(s MessageFormat) (Message, error) { + if err := s.MatchesStructure(*m.defaultMessage()); err != nil { + return nil, err } - // The length includes the length bytes, along with the NULL terminator. It does not include the message identifier - // though, so it cancels out the NULL terminator and allows us to use the length as-is. - return string(buf[5:queryLength]), true + return Query{ + String: s.Field("String").MustGet().(string), + }, nil +} + +// defaultMessage implements the interface Message. +func (m Query) defaultMessage() *MessageFormat { + return &queryDefault } diff --git a/postgres/messages/ready_for_query.go b/postgres/messages/ready_for_query.go index 5fa9db449f..96511f1770 100644 --- a/postgres/messages/ready_for_query.go +++ b/postgres/messages/ready_for_query.go @@ -1,12 +1,30 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package messages +func init() { + initializeDefaultMessage(ReadyForQuery{}) +} + // ReadyForQueryTransactionIndicator indicates the state of the transaction related to the query. type ReadyForQueryTransactionIndicator byte const ( - ReadyForQueryTransactionIndicator_Idle ReadyForQueryTransactionIndicator = iota - ReadyForQueryTransactionIndicator_TransactionBlock - ReadyForQueryTransactionIndicator_FailedTransactionBlock + ReadyForQueryTransactionIndicator_Idle ReadyForQueryTransactionIndicator = 'I' + ReadyForQueryTransactionIndicator_TransactionBlock ReadyForQueryTransactionIndicator = 'T' + ReadyForQueryTransactionIndicator_FailedTransactionBlock ReadyForQueryTransactionIndicator = 'E' ) // ReadyForQuery tells the client that the server is ready for a new query cycle. @@ -14,16 +32,49 @@ type ReadyForQuery struct { Indicator ReadyForQueryTransactionIndicator } -// Bytes returns ReadyForQuery as a byte slice, ready to be returned to the client. -func (rfq ReadyForQuery) Bytes() []byte { - var indicator byte - switch rfq.Indicator { - case ReadyForQueryTransactionIndicator_Idle: - indicator = 'I' - case ReadyForQueryTransactionIndicator_TransactionBlock: - indicator = 'T' - case ReadyForQueryTransactionIndicator_FailedTransactionBlock: - indicator = 'E' +var readyForQueryDefault = MessageFormat{ + Name: "ReadyForQuery", + Fields: FieldGroup{ + { + Name: "Header", + Type: Byte1, + Flags: Header, + Data: int32('Z'), + }, + { + Name: "MessageLength", + Type: Int32, + Flags: MessageLengthInclusive, + Data: int32(5), + }, + { + Name: "TransactionIndicator", + Type: Byte1, + Data: int32(0), + }, + }, +} + +var _ Message = ReadyForQuery{} + +// encode implements the interface Message. +func (m ReadyForQuery) encode() (MessageFormat, error) { + outputMessage := m.defaultMessage().Copy() + outputMessage.Field("TransactionIndicator").MustWrite(byte(m.Indicator)) + return outputMessage, nil +} + +// decode implements the interface Message. +func (m ReadyForQuery) decode(s MessageFormat) (Message, error) { + if err := s.MatchesStructure(*m.defaultMessage()); err != nil { + return nil, err } - return []byte{'Z', 0, 0, 0, 5, indicator} + return ReadyForQuery{ + Indicator: ReadyForQueryTransactionIndicator(s.Field("TransactionIndicator").MustGet().(int32)), + }, nil +} + +// defaultMessage implements the interface Message. +func (m ReadyForQuery) defaultMessage() *MessageFormat { + return &readyForQueryDefault } diff --git a/postgres/messages/row_description.go b/postgres/messages/row_description.go index c1545d84b8..6b5473ead7 100644 --- a/postgres/messages/row_description.go +++ b/postgres/messages/row_description.go @@ -1,81 +1,147 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package messages import ( - "bytes" "fmt" + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/vitess/go/vt/proto/query" ) +func init() { + initializeDefaultMessage(RowDescription{}) +} + // RowDescription represents a RowDescription message intended for the client. type RowDescription struct { - Fields []RowDescriptionField + Fields []*query.Field } -// NewRowDescription creates a new RowDescription from the given fields. -func NewRowDescription(fields []*query.Field) (RowDescription, error) { - var err error - rdFields := make([]RowDescriptionField, len(fields)) - for i, field := range fields { - rdFields[i] = RowDescriptionField{ - TableObjectID: 0, // Unused for now - ColumnAttributeNumber: 0, // Unused for now - DataTypeModifier: 0, // Always -1 since we're supporting a narrow set of integers - FormatCode: 0, // Always text for now +var rowDescriptionDefault = MessageFormat{ + Name: "RowDescription", + Fields: FieldGroup{ + { + Name: "Header", + Type: Byte1, + Flags: Header, + Data: int32('T'), + }, + { + Name: "MessageLength", + Type: Int32, + Flags: MessageLengthInclusive, + Data: int32(0), + }, + { + Name: "Fields", + Type: Int16, + Data: int32(0), + Children: []FieldGroup{ + { + { + Name: "ColumnName", + Type: String, + Data: "", + }, + { + Name: "TableObjectID", + Type: Int32, + Data: int32(0), + }, + { + Name: "ColumnAttributeNumber", + Type: Int16, + Data: int32(0), + }, + { + Name: "DataTypeObjectID", + Type: Int32, + Data: int32(0), + }, + { + Name: "DataTypeSize", + Type: Int16, + Data: int32(0), + }, + { + Name: "DataTypeModifier", + Type: Int32, + Data: int32(0), + }, + { + Name: "FormatCode", + Type: Int16, + Data: int32(0), + }, + }, + }, + }, + }, +} + +var _ Message = RowDescription{} + +// encode implements the interface Message. +func (m RowDescription) encode() (MessageFormat, error) { + outputMessage := m.defaultMessage().Copy() + for i := 0; i < len(m.Fields); i++ { + field := m.Fields[i] + dataTypeObjectID, err := VitessFieldToDataTypeObjectID(field) + if err != nil { + return MessageFormat{}, err } - rdFields[i].Name = field.Name - rdFields[i].DataTypeObjectID, err = VitessTypeToDataTypeObjectID(field.Type) + dataTypeSize, err := VitessFieldToDataTypeSize(field) if err != nil { - return RowDescription{}, err + return MessageFormat{}, err } - rdFields[i].DataTypeSize, err = VitessTypeToDataTypeSize(field.Type) + dataTypeModifier, err := VitessFieldToDataTypeModifier(field) if err != nil { - return RowDescription{}, err + return MessageFormat{}, err } + outputMessage.Field("Fields").Child("ColumnName", i).MustWrite(field.Name) + outputMessage.Field("Fields").Child("DataTypeObjectID", i).MustWrite(dataTypeObjectID) + outputMessage.Field("Fields").Child("DataTypeSize", i).MustWrite(dataTypeSize) + outputMessage.Field("Fields").Child("DataTypeModifier", i).MustWrite(dataTypeModifier) } - return RowDescription{ - Fields: rdFields, - }, nil + return outputMessage, nil } -// RowDescriptionField represents a field in RowDescription. -type RowDescriptionField struct { - Name string - TableObjectID int32 - ColumnAttributeNumber int16 - DataTypeObjectID int32 - DataTypeSize int16 - DataTypeModifier int32 - FormatCode int16 -} - -// Bytes returns RowDescription as a byte slice, ready to be returned to the client. -func (rd RowDescription) Bytes() []byte { - buf := bytes.Buffer{} - buf.WriteByte('T') // Message Type - WriteNumber(&buf, int32(0)) // Message length, will be corrected later - WriteNumber(&buf, int16(len(rd.Fields))) - for _, rdf := range rd.Fields { - rdf.Bytes(&buf) +// decode implements the interface Message. +func (m RowDescription) decode(s MessageFormat) (Message, error) { + if err := s.MatchesStructure(*m.defaultMessage()); err != nil { + return nil, err + } + fieldCount := int(s.Field("Fields").MustGet().(int32)) + for i := 0; i < fieldCount; i++ { + //TODO: decode the message in here } - return WriteLength(buf.Bytes()) + return RowDescription{ + Fields: nil, + }, nil } -// Bytes writes the field into the given buffer. -func (rdf RowDescriptionField) Bytes(buf *bytes.Buffer) { - buf.WriteString(rdf.Name) - buf.WriteByte(0) // Trailing NULL character, denoting the end of the string - WriteNumber(buf, rdf.TableObjectID) - WriteNumber(buf, rdf.ColumnAttributeNumber) - WriteNumber(buf, rdf.DataTypeObjectID) - WriteNumber(buf, rdf.DataTypeSize) - WriteNumber(buf, rdf.DataTypeModifier) - WriteNumber(buf, rdf.FormatCode) +// defaultMessage implements the interface Message. +func (m RowDescription) defaultMessage() *MessageFormat { + return &rowDescriptionDefault } -// VitessTypeToDataTypeObjectID returns a type, as defined by Vitess, into a type as defined by Postgres. -func VitessTypeToDataTypeObjectID(typ query.Type) (int32, error) { - switch typ { +// VitessFieldToDataTypeObjectID returns a type, as defined by Vitess, into a type as defined by Postgres. +func VitessFieldToDataTypeObjectID(field *query.Field) (int32, error) { + switch field.Type { case query.Type_INT8: return 17, nil case query.Type_INT16: @@ -86,14 +152,20 @@ func VitessTypeToDataTypeObjectID(typ query.Type) (int32, error) { return 23, nil case query.Type_INT64: return 20, nil + case query.Type_CHAR: + return 18, nil + case query.Type_VARCHAR: + return 1043, nil + case query.Type_TEXT: + return 25, nil default: return 0, fmt.Errorf("unsupported type returned from engine") } } -// VitessTypeToDataTypeSize returns the type's size, as defined by Vitess, into the size as defined by Postgres. -func VitessTypeToDataTypeSize(typ query.Type) (int16, error) { - switch typ { +// VitessFieldToDataTypeSize returns the type's size, as defined by Vitess, into the size as defined by Postgres. +func VitessFieldToDataTypeSize(field *query.Field) (int16, error) { + switch field.Type { case query.Type_INT8: return 1, nil case query.Type_INT16: @@ -104,6 +176,38 @@ func VitessTypeToDataTypeSize(typ query.Type) (int16, error) { return 4, nil case query.Type_INT64: return 8, nil + case query.Type_CHAR: + return -1, nil + case query.Type_VARCHAR: + return -1, nil + case query.Type_TEXT: + return -1, nil + default: + return 0, fmt.Errorf("unsupported type returned from engine") + } +} + +// VitessFieldToDataTypeModifier returns the field's data type modifier as defined by Postgres. +func VitessFieldToDataTypeModifier(field *query.Field) (int32, error) { + switch field.Type { + case query.Type_INT8: + return -1, nil + case query.Type_INT16: + return -1, nil + case query.Type_INT24: + return -1, nil + case query.Type_INT32: + return -1, nil + case query.Type_INT64: + return -1, nil + case query.Type_CHAR: + // PostgreSQL adds 4 to the length for an unknown reason + return int32(int64(field.ColumnLength)/sql.CharacterSetID(field.Charset).MaxLength()) + 4, nil + case query.Type_VARCHAR: + // PostgreSQL adds 4 to the length for an unknown reason + return int32(int64(field.ColumnLength)/sql.CharacterSetID(field.Charset).MaxLength()) + 4, nil + case query.Type_TEXT: + return -1, nil default: return 0, fmt.Errorf("unsupported type returned from engine") } diff --git a/postgres/messages/sasl_initial_response.go b/postgres/messages/sasl_initial_response.go new file mode 100644 index 0000000000..f4b7a98c0d --- /dev/null +++ b/postgres/messages/sasl_initial_response.go @@ -0,0 +1,92 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package messages + +func init() { + initializeDefaultMessage(SASLInitialResponse{}) +} + +// SASLInitialResponse represents a PostgreSQL message. +type SASLInitialResponse struct { + Name string + Response []byte +} + +var sASLInitialResponseDefault = MessageFormat{ + Name: "SASLInitialResponse", + Fields: FieldGroup{ + { + Name: "Header", + Type: Byte1, + Flags: Header, + Data: int32('p'), + }, + { + Name: "MessageLength", + Type: Int32, + Flags: MessageLengthInclusive, + Data: int32(0), + }, + { + Name: "Name", + Type: String, + Data: "", + }, + { + Name: "ResponseLength", + Type: Int32, + Flags: ByteCount, + Data: int32(-1), + }, + { + Name: "ResponseData", + Type: String, + Data: "", + }, + }, +} + +var _ Message = SASLInitialResponse{} + +// encode implements the interface Message. +func (m SASLInitialResponse) encode() (MessageFormat, error) { + outputMessage := m.defaultMessage().Copy() + outputMessage.Field("Name").MustWrite(m.Name) + if len(m.Response) > 0 { + outputMessage.Field("ResponseLength").MustWrite(len(m.Response)) + outputMessage.Field("ResponseData").MustWrite(m.Response) + } + return outputMessage, nil +} + +// decode implements the interface Message. +func (m SASLInitialResponse) decode(s MessageFormat) (Message, error) { + if err := s.MatchesStructure(*m.defaultMessage()); err != nil { + return nil, err + } + var responseData []byte + if s.Field("ResponseLength").MustGet().(int32) > 0 { + responseData = s.Field("ResponseData").MustGet().([]byte) + } + return SASLInitialResponse{ + Name: s.Field("Name").MustGet().(string), + Response: responseData, + }, nil +} + +// defaultMessage implements the interface Message. +func (m SASLInitialResponse) defaultMessage() *MessageFormat { + return &sASLInitialResponseDefault +} diff --git a/postgres/messages/sasl_response.go b/postgres/messages/sasl_response.go new file mode 100644 index 0000000000..b40d5ec019 --- /dev/null +++ b/postgres/messages/sasl_response.go @@ -0,0 +1,71 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package messages + +func init() { + initializeDefaultMessage(SASLResponse{}) +} + +// SASLResponse represents a PostgreSQL message. +type SASLResponse struct { + Data []byte +} + +var sASLResponseDefault = MessageFormat{ + Name: "SASLResponse", + Fields: FieldGroup{ + { + Name: "Header", + Type: Byte1, + Flags: Header, + Data: int32('p'), + }, + { + Name: "MessageLength", + Type: Int32, + Flags: MessageLengthInclusive, + Data: int32(0), + }, + { + Name: "Data", + Type: ByteN, + Data: []byte{}, + }, + }, +} + +var _ Message = SASLResponse{} + +// encode implements the interface Message. +func (m SASLResponse) encode() (MessageFormat, error) { + outputMessage := m.defaultMessage().Copy() + outputMessage.Field("Data").MustWrite(m.Data) + return outputMessage, nil +} + +// decode implements the interface Message. +func (m SASLResponse) decode(s MessageFormat) (Message, error) { + if err := s.MatchesStructure(*m.defaultMessage()); err != nil { + return nil, err + } + return SASLResponse{ + Data: s.Field("Data").MustGet().([]byte), + }, nil +} + +// defaultMessage implements the interface Message. +func (m SASLResponse) defaultMessage() *MessageFormat { + return &sASLResponseDefault +} diff --git a/postgres/messages/ssl_request.go b/postgres/messages/ssl_request.go new file mode 100644 index 0000000000..aa9398ea4e --- /dev/null +++ b/postgres/messages/ssl_request.go @@ -0,0 +1,59 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package messages + +func init() { + initializeDefaultMessage(SSLRequest{}) +} + +// SSLRequest represents a PostgreSQL message. +type SSLRequest struct{} + +var sslRequestDefault = MessageFormat{ + Name: "SSLRequest", + Fields: FieldGroup{ + { + Name: "MessageLength", + Type: Int32, + Flags: MessageLengthInclusive, + Data: int32(8), + }, + { + Name: "RequestCode", + Type: Int32, + Data: int32(80877103), + }, + }, +} + +var _ Message = SSLRequest{} + +// encode implements the interface Message. +func (m SSLRequest) encode() (MessageFormat, error) { + return m.defaultMessage().Copy(), nil +} + +// decode implements the interface Message. +func (m SSLRequest) decode(s MessageFormat) (Message, error) { + if err := s.MatchesStructure(*m.defaultMessage()); err != nil { + return nil, err + } + return SSLRequest{}, nil +} + +// defaultMessage implements the interface Message. +func (m SSLRequest) defaultMessage() *MessageFormat { + return &sslRequestDefault +} diff --git a/postgres/messages/ssl_response.go b/postgres/messages/ssl_response.go index 7ee6af9676..b09667cf52 100644 --- a/postgres/messages/ssl_response.go +++ b/postgres/messages/ssl_response.go @@ -1,15 +1,74 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package messages +import "fmt" + +func init() { + initializeDefaultMessage(SSLResponse{}) +} + // SSLResponse tells the client whether SSL is supported. type SSLResponse struct { SupportsSSL bool } -// Bytes returns SSLResponse as a byte slice, ready to be returned to the client. -func (sslr SSLResponse) Bytes() []byte { - if sslr.SupportsSSL { - return []byte{'Y'} +var sslResponseDefault = MessageFormat{ + Name: "SSLResponse", + Fields: FieldGroup{ + { + Name: "Supported", + Type: Byte1, + Data: int32(0), + }, + }, +} + +var _ Message = SSLResponse{} + +// encode implements the interface Message. +func (m SSLResponse) encode() (MessageFormat, error) { + outputMessage := m.defaultMessage().Copy() + if m.SupportsSSL { + outputMessage.Field("Supported").MustWrite('Y') + } else { + outputMessage.Field("Supported").MustWrite('N') + } + return outputMessage, nil +} + +// decode implements the interface Message. +func (m SSLResponse) decode(s MessageFormat) (Message, error) { + if err := s.MatchesStructure(*m.defaultMessage()); err != nil { + return nil, err + } + var supported bool + supportedInt := s.Field("Supported").MustGet().(int32) + if supportedInt == 'Y' { + supported = true + } else if supportedInt == 'N' { + supported = false } else { - return []byte{'N'} + return nil, fmt.Errorf("Unexpected supported value in SSLResponse message: %d", supportedInt) } + return SSLResponse{ + SupportsSSL: supported, + }, nil +} + +// defaultMessage implements the interface Message. +func (m SSLResponse) defaultMessage() *MessageFormat { + return &sslResponseDefault } diff --git a/postgres/messages/startup_message.go b/postgres/messages/startup_message.go index 57a05593e2..db675be839 100644 --- a/postgres/messages/startup_message.go +++ b/postgres/messages/startup_message.go @@ -1,52 +1,107 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package messages -import ( - "encoding/binary" - "fmt" -) +func init() { + initializeDefaultMessage(StartupMessage{}) +} // StartupMessage is returned by the client upon connecting to the server, providing details about the client. type StartupMessage struct { - ProtocolMajorVersion int16 - ProtocolMinorVersion int16 + ProtocolMajorVersion int + ProtocolMinorVersion int Parameters map[string]string } -// ReadStartupMessage returns the StartupMessage from the buffer. -func ReadStartupMessage(buf []byte) (StartupMessage, error) { - if len(buf) < 4 { - return StartupMessage{}, fmt.Errorf("invalid StartupMessage") +var startupMessageDefault = MessageFormat{ + Name: "StartupMessage", + Fields: FieldGroup{ + { + Name: "MessageLength", + Type: Int32, + Flags: MessageLengthInclusive, + Data: int32(0), + }, + { // The docs specify a single Int32 field, but the upper and lower bits are different values, so this just splits them + Name: "ProtocolMajorVersion", + Type: Int16, + Data: int32(0), + }, + { + Name: "ProtocolMinorVersion", + Type: Int16, + Data: int32(0), + }, + { + Name: "Parameters", + Type: Repeated, + Flags: RepeatedTerminator, + Data: int32(0), + Children: []FieldGroup{ + { + { + Name: "ParameterName", + Type: String, + Data: "", + }, + { + Name: "ParameterValue", + Type: String, + Data: "", + }, + }, + }, + }, + }, +} + +var _ Message = StartupMessage{} + +// encode implements the interface Message. +func (m StartupMessage) encode() (MessageFormat, error) { + outputMessage := m.defaultMessage().Copy() + outputMessage.Field("ProtocolMajorVersion").MustWrite(m.ProtocolMajorVersion) + outputMessage.Field("ProtocolMinorVersion").MustWrite(m.ProtocolMinorVersion) + index := 0 + for name, value := range m.Parameters { + outputMessage.Field("Parameters").Child("ParameterName", index).MustWrite(name) + outputMessage.Field("Parameters").Child("ParameterValue", index).MustWrite(value) + index++ + } + return outputMessage, nil +} + +// decode implements the interface Message. +func (m StartupMessage) decode(s MessageFormat) (Message, error) { + if err := s.MatchesStructure(*m.defaultMessage()); err != nil { + return nil, err } - messageLength := int32(binary.BigEndian.Uint32(buf)) - protocolMajorVersion := int16(binary.BigEndian.Uint16(buf[4:])) - protocolMinorVersion := int16(binary.BigEndian.Uint16(buf[6:])) - // Set the buffer to the stated length and skip the length and version - buf = buf[8:messageLength] parameters := make(map[string]string) - for len(buf) > 0 { - var name string - var value string - for i, b := range buf { - if b == 0 { - name = string(buf[:i]) - buf = buf[i+1:] - break - } - } - for i, b := range buf { - if b == 0 { - value = string(buf[:i]) - buf = buf[i+1:] - break - } - } - if len(name) > 0 && len(value) > 0 { - parameters[name] = value - } + count := int(s.Field("Parameters").MustGet().(int32)) + for i := 0; i < count; i++ { + parameters[s.Field("Parameters").Child("ParameterName", i).MustGet().(string)] = + s.Field("Parameters").Child("ParameterValue", i).MustGet().(string) } return StartupMessage{ - ProtocolMajorVersion: protocolMajorVersion, - ProtocolMinorVersion: protocolMinorVersion, + ProtocolMajorVersion: int(s.Field("ProtocolMajorVersion").MustGet().(int32)), + ProtocolMinorVersion: int(s.Field("ProtocolMinorVersion").MustGet().(int32)), Parameters: parameters, }, nil } + +// defaultMessage implements the interface Message. +func (m StartupMessage) defaultMessage() *MessageFormat { + return &startupMessageDefault +} diff --git a/postgres/messages/sync.go b/postgres/messages/sync.go new file mode 100644 index 0000000000..70dee95c52 --- /dev/null +++ b/postgres/messages/sync.go @@ -0,0 +1,61 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package messages + +func init() { + initializeDefaultMessage(Sync{}) + addMessageHeader(Sync{}) +} + +// Sync represents a PostgreSQL message. +type Sync struct{} + +var syncDefault = MessageFormat{ + Name: "Sync", + Fields: FieldGroup{ + { + Name: "Header", + Type: Byte1, + Flags: Header, + Data: int32('S'), + }, + { + Name: "MessageLength", + Type: Int32, + Flags: MessageLengthInclusive, + Data: int32(4), + }, + }, +} + +var _ Message = Sync{} + +// encode implements the interface Message. +func (m Sync) encode() (MessageFormat, error) { + return m.defaultMessage().Copy(), nil +} + +// decode implements the interface Message. +func (m Sync) decode(s MessageFormat) (Message, error) { + if err := s.MatchesStructure(*m.defaultMessage()); err != nil { + return nil, err + } + return Sync{}, nil +} + +// defaultMessage implements the interface Message. +func (m Sync) defaultMessage() *MessageFormat { + return &syncDefault +} diff --git a/postgres/messages/terminate.go b/postgres/messages/terminate.go index d55b11e44f..b0b6074f25 100644 --- a/postgres/messages/terminate.go +++ b/postgres/messages/terminate.go @@ -1,9 +1,61 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package messages -// ReadTerminate returns whether the buffer represents a Terminate message. -func ReadTerminate(buf []byte) bool { - if len(buf) < 5 { - return false +func init() { + initializeDefaultMessage(Terminate{}) + addMessageHeader(Terminate{}) +} + +// Terminate tells the server to close the connection. +type Terminate struct{} + +var terminateDefault = MessageFormat{ + Name: "Terminate", + Fields: FieldGroup{ + { + Name: "Header", + Type: Byte1, + Flags: Header, + Data: int32('X'), + }, + { + Name: "MessageLength", + Type: Int32, + Flags: MessageLengthInclusive, + Data: int32(0), + }, + }, +} + +var _ Message = Terminate{} + +// encode implements the interface Message. +func (m Terminate) encode() (MessageFormat, error) { + return terminateDefault.Copy(), nil +} + +// decode implements the interface Message. +func (m Terminate) decode(s MessageFormat) (Message, error) { + if err := s.MatchesStructure(*m.defaultMessage()); err != nil { + return nil, err } - return buf[0] == 'X' && buf[4] == 4 + return Terminate{}, nil +} + +// defaultMessage implements the interface Message. +func (m Terminate) defaultMessage() *MessageFormat { + return &terminateDefault } diff --git a/postgres/messages/utils.go b/postgres/messages/utils.go index 71372a5cfd..7f15acfedd 100644 --- a/postgres/messages/utils.go +++ b/postgres/messages/utils.go @@ -1,3 +1,17 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package messages import ( @@ -7,6 +21,8 @@ import ( "golang.org/x/exp/constraints" ) +//TODO: delete these Write functions + // WriteLength writes the length of the message into the byte slice. Modifies the given byte slice, while also // returning the same slice. Assumes that the first byte is the message identifier, while the next 4 bytes are // the length. @@ -22,3 +38,62 @@ func WriteLength(b []byte) []byte { func WriteNumber[T constraints.Integer | constraints.Float](buf *bytes.Buffer, num T) { _ = binary.Write(buf, binary.BigEndian, num) } + +// Stack is a generic stack. +type Stack[T any] struct { + values []T +} + +// NewStack creates a new, empty stack. +func NewStack[T any]() *Stack[T] { + return &Stack[T]{} +} + +// Len returns the size of the stack. +func (s *Stack[T]) Len() int { + return len(s.values) +} + +// Peek returns the top value on the stack without removing it. +func (s *Stack[T]) Peek() (value T) { + if len(s.values) == 0 { + return + } + return s.values[len(s.values)-1] +} + +// PeekDepth returns the n-th value from the top. PeekDepth(0) is equivalent to the standard Peek(). +func (s *Stack[T]) PeekDepth(depth int) (value T) { + if len(s.values) <= depth { + return + } + return s.values[len(s.values)-(1+depth)] +} + +// PeekReference returns a reference to the top value on the stack without removing it. +func (s *Stack[T]) PeekReference() *T { + if len(s.values) == 0 { + return nil + } + return &s.values[len(s.values)-1] +} + +// Pop returns the top value on the stack while also removing it from the stack. +func (s *Stack[T]) Pop() (value T) { + if len(s.values) == 0 { + return + } + value = s.values[len(s.values)-1] + s.values = s.values[:len(s.values)-1] + return +} + +// Push adds the given value to the stack. +func (s *Stack[T]) Push(value T) { + s.values = append(s.values, value) +} + +// Empty returns whether the stack is empty. +func (s *Stack[T]) Empty() bool { + return len(s.values) == 0 +} diff --git a/system_checks.go b/system_checks.go index 99fb1a7137..ce8fb87ff3 100644 --- a/system_checks.go +++ b/system_checks.go @@ -1,4 +1,4 @@ -// Copyright 2020 Dolthub, Inc. +// Copyright 2023 Dolthub, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License.