Skip to content

Commit ebf9c43

Browse files
committed
indicator: separate TR + RMA and ATR = TR + RMA
1 parent a887eaf commit ebf9c43

File tree

4 files changed

+137
-43
lines changed

4 files changed

+137
-43
lines changed

pkg/indicator/atr2.go

+5-39
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,13 @@
11
package indicator
22

3-
import (
4-
"math"
5-
6-
"github.com/c9s/bbgo/pkg/types"
7-
)
8-
9-
// This ATRStream calculates the ATR first
103
type ATRStream struct {
114
// embedded struct
12-
Float64Series
13-
14-
// private states
15-
previousClose float64
5+
*RMAStream
166
}
177

18-
func ATR2(source KLineSubscription) *ATRStream {
19-
s := &ATRStream{
20-
Float64Series: NewFloat64Series(),
21-
}
22-
23-
source.AddSubscriber(func(k types.KLine) {
24-
s.calculateAndPush(k.High.Float64(), k.Low.Float64(), k.Close.Float64())
25-
})
8+
func ATR2(source KLineSubscription, window int) *ATRStream {
9+
s := &ATRStream{}
10+
tr := TR2(source)
11+
s.RMAStream = RMA2(tr, window, true)
2612
return s
2713
}
28-
29-
func (s *ATRStream) calculateAndPush(high, low, cls float64) {
30-
if s.previousClose == .0 {
31-
s.previousClose = cls
32-
return
33-
}
34-
35-
trueRange := high - low
36-
hc := math.Abs(high - s.previousClose)
37-
lc := math.Abs(low - s.previousClose)
38-
if trueRange < hc {
39-
trueRange = hc
40-
}
41-
if trueRange < lc {
42-
trueRange = lc
43-
}
44-
45-
s.previousClose = cls
46-
s.EmitUpdate(trueRange)
47-
}

pkg/indicator/atr2_test.go

+3-4
Original file line numberDiff line numberDiff line change
@@ -65,17 +65,16 @@ func Test_ATR2(t *testing.T) {
6565
stream := &types.StandardStream{}
6666

6767
kLines := KLines(stream)
68-
atr := ATR2(kLines)
69-
rma := RMA2(atr, tt.window, true)
68+
atr := ATR2(kLines, tt.window)
7069

7170
for _, k := range tt.kLines {
7271
stream.EmitKLineClosed(k)
7372
}
7473

75-
got := rma.Last()
74+
got := atr.Last()
7675
diff := math.Trunc((got-tt.want)*100) / 100
7776
if diff != 0 {
78-
t.Errorf("calculateATR2() = %v, want %v", got, tt.want)
77+
t.Errorf("ATR2() = %v, want %v", got, tt.want)
7978
}
8079
})
8180
}

pkg/indicator/tr2.go

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
package indicator
2+
3+
import (
4+
"math"
5+
6+
"github.com/c9s/bbgo/pkg/types"
7+
)
8+
9+
// This TRStream calculates the ATR first
10+
type TRStream struct {
11+
// embedded struct
12+
Float64Series
13+
14+
// private states
15+
previousClose float64
16+
}
17+
18+
func TR2(source KLineSubscription) *TRStream {
19+
s := &TRStream{
20+
Float64Series: NewFloat64Series(),
21+
}
22+
23+
source.AddSubscriber(func(k types.KLine) {
24+
s.calculateAndPush(k.High.Float64(), k.Low.Float64(), k.Close.Float64())
25+
})
26+
return s
27+
}
28+
29+
func (s *TRStream) calculateAndPush(high, low, cls float64) {
30+
if s.previousClose == .0 {
31+
s.previousClose = cls
32+
return
33+
}
34+
35+
trueRange := high - low
36+
hc := math.Abs(high - s.previousClose)
37+
lc := math.Abs(low - s.previousClose)
38+
if trueRange < hc {
39+
trueRange = hc
40+
}
41+
if trueRange < lc {
42+
trueRange = lc
43+
}
44+
45+
s.previousClose = cls
46+
s.EmitUpdate(trueRange)
47+
}

pkg/indicator/tr2_test.go

