02c2e162 by astaxie

Strengthens the session's function

1 parent 59a67720
1 package session 1 package session
2 2
3 import ( 3 import (
4 "errors"
5 "io"
4 "io/ioutil" 6 "io/ioutil"
5 "os" 7 "os"
6 "path" 8 "path"
...@@ -48,6 +50,14 @@ func (fs *FileSessionStore) Delete(key interface{}) error { ...@@ -48,6 +50,14 @@ func (fs *FileSessionStore) Delete(key interface{}) error {
48 return nil 50 return nil
49 } 51 }
50 52
53 func (fs *FileSessionStore) Flush() error {
54 fs.lock.Lock()
55 defer fs.lock.Unlock()
56 fs.values = make(map[interface{}]interface{})
57 fs.updatecontent()
58 return nil
59 }
60
51 func (fs *FileSessionStore) SessionID() string { 61 func (fs *FileSessionStore) SessionID() string {
52 return fs.sid 62 return fs.sid
53 } 63 }
...@@ -121,6 +131,55 @@ func (fp *FileProvider) SessionGC() { ...@@ -121,6 +131,55 @@ func (fp *FileProvider) SessionGC() {
121 filepath.Walk(fp.savePath, gcpath) 131 filepath.Walk(fp.savePath, gcpath)
122 } 132 }
123 133
134 func (fp *FileProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) {
135 err := os.MkdirAll(path.Join(fp.savePath, string(oldsid[0]), string(oldsid[1])), 0777)
136 if err != nil {
137 println(err.Error())
138 }
139 err = os.MkdirAll(path.Join(fp.savePath, string(sid[0]), string(sid[1])), 0777)
140 if err != nil {
141 println(err.Error())
142 }
143 _, err = os.Stat(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid))
144 var newf *os.File
145 if err == nil {
146 return nil, errors.New("newsid exist")
147 } else if os.IsNotExist(err) {
148 newf, err = os.Create(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid))
149 }
150
151 _, err = os.Stat(path.Join(fp.savePath, string(oldsid[0]), string(oldsid[1]), oldsid))
152 var f *os.File
153 if err == nil {
154 f, err = os.OpenFile(path.Join(fp.savePath, string(oldsid[0]), string(oldsid[1]), oldsid), os.O_RDWR, 0777)
155 io.Copy(newf, f)
156 } else if os.IsNotExist(err) {
157 newf, err = os.Create(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid))
158 } else {
159 return nil, err
160 }
161 f.Close()
162 os.Remove(path.Join(fp.savePath, string(oldsid[0]), string(oldsid[1])))
163 os.Chtimes(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid), time.Now(), time.Now())
164 var kv map[interface{}]interface{}
165 b, err := ioutil.ReadAll(newf)
166 if err != nil {
167 return nil, err
168 }
169 if len(b) == 0 {
170 kv = make(map[interface{}]interface{})
171 } else {
172 kv, err = decodeGob(b)
173 if err != nil {
174 return nil, err
175 }
176 }
177
178 newf, err = os.OpenFile(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid), os.O_WRONLY|os.O_CREATE, 0777)
179 ss := &FileSessionStore{f: newf, sid: sid, values: kv}
180 return ss, nil
181 }
182
124 func gcpath(path string, info os.FileInfo, err error) error { 183 func gcpath(path string, info os.FileInfo, err error) error {
125 if err != nil { 184 if err != nil {
126 return err 185 return err
......
...@@ -40,6 +40,13 @@ func (st *MemSessionStore) Delete(key interface{}) error { ...@@ -40,6 +40,13 @@ func (st *MemSessionStore) Delete(key interface{}) error {
40 return nil 40 return nil
41 } 41 }
42 42
43 func (st *MemSessionStore) Flush() error {
44 st.lock.Lock()
45 defer st.lock.Unlock()
46 st.value = make(map[interface{}]interface{})
47 return nil
48 }
49
43 func (st *MemSessionStore) SessionID() string { 50 func (st *MemSessionStore) SessionID() string {
44 return st.sid 51 return st.sid
45 } 52 }
...@@ -80,6 +87,29 @@ func (pder *MemProvider) SessionRead(sid string) (SessionStore, error) { ...@@ -80,6 +87,29 @@ func (pder *MemProvider) SessionRead(sid string) (SessionStore, error) {
80 return nil, nil 87 return nil, nil
81 } 88 }
82 89
90 func (pder *MemProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) {
91 pder.lock.RLock()
92 if element, ok := pder.sessions[oldsid]; ok {
93 go pder.SessionUpdate(oldsid)
94 pder.lock.RUnlock()
95 pder.lock.Lock()
96 element.Value.(*MemSessionStore).sid = sid
97 pder.sessions[sid] = element
98 delete(pder.sessions, oldsid)
99 pder.lock.Unlock()
100 return element.Value.(*MemSessionStore), nil
101 } else {
102 pder.lock.RUnlock()
103 pder.lock.Lock()
104 newsess := &MemSessionStore{sid: sid, timeAccessed: time.Now(), value: make(map[interface{}]interface{})}
105 element := pder.list.PushBack(newsess)
106 pder.sessions[sid] = element
107 pder.lock.Unlock()
108 return newsess, nil
109 }
110 return nil, nil
111 }
112
83 func (pder *MemProvider) SessionDestroy(sid string) error { 113 func (pder *MemProvider) SessionDestroy(sid string) error {
84 pder.lock.Lock() 114 pder.lock.Lock()
85 defer pder.lock.Unlock() 115 defer pder.lock.Unlock()
......
...@@ -50,6 +50,14 @@ func (st *MysqlSessionStore) Delete(key interface{}) error { ...@@ -50,6 +50,14 @@ func (st *MysqlSessionStore) Delete(key interface{}) error {
50 return nil 50 return nil
51 } 51 }
52 52
53 func (st *MysqlSessionStore) Flush() error {
54 st.lock.Lock()
55 defer st.lock.Unlock()
56 st.values = make(map[interface{}]interface{})
57 st.updatemysql()
58 return nil
59 }
60
53 func (st *MysqlSessionStore) SessionID() string { 61 func (st *MysqlSessionStore) SessionID() string {
54 return st.sid 62 return st.sid
55 } 63 }
...@@ -108,6 +116,28 @@ func (mp *MysqlProvider) SessionRead(sid string) (SessionStore, error) { ...@@ -108,6 +116,28 @@ func (mp *MysqlProvider) SessionRead(sid string) (SessionStore, error) {
108 return rs, nil 116 return rs, nil
109 } 117 }
110 118
119 func (mp *MysqlProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) {
120 c := mp.connectInit()
121 row := c.QueryRow("select session_data from session where session_key=?", oldsid)
122 var sessiondata []byte
123 err := row.Scan(&sessiondata)
124 if err == sql.ErrNoRows {
125 c.Exec("insert into session(`session_key`,`session_data`,`session_expiry`) values(?,?,?)", oldsid, "", time.Now().Unix())
126 }
127 c.Exec("update session set `session_key`=? where session_key=?", sid, oldsid)
128 var kv map[interface{}]interface{}
129 if len(sessiondata) == 0 {
130 kv = make(map[interface{}]interface{})
131 } else {
132 kv, err = decodeGob(sessiondata)
133 if err != nil {
134 return nil, err
135 }
136 }
137 rs := &MysqlSessionStore{c: c, sid: sid, values: kv}
138 return rs, nil
139 }
140
111 func (mp *MysqlProvider) SessionDestroy(sid string) error { 141 func (mp *MysqlProvider) SessionDestroy(sid string) error {
112 c := mp.connectInit() 142 c := mp.connectInit()
113 c.Exec("DELETE FROM session where session_key=?", sid) 143 c.Exec("DELETE FROM session where session_key=?", sid)
......
...@@ -35,6 +35,11 @@ func (rs *RedisSessionStore) Delete(key interface{}) error { ...@@ -35,6 +35,11 @@ func (rs *RedisSessionStore) Delete(key interface{}) error {
35 return err 35 return err
36 } 36 }
37 37
38 func (rs *RedisSessionStore) Flush() error {
39 _, err := rs.c.Do("DEL", rs.sid)
40 return err
41 }
42
38 func (rs *RedisSessionStore) SessionID() string { 43 func (rs *RedisSessionStore) SessionID() string {
39 return rs.sid 44 return rs.sid
40 } 45 }
...@@ -99,6 +104,16 @@ func (rp *RedisProvider) SessionRead(sid string) (SessionStore, error) { ...@@ -99,6 +104,16 @@ func (rp *RedisProvider) SessionRead(sid string) (SessionStore, error) {
99 return rs, nil 104 return rs, nil
100 } 105 }
101 106
107 func (rp *RedisProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) {
108 c := rp.connectInit()
109 if str, err := redis.String(c.Do("HGET", oldsid, oldsid)); err != nil || str == "" {
110 c.Do("HSET", oldsid, oldsid, rp.maxlifetime)
111 }
112 c.Do("RENAME", oldsid, sid)
113 rs := &RedisSessionStore{c: c, sid: sid}
114 return rs, nil
115 }
116
102 func (rp *RedisProvider) SessionDestroy(sid string) error { 117 func (rp *RedisProvider) SessionDestroy(sid string) error {
103 c := rp.connectInit() 118 c := rp.connectInit()
104 c.Do("DEL", sid) 119 c.Do("DEL", sid)
......
1 package session 1 package session
2 2
3 import ( 3 import (
4 "crypto/hmac"
5 "crypto/md5"
4 "crypto/rand" 6 "crypto/rand"
7 "crypto/sha1"
5 "encoding/base64" 8 "encoding/base64"
9 "encoding/hex"
6 "fmt" 10 "fmt"
7 "io" 11 "io"
8 "net/http" 12 "net/http"
...@@ -16,11 +20,13 @@ type SessionStore interface { ...@@ -16,11 +20,13 @@ type SessionStore interface {
16 Delete(key interface{}) error //delete session value 20 Delete(key interface{}) error //delete session value
17 SessionID() string //back current sessionID 21 SessionID() string //back current sessionID
18 SessionRelease() // release the resource 22 SessionRelease() // release the resource
23 Flush() error //delete all data
19 } 24 }
20 25
21 type Provider interface { 26 type Provider interface {
22 SessionInit(maxlifetime int64, savePath string) error 27 SessionInit(maxlifetime int64, savePath string) error
23 SessionRead(sid string) (SessionStore, error) 28 SessionRead(sid string) (SessionStore, error)
29 SessionRegenerate(oldsid, sid string) (SessionStore, error)
24 SessionDestroy(sid string) error 30 SessionDestroy(sid string) error
25 SessionGC() 31 SessionGC()
26 } 32 }
...@@ -44,40 +50,91 @@ type Manager struct { ...@@ -44,40 +50,91 @@ type Manager struct {
44 cookieName string //private cookiename 50 cookieName string //private cookiename
45 provider Provider 51 provider Provider
46 maxlifetime int64 52 maxlifetime int64
53 hashfunc string //support md5 & sha1
54 hashkey string
47 options []interface{} 55 options []interface{}
48 } 56 }
49 57
58 //options
59 //1. is https default false
60 //2. hashfunc default sha1
61 //3. hashkey default beegosessionkey
62 //4. maxage default is none
50 func NewManager(provideName, cookieName string, maxlifetime int64, savePath string, options ...interface{}) (*Manager, error) { 63 func NewManager(provideName, cookieName string, maxlifetime int64, savePath string, options ...interface{}) (*Manager, error) {
51 provider, ok := provides[provideName] 64 provider, ok := provides[provideName]
52 if !ok { 65 if !ok {
53 return nil, fmt.Errorf("session: unknown provide %q (forgotten import?)", provideName) 66 return nil, fmt.Errorf("session: unknown provide %q (forgotten import?)", provideName)
54 } 67 }
55 provider.SessionInit(maxlifetime, savePath) 68 provider.SessionInit(maxlifetime, savePath)
56 return &Manager{provider: provider, cookieName: cookieName, maxlifetime: maxlifetime, options: options}, nil 69 hashfunc := "sha1"
70 if len(options) > 1 {
71 hashfunc = options[1].(string)
72 }
73 hashkey := "beegosessionkey"
74 if len(options) > 2 {
75 hashkey = options[2].(string)
76 }
77 return &Manager{
78 provider: provider,
79 cookieName: cookieName,
80 maxlifetime: maxlifetime,
81 hashfunc: hashfunc,
82 hashkey: hashkey,
83 options: options,
84 }, nil
57 } 85 }
58 86
59 //get Session 87 //get Session
60 func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (session SessionStore) { 88 func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (session SessionStore) {
61 cookie, err := r.Cookie(manager.cookieName) 89 cookie, err := r.Cookie(manager.cookieName)
90 maxage := -1
91 if len(manager.options) > 3 {
92 switch manager.options[3].(type) {
93 case int:
94 if manager.options[3].(int) > 0 {
95 maxage = manager.options[3].(int)
96 } else if manager.options[3].(int) < 0 {
97 maxage = 0
98 }
99 case int64:
100 if manager.options[3].(int64) > 0 {
101 maxage = int(manager.options[3].(int64))
102 } else if manager.options[3].(int64) < 0 {
103 maxage = 0
104 }
105 case int32:
106 if manager.options[3].(int32) > 0 {
107 maxage = int(manager.options[3].(int32))
108 } else if manager.options[3].(int32) < 0 {
109 maxage = 0
110 }
111 }
112 }
62 if err != nil || cookie.Value == "" { 113 if err != nil || cookie.Value == "" {
63 sid := manager.sessionId() 114 sid := manager.sessionId(r)
64 session, _ = manager.provider.SessionRead(sid) 115 session, _ = manager.provider.SessionRead(sid)
65 secure := false 116 secure := false
66 if len(manager.options) > 0 { 117 if len(manager.options) > 0 {
67 secure = manager.options[0].(bool) 118 secure = manager.options[0].(bool)
68 } 119 }
69 cookie := http.Cookie{Name: manager.cookieName, 120 cookie = &http.Cookie{Name: manager.cookieName,
70 Value: url.QueryEscape(sid), 121 Value: url.QueryEscape(sid),
71 Path: "/", 122 Path: "/",
72 HttpOnly: true, 123 HttpOnly: true,
73 Secure: secure} 124 Secure: secure}
125 if maxage >= 0 {
126 cookie.MaxAge = maxage
127 }
74 //cookie.Expires = time.Now().Add(time.Duration(manager.maxlifetime) * time.Second) 128 //cookie.Expires = time.Now().Add(time.Duration(manager.maxlifetime) * time.Second)
75 http.SetCookie(w, &cookie) 129 http.SetCookie(w, cookie)
76 r.AddCookie(&cookie) 130 r.AddCookie(cookie)
77 } else { 131 } else {
78 //cookie.Expires = time.Now().Add(time.Duration(manager.maxlifetime) * time.Second) 132 //cookie.Expires = time.Now().Add(time.Duration(manager.maxlifetime) * time.Second)
79 cookie.HttpOnly = true 133 cookie.HttpOnly = true
80 cookie.Path = "/" 134 cookie.Path = "/"
135 if maxage >= 0 {
136 cookie.MaxAge = maxage
137 }
81 http.SetCookie(w, cookie) 138 http.SetCookie(w, cookie)
82 sid, _ := url.QueryUnescape(cookie.Value) 139 sid, _ := url.QueryUnescape(cookie.Value)
83 session, _ = manager.provider.SessionRead(sid) 140 session, _ = manager.provider.SessionRead(sid)
...@@ -103,10 +160,81 @@ func (manager *Manager) GC() { ...@@ -103,10 +160,81 @@ func (manager *Manager) GC() {
103 time.AfterFunc(time.Duration(manager.maxlifetime)*time.Second, func() { manager.GC() }) 160 time.AfterFunc(time.Duration(manager.maxlifetime)*time.Second, func() { manager.GC() })
104 } 161 }
105 162
106 func (manager *Manager) sessionId() string { 163 func (manager *Manager) SessionRegenerateId(w http.ResponseWriter, r *http.Request) (session SessionStore) {
164 sid := manager.sessionId(r)
165 cookie, err := r.Cookie(manager.cookieName)
166 if err != nil && cookie.Value == "" {
167 //delete old cookie
168 session, _ = manager.provider.SessionRead(sid)
169 secure := false
170 if len(manager.options) > 0 {
171 secure = manager.options[0].(bool)
172 }
173 cookie = &http.Cookie{Name: manager.cookieName,
174 Value: url.QueryEscape(sid),
175 Path: "/",
176 HttpOnly: true,
177 Secure: secure,
178 }
179 } else {
180 oldsid, _ := url.QueryUnescape(cookie.Value)
181 session, _ = manager.provider.SessionRegenerate(oldsid, sid)
182 cookie.Value = url.QueryEscape(sid)
183 cookie.HttpOnly = true
184 cookie.Path = "/"
185 }
186 maxage := -1
187 if len(manager.options) > 3 {
188 switch manager.options[3].(type) {
189 case int:
190 if manager.options[3].(int) > 0 {
191 maxage = manager.options[3].(int)
192 } else if manager.options[3].(int) < 0 {
193 maxage = 0
194 }
195 case int64:
196 if manager.options[3].(int64) > 0 {
197 maxage = int(manager.options[3].(int64))
198 } else if manager.options[3].(int64) < 0 {
199 maxage = 0
200 }
201 case int32:
202 if manager.options[3].(int32) > 0 {
203 maxage = int(manager.options[3].(int32))
204 } else if manager.options[3].(int32) < 0 {
205 maxage = 0
206 }
207 }
208 }
209 if maxage >= 0 {
210 cookie.MaxAge = maxage
211 }
212 http.SetCookie(w, cookie)
213 r.AddCookie(cookie)
214 return
215 }
216
217 //remote_addr cruunixnano randdata
218
219 func (manager *Manager) sessionId(r *http.Request) (sid string) {
107 b := make([]byte, 24) 220 b := make([]byte, 24)
108 if _, err := io.ReadFull(rand.Reader, b); err != nil { 221 if _, err := io.ReadFull(rand.Reader, b); err != nil {
109 return "" 222 return ""
110 } 223 }
111 return base64.URLEncoding.EncodeToString(b) 224 bs := base64.URLEncoding.EncodeToString(b)
225 sig := fmt.Sprintf("%s%d%s", r.RemoteAddr, time.Now().UnixNano(), bs)
226 if manager.hashfunc == "md5" {
227 h := md5.New()
228 h.Write([]byte(bs))
229 sid = fmt.Sprintf("%s", hex.EncodeToString(h.Sum(nil)))
230 } else if manager.hashfunc == "sha1" {
231 h := hmac.New(sha1.New, []byte(manager.hashkey))
232 fmt.Fprintf(h, "%s", sig)
233 sid = fmt.Sprintf("%s", hex.EncodeToString(h.Sum(nil)))
234 } else {
235 h := hmac.New(sha1.New, []byte(manager.hashkey))
236 fmt.Fprintf(h, "%s", sig)
237 sid = fmt.Sprintf("%s", hex.EncodeToString(h.Sum(nil)))
238 }
239 return
112 } 240 }
......
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!