diff --git a/examples/autobahn_client.nim b/examples/autobahn_client.nim index 8d6ff03787..0a417e4228 100644 --- a/examples/autobahn_client.nim +++ b/examples/autobahn_client.nim @@ -9,8 +9,10 @@ import std/[strutils], - pkg/[chronos, chronicles, stew/byteutils], - ../websock/[websock, types, frame, extensions/compression/deflate] + chronos, + chronicles, + stew/byteutils, + ../websock/[websock, types, extensions/compression/deflate] const clientFlags = {NoVerifyHost, NoVerifyServerName} diff --git a/examples/client.nim b/examples/client.nim index 308bb9a8cd..019e538c74 100644 --- a/examples/client.nim +++ b/examples/client.nim @@ -7,10 +7,10 @@ ## This file may not be copied, modified, or distributed except according to ## those terms. -import pkg/[ +import chronos, chronicles, - stew/byteutils] + stew/byteutils import ../websock/websock diff --git a/examples/server.nim b/examples/server.nim index e7bc467ee7..9131322131 100644 --- a/examples/server.nim +++ b/examples/server.nim @@ -7,13 +7,14 @@ ## This file may not be copied, modified, or distributed except according to ## those terms. -import std/uri -import pkg/[chronos, - chronicles, - httputils] +import + std/uri, + chronos, + chronicles, + ../websock/[websock, extensions/compression/deflate] -import ../websock/[websock, extensions/compression/deflate] -import ../tests/keys +when defined tls: + import ../tests/keys proc handle(request: HttpRequest) {.async.} = trace "Handling request:", uri = request.uri.path diff --git a/scripts/ws.sh b/scripts/ws.sh index f8b6686a10..eb29d52cd8 100644 --- a/scripts/ws.sh +++ b/scripts/ws.sh @@ -9,6 +9,7 @@ # prevent issue https://github.com/status-im/nimbus-eth1/issues/3661 set -e +trap "trap - SIGTERM && pkill -P $$" SIGINT SIGTERM EXIT # script arguments [[ $# -ne 1 ]] && { echo "Usage: $0 NIM_VERSION"; } @@ -17,22 +18,23 @@ NIM_VERSION="$1" cd "$(dirname "${BASH_SOURCE[0]}")"/.. REPO_DIR="${PWD}" +CFG="server" +REPORT_DIR="autobahn/reports/$CFG-$NIM_VERSION" +mkdir -p autobahn/reports/$CFG nim c -d:release examples/server examples/server & -server=$! -mkdir -p autobahn/reports - -docker run \ +docker run --rm \ -v ${REPO_DIR}/autobahn:/config \ -v ${REPO_DIR}/autobahn/reports:/reports \ --network=host \ --name fuzzingclient \ crossbario/autobahn-testsuite wstest --mode fuzzingclient --spec /config/fuzzingclient.json -kill $server +mv autobahn/reports/$CFG "$REPORT_DIR" -mv autobahn/reports/server autobahn/reports/server-${NIM_VERSION} +echo "* [Nim-${NIM_VERSION} $CFG summary report]($CFG-${NIM_VERSION}/index.html)" > "$REPORT_DIR.txt" -echo "* [Nim-${NIM_VERSION} ws server summary report](server-${NIM_VERSION}/index.html)" > "autobahn/reports/server-${NIM_VERSION}.txt" +# squash to single line and look for errors +(cat $REPORT_DIR/index.json | tr '\n' '!' | sed "s|\},\!|\n|g" | tr '!' ' ' | tr -s ' ' | grep -v -e '"behavior": "OK"' -e '"behavior": "NON-STRICT"' -e '"behavior": "INFORMATIONAL"' && exit 1) || true diff --git a/scripts/wsc.sh b/scripts/wsc.sh index 5b4da3f735..f201e93015 100644 --- a/scripts/wsc.sh +++ b/scripts/wsc.sh @@ -9,6 +9,7 @@ # prevent issue https://github.com/status-im/nimbus-eth1/issues/3661 set -e +trap "trap - SIGTERM && pkill -P $$" SIGINT SIGTERM EXIT # script arguments [[ $# -ne 1 ]] && { echo "Usage: $0 NIM_VERSION"; } @@ -17,21 +18,25 @@ NIM_VERSION="$1" cd "$(dirname "${BASH_SOURCE[0]}")"/.. REPO_DIR="${PWD}" +CFG="client" +REPORT_DIR="autobahn/reports/$CFG-$NIM_VERSION" +mkdir -p autobahn/reports/$CFG -mkdir -p autobahn/reports - -docker run -d \ +docker run -d --rm \ -v ${REPO_DIR}/autobahn:/config \ -v ${REPO_DIR}/autobahn/reports:/reports \ --network=host \ --name fuzzingserver \ crossbario/autobahn-testsuite wstest --webport=0 --mode fuzzingserver --spec /config/fuzzingserver.json +trap "docker kill fuzzingserver" SIGINT SIGTERM EXIT + nim c -d:release examples/autobahn_client examples/autobahn_client -docker kill fuzzingserver +mv autobahn/reports/$CFG $REPORT_DIR -mv autobahn/reports/client autobahn/reports/client-${NIM_VERSION} +echo "* [Nim-${NIM_VERSION} $CFG summary report]($CFG-${NIM_VERSION}/index.html)" > "$REPORT_DIR.txt" -echo "* [Nim-${NIM_VERSION} ws client summary report](client-${NIM_VERSION}/index.html)" > "autobahn/reports/client-${NIM_VERSION}.txt" +# squash to single line and look for errors +(cat $REPORT_DIR/index.json | tr '\n' '!' | sed "s|\},\!|\n|g" | tr '!' ' ' | tr -s ' ' | grep -v -e '"behavior": "OK"' -e '"behavior": "NON-STRICT"' -e '"behavior": "INFORMATIONAL"' && exit 1) || true diff --git a/scripts/wss.sh b/scripts/wss.sh index cd51fe69e0..f28d5ea105 100644 --- a/scripts/wss.sh +++ b/scripts/wss.sh @@ -8,8 +8,8 @@ # prevent issue https://github.com/status-im/nimbus-eth1/issues/3661 - set -e +trap "trap - SIGTERM && pkill -P $$" SIGINT SIGTERM EXIT # script arguments [[ $# -ne 1 ]] && { echo "Usage: $0 NIM_VERSION"; } @@ -18,22 +18,23 @@ NIM_VERSION="$1" cd "$(dirname "${BASH_SOURCE[0]}")"/.. REPO_DIR="${PWD}" +CFG="server_tls" +REPORT_DIR="autobahn/reports/$CFG-$NIM_VERSION" +mkdir -p autobahn/reports/$CFG nim c -d:tls -d:release -o:examples/tls_server examples/server.nim examples/tls_server & -server=$! - -mkdir -p autobahn/reports -docker run \ +docker run --rm \ -v ${REPO_DIR}/autobahn:/config \ -v ${REPO_DIR}/autobahn/reports:/reports \ --network=host \ --name fuzzingclient_tls \ crossbario/autobahn-testsuite wstest --mode fuzzingclient --spec /config/fuzzingclient_tls.json -kill $server +mv autobahn/reports/$CFG "$REPORT_DIR" -mv autobahn/reports/server_tls autobahn/reports/server_tls-${NIM_VERSION} +echo "* [Nim-${NIM_VERSION} $CFG summary report]($CFG-${NIM_VERSION}/index.html)" > "$REPORT_DIR.txt" -echo "* [Nim-${NIM_VERSION} wss server summary report](server_tls-${NIM_VERSION}/index.html)" > "autobahn/reports/server_tls-${NIM_VERSION}.txt" +# squash to single line and look for errors +(cat $REPORT_DIR/index.json | tr '\n' '!' | sed "s|\},\!|\n|g" | tr '!' ' ' | tr -s ' ' | grep -v -e '"behavior": "OK"' -e '"behavior": "NON-STRICT"' -e '"behavior": "INFORMATIONAL"' && exit 1) || true diff --git a/scripts/wssc.sh b/scripts/wssc.sh index 8174438825..8fcc7aa1e2 100644 --- a/scripts/wssc.sh +++ b/scripts/wssc.sh @@ -9,6 +9,7 @@ # prevent issue https://github.com/status-im/nimbus-eth1/issues/3661 set -e +trap "trap - SIGTERM && pkill -P $$" SIGINT SIGTERM EXIT # script arguments [[ $# -ne 1 ]] && { echo "Usage: $0 NIM_VERSION"; } @@ -17,21 +18,25 @@ NIM_VERSION="$1" cd "$(dirname "${BASH_SOURCE[0]}")"/.. REPO_DIR="${PWD}" +CFG="client_tls" +REPORT_DIR="autobahn/reports/$CFG-$NIM_VERSION" +mkdir -p autobahn/reports/$CFG -mkdir -p autobahn/reports - -docker run -d \ +docker run -d --rm \ -v ${REPO_DIR}/autobahn:/config \ -v ${REPO_DIR}/autobahn/reports:/reports \ --network=host \ --name fuzzingserver_tls \ crossbario/autobahn-testsuite wstest --webport=0 --mode fuzzingserver --spec /config/fuzzingserver_tls.json +trap "docker kill fuzzingserver_tls" SIGINT SIGTERM EXIT + nim c -d:tls -d:release -o:examples/autobahn_tlsclient examples/autobahn_client examples/autobahn_tlsclient -docker kill fuzzingserver_tls +mv autobahn/reports/$CFG $REPORT_DIR -mv autobahn/reports/client_tls autobahn/reports/client_tls-${NIM_VERSION} +echo "* [Nim-${NIM_VERSION} $CFG summary report]($CFG-${NIM_VERSION}/index.html)" > "$REPORT_DIR.txt" -echo "* [Nim-${NIM_VERSION} wss client summary report](client_tls-${NIM_VERSION}/index.html)" > "autobahn/reports/client_tls-${NIM_VERSION}.txt" +# squash to single line and look for errors +(cat $REPORT_DIR/index.json | tr '\n' '!' | sed "s|\},\!|\n|g" | tr '!' ' ' | tr -s ' ' | grep -v -e '"behavior": "OK"' -e '"behavior": "NON-STRICT"' -e '"behavior": "INFORMATIONAL"' && exit 1) || true diff --git a/tests/extensions/base64ext.nim b/tests/extensions/base64ext.nim index 26319c95ae..f0478bf23d 100644 --- a/tests/extensions/base64ext.nim +++ b/tests/extensions/base64ext.nim @@ -8,12 +8,11 @@ ## those terms. import - pkg/[stew/base64, - chronos, - chronicles, - results], - ../../websock/types, - ../../websock/frame + stew/base64, + chronos, + chronicles, + results, + ../../websock/[frame, types] type Base64Ext = ref object of Ext @@ -23,7 +22,9 @@ type const extID = "base64" -method decode(ext: Base64Ext, frame: Frame): Future[Frame] {.async.} = +method decode( + ext: Base64Ext, frame: Frame +): Future[Frame] {.async: (raises: [CancelledError, AsyncStreamError, WebSocketError]).} = if frame.opcode notin {Opcode.Text, Opcode.Binary, Opcode.Cont}: return frame @@ -48,10 +49,13 @@ method decode(ext: Base64Ext, frame: Frame): Future[Frame] {.async.} = # bug in Base64.Decode when accepts seq[byte] let instr = cast[string](data) - if ext.padding: - frame.data = Base64Pad.decode(instr) - else: - frame.data = Base64.decode(instr) + try: + if ext.padding: + frame.data = Base64Pad.decode(instr) + else: + frame.data = Base64.decode(instr) + except Base64Error: + raise newException(WSExtError, "invalid data") trace "Base64Ext decode", input=frame.length, output=frame.data.len @@ -62,7 +66,9 @@ method decode(ext: Base64Ext, frame: Frame): Future[Frame] {.async.} = return frame -method encode(ext: Base64Ext, frame: Frame): Future[Frame] {.async.} = +method encode( + ext: Base64Ext, frame: Frame +): Future[Frame] {.async: (raises: [CancelledError, WebSocketError]).} = if frame.opcode notin {Opcode.Text, Opcode.Binary, Opcode.Cont}: return frame diff --git a/tests/extensions/hexext.nim b/tests/extensions/hexext.nim index 1855c29312..b675638e54 100644 --- a/tests/extensions/hexext.nim +++ b/tests/extensions/hexext.nim @@ -7,13 +7,7 @@ ## This file may not be copied, modified, or distributed except according to ## those terms. -import - pkg/[results, - stew/byteutils, - chronos, - chronicles], - ../../websock/types, - ../../websock/frame +import results, stew/byteutils, chronos, chronicles, ../../websock/[frame, types] type HexExt = ref object of Ext @@ -22,7 +16,9 @@ type const extID = "hex" -method decode(ext: HexExt, frame: Frame): Future[Frame] {.async.} = +method decode( + ext: HexExt, frame: Frame +): Future[Frame] {.async: (raises: [CancelledError, AsyncStreamError, WebSocketError]).} = if frame.opcode notin {Opcode.Text, Opcode.Binary, Opcode.Cont}: return frame @@ -45,7 +41,11 @@ method decode(ext: HexExt, frame: Frame): Future[Frame] {.async.} = if data.len > ext.session.frameSize: raise newException(WSPayloadTooLarge, "payload exceeds allowed max frame size") - frame.data = hexToSeqByte(cast[string](data)) + frame.data = try: + hexToSeqByte(cast[string](data)) + except ValueError: + raise newException(WSExtError, "invalid data") + trace "HexExt decode", input=frame.length, output=frame.data.len frame.length = frame.data.len.uint64 @@ -55,7 +55,9 @@ method decode(ext: HexExt, frame: Frame): Future[Frame] {.async.} = return frame -method encode(ext: HexExt, frame: Frame): Future[Frame] {.async.} = +method encode( + ext: HexExt, frame: Frame +): Future[Frame] {.async: (raises: [CancelledError, WebSocketError]).} = if frame.opcode notin {Opcode.Text, Opcode.Binary, Opcode.Cont}: return frame diff --git a/tests/extensions/testcompression.nim b/tests/extensions/testcompression.nim index 89b01e119a..31c13984e9 100644 --- a/tests/extensions/testcompression.nim +++ b/tests/extensions/testcompression.nim @@ -7,10 +7,12 @@ ## This file may not be copied, modified, or distributed except according to ## those terms. -import std/[os, strutils] -import pkg/[chronos/unittest2/asynctests, stew/io2] -import ../../websock/websock -import ../../websock/extensions/compression/deflate +import + std/[os, strutils], + chronos/unittest2/asynctests, + stew/io2, + ../../websock/websock, + ../../websock/extensions/compression/deflate const dataFolder = currentSourcePath.rsplit(os.DirSep, 1)[0] / "data" diff --git a/tests/extensions/testextflow.nim b/tests/extensions/testextflow.nim index c11b1f188e..15282d92b8 100644 --- a/tests/extensions/testextflow.nim +++ b/tests/extensions/testextflow.nim @@ -7,9 +7,7 @@ ## This file may not be copied, modified, or distributed except according to ## those terms. -import std/strutils -import pkg/[chronos, stew/byteutils] -import pkg/asynctest/unittest2 +import std/strutils, chronos, stew/byteutils, asynctest/unittest import ../../ws/ws @@ -28,13 +26,13 @@ proc new*( name: "HelperExtension") method decode*( - self: HelperExtension, - frame: Frame): Future[Frame] {.async.} = + self: HelperExtension, frame: Frame +): Future[Frame] {.async: (raises: [CancelledError, AsyncStreamError, WebSocketError]).} = return await self.handler(self, frame) method encode*( - self: HelperExtension, - frame: Frame): Future[Frame] {.async.} = + self: HelperExtension, frame: Frame +): Future[Frame] {.async: (raises: [CancelledError, WebSocketError]).} = return await self.handler(self, frame) const TestString = "Hello" diff --git a/tests/extensions/testexts.nim b/tests/extensions/testexts.nim index 6fd9661a0f..2b1a5c355a 100644 --- a/tests/extensions/testexts.nim +++ b/tests/extensions/testexts.nim @@ -7,8 +7,7 @@ ## This file may not be copied, modified, or distributed except according to ## those terms. -import pkg/[chronos/unittest2/asynctests, stew/byteutils] -import ./base64ext, ./hexext +import chronos/unittest2/asynctests, stew/byteutils, ./[base64ext, hexext] import ../../websock/websock, ../helpers suite "multiple extensions flow": diff --git a/tests/helpers.nim b/tests/helpers.nim index 8ca07ce943..1308483c59 100644 --- a/tests/helpers.nim +++ b/tests/helpers.nim @@ -9,17 +9,16 @@ {.push raises: [].} -import std/[strutils, random] -import pkg/[ +import + std/[strutils, random], chronos, chronos/streams/tlsstream, httputils, - chronicles] - -import ../websock/websock + chronicles, + ../websock/websock import ./keys -{.push gcsafe, raises: [].} +{.push raises: [], gcsafe.} const WSPath* = when defined secure: "/wss" else: "/ws" @@ -50,7 +49,7 @@ proc createServer*( {.raises: [].} = try: let server = when defined secure: - TlsHttpServer.create( + HttpServer.create( address = address, tlsPrivateKey = tlsPrivateKey, tlsCertificate = tlsCertificate, diff --git a/tests/testframes.nim b/tests/testframes.nim index cece3ec386..7e672989f4 100644 --- a/tests/testframes.nim +++ b/tests/testframes.nim @@ -8,7 +8,7 @@ ## those terms. import - pkg/chronos/unittest2/asynctests + chronos/unittest2/asynctests include ../websock/frame @@ -16,7 +16,7 @@ include ../websock/frame suite "Test data frames": setup: - var maskKey {.used.} : array[4, char] + var maskKey {.used.} : array[4, byte] asyncTest "# 7bit length text": check (await Frame( @@ -246,14 +246,14 @@ suite "Test data frames": opcode: Opcode.Text, mask: true, data: toBytes("hi there"), - maskKey: ['\xCF', '\xD8', '\x05', 'e'] + maskKey: [byte 0xCF, 0xD8, 0x05, ord 'e'] ).encode()) check data == toBytes("\129\136\207\216\5e\167\177%\17\167\189w\0") suite "Test control frames": setup: - var maskKey {.used.} : array[4, char] + var maskKey {.used.} : array[4, byte] asyncTest "Close": check (await Frame( diff --git a/tests/testhooks.nim b/tests/testhooks.nim index 940e260f18..d8301cedc2 100644 --- a/tests/testhooks.nim +++ b/tests/testhooks.nim @@ -7,14 +7,11 @@ ## This file may not be copied, modified, or distributed except according to ## those terms. -import pkg/[ +import httputils, chronos/unittest2/asynctests, - ] - -import ../websock/websock - -import ./helpers + ../websock/websock, + ./helpers let address = initTAddress("127.0.0.1:8888") diff --git a/tests/testutf8.nim b/tests/testutf8.nim index b4d81c8ea0..a2c834a52d 100644 --- a/tests/testutf8.nim +++ b/tests/testutf8.nim @@ -9,11 +9,9 @@ import std/[strutils], - pkg/[ - stew/byteutils, - chronos/unittest2/asynctests, - chronicles - ], + stew/byteutils, + chronos/unittest2/asynctests, + chronicles, ../websock/[websock, utf8dfa] suite "UTF-8 DFA validator": diff --git a/tests/testwebsockets.nim b/tests/testwebsockets.nim index fa37269cee..31660cfccc 100644 --- a/tests/testwebsockets.nim +++ b/tests/testwebsockets.nim @@ -7,17 +7,13 @@ ## This file may not be copied, modified, or distributed except according to ## those terms. -import std/[ - random, - sequtils, - strutils] -import pkg/[ +import + std/[random, sequtils, strutils], httputils, chronos/unittest2/asynctests, chronicles, - stew/byteutils] - -import ../websock/websock + stew/byteutils, + ../websock/websock import ./helpers diff --git a/websock.nimble b/websock.nimble index dc90bb1814..0c5c03beb9 100644 --- a/websock.nimble +++ b/websock.nimble @@ -30,7 +30,7 @@ task test, "run tests": nimFlags = envNimFlags & " --verbosity:0 --hints:off --hint:Name:on " & "--styleCheck:usages --styleCheck:error" & - " -d:chronosStrictException --mm:refc" + " --mm:refc" # dont't need to run it, only want to test if it is compileable exec "nim c -c " & nimFlags & " -d:chronicles_log_level=TRACE -d:chronicles_sinks:json --styleCheck:usages --styleCheck:hint ./tests/all_tests" diff --git a/websock/extensions/compression/deflate.nim b/websock/extensions/compression/deflate.nim index 15af55c1e6..9e5b47af0b 100644 --- a/websock/extensions/compression/deflate.nim +++ b/websock/extensions/compression/deflate.nim @@ -7,14 +7,7 @@ ## This file may not be copied, modified, or distributed except according to ## those terms. -import - std/[strutils], - pkg/[results, - chronos, - chronicles, - zlib], - ../../types, - ../../frame +import std/[strutils], results, chronos, chronicles, zlib, ../../[frame, types] logScope: topics = "websock deflate" @@ -53,7 +46,17 @@ const ExtDeflateThreshold* = 1024 ExtDeflateDecompressLimit* = 10 shl 20 # 10mb -proc destroyExt(ext: DeflateExt) = +when not declared(newSeqUninit): # nim 2.2+ + template newSeqUninit[T: byte](len: int): seq[byte] = + newSeqUninitialized[byte](len) + +when not declared(setLenUninit): + template setLenUninit(src: var seq[byte], newlen: int) = + var tmp = newSeqUninit[byte](newlen) + copyMem(addr tmp[0], addr src[0], src.len) + src = move(tmp) + +proc destroyExt(ext: DeflateExt) {.nimcall.} = if ext.compCtxState != ContextState.Invalid: # zlib.deflateEnd somehow return DATA_ERROR # when compression succeed some cases. @@ -182,65 +185,71 @@ proc compressInit(ext: DeflateExt) = ext.compCtxState = ContextState.Initialized proc compress(zs: var ZStream, data: openArray[byte]): seq[byte] = - var buf: array[0xFFFF, byte] - # these casting is needed to prevent compilation # error with CLANG zs.next_in = cast[ptr uint8](data[0].unsafeAddr) zs.avail_in = data.len.cuint + result = newSeqUninit[byte](deflateBound(zs, data.len.culong).int + 10) + var added = 0 while true: - zs.next_out = cast[ptr uint8](buf[0].addr) - zs.avail_out = buf.len.cuint + let avail = result.len - added + zs.next_out = cast[ptr uint8](result[added].addr) + zs.avail_out = avail.cuint let r = zs.deflate(Z_SYNC_FLUSH) - let outSize = buf.len - zs.avail_out.int - result.add toOpenArray(buf, 0, outSize-1) + added += avail - zs.avail_out.int if r == Z_STREAM_END: break elif r == Z_OK: - # need more input or more output available + # Because we use `Z_SYNC_FLUSH`, we may need more than `deflateBound` bytes if zs.avail_in > 0 or zs.avail_out == 0: + result.setLenUninit(result.len + 128) continue else: break else: raise newException(WSExtError, "compression error " & $r) + result.setLen(added) proc decompress(zs: var ZStream, limit: int, data: openArray[byte]): seq[byte] = - var buf: array[0xFFFF, byte] - # these casting is needed to prevent compilation # error with CLANG zs.next_in = cast[ptr uint8](data[0].unsafeAddr) zs.avail_in = data.len.cuint + result = newSeqUninit[byte](min(max(data.len * 2, 65636), limit)) + + var added = 0 while true: - zs.next_out = cast[ptr uint8](buf[0].addr) - zs.avail_out = buf.len.cuint + let avail = result.len - added + zs.next_out = cast[ptr uint8](result[added].addr) + zs.avail_out = avail.cuint let r = zs.inflate(Z_NO_FLUSH) - let outSize = buf.len - zs.avail_out.int - result.add toOpenArray(buf, 0, outSize-1) - - if result.len > limit: - raise newException(WSExtError, "decompression exceeds allowed limit") + added += avail - zs.avail_out.int if r == Z_STREAM_END: break elif r == Z_OK: # need more input or more output available if zs.avail_in > 0 or zs.avail_out == 0: + if result.len == limit: + raise newException(WSExtError, "decompression exceeds allowed limit") + + result.setLenUninit(min(result.len + result.len div 2, limit)) continue else: break else: raise newException(WSExtError, "decompression error " & $r) - return result + result.setLen(added) -method decode(ext: DeflateExt, frame: Frame): Future[Frame] {.async.} = +method decode( + ext: DeflateExt, frame: Frame +): Future[Frame] {.async: (raises: [CancelledError, AsyncStreamError, WebSocketError]).} = if frame.opcode notin {Opcode.Text, Opcode.Binary, Opcode.Cont}: # only data frames can be decompressed return frame @@ -262,16 +271,16 @@ method decode(ext: DeflateExt, frame: Frame): Future[Frame] {.async.} = # even though the frame.data.len == 0, the stream needs # to be closed with trailing bytes if it's a final frame - var data: seq[byte] - var buf: array[0xFFFF, byte] + if frame.length > ext.session.frameSize.uint64: + raise newException(WSPayloadTooLarge, "payload exceeds allowed max frame size") - while data.len < frame.length.int: - let len = min(frame.length.int - data.len, buf.len) - let read = await frame.read(ext.session.stream.reader, addr buf[0], len) - data.add toOpenArray(buf, 0, read - 1) + var + data = newSeqUninit[byte](frame.length.int) + offset = 0 - if data.len > ext.session.frameSize: - raise newException(WSPayloadTooLarge, "payload exceeds allowed max frame size") + while offset < data.len: + offset += + await frame.read(ext.session.stream.reader, addr data[offset], data.len - offset) if frame.fin: data.add TrailingBytes @@ -294,7 +303,9 @@ method decode(ext: DeflateExt, frame: Frame): Future[Frame] {.async.} = return frame -method encode(ext: DeflateExt, frame: Frame): Future[Frame] {.async.} = +method encode( + ext: DeflateExt, frame: Frame +): Future[Frame] {.async: (raises: [CancelledError, WebSocketError]).} = if frame.opcode notin {Opcode.Text, Opcode.Binary, Opcode.Cont}: # only data frames can be compressed return frame diff --git a/websock/extensions/extutils.nim b/websock/extensions/extutils.nim index f9eba08590..c90720a4e2 100644 --- a/websock/extensions/extutils.nim +++ b/websock/extensions/extutils.nim @@ -9,7 +9,7 @@ import std/strutils, - pkg/httputils, + httputils, ../types type diff --git a/websock/frame.nim b/websock/frame.nim index 9ba561d375..737bc216d5 100644 --- a/websock/frame.nim +++ b/websock/frame.nim @@ -7,16 +7,14 @@ ## This file may not be copied, modified, or distributed except according to ## those terms. -{.push gcsafe, raises: [].} +{.push raises: [], gcsafe.} -import pkg/[ +import chronos, chronicles, results, - stew/byteutils, - stew/endians2] - -import ./types + stew/[byteutils, endians2, objects], + ./types logScope: topics = "websock ws-frame" @@ -51,16 +49,14 @@ proc mask*( ## for i in 0 ..< data.len: - data[i] = (data[i].uint8 xor maskKey[(offset + i) mod 4].uint8) + data[i] = (data[i] xor maskKey[(offset + i) mod 4]) template remainder*(frame: Frame): uint64 = frame.length - frame.consumed proc read*( - frame: Frame, - reader: AsyncStreamReader, - pbytes: pointer, - nbytes: int): Future[int] {.async.} = + frame: Frame, reader: AsyncStreamReader, pbytes: pointer, nbytes: int +): Future[int] {.async: (raises: [CancelledError, AsyncStreamError]).} = # read data from buffered payload if available # e.g. data processed by extensions @@ -70,7 +66,10 @@ proc read*( copyMem(pbytes, addr frame.data[frame.offset], readLen) frame.offset += readLen - var pbuf = cast[ptr UncheckedArray[byte]](pbytes) + if frame.offset == frame.data.len: + frame.data.reset() + + let pbuf = cast[ptr UncheckedArray[byte]](pbytes) if readLen < nbytes: let len = min(nbytes - readLen, frame.remainder.int - readLen) readLen += await reader.readOnce(addr pbuf[readLen], len) @@ -86,15 +85,17 @@ proc read*( return readLen proc encode*( - frame: Frame, - extensions: seq[Ext] = @[]): Future[seq[byte]] {.async.} = + frame: Frame, extensions: seq[Ext] = @[] +): Future[seq[byte]] {. + async: (raises: [CancelledError, AsyncStreamError, WebSocketError]) +.} = ## Encodes a frame into a string buffer. ## See https://tools.ietf.org/html/rfc6455#section-5.2 var f = frame if extensions.len > 0: - for e in extensions: - f = await e.encode(f) + for ext in extensions: + f = await ext.encode(f) var ret: seq[byte] var b0 = (f.opcode.uint8 and 0x0f) # 0th byte: opcodes and flags. @@ -152,58 +153,49 @@ proc encode*( return ret proc decode*( - _: typedesc[Frame], - reader: AsyncStreamReader, - masked: bool, - extensions: seq[Ext] = @[]): Future[Frame] {.async.} = + _: typedesc[Frame], + reader: AsyncStreamReader, + masked: bool, + extensions: seq[Ext] = @[], +): Future[Frame] {.async: (raises: [CancelledError, AsyncStreamError, WebSocketError]).} = ## Read and Decode incoming header ## - var header = newSeq[byte](2) + var header {.noinit.}: array[2, byte] trace "Reading new frame" await reader.readExactly(addr header[0], 2) - if header.len != 2: - trace "Invalid websocket header length" - raise newException(WSMalformedHeaderError, - "Invalid websocket header length") - let b0 = header[0].uint8 - let b1 = header[1].uint8 + let b0 = header[0] + let b1 = header[1] var frame = Frame() # Read the flags and fin from the header. - var hf = cast[HeaderFlags](b0 shr 4) + let hf = cast[HeaderFlags](b0 shr 4) frame.fin = HeaderFlag.fin in hf frame.rsv1 = HeaderFlag.rsv1 in hf frame.rsv2 = HeaderFlag.rsv2 in hf frame.rsv3 = HeaderFlag.rsv3 in hf - let opcode = (b0 and 0x0f) - if opcode > ord(Opcode.Pong): + if not checkedEnumAssign(frame.opcode, b0 and 0x0f): raise newException(WSOpcodeMismatchError, "Wrong opcode!") - frame.opcode = (opcode).Opcode - # Payload length can be 7 bits, 7+16 bits, or 7+64 bits. - var finalLen: uint64 = 0 - - let headerLen = uint(b1 and 0x7f) - if headerLen == 0x7e: - # Length must be 7+16 bits. - var length = newSeq[byte](2) - await reader.readExactly(addr length[0], 2) - finalLen = uint16.fromBytesBE(length) - elif headerLen == 0x7f: - # Length must be 7+64 bits. - var length = newSeq[byte](8) - await reader.readExactly(addr length[0], 8) - finalLen = uint64.fromBytesBE(length) - else: - # Length must be 7 bits. - finalLen = headerLen - - frame.length = finalLen + let headerLen = b1 and 0x7f + frame.length = + if headerLen == 0x7e: + # Length must be 7+16 bits. + var length {.noinit.}: array[2, byte] + await reader.readExactly(addr length[0], length.len) + uint64(uint16.fromBytesBE(length)) + elif headerLen == 0x7f: + # Length must be 7+64 bits. + var length {.noinit.}: array[8, byte] + await reader.readExactly(addr length[0], length.len) + uint64.fromBytesBE(length) + else: + # Length must be 7 bits. + uint64(headerLen) if frame.length > WSMaxMessageSize: raise newException(WSPayloadLengthError, "Frame too big: " & $frame.length) @@ -213,15 +205,11 @@ proc decode*( if masked == frame.mask: # Server sends unmasked but accepts only masked. # Client sends masked but accepts only unmasked. - raise newException(WSMaskMismatchError, - "Socket mask mismatch") + raise newException(WSMaskMismatchError, "Socket mask mismatch") - var maskKey = newSeq[byte](4) if frame.mask: # Read the mask. - await reader.readExactly(addr maskKey[0], 4) - for i in 0.. 0: for i in countdown(extensions.high, extensions.low): diff --git a/websock/http.nim b/websock/http.nim index d076a2016a..79ad6bd5df 100644 --- a/websock/http.nim +++ b/websock/http.nim @@ -7,12 +7,11 @@ ## This file may not be copied, modified, or distributed except according to ## those terms. -import std/uri -import pkg/[ +import + std/uri, chronos/apps/http/httptable, chronos/streams/tlsstream, - httputils] - -import ./http/client, ./http/server, ./http/common + httputils, + ./http/[client, common, server] export uri, httputils, client, server, httptable, tlsstream, common diff --git a/websock/http/client.nim b/websock/http/client.nim index e4fc61a5cd..a56eafcff9 100644 --- a/websock/http/client.nim +++ b/websock/http/client.nim @@ -7,16 +7,9 @@ ## This file may not be copied, modified, or distributed except according to ## those terms. -{.push gcsafe, raises: [].} +{.push raises: [], gcsafe.} -import std/[uri, strutils] -import pkg/[ - chronos, - chronicles, - httputils, - stew/byteutils] - -import ./common +import std/[uri, strutils], chronos, chronicles, httputils, stew/byteutils, ./common logScope: topics = "websock http-client" @@ -36,29 +29,11 @@ type minVersion*: TLSVersion maxVersion*: TLSVersion -proc close*(client: HttpClient): Future[void] = +proc closeWait*(client: HttpClient): Future[void] {.async: (raises: [], raw: true).} = client.stream.closeWait() -proc readResponse(stream: AsyncStreamReader): Future[HttpResponseHeader] {.async.} = - var buffer = newSeq[byte](MaxHttpHeadersSize) - try: - let - hlenfut = stream.readUntil( - addr buffer[0], MaxHttpHeadersSize, sep = HeaderSep) - ores = await withTimeout(hlenfut, HttpHeadersTimeout) - - if not ores: - raise newException(HttpError, - "Timeout expired while receiving headers") - - let hlen = hlenfut.read() - buffer.setLen(hlen) - - return buffer.parseResponse() - except CatchableError as exc: - trace "Exception reading headers", exc = exc.msg - buffer.setLen(0) - raise exc +proc close*(client: HttpClient): Future[void] {.deprecated: "closeWait".} = + client.closeWait() proc generateHeaders( requestUrl: Uri, @@ -85,11 +60,13 @@ proc generateHeaders( return headersData proc request*( - client: HttpClient, - url: string | Uri, - httpMethod = MethodGet, - headers: HttpTables, - body: seq[byte] = @[]): Future[HttpResponse] {.async.} = + client: HttpClient, + url: string | Uri, + httpMethod = MethodGet, + headers: HttpTables, +): Future[HttpResponse] {. + async: (raises: [CancelledError, AsyncStreamError, HttpError]) +.} = ## Helper that actually makes the request. ## Does not handle redirects. ## @@ -104,30 +81,31 @@ proc request*( url let headerString = generateHeaders(requestUrl, httpMethod, client.version, headers) - await client.stream.writer.write(headerString) - let response = await client.stream.reader.readResponse() - let headers = - block: - var res = HttpTable.init() - for key, value in response.headers(): - res.add(key, value) - res + + let + header = + try: + await client.stream.reader.readHttpHeader().wait(HttpHeadersTimeout) + except AsyncTimeoutError: + raise newException(HttpError, "Timeout expired while receiving headers") + response = header.parseResponse() return HttpResponse( - headers: headers, + headers: response.toHttpTable(), stream: client.stream, code: response.code, reason: response.reason()) proc connect*( - T: typedesc[HttpClient | TlsHttpClient], - address: TransportAddress, - version = HttpVersion11, - tlsFlags: set[TLSFlags] = {}, - tlsMinVersion = TLSVersion.TLS12, - tlsMaxVersion = TLSVersion.TLS12, - hostName = ""): Future[T] {.async.} = + T: typedesc[HttpClient | TlsHttpClient], + address: TransportAddress, + version = HttpVersion11, + tlsFlags: set[TLSFlags] = {}, + tlsMinVersion = TLSVersion.TLS12, + tlsMaxVersion = TLSVersion.TLS12, + hostName = "", +): Future[T] {.async: (raises: [CancelledError, AsyncStreamError, TransportError]).} = let transp = await connect(address) let client = T( @@ -163,25 +141,23 @@ proc connect*( return client proc connect*( - T: typedesc[HttpClient | TlsHttpClient], - host: string, - version = HttpVersion11, - tlsFlags: set[TLSFlags] = {}, - tlsMinVersion = TLSVersion.TLS12, - tlsMaxVersion = TLSVersion.TLS12, - hostName = ""): Future[T] - {.async: (raises: [CatchableError, HttpError]).} = - + T: typedesc[HttpClient | TlsHttpClient], + host: string, + version = HttpVersion11, + tlsFlags: set[TLSFlags] = {}, + tlsMinVersion = TLSVersion.TLS12, + tlsMaxVersion = TLSVersion.TLS12, + hostName = "", +): Future[T] {. + async: (raises: [CancelledError, AsyncStreamError, HttpError, TransportError]) +.} = let wantedHostName = if hostName.len > 0: hostName else: host.split(":")[0] - template used(x: typed) = - # silence unused warning - discard - let addrs = resolveTAddress(host) + var lastException: ref TransportError for a in addrs: try: let conn = await T.connect( @@ -194,8 +170,11 @@ proc connect*( return conn except TransportError as exc: - used(exc) trace "Error connecting to address", address = $a, exc = exc.msg + lastException = exc + + if lastException != nil: + raise lastException raise newException(HttpError, "Unable to connect to host on any address!") diff --git a/websock/http/common.nim b/websock/http/common.nim index ba21332174..c81c8015e9 100644 --- a/websock/http/common.nim +++ b/websock/http/common.nim @@ -7,18 +7,16 @@ ## This file may not be copied, modified, or distributed except according to ## those terms. -{.push gcsafe, raises: [].} +{.push raises: [], gcsafe.} -import std/[uri] -import pkg/[ +import + std/[uri], chronos, httputils, stew/byteutils, - chronicles] - -import pkg/[ + chronicles, chronos/apps/http/httptable, - chronos/streams/tlsstream] + chronos/streams/tlsstream export httputils, httptable, tlsstream, uri @@ -29,6 +27,7 @@ const MaxHttpHeadersSize* = 8192 # maximum size of HTTP headers in octets MaxHttpRequestSize* = 128 * 1024 # maximum size of HTTP body in octets HttpHeadersTimeout* = 120.seconds # timeout for receiving headers (120 sec) + HttpErrorTimeout* = 2.seconds # How long we wait for error sending to complete HeaderSep* = @[byte('\c'), byte('\L'), byte('\c'), byte('\L')] CRLF* = "\r\n" @@ -53,37 +52,59 @@ type HttpError* = object of CatchableError HttpHeaderError* = HttpError -proc closeTransp*(transp: StreamTransport) {.async.} = +when not declared(newSeqUninit): # nim 2.2+ + template newSeqUninit[T: byte](len: int): seq[byte] = + newSeqUninitialized[byte](len) + +proc add(v: var seq[byte], data: string) = + v.add data.toOpenArrayByte(0, data.high()) + +proc closeTransp*(transp: StreamTransport) {.async, deprecated.} = if not transp.closed(): await transp.closeWait() -proc closeStream*(stream: AsyncStreamRW) {.async.} = +proc closeStream*(stream: AsyncStreamRW) {.async, deprecated.} = if not stream.closed(): await stream.closeWait() -proc closeWait*(stream: AsyncStream) {.async.} = - await allFutures( - stream.reader.closeStream(), - stream.writer.closeStream(), - stream.reader.tsource.closeTransp()) +proc closeWait*(stream: AsyncStream) {.async: (raises: []).} = + await noCancel allFutures(stream.reader.closeWait(), stream.writer.closeWait()) + await stream.reader.tsource.closeWait() proc close*(stream: AsyncStream) = stream.reader.close() stream.writer.close() stream.reader.tsource.close() +proc readHttpHeader*( + stream: AsyncStreamReader +): Future[seq[byte]] {.async: (raises: [CancelledError, AsyncStreamError]).} = + var buffer = newSeqUninit[byte](MaxHttpHeadersSize) + let hlen = await stream.readUntil(addr buffer[0], MaxHttpHeadersSize, sep = HeaderSep) + buffer.setLen(hlen) + buffer + +func toHttpTable*(header: HttpRequestHeader | HttpResponseHeader): HttpTable = + var res = HttpTable.init() + for key, value in header.headers(): + res.add(key, value) + res + proc sendResponse*( - request: HttpRequest, - code: HttpCode, - headers: HttpTables = HttpTable.init(), - data: seq[byte] = @[], - version = HttpVersion11, - content = "") {.async.} = + request: HttpRequest, + code: HttpCode, + headers: HttpTables = HttpTable.init(), + data: openArray[byte] = @[], + version = HttpVersion11, + content = "", +) {.async: (raises: [CancelledError, AsyncStreamError], raw: true).} = ## Send response ## var headers = headers - var response: string = $version + var response = newSeqOfCap[byte](1024 + data.len) + + response.add($version) response.add(" ") response.add($code) response.add(CRLF) @@ -105,33 +126,42 @@ proc sendResponse*( response.add(CRLF) response.add(CRLF) - await request.stream.writer.write( - response.toBytes() & data) + response.add(data) + + request.stream.writer.write(response) proc sendResponse*( - request: HttpRequest, - code: HttpCode, - headers: HttpTables = HttpTable.init(), - data: string, - version = HttpVersion11, - content = ""): Future[void] = - request.sendResponse(code, headers, data.toBytes(), version, content) + request: HttpRequest, + code: HttpCode, + headers: HttpTables = HttpTable.init(), + data: string, + version = HttpVersion11, + content = "", +): Future[void] {.async: (raises: [CancelledError, AsyncStreamError], raw: true).} = + request.sendResponse( + code, headers, data.toOpenArrayByte(0, data.high()), version, content + ) proc sendError*( - stream: AsyncStreamWriter, - code: HttpCode, - version = HttpVersion11) {.async.} = - let content = $code - var response: string = $version + stream: AsyncStreamWriter, code: HttpCode, version = HttpVersion11 +) {.async: (raises: [CancelledError]).} = + var response = newSeqOfCap[byte](1024) + response.add($version) response.add(" ") - response.add(content & CRLF) + response.add($code) response.add(CRLF) + response.add(CRLF) + response.add($code) - await stream.write( - response.toBytes() & content.toBytes()) + try: + # When sending errors, don't waste too much time on it.. + discard await stream.write(response).withTimeout(HttpErrorTimeout) + except AsyncStreamError: + # Ignore errors while sending error responses to not swallow the original + # error that caused us to want to send an error + discard proc sendError*( - request: HttpRequest, - code: HttpCode, - version = HttpVersion11): Future[void] = + request: HttpRequest, code: HttpCode, version = HttpVersion11 +): Future[void] {.async: (raises: [CancelledError], raw: true).} = request.stream.writer.sendError(code, version) diff --git a/websock/http/server.nim b/websock/http/server.nim index e662127e6b..e3a8c78f42 100644 --- a/websock/http/server.nim +++ b/websock/http/server.nim @@ -7,32 +7,24 @@ ## This file may not be copied, modified, or distributed except according to ## those terms. -{.push gcsafe, raises: [].} +{.push raises: [], gcsafe.} -import std/uri -import pkg/[ - chronos, - chronicles, - httputils] +import chronos, chronicles, httputils, ./common when isLogFormatUsed(json): import json_serialization/std/net as jsnet export jsnet -import ./common - logScope: topics = "websock http-server" type - HttpAsyncCallback* = proc (request: HttpRequest): - Future[void] {.closure, gcsafe, raises: [].} + HttpAsyncCallback* = proc(request: HttpRequest) {.async.} HttpServer* = ref object of StreamServer handler*: HttpAsyncCallback - handshakeTimeout*: Duration headersTimeout*: Duration - case secure*: bool: + case secure*: bool of true: tlsFlags*: set[TLSFlags] tlsPrivateKey*: TLSPrivateKey @@ -42,255 +34,194 @@ type else: discard - TlsHttpServer* = HttpServer + TlsHttpServer* {.deprecated.} = HttpServer template used(x: typed) = # silence unused warning discard -proc validateRequest( - stream: AsyncStreamWriter, - header: HttpRequestHeader): Future[ReqStatus] {.async.} = - ## Validate Request - ## - - if header.meth notin {MethodGet}: - trace "GET method is only allowed", address = stream.tsource.remoteAddress() - await stream.sendError(Http405, version = header.version) - return ReqStatus.Error - - var hlen = header.contentLength() - if hlen < 0 or hlen > MaxHttpRequestSize: - trace "Invalid header length", address = stream.tsource.remoteAddress() - await stream.sendError(Http413, version = header.version) - return ReqStatus.Error - - return ReqStatus.Success - -proc parseRequest( - server: HttpServer, - stream: AsyncStream): Future[HttpRequest] {.async.} = +proc readHttpRequest( + stream: AsyncStream, headersTimeout: Duration +): Future[HttpRequest] {. + async: (raises: [CancelledError, AsyncStreamError, HttpError]) +.} = ## Process transport data to the HTTP server ## + when chronicles.enabledLogLevel == LogLevel.TRACE: + let remoteAddr = + stream.reader.tsource.remoteAddress2().valueOr(default(TransportAddress)) + + trace "Received connection", remoteAddr + + let + requestData = + try: + await stream.reader.readHttpHeader().wait(headersTimeout) + except AsyncTimeoutError: + trace "Timeout expired while receiving headers", remoteAddr + await stream.writer.sendError(Http408, version = HttpVersion11) + raise newException(HttpError, "Didn't read headers in time!") + + request = requestData.parseRequest() + + if request.failed(): + # Header could not be parsed + trace "Malformed header received", remoteAddr + await stream.writer.sendError(Http400, version = HttpVersion11) + raise newException(HttpError, "Malformed header received") + + if request.meth != MethodGet: + trace "GET method is only allowed", remoteAddr + await stream.writer.sendError(Http405, version = request.version) + raise newException(HttpError, $Http405) + + let hlen = request.contentLength() + if hlen < 0 or hlen > MaxHttpRequestSize: + trace "Invalid header length", remoteAddr + await stream.writer.sendError(Http413, version = request.version) + raise newException(HttpError, $Http413) - var buffer = newSeq[byte](MaxHttpHeadersSize) - let remoteAddr {.used.} = stream.reader.tsource.remoteAddress() - trace "Received connection", address = $remoteAddr - try: - let hlenfut = stream.reader.readUntil( - addr buffer[0], MaxHttpHeadersSize, sep = HeaderSep) - let ores = await withTimeout(hlenfut, server.headersTimeout) - if not ores: - # Timeout - trace "Timeout expired while receiving headers", address = $remoteAddr - await stream.writer.sendError(Http408, version = HttpVersion11) - raise newException(HttpError, "Didn't read headers in time!") - - let hlen = hlenfut.read() - buffer.setLen(hlen) - let requestData = buffer.parseRequest() - if requestData.failed(): - # Header could not be parsed - trace "Malformed header received", address = $remoteAddr - await stream.writer.sendError(Http400, version = HttpVersion11) - raise newException(HttpError, "Malformed header received") - - var vres = await stream.writer.validateRequest(requestData) - let hdrs = - block: - var res = HttpTable.init() - for key, value in requestData.headers(): - res.add(key, value) - res - - if vres == ReqStatus.ErrorFailure: - trace "Remote peer disconnected", address = $remoteAddr - raise newException(HttpError, "Remote peer disconnected") - - trace "Received valid HTTP request", address = $remoteAddr - return HttpRequest( - headers: hdrs, - stream: stream, - uri: requestData.uri().parseUri()) - except TransportLimitError: - # size of headers exceeds `MaxHttpHeadersSize` - trace "maximum size of headers limit reached", address = $remoteAddr - await stream.writer.sendError(Http413, version = HttpVersion11) - except TransportIncompleteError: - # remote peer disconnected - trace "Remote peer disconnected", address = $remoteAddr - except TransportOsError as exc: - used(exc) - trace "Problems with networking", address = $remoteAddr, error = exc.msg - -proc handleConnCb( - server: StreamServer, - transp: StreamTransport) {.gcsafe, async: (raises: []).} = - var stream: AsyncStream - try: - stream = AsyncStream( - reader: newAsyncStreamReader(transp), - writer: newAsyncStreamWriter(transp)) - - let httpServer = HttpServer(server) - let request = await httpServer.parseRequest(stream) + trace "Received valid HTTP request", address = $remoteAddr + HttpRequest( + headers: request.toHttpTable(), stream: stream, uri: request.uri().parseUri() + ) - await httpServer.handler(request) - except CatchableError as exc: - used(exc) - debug "Exception in HttpHandler", exc = exc.msg - finally: +proc openAsyncStream( + server: HttpServer, transp: StreamTransport +): Result[AsyncStream, string] = + if server.secure: try: - await stream.closeWait() + let tlsStream = newTLSServerAsyncStream( + newAsyncStreamReader(transp), + newAsyncStreamWriter(transp), + server.tlsPrivateKey, + server.tlsCertificate, + minVersion = server.minVersion, + maxVersion = server.maxVersion, + flags = server.tlsFlags, + ) + + ok AsyncStream(reader: tlsStream.reader, writer: tlsStream.writer) except CatchableError as exc: - used(exc) - debug "Exception in HttpHandler closewait", exc = exc.msg - -proc handleTlsConnCb( - server: StreamServer, - transp: StreamTransport) {.gcsafe, async: (raises: []).} = - - let tlsHttpServer = TlsHttpServer(server) - var tlsStream: TLSAsyncStream - - try: - tlsStream = newTLSServerAsyncStream( - newAsyncStreamReader(transp), - newAsyncStreamWriter(transp), - tlsHttpServer.tlsPrivateKey, - tlsHttpServer.tlsCertificate, - minVersion = tlsHttpServer.minVersion, - maxVersion = tlsHttpServer.maxVersion, - flags = tlsHttpServer.tlsFlags) - except CatchableError as exc: - used(exc) - debug "Exception when initialize TLS stream", exc = exc.msg - return + err exc.msg + else: + ok AsyncStream( + reader: newAsyncStreamReader(transp), writer: newAsyncStreamWriter(transp) + ) - let stream = AsyncStream( - reader: tlsStream.reader, - writer: tlsStream.writer) +proc handleConnCb( + server: StreamServer, transp: StreamTransport +) {.async: (raises: []).} = + let + server = HttpServer(server) + stream = server.openAsyncStream(transp).valueOr: + debug "Failed to open streams", err = error + await transp.closeWait() + return try: - let httpServer = HttpServer(server) - let request = await httpServer.parseRequest(stream) + let request = await stream.readHttpRequest(server.headersTimeout) - await httpServer.handler(request) + await server.handler(request) except CatchableError as exc: used(exc) - debug "Exception in HttpsHandler", exc = exc.msg + debug "Exception in HttpHandler", exc = exc.msg finally: - try: - await stream.closeWait() - except CatchableError as exc: - used(exc) - debug "Exception in HttpsHandler closewait", exc = exc.msg + await stream.closeWait() +# TODO async raises not implemented for accept because it breaks libp2p (1.13.0 +# at the time of writing) proc accept*(server: HttpServer): Future[HttpRequest] {.async.} = if not isNil(server.handler): - raise newException(HttpError, - "Callback already registered - cannot mix callback and accepts styles!") + raise newException( + HttpError, "Callback already registered - cannot mix callback and accepts styles!" + ) trace "Awaiting new request" - let transp = await StreamServer(server).accept() - let stream = if server.secure: - let tlsStream = newTLSServerAsyncStream( - newAsyncStreamReader(transp), - newAsyncStreamWriter(transp), - server.tlsPrivateKey, - server.tlsCertificate, - minVersion = server.minVersion, - maxVersion = server.maxVersion, - flags = server.tlsFlags) - - AsyncStream( - reader: tlsStream.reader, - writer: tlsStream.writer) - else: - AsyncStream( - reader: newAsyncStreamReader(transp), - writer: newAsyncStreamWriter(transp)) + let + transp = await StreamServer(server).accept() + stream = server.openAsyncStream(transp).valueOr: + await transp.closeWait() + raise (ref HttpError)(msg: error) trace "Got new request", isTls = server.secure try: - let - parseFut = server.parseRequest(stream) - if await withTimeout(parseFut, server.handshakeTimeout): - return parseFut.read() - raise newException(HttpError, "Timed out parsing request") - except CatchableError as exc: - # Can't hold up the accept loop - stream.close() + await stream.readHttpRequest(server.headersTimeout) + except CancelledError as exc: + await stream.closeWait() + raise exc + except AsyncStreamError as exc: + await stream.closeWait() + raise exc + except HttpError as exc: + await stream.closeWait() raise exc - proc create*( - _: typedesc[HttpServer], - address: TransportAddress | string, - handler: HttpAsyncCallback = nil, - flags: set[ServerFlags] = {}, - headersTimeout = HttpHeadersTimeout, - handshakeTimeout = 0.seconds - ): HttpServer - {.raises: [CatchableError].} = # TODO: remove CatchableError + _: typedesc[HttpServer], + address: TransportAddress | string, + handler: HttpAsyncCallback = nil, + flags: set[ServerFlags] = {}, + headersTimeout = HttpHeadersTimeout, +): HttpServer {.raises: [TransportOsError].} = ## Make a new HTTP Server ## - var server = HttpServer( - handler: handler, - headersTimeout: headersTimeout, - handshakeTimeout: - if handshakeTimeout == 0.seconds: - # default to headersTimeout * 1.05 - headersTimeout + (headersTimeout div 20) - else: handshakeTimeout, - ) - let localAddress = when address is string: initTAddress(address) else: address + var server = HttpServer(handler: handler, headersTimeout: headersTimeout) + server = HttpServer( - createStreamServer( - localAddress, - handleConnCb, - flags, - child = StreamServer(server))) + createStreamServer(localAddress, handleConnCb, flags, child = StreamServer(server)) + ) trace "Created HTTP Server", host = $server.localAddress() - return server + server proc create*( - _: typedesc[TlsHttpServer], - address: TransportAddress | string, - tlsPrivateKey: TLSPrivateKey, - tlsCertificate: TLSCertificate, - handler: HttpAsyncCallback = nil, - flags: set[ServerFlags] = {}, - tlsFlags: set[TLSFlags] = {}, - tlsMinVersion = TLSVersion.TLS12, - tlsMaxVersion = TLSVersion.TLS12, - headersTimeout = HttpHeadersTimeout, - handshakeTimeout = 0.seconds - ): TlsHttpServer - {.raises: [CatchableError].} = # TODO: remove CatchableError - - var server = TlsHttpServer( + _: typedesc[HttpServer], + address: TransportAddress | string, + handler: HttpAsyncCallback = nil, + flags: set[ServerFlags] = {}, + headersTimeout = HttpHeadersTimeout, + handshakeTimeout: Duration, +): HttpServer {. + raises: [TransportOsError], + deprecated: "Use headersTimeout instead of handshakeTimeout" +.} = + let headersTimeout = + if handshakeTimeout > 0.seconds: + min(handshakeTimeout, headersTimeout) + else: + headersTimeout + HttpServer.create(address, handler, flags, headersTimeout) + +proc create*( + _: typedesc[HttpServer], + address: TransportAddress | string, + tlsPrivateKey: TLSPrivateKey, + tlsCertificate: TLSCertificate, + handler: HttpAsyncCallback = nil, + flags: set[ServerFlags] = {}, + tlsFlags: set[TLSFlags] = {}, + tlsMinVersion = TLSVersion.TLS12, + tlsMaxVersion = TLSVersion.TLS12, + headersTimeout = HttpHeadersTimeout, +): HttpServer {.raises: [TransportOsError].} = + var server = HttpServer( headersTimeout: headersTimeout, - handshakeTimeout: - if handshakeTimeout == 0.seconds: - # default to headersTimeout * 1.05 - headersTimeout + (headersTimeout div 20) - else: handshakeTimeout, secure: true, handler: handler, tlsPrivateKey: tlsPrivateKey, tlsCertificate: tlsCertificate, minVersion: tlsMinVersion, - maxVersion: tlsMaxVersion) + maxVersion: tlsMaxVersion, + ) let localAddress = when address is string: @@ -298,13 +229,43 @@ proc create*( else: address - server = TlsHttpServer( - createStreamServer( - localAddress, - handleTlsConnCb, - flags, - child = StreamServer(server))) + server = HttpServer( + createStreamServer(localAddress, handleConnCb, flags, child = StreamServer(server)) + ) trace "Created TLS HTTP Server", host = $server.localAddress() - return server + server + +proc create*( + _: typedesc[HttpServer], + address: TransportAddress | string, + tlsPrivateKey: TLSPrivateKey, + tlsCertificate: TLSCertificate, + handler: HttpAsyncCallback = nil, + flags: set[ServerFlags] = {}, + tlsFlags: set[TLSFlags] = {}, + tlsMinVersion = TLSVersion.TLS12, + tlsMaxVersion = TLSVersion.TLS12, + headersTimeout = HttpHeadersTimeout, + handshakeTimeout: Duration, +): HttpServer {. + raises: [TransportOsError], + deprecated: "Use headersTimeout instead of handshakeTimeout" +.} = + let headersTimeout = + if handshakeTimeout > 0.seconds: + min(handshakeTimeout, headersTimeout) + else: + headersTimeout + + HttpServer.create( + address, tlsPrivateKey, tlsCertificate, handler, flags, tlsFlags, tlsMinVersion, + tlsMaxVersion, headersTimeout, + ) + +proc handshakeTimeout*(s: HttpServer): Duration {.deprecated: "headersTimeout".} = + s.headersTimeout + +proc `handshakeTimeout=`*(s: HttpServer, v: Duration) {.deprecated: "headersTimeout".} = + s.headersTimeout = v diff --git a/websock/session.nim b/websock/session.nim index e93381248e..73897ea54b 100644 --- a/websock/session.nim +++ b/websock/session.nim @@ -7,13 +7,15 @@ ## This file may not be copied, modified, or distributed except according to ## those terms. -{.push gcsafe, raises: [].} +{.push raises: [], gcsafe.} -import std/strformat -import pkg/[chronos, chronicles, stew/byteutils, stew/endians2] -import ./types, ./frame, ./utf8dfa, ./http - -import pkg/chronos/streams/asyncstream +import + std/strformat, + chronos, + chronicles, + stew/byteutils, + stew/endians2, + ./[frame, types, utf8dfa, http] logScope: topics = "websock ws-session" @@ -27,43 +29,41 @@ proc prepareCloseBody(code: StatusCodes, reason: string): seq[byte] = if ord(code) > 999: result = @(ord(code).uint16.toBytesBE()) & result -proc writeMessage(ws: WSSession, - data: seq[byte] = @[], - opcode: Opcode, - maskKey: MaskKey, - extensions: seq[Ext]) {.async.} = - - if opcode notin {Opcode.Text, Opcode.Cont, Opcode.Binary}: - warn "Attempting to send a data frame with an invalid opcode!" - raise newException(WSInvalidOpcodeError, - &"Attempting to send a data frame with an invalid opcode {opcode}!") - +proc writeMessage( + ws: WSSession, + data: seq[byte], + opcode: Opcode, + maskKey: MaskKey, + extensions: seq[Ext], +) {.async: (raises: [CancelledError, AsyncStreamError, WebSocketError]).} = let maxSize = ws.frameSize - var i = 0 + var sent = 0 while ws.readyState notin {ReadyState.Closing, ReadyState.Closed}: - let canSend = min(data.len - i, maxSize) - let frame = Frame( - fin: if (canSend + i >= data.len): true else: false, + let + canSend = min(data.len - sent, maxSize) + # fragments have to be `Continuation` frames + opcode = if sent > 0: Opcode.Cont else: opcode + frame = Frame( + fin: if (canSend + sent >= data.len): true else: false, rsv1: false, rsv2: false, rsv3: false, - opcode: if i > 0: Opcode.Cont else: opcode, # fragments have to be `Continuation` frames + opcode: opcode, mask: ws.masked, - data: data[i ..< canSend + i], - maskKey: maskKey) + data: data[sent ..< canSend + sent], + maskKey: maskKey, + ) + encoded = await frame.encode(extensions) - let encoded = await frame.encode(extensions) await ws.stream.writer.write(encoded) - i += canSend - if i >= data.len: + sent += canSend + if sent >= data.len: break proc writeControl( - ws: WSSession, - data: seq[byte] = @[], - opcode: Opcode, - maskKey: MaskKey) {.async.} = + ws: WSSession, data: seq[byte], opcode: Opcode, maskKey: MaskKey +) {.async: (raises: [CancelledError, AsyncStreamError, WebSocketError]).} = ## Send a frame applying the supplied ## extensions ## @@ -73,12 +73,8 @@ proc writeControl( dataSize = data.len masked = ws.masked - if opcode in {Opcode.Text, Opcode.Cont, Opcode.Binary}: - warn "Attempting to send a control frame with an invalid opcode!" - raise newException(WSInvalidOpcodeError, - &"Attempting to send a control frame with an invalid opcode {opcode}!") - - let frame = Frame( + let + frame = Frame( fin: true, rsv1: false, rsv2: false, @@ -86,21 +82,16 @@ proc writeControl( opcode: opcode, mask: ws.masked, data: data, - maskKey: maskKey) - - let encoded = await frame.encode() + maskKey: maskKey, + ) + encoded = await frame.encode() await ws.stream.writer.write(encoded) trace "Wrote control frame" -func isControl(opcode: Opcode): bool = - opcode notin {Opcode.Text, Opcode.Cont, Opcode.Binary} - -proc nonCancellableSend( - ws: WSSession, - data: seq[byte] = @[], - opcode: Opcode): Future[void] - {.async.} = +proc doSend( + ws: WSSession, data: seq[byte], opcode: Opcode +): Future[void] {.async: (raises: [CancelledError, AsyncStreamError, WebSocketError]).} = ## Send a frame ## @@ -124,54 +115,42 @@ proc nonCancellableSend( else: default(MaskKey) - if opcode.isControl: - await ws.writeControl(data, opcode, maskKey) - else: - await ws.writeMessage(data, opcode, maskKey, ws.extensions) - -proc doSend( - ws: WSSession, - data: seq[byte] = @[], - opcode: Opcode - ): Future[void] = - let - retFut = newFuture[void]("doSend") - sendFut = ws.nonCancellableSend(data, opcode) - - proc handleSend {.async.} = - try: - await sendFut - retFut.complete() - except CatchableError as exc: - retFut.fail(exc) + let writeFut = + case opcode + of ControlOpcodes: + ws.writeControl(data, opcode, maskKey) + of MessageOpcodes: + ws.writeMessage(data, opcode, maskKey, ws.extensions) + await writeFut - asyncSpawn handleSend() - retFut - -proc sendLoop(ws: WSSession) {.gcsafe, async.} = +proc sendLoop(ws: WSSession) {.async: (raises: []).} = while ws.sendQueue.len > 0: let task = ws.sendQueue.popFirst() if task.fut.cancelled: continue try: - await ws.doSend(task.data, task.opcode) + await noCancel ws.doSend(task.data, task.opcode) task.fut.complete() - except CatchableError as exc: + except AsyncStreamError as exc: + task.fut.fail(exc) + except WebSocketError as exc: task.fut.fail(exc) proc send*( - ws: WSSession, - data: seq[byte] = @[], - opcode: Opcode): Future[void] = - if opcode.isControl: + ws: WSSession, data: seq[byte] = @[], opcode: Opcode +): Future[void] {. + async: (raises: [CancelledError, AsyncStreamError, WebSocketError], raw: true) +.} = + + if opcode in ControlOpcodes: # Control frames (see Section 5.5) MAY be injected in the middle of # a fragmented message. Control frames themselves MUST NOT be # fragmented. # See RFC 6455 Section 5.4 Fragmentation return ws.doSend(data, opcode) - let fut = newFuture[void]("send") + let fut = WSSendFuture.init("send") ws.sendQueue.addLast (data: data, opcode: opcode, fut: fut) @@ -181,14 +160,15 @@ proc send*( fut proc send*( - ws: WSSession, - data: string): Future[void] = + ws: WSSession, data: string +): Future[void] {. + async: (raises: [CancelledError, AsyncStreamError, WebSocketError], raw: true) +.} = send(ws, data.toBytes(), Opcode.Text) -proc handleClose*( - ws: WSSession, - frame: Frame, - payload: seq[byte] = @[]) {.async.} = +proc handleClose( + ws: WSSession, frame: Frame, payload: seq[byte] +) {.async: (raises: [CancelledError, AsyncStreamError, WebSocketError]).} = ## Handle close sequence ## @@ -215,11 +195,12 @@ proc handleClose*( raise newException(WSPayloadLengthError, "Invalid close frame with payload length 1!") else: - try: - code = StatusCodes(uint16.fromBytesBE(payload[0..<2])) - except RangeDefect: - raise newException(WSInvalidCloseCodeError, - "Status code out of range!") + let code = block: + let v = uint16.fromBytesBE(payload.toOpenArray(0, 1)) + if v > StatusCodes.high().uint16: + raise newException(WSInvalidCloseCodeError, + "Status code out of range!") + cast[StatusCodes](v) if code in StatusNotUsed or code in StatusReservedProtocol: @@ -233,7 +214,7 @@ proc handleClose*( &"Can't use reserved status code: {code}") # remaining payload bytes are reason for closing - reason = string.fromBytes(payload[2..payload.high]) + reason = string.fromBytes(payload.toOpenArray(0, payload.high)) if not ws.binary and validateUTF8(reason) == false: raise newException(WSInvalidUTF8, @@ -241,11 +222,7 @@ proc handleClose*( trace "Handling close message", code = ord(code), reason if not isNil(ws.onClose): - try: - (code, reason) = ws.onClose(code, reason) - except CatchableError as exc: - used(exc) - trace "Exception in Close callback, this is most likely a bug", exc = exc.msg + (code, reason) = ws.onClose(code, reason) else: code = StatusFulfilled reason = "" @@ -270,7 +247,9 @@ proc handleClose*( await sleepAsync(10.millis) await ws.stream.closeWait() -proc handleControl*(ws: WSSession, frame: Frame) {.async.} = +proc handleControl( + ws: WSSession, frame: Frame +) {.async: (raises: [CancelledError, AsyncStreamError, WebSocketError]).} = ## Handle control frames ## @@ -293,48 +272,37 @@ proc handleControl*(ws: WSSession, frame: Frame) {.async.} = var payload = newSeq[byte](frame.length.int) if frame.length > 0: - payload.setLen(frame.length.int) # Read control frame payload. - await ws.stream.reader.readExactly(addr payload[0], frame.length.int) + await ws.stream.reader.readExactly(addr payload[0], payload.len) + if frame.mask: - mask( - payload.toOpenArray(0, payload.high), - frame.maskKey) + mask(payload.toOpenArray(0, payload.high), frame.maskKey) # Process control frame payload. case frame.opcode: of Opcode.Ping: if not isNil(ws.onPing): - try: - ws.onPing(payload) - except CatchableError as exc: - used(exc) - trace "Exception in Ping callback, this is most likely a bug", exc = exc.msg + ws.onPing(payload) # send pong to remote await ws.send(payload, Opcode.Pong) of Opcode.Pong: if not isNil(ws.onPong): - try: - ws.onPong(payload) - except CatchableError as exc: - used(exc) - trace "Exception in Pong callback, this is most likely a bug", exc = exc.msg + ws.onPong(payload) of Opcode.Close: await ws.handleClose(frame, payload) else: raise newException(WSInvalidOpcodeError, "Invalid control opcode!") -{.warning[HoleEnumConv]:off.} - -proc readFrame*(ws: WSSession, extensions: seq[Ext] = @[]): Future[Frame] {.async.} = +proc readFrame*( + ws: WSSession, extensions: seq[Ext] = @[] +): Future[Frame] {.async: (raises: [CancelledError, AsyncStreamError, WebSocketError]).} = ## Gets a frame from the WebSocket. ## See https://tools.ietf.org/html/rfc6455#section-5.2 ## while ws.readyState != ReadyState.Closed: - let frame = await Frame.decode( - ws.stream.reader, ws.masked, extensions) + let frame = await Frame.decode(ws.stream.reader, ws.masked, extensions) logScope: opcode = frame.opcode @@ -350,18 +318,18 @@ proc readFrame*(ws: WSSession, extensions: seq[Ext] = @[]): Future[Frame] {.asyn continue return frame - -{.warning[HoleEnumConv]:on.} + nil proc ping*( - ws: WSSession, - data: seq[byte] = @[]): Future[void] = + ws: WSSession, data: seq[byte] = @[] +) {.async: (raises: [CancelledError, AsyncStreamError, WebSocketError], raw: true).} = ws.send(data, opcode = Opcode.Ping) proc recv*( - ws: WSSession, - data: pointer | ptr byte | ref seq[byte], # nim bug: pointer doesn't match ptr byte? - size: int): Future[int] {.async.} = + ws: WSSession, + data: pointer | ptr byte | ref seq[byte], + size: int, +): Future[int] {.async: (raises: [CancelledError, AsyncStreamError, WebSocketError]).} = ## Attempts to read up to ``size`` bytes ## ## If ``size`` is less than the data in @@ -413,14 +381,16 @@ proc recv*( ws.binary = ws.frame.opcode == Opcode.Binary # set binary flag trace "Setting binary flag" - let len = min(ws.frame.remainder.int, size - consumed) - if len > 0: + while ws.frame.remainder > 0 and consumed < size: + let len = min(ws.frame.remainder.int, size - consumed) trace "Reading bytes from frame stream", len - when data is ref seq[byte]: - data[].setLen(consumed + len) - let read = await ws.frame.read(ws.stream.reader, addr data[][consumed], len) - else: - let read = await ws.frame.read(ws.stream.reader, addr pbuffer[consumed], len) + let pbuf = + when data is ref seq[byte]: + data[].setLen(consumed + len) + addr data[][consumed] + else: + addr pbuffer[consumed] + let read = await ws.frame.read(ws.stream.reader, pbuf, len) if read <= 0: trace "Didn't read any bytes, stopping" raise newException(WSClosedError, "WebSocket is closed!") @@ -440,23 +410,33 @@ proc recv*( # read next frame ws.frame = await ws.readFrame(ws.extensions) - except CatchableError as exc: + except CancelledError as exc: + # TODO should all these exceptions be handled the same?? + trace "Exception reading frames", exc = exc.msg + ws.readyState = ReadyState.Closed + await ws.stream.closeWait() + + raise exc + except AsyncStreamError as exc: + trace "Exception reading frames", exc = exc.msg + ws.readyState = ReadyState.Closed + await ws.stream.closeWait() + + raise exc + except WebSocketError as exc: trace "Exception reading frames", exc = exc.msg ws.readyState = ReadyState.Closed await ws.stream.closeWait() raise exc - finally: - if not isNil(ws.frame) and - (ws.frame.fin and ws.frame.remainder <= 0): - trace "Last frame in message and no more bytes left to read, reseting current frame" - ws.frame = nil return consumed proc recvMsg*( - ws: WSSession, - size = WSMaxMessageSize): Future[seq[byte]] {.async.} = + ws: WSSession, size = WSMaxMessageSize +): Future[seq[byte]] {. + async: (raises: [CancelledError, AsyncStreamError, WebSocketError]) +.} = ## Attempt to read a full message up to max `size` ## bytes in `frameSize` chunks. ## @@ -468,41 +448,25 @@ proc recvMsg*( ## ## In all other cases it awaits a full message. ## - try: - var res: seq[byte] - while ws.readyState != ReadyState.Closed: - var buf = new(seq[byte]) - let read {.used.} = await ws.recv(buf, min(size, ws.frameSize)) - - if res.len + buf[].len > size: - raise newException(WSMaxMessageSizeError, "Max message size exceeded") - - trace "Read message", size = read - res.add(buf[]) + var buf = new(seq[byte]) - # no more frames - if isNil(ws.frame): - break - - # read the entire message, exit - if ws.frame.fin and ws.frame.remainder <= 0: - trace "Read full message, breaking!" - break + # Read up to `size` bytes or until `fin`, whichever comes first + discard await ws.recv(buf, size - buf[].len) - if ws.readyState == ReadyState.Closed: - # avoid reporting incomplete message - raise newException(WSClosedError, "WebSocket is closed!") + if ws.readyState == ReadyState.Closed: + raise newException(WSClosedError, "WebSocket is closed!") - if not ws.binary and validateUTF8(res.toOpenArray(0, res.high)) == false: - raise newException(WSInvalidUTF8, "Invalid UTF8 sequence detected") + if not isNil(ws.frame): + # If `ws.frame` is not nil, it means we reached `size` bytes without + # receiving a `fin` + await ws.stream.closeWait() + raise newException(WSMaxMessageSizeError, "Max message size exceeded") - return res - except CatchableError as exc: - trace "Exception reading message", exc = exc.msg - ws.readyState = ReadyState.Closed + if not ws.binary and not validateUTF8(buf[].toOpenArray(0, buf[].high)): await ws.stream.closeWait() + raise newException(WSInvalidUTF8, "Invalid UTF8 sequence detected") - raise exc + return move(buf[]) proc recv*( ws: WSSession, @@ -511,9 +475,8 @@ proc recv*( ws.recvMsg(size) proc close*( - ws: WSSession, - code = StatusFulfilled, - reason: string = "") {.async.} = + ws: WSSession, code = StatusFulfilled, reason: string = "" +) {.async: (raises: [CancelledError]).} = ## Close the Socket, sends close packet. ## diff --git a/websock/types.nim b/websock/types.nim index 2589f5cb9f..46248357f5 100644 --- a/websock/types.nim +++ b/websock/types.nim @@ -7,21 +7,21 @@ ## This file may not be copied, modified, or distributed except according to ## those terms. -{.push gcsafe, raises: [].} +{.push raises: [], gcsafe.} -import std/deques -import pkg/[chronos, - chronos/streams/tlsstream, - chronos/apps/http/httptable, - bearssl/rand, - httputils, - results] +import + std/deques, + chronos, + chronos/streams/tlsstream, + chronos/apps/http/httptable, + bearssl/rand, + httputils, + results export deques, rand const SHA1DigestSize* = 20 - WSHeaderSize* = 12 WSDefaultVersion* = 13 WSDefaultFrameSize* = 1 shl 20 # 1mb WSMaxMessageSize* = 20 shl 20 # 20mb @@ -43,8 +43,6 @@ type Close = 0x8 ## Denotes a connection close. Ping = 0x9 ## Denotes a ping. Pong = 0xa ## Denotes a pong. - # B-F are reserved for further control frames. - Reserved = 0xf HeaderFlag* {.pure, size: sizeof(uint8).} = enum rsv3 @@ -54,7 +52,7 @@ type HeaderFlags* = set[HeaderFlag] - MaskKey* = array[4, char] + MaskKey* = array[4, byte] WebSecKey* = array[16, byte] Frame* = ref object @@ -96,6 +94,8 @@ type onPong*: ControlCb onClose*: CloseCb + WSSendFuture* = Future[void].Raising([CancelledError, AsyncStreamError, WebSocketError]) + WSSession* = ref object of WebSocket stream*: AsyncStream frame*: Frame @@ -108,8 +108,7 @@ type # negotiated that can interpret the interleaving. # See RFC 6455 Section 5.4 Fragmentation sendLoop*: Future[void] - sendQueue*: Deque[ - tuple[data: seq[byte], opcode: Opcode, fut: Future[void]]] + sendQueue*: Deque[tuple[data: seq[byte], opcode: Opcode, fut: WSSendFuture]] Ext* = ref object of RootObj name*: string @@ -142,18 +141,17 @@ type # 3. server reply with response header # 4. client verify response header from server Hook* = ref object of RootObj - append*: proc(ctx: Hook, - headers: var HttpTable): Result[void, string] - {.gcsafe, raises: [].} - verify*: proc(ctx: Hook, - headers: HttpTable): Future[Result[void, string]] - {.gcsafe, async: (raises: []).} + append*: proc(ctx: Hook, headers: var HttpTable): Result[void, string] {. + gcsafe, raises: [] + .} + verify*: proc(ctx: Hook, headers: HttpTable): Future[Result[void, string]] {. + async: (raises: []) + .} WebSocketError* = object of CatchableError WSMalformedHeaderError* = object of WebSocketError WSFailedUpgradeError* = object of WebSocketError WSVersionError* = object of WebSocketError - WSProtoMismatchError* = object of WebSocketError WSMaskMismatchError* = object of WebSocketError WSHandshakeError* = object of WebSocketError WSOpcodeMismatchError* = object of WebSocketError @@ -161,9 +159,7 @@ type WSWrongUriSchemeError* = object of WebSocketError WSMaxMessageSizeError* = object of WebSocketError WSClosedError* = object of WebSocketError - WSSendError* = object of WebSocketError WSPayloadTooLarge* = object of WebSocketError - WSReservedOpcodeError* = object of WebSocketError WSFragmentedControlFrameError* = object of WebSocketError WSInvalidCloseCodeError* = object of WebSocketError WSPayloadLengthError* = object of WebSocketError @@ -191,6 +187,10 @@ const StatusLibsCodes* = (StatusCodes(3000)..StatusCodes(3999)) # 3000-3999 reserved for libs StatusAppsCodes* = (StatusCodes(4000)..StatusCodes(4999)) # 4000-4999 reserved for apps +const + ControlOpcodes* = {Opcode.Close, Opcode.Ping, Opcode.Pong} + MessageOpcodes* = {Opcode.Cont, Opcode.Text, Opcode.Binary} + proc `<=`*(a, b: StatusCodes): bool = a.uint16 <= b.uint16 proc `>=`*(a, b: StatusCodes): bool = a.uint16 >= b.uint16 proc `<`*(a, b: StatusCodes): bool = a.uint16 < b.uint16 @@ -205,14 +205,20 @@ proc `$`*(a: StatusCodes): string = $(a.int) proc `name=`*(self: Ext, name: string) = raiseAssert "Can't change extensions name!" -method decode*(self: Ext, frame: Frame): Future[Frame] {.base, async.} = +method decode*( + self: Ext, frame: Frame +): Future[Frame] {. + base, async: (raises: [CancelledError, AsyncStreamError, WebSocketError]) +.} = raiseAssert "Not implemented!" -method encode*(self: Ext, frame: Frame): Future[Frame] {.base, async.} = +method encode*( + self: Ext, frame: Frame +): Future[Frame] {.base, async: (raises: [CancelledError, WebSocketError]).} = raiseAssert "Not implemented!" method toHttpOptions*(self: Ext): string {.base, gcsafe.} = raiseAssert "Not implemented!" -func random*(T: typedesc[MaskKey|WebSecKey], rng: var HmacDrbgContext): T = +func random*(T: typedesc[MaskKey | WebSecKey], rng: var HmacDrbgContext): T = rng.generate(result) diff --git a/websock/websock.nim b/websock/websock.nim index 8bd4436949..270840b8f5 100644 --- a/websock/websock.nim +++ b/websock/websock.nim @@ -7,26 +7,17 @@ ## This file may not be copied, modified, or distributed except according to ## those terms. -{.push gcsafe, raises: [].} - -import std/[tables, - strutils, - strformat, - sequtils, - uri] - -import pkg/[chronos, - chronos/apps/http/httptable, - chronos/streams/asyncstream, - chronos/streams/tlsstream, - chronicles, - httputils, - stew/byteutils, - stew/base64, - stew/base10, - nimcrypto/sha] - -import ./frame, ./session, /types, ./http, ./extensions/extutils +{.push raises: [], gcsafe.} + +import + std/[tables, strutils, strformat, sequtils, uri], + chronos, + chronicles, + httputils, + stew/[base64, base10, byteutils], + nimcrypto/sha, + ./[frame, session, types, http], + ./extensions/extutils export session, frame, types, http, httptable @@ -38,9 +29,6 @@ type protocols: seq[string] factories: seq[ExtFactory] -func toException(e: cstring): ref WebSocketError = - (ref WebSocketError)(msg: $e) - func contains(extensions: openArray[Ext], extName: string): bool = for ext in extensions: if ext.name == extName: @@ -100,30 +88,39 @@ proc selectExt(isServer: bool, response proc connect*( - _: type WebSocket, - host: string | TransportAddress, - path: string, - hostName: string = "", # override used when the hostname has been externally resolved - protocols: seq[string] = @[], - factories: seq[ExtFactory] = @[], - hooks: seq[Hook] = @[], - secure = false, - flags: set[TLSFlags] = {}, - version = WSDefaultVersion, - frameSize = WSDefaultFrameSize, - onPing: ControlCb = nil, - onPong: ControlCb = nil, - onClose: CloseCb = nil, - rng = HmacDrbgContext.new()): Future[WSSession] {.async.} = - + _: type WebSocket, + host: string | TransportAddress, + path: string, + hostName: string = "", + # override used when the hostname has been externally resolved + protocols: seq[string] = @[], + factories: seq[ExtFactory] = @[], + hooks: seq[Hook] = @[], + secure = false, + flags: set[TLSFlags] = {}, + version = WSDefaultVersion, + frameSize = WSDefaultFrameSize, + onPing: ControlCb = nil, + onPong: ControlCb = nil, + onClose: CloseCb = nil, + rng = HmacDrbgContext.new(), +): Future[WSSession] {. + async: ( + raises: + [CancelledError, AsyncStreamError, HttpError, TransportError, WebSocketError] + ) +.} = let key = Base64Pad.encode(WebSecKey.random(rng[])) hostname = if hostName.len > 0: hostName else: $host - let client = if secure: - await TlsHttpClient.connect(host, tlsFlags = flags, hostName = hostname) - else: - await HttpClient.connect(host) + var + connected = false + client = + if secure: + await TlsHttpClient.connect(host, tlsFlags = flags, hostName = hostname) + else: + await HttpClient.connect(host) let headerData = [ ("Connection", "Upgrade"), @@ -131,7 +128,8 @@ proc connect*( ("Cache-Control", "no-cache"), ("Sec-WebSocket-Version", $version), ("Sec-WebSocket-Key", key), - ("Host", hostname)] + ("Host", hostname), + ] var headers = HttpTable.init(headerData) if protocols.len > 0: @@ -148,72 +146,74 @@ proc connect*( for hp in hooks: if hp.append == nil: continue - let res = hp.append(hp, headers) - if res.isErr: - raise newException(WSHookError, - "Header plugin execution failed: " & res.error) + hp.append(hp, headers).isOkOr: + raise newException(WSHookError, "Header plugin execution failed: " & error) - let response = try: - await client.request(path, headers = headers) - except CatchableError as exc: - trace "Websocket failed during handshake", exc = exc.msg - await client.close() - raise exc + try: + let response = await client.request(path, headers = headers) - if response.code != Http101.toInt(): - raise newException(WSFailedUpgradeError, + if response.code != Http101.toInt(): + raise newException(WSFailedUpgradeError, &"Server did not reply with a websocket upgrade: " & &"Header code: {response.code} Header reason: {response.reason} " & &"Address: {client.address}") - let proto = response.headers.getString("Sec-WebSocket-Protocol") - if proto.len > 0 and protocols.len > 0: - if proto notin protocols: - raise newException(WSFailedUpgradeError, - &"Invalid protocol returned {proto}!") - - for hp in hooks: - if hp.verify == nil: continue - let res = await hp.verify(hp, response.headers) - if res.isErr: - raise newException(WSHookError, - "Header verification failed: " & res.error) - - var extensions: seq[Ext] - let exts = response.headers.getList("Sec-WebSocket-Extensions") - discard selectExt(false, extensions, factories, exts) - - # Client data should be masked. - let session = WSSession( - stream: client.stream, - readyState: ReadyState.Open, - masked: true, - extensions: system.move(extensions), - rng: rng, - frameSize: frameSize, - onPing: onPing, - onPong: onPong, - onClose: onClose) - - for ext in session.extensions: - ext.session = session - - return session + let proto = response.headers.getString("Sec-WebSocket-Protocol") + if proto.len > 0 and protocols.len > 0: + if proto notin protocols: + raise newException(WSFailedUpgradeError, &"Invalid protocol returned {proto}!") + + for hp in hooks: + if hp.verify == nil: continue + let res = await hp.verify(hp, response.headers) + if res.isErr: + raise newException(WSHookError, "Header verification failed: " & res.error) + + var extensions: seq[Ext] + let exts = response.headers.getList("Sec-WebSocket-Extensions") + discard selectExt(false, extensions, factories, exts) + + # Client data should be masked. + let session = WSSession( + stream: move(client.stream), + readyState: ReadyState.Open, + masked: true, + extensions: move(extensions), + rng: rng, + frameSize: frameSize, + onPing: onPing, + onPong: onPong, + onClose: onClose, + ) + + for ext in session.extensions: + ext.session = session + + connected = true + session + finally: + if not connected: + await client.closeWait() proc connect*( - _: type WebSocket, - uri: Uri, - protocols: seq[string] = @[], - factories: seq[ExtFactory] = @[], - hooks: seq[Hook] = @[], - flags: set[TLSFlags] = {}, - version = WSDefaultVersion, - frameSize = WSDefaultFrameSize, - onPing: ControlCb = nil, - onPong: ControlCb = nil, - onClose: CloseCb = nil, - rng = HmacDrbgContext.new()): Future[WSSession] - {.raises: [WSWrongUriSchemeError].} = + _: type WebSocket, + uri: Uri, + protocols: seq[string] = @[], + factories: seq[ExtFactory] = @[], + hooks: seq[Hook] = @[], + flags: set[TLSFlags] = {}, + version = WSDefaultVersion, + frameSize = WSDefaultFrameSize, + onPing: ControlCb = nil, + onPong: ControlCb = nil, + onClose: CloseCb = nil, + rng = HmacDrbgContext.new(), +): Future[WSSession] {. + async: ( + raises: + [CancelledError, AsyncStreamError, HttpError, TransportError, WebSocketError] + ) +.} = ## Create a new websockets client ## using a Uri ## @@ -229,7 +229,7 @@ proc connect*( if uri.port.len <= 0: uri.port = if secure: "443" else: "80" - return WebSocket.connect( + await WebSocket.connect( host = uri.hostname & ":" & uri.port, path = uri.path, hostName = uri.hostname, @@ -243,31 +243,25 @@ proc connect*( onPing = onPing, onPong = onPong, onClose = onClose, - rng = rng) + rng = rng, + ) proc handleRequest*( - ws: WSServer, - request: HttpRequest, - version: uint = WSDefaultVersion, - hooks: seq[Hook] = @[]): Future[WSSession] - {. - async: - (raises: [ - CancelledError, - CatchableError, - WSHandshakeError, - WSProtoMismatchError]) - .} = + ws: WSServer, + request: HttpRequest, + version: uint = WSDefaultVersion, + hooks: seq[Hook] = @[], +): Future[WSSession] {. + async: (raises: [CancelledError, WebSocketError]) +.} = ## Creates a new socket from a request. ## if not request.headers.contains("Sec-WebSocket-Version"): raise newException(WSHandshakeError, "Missing version header") - ws.version = Base10.decode( - uint, - request.headers.getString("Sec-WebSocket-Version")) - .tryGet() # this method throws + ws.version = Base10.decode(uint, request.headers.getString("Sec-WebSocket-Version")).valueOr: + raise (ref WebSocketError)(msg: $error) if ws.version != version: await request.stream.writer.sendError(Http426)