+82
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
package indicator
2+
3+
import (
4+
"encoding/json"
5+
"math"
6+
"testing"
7+
8+
"github.com/c9s/bbgo/pkg/fixedpoint"
9+
"github.com/c9s/bbgo/pkg/types"
10+
)
11+
12+
/*
13+
python
14+
15+
import pandas as pd
16+
import pandas_ta as ta
17+
18+
data = {
19+
"high": [40145.0, 40186.36, 40196.39, 40344.6, 40245.48, 40273.24, 40464.0, 40699.0, 40627.48, 40436.31, 40370.0, 40376.8, 40227.03, 40056.52, 39721.7, 39597.94, 39750.15, 39927.0, 40289.02, 40189.0],
20+
"low": [39870.71, 39834.98, 39866.31, 40108.31, 40016.09, 40094.66, 40105.0, 40196.48, 40154.99, 39800.0, 39959.21, 39922.98, 39940.02, 39632.0, 39261.39, 39254.63, 39473.91, 39555.51, 39819.0, 40006.84],
21+
"close": [40105.78, 39935.23, 40183.97, 40182.03, 40212.26, 40149.99, 40378.0, 40618.37, 40401.03, 39990.39, 40179.13, 40097.23, 40014.72, 39667.85, 39303.1, 39519.99,
22+
23+
39693.79, 39827.96, 40074.94, 40059.84]
24+
}
25+
26+
high = pd.Series(data['high'])
27+
low = pd.Series(data['low'])
28+
close = pd.Series(data['close'])
29+
result = ta.atr(high, low, close, length=14)
30+
print(result)
31+
*/
32+
func Test_TR_and_RMA(t *testing.T) {
33+
var bytes = []byte(`{
34+
"high": [40145.0, 40186.36, 40196.39, 40344.6, 40245.48, 40273.24, 40464.0, 40699.0, 40627.48, 40436.31, 40370.0, 40376.8, 40227.03, 40056.52, 39721.7, 39597.94, 39750.15, 39927.0, 40289.02, 40189.0],
35+
"low": [39870.71, 39834.98, 39866.31, 40108.31, 40016.09, 40094.66, 40105.0, 40196.48, 40154.99, 39800.0, 39959.21, 39922.98, 39940.02, 39632.0, 39261.39, 39254.63, 39473.91, 39555.51, 39819.0, 40006.84],
36+
"close": [40105.78, 39935.23, 40183.97, 40182.03, 40212.26, 40149.99, 40378.0, 40618.37, 40401.03, 39990.39, 40179.13, 40097.23, 40014.72, 39667.85, 39303.1, 39519.99, 39693.79, 39827.96, 40074.94, 40059.84]
37+
}`)
38+
39+
var buildKLines = func(bytes []byte) (kLines []types.KLine) {
40+
var prices map[string][]fixedpoint.Value
41+
_ = json.Unmarshal(bytes, &prices)
42+
for i, h := range prices["high"] {
43+
kLine := types.KLine{High: h, Low: prices["low"][i], Close: prices["close"][i]}
44+
kLines = append(kLines, kLine)
45+
}
46+
return kLines
47+
}
48+
49+
tests := []struct {
50+
name string
51+
kLines []types.KLine
52+
window int
53+
want float64
54+
}{
55+
{
56+
name: "test_binance_btcusdt_1h",
57+
kLines: buildKLines(bytes),
58+
window: 14,
59+
want: 367.913903,
60+
},
61+
}
62+
63+
for _, tt := range tests {
64+
t.Run(tt.name, func(t *testing.T) {
65+
stream := &types.StandardStream{}
66+
67+
kLines := KLines(stream)
68+
atr := TR2(kLines)
69+
rma := RMA2(atr, tt.window, true)
70+
71+
for _, k := range tt.kLines {
72+
stream.EmitKLineClosed(k)
73+
}
74+
75+
got := rma.Last()
76+
diff := math.Trunc((got-tt.want)*100) / 100
77+
if diff != 0 {
78+
t.Errorf("RMA(TR()) = %v, want %v", got, tt.want)
79+
}
80+
})
81+
}
82+
}

0 commit comments

Comments
 (0)