Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEATURE: [okx] support kline subscriptions #1519

Merged
merged 1 commit into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 18 additions & 11 deletions pkg/exchange/okex/kline_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,15 @@ func (s *KLineStream) handleConnect() {

subs = append(subs, sub)
}
if len(subs) == 0 {
return
}

log.Infof("subscribing channels: %+v", subs)
err := s.Conn.WriteJSON(WebsocketOp{
Op: "subscribe",
Args: subs,
})
subscribe(s.Conn, subs)
}

if err != nil {
log.WithError(err).Error("subscribe error")
func (s *KLineStream) Connect(ctx context.Context) error {
if len(s.StandardStream.Subscriptions) == 0 {
log.Info("no subscriptions in kline")
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Warnf?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the kline stream is not always subscribed, i print the log with info

return nil
}
return s.StandardStream.Connect(ctx)
}

func (s *KLineStream) handleKLineEvent(k KLineEvent) {
Expand All @@ -85,3 +81,14 @@ func (s *KLineStream) dispatchEvent(e interface{}) {
s.EmitKLineEvent(*et)
}
}

func (s *KLineStream) Unsubscribe() {
// errors are handled in the syncSubscriptions, so they are skipped here.
if len(s.StandardStream.Subscriptions) != 0 {
_ = syncSubscriptions(s.StandardStream.Conn, s.StandardStream.Subscriptions, WsEventTypeUnsubscribe)
}
s.Resubscribe(func(old []types.Subscription) (new []types.Subscription, err error) {
// clear the subscriptions
return []types.Subscription{}, nil
})
}
64 changes: 46 additions & 18 deletions pkg/exchange/okex/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ package okex
import (
"context"
"fmt"
"golang.org/x/time/rate"
"strconv"
"time"

"github.com/gorilla/websocket"
"golang.org/x/time/rate"

"github.com/c9s/bbgo/pkg/exchange/okex/okexapi"
"github.com/c9s/bbgo/pkg/exchange/retry"
"github.com/c9s/bbgo/pkg/types"
Expand Down Expand Up @@ -67,18 +69,20 @@ func NewStream(client *okexapi.RestClient, balanceProvider types.ExchangeAccount
stream.OnOrderTradesEvent(stream.handleOrderDetailsEvent)
stream.OnConnect(stream.handleConnect)
stream.OnAuth(stream.subscribePrivateChannels(stream.emitBalanceSnapshot))
stream.kLineStream.OnKLineClosed(stream.EmitKLineClosed)
stream.kLineStream.OnKLine(stream.EmitKLine)

return stream
}

func (s *Stream) syncSubscriptions(opType WsEventType) error {
func syncSubscriptions(conn *websocket.Conn, subscriptions []types.Subscription, opType WsEventType) error {
if opType != WsEventTypeUnsubscribe && opType != WsEventTypeSubscribe {
return fmt.Errorf("unexpected subscription type: %v", opType)
}

logger := log.WithField("opType", opType)
var topics []WebsocketSubscription
for _, subscription := range s.Subscriptions {
for _, subscription := range subscriptions {
topic, err := convertSubscription(subscription)
if err != nil {
logger.WithError(err).Errorf("convert error, subscription: %+v", subscription)
Expand All @@ -89,7 +93,7 @@ func (s *Stream) syncSubscriptions(opType WsEventType) error {
}

logger.Infof("%s channels: %+v", opType, topics)
if err := s.Conn.WriteJSON(WebsocketOp{
if err := conn.WriteJSON(WebsocketOp{
Op: opType,
Args: topics,
}); err != nil {
Expand All @@ -102,11 +106,47 @@ func (s *Stream) syncSubscriptions(opType WsEventType) error {

func (s *Stream) Unsubscribe() {
// errors are handled in the syncSubscriptions, so they are skipped here.
_ = s.syncSubscriptions(WsEventTypeUnsubscribe)
_ = syncSubscriptions(s.StandardStream.Conn, s.StandardStream.Subscriptions, WsEventTypeUnsubscribe)
s.Resubscribe(func(old []types.Subscription) (new []types.Subscription, err error) {
// clear the subscriptions
return []types.Subscription{}, nil
})

s.kLineStream.Unsubscribe()
}

func (s *Stream) Connect(ctx context.Context) error {
if err := s.StandardStream.Connect(ctx); err != nil {
return err
}
if err := s.kLineStream.Connect(ctx); err != nil {
return err
}
return nil
}

func (s *Stream) Subscribe(channel types.Channel, symbol string, options types.SubscribeOptions) {
if channel == types.KLineChannel {
s.kLineStream.Subscribe(channel, symbol, options)
} else {
s.StandardStream.Subscribe(channel, symbol, options)
}
}

func subscribe(conn *websocket.Conn, subs []WebsocketSubscription) {
if len(subs) == 0 {
return
}

log.Infof("subscribing channels: %+v", subs)
err := conn.WriteJSON(WebsocketOp{
Op: "subscribe",
Args: subs,
})

if err != nil {
log.WithError(err).Error("subscribe error")
}
}

func (s *Stream) handleConnect() {
Expand All @@ -121,19 +161,7 @@ func (s *Stream) handleConnect() {

subs = append(subs, sub)
}
if len(subs) == 0 {
return
}

log.Infof("subscribing channels: %+v", subs)
err := s.Conn.WriteJSON(WebsocketOp{
Op: "subscribe",
Args: subs,
})

if err != nil {
log.WithError(err).Error("subscribe error")
}
subscribe(s.StandardStream.Conn, subs)
} else {
// login as private channel
// sign example:
Expand Down
28 changes: 28 additions & 0 deletions pkg/exchange/okex/stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,34 @@ func TestStream(t *testing.T) {
<-c
})

t.Run("book && kline test", func(t *testing.T) {
s.Subscribe(types.BookChannel, "BTCUSDT", types.SubscribeOptions{
Depth: types.DepthLevel400,
})
s.Subscribe(types.KLineChannel, "BTCUSDT", types.SubscribeOptions{
Interval: types.Interval1m,
})
s.SetPublicOnly()
err := s.Connect(context.Background())
assert.NoError(t, err)

s.OnBookSnapshot(func(book types.SliceOrderBook) {
t.Log("got snapshot", book)
})
s.OnBookUpdate(func(book types.SliceOrderBook) {
t.Log("got update", book)
})
s.OnKLine(func(kline types.KLine) {
t.Log("kline", kline)
})
s.OnKLineClosed(func(kline types.KLine) {
t.Log("kline closed", kline)
})

c := make(chan struct{})
<-c
})

t.Run("market trade test", func(t *testing.T) {
s.Subscribe(types.MarketTradeChannel, "BTCUSDT", types.SubscribeOptions{})
s.SetPublicOnly()
Expand Down