Skip to content

Commit 4c29e6d

Browse files
authored
Merge pull request #378 from frecklebars/main
Added Mamba models implementation for forecasting tasks
2 parents 17a2179 + 094f7a5 commit 4c29e6d

File tree

17 files changed

+643
-4
lines changed

17 files changed

+643
-4
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,3 +157,6 @@ data_loader_all.py
157157
/scripts/imputation/tmp/
158158
/utils/self_tools.py
159159
/scripts/exp_scripts/
160+
161+
/checkpoints/
162+
/results/

exp/exp_basic.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch
33
from models import Autoformer, Transformer, TimesNet, Nonstationary_Transformer, DLinear, FEDformer, \
44
Informer, LightTS, Reformer, ETSformer, Pyraformer, PatchTST, MICN, Crossformer, FiLM, iTransformer, \
5-
Koopa, TiDE, FreTS, TimeMixer, TSMixer, SegRNN
5+
Koopa, TiDE, FreTS, TimeMixer, TSMixer, SegRNN, MambaSimple, Mamba
66

77

88
class Exp_Basic(object):
@@ -28,6 +28,8 @@ def __init__(self, args):
2828
'Koopa': Koopa,
2929
'TiDE': TiDE,
3030
'FreTS': FreTS,
31+
'MambaSimple': MambaSimple,
32+
'Mamba': Mamba,
3133
'TimeMixer': TimeMixer,
3234
'TSMixer': TSMixer,
3335
'SegRNN': SegRNN

models/Mamba.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import math
2+
3+
import torch
4+
import torch.nn as nn
5+
import torch.nn.functional as F
6+
7+
from mamba_ssm import Mamba
8+
9+
from layers.Embed import DataEmbedding
10+
11+
class Model(nn.Module):
12+
13+
def __init__(self, configs):
14+
super(Model, self).__init__()
15+
self.task_name = configs.task_name
16+
self.pred_len = configs.pred_len
17+
18+
self.d_inner = configs.d_model * configs.expand
19+
self.dt_rank = math.ceil(configs.d_model / 16) # TODO implement "auto"
20+
21+
self.embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq, configs.dropout)
22+
23+
self.mamba = Mamba(
24+
d_model = configs.d_model,
25+
d_state = configs.d_ff,
26+
d_conv = configs.d_conv,
27+
expand = configs.expand,
28+
)
29+
30+
self.out_layer = nn.Linear(configs.d_model, configs.c_out, bias=False)
31+
32+
def forecast(self, x_enc, x_mark_enc):
33+
mean_enc = x_enc.mean(1, keepdim=True).detach()
34+
x_enc = x_enc - mean_enc
35+
std_enc = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach()
36+
x_enc = x_enc / std_enc
37+
38+
x = self.embedding(x_enc, x_mark_enc)
39+
x = self.mamba(x)
40+
x_out = self.out_layer(x)
41+
42+
x_out = x_out * std_enc + mean_enc
43+
return x_out
44+
45+
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
46+
if self.task_name in ['short_term_forecast', 'long_term_forecast']:
47+
x_out = self.forecast(x_enc, x_mark_enc)
48+
return x_out[:, -self.pred_len:, :]
49+
50+
# other tasks not implemented

