Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

properly close connections #128

Merged
merged 11 commits into from
Apr 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 29 additions & 12 deletions libp2p/connection.nim
Original file line number Diff line number Diff line change
Expand Up @@ -35,21 +35,28 @@ proc newInvalidVarintException*(): ref InvalidVarintException =
proc newInvalidVarintSizeException*(): ref InvalidVarintSizeException =
newException(InvalidVarintSizeException, "Wrong varint size")

proc init*[T: Connection](self: var T, stream: LPStream) =
proc bindStreamClose(conn: Connection) {.async.} =
# bind stream's close event to connection's close
# to ensure correct close propagation
if not isNil(conn.stream.closeEvent):
await conn.stream.closeEvent.wait()
trace "wrapped stream closed, about to close conn", closed = this.isClosed,
peer = if not isNil(this.peerInfo):
this.peerInfo.id else: ""
if not conn.isClosed:
trace "wrapped stream closed, closing conn", closed = this.isClosed,
peer = if not isNil(this.peerInfo):
this.peerInfo.id else: ""
asyncCheck conn.close()

proc init*[T: Connection](self: var T, stream: LPStream): T =
## create a new Connection for the specified async reader/writer
new self
self.stream = stream
self.closeEvent = newAsyncEvent()
asyncCheck self.bindStreamClose()

# bind stream's close event to connection's close
# to ensure correct close propagation
let this = self
if not isNil(self.stream.closeEvent):
self.stream.closeEvent.wait().
addCallback do (udata: pointer):
if not this.closed:
trace "wrapped stream closed, closing conn"
asyncCheck this.close()
return self

proc newConnection*(stream: LPStream): Connection =
## create a new Connection for the specified async reader/writer
Expand Down Expand Up @@ -108,13 +115,23 @@ method closed*(s: Connection): bool =
result = s.stream.closed

method close*(s: Connection) {.async, gcsafe.} =
trace "closing connection"
trace "about to close connection", closed = s.closed,
peer = if not isNil(s.peerInfo):
s.peerInfo.id else: ""

if not s.closed:
if not isNil(s.stream) and not s.stream.closed:
trace "closing child stream", closed = s.closed,
peer = if not isNil(s.peerInfo):
s.peerInfo.id else: ""
await s.stream.close()

s.closeEvent.fire()
s.isClosed = true
trace "connection closed", closed = s.closed

trace "connection closed", closed = s.closed,
peer = if not isNil(s.peerInfo):
s.peerInfo.id else: ""

proc readLp*(s: Connection): Future[seq[byte]] {.async, gcsafe.} =
## read lenght prefixed msg
Expand Down
6 changes: 3 additions & 3 deletions libp2p/crypto/chacha20poly1305.nim
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ const
ChaChaPolyKeySize = 32
ChaChaPolyNonceSize = 12
ChaChaPolyTagSize = 16

type
ChaChaPoly* = object
ChaChaPolyKey* = array[ChaChaPolyKeySize, byte]
Expand All @@ -46,7 +46,7 @@ proc intoChaChaPolyNonce*(s: openarray[byte]): ChaChaPolyNonce =
proc intoChaChaPolyTag*(s: openarray[byte]): ChaChaPolyTag =
assert s.len == ChaChaPolyTagSize
copyMem(addr result[0], unsafeaddr s[0], ChaChaPolyTagSize)

