1
+ import math
2
+
3
+ import torch
4
+ import torch .nn as nn
5
+ import torch .nn .functional as F
6
+ from einops import rearrange , repeat , einsum
7
+
8
+ from layers .Embed import DataEmbedding
9
+
10
+
11
+ class Model (nn .Module ):
12
+ """
13
+ Mamba, linear-time sequence modeling with selective state spaces O(L)
14
+ Paper link: https://arxiv.org/abs/2312.00752
15
+ Implementation refernce: https://github.com/johnma2006/mamba-minimal/
16
+ """
17
+
18
+ def __init__ (self , configs ):
19
+ super (Model , self ).__init__ ()
20
+ self .task_name = configs .task_name
21
+ self .pred_len = configs .pred_len
22
+
23
+ self .d_inner = configs .d_model * configs .expand
24
+ self .dt_rank = math .ceil (configs .d_model / 16 )
25
+
26
+ self .embedding = DataEmbedding (configs .enc_in , configs .d_model , configs .embed , configs .freq , configs .dropout )
27
+
28
+ self .layers = nn .ModuleList ([ResidualBlock (configs , self .d_inner , self .dt_rank ) for _ in range (configs .e_layers )])
29
+ self .norm = RMSNorm (configs .d_model )
30
+
31
+ self .out_layer = nn .Linear (configs .d_model , configs .c_out , bias = False )
32
+
33
+ # def short_term_forecast(self, x_enc, x_mark_enc):
34
+ def forecast (self , x_enc , x_mark_enc ):
35
+ mean_enc = x_enc .mean (1 , keepdim = True ).detach ()
36
+ x_enc = x_enc - mean_enc
37
+ std_enc = torch .sqrt (torch .var (x_enc , dim = 1 , keepdim = True , unbiased = False ) + 1e-5 ).detach ()
38
+ x_enc = x_enc / std_enc
39
+
40
+ x = self .embedding (x_enc , x_mark_enc )
41
+ for layer in self .layers :
42
+ x = layer (x )
43
+
44
+ x = self .norm (x )
45
+ x_out = self .out_layer (x )
46
+
47
+ x_out = x_out * std_enc + mean_enc
48
+ return x_out
49
+
50
+ # def long_term_forecast(self, x_enc, x_mark_enc):
51
+ # x = self.embedding(x_enc, x_mark_enc)
52
+ # for layer in self.layers:
53
+ # x = layer(x)
54
+
55
+ # x = self.norm(x)
56
+ # x_out = self.out_layer(x)
57
+ # return x_out
58
+
59
+ def forward (self , x_enc , x_mark_enc , x_dec , x_mark_dec , mask = None ):
60
+ if self .task_name in ['short_term_forecast' , 'long_term_forecast' ]:
61
+ x_out = self .forecast (x_enc , x_mark_enc )
62
+ return x_out [:, - self .pred_len :, :]
63
+
64
+
65
+ # other tasks not implemented
66
+
67
+
68
+ class ResidualBlock (nn .Module ):
69
+ def __init__ (self , configs , d_inner , dt_rank ):
70
+ super (ResidualBlock , self ).__init__ ()
71
+
72
+ self .mixer = MambaBlock (configs , d_inner , dt_rank )
73
+ self .norm = RMSNorm (configs .d_model )
74
+
75
+ def forward (self , x ):
76
+ output = self .mixer (self .norm (x )) + x
77
+ return output
78
+
79
+ class MambaBlock (nn .Module ):
80
+ def __init__ (self , configs , d_inner , dt_rank ):
81
+ super (MambaBlock , self ).__init__ ()
82
+ self .d_inner = d_inner
83
+ self .dt_rank = dt_rank
84
+
85
+ self .in_proj = nn .Linear (configs .d_model , self .d_inner * 2 , bias = False )
86
+
87
+ self .conv1d = nn .Conv1d (
88
+ in_channels = self .d_inner ,
89
+ out_channels = self .d_inner ,
90
+ bias = True ,
91
+ kernel_size = configs .d_conv ,
92
+ padding = configs .d_conv - 1 ,
93
+ groups = self .d_inner
94
+ )
95
+
96
+ # takes in x and outputs the input-specific delta, B, C
97
+ self .x_proj = nn .Linear (self .d_inner , self .dt_rank + configs .d_ff * 2 , bias = False )
98
+
99
+ # projects delta
100
+ self .dt_proj = nn .Linear (self .dt_rank , self .d_inner , bias = True )
101
+
102
+ A = repeat (torch .arange (1 , configs .d_ff + 1 ), "n -> d n" , d = self .d_inner )
103
+ self .A_log = nn .Parameter (torch .log (A ))
104
+ self .D = nn .Parameter (torch .ones (self .d_inner ))
105
+
106
+ self .out_proj = nn .Linear (self .d_inner , configs .d_model , bias = False )
107
+
108
+ def forward (self , x ):
109
+ """
110
+ Figure 3 in Section 3.4 in the paper
111
+ """
112
+ (b , l , d ) = x .shape
113
+
114
+ x_and_res = self .in_proj (x ) # [B, L, 2 * d_inner]
115
+ (x , res ) = x_and_res .split (split_size = [self .d_inner , self .d_inner ], dim = - 1 )
116
+
117
+ x = rearrange (x , "b l d -> b d l" )
118
+ x = self .conv1d (x )[:, :, :l ]
119
+ x = rearrange (x , "b d l -> b l d" )
120
+
121
+ x = F .silu (x )
122
+
123
+ y = self .ssm (x )
124
+ y = y * F .silu (res )
125
+
126
+ output = self .out_proj (y )
127
+ return output
128
+
129
+
130
+ def ssm (self , x ):
131
+ """
132
+ Algorithm 2 in Section 3.2 in the paper
133
+ """
134
+
135
+ (d_in , n ) = self .A_log .shape
136
+
137
+ A = - torch .exp (self .A_log .float ()) # [d_in, n]
138
+ D = self .D .float () # [d_in]
139
+
140
+ x_dbl = self .x_proj (x ) # [B, L, d_rank + 2 * d_ff]
141
+ (delta , B , C ) = x_dbl .split (split_size = [self .dt_rank , n , n ], dim = - 1 ) # delta: [B, L, d_rank]; B, C: [B, L, n]
142
+ delta = F .softplus (self .dt_proj (delta )) # [B, L, d_in]
143
+ y = self .selective_scan (x , delta , A , B , C , D )
144
+
145
+ return y
146
+
147
+ def selective_scan (self , u , delta , A , B , C , D ):
148
+ (b , l , d_in ) = u .shape
149
+ n = A .shape [1 ]
150
+
151
+ deltaA = torch .exp (einsum (delta , A , "b l d, d n -> b l d n" )) # A is discretized using zero-order hold (ZOH) discretization
152
+ deltaB_u = einsum (delta , B , u , "b l d, b l n, b l d -> b l d n" ) # B is discretized using a simplified Euler discretization instead of ZOH. From a discussion with authors: "A is the more important term and the performance doesn't change much with the simplification on B"
153
+
154
+ # selective scan, sequential instead of parallel
155
+ x = torch .zeros ((b , d_in , n ), device = deltaA .device )
156
+ ys = []
157
+ for i in range (l ):
158
+ x = deltaA [:, i ] * x + deltaB_u [:, i ]
159
+ y = einsum (x , C [:, i , :], "b d n, b n -> b d" )
160
+ ys .append (y )
161
+
162
+ y = torch .stack (ys , dim = 1 ) # [B, L, d_in]
163
+ y = y + u * D
164
+
165
+ return y
166
+
167
+ class RMSNorm (nn .Module ):
168
+ def __init__ (self , d_model , eps = 1e-5 ):
169
+ super (RMSNorm , self ).__init__ ()
170
+ self .eps = eps
171
+ self .weight = nn .Parameter (torch .ones (d_model ))
172
+
173
+ def forward (self , x ):
174
+ output = x * torch .rsqrt (x .pow (2 ).mean (- 1 , keepdim = True ) + self .eps ) * self .weight
175
+ return output
0 commit comments