Skip to content

Commit

Permalink
Add context to mse handshakes
Browse files Browse the repository at this point in the history
  • Loading branch information
anacrolix committed Aug 10, 2024
1 parent f63b38a commit b7b97a6
Show file tree
Hide file tree
Showing 6 changed files with 149 additions and 45 deletions.
7 changes: 4 additions & 3 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -728,7 +728,7 @@ func (cl *Client) initiateProtocolHandshakes(
if err != nil {
panic(err)
}
err = cl.initiateHandshakes(c, t)
err = cl.initiateHandshakes(ctx, c, t)
return
}

Expand Down Expand Up @@ -914,10 +914,11 @@ func (cl *Client) incomingPeerPort() int {
return cl.LocalPort()
}

func (cl *Client) initiateHandshakes(c *PeerConn, t *Torrent) (err error) {
func (cl *Client) initiateHandshakes(ctx context.Context, c *PeerConn, t *Torrent) (err error) {
if c.headerEncrypted {
var rw io.ReadWriter
rw, c.cryptoMethod, err = mse.InitiateHandshake(
rw, c.cryptoMethod, err = mse.InitiateHandshakeContext(
ctx,
struct {
io.Reader
io.Writer
Expand Down
3 changes: 2 additions & 1 deletion handshake.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package torrent

import (
"bytes"
"context"
"fmt"
"io"
"net"
Expand Down Expand Up @@ -65,7 +66,7 @@ func handleEncryption(
}
}
headerEncrypted = true
ret, cryptoMethod, err = mse.ReceiveHandshake(rw, skeys, selector)
ret, cryptoMethod, err = mse.ReceiveHandshake(context.TODO(), rw, skeys, selector)
return
}

Expand Down
3 changes: 2 additions & 1 deletion mse/cmd/mse/main.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"context"
"fmt"
"io"
"log"
Expand Down Expand Up @@ -62,7 +63,7 @@ func mainErr() error {
return fmt.Errorf("accepting: %w", err)
}
defer cn.Close()
rw, _, err := mse.ReceiveHandshake(cn, func(f func([]byte) bool) {
rw, _, err := mse.ReceiveHandshake(context.TODO(), cn, func(f func([]byte) bool) {
for _, sk := range args.Listen.SecretKeys {
f([]byte(sk))
}
Expand Down
58 changes: 58 additions & 0 deletions mse/ctxrw.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package mse

import (
"context"
g "github.com/anacrolix/generics"
"io"
)

type contextedReader struct {
ctx context.Context
r io.Reader
}

func (me contextedReader) Read(p []byte) (n int, err error) {
return contextedReadOrWrite(me.ctx, me.r.Read, p)
}

type contextedWriter struct {
ctx context.Context
w io.Writer
}

// This is problematic. If you return with a context error, a read or write is still pending, and
// could mess up the stream.
func contextedReadOrWrite(ctx context.Context, method func(b []byte) (int, error), b []byte) (_ int, err error) {
asyncCh := make(chan g.Result[int], 1)
go func() {
asyncCh <- g.ResultFromTuple(method(b))
}()
select {
case <-ctx.Done():
err = context.Cause(ctx)
return
case res := <-asyncCh:
return res.AsTuple()
}

}

func (me contextedWriter) Write(p []byte) (n int, err error) {
return contextedReadOrWrite(me.ctx, me.w.Write, p)
}

func contextedReadWriter(ctx context.Context, rw io.ReadWriter) io.ReadWriter {
return struct {
io.Reader
io.Writer
}{
contextedReader{
ctx: ctx,
r: rw,
},
contextedWriter{
ctx: ctx,
w: rw,
},
}
}
104 changes: 68 additions & 36 deletions mse/mse.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package mse

import (
"bytes"
"context"
"crypto/rand"
"crypto/rc4"
"crypto/sha1"
Expand Down Expand Up @@ -32,7 +33,7 @@ type CryptoMethod uint32

var (
// Prime P according to the spec, and G, the generator.
p, g big.Int
p, specG big.Int
// The rand.Int max arg for use in newPadLen()
newPadLenMax big.Int
// For use in initer's hashes
Expand All @@ -50,7 +51,7 @@ var (

func init() {
p.SetString("0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A63A36210000000000090563", 0)
g.SetInt64(2)
specG.SetInt64(2)
newPadLenMax.SetInt64(maxPadLen + 1)
}

Expand Down Expand Up @@ -159,15 +160,15 @@ func paddedLeft(b []byte, _len int) []byte {
// Calculate, and send Y, our public key.
func (h *handshake) postY(x *big.Int) error {
var y big.Int
y.Exp(&g, x, &p)
y.Exp(&specG, x, &p)
return h.postWrite(paddedLeft(y.Bytes(), 96))
}

func (h *handshake) establishS() error {
x := newX()
h.postY(&x)
var b [96]byte
_, err := io.ReadFull(h.conn, b[:])
_, err := io.ReadFull(h.ctxConn, b[:])
if err != nil {
return fmt.Errorf("error reading Y: %w", err)
}
Expand All @@ -193,12 +194,14 @@ func newPadLen() int64 {

// Manages state for both initiating and receiving handshakes.
type handshake struct {
conn io.ReadWriter
s [96]byte
initer bool // Whether we're initiating or receiving.
skeys SecretKeyIter // Skeys we'll accept if receiving.
skey []byte // Skey we're initiating with.
ia []byte // Initial payload. Only used by the initiator.
conn io.ReadWriter
// The conn with Reads and Writes wrapped to the context given in handshake.Do.
ctxConn io.ReadWriter
s [96]byte
initer bool // Whether we're initiating or receiving.
skeys SecretKeyIter // Skeys we'll accept if receiving.
skey []byte // Skey we're initiating with.
ia []byte // Initial payload. Only used by the initiator.
// Return the bit for the crypto method the receiver wants to use.
chooseMethod CryptoSelector
// Sent to the receiver.
Expand Down Expand Up @@ -250,7 +253,7 @@ func (h *handshake) writer() {
b := h.writes[0]
h.writes = h.writes[1:]
h.writeMu.Unlock()
_, err := h.conn.Write(b)
_, err := h.ctxConn.Write(b)
if err != nil {
h.writeMu.Lock()
h.writeErr = err
Expand Down Expand Up @@ -357,7 +360,7 @@ func (h *handshake) newEncrypt(initer bool) *rc4.Cipher {
return newEncrypt(initer, h.s[:], h.skey)
}

func (h *handshake) initerSteps() (ret io.ReadWriter, selected CryptoMethod, err error) {
func (h *handshake) initerSteps(ctx context.Context) (ret io.ReadWriter, selected CryptoMethod, err error) {
h.postWrite(hash(req1, h.s[:]))
h.postWrite(xor(hash(req2, h.skey), hash(req3, h.s[:])))
buf := &bytes.Buffer{}
Expand All @@ -380,29 +383,32 @@ func (h *handshake) initerSteps() (ret io.ReadWriter, selected CryptoMethod, err
// Read until the all zero VC. At this point we've only read the 96 byte
// public key, Y. There is potentially 512 byte padding, between us and
// the 8 byte verification constant.
err = readUntil(io.LimitReader(h.conn, 520), eVC[:])
err = readUntil(io.LimitReader(h.ctxConn, 520), eVC[:])
if err != nil {
if err == io.EOF {
err = errors.New("failed to synchronize on VC")
} else {
err = fmt.Errorf("error reading until VC: %s", err)
err = fmt.Errorf("error reading until VC: %w", err)
}
return
}
r := newCipherReader(bC, h.conn)
ctxReader := newCipherReader(bC, h.ctxConn)
var method CryptoMethod
err = unmarshal(r, &method, &padLen)
err = unmarshal(ctxReader, &method, &padLen)
if err != nil {
return
}
_, err = io.CopyN(io.Discard, r, int64(padLen))
_, err = io.CopyN(io.Discard, ctxReader, int64(padLen))
if err != nil {
return
}
selected = method & h.cryptoProvides
switch selected {
case CryptoMethodRC4:
ret = readWriter{r, &cipherWriter{e, h.conn, nil}}
ret = readWriter{
newCipherReader(bC, h.conn),
&cipherWriter{e, h.conn, nil},
}
case CryptoMethodPlaintext:
ret = h.conn
default:
Expand All @@ -413,17 +419,17 @@ func (h *handshake) initerSteps() (ret io.ReadWriter, selected CryptoMethod, err

var ErrNoSecretKeyMatch = errors.New("no skey matched")

func (h *handshake) receiverSteps() (ret io.ReadWriter, chosen CryptoMethod, err error) {
func (h *handshake) receiverSteps(ctx context.Context) (ret io.ReadWriter, chosen CryptoMethod, err error) {
// There is up to 512 bytes of padding, then the 20 byte hash.
err = readUntil(io.LimitReader(h.conn, 532), hash(req1, h.s[:]))
err = readUntil(io.LimitReader(h.ctxConn, 532), hash(req1, h.s[:]))
if err != nil {
if err == io.EOF {
err = errors.New("failed to synchronize on S hash")
}
return
}
var b [20]byte
_, err = io.ReadFull(h.conn, b[:])
_, err = io.ReadFull(h.ctxConn, b[:])
if err != nil {
return
}
Expand All @@ -447,28 +453,29 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, chosen CryptoMethod, err
if err != nil {
return
}
r := newCipherReader(newEncrypt(true, h.s[:], h.skey), h.conn)
cipher := newEncrypt(true, h.s[:], h.skey)
ctxReader := newCipherReader(cipher, h.ctxConn)
var (
vc [8]byte
provides CryptoMethod
padLen uint16
)

err = unmarshal(r, vc[:], &provides, &padLen)
err = unmarshal(ctxReader, vc[:], &provides, &padLen)
if err != nil {
return
}
cryptoProvidesCount.Add(strconv.FormatUint(uint64(provides), 16), 1)
chosen = h.chooseMethod(provides)
_, err = io.CopyN(io.Discard, r, int64(padLen))
_, err = io.CopyN(io.Discard, ctxReader, int64(padLen))
if err != nil {
return
}
var lenIA uint16
unmarshal(r, &lenIA)
unmarshal(ctxReader, &lenIA)
if lenIA != 0 {
h.ia = make([]byte, lenIA)
unmarshal(r, h.ia)
unmarshal(ctxReader, h.ia)
}
buf := &bytes.Buffer{}
w := cipherWriter{h.newEncrypt(false), buf, nil}
Expand All @@ -484,7 +491,7 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, chosen CryptoMethod, err
switch chosen {
case CryptoMethodRC4:
ret = readWriter{
io.MultiReader(bytes.NewReader(h.ia), r),
io.MultiReader(bytes.NewReader(h.ia), newCipherReader(cipher, h.conn)),
&cipherWriter{w.c, h.conn, nil},
}
case CryptoMethodPlaintext:
Expand All @@ -498,7 +505,7 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, chosen CryptoMethod, err
return
}

func (h *handshake) Do() (ret io.ReadWriter, method CryptoMethod, err error) {
func (h *handshake) Do(ctx context.Context) (ret io.ReadWriter, method CryptoMethod, err error) {
h.writeCond.L = &h.writeMu
h.writerCond.L = &h.writerMu
go h.writer()
Expand All @@ -520,27 +527,41 @@ func (h *handshake) Do() (ret io.ReadWriter, method CryptoMethod, err error) {
return
}
if h.initer {
ret, method, err = h.initerSteps()
ret, method, err = h.initerSteps(ctx)
} else {
ret, method, err = h.receiverSteps()
ret, method, err = h.receiverSteps(ctx)
}
return
}

func InitiateHandshake(
rw io.ReadWriter, skey, initialPayload []byte, cryptoProvides CryptoMethod,
rw io.ReadWriter,
skey, initialPayload []byte,
cryptoProvides CryptoMethod,
) (
ret io.ReadWriter, method CryptoMethod, err error,
) {
return InitiateHandshakeContext(context.TODO(), rw, skey, initialPayload, cryptoProvides)
}

func InitiateHandshakeContext(
ctx context.Context,
rw io.ReadWriter,
skey, initialPayload []byte,
cryptoProvides CryptoMethod,
) (
ret io.ReadWriter, method CryptoMethod, err error,
) {
h := handshake{
conn: rw,
ctxConn: contextedReadWriter(ctx, rw),
initer: true,
skey: skey,
ia: initialPayload,
cryptoProvides: cryptoProvides,
}
defer perf.ScopeTimerErr(&err)()
return h.Do()
return h.Do(ctx)
}

type HandshakeResult struct {
Expand All @@ -550,19 +571,30 @@ type HandshakeResult struct {
SecretKey []byte
}

func ReceiveHandshake(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto CryptoSelector) (io.ReadWriter, CryptoMethod, error) {
res := ReceiveHandshakeEx(rw, skeys, selectCrypto)
func ReceiveHandshake(
ctx context.Context,
rw io.ReadWriter,
skeys SecretKeyIter,
selectCrypto CryptoSelector,
) (io.ReadWriter, CryptoMethod, error) {
res := ReceiveHandshakeEx(ctx, rw, skeys, selectCrypto)
return res.ReadWriter, res.CryptoMethod, res.error
}

func ReceiveHandshakeEx(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto CryptoSelector) (ret HandshakeResult) {
func ReceiveHandshakeEx(
ctx context.Context,
rw io.ReadWriter,
skeys SecretKeyIter,
selectCrypto CryptoSelector,
) (ret HandshakeResult) {
h := handshake{
conn: rw,
ctxConn: contextedReadWriter(ctx, rw),
initer: false,
skeys: skeys,
chooseMethod: selectCrypto,
}
ret.ReadWriter, ret.CryptoMethod, ret.error = h.Do()
ret.ReadWriter, ret.CryptoMethod, ret.error = h.Do(ctx)
ret.SecretKey = h.skey
return
}
Expand Down
Loading

0 comments on commit b7b97a6

Please sign in to comment.