Skip to content

Commit b8142bf

Browse files
authored
feat(os/gsession): add RegenerateId/MustRegenerateId support (#4012)
1 parent ba96894 commit b8142bf

File tree

2 files changed

+179
-17
lines changed

2 files changed

+179
-17
lines changed

os/gsession/gsession_session.go

+71-17
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ func (s *Session) init() error {
4545
// Retrieve stored session data from storage.
4646
if s.manager.storage != nil {
4747
s.data, err = s.manager.storage.GetSession(s.ctx, s.id, s.manager.GetTTL())
48-
if err != nil && err != ErrorDisabled {
48+
if err != nil && !gerror.Is(err, ErrorDisabled) {
4949
intlog.Errorf(s.ctx, `session restoring failed for id "%s": %+v`, s.id, err)
5050
return err
5151
}
@@ -59,7 +59,7 @@ func (s *Session) init() error {
5959
} else {
6060
// Use default session id creating function of storage.
6161
s.id, err = s.manager.storage.New(s.ctx, s.manager.ttl)
62-
if err != nil && err != ErrorDisabled {
62+
if err != nil && !gerror.Is(err, ErrorDisabled) {
6363
intlog.Errorf(s.ctx, "create session id failed: %+v", err)
6464
return err
6565
}
@@ -89,12 +89,12 @@ func (s *Session) Close() error {
8989
size := s.data.Size()
9090
if s.dirty {
9191
err := s.manager.storage.SetSession(s.ctx, s.id, s.data, s.manager.ttl)
92-
if err != nil && err != ErrorDisabled {
92+
if err != nil && !gerror.Is(err, ErrorDisabled) {
9393
return err
9494
}
9595
} else if size > 0 {
9696
err := s.manager.storage.UpdateTTL(s.ctx, s.id, s.manager.ttl)
97-
if err != nil && err != ErrorDisabled {
97+
if err != nil && !gerror.Is(err, ErrorDisabled) {
9898
return err
9999
}
100100
}
@@ -108,11 +108,10 @@ func (s *Session) Set(key string, value interface{}) (err error) {
108108
return err
109109
}
110110
if err = s.manager.storage.Set(s.ctx, s.id, key, value, s.manager.ttl); err != nil {
111-
if err == ErrorDisabled {
112-
s.data.Set(key, value)
113-
} else {
111+
if !gerror.Is(err, ErrorDisabled) {
114112
return err
115113
}
114+
s.data.Set(key, value)
116115
}
117116
s.dirty = true
118117
return nil
@@ -124,11 +123,10 @@ func (s *Session) SetMap(data map[string]interface{}) (err error) {
124123
return err
125124
}
126125
if err = s.manager.storage.SetMap(s.ctx, s.id, data, s.manager.ttl); err != nil {
127-
if err == ErrorDisabled {
128-
s.data.Sets(data)
129-
} else {
126+
if !gerror.Is(err, ErrorDisabled) {
130127
return err
131128
}
129+
s.data.Sets(data)
132130
}
133131
s.dirty = true
134132
return nil
@@ -144,11 +142,10 @@ func (s *Session) Remove(keys ...string) (err error) {
144142
}
145143
for _, key := range keys {
146144
if err = s.manager.storage.Remove(s.ctx, s.id, key); err != nil {
147-
if err == ErrorDisabled {
148-
s.data.Remove(key)
149-
} else {
145+
if !gerror.Is(err, ErrorDisabled) {
150146
return err
151147
}
148+
s.data.Remove(key)
152149
}
153150
}
154151
s.dirty = true
@@ -164,7 +161,7 @@ func (s *Session) RemoveAll() (err error) {
164161
return err
165162
}
166163
if err = s.manager.storage.RemoveAll(s.ctx, s.id); err != nil {
167-
if err != ErrorDisabled {
164+
if !gerror.Is(err, ErrorDisabled) {
168165
return err
169166
}
170167
}
@@ -215,7 +212,7 @@ func (s *Session) Data() (sessionData map[string]interface{}, err error) {
215212
return nil, err
216213
}
217214
sessionData, err = s.manager.storage.Data(s.ctx, s.id)
218-
if err != nil && err != ErrorDisabled {
215+
if err != nil && !gerror.Is(err, ErrorDisabled) {
219216
intlog.Errorf(s.ctx, `%+v`, err)
220217
}
221218
if sessionData != nil {
@@ -233,7 +230,7 @@ func (s *Session) Size() (size int, err error) {
233230
return 0, err
234231
}
235232
size, err = s.manager.storage.GetSize(s.ctx, s.id)
236-
if err != nil && err != ErrorDisabled {
233+
if err != nil && !gerror.Is(err, ErrorDisabled) {
237234
intlog.Errorf(s.ctx, `%+v`, err)
238235
}
239236
if size > 0 {
@@ -273,7 +270,7 @@ func (s *Session) Get(key string, def ...interface{}) (value *gvar.Var, err erro
273270
return nil, err
274271
}
275272
v, err := s.manager.storage.Get(s.ctx, s.id, key)
276-
if err != nil && err != ErrorDisabled {
273+
if err != nil && !gerror.Is(err, ErrorDisabled) {
277274
intlog.Errorf(s.ctx, `%+v`, err)
278275
return nil, err
279276
}
@@ -357,3 +354,60 @@ func (s *Session) MustRemove(keys ...string) {
357354
panic(err)
358355
}
359356
}
357+
358+
// RegenerateId regenerates a new session id for current session.
359+
// It keeps the session data and updates the session id with a new one.
360+
// This is commonly used to prevent session fixation attacks and increase security.
361+
//
362+
// The parameter `deleteOld` specifies whether to delete the old session data:
363+
// - If true: the old session data will be deleted immediately
364+
// - If false: the old session data will be kept and expire according to its TTL
365+
func (s *Session) RegenerateId(deleteOld bool) (newId string, err error) {
366+
if err = s.init(); err != nil {
367+
return "", err
368+
}
369+
370+
// Generate new session id
371+
if s.idFunc != nil {
372+
newId = s.idFunc(s.manager.ttl)
373+
} else {
374+
newId, err = s.manager.storage.New(s.ctx, s.manager.ttl)
375+
if err != nil && !gerror.Is(err, ErrorDisabled) {
376+
return "", err
377+
}
378+
if newId == "" {
379+
newId = NewSessionId()
380+
}
381+
}
382+
383+
// If using storage, need to copy data to new id
384+
if s.manager.storage != nil {
385+
if err = s.manager.storage.SetSession(s.ctx, newId, s.data, s.manager.ttl); err != nil {
386+
if !gerror.Is(err, ErrorDisabled) {
387+
return "", err
388+
}
389+
}
390+
// Delete old session data if requested
391+
if deleteOld {
392+
if err = s.manager.storage.RemoveAll(s.ctx, s.id); err != nil {
393+
if !gerror.Is(err, ErrorDisabled) {
394+
return "", err
395+
}
396+
}
397+
}
398+
}
399+
400+
// Update session id
401+
s.id = newId
402+
s.dirty = true
403+
return newId, nil
404+
}
405+
406+
// MustRegenerateId performs as function RegenerateId, but it panics if any error occurs.
407+
func (s *Session) MustRegenerateId(deleteOld bool) string {
408+
newId, err := s.RegenerateId(deleteOld)
409+
if err != nil {
410+
panic(err)
411+
}
412+
return newId
413+
}

os/gsession/gsession_z_unit_test.go

+108
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,15 @@
77
package gsession
88

99
import (
10+
"context"
1011
"testing"
12+
"time"
1113

1214
"github.com/gogf/gf/v2/test/gtest"
1315
)
1416

17+
var ctx = context.TODO()
18+
1519
func Test_NewSessionId(t *testing.T) {
1620
gtest.C(t, func(t *gtest.T) {
1721
id1 := NewSessionId()
@@ -20,3 +24,107 @@ func Test_NewSessionId(t *testing.T) {
2024
t.Assert(len(id1), 32)
2125
})
2226
}
27+
28+
func Test_Session_RegenerateId(t *testing.T) {
29+
gtest.C(t, func(t *gtest.T) {
30+
// 1. Test with memory storage
31+
storage := NewStorageMemory()
32+
manager := New(time.Hour, storage)
33+
session := manager.New(ctx)
34+
35+
// Store some data
36+
err := session.Set("key1", "value1")
37+
t.AssertNil(err)
38+
err = session.Set("key2", "value2")
39+
t.AssertNil(err)
40+
41+
// Get original session id
42+
oldId := session.MustId()
43+
44+
// Test regenerate with deleteOld = true
45+
newId1, err := session.RegenerateId(true)
46+
t.AssertNil(err)
47+
t.AssertNE(oldId, newId1)
48+
49+
// Verify data is preserved
50+
v1 := session.MustGet("key1")
51+
t.Assert(v1.String(), "value1")
52+
v2 := session.MustGet("key2")
53+
t.Assert(v2.String(), "value2")
54+
55+
// Verify old session is deleted
56+
oldSession := manager.New(ctx)
57+
err = oldSession.SetId(oldId)
58+
t.AssertNil(err)
59+
v3 := oldSession.MustGet("key1")
60+
t.Assert(v3.IsNil(), true)
61+
62+
// Test regenerate with deleteOld = false
63+
currentId := newId1
64+
newId2, err := session.RegenerateId(false)
65+
t.AssertNil(err)
66+
t.AssertNE(currentId, newId2)
67+
68+
// Verify data is preserved in new session
69+
v4 := session.MustGet("key1")
70+
t.Assert(v4.String(), "value1")
71+
72+
// Create another session instance with the previous id
73+
prevSession := manager.New(ctx)
74+
err = prevSession.SetId(currentId)
75+
t.AssertNil(err)
76+
// Data should still be accessible in previous session
77+
v5 := prevSession.MustGet("key1")
78+
t.Assert(v5.String(), "value1")
79+
})
80+
81+
gtest.C(t, func(t *gtest.T) {
82+
// 2. Test with custom id function
83+
storage := NewStorageMemory()
84+
manager := New(time.Hour, storage)
85+
session := manager.New(ctx)
86+
87+
customId := "custom_session_id"
88+
err := session.SetIdFunc(func(ttl time.Duration) string {
89+
return customId
90+
})
91+
t.AssertNil(err)
92+
93+
newId, err := session.RegenerateId(true)
94+
t.AssertNil(err)
95+
t.Assert(newId, customId)
96+
})
97+
98+
gtest.C(t, func(t *gtest.T) {
99+
// 3. Test with disabled storage
100+
storage := &StorageBase{} // implements Storage interface but all methods return ErrorDisabled
101+
manager := New(time.Hour, storage)
102+
session := manager.New(ctx)
103+
104+
// Should still work even with disabled storage
105+
newId, err := session.RegenerateId(true)
106+
t.AssertNil(err)
107+
t.Assert(len(newId), 32)
108+
})
109+
}
110+
111+
// Test MustRegenerateId
112+
func Test_Session_MustRegenerateId(t *testing.T) {
113+
gtest.C(t, func(t *gtest.T) {
114+
storage := NewStorageMemory()
115+
manager := New(time.Hour, storage)
116+
session := manager.New(ctx)
117+
118+
// Normal case should not panic
119+
t.AssertNil(session.Set("key", "value"))
120+
newId := session.MustRegenerateId(true)
121+
t.Assert(len(newId), 32)
122+
123+
// Test with disabled storage (should not panic)
124+
storage2 := &StorageBase{}
125+
manager2 := New(time.Hour, storage2)
126+
session2 := manager2.New(ctx)
127+
newId2 := session2.MustRegenerateId(true)
128+
t.Assert(len(newId2), 32)
129+
})
130+
}

0 commit comments

Comments
 (0)