-
Notifications
You must be signed in to change notification settings - Fork 4
/
rxtx.go
504 lines (468 loc) · 14.9 KB
/
rxtx.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
package mqtt
import (
"bytes"
"encoding/binary"
"errors"
"io"
)
// Rx implements a bare minimum MQTT v3.1.1 protocol transport layer handler.
// Packages are received by calling [Rx.ReadNextPacket] and setting the callback
// in Rx corresponding to the expected packet.
// Rx will perform basic validation of input data according to MQTT's specification.
// If there is an error after reading the first byte of a packet the transport is closed
// and a new transport must be set with [Rx.SetRxTransport].
// If OnRxError is set the underlying transport is not automatically closed and
// it becomes the callback's responsibility to close the transport.
//
// Not safe for concurrent use.
type Rx struct {
// Transport over which packets are read and written to.
// Not exported since RxTx type might be composed of embedded Rx and Tx types in future. TBD.
rxTrp io.ReadCloser
RxCallbacks RxCallbacks
// User defined decoder for allocating packets.
userDecoder Decoder
// ScratchBuf is lazily allocated to exhaust Publish payloads when received and no
// OnPub callback is set.
ScratchBuf []byte
// LastReceivedHeader contains the last correctly read header.
LastReceivedHeader Header
// LimitedReader field prevents a heap allocation in ReadNext since passing
// a stack allocated LimitedReader into RxCallbacks.OnPub will escape inconditionally.
packetLimitReader io.LimitedReader
}
// RxCallbacks groups all functionality executed on data receipt, both successful
// and unsuccessful.
type RxCallbacks struct {
// Functions below can access the Header of the message via RxTx.LastReceivedHeader.
// All these functions block RxTx.ReadNextPacket.
OnConnect func(*Rx, *VariablesConnect) error // Receives pointer because of large struct!
// OnConnack is called on a CONNACK packet receipt.
OnConnack func(*Rx, VariablesConnack) error
// OnPub is called on PUBLISH packet receive. The [io.Reader] points to the transport's reader
// and is limited to read the amount of bytes in the payload as given by RemainingLength.
// One may calculate amount of bytes in the reader like so:
// payloadLen := rx.LastReceivedHeader.RemainingLength - varPub.Size()
// It is important to note the reader `r` will be invalidated on the next incoming publish packet,
// calling r after this point will result in undefined behaviour.
OnPub func(rx *Rx, varPub VariablesPublish, r io.Reader) error
// OnOther takes in the Header of received packet and a packet identifier uint16 if present.
// OnOther receives PUBACK, PUBREC, PUBREL, PUBCOMP, UNSUBACK packets containing non-zero packet identfiers
// and DISCONNECT, PINGREQ, PINGRESP packets with no packet identifier.
OnOther func(rx *Rx, packetIdentifier uint16) error
OnSub func(*Rx, VariablesSubscribe) error
OnSuback func(*Rx, VariablesSuback) error
OnUnsub func(*Rx, VariablesUnsubscribe) error
// OnRxError is called if an error is encountered during decoding of packet.
// If it is set then it becomes the responsibility of the callback to close the transport.
OnRxError func(*Rx, error)
}
// SetRxTransport sets the rx's reader.
func (rx *Rx) SetRxTransport(transport io.ReadCloser) {
rx.rxTrp = transport
}
// Close closes the underlying transport.
func (rx *Rx) CloseRx() error { return rx.rxTrp.Close() }
func (rx *Rx) rxErrHandler(err error) {
if rx.RxCallbacks.OnRxError != nil {
rx.RxCallbacks.OnRxError(rx, err)
} else {
rx.CloseRx()
}
}
// ReadNextPacket reads the next packet in the transport. If it fails after reading a
// non-zero amount of bytes it closes the transport and the underlying transport must be reset.
func (rx *Rx) ReadNextPacket() (int, error) {
if rx.rxTrp == nil {
return 0, errors.New("nil transport")
}
rx.LastReceivedHeader = Header{}
hdr, n, err := DecodeHeader(rx.rxTrp)
if err != nil {
if n > 0 {
rx.rxErrHandler(err)
}
return n, err
}
rx.LastReceivedHeader = hdr
var (
packetType = hdr.Type()
ngot int
packetIdentifier uint16
)
switch packetType {
case PacketPublish:
packetFlags := hdr.Flags()
qos := packetFlags.QoS()
var vp VariablesPublish
vp, ngot, err = rx.userDecoder.DecodePublish(rx.rxTrp, qos)
n += ngot
if err != nil {
break
}
payloadLen := int(hdr.RemainingLength) - ngot
rx.packetLimitReader = io.LimitedReader{R: rx.rxTrp, N: int64(payloadLen)}
if rx.RxCallbacks.OnPub != nil {
err = rx.RxCallbacks.OnPub(rx, vp, &rx.packetLimitReader)
} else {
err = rx.exhaustReader(&rx.packetLimitReader)
}
if rx.packetLimitReader.N != 0 && err == nil {
err = errors.New("expected OnPub to completely read payload")
break
}
case PacketConnack:
if hdr.RemainingLength != 2 {
err = ErrBadRemainingLen
break
}
var vc VariablesConnack
vc, ngot, err = decodeConnack(rx.rxTrp)
n += ngot
if err != nil {
break
}
if rx.RxCallbacks.OnConnack != nil {
err = rx.RxCallbacks.OnConnack(rx, vc)
}
case PacketConnect:
// if hdr.RemainingLength != 0 { // TODO(soypat): What's the minimum RL for CONNECT?
// err = ErrBadRemainingLen
// break
// }
var vc VariablesConnect
vc, ngot, err = rx.userDecoder.DecodeConnect(rx.rxTrp)
n += ngot
if err != nil {
break
}
if rx.RxCallbacks.OnConnect != nil {
err = rx.RxCallbacks.OnConnect(rx, &vc)
}
case PacketSuback:
if hdr.RemainingLength < 2 {
err = ErrBadRemainingLen
break
}
var vsbck VariablesSuback
vsbck, ngot, err = decodeSuback(rx.rxTrp, hdr.RemainingLength)
n += ngot
if err != nil {
break
}
if rx.RxCallbacks.OnSuback != nil {
err = rx.RxCallbacks.OnSuback(rx, vsbck)
}
case PacketSubscribe:
var vsbck VariablesSubscribe
vsbck, ngot, err = rx.userDecoder.DecodeSubscribe(rx.rxTrp, hdr.RemainingLength)
n += ngot
if err != nil {
break
}
if rx.RxCallbacks.OnSub != nil {
err = rx.RxCallbacks.OnSub(rx, vsbck)
}
case PacketUnsubscribe:
var vunsub VariablesUnsubscribe
vunsub, ngot, err = rx.userDecoder.DecodeUnsubscribe(rx.rxTrp, hdr.RemainingLength)
n += ngot
if err != nil {
break
}
if rx.RxCallbacks.OnUnsub != nil {
err = rx.RxCallbacks.OnUnsub(rx, vunsub)
}
case PacketPuback, PacketPubrec, PacketPubrel, PacketPubcomp, PacketUnsuback:
if hdr.RemainingLength != 2 {
err = ErrBadRemainingLen
break
}
// Only PI, no payload.
packetIdentifier, ngot, err = decodeUint16(rx.rxTrp)
n += ngot
if err != nil {
break
}
if rx.RxCallbacks.OnOther != nil {
err = rx.RxCallbacks.OnOther(rx, packetIdentifier)
}
case PacketDisconnect, PacketPingreq, PacketPingresp:
if hdr.RemainingLength != 0 {
err = ErrBadRemainingLen
break
}
// No payload or variable header.
if rx.RxCallbacks.OnOther != nil {
err = rx.RxCallbacks.OnOther(rx, packetIdentifier)
}
default:
// Header Decode should return an error on incorrect packet type receive.
// This could be tested via fuzzing.
panic("unreachable")
}
if err != nil {
rx.rxErrHandler(err)
}
return n, err
}
// RxTransport returns the underlying transport handler. It may be nil.
func (rx *Rx) RxTransport() io.ReadCloser {
return rx.rxTrp
}
// ShallowCopy shallow copies rx and underlying transport and decoder. Does not copy callbacks over.
func (rx *Rx) ShallowCopy() *Rx {
return &Rx{rxTrp: rx.rxTrp, userDecoder: rx.userDecoder}
}
func (rx *Rx) exhaustReader(r io.Reader) (err error) {
if len(rx.ScratchBuf) == 0 {
rx.ScratchBuf = make([]byte, 1024) // Lazy initialization when needed.
}
for err == nil {
_, err = r.Read(rx.ScratchBuf[:])
}
if errors.Is(err, io.EOF) {
return nil
}
return err
}
// Tx implements a bare minimum MQTT v3.1.1 protocol transport layer handler for transmitting packets.
// If there is an error during read/write of a packet the transport is closed
// and a new transport must be set with [Tx.SetTxTransport].
// A Tx will not validate data before encoding, that is up to the caller, Malformed packets
// will be rejected and the connection will be closed immediately. If OnTxError is
// set then the underlying transport is not closed and it becomes responsibility
// of the callback to close the transport.
type Tx struct {
txTrp io.WriteCloser
TxCallbacks TxCallbacks
buffer bytes.Buffer
}
// TxCallbacks groups functionality executed on transmission success or failure
// of an MQTT packet.
type TxCallbacks struct {
// OnTxError is called if an error is encountered during encoding. If it is set
// then it becomes the responsibility of the callback to close Tx's transport.
OnTxError func(*Tx, error)
// OnSuccessfulTx is called after a MQTT packet is fully written to the underlying transport.
OnSuccessfulTx func(*Tx)
}
// TxTransport returns the underlying transport handler. It may be nil.
func (tx *Tx) TxTransport() io.WriteCloser {
return tx.txTrp
}
// SetTxTransport sets the tx's writer.
func (tx *Tx) SetTxTransport(transport io.WriteCloser) {
tx.txTrp = transport
}
// WriteConnack writes a CONNECT packet over the transport.
func (tx *Tx) WriteConnect(varConn *VariablesConnect) error {
if tx.txTrp == nil {
return errors.New("nil transport")
}
buffer := &tx.buffer
buffer.Reset()
h := newHeader(PacketConnect, 0, uint32(varConn.Size()))
_, err := h.Encode(buffer)
if err != nil {
return err
}
_, err = encodeConnect(buffer, varConn)
if err != nil {
return err
}
n, err := buffer.WriteTo(tx.txTrp)
if err != nil && n > 0 {
tx.prepClose(err)
} else if tx.TxCallbacks.OnSuccessfulTx != nil && err == nil {
tx.TxCallbacks.OnSuccessfulTx(tx)
}
return err
}
// WriteConnack writes a CONNACK packet over the transport.
func (tx *Tx) WriteConnack(varConnack VariablesConnack) error {
if tx.txTrp == nil {
return errors.New("nil transport")
}
buffer := &tx.buffer
buffer.Reset()
h := newHeader(PacketConnack, 0, uint32(varConnack.Size()))
_, err := h.Encode(buffer)
if err != nil {
return err
}
_, err = encodeConnack(buffer, varConnack)
if err != nil {
return err
}
n, err := buffer.WriteTo(tx.txTrp)
if err != nil && n > 0 {
tx.prepClose(err)
} else if tx.TxCallbacks.OnSuccessfulTx != nil && err == nil {
tx.TxCallbacks.OnSuccessfulTx(tx)
}
return err
}
// WritePublishPayload writes a PUBLISH packet over the transport along with the
// Application Message in the payload. payload can be zero-length.
func (tx *Tx) WritePublishPayload(h Header, varPub VariablesPublish, payload []byte) error {
if tx.txTrp == nil {
return errors.New("nil transport")
}
buffer := &tx.buffer
buffer.Reset()
qos := h.Flags().QoS()
h.RemainingLength = uint32(varPub.Size(qos) + len(payload))
_, err := h.Encode(buffer)
if err != nil {
return err
}
_, err = encodePublish(buffer, qos, varPub)
if err != nil {
return err
}
_, err = writeFull(buffer, payload)
if err != nil {
return err
}
n, err := buffer.WriteTo(tx.txTrp)
if err != nil && n > 0 {
tx.prepClose(err)
} else if tx.TxCallbacks.OnSuccessfulTx != nil && err == nil {
tx.TxCallbacks.OnSuccessfulTx(tx)
}
return err
}
// WriteSubscribe writes an SUBSCRIBE packet over the transport.
func (tx *Tx) WriteSubscribe(varSub VariablesSubscribe) error {
if tx.txTrp == nil {
return errors.New("nil transport")
}
buffer := &tx.buffer
buffer.Reset()
h := newHeader(PacketSubscribe, PacketFlagsPubrelSubUnsub, uint32(varSub.Size()))
_, err := h.Encode(buffer)
if err != nil {
return err
}
_, err = encodeSubscribe(buffer, varSub)
if err != nil {
return err
}
n, err := buffer.WriteTo(tx.txTrp)
if err != nil && n > 0 {
tx.prepClose(err)
} else if tx.TxCallbacks.OnSuccessfulTx != nil && err == nil {
tx.TxCallbacks.OnSuccessfulTx(tx)
}
return err
}
// WriteSuback writes an UNSUBACK packet over the transport.
func (tx *Tx) WriteSuback(varSub VariablesSuback) error {
if tx.txTrp == nil {
return errors.New("nil transport")
}
if err := varSub.Validate(); err != nil {
return err
}
buffer := &tx.buffer
buffer.Reset()
h := newHeader(PacketSuback, 0, uint32(varSub.Size()))
_, err := h.Encode(buffer)
if err != nil {
return err
}
_, err = encodeSuback(buffer, varSub)
if err != nil {
return err
}
n, err := buffer.WriteTo(tx.txTrp)
if err != nil && n > 0 {
tx.prepClose(err)
} else if tx.TxCallbacks.OnSuccessfulTx != nil && err == nil {
tx.TxCallbacks.OnSuccessfulTx(tx)
}
return err
}
// WriteUnsubscribe writes an UNSUBSCRIBE packet over the transport.
func (tx *Tx) WriteUnsubscribe(varUnsub VariablesUnsubscribe) error {
if tx.txTrp == nil {
return errors.New("nil transport")
}
buffer := &tx.buffer
buffer.Reset()
h := newHeader(PacketUnsubscribe, PacketFlagsPubrelSubUnsub, uint32(varUnsub.Size()))
_, err := h.Encode(buffer)
if err != nil {
return err
}
_, err = encodeUnsubscribe(buffer, varUnsub)
if err != nil {
return err
}
n, err := buffer.WriteTo(tx.txTrp)
if err != nil && n > 0 {
tx.prepClose(err)
} else if tx.TxCallbacks.OnSuccessfulTx != nil && err == nil {
tx.TxCallbacks.OnSuccessfulTx(tx)
}
return err
}
// WriteIdentified writes PUBACK, PUBREC, PUBREL, PUBCOMP, UNSUBACK packets containing non-zero packet identfiers
// It automatically sets the RemainingLength field to 2.
func (tx *Tx) WriteIdentified(packetType PacketType, packetIdentifier uint16) (err error) {
if tx.txTrp == nil {
return errors.New("nil transport")
}
if packetIdentifier == 0 {
return errGotZeroPI
}
// This packet has special QoS1 flag.
isPubrelSubUnsub := packetType == PacketPubrel
if !(isPubrelSubUnsub || packetType == PacketPuback || packetType == PacketPubrec ||
packetType == PacketPubcomp || packetType == PacketUnsuback) {
return errors.New("expected a packet type from PUBACK|PUBREL|PUBCOMP|UNSUBACK")
}
var buf [5 + 2]byte
n := newHeader(packetType, PacketFlags(b2u8(isPubrelSubUnsub)<<1), 2).Put(buf[:])
binary.BigEndian.PutUint16(buf[n:], packetIdentifier)
n, err = writeFull(tx.txTrp, buf[:n+2])
if err != nil && n > 0 {
tx.prepClose(err)
} else if tx.TxCallbacks.OnSuccessfulTx != nil && err == nil {
tx.TxCallbacks.OnSuccessfulTx(tx)
}
return err
}
// WriteSimple facilitates easy sending of the 2 octet DISCONNECT, PINGREQ, PINGRESP packets.
// If the packet is not one of these then an error is returned.
// It also returns an error with encoding step if there was one.
func (tx *Tx) WriteSimple(packetType PacketType) (err error) {
if tx.txTrp == nil {
return errors.New("nil transport")
}
isValid := packetType == PacketDisconnect || packetType == PacketPingreq || packetType == PacketPingresp
if !isValid {
return errors.New("expected packet type from PINGREQ|PINGRESP|DISCONNECT")
}
n, err := newHeader(packetType, 0, 0).Encode(tx.txTrp)
if err != nil && n > 0 {
tx.prepClose(err)
} else if tx.TxCallbacks.OnSuccessfulTx != nil && err == nil {
tx.TxCallbacks.OnSuccessfulTx(tx)
}
return err
}
// Close closes the underlying tranport and returns an error if any.
func (tx *Tx) CloseTx() error { return tx.txTrp.Close() }
func (tx *Tx) prepClose(err error) {
if tx.TxCallbacks.OnTxError != nil {
tx.TxCallbacks.OnTxError(tx, err)
} else {
tx.txTrp.Close()
}
}
// ShallowCopy shallow copies rx and underlying transport and encoder. Does not copy callbacks over.
func (tx *Tx) ShallowCopy() *Tx {
return &Tx{txTrp: tx.txTrp}
}