models/MambaSimple.py

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
import math
2+
3+
import torch
4+
import torch.nn as nn
5+
import torch.nn.functional as F
6+
from einops import rearrange, repeat, einsum
7+
8+
from layers.Embed import DataEmbedding
9+
10+
11+
class Model(nn.Module):
12+
"""
13+
Mamba, linear-time sequence modeling with selective state spaces O(L)
14+
Paper link: https://arxiv.org/abs/2312.00752
15+
Implementation refernce: https://github.com/johnma2006/mamba-minimal/
16+
"""
17+
18+
def __init__(self, configs):
19+
super(Model, self).__init__()
20+
self.task_name = configs.task_name
21+
self.pred_len = configs.pred_len
22+
23+
self.d_inner = configs.d_model * configs.expand
24+
self.dt_rank = math.ceil(configs.d_model / 16)
25+
26+
self.embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq, configs.dropout)
27+
28+
self.layers = nn.ModuleList([ResidualBlock(configs, self.d_inner, self.dt_rank) for _ in range(configs.e_layers)])
29+
self.norm = RMSNorm(configs.d_model)
30+
31+
self.out_layer = nn.Linear(configs.d_model, configs.c_out, bias=False)
32+
33+
# def short_term_forecast(self, x_enc, x_mark_enc):
34+
def forecast(self, x_enc, x_mark_enc):
35+
mean_enc = x_enc.mean(1, keepdim=True).detach()
36+
x_enc = x_enc - mean_enc
37+
std_enc = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach()
38+
x_enc = x_enc / std_enc
39+
40+
x = self.embedding(x_enc, x_mark_enc)
41+
for layer in self.layers:
42+
x = layer(x)
43+
44+
x = self.norm(x)
45+
x_out = self.out_layer(x)
46+
47+
x_out = x_out * std_enc + mean_enc
48+
return x_out
49+
50+
# def long_term_forecast(self, x_enc, x_mark_enc):
51+
# x = self.embedding(x_enc, x_mark_enc)
52+
# for layer in self.layers:
53+
# x = layer(x)
54+
55+
# x = self.norm(x)
56+
# x_out = self.out_layer(x)
57+
# return x_out
58+
59+
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
60+
if self.task_name in ['short_term_forecast', 'long_term_forecast']:
61+
x_out = self.forecast(x_enc, x_mark_enc)
62+
return x_out[:, -self.pred_len:, :]
63+
64+
65+
# other tasks not implemented
66+
67+
68+
class ResidualBlock(nn.Module):
69+
def __init__(self, configs, d_inner, dt_rank):
70+
super(ResidualBlock, self).__init__()
71+
72+
self.mixer = MambaBlock(configs, d_inner, dt_rank)
73+
self.norm = RMSNorm(configs.d_model)
74+
75+
def forward(self, x):
76+
output = self.mixer(self.norm(x)) + x
77+
return output
78+
79+
class MambaBlock(nn.Module):
80+
def __init__(self, configs, d_inner, dt_rank):
81+
super(MambaBlock, self).__init__()
82+
self.d_inner = d_inner
83+
self.dt_rank = dt_rank
84+
85+
self.in_proj = nn.Linear(configs.d_model, self.d_inner * 2, bias=False)
86+
87+
self.conv1d = nn.Conv1d(
88+
in_channels = self.d_inner,
89+
out_channels = self.d_inner,
90+
bias = True,
91+
kernel_size = configs.d_conv,
92+
padding = configs.d_conv - 1,
93+
groups = self.d_inner
94+
)
95+
96+
# takes in x and outputs the input-specific delta, B, C
97+
self.x_proj = nn.Linear(self.d_inner, self.dt_rank + configs.d_ff * 2, bias=False)
98+
99+
# projects delta
100+
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)
101+
102+
A = repeat(torch.arange(1, configs.d_ff + 1), "n -> d n", d=self.d_inner)
103+
self.A_log = nn.Parameter(torch.log(A))
104+
self.D = nn.Parameter(torch.ones(self.d_inner))
105+
106+
self.out_proj = nn.Linear(self.d_inner, configs.d_model, bias=False)
107+
108+
def forward(self, x):
109+
"""
110+
Figure 3 in Section 3.4 in the paper
111+
"""
112+
(b, l, d) = x.shape
113+
114+
x_and_res = self.in_proj(x) # [B, L, 2 * d_inner]
115+
(x, res) = x_and_res.split(split_size=[self.d_inner, self.d_inner], dim=-1)
116+
117+
x = rearrange(x, "b l d -> b d l")
118+
x = self.conv1d(x)[:, :, :l]
119+
x = rearrange(x, "b d l -> b l d")
120+
121+
x = F.silu(x)
122+
123+
y = self.ssm(x)
124+
y = y * F.silu(res)
125+
126+
output = self.out_proj(y)
127+
return output
128+
129+
130+
def ssm(self, x):
131+
"""
132+
Algorithm 2 in Section 3.2 in the paper
133+
"""
134+
135+
(d_in, n) = self.A_log.shape
136+
137+
A = -torch.exp(self.A_log.float()) # [d_in, n]
138+
D = self.D.float() # [d_in]
139+
140+
x_dbl = self.x_proj(x) # [B, L, d_rank + 2 * d_ff]
141+
(delta, B, C) = x_dbl.split(split_size=[self.dt_rank, n, n], dim=-1) # delta: [B, L, d_rank]; B, C: [B, L, n]
142+
delta = F.softplus(self.dt_proj(delta)) # [B, L, d_in]
143+
y = self.selective_scan(x, delta, A, B, C, D)
144+
145+
return y
146+
147+
def selective_scan(self, u, delta, A, B, C, D):
148+
(b, l, d_in) = u.shape
149+
n = A.shape[1]
150+
151+
deltaA = torch.exp(einsum(delta, A, "b l d, d n -> b l d n")) # A is discretized using zero-order hold (ZOH) discretization
152+
deltaB_u = einsum(delta, B, u, "b l d, b l n, b l d -> b l d n") # B is discretized using a simplified Euler discretization instead of ZOH. From a discussion with authors: "A is the more important term and the performance doesn't change much with the simplification on B"
153+
154+
# selective scan, sequential instead of parallel
155+
x = torch.zeros((b, d_in, n), device=deltaA.device)
156+
ys = []
157+
for i in range(l):
158+
x = deltaA[:, i] * x + deltaB_u[:, i]
159+
y = einsum(x, C[:, i, :], "b d n, b n -> b d")
160+
ys.append(y)
161+
162+
y = torch.stack(ys, dim=1) # [B, L, d_in]
163+
y = y + u * D
164+
165+
return y
166+
167+
class RMSNorm(nn.Module):
168+
def __init__(self, d_model, eps=1e-5):
169+
super(RMSNorm, self).__init__()
170+
self.eps = eps
171+
self.weight = nn.Parameter(torch.ones(d_model))
172+
173+
def forward(self, x):
174+
output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
175+
return output