# bearssl allows us to use optimized versions
# this is reconciled at runtime
# we do this in the global scope / module init
Expand Down Expand Up @@ -85,7 +85,7 @@ proc decrypt*(_: type[ChaChaPoly],
unsafeaddr aad[0]
else:
nil

ourPoly1305CtmulRun(
unsafeaddr key[0],
unsafeaddr nonce[0],
Expand Down
19 changes: 12 additions & 7 deletions libp2p/muxers/mplex/mplex.nim
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,12 @@ method handle*(m: Mplex) {.async, gcsafe.} =
trace "Exception occurred", exception = exc.msg
finally:
trace "stopping mplex main loop"
if not m.connection.closed():
await m.connection.close()
await m.close()

proc internalCleanup(m: Mplex, conn: Connection) {.async.} =
await conn.closeEvent.wait()
trace "connection closed, cleaning up mplex"
await m.close()

proc newMplex*(conn: Connection,
maxChanns: uint = MaxChannels): Mplex =
Expand All @@ -137,11 +141,7 @@ proc newMplex*(conn: Connection,
result.remote = initTable[uint64, LPChannel]()
result.local = initTable[uint64, LPChannel]()

let m = result
conn.closeEvent.wait()
.addCallback do (udata: pointer):
trace "connection closed, cleaning up mplex"
asyncCheck m.close()
asyncCheck result.internalCleanup(conn)

method newStream*(m: Mplex,
name: string = "",
Expand All @@ -154,5 +154,10 @@ method newStream*(m: Mplex,

method close*(m: Mplex) {.async, gcsafe.} =
trace "closing mplex muxer"
if not m.connection.closed():
await m.connection.close()

await allFutures(@[allFutures(toSeq(m.remote.values).mapIt(it.reset())),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might help to merge this #125 but anyway can be the opposite way too

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably more urgent since it fixes some pretty egregious mem leaks. I'd rather get this in first to see how it behaves in NBC.

allFutures(toSeq(m.local.values).mapIt(it.reset()))])
m.remote.clear()
m.local.clear()
16 changes: 9 additions & 7 deletions libp2p/protocols/pubsub/pubsub.nim
Original file line number Diff line number Diff line change
Expand Up @@ -138,20 +138,22 @@ method handleConn*(p: PubSub,
trace "pubsub peer handler ended, cleaning up"
await p.cleanUpHelper(peer)

proc internalClenaup(p: PubSub, conn: Connection) {.async.} =
# handle connection close
var peer = p.getPeer(conn.peerInfo, p.codec)
await conn.closeEvent.wait()
trace "connection closed, cleaning up peer", peer = conn.peerInfo.id

await p.cleanUpHelper(peer)

method subscribeToPeer*(p: PubSub,
conn: Connection) {.base, async.} =
var peer = p.getPeer(conn.peerInfo, p.codec)
trace "setting connection for peer", peerId = conn.peerInfo.id
if not peer.isConnected:
peer.conn = conn

# handle connection close
conn.closeEvent.wait()
.addCallback do (udata: pointer = nil):
trace "connection closed, cleaning up peer",
peer = conn.peerInfo.id

asyncCheck p.cleanUpHelper(peer)
asyncCheck p.internalClenaup(conn)

method unsubscribe*(p: PubSub,
topics: seq[TopicPair]) {.base, async.} =
Expand Down
31 changes: 16 additions & 15 deletions libp2p/protocols/secure/noise.nim
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import ../../peer
import ../../peerinfo
import ../../protobuf/minprotobuf
import ../../utility
import ../../stream/lpstream
import secure,
../../crypto/[crypto, chacha20poly1305, curve25519, hkdf],
../../stream/bufferstream
Expand All @@ -26,7 +27,7 @@ logScope:
const
# https://godoc.org/github.com/libp2p/go-libp2p-noise#pkg-constants
NoiseCodec* = "/noise"

PayloadString = "noise-libp2p-static-key:"

ProtocolXXName = "Noise_XX_25519_ChaChaPoly_SHA256"
Expand All @@ -41,7 +42,7 @@ type
KeyPair = object
privateKey: Curve25519Key
publicKey: Curve25519Key

# https://noiseprotocol.org/noise.html#the-cipherstate-object
CipherState = object
k: ChaChaPolyKey
Expand All @@ -66,7 +67,7 @@ type
cs2: CipherState
remoteP2psecret: seq[byte]
rs: Curve25519Key

Noise* = ref object of Secure
localPrivateKey: PrivateKey
localPublicKey: PublicKey
Expand All @@ -89,7 +90,7 @@ type
proc genKeyPair(): KeyPair =
result.privateKey = Curve25519Key.random()
result.publicKey = result.privateKey.public()

proc hashProtocol(name: string): MDigest[256] =
# If protocol_name is less than or equal to HASHLEN bytes in length,
# sets h equal to protocol_name with zero bytes appended to make HASHLEN bytes.
Expand Down Expand Up @@ -195,7 +196,7 @@ proc split(ss: var SymmetricState): tuple[cs1, cs2: CipherState] =

proc init(_: type[HandshakeState]): HandshakeState =
result.ss = SymmetricState.init()

template write_e: untyped =
trace "noise write e"
# Sets e (which must be empty) to GENERATE_KEYPAIR(). Appends e.public_key to the buffer. Calls MixHash(e.public_key).
Expand Down Expand Up @@ -302,7 +303,7 @@ proc packNoisePayload(payload: openarray[byte]): seq[byte] =

if result.len > uint16.high.int:
raise newException(NoiseOversizedPayloadError, "Trying to send an unsupported oversized payload over Noise")

trace "packed noise payload", inSize = payload.len, outSize = result.len

proc unpackNoisePayload(payload: var seq[byte]) =
Expand All @@ -312,7 +313,7 @@ proc unpackNoisePayload(payload: var seq[byte]) =

if size > (payload.len - 2):
raise newException(NoiseOversizedPayloadError, "Received a wrong payload size")

payload = payload[2..^((payload.len - size) - 1)]

trace "unpacked noise payload", size = payload.len
Expand Down Expand Up @@ -362,7 +363,7 @@ proc handshakeXXOutbound(p: Noise, conn: Connection, p2pProof: ProtoBuffer): Fut
msg &= hs.ss.encryptAndHash(packed)

await conn.sendHSMessage(msg)

let (cs1, cs2) = hs.ss.split()
return HandshakeResult(cs1: cs1, cs2: cs2, remoteP2psecret: remoteP2psecret, rs: hs.rs)

Expand Down Expand Up @@ -426,9 +427,9 @@ method readMessage(sconn: NoiseConnection): Future[seq[byte]] {.async.} =
var plain = sconn.readCs.decryptWithAd([], cipher)
unpackNoisePayload(plain)
return plain
except AsyncStreamIncompleteError:
except LPStreamIncompleteError:
trace "Connection dropped while reading"
except AsyncStreamReadError:
except LPStreamReadError:
trace "Error reading from connection"

method writeMessage(sconn: NoiseConnection, message: seq[byte]): Future[void] {.async.} =
Expand Down Expand Up @@ -460,7 +461,7 @@ method handshake*(p: Noise, conn: Connection, initiator: bool = false): Future[S
# https://github.com/libp2p/specs/tree/master/noise#libp2p-data-in-handshake-messages
let
signedPayload = p.localPrivateKey.sign(PayloadString.toBytes & p.noisePublicKey.getBytes)

var
libp2pProof = initProtoBuffer()

Expand Down Expand Up @@ -489,7 +490,7 @@ method handshake*(p: Noise, conn: Connection, initiator: bool = false): Future[S
raise newException(NoiseHandshakeError, "Noise handshake signature verify failed.")
else:
trace "Remote signature verified"

if initiator and not isNil(conn.peerInfo):
let pid = PeerID.init(remotePubKey)
if not conn.peerInfo.peerId.validate():
Expand All @@ -508,10 +509,10 @@ method handshake*(p: Noise, conn: Connection, initiator: bool = false): Future[S
secure.readCs = handshakeRes.cs1
secure.writeCs = handshakeRes.cs2

debug "Noise handshake completed!"
trace "Noise handshake completed!"

return secure

method init*(p: Noise) {.gcsafe.} =
procCall Secure(p).init()
p.codec = NoiseCodec
Expand All @@ -523,7 +524,7 @@ method secure*(p: Noise, conn: Connection): Future[Connection] {.async, gcsafe.}
warn "securing connection failed", msg = exc.msg
if not conn.closed():
await conn.close()

proc newNoise*(privateKey: PrivateKey; outgoing: bool = true; commonPrologue: seq[byte] = @[]): Noise =
new result
result.outgoing = outgoing
Expand Down
6 changes: 3 additions & 3 deletions libp2p/protocols/secure/secio.nim
Original file line number Diff line number Diff line change
Expand Up @@ -281,11 +281,11 @@ proc transactMessage(conn: Connection,
else:
trace "Received size of message exceed limits", conn = $conn,
length = length
except AsyncStreamIncompleteError:
except LPStreamIncompleteError:
trace "Connection dropped while reading", conn = $conn
except AsyncStreamReadError:
except LPStreamReadError:
trace "Error reading from connection", conn = $conn
except AsyncStreamWriteError:
except LPStreamWriteError:
trace "Could not write to connection", conn = $conn

method handshake*(s: Secio, conn: Connection, initiator: bool = false): Future[SecureConn] {.async.} =
Expand Down
21 changes: 11 additions & 10 deletions libp2p/protocols/secure/secure.nim
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import ../protocol,

type
Secure* = ref object of LPProtocol # base type for secure managers

SecureConn* = ref object of Connection

method readMessage*(c: SecureConn): Future[seq[byte]] {.async, base.} =
Expand All @@ -32,8 +33,9 @@ method handshake(s: Secure,
initiator: bool = false): Future[SecureConn] {.async, base.} =
doAssert(false, "Not implemented!")

proc readLoop(sconn: SecureConn, stream: BufferStream) {.async.} =
proc readLoop(sconn: SecureConn, conn: Connection) {.async.} =
try:
let stream = BufferStream(conn.stream)
while not sconn.closed:
let msg = await sconn.readMessage()
if msg.len == 0:
Expand All @@ -44,24 +46,23 @@ proc readLoop(sconn: SecureConn, stream: BufferStream) {.async.} =
except CatchableError as exc:
trace "Exception occurred Secure.readLoop", exc = exc.msg
finally:
trace "closing conn", closed = conn.closed()
if not conn.closed:
await conn.close()

trace "closing sconn", closed = sconn.closed()
if not sconn.closed:
await sconn.close()
trace "ending Secure readLoop", isclosed = sconn.closed()
trace "ending Secure readLoop"

proc handleConn*(s: Secure, conn: Connection, initiator: bool = false): Future[Connection] {.async, gcsafe.} =
var sconn = await s.handshake(conn, initiator)
proc writeHandler(data: seq[byte]) {.async, gcsafe.} =
trace "sending encrypted bytes", bytes = data.shortLog
await sconn.writeMessage(data)

var stream = newBufferStream(writeHandler)
asyncCheck readLoop(sconn, stream)
result = newConnection(stream)
result.closeEvent.wait()
.addCallback do (udata: pointer):
trace "wrapped connection closed, closing upstream"
if not isNil(sconn) and not sconn.closed:
asyncCheck sconn.close()
result = newConnection(newBufferStream(writeHandler))
asyncCheck readLoop(sconn, result)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kinda wanna remove this too and track the future but I guess can be done in another PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this needs to happen in a more focused refactor, we'll do that after we get the current implementation more stable


if not isNil(sconn.peerInfo) and sconn.peerInfo.publicKey.isSome:
result.peerInfo = PeerInfo.init(sconn.peerInfo.publicKey.get())
Expand Down
5 changes: 2 additions & 3 deletions libp2p/switch.nim
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ proc upgradeIncoming(s: Switch, conn: Connection) {.async, gcsafe.} =

# handle subsequent requests
await ms.handle(sconn)
await sconn.close()

if (await ms.select(conn)): # just handshake
# add the secure handlers
Expand Down Expand Up @@ -289,9 +290,7 @@ proc start*(s: Switch): Future[seq[Future[void]]] {.async, gcsafe.} =
except CatchableError as exc:
trace "Exception occurred in Switch.start", exc = exc.msg
finally:
if not isNil(conn) and not conn.closed:
await conn.close()

await conn.close()
await s.cleanupConn(conn)

var startFuts: seq[Future[void]]
Expand Down
Loading