481448fa by astaxie

modify session module

change a log
1 parent 95c65de9
...@@ -28,21 +28,21 @@ Then in you web app init the global session manager ...@@ -28,21 +28,21 @@ Then in you web app init the global session manager
28 * Use **memory** as provider: 28 * Use **memory** as provider:
29 29
30 func init() { 30 func init() {
31 globalSessions, _ = session.NewManager("memory", "gosessionid", 3600,"") 31 globalSessions, _ = session.NewManager("memory", `{"cookieName":"gosessionid","gclifetime":3600}`)
32 go globalSessions.GC() 32 go globalSessions.GC()
33 } 33 }
34 34
35 * Use **file** as provider, the last param is the path where you want file to be stored: 35 * Use **file** as provider, the last param is the path where you want file to be stored:
36 36
37 func init() { 37 func init() {
38 globalSessions, _ = session.NewManager("file", "gosessionid", 3600, "./tmp") 38 globalSessions, _ = session.NewManager("file",`{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig","./tmp"}`)
39 go globalSessions.GC() 39 go globalSessions.GC()
40 } 40 }
41 41
42 * Use **Redis** as provider, the last param is the Redis conn address,poolsize,password: 42 * Use **Redis** as provider, the last param is the Redis conn address,poolsize,password:
43 43
44 func init() { 44 func init() {
45 globalSessions, _ = session.NewManager("redis", "gosessionid", 3600, "127.0.0.1:6379,100,astaxie") 45 globalSessions, _ = session.NewManager("redis", `{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig","127.0.0.1:6379,100,astaxie"}`)
46 go globalSessions.GC() 46 go globalSessions.GC()
47 } 47 }
48 48
...@@ -50,15 +50,24 @@ Then in you web app init the global session manager ...@@ -50,15 +50,24 @@ Then in you web app init the global session manager
50 50
51 func init() { 51 func init() {
52 globalSessions, _ = session.NewManager( 52 globalSessions, _ = session.NewManager(
53 "mysql", "gosessionid", 3600, "username:password@protocol(address)/dbname?param=value") 53 "mysql", `{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig","username:password@protocol(address)/dbname?param=value"}`)
54 go globalSessions.GC() 54 go globalSessions.GC()
55 } 55 }
56 56
57 * Use **Cookie** as provider:
58
59 func init() {
60 globalSessions, _ = session.NewManager(
61 "cookie", `{"cookieName":"gosessionid","enableSetCookie":false,gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}`)
62 go globalSessions.GC()
63 }
64
65
57 Finally in the handlerfunc you can use it like this 66 Finally in the handlerfunc you can use it like this
58 67
59 func login(w http.ResponseWriter, r *http.Request) { 68 func login(w http.ResponseWriter, r *http.Request) {
60 sess := globalSessions.SessionStart(w, r) 69 sess := globalSessions.SessionStart(w, r)
61 defer sess.SessionRelease() 70 defer sess.SessionRelease(w)
62 username := sess.Get("username") 71 username := sess.Get("username")
63 fmt.Println(username) 72 fmt.Println(username)
64 if r.Method == "GET" { 73 if r.Method == "GET" {
...@@ -78,19 +87,19 @@ When you develop a web app, maybe you want to write own provider because you mus ...@@ -78,19 +87,19 @@ When you develop a web app, maybe you want to write own provider because you mus
78 87
79 Writing a provider is easy. You only need to define two struct types 88 Writing a provider is easy. You only need to define two struct types
80 (Session and Provider), which satisfy the interface definition. 89 (Session and Provider), which satisfy the interface definition.
81 Maybe you will find the **memory** provider as good example. 90 Maybe you will find the **memory** provider is a good example.
82 91
83 type SessionStore interface { 92 type SessionStore interface {
84 Set(key, value interface{}) error //set session value 93 Set(key, value interface{}) error //set session value
85 Get(key interface{}) interface{} //get session value 94 Get(key interface{}) interface{} //get session value
86 Delete(key interface{}) error //delete session value 95 Delete(key interface{}) error //delete session value
87 SessionID() string //back current sessionID 96 SessionID() string //back current sessionID
88 SessionRelease() // release the resource & save data to provider 97 SessionRelease(w http.ResponseWriter) // release the resource & save data to provider & return the data
89 Flush() error //delete all data 98 Flush() error //delete all data
90 } 99 }
91 100
92 type Provider interface { 101 type Provider interface {
93 SessionInit(maxlifetime int64, savePath string) error 102 SessionInit(gclifetime int64, config string) error
94 SessionRead(sid string) (SessionStore, error) 103 SessionRead(sid string) (SessionStore, error)
95 SessionExist(sid string) bool 104 SessionExist(sid string) bool
96 SessionRegenerate(oldsid, sid string) (SessionStore, error) 105 SessionRegenerate(oldsid, sid string) (SessionStore, error)
......
1 package session
2
3 import (
4 "crypto/aes"
5 "crypto/cipher"
6 "encoding/json"
7 "net/http"
8 "net/url"
9 "sync"
10 )
11
12 var cookiepder = &CookieProvider{}
13
14 type CookieSessionStore struct {
15 sid string
16 values map[interface{}]interface{} //session data
17 lock sync.RWMutex
18 }
19
20 func (st *CookieSessionStore) Set(key, value interface{}) error {
21 st.lock.Lock()
22 defer st.lock.Unlock()
23 st.values[key] = value
24 return nil
25 }
26
27 func (st *CookieSessionStore) Get(key interface{}) interface{} {
28 st.lock.RLock()
29 defer st.lock.RUnlock()
30 if v, ok := st.values[key]; ok {
31 return v
32 } else {
33 return nil
34 }
35 return nil
36 }
37
38 func (st *CookieSessionStore) Delete(key interface{}) error {
39 st.lock.Lock()
40 defer st.lock.Unlock()
41 delete(st.values, key)
42 return nil
43 }
44
45 func (st *CookieSessionStore) Flush() error {
46 st.lock.Lock()
47 defer st.lock.Unlock()
48 st.values = make(map[interface{}]interface{})
49 return nil
50 }
51
52 func (st *CookieSessionStore) SessionID() string {
53 return st.sid
54 }
55
56 func (st *CookieSessionStore) SessionRelease(w http.ResponseWriter) {
57 str, err := encodeCookie(cookiepder.block,
58 cookiepder.config.SecurityKey,
59 cookiepder.config.SecurityName,
60 st.values)
61 if err != nil {
62 return
63 }
64 cookie := &http.Cookie{Name: cookiepder.config.CookieName,
65 Value: url.QueryEscape(str),
66 Path: "/",
67 HttpOnly: true,
68 Secure: cookiepder.config.Secure}
69 http.SetCookie(w, cookie)
70 return
71 }
72
73 type cookieConfig struct {
74 SecurityKey string `json:"securityKey"`
75 BlockKey string `json:"blockKey"`
76 SecurityName string `json:"securityName"`
77 CookieName string `json:"cookieName"`
78 Secure bool `json:"secure"`
79 Maxage int `json:"maxage"`
80 }
81
82 type CookieProvider struct {
83 maxlifetime int64
84 config *cookieConfig
85 block cipher.Block
86 }
87
88 func (pder *CookieProvider) SessionInit(maxlifetime int64, config string) error {
89 pder.config = &cookieConfig{}
90 err := json.Unmarshal([]byte(config), pder.config)
91 if err != nil {
92 return err
93 }
94 if pder.config.BlockKey == "" {
95 pder.config.BlockKey = string(generateRandomKey(16))
96 }
97 if pder.config.SecurityName == "" {
98 pder.config.SecurityName = string(generateRandomKey(20))
99 }
100 pder.block, err = aes.NewCipher([]byte(pder.config.BlockKey))
101 if err != nil {
102 return err
103 }
104 return nil
105 }
106
107 func (pder *CookieProvider) SessionRead(sid string) (SessionStore, error) {
108 kv := make(map[interface{}]interface{})
109 kv, _ = decodeCookie(pder.block,
110 pder.config.SecurityKey,
111 pder.config.SecurityName,
112 sid, pder.maxlifetime)
113 rs := &CookieSessionStore{sid: sid, values: kv}
114 return rs, nil
115 }
116
117 func (pder *CookieProvider) SessionExist(sid string) bool {
118 return true
119 }
120
121 func (pder *CookieProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) {
122 return nil, nil
123 }
124
125 func (pder *CookieProvider) SessionDestroy(sid string) error {
126 return nil
127 }
128
129 func (pder *CookieProvider) SessionGC() {
130 return
131 }
132
133 func (pder *CookieProvider) SessionAll() int {
134 return 0
135 }
136
137 func (pder *CookieProvider) SessionUpdate(sid string) error {
138 return nil
139 }
140
141 func init() {
142 Register("cookie", cookiepder)
143 }
...@@ -5,6 +5,7 @@ import ( ...@@ -5,6 +5,7 @@ import (
5 "fmt" 5 "fmt"
6 "io" 6 "io"
7 "io/ioutil" 7 "io/ioutil"
8 "net/http"
8 "os" 9 "os"
9 "path" 10 "path"
10 "path/filepath" 11 "path/filepath"
...@@ -60,7 +61,7 @@ func (fs *FileSessionStore) SessionID() string { ...@@ -60,7 +61,7 @@ func (fs *FileSessionStore) SessionID() string {
60 return fs.sid 61 return fs.sid
61 } 62 }
62 63
63 func (fs *FileSessionStore) SessionRelease() { 64 func (fs *FileSessionStore) SessionRelease(w http.ResponseWriter) {
64 defer fs.f.Close() 65 defer fs.f.Close()
65 b, err := encodeGob(fs.values) 66 b, err := encodeGob(fs.values)
66 if err != nil { 67 if err != nil {
......
1 package session
2
3 import (
4 "bytes"
5 "encoding/gob"
6 )
7
8 func init() {
9 gob.Register([]interface{}{})
10 gob.Register(map[int]interface{}{})
11 gob.Register(map[string]interface{}{})
12 gob.Register(map[interface{}]interface{}{})
13 gob.Register(map[string]string{})
14 gob.Register(map[int]string{})
15 gob.Register(map[int]int{})
16 gob.Register(map[int]int64{})
17 }
18
19 func encodeGob(obj map[interface{}]interface{}) ([]byte, error) {
20 buf := bytes.NewBuffer(nil)
21 enc := gob.NewEncoder(buf)
22 err := enc.Encode(obj)
23 if err != nil {
24 return []byte(""), err
25 }
26 return buf.Bytes(), nil
27 }
28
29 func decodeGob(encoded []byte) (map[interface{}]interface{}, error) {
30 buf := bytes.NewBuffer(encoded)
31 dec := gob.NewDecoder(buf)
32 var out map[interface{}]interface{}
33 err := dec.Decode(&out)
34 if err != nil {
35 return nil, err
36 }
37 return out, nil
38 }
...@@ -2,6 +2,7 @@ package session ...@@ -2,6 +2,7 @@ package session
2 2
3 import ( 3 import (
4 "container/list" 4 "container/list"
5 "net/http"
5 "sync" 6 "sync"
6 "time" 7 "time"
7 ) 8 )
...@@ -9,9 +10,9 @@ import ( ...@@ -9,9 +10,9 @@ import (
9 var mempder = &MemProvider{list: list.New(), sessions: make(map[string]*list.Element)} 10 var mempder = &MemProvider{list: list.New(), sessions: make(map[string]*list.Element)}
10 11
11 type MemSessionStore struct { 12 type MemSessionStore struct {
12 sid string //session id唯一标示 13 sid string //session id
13 timeAccessed time.Time //最后访问时间 14 timeAccessed time.Time //last access time
14 value map[interface{}]interface{} //session里面存储的值 15 value map[interface{}]interface{} //session store
15 lock sync.RWMutex 16 lock sync.RWMutex
16 } 17 }
17 18
...@@ -51,8 +52,7 @@ func (st *MemSessionStore) SessionID() string { ...@@ -51,8 +52,7 @@ func (st *MemSessionStore) SessionID() string {
51 return st.sid 52 return st.sid
52 } 53 }
53 54
54 func (st *MemSessionStore) SessionRelease() { 55 func (st *MemSessionStore) SessionRelease(w http.ResponseWriter) {
55
56 } 56 }
57 57
58 type MemProvider struct { 58 type MemProvider struct {
......
1 package session
2
3 import (
4 "net/http"
5 "net/http/httptest"
6 "strings"
7 "testing"
8 )
9
10 func TestMem(t *testing.T) {
11 globalSessions, _ := NewManager("memory", `{"cookieName":"gosessionid","gclifetime":10}`)
12 go globalSessions.GC()
13 r, _ := http.NewRequest("GET", "/", nil)
14 w := httptest.NewRecorder()
15 sess := globalSessions.SessionStart(w, r)
16 defer sess.SessionRelease(w)
17 err := sess.Set("username", "astaxie")
18 if err != nil {
19 t.Fatal("set error,", err)
20 }
21 if username := sess.Get("username"); username != "astaxie" {
22 t.Fatal("get username error")
23 }
24 if cookiestr := w.Header().Get("Set-Cookie"); cookiestr == "" {
25 t.Fatal("setcookie error")
26 } else {
27 parts := strings.Split(strings.TrimSpace(cookiestr), ";")
28 for k, v := range parts {
29 nameval := strings.Split(v, "=")
30 if k == 0 && nameval[0] != "gosessionid" {
31 t.Fatal("error")
32 }
33 }
34 }
35 }
...@@ -9,6 +9,7 @@ package session ...@@ -9,6 +9,7 @@ package session
9 9
10 import ( 10 import (
11 "database/sql" 11 "database/sql"
12 "net/http"
12 "sync" 13 "sync"
13 "time" 14 "time"
14 15
...@@ -60,7 +61,7 @@ func (st *MysqlSessionStore) SessionID() string { ...@@ -60,7 +61,7 @@ func (st *MysqlSessionStore) SessionID() string {
60 return st.sid 61 return st.sid
61 } 62 }
62 63
63 func (st *MysqlSessionStore) SessionRelease() { 64 func (st *MysqlSessionStore) SessionRelease(w http.ResponseWriter) {
64 defer st.c.Close() 65 defer st.c.Close()
65 if len(st.values) > 0 { 66 if len(st.values) > 0 {
66 b, err := encodeGob(st.values) 67 b, err := encodeGob(st.values)
......
1 package session 1 package session
2 2
3 import ( 3 import (
4 "net/http"
4 "strconv" 5 "strconv"
5 "strings" 6 "strings"
6 "sync" 7 "sync"
...@@ -58,7 +59,7 @@ func (rs *RedisSessionStore) SessionID() string { ...@@ -58,7 +59,7 @@ func (rs *RedisSessionStore) SessionID() string {
58 return rs.sid 59 return rs.sid
59 } 60 }
60 61
61 func (rs *RedisSessionStore) SessionRelease() { 62 func (rs *RedisSessionStore) SessionRelease(w http.ResponseWriter) {
62 defer rs.c.Close() 63 defer rs.c.Close()
63 if len(rs.values) > 0 { 64 if len(rs.values) > 0 {
64 b, err := encodeGob(rs.values) 65 b, err := encodeGob(rs.values)
......
1 package session 1 package session
2 2
3 import ( 3 import (
4 "crypto/aes"
5 "encoding/json"
4 "testing" 6 "testing"
5 ) 7 )
6 8
...@@ -26,3 +28,82 @@ func Test_gob(t *testing.T) { ...@@ -26,3 +28,82 @@ func Test_gob(t *testing.T) {
26 t.Error("decode int error") 28 t.Error("decode int error")
27 } 29 }
28 } 30 }
31
32 func TestGenerate(t *testing.T) {
33 str := generateRandomKey(20)
34 if len(str) != 20 {
35 t.Fatal("generate length is not equal to 20")
36 }
37 }
38
39 func TestCookieEncodeDecode(t *testing.T) {
40 hashKey := "testhashKey"
41 blockkey := generateRandomKey(16)
42 block, err := aes.NewCipher(blockkey)
43 if err != nil {
44 t.Fatal("NewCipher:", err)
45 }
46 securityName := string(generateRandomKey(20))
47 val := make(map[interface{}]interface{})
48 val["name"] = "astaxie"
49 val["gender"] = "male"
50 str, err := encodeCookie(block, hashKey, securityName, val)
51 if err != nil {
52 t.Fatal("encodeCookie:", err)
53 }
54 dst := make(map[interface{}]interface{})
55 dst, err = decodeCookie(block, hashKey, securityName, str, 3600)
56 if err != nil {
57 t.Fatal("decodeCookie", err)
58 }
59 if dst["name"] != "astaxie" {
60 t.Fatal("dst get map error")
61 }
62 if dst["gender"] != "male" {
63 t.Fatal("dst get map error")
64 }
65 }
66
67 func TestParseConfig(t *testing.T) {
68 s := `{"cookieName":"gosessionid","gclifetime":3600}`
69 cf := new(managerConfig)
70 cf.EnableSetCookie = true
71 err := json.Unmarshal([]byte(s), cf)
72 if err != nil {
73 t.Fatal("parse json error,", err)
74 }
75 if cf.CookieName != "gosessionid" {
76 t.Fatal("parseconfig get cookiename error")
77 }
78 if cf.Gclifetime != 3600 {
79 t.Fatal("parseconfig get gclifetime error")
80 }
81
82 cc := `{"cookieName":"gosessionid","enableSetCookie":false,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}`
83 cf2 := new(managerConfig)
84 cf2.EnableSetCookie = true
85 err = json.Unmarshal([]byte(cc), cf2)
86 if err != nil {
87 t.Fatal("parse json error,", err)
88 }
89 if cf2.CookieName != "gosessionid" {
90 t.Fatal("parseconfig get cookiename error")
91 }
92 if cf2.Gclifetime != 3600 {
93 t.Fatal("parseconfig get gclifetime error")
94 }
95 if cf2.EnableSetCookie != false {
96 t.Fatal("parseconfig get enableSetCookie error")
97 }
98 cconfig := new(cookieConfig)
99 err = json.Unmarshal([]byte(cf2.ProviderConfig), cconfig)
100 if err != nil {
101 t.Fatal("parse ProviderConfig err,", err)
102 }
103 if cconfig.CookieName != "gosessionid" {
104 t.Fatal("ProviderConfig get cookieName error")
105 }
106 if cconfig.SecurityKey != "beegocookiehashkey" {
107 t.Fatal("ProviderConfig get securityKey error")
108 }
109 }
......
1 package session
2
3 import (
4 "bytes"
5 "crypto/cipher"
6 "crypto/hmac"
7 "crypto/rand"
8 "crypto/sha1"
9 "crypto/subtle"
10 "encoding/base64"
11 "encoding/gob"
12 "errors"
13 "fmt"
14 "io"
15 "strconv"
16 "time"
17 )
18
19 func init() {
20 gob.Register([]interface{}{})
21 gob.Register(map[int]interface{}{})
22 gob.Register(map[string]interface{}{})
23 gob.Register(map[interface{}]interface{}{})
24 gob.Register(map[string]string{})
25 gob.Register(map[int]string{})
26 gob.Register(map[int]int{})
27 gob.Register(map[int]int64{})
28 }
29
30 func encodeGob(obj map[interface{}]interface{}) ([]byte, error) {
31 buf := bytes.NewBuffer(nil)
32 enc := gob.NewEncoder(buf)
33 err := enc.Encode(obj)
34 if err != nil {
35 return []byte(""), err
36 }
37 return buf.Bytes(), nil
38 }
39
40 func decodeGob(encoded []byte) (map[interface{}]interface{}, error) {
41 buf := bytes.NewBuffer(encoded)
42 dec := gob.NewDecoder(buf)
43 var out map[interface{}]interface{}
44 err := dec.Decode(&out)
45 if err != nil {
46 return nil, err
47 }
48 return out, nil
49 }
50
51 // generateRandomKey creates a random key with the given strength.
52 func generateRandomKey(strength int) []byte {
53 k := make([]byte, strength)
54 if _, err := io.ReadFull(rand.Reader, k); err != nil {
55 return nil
56 }
57 return k
58 }
59
60 // Encryption -----------------------------------------------------------------
61
62 // encrypt encrypts a value using the given block in counter mode.
63 //
64 // A random initialization vector (http://goo.gl/zF67k) with the length of the
65 // block size is prepended to the resulting ciphertext.
66 func encrypt(block cipher.Block, value []byte) ([]byte, error) {
67 iv := generateRandomKey(block.BlockSize())
68 if iv == nil {
69 return nil, errors.New("encrypt: failed to generate random iv")
70 }
71 // Encrypt it.
72 stream := cipher.NewCTR(block, iv)
73 stream.XORKeyStream(value, value)
74 // Return iv + ciphertext.
75 return append(iv, value...), nil
76 }
77
78 // decrypt decrypts a value using the given block in counter mode.
79 //
80 // The value to be decrypted must be prepended by a initialization vector
81 // (http://goo.gl/zF67k) with the length of the block size.
82 func decrypt(block cipher.Block, value []byte) ([]byte, error) {
83 size := block.BlockSize()
84 if len(value) > size {
85 // Extract iv.
86 iv := value[:size]
87 // Extract ciphertext.
88 value = value[size:]
89 // Decrypt it.
90 stream := cipher.NewCTR(block, iv)
91 stream.XORKeyStream(value, value)
92 return value, nil
93 }
94 return nil, errors.New("decrypt: the value could not be decrypted")
95 }
96
97 func encodeCookie(block cipher.Block, hashKey, name string, value map[interface{}]interface{}) (string, error) {
98 var err error
99 var b []byte
100 // 1. encodeGob.
101 if b, err = encodeGob(value); err != nil {
102 return "", err
103 }
104 // 2. Encrypt (optional).
105 if b, err = encrypt(block, b); err != nil {
106 return "", err
107 }
108 b = encode(b)
109 // 3. Create MAC for "name|date|value". Extra pipe to be used later.
110 b = []byte(fmt.Sprintf("%s|%d|%s|", name, time.Now().UTC().Unix(), b))
111 h := hmac.New(sha1.New, []byte(hashKey))
112 h.Write(b)
113 sig := h.Sum(nil)
114 // Append mac, remove name.
115 b = append(b, sig...)[len(name)+1:]
116 // 4. Encode to base64.
117 b = encode(b)
118 // Done.
119 return string(b), nil
120 }
121
122 func decodeCookie(block cipher.Block, hashKey, name, value string, gcmaxlifetime int64) (map[interface{}]interface{}, error) {
123 // 1. Decode from base64.
124 b, err := decode([]byte(value))
125 if err != nil {
126 return nil, err
127 }
128 // 2. Verify MAC. Value is "date|value|mac".
129 parts := bytes.SplitN(b, []byte("|"), 3)
130 if len(parts) != 3 {
131 return nil, errors.New("Decode: invalid value %v")
132 }
133
134 b = append([]byte(name+"|"), b[:len(b)-len(parts[2])]...)
135 h := hmac.New(sha1.New, []byte(hashKey))
136 h.Write(b)
137 sig := h.Sum(nil)
138 if len(sig) != len(parts[2]) || subtle.ConstantTimeCompare(sig, parts[2]) != 1 {
139 return nil, errors.New("Decode: the value is not valid")
140 }
141 // 3. Verify date ranges.
142 var t1 int64
143 if t1, err = strconv.ParseInt(string(parts[0]), 10, 64); err != nil {
144 return nil, errors.New("Decode: invalid timestamp")
145 }
146 t2 := time.Now().UTC().Unix()
147 if t1 > t2 {
148 return nil, errors.New("Decode: timestamp is too new")
149 }
150 if t1 < t2-gcmaxlifetime {
151 return nil, errors.New("Decode: expired timestamp")
152 }
153 // 4. Decrypt (optional).
154 b, err = decode(parts[1])
155 if err != nil {
156 return nil, err
157 }
158 if b, err = decrypt(block, b); err != nil {
159 return nil, err
160 }
161 // 5. decodeGob.
162 if dst, err := decodeGob(b); err != nil {
163 return nil, err
164 } else {
165 return dst, nil
166 }
167 // Done.
168 return nil, nil
169 }
170
171 // Encoding -------------------------------------------------------------------
172
173 // encode encodes a value using base64.
174 func encode(value []byte) []byte {
175 encoded := make([]byte, base64.URLEncoding.EncodedLen(len(value)))
176 base64.URLEncoding.Encode(encoded, value)
177 return encoded
178 }
179
180 // decode decodes a cookie using base64.
181 func decode(value []byte) ([]byte, error) {
182 decoded := make([]byte, base64.URLEncoding.DecodedLen(len(value)))
183 b, err := base64.URLEncoding.Decode(decoded, value)
184 if err != nil {
185 return nil, err
186 }
187 return decoded[:b], nil
188 }
...@@ -6,6 +6,7 @@ import ( ...@@ -6,6 +6,7 @@ import (
6 "crypto/rand" 6 "crypto/rand"
7 "crypto/sha1" 7 "crypto/sha1"
8 "encoding/hex" 8 "encoding/hex"
9 "encoding/json"
9 "fmt" 10 "fmt"
10 "io" 11 "io"
11 "net/http" 12 "net/http"
...@@ -18,12 +19,12 @@ type SessionStore interface { ...@@ -18,12 +19,12 @@ type SessionStore interface {
18 Get(key interface{}) interface{} //get session value 19 Get(key interface{}) interface{} //get session value
19 Delete(key interface{}) error //delete session value 20 Delete(key interface{}) error //delete session value
20 SessionID() string //back current sessionID 21 SessionID() string //back current sessionID
21 SessionRelease() // release the resource & save data to provider 22 SessionRelease(w http.ResponseWriter) // release the resource & save data to provider & return the data
22 Flush() error //delete all data 23 Flush() error //delete all data
23 } 24 }
24 25
25 type Provider interface { 26 type Provider interface {
26 SessionInit(maxlifetime int64, savePath string) error 27 SessionInit(gclifetime int64, config string) error
27 SessionRead(sid string) (SessionStore, error) 28 SessionRead(sid string) (SessionStore, error)
28 SessionExist(sid string) bool 29 SessionExist(sid string) bool
29 SessionRegenerate(oldsid, sid string) (SessionStore, error) 30 SessionRegenerate(oldsid, sid string) (SessionStore, error)
...@@ -47,15 +48,21 @@ func Register(name string, provide Provider) { ...@@ -47,15 +48,21 @@ func Register(name string, provide Provider) {
47 provides[name] = provide 48 provides[name] = provide
48 } 49 }
49 50
51 type managerConfig struct {
52 CookieName string `json:"cookieName"`
53 EnableSetCookie bool `json:"enableSetCookie,omitempty"`
54 Gclifetime int64 `json:"gclifetime"`
55 Maxage int `json:"maxage"`
56 Secure bool `json:"secure"`
57 SessionIDHashFunc string `json:"sessionIDHashFunc"`
58 SessionIDHashKey string `json:"sessionIDHashKey"`
59 CookieLifeTime int64 `json:"cookieLifeTime"`
60 ProviderConfig string `json:"providerConfig"`
61 }
62
50 type Manager struct { 63 type Manager struct {
51 cookieName string //private cookiename
52 provider Provider 64 provider Provider
53 maxlifetime int64 65 config *managerConfig
54 hashfunc string //support md5 & sha1
55 hashkey string
56 maxage int //cookielifetime
57 secure bool
58 options []interface{}
59 } 66 }
60 67
61 //options 68 //options
...@@ -63,74 +70,49 @@ type Manager struct { ...@@ -63,74 +70,49 @@ type Manager struct {
63 //2. hashfunc default sha1 70 //2. hashfunc default sha1
64 //3. hashkey default beegosessionkey 71 //3. hashkey default beegosessionkey
65 //4. maxage default is none 72 //4. maxage default is none
66 func NewManager(provideName, cookieName string, maxlifetime int64, savePath string, options ...interface{}) (*Manager, error) { 73 func NewManager(provideName, config string) (*Manager, error) {
67 provider, ok := provides[provideName] 74 provider, ok := provides[provideName]
68 if !ok { 75 if !ok {
69 return nil, fmt.Errorf("session: unknown provide %q (forgotten import?)", provideName) 76 return nil, fmt.Errorf("session: unknown provide %q (forgotten import?)", provideName)
70 } 77 }
71 provider.SessionInit(maxlifetime, savePath) 78 cf := new(managerConfig)
72 secure := false 79 cf.EnableSetCookie = true
73 if len(options) > 0 { 80 err := json.Unmarshal([]byte(config), cf)
74 secure = options[0].(bool) 81 if err != nil {
75 } 82 return nil, err
76 hashfunc := "sha1"
77 if len(options) > 1 {
78 hashfunc = options[1].(string)
79 }
80 hashkey := "beegosessionkey"
81 if len(options) > 2 {
82 hashkey = options[2].(string)
83 }
84 maxage := -1
85 if len(options) > 3 {
86 switch options[3].(type) {
87 case int:
88 if options[3].(int) > 0 {
89 maxage = options[3].(int)
90 } else if options[3].(int) < 0 {
91 maxage = 0
92 }
93 case int64:
94 if options[3].(int64) > 0 {
95 maxage = int(options[3].(int64))
96 } else if options[3].(int64) < 0 {
97 maxage = 0
98 }
99 case int32:
100 if options[3].(int32) > 0 {
101 maxage = int(options[3].(int32))
102 } else if options[3].(int32) < 0 {
103 maxage = 0
104 } 83 }
84 provider.SessionInit(cf.Gclifetime, cf.ProviderConfig)
85
86 if cf.SessionIDHashFunc == "" {
87 cf.SessionIDHashFunc = "sha1"
105 } 88 }
89 if cf.SessionIDHashKey == "" {
90 cf.SessionIDHashKey = string(generateRandomKey(16))
106 } 91 }
92
107 return &Manager{ 93 return &Manager{
108 provider: provider, 94 provider,
109 cookieName: cookieName, 95 cf,
110 maxlifetime: maxlifetime,
111 hashfunc: hashfunc,
112 hashkey: hashkey,
113 maxage: maxage,
114 secure: secure,
115 options: options,
116 }, nil 96 }, nil
117 } 97 }
118 98
119 //get Session 99 //get Session
120 func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (session SessionStore) { 100 func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (session SessionStore) {
121 cookie, err := r.Cookie(manager.cookieName) 101 cookie, err := r.Cookie(manager.config.CookieName)
122 if err != nil || cookie.Value == "" { 102 if err != nil || cookie.Value == "" {
123 sid := manager.sessionId(r) 103 sid := manager.sessionId(r)
124 session, _ = manager.provider.SessionRead(sid) 104 session, _ = manager.provider.SessionRead(sid)
125 cookie = &http.Cookie{Name: manager.cookieName, 105 cookie = &http.Cookie{Name: manager.config.CookieName,
126 Value: url.QueryEscape(sid), 106 Value: url.QueryEscape(sid),
127 Path: "/", 107 Path: "/",
128 HttpOnly: true, 108 HttpOnly: true,
129 Secure: manager.secure} 109 Secure: manager.config.Secure}
130 if manager.maxage >= 0 { 110 if manager.config.Maxage >= 0 {
131 cookie.MaxAge = manager.maxage 111 cookie.MaxAge = manager.config.Maxage
132 } 112 }
113 if manager.config.EnableSetCookie {
133 http.SetCookie(w, cookie) 114 http.SetCookie(w, cookie)
115 }
134 r.AddCookie(cookie) 116 r.AddCookie(cookie)
135 } else { 117 } else {
136 sid, _ := url.QueryUnescape(cookie.Value) 118 sid, _ := url.QueryUnescape(cookie.Value)
...@@ -139,15 +121,17 @@ func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (se ...@@ -139,15 +121,17 @@ func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (se
139 } else { 121 } else {
140 sid = manager.sessionId(r) 122 sid = manager.sessionId(r)
141 session, _ = manager.provider.SessionRead(sid) 123 session, _ = manager.provider.SessionRead(sid)
142 cookie = &http.Cookie{Name: manager.cookieName, 124 cookie = &http.Cookie{Name: manager.config.CookieName,
143 Value: url.QueryEscape(sid), 125 Value: url.QueryEscape(sid),
144 Path: "/", 126 Path: "/",
145 HttpOnly: true, 127 HttpOnly: true,
146 Secure: manager.secure} 128 Secure: manager.config.Secure}
147 if manager.maxage >= 0 { 129 if manager.config.Maxage >= 0 {
148 cookie.MaxAge = manager.maxage 130 cookie.MaxAge = manager.config.Maxage
149 } 131 }
132 if manager.config.EnableSetCookie {
150 http.SetCookie(w, cookie) 133 http.SetCookie(w, cookie)
134 }
151 r.AddCookie(cookie) 135 r.AddCookie(cookie)
152 } 136 }
153 } 137 }
...@@ -156,13 +140,17 @@ func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (se ...@@ -156,13 +140,17 @@ func (manager *Manager) SessionStart(w http.ResponseWriter, r *http.Request) (se
156 140
157 //Destroy sessionid 141 //Destroy sessionid
158 func (manager *Manager) SessionDestroy(w http.ResponseWriter, r *http.Request) { 142 func (manager *Manager) SessionDestroy(w http.ResponseWriter, r *http.Request) {
159 cookie, err := r.Cookie(manager.cookieName) 143 cookie, err := r.Cookie(manager.config.CookieName)
160 if err != nil || cookie.Value == "" { 144 if err != nil || cookie.Value == "" {
161 return 145 return
162 } else { 146 } else {
163 manager.provider.SessionDestroy(cookie.Value) 147 manager.provider.SessionDestroy(cookie.Value)
164 expiration := time.Now() 148 expiration := time.Now()
165 cookie := http.Cookie{Name: manager.cookieName, Path: "/", HttpOnly: true, Expires: expiration, MaxAge: -1} 149 cookie := http.Cookie{Name: manager.config.CookieName,
150 Path: "/",
151 HttpOnly: true,
152 Expires: expiration,
153 MaxAge: -1}
166 http.SetCookie(w, &cookie) 154 http.SetCookie(w, &cookie)
167 } 155 }
168 } 156 }
...@@ -174,20 +162,20 @@ func (manager *Manager) GetProvider(sid string) (sessions SessionStore, err erro ...@@ -174,20 +162,20 @@ func (manager *Manager) GetProvider(sid string) (sessions SessionStore, err erro
174 162
175 func (manager *Manager) GC() { 163 func (manager *Manager) GC() {
176 manager.provider.SessionGC() 164 manager.provider.SessionGC()
177 time.AfterFunc(time.Duration(manager.maxlifetime)*time.Second, func() { manager.GC() }) 165 time.AfterFunc(time.Duration(manager.config.Gclifetime)*time.Second, func() { manager.GC() })
178 } 166 }
179 167
180 func (manager *Manager) SessionRegenerateId(w http.ResponseWriter, r *http.Request) (session SessionStore) { 168 func (manager *Manager) SessionRegenerateId(w http.ResponseWriter, r *http.Request) (session SessionStore) {
181 sid := manager.sessionId(r) 169 sid := manager.sessionId(r)
182 cookie, err := r.Cookie(manager.cookieName) 170 cookie, err := r.Cookie(manager.config.CookieName)
183 if err != nil && cookie.Value == "" { 171 if err != nil && cookie.Value == "" {
184 //delete old cookie 172 //delete old cookie
185 session, _ = manager.provider.SessionRead(sid) 173 session, _ = manager.provider.SessionRead(sid)
186 cookie = &http.Cookie{Name: manager.cookieName, 174 cookie = &http.Cookie{Name: manager.config.CookieName,
187 Value: url.QueryEscape(sid), 175 Value: url.QueryEscape(sid),
188 Path: "/", 176 Path: "/",
189 HttpOnly: true, 177 HttpOnly: true,
190 Secure: manager.secure, 178 Secure: manager.config.Secure,
191 } 179 }
192 } else { 180 } else {
193 oldsid, _ := url.QueryUnescape(cookie.Value) 181 oldsid, _ := url.QueryUnescape(cookie.Value)
...@@ -196,8 +184,8 @@ func (manager *Manager) SessionRegenerateId(w http.ResponseWriter, r *http.Reque ...@@ -196,8 +184,8 @@ func (manager *Manager) SessionRegenerateId(w http.ResponseWriter, r *http.Reque
196 cookie.HttpOnly = true 184 cookie.HttpOnly = true
197 cookie.Path = "/" 185 cookie.Path = "/"
198 } 186 }
199 if manager.maxage >= 0 { 187 if manager.config.Maxage >= 0 {
200 cookie.MaxAge = manager.maxage 188 cookie.MaxAge = manager.config.Maxage
201 } 189 }
202 http.SetCookie(w, cookie) 190 http.SetCookie(w, cookie)
203 r.AddCookie(cookie) 191 r.AddCookie(cookie)
...@@ -209,12 +197,12 @@ func (manager *Manager) GetActiveSession() int { ...@@ -209,12 +197,12 @@ func (manager *Manager) GetActiveSession() int {
209 } 197 }
210 198
211 func (manager *Manager) SetHashFunc(hasfunc, hashkey string) { 199 func (manager *Manager) SetHashFunc(hasfunc, hashkey string) {
212 manager.hashfunc = hasfunc 200 manager.config.SessionIDHashFunc = hasfunc
213 manager.hashkey = hashkey 201 manager.config.SessionIDHashKey = hashkey
214 } 202 }
215 203
216 func (manager *Manager) SetSecure(secure bool) { 204 func (manager *Manager) SetSecure(secure bool) {
217 manager.secure = secure 205 manager.config.Secure = secure
218 } 206 }
219 207
220 //remote_addr cruunixnano randdata 208 //remote_addr cruunixnano randdata
...@@ -224,16 +212,16 @@ func (manager *Manager) sessionId(r *http.Request) (sid string) { ...@@ -224,16 +212,16 @@ func (manager *Manager) sessionId(r *http.Request) (sid string) {
224 return "" 212 return ""
225 } 213 }
226 sig := fmt.Sprintf("%s%d%s", r.RemoteAddr, time.Now().UnixNano(), bs) 214 sig := fmt.Sprintf("%s%d%s", r.RemoteAddr, time.Now().UnixNano(), bs)
227 if manager.hashfunc == "md5" { 215 if manager.config.SessionIDHashFunc == "md5" {
228 h := md5.New() 216 h := md5.New()
229 h.Write([]byte(sig)) 217 h.Write([]byte(sig))
230 sid = hex.EncodeToString(h.Sum(nil)) 218 sid = hex.EncodeToString(h.Sum(nil))
231 } else if manager.hashfunc == "sha1" { 219 } else if manager.config.SessionIDHashFunc == "sha1" {
232 h := hmac.New(sha1.New, []byte(manager.hashkey)) 220 h := hmac.New(sha1.New, []byte(manager.config.SessionIDHashKey))
233 fmt.Fprintf(h, "%s", sig) 221 fmt.Fprintf(h, "%s", sig)
234 sid = hex.EncodeToString(h.Sum(nil)) 222 sid = hex.EncodeToString(h.Sum(nil))
235 } else { 223 } else {
236 h := hmac.New(sha1.New, []byte(manager.hashkey)) 224 h := hmac.New(sha1.New, []byte(manager.config.SessionIDHashKey))
237 fmt.Fprintf(h, "%s", sig) 225 fmt.Fprintf(h, "%s", sig)
238 sid = hex.EncodeToString(h.Sum(nil)) 226 sid = hex.EncodeToString(h.Sum(nil))
239 } 227 }
......
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!