run.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@
5151
parser.add_argument('--anomaly_ratio', type=float, default=0.25, help='prior anomaly ratio (%)')
5252

5353
# model define
54+
parser.add_argument('--expand', type=int, default=2, help='expansion factor for Mamba')
55+
parser.add_argument('--d_conv', type=int, default=4, help='conv kernel size for Mamba')
5456
parser.add_argument('--top_k', type=int, default=5, help='for TimesBlock')
5557
parser.add_argument('--num_kernels', type=int, default=6, help='for Inception')
5658
parser.add_argument('--enc_in', type=int, default=7, help='encoder input size')
@@ -107,7 +109,10 @@
107109
parser.add_argument('--p_hidden_layers', type=int, default=2, help='number of hidden layers in projector')
108110

109111
args = parser.parse_args()
110-
args.use_gpu = True if torch.cuda.is_available() and args.use_gpu else False
112+
# args.use_gpu = True if torch.cuda.is_available() and args.use_gpu else False
113+
args.use_gpu = True if torch.cuda.is_available() else False
114+
115+
print(torch.cuda.is_available())
111116

112117
if args.use_gpu and args.use_multi_gpu:
113118
args.devices = args.devices.replace(' ', '')
@@ -135,7 +140,7 @@
135140
for ii in range(args.itr):
136141
# setting record of experiments
137142
exp = Exp(args) # set experiments
138-
setting = '{}_{}_{}_{}_ft{}_sl{}_ll{}_pl{}_dm{}_nh{}_el{}_dl{}_df{}_fc{}_eb{}_dt{}_{}_{}'.format(
143+
setting = '{}_{}_{}_{}_ft{}_sl{}_ll{}_pl{}_dm{}_nh{}_el{}_dl{}_df{}_expand{}_dc{}_fc{}_eb{}_dt{}_{}_{}'.format(
139144
args.task_name,
140145
args.model_id,
141146
args.model,
@@ -149,6 +154,8 @@
149154
args.e_layers,
150155
args.d_layers,
151156
args.d_ff,
157+
args.expand,
158+
args.d_conv,
152159
args.factor,
153160
args.embed,
154161
args.distil,
@@ -162,7 +169,7 @@
162169
torch.cuda.empty_cache()
163170
else:
164171
ii = 0
165-
setting = '{}_{}_{}_{}_ft{}_sl{}_ll{}_pl{}_dm{}_nh{}_el{}_dl{}_df{}_fc{}_eb{}_dt{}_{}_{}'.format(
172+
setting = '{}_{}_{}_{}_ft{}_sl{}_ll{}_pl{}_dm{}_nh{}_el{}_dl{}_df{}_expand{}_dc{}_fc{}_eb{}_dt{}_{}_{}'.format(
166173
args.task_name,
167174
args.model_id,
168175
args.model,
@@ -176,6 +183,8 @@
176183
args.e_layers,
177184
args.d_layers,
178185
args.d_ff,
186+
args.expand,
187+
args.d_conv,
179188
args.factor,
180189
args.embed,
181190
args.distil,
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
model_name=Mamba
2+
3+
for pred_len in 96 192 336 720
4+
# for pred_len in 336 720
5+
do
6+
7+
python -u run.py \
8+
--task_name long_term_forecast \
9+
--is_training 1 \
10+
--root_path ./dataset/electricity/ \
11+
--data_path electricity.csv \
12+
--model_id ECL_$pred_len'_'$pred_len \
13+
--model $model_name \
14+
--data custom \
15+
--features M \
16+
--seq_len $pred_len \
17+
--label_len 48 \
18+
--pred_len $pred_len \
19+
--e_layers 2 \
20+
--d_layers 1 \
21+
--enc_in 321 \
22+
--expand 2 \
23+
--d_ff 16 \
24+
--d_conv 4 \
25+
--c_out 321 \
26+
--d_model 128 \
27+
--des 'Exp' \
28+
--itr 1 \
29+
30+
done
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
model_name=MambaSimple
2+
3+
for pred_len in 96 192 336 720
4+
do
5+
6+
python -u run.py \
7+
--task_name long_term_forecast \
8+
--is_training 1 \
9+
--root_path ./dataset/ETT-small/ \
10+
--data_path ETTh1.csv \
11+
--model_id ETTh1_$pred_len'_'$pred_len \
12+
--model $model_name \
13+
--data ETTh1 \
14+
--features M \
15+
--seq_len $pred_len \
16+
--label_len 48 \
17+
--pred_len $pred_len \
18+
--e_layers 2 \
19+
--d_layers 1 \
20+
--enc_in 7 \
21+
--expand 2 \
22+
--d_ff 16 \
23+
--d_conv 4 \
24+
--c_out 7 \
25+
--d_model 128 \
26+
--des 'Exp' \
27+
--itr 1 \
28+
29+
done
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
./scripts/long_term_forecast/ETT_script/Mamba_ETTh1.sh | tee mamba_ett.txt
2+
./scripts/long_term_forecast/ETT_script/Mamba_ETTh2.sh | tee mamba_ett.txt -a
3+
./scripts/long_term_forecast/ETT_script/Mamba_ETTm1.sh | tee mamba_ett.txt -a
4+
./scripts/long_term_forecast/ETT_script/Mamba_ETTm2.sh | tee mamba_ett.txt -a

0 commit comments

Comments
 (0)