6b5108ef by 傅小黑

Merge pull request #1 from fuxiaohei/develop

merge develop
2 parents ecfd11ad 828a3060
......@@ -35,9 +35,5 @@ More info [beego.me](http://beego.me)
beego is licensed under the Apache Licence, Version 2.0
(http://www.apache.org/licenses/LICENSE-2.0.html).
## Use case
- Displaying API documentation: [gowalker](https://github.com/Unknwon/gowalker)
- seocms: [seocms](https://github.com/chinakr/seocms)
- CMS: [toropress](https://github.com/insionng/toropress)
[![Clone in Koding](http://learn.koding.com/btn/clone_d.png)][koding]
[koding]: https://koding.com/Teamwork?import=https://github.com/astaxie/beego/archive/master.zip&c=git1
\ No newline at end of file
......
......@@ -118,6 +118,14 @@ func (app *App) AutoRouter(c ControllerInterface) *App {
return app
}
// AutoRouterWithPrefix adds beego-defined controller handler with prefix.
// if beego.AutoPrefix("/admin",&MainContorlller{}) and MainController has methods List and Page,
// visit the url /admin/main/list to exec List function or /admin/main/page to exec Page function.
func (app *App) AutoRouterWithPrefix(prefix string, c ControllerInterface) *App {
app.Handlers.AddAutoPrefix(prefix, c)
return app
}
// UrlFor creates a url with another registered controller handler with params.
// The endpoint is formed as path.controller.name to defined the controller method which will run.
// The values need key-pair data to assign into controller method.
......
......@@ -4,6 +4,7 @@ import (
"net/http"
"path"
"path/filepath"
"strconv"
"strings"
"github.com/astaxie/beego/middleware"
......@@ -13,6 +14,76 @@ import (
// beego web framework version.
const VERSION = "1.0.1"
type hookfunc func() error //hook function to run
var hooks []hookfunc //hook function slice to store the hookfunc
type groupRouter struct {
pattern string
controller ControllerInterface
mappingMethods string
}
// RouterGroups which will store routers
type GroupRouters []groupRouter
// Get a new GroupRouters
func NewGroupRouters() GroupRouters {
return make([]groupRouter, 0)
}
// Add Router in the GroupRouters
// it is for plugin or module to register router
func (gr GroupRouters) AddRouter(pattern string, c ControllerInterface, mappingMethod ...string) {
var newRG groupRouter
if len(mappingMethod) > 0 {
newRG = groupRouter{
pattern,
c,
mappingMethod[0],
}
} else {
newRG = groupRouter{
pattern,
c,
"",
}
}
gr = append(gr, newRG)
}
func (gr GroupRouters) AddAuto(c ControllerInterface) {
newRG := groupRouter{
"",
c,
"",
}
gr = append(gr, newRG)
}
// AddGroupRouter with the prefix
// it will register the router in BeeApp
// the follow code is write in modules:
// GR:=NewGroupRouters()
// GR.AddRouter("/login",&UserController,"get:Login")
// GR.AddRouter("/logout",&UserController,"get:Logout")
// GR.AddRouter("/register",&UserController,"get:Reg")
// the follow code is write in app:
// import "github.com/beego/modules/auth"
// AddRouterGroup("/admin", auth.GR)
func AddGroupRouter(prefix string, groups GroupRouters) *App {
for _, v := range groups {
if v.pattern == "" {
BeeApp.AutoRouterWithPrefix(prefix, v.controller)
} else if v.mappingMethods != "" {
BeeApp.Router(prefix+v.pattern, v.controller, v.mappingMethods)
} else {
BeeApp.Router(prefix+v.pattern, v.controller)
}
}
return BeeApp
}
// Router adds a patterned controller handler to BeeApp.
// it's an alias method of App.Router.
func Router(rootpath string, c ControllerInterface, mappingMethods ...string) *App {
......@@ -36,6 +107,13 @@ func AutoRouter(c ControllerInterface) *App {
return BeeApp
}
// AutoPrefix adds controller handler to BeeApp with prefix.
// it's same to App.AutoRouterWithPrefix.
func AutoPrefix(prefix string, c ControllerInterface) *App {
BeeApp.AutoRouterWithPrefix(prefix, c)
return BeeApp
}
// ErrorHandler registers http.HandlerFunc to each http err code string.
// usage:
// beego.ErrorHandler("404",NotFound)
......@@ -87,6 +165,12 @@ func InsertFilter(pattern string, pos int, filter FilterFunc) *App {
return BeeApp
}
// The hookfunc will run in beego.Run()
// such as sessionInit, middlerware start, buildtemplate, admin start
func AddAPPStartHook(hf hookfunc) {
hooks = append(hooks, hf)
}
// Run beego application.
// it's alias of App.Run.
func Run() {
......@@ -99,18 +183,32 @@ func Run() {
}
}
//init mime
initMime()
// do hooks function
for _, hk := range hooks {
err := hk()
if err != nil {
panic(err)
}
}
if SessionOn {
GlobalSessions, _ = session.NewManager(SessionProvider,
SessionName,
SessionGCMaxLifetime,
SessionSavePath,
HttpTLS,
SessionHashFunc,
SessionHashKey,
SessionCookieLifeTime)
var err error
sessionConfig := AppConfig.String("sessionConfig")
if sessionConfig == "" {
sessionConfig = `{"cookieName":"` + SessionName + `",` +
`"gclifetime":` + strconv.FormatInt(SessionGCMaxLifetime, 10) + `,` +
`"providerConfig":"` + SessionSavePath + `",` +
`"secure":` + strconv.FormatBool(HttpTLS) + `,` +
`"sessionIDHashFunc":"` + SessionHashFunc + `",` +
`"sessionIDHashKey":"` + SessionHashKey + `",` +
`"enableSetCookie":` + strconv.FormatBool(SessionAutoSetCookie) + `,` +
`"cookieLifeTime":` + strconv.Itoa(SessionCookieLifeTime) + `}`
}
GlobalSessions, err = session.NewManager(SessionProvider,
sessionConfig)
if err != nil {
panic(err)
}
go GlobalSessions.GC()
}
......@@ -123,7 +221,7 @@ func Run() {
middleware.VERSION = VERSION
middleware.AppName = AppName
middleware.RegisterErrorHander()
middleware.RegisterErrorHandler()
if EnableAdmin {
go BeeAdminApp.Run()
......@@ -131,3 +229,9 @@ func Run() {
BeeApp.Run()
}
func init() {
hooks = make([]hookfunc, 0)
//init mime
AddAPPStartHook(initMime)
}
......
......@@ -43,7 +43,7 @@ interval means the gc time. The cache will check at each time interval, whether
## Memcache adapter
memory adapter use the vitess's [Memcache](http://code.google.com/p/vitess/go/memcache) client.
Memcache adapter use the vitess's [Memcache](http://code.google.com/p/vitess/go/memcache) client.
Configure like this:
......
......@@ -5,7 +5,7 @@ import (
"time"
)
func Test_cache(t *testing.T) {
func TestCache(t *testing.T) {
bm, err := NewCache("memory", `{"interval":20}`)
if err != nil {
t.Error("init err")
......@@ -51,3 +51,51 @@ func Test_cache(t *testing.T) {
t.Error("delete err")
}
}
func TestFileCache(t *testing.T) {
bm, err := NewCache("file", `{"CachePath":"/cache","FileSuffix":".bin","DirectoryLevel":2,"EmbedExpiry":0}`)
if err != nil {
t.Error("init err")
}
if err = bm.Put("astaxie", 1, 10); err != nil {
t.Error("set Error", err)
}
if !bm.IsExist("astaxie") {
t.Error("check err")
}
if v := bm.Get("astaxie"); v.(int) != 1 {
t.Error("get err")
}
if err = bm.Incr("astaxie"); err != nil {
t.Error("Incr Error", err)
}
if v := bm.Get("astaxie"); v.(int) != 2 {
t.Error("get err")
}
if err = bm.Decr("astaxie"); err != nil {
t.Error("Incr Error", err)
}
if v := bm.Get("astaxie"); v.(int) != 1 {
t.Error("get err")
}
bm.Delete("astaxie")
if bm.IsExist("astaxie") {
t.Error("delete err")
}
//test string
if err = bm.Put("astaxie", "author", 10); err != nil {
t.Error("set Error", err)
}
if !bm.IsExist("astaxie") {
t.Error("check err")
}
if v := bm.Get("astaxie"); v.(string) != "author" {
t.Error("get err")
}
}
......
......@@ -61,6 +61,7 @@ func (this *FileCache) StartAndGC(config string) error {
var cfg map[string]string
json.Unmarshal([]byte(config), &cfg)
//fmt.Println(cfg)
//fmt.Println(config)
if _, ok := cfg["CachePath"]; !ok {
cfg["CachePath"] = FileCachePath
}
......@@ -135,7 +136,7 @@ func (this *FileCache) Get(key string) interface{} {
return ""
}
var to FileCacheItem
Gob_decode([]byte(filedata), &to)
Gob_decode(filedata, &to)
if to.Expired < time.Now().Unix() {
return ""
}
......@@ -177,7 +178,7 @@ func (this *FileCache) Delete(key string) error {
func (this *FileCache) Incr(key string) error {
data := this.Get(key)
var incr int
fmt.Println(reflect.TypeOf(data).Name())
//fmt.Println(reflect.TypeOf(data).Name())
if reflect.TypeOf(data).Name() != "int" {
incr = 0
} else {
......@@ -210,8 +211,7 @@ func (this *FileCache) IsExist(key string) bool {
// Clean cached files.
// not implemented.
func (this *FileCache) ClearAll() error {
//this.CachePath .递归删除
//this.CachePath
return nil
}
......@@ -271,7 +271,7 @@ func Gob_encode(data interface{}) ([]byte, error) {
}
// Gob decodes file cache item.
func Gob_decode(data []byte, to interface{}) error {
func Gob_decode(data []byte, to *FileCacheItem) error {
buf := bytes.NewBuffer(data)
dec := gob.NewDecoder(buf)
return dec.Decode(&to)
......
......@@ -21,7 +21,11 @@ func NewMemCache() *MemcacheCache {
// get value from memcache.
func (rc *MemcacheCache) Get(key string) interface{} {
if rc.c == nil {
rc.c = rc.connectInit()
var err error
rc.c, err = rc.connectInit()
if err != nil {
return err
}
}
v, err := rc.c.Get(key)
if err != nil {
......@@ -39,7 +43,11 @@ func (rc *MemcacheCache) Get(key string) interface{} {
// put value to memcache. only support string.
func (rc *MemcacheCache) Put(key string, val interface{}, timeout int64) error {
if rc.c == nil {
rc.c = rc.connectInit()
var err error
rc.c, err = rc.connectInit()
if err != nil {
return err
}
}
v, ok := val.(string)
if !ok {
......@@ -55,7 +63,11 @@ func (rc *MemcacheCache) Put(key string, val interface{}, timeout int64) error {
// delete value in memcache.
func (rc *MemcacheCache) Delete(key string) error {
if rc.c == nil {
rc.c = rc.connectInit()
var err error
rc.c, err = rc.connectInit()
if err != nil {
return err
}
}
_, err := rc.c.Delete(key)
return err
......@@ -76,7 +88,11 @@ func (rc *MemcacheCache) Decr(key string) error {
// check value exists in memcache.
func (rc *MemcacheCache) IsExist(key string) bool {
if rc.c == nil {
rc.c = rc.connectInit()
var err error
rc.c, err = rc.connectInit()
if err != nil {
return false
}
}
v, err := rc.c.Get(key)
if err != nil {
......@@ -93,7 +109,11 @@ func (rc *MemcacheCache) IsExist(key string) bool {
// clear all cached in memcache.
func (rc *MemcacheCache) ClearAll() error {
if rc.c == nil {
rc.c = rc.connectInit()
var err error
rc.c, err = rc.connectInit()
if err != nil {
return err
}
}
err := rc.c.FlushAll()
return err
......@@ -109,20 +129,21 @@ func (rc *MemcacheCache) StartAndGC(config string) error {
return errors.New("config has no conn key")
}
rc.conninfo = cf["conn"]
rc.c = rc.connectInit()
if rc.c == nil {
var err error
rc.c, err = rc.connectInit()
if err != nil {
return errors.New("dial tcp conn error")
}
return nil
}
// connect to memcache and keep the connection.
func (rc *MemcacheCache) connectInit() *memcache.Connection {
func (rc *MemcacheCache) connectInit() (*memcache.Connection, error) {
c, err := memcache.Connect(rc.conninfo)
if err != nil {
return nil
return nil, err
}
return c
return c, nil
}
func init() {
......
......@@ -3,6 +3,7 @@ package cache
import (
"encoding/json"
"errors"
"time"
"github.com/beego/redigo/redis"
)
......@@ -14,7 +15,7 @@ var (
// Redis cache adapter.
type RedisCache struct {
c redis.Conn
p *redis.Pool // redis connection pool
conninfo string
key string
}
......@@ -24,107 +25,62 @@ func NewRedisCache() *RedisCache {
return &RedisCache{key: DefaultKey}
}
// actually do the redis cmds
func (rc *RedisCache) do(commandName string, args ...interface{}) (reply interface{}, err error) {
c := rc.p.Get()
defer c.Close()
return c.Do(commandName, args...)
}
// Get cache from redis.
func (rc *RedisCache) Get(key string) interface{} {
if rc.c == nil {
var err error
rc.c, err = rc.connectInit()
if err != nil {
return nil
}
}
v, err := rc.c.Do("HGET", rc.key, key)
v, err := rc.do("HGET", rc.key, key)
if err != nil {
return nil
}
return v
}
// put cache to redis.
// timeout is ignored.
func (rc *RedisCache) Put(key string, val interface{}, timeout int64) error {
if rc.c == nil {
var err error
rc.c, err = rc.connectInit()
if err != nil {
return err
}
}
_, err := rc.c.Do("HSET", rc.key, key, val)
_, err := rc.do("HSET", rc.key, key, val)
return err
}
// delete cache in redis.
func (rc *RedisCache) Delete(key string) error {
if rc.c == nil {
var err error
rc.c, err = rc.connectInit()
if err != nil {
return err
}
}
_, err := rc.c.Do("HDEL", rc.key, key)
_, err := rc.do("HDEL", rc.key, key)
return err
}
// check cache exist in redis.
func (rc *RedisCache) IsExist(key string) bool {
if rc.c == nil {
var err error
rc.c, err = rc.connectInit()
if err != nil {
return false
}
}
v, err := redis.Bool(rc.c.Do("HEXISTS", rc.key, key))
v, err := redis.Bool(rc.do("HEXISTS", rc.key, key))
if err != nil {
return false
}
return v
}
// increase counter in redis.
func (rc *RedisCache) Incr(key string) error {
if rc.c == nil {
var err error
rc.c, err = rc.connectInit()
if err != nil {
return err
}
}
_, err := redis.Bool(rc.c.Do("HINCRBY", rc.key, key, 1))
if err != nil {
return err
}
return nil
_, err := redis.Bool(rc.do("HINCRBY", rc.key, key, 1))
return err
}
// decrease counter in redis.
func (rc *RedisCache) Decr(key string) error {
if rc.c == nil {
var err error
rc.c, err = rc.connectInit()
if err != nil {
return err
}
}
_, err := redis.Bool(rc.c.Do("HINCRBY", rc.key, key, -1))
if err != nil {
return err
}
return nil
_, err := redis.Bool(rc.do("HINCRBY", rc.key, key, -1))
return err
}
// clean all cache in redis. delete this redis collection.
func (rc *RedisCache) ClearAll() error {
if rc.c == nil {
var err error
rc.c, err = rc.connectInit()
if err != nil {
return err
}
}
_, err := rc.c.Do("DEL", rc.key)
_, err := rc.do("DEL", rc.key)
return err
}
......@@ -135,32 +91,42 @@ func (rc *RedisCache) ClearAll() error {
func (rc *RedisCache) StartAndGC(config string) error {
var cf map[string]string
json.Unmarshal([]byte(config), &cf)
if _, ok := cf["key"]; !ok {
cf["key"] = DefaultKey
}
if _, ok := cf["conn"]; !ok {
return errors.New("config has no conn key")
}
rc.key = cf["key"]
rc.conninfo = cf["conn"]
var err error
rc.c, err = rc.connectInit()
if err != nil {
rc.connectInit()
c := rc.p.Get()
defer c.Close()
if err := c.Err(); err != nil {
return err
}
if rc.c == nil {
return errors.New("dial tcp conn error")
}
return nil
}
// connect to redis.
func (rc *RedisCache) connectInit() (redis.Conn, error) {
c, err := redis.Dial("tcp", rc.conninfo)
if err != nil {
return nil, err
func (rc *RedisCache) connectInit() {
// initialize a new pool
rc.p = &redis.Pool{
MaxIdle: 3,
IdleTimeout: 180 * time.Second,
Dial: func() (redis.Conn, error) {
c, err := redis.Dial("tcp", rc.conninfo)
if err != nil {
return nil, err
}
return c, nil
},
}
return c, nil
}
func init() {
......
......@@ -40,6 +40,7 @@ var (
SessionHashFunc string // session hash generation func.
SessionHashKey string // session hash salt string.
SessionCookieLifeTime int // the life time of session id in cookie.
SessionAutoSetCookie bool // auto setcookie
UseFcgi bool
MaxMemory int64
EnableGzip bool // flag of enable gzip
......@@ -96,6 +97,7 @@ func init() {
SessionHashFunc = "sha1"
SessionHashKey = "beegoserversessionkey"
SessionCookieLifeTime = 0 //set cookie default is the brower life
SessionAutoSetCookie = true
UseFcgi = false
......@@ -139,6 +141,7 @@ func init() {
func ParseConfig() (err error) {
AppConfig, err = config.NewConfig("ini", AppConfigPath)
if err != nil {
AppConfig = config.NewFakeConfig()
return err
} else {
HttpAddr = AppConfig.String("HttpAddr")
......
......@@ -6,8 +6,9 @@ import (
// ConfigContainer defines how to get and set value from configuration raw data.
type ConfigContainer interface {
Set(key, val string) error // support section::key type in given key when using ini type.
String(key string) string // support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same.
Set(key, val string) error // support section::key type in given key when using ini type.
String(key string) string // support section::key type in key string when using ini and json type; Int,Int64,Bool,Float,DIY are same.
Strings(key string) []string //get string slice
Int(key string) (int, error)
Int64(key string) (int64, error)
Bool(key string) (bool, error)
......
package config
import (
"errors"
"strconv"
"strings"
)
type fakeConfigContainer struct {
data map[string]string
}
func (c *fakeConfigContainer) getData(key string) string {
key = strings.ToLower(key)
return c.data[key]
}
func (c *fakeConfigContainer) Set(key, val string) error {
key = strings.ToLower(key)
c.data[key] = val
return nil
}
func (c *fakeConfigContainer) String(key string) string {
return c.getData(key)
}
func (c *fakeConfigContainer) Strings(key string) []string {
return strings.Split(c.getData(key), ";")
}
func (c *fakeConfigContainer) Int(key string) (int, error) {
return strconv.Atoi(c.getData(key))
}
func (c *fakeConfigContainer) Int64(key string) (int64, error) {
return strconv.ParseInt(c.getData(key), 10, 64)
}
func (c *fakeConfigContainer) Bool(key string) (bool, error) {
return strconv.ParseBool(c.getData(key))
}
func (c *fakeConfigContainer) Float(key string) (float64, error) {
return strconv.ParseFloat(c.getData(key), 64)
}
func (c *fakeConfigContainer) DIY(key string) (interface{}, error) {
key = strings.ToLower(key)
if v, ok := c.data[key]; ok {
return v, nil
}
return nil, errors.New("key not find")
}
var _ ConfigContainer = new(fakeConfigContainer)
func NewFakeConfig() ConfigContainer {
return &fakeConfigContainer{
data: make(map[string]string),
}
}
......@@ -146,6 +146,11 @@ func (c *IniConfigContainer) String(key string) string {
return c.getdata(key)
}
// Strings returns the []string value for a given key.
func (c *IniConfigContainer) Strings(key string) []string {
return strings.Split(c.String(key), ";")
}
// WriteValue writes a new value for key.
// if write to one section, the key need be "section::key".
// if the section is not existed, it panics.
......
......@@ -19,6 +19,7 @@ copyrequestbody = true
key1="asta"
key2 = "xie"
CaseInsensitive = true
peers = one;two;three
`
func TestIni(t *testing.T) {
......@@ -78,4 +79,11 @@ func TestIni(t *testing.T) {
if v, err := iniconf.Bool("demo::caseinsensitive"); err != nil || v != true {
t.Fatal("get demo.caseinsensitive error")
}
if data := iniconf.Strings("demo::peers"); len(data) != 3 {
t.Fatal("get strings error", data)
} else if data[0] != "one" {
t.Fatal("get first params error not equat to one")
}
}
......
......@@ -116,6 +116,11 @@ func (c *JsonConfigContainer) String(key string) string {
return ""
}
// Strings returns the []string value for a given key.
func (c *JsonConfigContainer) Strings(key string) []string {
return strings.Split(c.String(key), ";")
}
// WriteValue writes a new value for key.
func (c *JsonConfigContainer) Set(key, val string) error {
c.Lock()
......
......@@ -5,6 +5,7 @@ import (
"io/ioutil"
"os"
"strconv"
"strings"
"sync"
"github.com/beego/x2j"
......@@ -72,6 +73,11 @@ func (c *XMLConfigContainer) String(key string) string {
return ""
}
// Strings returns the []string value for a given key.
func (c *XMLConfigContainer) Strings(key string) []string {
return strings.Split(c.String(key), ";")
}
// WriteValue writes a new value for key.
func (c *XMLConfigContainer) Set(key, val string) error {
c.Lock()
......
......@@ -7,6 +7,7 @@ import (
"io/ioutil"
"log"
"os"
"strings"
"sync"
"github.com/beego/goyaml2"
......@@ -117,6 +118,11 @@ func (c *YAMLConfigContainer) String(key string) string {
return ""
}
// Strings returns the []string value for a given key.
func (c *YAMLConfigContainer) Strings(key string) []string {
return strings.Split(c.String(key), ";")
}
// WriteValue writes a new value for key.
func (c *YAMLConfigContainer) Set(key, val string) error {
c.Lock()
......
......@@ -3,7 +3,6 @@ package beego
import (
"bytes"
"crypto/hmac"
"crypto/rand"
"crypto/sha1"
"encoding/base64"
"errors"
......@@ -22,6 +21,7 @@ import (
"github.com/astaxie/beego/context"
"github.com/astaxie/beego/session"
"github.com/astaxie/beego/utils"
)
var (
......@@ -140,7 +140,7 @@ func (c *Controller) RenderString() (string, error) {
return string(b), e
}
// RenderBytes returns the bytes of renderd tempate string. Do not send out response.
// RenderBytes returns the bytes of rendered template string. Do not send out response.
func (c *Controller) RenderBytes() ([]byte, error) {
//if the controller has set layout, then first get the tplname's content set the content to the layout
if c.Layout != "" {
......@@ -165,7 +165,7 @@ func (c *Controller) RenderBytes() ([]byte, error) {
if c.LayoutSections != nil {
for sectionName, sectionTpl := range c.LayoutSections {
if (sectionTpl == "") {
if sectionTpl == "" {
c.Data[sectionName] = ""
continue
}
......@@ -391,6 +391,7 @@ func (c *Controller) DelSession(name interface{}) {
// SessionRegenerateID regenerates session id for this session.
// the session data have no changes.
func (c *Controller) SessionRegenerateID() {
c.CruSession.SessionRelease(c.Ctx.ResponseWriter)
c.CruSession = GlobalSessions.SessionRegenerateId(c.Ctx.ResponseWriter, c.Ctx.Request)
c.Ctx.Input.CruSession = c.CruSession
}
......@@ -454,7 +455,7 @@ func (c *Controller) XsrfToken() string {
} else {
expire = int64(XSRFExpire)
}
token = getRandomString(15)
token = string(utils.RandomCreateBytes(15))
c.SetSecureCookie(XSRFKEY, "_xsrf", token, expire)
}
c._xsrf_token = token
......@@ -491,14 +492,3 @@ func (c *Controller) XsrfFormHtml() string {
func (c *Controller) GetControllerAndAction() (controllerName, actionName string) {
return c.controllerName, c.actionName
}
// getRandomString returns random string.
func getRandomString(n int) string {
const alphanum = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
var bytes = make([]byte, n)
rand.Read(bytes)
for i, b := range bytes {
bytes[i] = alphanum[b%byte(len(alphanum))]
}
return string(bytes)
}
......
package controllers
import (
"github.com/astaxie/beego"
"github.com/garyburd/go-websocket/websocket"
"io/ioutil"
"math/rand"
"net/http"
"time"
"github.com/astaxie/beego"
"github.com/gorilla/websocket"
)
const (
......
......@@ -28,6 +28,12 @@ func (mr *FilterRouter) ValidRouter(router string) (bool, map[string]string) {
if router == mr.pattern {
return true, nil
}
//pattern /admin router /admin/ match
//pattern /admin/ router /admin don't match, because url will 301 in router
if n := len(router); n > 1 && router[n-1] == '/' && router[:n-2] == mr.pattern {
return true, nil
}
if mr.hasregex {
if !mr.regex.MatchString(router) {
return false, nil
......@@ -46,7 +52,7 @@ func (mr *FilterRouter) ValidRouter(router string) (bool, map[string]string) {
return false, nil
}
func buildFilter(pattern string, filter FilterFunc) *FilterRouter {
func buildFilter(pattern string, filter FilterFunc) (*FilterRouter, error) {
mr := new(FilterRouter)
mr.params = make(map[int]string)
mr.filterFunc = filter
......@@ -54,7 +60,7 @@ func buildFilter(pattern string, filter FilterFunc) *FilterRouter {
j := 0
for i, part := range parts {
if strings.HasPrefix(part, ":") {
expr := "(.+)"
expr := "(.*)"
//a user may choose to override the default expression
// similar to expressjs: ‘/user/:id([0-9]+)’
if index := strings.Index(part, "("); index != -1 {
......@@ -77,7 +83,7 @@ func buildFilter(pattern string, filter FilterFunc) *FilterRouter {
j++
}
if strings.HasPrefix(part, "*") {
expr := "(.+)"
expr := "(.*)"
if part == "*.*" {
mr.params[j] = ":path"
parts[i] = "([^.]+).([^.]+)"
......@@ -137,12 +143,11 @@ func buildFilter(pattern string, filter FilterFunc) *FilterRouter {
pattern = strings.Join(parts, "/")
regex, regexErr := regexp.Compile(pattern)
if regexErr != nil {
//TODO add error handling here to avoid panic
panic(regexErr)
return nil, regexErr
}
mr.regex = regex
mr.hasregex = true
}
mr.pattern = pattern
return mr
return mr, nil
}
......
......@@ -23,3 +23,32 @@ func TestFilter(t *testing.T) {
t.Errorf("user define func can't run")
}
}
var FilterAdminUser = func(ctx *context.Context) {
ctx.Output.Body([]byte("i am admin"))
}
// Filter pattern /admin/:all
// all url like /admin/ /admin/xie will all get filter
func TestPatternTwo(t *testing.T) {
r, _ := http.NewRequest("GET", "/admin/", nil)
w := httptest.NewRecorder()
handler := NewControllerRegistor()
handler.AddFilter("/admin/:all", "AfterStatic", FilterAdminUser)
handler.ServeHTTP(w, r)
if w.Body.String() != "i am admin" {
t.Errorf("filter /admin/ can't run")
}
}
func TestPatternThree(t *testing.T) {
r, _ := http.NewRequest("GET", "/admin/astaxie", nil)
w := httptest.NewRecorder()
handler := NewControllerRegistor()
handler.AddFilter("/admin/:all", "AfterStatic", FilterAdminUser)
handler.ServeHTTP(w, r)
if w.Body.String() != "i am admin" {
t.Errorf("filter /admin/astaxie can't run")
}
}
......
......@@ -7,6 +7,8 @@ import (
"net"
)
// ConnWriter implements LoggerInterface.
// it writes messages in keep-live tcp connection.
type ConnWriter struct {
lg *log.Logger
innerWriter io.WriteCloser
......@@ -17,12 +19,15 @@ type ConnWriter struct {
Level int `json:"level"`
}
// create new ConnWrite returning as LoggerInterface.
func NewConn() LoggerInterface {
conn := new(ConnWriter)
conn.Level = LevelTrace
return conn
}
// init connection writer with json config.
// json config only need key "level".
func (c *ConnWriter) Init(jsonconfig string) error {
err := json.Unmarshal([]byte(jsonconfig), c)
if err != nil {
......@@ -31,6 +36,8 @@ func (c *ConnWriter) Init(jsonconfig string) error {
return nil
}
// write message in connection.
// if connection is down, try to re-connect.
func (c *ConnWriter) WriteMsg(msg string, level int) error {
if level < c.Level {
return nil
......@@ -49,10 +56,12 @@ func (c *ConnWriter) WriteMsg(msg string, level int) error {
return nil
}
// implementing method. empty.
func (c *ConnWriter) Flush() {
}
// destroy connection writer and close tcp listener.
func (c *ConnWriter) Destroy() {
if c.innerWriter == nil {
return
......
......@@ -6,11 +6,13 @@ import (
"os"
)
// ConsoleWriter implements LoggerInterface and writes messages to terminal.
type ConsoleWriter struct {
lg *log.Logger
Level int `json:"level"`
}
// create ConsoleWriter returning as LoggerInterface.
func NewConsole() LoggerInterface {
cw := new(ConsoleWriter)
cw.lg = log.New(os.Stdout, "", log.Ldate|log.Ltime)
......@@ -18,6 +20,8 @@ func NewConsole() LoggerInterface {
return cw
}
// init console logger.
// jsonconfig like '{"level":LevelTrace}'.
func (c *ConsoleWriter) Init(jsonconfig string) error {
err := json.Unmarshal([]byte(jsonconfig), c)
if err != nil {
......@@ -26,6 +30,7 @@ func (c *ConsoleWriter) Init(jsonconfig string) error {
return nil
}
// write message in console.
func (c *ConsoleWriter) WriteMsg(msg string, level int) error {
if level < c.Level {
return nil
......@@ -34,10 +39,12 @@ func (c *ConsoleWriter) WriteMsg(msg string, level int) error {
return nil
}
// implementing method. empty.
func (c *ConsoleWriter) Destroy() {
}
// implementing method. empty.
func (c *ConsoleWriter) Flush() {
}
......
......@@ -13,6 +13,8 @@ import (
"time"
)
// FileLogWriter implements LoggerInterface.
// It writes messages by lines limit, file size limit, or time frequency.
type FileLogWriter struct {
*log.Logger
mw *MuxWriter
......@@ -38,17 +40,20 @@ type FileLogWriter struct {
Level int `json:"level"`
}
// an *os.File writer with locker.
type MuxWriter struct {
sync.Mutex
fd *os.File
}
// write to os.File.
func (l *MuxWriter) Write(b []byte) (int, error) {
l.Lock()
defer l.Unlock()
return l.fd.Write(b)
}
// set os.File in writer.
func (l *MuxWriter) SetFd(fd *os.File) {
if l.fd != nil {
l.fd.Close()
......@@ -56,6 +61,7 @@ func (l *MuxWriter) SetFd(fd *os.File) {
l.fd = fd
}
// create a FileLogWriter returning as LoggerInterface.
func NewFileWriter() LoggerInterface {
w := &FileLogWriter{
Filename: "",
......@@ -73,15 +79,16 @@ func NewFileWriter() LoggerInterface {
return w
}
// jsonconfig like this
//{
// Init file logger with json config.
// jsonconfig like:
// {
// "filename":"logs/beego.log",
// "maxlines":10000,
// "maxsize":1<<30,
// "daily":true,
// "maxdays":15,
// "rotate":true
//}
// }
func (w *FileLogWriter) Init(jsonconfig string) error {
err := json.Unmarshal([]byte(jsonconfig), w)
if err != nil {
......@@ -94,6 +101,7 @@ func (w *FileLogWriter) Init(jsonconfig string) error {
return err
}
// start file logger. create log file and set to locker-inside file writer.
func (w *FileLogWriter) StartLogger() error {
fd, err := w.createLogFile()
if err != nil {
......@@ -122,6 +130,7 @@ func (w *FileLogWriter) docheck(size int) {
w.maxsize_cursize += size
}
// write logger message into file.
func (w *FileLogWriter) WriteMsg(msg string, level int) error {
if level < w.Level {
return nil
......@@ -158,6 +167,8 @@ func (w *FileLogWriter) initFd() error {
return nil
}
// DoRotate means it need to write file in new file.
// new file name like xx.log.2013-01-01.2
func (w *FileLogWriter) DoRotate() error {
_, err := os.Lstat(w.Filename)
if err == nil { // file exists
......@@ -211,10 +222,14 @@ func (w *FileLogWriter) deleteOldLog() {
})
}
// destroy file logger, close file writer.
func (w *FileLogWriter) Destroy() {
w.mw.fd.Close()
}
// flush file logger.
// there are no buffering messages in file logger in memory.
// flush file means sync file from disk.
func (w *FileLogWriter) Flush() {
w.mw.fd.Sync()
}
......
......@@ -6,6 +6,7 @@ import (
)
const (
// log message levels
LevelTrace = iota
LevelDebug
LevelInfo
......@@ -16,6 +17,7 @@ const (
type loggerType func() LoggerInterface
// LoggerInterface defines the behavior of a log provider.
type LoggerInterface interface {
Init(config string) error
WriteMsg(msg string, level int) error
......@@ -38,6 +40,8 @@ func Register(name string, log loggerType) {
adapters[name] = log
}
// BeeLogger is default logger in beego application.
// it can contain several providers and log message into all providers.
type BeeLogger struct {
lock sync.Mutex
level int
......@@ -50,7 +54,9 @@ type logMsg struct {
msg string
}
// config need to be correct JSON as string: {"interval":360}
// NewLogger returns a new BeeLogger.
// channellen means the number of messages in chan.
// if the buffering chan is full, logger adapters write to file or other way.
func NewLogger(channellen int64) *BeeLogger {
bl := new(BeeLogger)
bl.msg = make(chan *logMsg, channellen)
......@@ -60,6 +66,8 @@ func NewLogger(channellen int64) *BeeLogger {
return bl
}
// SetLogger provides a given logger adapter into BeeLogger with config string.
// config need to be correct JSON as string: {"interval":360}.
func (bl *BeeLogger) SetLogger(adaptername string, config string) error {
bl.lock.Lock()
defer bl.lock.Unlock()
......@@ -73,6 +81,7 @@ func (bl *BeeLogger) SetLogger(adaptername string, config string) error {
}
}
// remove a logger adapter in BeeLogger.
func (bl *BeeLogger) DelLogger(adaptername string) error {
bl.lock.Lock()
defer bl.lock.Unlock()
......@@ -96,10 +105,14 @@ func (bl *BeeLogger) writerMsg(loglevel int, msg string) error {
return nil
}
// set log message level.
// if message level (such as LevelTrace) is less than logger level (such as LevelWarn), ignore message.
func (bl *BeeLogger) SetLevel(l int) {
bl.level = l
}
// start logger chan reading.
// when chan is full, write logs.
func (bl *BeeLogger) StartLogger() {
for {
select {
......@@ -111,43 +124,50 @@ func (bl *BeeLogger) StartLogger() {
}
}
// log trace level message.
func (bl *BeeLogger) Trace(format string, v ...interface{}) {
msg := fmt.Sprintf("[T] "+format, v...)
bl.writerMsg(LevelTrace, msg)
}
// log debug level message.
func (bl *BeeLogger) Debug(format string, v ...interface{}) {
msg := fmt.Sprintf("[D] "+format, v...)
bl.writerMsg(LevelDebug, msg)
}
// log info level message.
func (bl *BeeLogger) Info(format string, v ...interface{}) {
msg := fmt.Sprintf("[I] "+format, v...)
bl.writerMsg(LevelInfo, msg)
}
// log warn level message.
func (bl *BeeLogger) Warn(format string, v ...interface{}) {
msg := fmt.Sprintf("[W] "+format, v...)
bl.writerMsg(LevelWarn, msg)
}
// log error level message.
func (bl *BeeLogger) Error(format string, v ...interface{}) {
msg := fmt.Sprintf("[E] "+format, v...)
bl.writerMsg(LevelError, msg)
}
// log critical level message.
func (bl *BeeLogger) Critical(format string, v ...interface{}) {
msg := fmt.Sprintf("[C] "+format, v...)
bl.writerMsg(LevelCritical, msg)
}
//flush all chan data
// flush all chan data.
func (bl *BeeLogger) Flush() {
for _, l := range bl.outputs {
l.Flush()
}
}
// close logger, flush all chan data and destroy all adapters in BeeLogger.
func (bl *BeeLogger) Close() {
for {
if len(bl.msg) > 0 {
......
......@@ -12,7 +12,7 @@ const (
subjectPhrase = "Diagnostic message from server"
)
// smtpWriter is used to send emails via given SMTP-server.
// smtpWriter implements LoggerInterface and is used to send emails via given SMTP-server.
type SmtpWriter struct {
Username string `json:"Username"`
Password string `json:"password"`
......@@ -22,10 +22,21 @@ type SmtpWriter struct {
Level int `json:"level"`
}
// create smtp writer.
func NewSmtpWriter() LoggerInterface {
return &SmtpWriter{Level: LevelTrace}
}
// init smtp writer with json config.
// config like:
// {
// "Username":"example@gmail.com",
// "password:"password",
// "host":"smtp.gmail.com:465",
// "subject":"email title",
// "sendTos":["email1","email2"],
// "level":LevelError
// }
func (s *SmtpWriter) Init(jsonconfig string) error {
err := json.Unmarshal([]byte(jsonconfig), s)
if err != nil {
......@@ -34,6 +45,8 @@ func (s *SmtpWriter) Init(jsonconfig string) error {
return nil
}
// write message in smtp writer.
// it will send an email with subject and only this message.
func (s *SmtpWriter) WriteMsg(msg string, level int) error {
if level < s.Level {
return nil
......@@ -65,9 +78,12 @@ func (s *SmtpWriter) WriteMsg(msg string, level int) error {
return err
}
// implementing method. empty.
func (s *SmtpWriter) Flush() {
return
}
// implementing method. empty.
func (s *SmtpWriter) Destroy() {
return
}
......
......@@ -5,16 +5,17 @@ import (
"compress/flate"
"compress/gzip"
"errors"
//"fmt"
"io"
"io/ioutil"
"net/http"
"os"
"strings"
"sync"
"time"
)
var gmfim map[string]*MemFileInfo = make(map[string]*MemFileInfo)
var lock sync.RWMutex
// OpenMemZipFile returns MemFile object with a compressed static file.
// it's used for serve static file if gzip enable.
......@@ -32,12 +33,12 @@ func OpenMemZipFile(path string, zip string) (*MemFile, error) {
modtime := osfileinfo.ModTime()
fileSize := osfileinfo.Size()
lock.RLock()
cfi, ok := gmfim[zip+":"+path]
lock.RUnlock()
if ok && cfi.ModTime() == modtime && cfi.fileSize == fileSize {
//fmt.Printf("read %s file %s from cache\n", zip, path)
} else {
//fmt.Printf("NOT read %s file %s from cache\n", zip, path)
var content []byte
if zip == "gzip" {
//将文件内容压缩到zipbuf中
......@@ -81,8 +82,9 @@ func OpenMemZipFile(path string, zip string) (*MemFile, error) {
}
cfi = &MemFileInfo{osfileinfo, modtime, content, int64(len(content)), fileSize}
lock.Lock()
defer lock.Unlock()
gmfim[zip+":"+path] = cfi
//fmt.Printf("%s file %s to %d, cache it\n", zip, path, len(content))
}
return &MemFile{fi: cfi, offset: 0}, nil
}
......
......@@ -61,6 +61,7 @@ var tpl = `
</html>
`
// render default application error page with error and stack string.
func ShowErr(err interface{}, rw http.ResponseWriter, r *http.Request, Stack string) {
t, _ := template.New("beegoerrortemp").Parse(tpl)
data := make(map[string]string)
......@@ -71,6 +72,7 @@ func ShowErr(err interface{}, rw http.ResponseWriter, r *http.Request, Stack str
data["Stack"] = Stack
data["BeegoVersion"] = VERSION
data["GoVersion"] = runtime.Version()
rw.WriteHeader(500)
t.Execute(rw, data)
}
......@@ -174,18 +176,19 @@ var errtpl = `
</html>
`
// map of http handlers for each error string.
var ErrorMaps map[string]http.HandlerFunc
func init() {
ErrorMaps = make(map[string]http.HandlerFunc)
}
//404
// show 404 notfound error.
func NotFound(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl)
data := make(map[string]interface{})
data["Title"] = "Page Not Found"
data["Content"] = template.HTML("<br>The Page You have requested flown the coop." +
data["Content"] = template.HTML("<br>The page you have requested has flown the coop." +
"<br>Perhaps you are here because:" +
"<br><br><ul>" +
"<br>The page has moved" +
......@@ -198,28 +201,28 @@ func NotFound(rw http.ResponseWriter, r *http.Request) {
t.Execute(rw, data)
}
//401
// show 401 unauthorized error.
func Unauthorized(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl)
data := make(map[string]interface{})
data["Title"] = "Unauthorized"
data["Content"] = template.HTML("<br>The Page You have requested can't authorized." +
data["Content"] = template.HTML("<br>The page you have requested can't be authorized." +
"<br>Perhaps you are here because:" +
"<br><br><ul>" +
"<br>Check the credentials that you supplied" +
"<br>Check the address for errors" +
"<br>The credentials you supplied are incorrect" +
"<br>There are errors in the website address" +
"</ul>")
data["BeegoVersion"] = VERSION
//rw.WriteHeader(http.StatusUnauthorized)
t.Execute(rw, data)
}
//403
// show 403 forbidden error.
func Forbidden(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl)
data := make(map[string]interface{})
data["Title"] = "Forbidden"
data["Content"] = template.HTML("<br>The Page You have requested forbidden." +
data["Content"] = template.HTML("<br>The page you have requested is forbidden." +
"<br>Perhaps you are here because:" +
"<br><br><ul>" +
"<br>Your address may be blocked" +
......@@ -231,12 +234,12 @@ func Forbidden(rw http.ResponseWriter, r *http.Request) {
t.Execute(rw, data)
}
//503
// show 503 service unavailable error.
func ServiceUnavailable(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl)
data := make(map[string]interface{})
data["Title"] = "Service Unavailable"
data["Content"] = template.HTML("<br>The Page You have requested unavailable." +
data["Content"] = template.HTML("<br>The page you have requested is unavailable." +
"<br>Perhaps you are here because:" +
"<br><br><ul>" +
"<br><br>The page is overloaded" +
......@@ -247,30 +250,32 @@ func ServiceUnavailable(rw http.ResponseWriter, r *http.Request) {
t.Execute(rw, data)
}
//500
// show 500 internal server error.
func InternalServerError(rw http.ResponseWriter, r *http.Request) {
t, _ := template.New("beegoerrortemp").Parse(errtpl)
data := make(map[string]interface{})
data["Title"] = "Internal Server Error"
data["Content"] = template.HTML("<br>The Page You have requested has down now." +
data["Content"] = template.HTML("<br>The page you have requested is down right now." +
"<br><br><ul>" +
"<br>simply try again later" +
"<br>you should report the fault to the website administrator" +
"</ul>")
"<br>Please try again later and report the error to the website administrator" +
"<br></ul>")
data["BeegoVersion"] = VERSION
//rw.WriteHeader(http.StatusInternalServerError)
t.Execute(rw, data)
}
// show 500 internal error with simple text string.
func SimpleServerError(rw http.ResponseWriter, r *http.Request) {
http.Error(rw, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
}
// add http handler for given error string.
func Errorhandler(err string, h http.HandlerFunc) {
ErrorMaps[err] = h
}
func RegisterErrorHander() {
// register default error http handlers, 404,401,403,500 and 503.
func RegisterErrorHandler() {
if _, ok := ErrorMaps["404"]; !ok {
ErrorMaps["404"] = NotFound
}
......@@ -292,6 +297,8 @@ func RegisterErrorHander() {
}
}
// show error string as simple text message.
// if error string is empty, show 500 error as default.
func Exception(errcode string, w http.ResponseWriter, r *http.Request, msg string) {
if h, ok := ErrorMaps[errcode]; ok {
isint, err := strconv.Atoi(errcode)
......
......@@ -2,16 +2,19 @@ package middleware
import "fmt"
// http exceptions
type HTTPException struct {
StatusCode int // http status code 4xx, 5xx
Description string
}
// return http exception error string, e.g. "400 Bad Request".
func (e *HTTPException) Error() string {
// return `status description`, e.g. `400 Bad Request`
return fmt.Sprintf("%d %s", e.StatusCode, e.Description)
}
// map of http exceptions for each http status code int.
// defined 400,401,403,404,405,500,502,503 and 504 default.
var HTTPExceptionMaps map[int]HTTPException
func init() {
......
......@@ -544,8 +544,9 @@ var mimemaps map[string]string = map[string]string{
".mustache": "text/html",
}
func initMime() {
func initMime() error {
for k, v := range mimemaps {
mime.AddExtensionType(k, v)
}
return nil
}
......
......@@ -16,6 +16,7 @@ var (
commands = make(map[string]commander)
)
// print help.
func printHelp(errs ...string) {
content := `orm command usage:
......@@ -31,6 +32,7 @@ func printHelp(errs ...string) {
os.Exit(2)
}
// listen for orm command and then run it if command arguments passed.
func RunCommand() {
if len(os.Args) < 2 || os.Args[1] != "orm" {
return
......@@ -58,6 +60,7 @@ func RunCommand() {
}
}
// sync database struct command interface.
type commandSyncDb struct {
al *alias
force bool
......@@ -66,6 +69,7 @@ type commandSyncDb struct {
rtOnError bool
}
// parse orm command line arguments.
func (d *commandSyncDb) Parse(args []string) {
var name string
......@@ -78,6 +82,7 @@ func (d *commandSyncDb) Parse(args []string) {
d.al = getDbAlias(name)
}
// run orm line command.
func (d *commandSyncDb) Run() error {
var drops []string
if d.force {
......@@ -208,10 +213,12 @@ func (d *commandSyncDb) Run() error {
return nil
}
// database creation commander interface implement.
type commandSqlAll struct {
al *alias
}
// parse orm command line arguments.
func (d *commandSqlAll) Parse(args []string) {
var name string
......@@ -222,6 +229,7 @@ func (d *commandSqlAll) Parse(args []string) {
d.al = getDbAlias(name)
}
// run orm line command.
func (d *commandSqlAll) Run() error {
sqls, indexes := getDbCreateSql(d.al)
var all []string
......@@ -243,6 +251,10 @@ func init() {
commands["sqlall"] = new(commandSqlAll)
}
// run syncdb command line.
// name means table's alias name. default is "default".
// force means run next sql if the current is error.
// verbose means show all info when running command or not.
func RunSyncdb(name string, force bool, verbose bool) error {
BootStrap()
......
......@@ -12,6 +12,7 @@ type dbIndex struct {
Sql string
}
// create database drop sql.
func getDbDropSql(al *alias) (sqls []string) {
if len(modelCache.cache) == 0 {
fmt.Println("no Model found, need register your model")
......@@ -26,6 +27,7 @@ func getDbDropSql(al *alias) (sqls []string) {
return sqls
}
// get database column type string.
func getColumnTyp(al *alias, fi *fieldInfo) (col string) {
T := al.DbBaser.DbTypes()
fieldType := fi.fieldType
......@@ -79,6 +81,7 @@ checkColumn:
return
}
// create alter sql string.
func getColumnAddQuery(al *alias, fi *fieldInfo) string {
Q := al.DbBaser.TableQuote()
typ := getColumnTyp(al, fi)
......@@ -90,6 +93,7 @@ func getColumnAddQuery(al *alias, fi *fieldInfo) string {
return fmt.Sprintf("ALTER TABLE %s%s%s ADD COLUMN %s%s%s %s", Q, fi.mi.table, Q, Q, fi.column, Q, typ)
}
// create database creation string.
func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex) {
if len(modelCache.cache) == 0 {
fmt.Println("no Model found, need register your model")
......
......@@ -9,27 +9,32 @@ import (
"time"
)
// database driver constant int.
type DriverType int
const (
_ DriverType = iota
DR_MySQL
DR_Sqlite
DR_Oracle
DR_Postgres
_ DriverType = iota // int enum type
DR_MySQL // mysql
DR_Sqlite // sqlite
DR_Oracle // oracle
DR_Postgres // pgsql
)
// database driver string.
type driver string
// get type constant int of current driver..
func (d driver) Type() DriverType {
a, _ := dataBaseCache.get(string(d))
return a.Driver
}
// get name of current driver
func (d driver) Name() string {
return string(d)
}
// check driver iis implemented Driver interface or not.
var _ Driver = new(driver)
var (
......@@ -47,11 +52,13 @@ var (
}
)
// database alias cacher.
type _dbCache struct {
mux sync.RWMutex
cache map[string]*alias
}
// add database alias with original name.
func (ac *_dbCache) add(name string, al *alias) (added bool) {
ac.mux.Lock()
defer ac.mux.Unlock()
......@@ -62,6 +69,7 @@ func (ac *_dbCache) add(name string, al *alias) (added bool) {
return
}
// get database alias if cached.
func (ac *_dbCache) get(name string) (al *alias, ok bool) {
ac.mux.RLock()
defer ac.mux.RUnlock()
......@@ -69,6 +77,7 @@ func (ac *_dbCache) get(name string) (al *alias, ok bool) {
return
}
// get default alias.
func (ac *_dbCache) getDefault() (al *alias) {
al, _ = ac.get("default")
return
......@@ -123,21 +132,18 @@ func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) {
switch al.Driver {
case DR_MySQL:
row := al.DB.QueryRow("SELECT @@session.time_zone")
row := al.DB.QueryRow("SELECT TIMEDIFF(NOW(), UTC_TIMESTAMP)")
var tz string
row.Scan(&tz)
if tz == "SYSTEM" {
tz = ""
row = al.DB.QueryRow("SELECT @@system_time_zone")
row.Scan(&tz)
t, err := time.Parse("MST", tz)
if err == nil {
al.TZ = t.Location()
if len(tz) >= 8 {
if tz[0] != '-' {
tz = "+" + tz
}
} else {
t, err := time.Parse("-07:00", tz)
t, err := time.Parse("-07:00:00", tz)
if err == nil {
al.TZ = t.Location()
} else {
DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error())
}
}
......@@ -163,6 +169,8 @@ func RegisterDataBase(aliasName, driverName, dataSource string, params ...int) {
loc, err := time.LoadLocation(tz)
if err == nil {
al.TZ = loc
} else {
DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error())
}
}
......
......@@ -4,6 +4,7 @@ import (
"fmt"
)
// mysql operators.
var mysqlOperators = map[string]string{
"exact": "= ?",
"iexact": "LIKE ?",
......@@ -21,6 +22,7 @@ var mysqlOperators = map[string]string{
"iendswith": "LIKE ?",
}
// mysql column field types.
var mysqlTypes = map[string]string{
"auto": "AUTO_INCREMENT NOT NULL PRIMARY KEY",
"pk": "NOT NULL PRIMARY KEY",
......@@ -41,29 +43,35 @@ var mysqlTypes = map[string]string{
"float64-decimal": "numeric(%d, %d)",
}
// mysql dbBaser implementation.
type dbBaseMysql struct {
dbBase
}
var _ dbBaser = new(dbBaseMysql)
// get mysql operator.
func (d *dbBaseMysql) OperatorSql(operator string) string {
return mysqlOperators[operator]
}
// get mysql table field types.
func (d *dbBaseMysql) DbTypes() map[string]string {
return mysqlTypes
}
// show table sql for mysql.
func (d *dbBaseMysql) ShowTablesQuery() string {
return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema = DATABASE()"
}
// show columns sql of table for mysql.
func (d *dbBaseMysql) ShowColumnsQuery(table string) string {
return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.columns "+
"WHERE table_schema = DATABASE() AND table_name = '%s'", table)
}
// execute sql to check index exist.
func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool {
row := db.QueryRow("SELECT count(*) FROM information_schema.statistics "+
"WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name)
......@@ -72,6 +80,7 @@ func (d *dbBaseMysql) IndexExists(db dbQuerier, table string, name string) bool
return cnt > 0
}
// create new mysql dbBaser.
func newdbBaseMysql() dbBaser {
b := new(dbBaseMysql)
b.ins = b
......
package orm
// oracle dbBaser
type dbBaseOracle struct {
dbBase
}
var _ dbBaser = new(dbBaseOracle)
// create oracle dbBaser.
func newdbBaseOracle() dbBaser {
b := new(dbBaseOracle)
b.ins = b
......
......@@ -5,6 +5,7 @@ import (
"strconv"
)
// postgresql operators.
var postgresOperators = map[string]string{
"exact": "= ?",
"iexact": "= UPPER(?)",
......@@ -20,6 +21,7 @@ var postgresOperators = map[string]string{
"iendswith": "LIKE UPPER(?)",
}
// postgresql column field types.
var postgresTypes = map[string]string{
"auto": "serial NOT NULL PRIMARY KEY",
"pk": "NOT NULL PRIMARY KEY",
......@@ -40,16 +42,19 @@ var postgresTypes = map[string]string{
"float64-decimal": "numeric(%d, %d)",
}
// postgresql dbBaser.
type dbBasePostgres struct {
dbBase
}
var _ dbBaser = new(dbBasePostgres)
// get postgresql operator.
func (d *dbBasePostgres) OperatorSql(operator string) string {
return postgresOperators[operator]
}
// generate functioned sql string, such as contains(text).
func (d *dbBasePostgres) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) {
switch operator {
case "contains", "startswith", "endswith":
......@@ -59,6 +64,7 @@ func (d *dbBasePostgres) GenerateOperatorLeftCol(fi *fieldInfo, operator string,
}
}
// postgresql unsupports updating joined record.
func (d *dbBasePostgres) SupportUpdateJoin() bool {
return false
}
......@@ -67,10 +73,13 @@ func (d *dbBasePostgres) MaxLimit() uint64 {
return 0
}
// postgresql quote is ".
func (d *dbBasePostgres) TableQuote() string {
return `"`
}
// postgresql value placeholder is $n.
// replace default ? to $n.
func (d *dbBasePostgres) ReplaceMarks(query *string) {
q := *query
num := 0
......@@ -97,6 +106,7 @@ func (d *dbBasePostgres) ReplaceMarks(query *string) {
*query = string(data)
}
// make returning sql support for postgresql.
func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) (has bool) {
if mi.fields.pk.auto {
if query != nil {
......@@ -107,18 +117,22 @@ func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) (has bool)
return
}
// show table sql for postgresql.
func (d *dbBasePostgres) ShowTablesQuery() string {
return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema NOT IN ('pg_catalog', 'information_schema')"
}
// show table columns sql for postgresql.
func (d *dbBasePostgres) ShowColumnsQuery(table string) string {
return fmt.Sprintf("SELECT column_name, data_type, is_nullable FROM information_schema.columns where table_schema NOT IN ('pg_catalog', 'information_schema') and table_name = '%s'", table)
}
// get column types of postgresql.
func (d *dbBasePostgres) DbTypes() map[string]string {
return postgresTypes
}
// check index exist in postgresql.
func (d *dbBasePostgres) IndexExists(db dbQuerier, table string, name string) bool {
query := fmt.Sprintf("SELECT COUNT(*) FROM pg_indexes WHERE tablename = '%s' AND indexname = '%s'", table, name)
row := db.QueryRow(query)
......@@ -127,6 +141,7 @@ func (d *dbBasePostgres) IndexExists(db dbQuerier, table string, name string) bo
return cnt > 0
}
// create new postgresql dbBaser.
func newdbBasePostgres() dbBaser {
b := new(dbBasePostgres)
b.ins = b
......
......@@ -5,6 +5,7 @@ import (
"fmt"
)
// sqlite operators.
var sqliteOperators = map[string]string{
"exact": "= ?",
"iexact": "LIKE ? ESCAPE '\\'",
......@@ -20,6 +21,7 @@ var sqliteOperators = map[string]string{
"iendswith": "LIKE ? ESCAPE '\\'",
}
// sqlite column types.
var sqliteTypes = map[string]string{
"auto": "integer NOT NULL PRIMARY KEY AUTOINCREMENT",
"pk": "NOT NULL PRIMARY KEY",
......@@ -40,38 +42,47 @@ var sqliteTypes = map[string]string{
"float64-decimal": "decimal",
}
// sqlite dbBaser.
type dbBaseSqlite struct {
dbBase
}
var _ dbBaser = new(dbBaseSqlite)
// get sqlite operator.
func (d *dbBaseSqlite) OperatorSql(operator string) string {
return sqliteOperators[operator]
}
// generate functioned sql for sqlite.
// only support DATE(text).
func (d *dbBaseSqlite) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) {
if fi.fieldType == TypeDateField {
*leftCol = fmt.Sprintf("DATE(%s)", *leftCol)
}
}
// unable updating joined record in sqlite.
func (d *dbBaseSqlite) SupportUpdateJoin() bool {
return false
}
// max int in sqlite.
func (d *dbBaseSqlite) MaxLimit() uint64 {
return 9223372036854775807
}
// get column types in sqlite.
func (d *dbBaseSqlite) DbTypes() map[string]string {
return sqliteTypes
}
// get show tables sql in sqlite.
func (d *dbBaseSqlite) ShowTablesQuery() string {
return "SELECT name FROM sqlite_master WHERE type = 'table'"
}
// get columns in sqlite.
func (d *dbBaseSqlite) GetColumns(db dbQuerier, table string) (map[string][3]string, error) {
query := d.ins.ShowColumnsQuery(table)
rows, err := db.Query(query)
......@@ -92,10 +103,12 @@ func (d *dbBaseSqlite) GetColumns(db dbQuerier, table string) (map[string][3]str
return columns, nil
}
// get show columns sql in sqlite.
func (d *dbBaseSqlite) ShowColumnsQuery(table string) string {
return fmt.Sprintf("pragma table_info('%s')", table)
}
// check index exist in sqlite.
func (d *dbBaseSqlite) IndexExists(db dbQuerier, table string, name string) bool {
query := fmt.Sprintf("PRAGMA index_list('%s')", table)
rows, err := db.Query(query)
......@@ -113,6 +126,7 @@ func (d *dbBaseSqlite) IndexExists(db dbQuerier, table string, name string) bool
return false
}
// create new sqlite dbBaser.
func newdbBaseSqlite() dbBaser {
b := new(dbBaseSqlite)
b.ins = b
......
......@@ -6,6 +6,7 @@ import (
"time"
)
// table info struct.
type dbTable struct {
id int
index string
......@@ -18,13 +19,17 @@ type dbTable struct {
jtl *dbTable
}
// tables collection struct, contains some tables.
type dbTables struct {
tablesM map[string]*dbTable
tables []*dbTable
mi *modelInfo
base dbBaser
skipEnd bool
}
// set table info to collection.
// if not exist, create new.
func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool) *dbTable {
name := strings.Join(names, ExprSep)
if j, ok := t.tablesM[name]; ok {
......@@ -41,6 +46,7 @@ func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool)
return t.tablesM[name]
}
// add table info to collection.
func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool) (*dbTable, bool) {
name := strings.Join(names, ExprSep)
if _, ok := t.tablesM[name]; ok == false {
......@@ -53,11 +59,14 @@ func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool)
return t.tablesM[name], false
}
// get table info in collection.
func (t *dbTables) get(name string) (*dbTable, bool) {
j, ok := t.tablesM[name]
return j, ok
}
// get related fields info in recursive depth loop.
// loop once, depth decreases one.
func (t *dbTables) loopDepth(depth int, prefix string, fi *fieldInfo, related []string) []string {
if depth < 0 || fi.fieldType == RelManyToMany {
return related
......@@ -78,6 +87,7 @@ func (t *dbTables) loopDepth(depth int, prefix string, fi *fieldInfo, related []
return related
}
// parse related fields.
func (t *dbTables) parseRelated(rels []string, depth int) {
relsNum := len(rels)
......@@ -111,7 +121,7 @@ func (t *dbTables) parseRelated(rels []string, depth int) {
names = append(names, fi.name)
mmi = fi.relModelInfo
if fi.null {
if fi.null || t.skipEnd {
inner = false
}
......@@ -139,6 +149,7 @@ func (t *dbTables) parseRelated(rels []string, depth int) {
}
}
// generate join string.
func (t *dbTables) getJoinSql() (join string) {
Q := t.base.TableQuote()
......@@ -185,9 +196,12 @@ func (t *dbTables) getJoinSql() (join string) {
return
}
// parse orm model struct field tag expression.
func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string, info *fieldInfo, success bool) {
var (
jtl *dbTable
fi *fieldInfo
fiN *fieldInfo
mmi = mi
)
......@@ -196,9 +210,22 @@ func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string
inner := true
loopFor:
for i, ex := range exprs {
fi, ok := mmi.fields.GetByAny(ex)
var ok, okN bool
if fiN != nil {
fi = fiN
ok = true
fiN = nil
}
if i == 0 {
fi, ok = mmi.fields.GetByAny(ex)
}
_ = okN
if ok {
......@@ -216,44 +243,61 @@ func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string
mmi = fi.reverseFieldInfo.mi
}
if i < num {
fiN, okN = mmi.fields.GetByAny(exprs[i+1])
}
if isRel && (fi.mi.isThrough == false || num != i) {
if fi.null {
if fi.null || t.skipEnd {
inner = false
}
jt, _ := t.add(names, mmi, fi, inner)
jt.jtl = jtl
jtl = jt
}
if t.skipEnd && okN || !t.skipEnd {
if t.skipEnd && okN && fiN.pk {
goto loopEnd
}
if num == i {
if i == 0 || jtl == nil {
index = "T0"
} else {
index = jtl.index
jt, _ := t.add(names, mmi, fi, inner)
jt.jtl = jtl
jtl = jt
}
info = fi
}
if jtl == nil {
name = fi.name
} else {
name = jtl.name + ExprSep + fi.name
}
if num != i {
continue
}
switch {
case fi.rel:
loopEnd:
case fi.reverse:
switch fi.reverseFieldInfo.fieldType {
case RelOneToOne, RelForeignKey:
index = jtl.index
info = fi.reverseFieldInfo.mi.fields.pk
name = info.name
}
if i == 0 || jtl == nil {
index = "T0"
} else {
index = jtl.index
}
info = fi
if jtl == nil {
name = fi.name
} else {
name = jtl.name + ExprSep + fi.name
}
switch {
case fi.rel:
case fi.reverse:
switch fi.reverseFieldInfo.fieldType {
case RelOneToOne, RelForeignKey:
index = jtl.index
info = fi.reverseFieldInfo.mi.fields.pk
name = info.name
}
}
break loopFor
} else {
index = ""
name = ""
......@@ -267,6 +311,7 @@ func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string
return
}
// generate condition sql.
func (t *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (where string, params []interface{}) {
if cond == nil || cond.IsEmpty() {
return
......@@ -331,6 +376,7 @@ func (t *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (whe
return
}
// generate order sql.
func (t *dbTables) getOrderSql(orders []string) (orderSql string) {
if len(orders) == 0 {
return
......@@ -359,6 +405,7 @@ func (t *dbTables) getOrderSql(orders []string) (orderSql string) {
return
}
// generate limit sql.
func (t *dbTables) getLimitSql(mi *modelInfo, offset int64, limit int64) (limits string) {
if limit == 0 {
limit = int64(DefaultRowsLimit)
......@@ -381,6 +428,7 @@ func (t *dbTables) getLimitSql(mi *modelInfo, offset int64, limit int64) (limits
return
}
// crete new tables collection.
func newDbTables(mi *modelInfo, base dbBaser) *dbTables {
tables := &dbTables{}
tables.tablesM = make(map[string]*dbTable)
......
......@@ -6,6 +6,7 @@ import (
"time"
)
// get table alias.
func getDbAlias(name string) *alias {
if al, ok := dataBaseCache.get(name); ok {
return al
......@@ -15,6 +16,7 @@ func getDbAlias(name string) *alias {
return nil
}
// get pk column info.
func getExistPk(mi *modelInfo, ind reflect.Value) (column string, value interface{}, exist bool) {
fi := mi.fields.pk
......@@ -37,6 +39,7 @@ func getExistPk(mi *modelInfo, ind reflect.Value) (column string, value interfac
return
}
// get fields description as flatted string.
func getFlatParams(fi *fieldInfo, args []interface{}, tz *time.Location) (params []interface{}) {
outFor:
......
......@@ -41,6 +41,7 @@ var (
}
)
// model info collection
type _modelCache struct {
sync.RWMutex
orders []string
......@@ -49,6 +50,7 @@ type _modelCache struct {
done bool
}
// get all model info
func (mc *_modelCache) all() map[string]*modelInfo {
m := make(map[string]*modelInfo, len(mc.cache))
for k, v := range mc.cache {
......@@ -57,6 +59,7 @@ func (mc *_modelCache) all() map[string]*modelInfo {
return m
}
// get orderd model info
func (mc *_modelCache) allOrdered() []*modelInfo {
m := make([]*modelInfo, 0, len(mc.orders))
for _, table := range mc.orders {
......@@ -65,16 +68,19 @@ func (mc *_modelCache) allOrdered() []*modelInfo {
return m
}
// get model info by table name
func (mc *_modelCache) get(table string) (mi *modelInfo, ok bool) {
mi, ok = mc.cache[table]
return
}
// get model info by field name
func (mc *_modelCache) getByFN(name string) (mi *modelInfo, ok bool) {
mi, ok = mc.cacheByFN[name]
return
}
// set model info to collection
func (mc *_modelCache) set(table string, mi *modelInfo) *modelInfo {
mii := mc.cache[table]
mc.cache[table] = mi
......@@ -85,6 +91,7 @@ func (mc *_modelCache) set(table string, mi *modelInfo) *modelInfo {
return mii
}
// clean all model info.
func (mc *_modelCache) clean() {
mc.orders = make([]string, 0)
mc.cache = make(map[string]*modelInfo)
......
......@@ -8,6 +8,8 @@ import (
"strings"
)
// register models.
// prefix means table name prefix.
func registerModel(model interface{}, prefix string) {
val := reflect.ValueOf(model)
ind := reflect.Indirect(val)
......@@ -67,6 +69,7 @@ func registerModel(model interface{}, prefix string) {
modelCache.set(table, info)
}
// boostrap models
func bootStrap() {
if modelCache.done {
return
......@@ -281,6 +284,7 @@ end:
}
}
// register models
func RegisterModel(models ...interface{}) {
if modelCache.done {
panic(fmt.Errorf("RegisterModel must be run before BootStrap"))
......@@ -302,6 +306,8 @@ func RegisterModelWithPrefix(prefix string, models ...interface{}) {
}
}
// bootrap models.
// make all model parsed and can not add more models
func BootStrap() {
if modelCache.done {
return
......
......@@ -9,6 +9,7 @@ import (
var errSkipField = errors.New("skip field")
// field info collection
type fields struct {
pk *fieldInfo
columns map[string]*fieldInfo
......@@ -23,6 +24,7 @@ type fields struct {
dbcols []string
}
// add field info
func (f *fields) Add(fi *fieldInfo) (added bool) {
if f.fields[fi.name] == nil && f.columns[fi.column] == nil {
f.columns[fi.column] = fi
......@@ -49,14 +51,17 @@ func (f *fields) Add(fi *fieldInfo) (added bool) {
return true
}
// get field info by name
func (f *fields) GetByName(name string) *fieldInfo {
return f.fields[name]
}
// get field info by column name
func (f *fields) GetByColumn(column string) *fieldInfo {
return f.columns[column]
}
// get field info by string, name is prior
func (f *fields) GetByAny(name string) (*fieldInfo, bool) {
if fi, ok := f.fields[name]; ok {
return fi, ok
......@@ -70,6 +75,7 @@ func (f *fields) GetByAny(name string) (*fieldInfo, bool) {
return nil, false
}
// create new field info collection
func newFields() *fields {
f := new(fields)
f.fields = make(map[string]*fieldInfo)
......@@ -79,6 +85,7 @@ func newFields() *fields {
return f
}
// single field info
type fieldInfo struct {
mi *modelInfo
fieldIndex int
......@@ -115,6 +122,7 @@ type fieldInfo struct {
onDelete string
}
// new field info
func newFieldInfo(mi *modelInfo, field reflect.Value, sf reflect.StructField) (fi *fieldInfo, err error) {
var (
tag string
......
......@@ -7,6 +7,7 @@ import (
"reflect"
)
// single model info
type modelInfo struct {
pkg string
name string
......@@ -20,6 +21,7 @@ type modelInfo struct {
isThrough bool
}
// new model info
func newModelInfo(val reflect.Value) (info *modelInfo) {
var (
err error
......@@ -79,6 +81,8 @@ func newModelInfo(val reflect.Value) (info *modelInfo) {
return
}
// combine related model info to new model info.
// prepare for relation models query.
func newM2MModelInfo(m1, m2 *modelInfo) (info *modelInfo) {
info = new(modelInfo)
info.fields = newFields()
......
......@@ -7,10 +7,12 @@ import (
"time"
)
// get reflect.Type name with package path.
func getFullName(typ reflect.Type) string {
return typ.PkgPath() + "." + typ.Name()
}
// get table name. method, or field name. auto snaked.
func getTableName(val reflect.Value) string {
ind := reflect.Indirect(val)
fun := val.MethodByName("TableName")
......@@ -26,6 +28,7 @@ func getTableName(val reflect.Value) string {
return snakeString(ind.Type().Name())
}
// get table engine, mysiam or innodb.
func getTableEngine(val reflect.Value) string {
fun := val.MethodByName("TableEngine")
if fun.IsValid() {
......@@ -40,6 +43,7 @@ func getTableEngine(val reflect.Value) string {
return ""
}
// get table index from method.
func getTableIndex(val reflect.Value) [][]string {
fun := val.MethodByName("TableIndex")
if fun.IsValid() {
......@@ -56,6 +60,7 @@ func getTableIndex(val reflect.Value) [][]string {
return nil
}
// get table unique from method
func getTableUnique(val reflect.Value) [][]string {
fun := val.MethodByName("TableUnique")
if fun.IsValid() {
......@@ -72,6 +77,7 @@ func getTableUnique(val reflect.Value) [][]string {
return nil
}
// get snaked column name
func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col string) string {
col = strings.ToLower(col)
column := col
......@@ -89,6 +95,7 @@ func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col
return column
}
// return field type as type constant from reflect.Value
func getFieldType(val reflect.Value) (ft int, err error) {
elm := reflect.Indirect(val)
switch elm.Kind() {
......@@ -128,6 +135,7 @@ func getFieldType(val reflect.Value) (ft int, err error) {
return
}
// parse struct tag string
func parseStructTag(data string, attrs *map[string]bool, tags *map[string]string) {
attr := make(map[string]bool)
tag := make(map[string]string)
......
......@@ -25,6 +25,7 @@ var (
ErrMultiRows = errors.New("<QuerySeter> return multi rows")
ErrNoRows = errors.New("<QuerySeter> no row found")
ErrStmtClosed = errors.New("<QuerySeter> stmt already closed")
ErrArgs = errors.New("<Ormer> args error may be empty")
ErrNotImplement = errors.New("have not implement")
)
......@@ -39,11 +40,12 @@ type orm struct {
var _ Ormer = new(orm)
func (o *orm) getMiInd(md interface{}) (mi *modelInfo, ind reflect.Value) {
// get model info and model reflect value
func (o *orm) getMiInd(md interface{}, needPtr bool) (mi *modelInfo, ind reflect.Value) {
val := reflect.ValueOf(md)
ind = reflect.Indirect(val)
typ := ind.Type()
if val.Kind() != reflect.Ptr {
if needPtr && val.Kind() != reflect.Ptr {
panic(fmt.Errorf("<Ormer> cannot use non-ptr model struct `%s`", getFullName(typ)))
}
name := getFullName(typ)
......@@ -53,6 +55,7 @@ func (o *orm) getMiInd(md interface{}) (mi *modelInfo, ind reflect.Value) {
panic(fmt.Errorf("<Ormer> table: `%s` not found, maybe not RegisterModel", name))
}
// get field info from model info by given field name
func (o *orm) getFieldInfo(mi *modelInfo, name string) *fieldInfo {
fi, ok := mi.fields.GetByAny(name)
if !ok {
......@@ -61,8 +64,9 @@ func (o *orm) getFieldInfo(mi *modelInfo, name string) *fieldInfo {
return fi
}
// read data to model
func (o *orm) Read(md interface{}, cols ...string) error {
mi, ind := o.getMiInd(md)
mi, ind := o.getMiInd(md, true)
err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols)
if err != nil {
return err
......@@ -70,26 +74,69 @@ func (o *orm) Read(md interface{}, cols ...string) error {
return nil
}
// insert model data to database
func (o *orm) Insert(md interface{}) (int64, error) {
mi, ind := o.getMiInd(md)
mi, ind := o.getMiInd(md, true)
id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ)
if err != nil {
return id, err
}
if id > 0 {
if mi.fields.pk.auto {
if mi.fields.pk.fieldType&IsPostiveIntegerField > 0 {
ind.Field(mi.fields.pk.fieldIndex).SetUint(uint64(id))
} else {
ind.Field(mi.fields.pk.fieldIndex).SetInt(id)
o.setPk(mi, ind, id)
return id, nil
}
// set auto pk field
func (o *orm) setPk(mi *modelInfo, ind reflect.Value, id int64) {
if mi.fields.pk.auto {
if mi.fields.pk.fieldType&IsPostiveIntegerField > 0 {
ind.Field(mi.fields.pk.fieldIndex).SetUint(uint64(id))
} else {
ind.Field(mi.fields.pk.fieldIndex).SetInt(id)
}
}
}
// insert some models to database
func (o *orm) InsertMulti(bulk int, mds interface{}) (int64, error) {
var cnt int64
sind := reflect.Indirect(reflect.ValueOf(mds))
switch sind.Kind() {
case reflect.Array, reflect.Slice:
if sind.Len() == 0 {
return cnt, ErrArgs
}
default:
return cnt, ErrArgs
}
if bulk <= 1 {
for i := 0; i < sind.Len(); i++ {
ind := sind.Index(i)
mi, _ := o.getMiInd(ind.Interface(), false)
id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ)
if err != nil {
return cnt, err
}
o.setPk(mi, ind, id)
cnt += 1
}
} else {
mi, _ := o.getMiInd(sind.Index(0).Interface(), false)
return o.alias.DbBaser.InsertMulti(o.db, mi, sind, bulk, o.alias.TZ)
}
return id, nil
return cnt, nil
}
// update model to database.
// cols set the columns those want to update.
func (o *orm) Update(md interface{}, cols ...string) (int64, error) {
mi, ind := o.getMiInd(md)
mi, ind := o.getMiInd(md, true)
num, err := o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols)
if err != nil {
return num, err
......@@ -97,26 +144,22 @@ func (o *orm) Update(md interface{}, cols ...string) (int64, error) {
return num, nil
}
// delete model in database
func (o *orm) Delete(md interface{}) (int64, error) {
mi, ind := o.getMiInd(md)
mi, ind := o.getMiInd(md, true)
num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ)
if err != nil {
return num, err
}
if num > 0 {
if mi.fields.pk.auto {
if mi.fields.pk.fieldType&IsPostiveIntegerField > 0 {
ind.Field(mi.fields.pk.fieldIndex).SetUint(0)
} else {
ind.Field(mi.fields.pk.fieldIndex).SetInt(0)
}
}
o.setPk(mi, ind, 0)
}
return num, nil
}
// create a models to models queryer
func (o *orm) QueryM2M(md interface{}, name string) QueryM2Mer {
mi, ind := o.getMiInd(md)
mi, ind := o.getMiInd(md, true)
fi := o.getFieldInfo(mi, name)
switch {
......@@ -129,6 +172,14 @@ func (o *orm) QueryM2M(md interface{}, name string) QueryM2Mer {
return newQueryM2M(md, o, mi, fi, ind)
}
// load related models to md model.
// args are limit, offset int and order string.
//
// example:
// orm.LoadRelated(post,"Tags")
// for _,tag := range post.Tags{...}
//
// make sure the relation is defined in model struct tags.
func (o *orm) LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) {
_, fi, ind, qseter := o.queryRelated(md, name)
......@@ -190,14 +241,21 @@ func (o *orm) LoadRelated(md interface{}, name string, args ...interface{}) (int
return nums, err
}
// return a QuerySeter for related models to md model.
// it can do all, update, delete in QuerySeter.
// example:
// qs := orm.QueryRelated(post,"Tag")
// qs.All(&[]*Tag{})
//
func (o *orm) QueryRelated(md interface{}, name string) QuerySeter {
// is this api needed ?
_, _, _, qs := o.queryRelated(md, name)
return qs
}
// get QuerySeter for related models to md model
func (o *orm) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, reflect.Value, QuerySeter) {
mi, ind := o.getMiInd(md)
mi, ind := o.getMiInd(md, true)
fi := o.getFieldInfo(mi, name)
_, _, exist := getExistPk(mi, ind)
......@@ -227,6 +285,7 @@ func (o *orm) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo,
return mi, fi, ind, qs
}
// get reverse relation QuerySeter
func (o *orm) getReverseQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet {
switch fi.fieldType {
case RelReverseOne, RelReverseMany:
......@@ -247,6 +306,7 @@ func (o *orm) getReverseQs(md interface{}, mi *modelInfo, fi *fieldInfo) *queryS
return q
}
// get relation QuerySeter
func (o *orm) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet {
switch fi.fieldType {
case RelOneToOne, RelForeignKey, RelManyToMany:
......@@ -266,6 +326,9 @@ func (o *orm) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet {
return q
}
// return a QuerySeter for table operations.
// table name can be string or struct.
// e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)),
func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) {
name := ""
if table, ok := ptrStructOrTableName.(string); ok {
......@@ -285,6 +348,7 @@ func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) {
return
}
// switch to another registered database driver by given name.
func (o *orm) Using(name string) error {
if o.isTx {
panic(fmt.Errorf("<Ormer.Using> transaction has been start, cannot change db"))
......@@ -302,6 +366,7 @@ func (o *orm) Using(name string) error {
return nil
}
// begin transaction
func (o *orm) Begin() error {
if o.isTx {
return ErrTxHasBegan
......@@ -320,6 +385,7 @@ func (o *orm) Begin() error {
return nil
}
// commit transaction
func (o *orm) Commit() error {
if o.isTx == false {
return ErrTxDone
......@@ -334,6 +400,7 @@ func (o *orm) Commit() error {
return err
}
// rollback transaction
func (o *orm) Rollback() error {
if o.isTx == false {
return ErrTxDone
......@@ -348,14 +415,17 @@ func (o *orm) Rollback() error {
return err
}
// return a raw query seter for raw sql string.
func (o *orm) Raw(query string, args ...interface{}) RawSeter {
return newRawSet(o, query, args)
}
// return current using database Driver
func (o *orm) Driver() Driver {
return driver(o.alias.Name)
}
// create new orm
func NewOrm() Ormer {
BootStrap() // execute only once
......
......@@ -18,15 +18,19 @@ type condValue struct {
isCond bool
}
// condition struct.
// work for WHERE conditions.
type Condition struct {
params []condValue
}
// return new condition struct
func NewCondition() *Condition {
c := &Condition{}
return c
}
// add expression to condition
func (c Condition) And(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 {
panic(fmt.Errorf("<Condition.And> args cannot empty"))
......@@ -35,6 +39,7 @@ func (c Condition) And(expr string, args ...interface{}) *Condition {
return &c
}
// add NOT expression to condition
func (c Condition) AndNot(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 {
panic(fmt.Errorf("<Condition.AndNot> args cannot empty"))
......@@ -43,6 +48,7 @@ func (c Condition) AndNot(expr string, args ...interface{}) *Condition {
return &c
}
// combine a condition to current condition
func (c *Condition) AndCond(cond *Condition) *Condition {
c = c.clone()
if c == cond {
......@@ -54,6 +60,7 @@ func (c *Condition) AndCond(cond *Condition) *Condition {
return c
}
// add OR expression to condition
func (c Condition) Or(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 {
panic(fmt.Errorf("<Condition.Or> args cannot empty"))
......@@ -62,6 +69,7 @@ func (c Condition) Or(expr string, args ...interface{}) *Condition {
return &c
}
// add OR NOT expression to condition
func (c Condition) OrNot(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 {
panic(fmt.Errorf("<Condition.OrNot> args cannot empty"))
......@@ -70,6 +78,7 @@ func (c Condition) OrNot(expr string, args ...interface{}) *Condition {
return &c
}
// combine a OR condition to current condition
func (c *Condition) OrCond(cond *Condition) *Condition {
c = c.clone()
if c == cond {
......@@ -81,10 +90,12 @@ func (c *Condition) OrCond(cond *Condition) *Condition {
return c
}
// check the condition arguments are empty or not.
func (c *Condition) IsEmpty() bool {
return len(c.params) == 0
}
// clone a condition
func (c Condition) clone() *Condition {
return &c
}
......
......@@ -13,6 +13,7 @@ type Log struct {
*log.Logger
}
// set io.Writer to create a Logger.
func NewLog(out io.Writer) *Log {
d := new(Log)
d.Logger = log.New(out, "[ORM]", 1e9)
......@@ -40,6 +41,8 @@ func debugLogQueies(alias *alias, operaton, query string, t time.Time, err error
DebugLog.Println(con)
}
// statement query logger struct.
// if dev mode, use stmtQueryLog, or use stmtQuerier.
type stmtQueryLog struct {
alias *alias
query string
......@@ -84,6 +87,8 @@ func newStmtQueryLog(alias *alias, stmt stmtQuerier, query string) stmtQuerier {
return d
}
// database query logger struct.
// if dev mode, use dbQueryLog, or use dbQuerier.
type dbQueryLog struct {
alias *alias
db dbQuerier
......
......@@ -5,6 +5,7 @@ import (
"reflect"
)
// an insert queryer struct
type insertSet struct {
mi *modelInfo
orm *orm
......@@ -14,6 +15,7 @@ type insertSet struct {
var _ Inserter = new(insertSet)
// insert model ignore it's registered or not.
func (o *insertSet) Insert(md interface{}) (int64, error) {
if o.closed {
return 0, ErrStmtClosed
......@@ -44,6 +46,7 @@ func (o *insertSet) Insert(md interface{}) (int64, error) {
return id, nil
}
// close insert queryer statement
func (o *insertSet) Close() error {
if o.closed {
return ErrStmtClosed
......@@ -52,6 +55,7 @@ func (o *insertSet) Close() error {
return o.stmt.Close()
}
// create new insert queryer.
func newInsertSet(orm *orm, mi *modelInfo) (Inserter, error) {
bi := new(insertSet)
bi.orm = orm
......
......@@ -4,6 +4,7 @@ import (
"reflect"
)
// model to model struct
type queryM2M struct {
md interface{}
mi *modelInfo
......@@ -12,6 +13,13 @@ type queryM2M struct {
ind reflect.Value
}
// add models to origin models when creating queryM2M.
// example:
// m2m := orm.QueryM2M(post,"Tag")
// m2m.Add(&Tag1{},&Tag2{})
// for _,tag := range post.Tags{}
//
// make sure the relation is defined in post model struct tag.
func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
fi := o.fi
mi := fi.relThroughModelInfo
......@@ -44,7 +52,8 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
names := []string{mfi.column, rfi.column}
var nums int64
values := make([]interface{}, 0, len(models)*2)
for _, md := range models {
ind := reflect.Indirect(reflect.ValueOf(md))
......@@ -59,18 +68,14 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
}
}
values := []interface{}{v1, v2}
_, err := dbase.InsertValue(orm.db, mi, names, values)
if err != nil {
return nums, err
}
values = append(values, v1, v2)
nums += 1
}
return nums, nil
return dbase.InsertValue(orm.db, mi, true, names, values)
}
// remove models following the origin model relationship
func (o *queryM2M) Remove(mds ...interface{}) (int64, error) {
fi := o.fi
qs := o.qs.Filter(fi.reverseFieldInfo.name, o.md)
......@@ -82,17 +87,20 @@ func (o *queryM2M) Remove(mds ...interface{}) (int64, error) {
return nums, nil
}
// check model is existed in relationship of origin model
func (o *queryM2M) Exist(md interface{}) bool {
fi := o.fi
return o.qs.Filter(fi.reverseFieldInfo.name, o.md).
Filter(fi.reverseFieldInfoTwo.name, md).Exist()
}
// clean all models in related of origin model
func (o *queryM2M) Clear() (int64, error) {
fi := o.fi
return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Delete()
}
// count all related models of origin model
func (o *queryM2M) Count() (int64, error) {
fi := o.fi
return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Count()
......@@ -100,6 +108,7 @@ func (o *queryM2M) Count() (int64, error) {
var _ QueryM2Mer = new(queryM2M)
// create new M2M queryer.
func newQueryM2M(md interface{}, o *orm, mi *modelInfo, fi *fieldInfo, ind reflect.Value) QueryM2Mer {
qm2m := new(queryM2M)
qm2m.md = md
......
......@@ -18,6 +18,10 @@ const (
Col_Except
)
// ColValue do the field raw changes. e.g Nums = Nums + 10. usage:
// Params{
// "Nums": ColValue(Col_Add, 10),
// }
func ColValue(opt operator, value interface{}) interface{} {
switch opt {
case Col_Add, Col_Minus, Col_Multiply, Col_Except:
......@@ -34,6 +38,7 @@ func ColValue(opt operator, value interface{}) interface{} {
return val
}
// real query struct
type querySet struct {
mi *modelInfo
cond *Condition
......@@ -47,6 +52,7 @@ type querySet struct {
var _ QuerySeter = new(querySet)
// add condition expression to QuerySeter.
func (o querySet) Filter(expr string, args ...interface{}) QuerySeter {
if o.cond == nil {
o.cond = NewCondition()
......@@ -55,6 +61,7 @@ func (o querySet) Filter(expr string, args ...interface{}) QuerySeter {
return &o
}
// add NOT condition to querySeter.
func (o querySet) Exclude(expr string, args ...interface{}) QuerySeter {
if o.cond == nil {
o.cond = NewCondition()
......@@ -63,10 +70,13 @@ func (o querySet) Exclude(expr string, args ...interface{}) QuerySeter {
return &o
}
// set offset number
func (o *querySet) setOffset(num interface{}) {
o.offset = ToInt64(num)
}
// add LIMIT value.
// args[0] means offset, e.g. LIMIT num,offset.
func (o querySet) Limit(limit interface{}, args ...interface{}) QuerySeter {
o.limit = ToInt64(limit)
if len(args) > 0 {
......@@ -75,16 +85,21 @@ func (o querySet) Limit(limit interface{}, args ...interface{}) QuerySeter {
return &o
}
// add OFFSET value
func (o querySet) Offset(offset interface{}) QuerySeter {
o.setOffset(offset)
return &o
}
// add ORDER expression.
// "column" means ASC, "-column" means DESC.
func (o querySet) OrderBy(exprs ...string) QuerySeter {
o.orders = exprs
return &o
}
// set relation model to query together.
// it will query relation models and assign to parent model.
func (o querySet) RelatedSel(params ...interface{}) QuerySeter {
var related []string
if len(params) == 0 {
......@@ -105,36 +120,50 @@ func (o querySet) RelatedSel(params ...interface{}) QuerySeter {
return &o
}
// set condition to QuerySeter.
func (o querySet) SetCond(cond *Condition) QuerySeter {
o.cond = cond
return &o
}
// return QuerySeter execution result number
func (o *querySet) Count() (int64, error) {
return o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
}
// check result empty or not after QuerySeter executed
func (o *querySet) Exist() bool {
cnt, _ := o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
return cnt > 0
}
// execute update with parameters
func (o *querySet) Update(values Params) (int64, error) {
return o.orm.alias.DbBaser.UpdateBatch(o.orm.db, o, o.mi, o.cond, values, o.orm.alias.TZ)
}
// execute delete
func (o *querySet) Delete() (int64, error) {
return o.orm.alias.DbBaser.DeleteBatch(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
}
// return a insert queryer.
// it can be used in times.
// example:
// i,err := sq.PrepareInsert()
// i.Add(&user1{},&user2{})
func (o *querySet) PrepareInsert() (Inserter, error) {
return newInsertSet(o.orm, o.mi)
}
// query all data and map to containers.
// cols means the columns when querying.
func (o *querySet) All(container interface{}, cols ...string) (int64, error) {
return o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols)
}
// query one row data and map to containers.
// cols means the columns when querying.
func (o *querySet) One(container interface{}, cols ...string) error {
num, err := o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols)
if err != nil {
......@@ -149,18 +178,26 @@ func (o *querySet) One(container interface{}, cols ...string) error {
return nil
}
// query all data and map to []map[string]interface.
// expres means condition expression.
// it converts data to []map[column]value.
func (o *querySet) Values(results *[]Params, exprs ...string) (int64, error) {
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ)
}
// query all data and map to [][]interface
// it converts data to [][column_index]value
func (o *querySet) ValuesList(results *[]ParamsList, exprs ...string) (int64, error) {
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, exprs, results, o.orm.alias.TZ)
}
// query all data and map to []interface.
// it's designed for one row record set, auto change to []value, not [][column]value.
func (o *querySet) ValuesFlat(result *ParamsList, expr string) (int64, error) {
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, []string{expr}, result, o.orm.alias.TZ)
}
// create new QuerySeter.
func newQuerySet(orm *orm, mi *modelInfo) QuerySeter {
o := new(querySet)
o.mi = mi
......
......@@ -4,10 +4,10 @@ import (
"database/sql"
"fmt"
"reflect"
"strings"
"time"
)
// raw sql string prepared statement
type rawPrepare struct {
rs *rawSet
stmt stmtQuerier
......@@ -45,6 +45,7 @@ func newRawPreparer(rs *rawSet) (RawPreparer, error) {
return o, nil
}
// raw query seter
type rawSet struct {
query string
args []interface{}
......@@ -53,11 +54,13 @@ type rawSet struct {
var _ RawSeter = new(rawSet)
// set args for every query
func (o rawSet) SetArgs(args ...interface{}) RawSeter {
o.args = args
return &o
}
// execute raw sql and return sql.Result
func (o *rawSet) Exec() (sql.Result, error) {
query := o.query
o.orm.alias.DbBaser.ReplaceMarks(&query)
......@@ -66,6 +69,7 @@ func (o *rawSet) Exec() (sql.Result, error) {
return o.orm.db.Exec(query, args...)
}
// set field value to row container
func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) {
switch ind.Kind() {
case reflect.Bool:
......@@ -164,65 +168,12 @@ func (o *rawSet) setFieldValue(ind reflect.Value, value interface{}) {
}
}
func (o *rawSet) loopInitRefs(typ reflect.Type, refsPtr *[]interface{}, sIdxesPtr *[][]int) {
sIdxes := *sIdxesPtr
refs := *refsPtr
if typ.Kind() == reflect.Struct {
if typ.String() == "time.Time" {
var ref interface{}
refs = append(refs, &ref)
sIdxes = append(sIdxes, []int{0})
} else {
idxs := []int{}
outFor:
for idx := 0; idx < typ.NumField(); idx++ {
ctyp := typ.Field(idx)
tag := ctyp.Tag.Get(defaultStructTagName)
for _, v := range strings.Split(tag, defaultStructTagDelim) {
if v == "-" {
continue outFor
}
}
tp := ctyp.Type
if tp.Kind() == reflect.Ptr {
tp = tp.Elem()
}
if tp.String() == "time.Time" {
var ref interface{}
refs = append(refs, &ref)
} else if tp.Kind() != reflect.Struct {
var ref interface{}
refs = append(refs, &ref)
} else {
// skip other type
continue
}
idxs = append(idxs, idx)
}
sIdxes = append(sIdxes, idxs)
}
} else {
var ref interface{}
refs = append(refs, &ref)
sIdxes = append(sIdxes, []int{0})
}
*sIdxesPtr = sIdxes
*refsPtr = refs
}
func (o *rawSet) loopSetRefs(refs []interface{}, sIdxes [][]int, sInds []reflect.Value, nIndsPtr *[]reflect.Value, eTyps []reflect.Type, init bool) {
// set field value in loop for slice container
func (o *rawSet) loopSetRefs(refs []interface{}, sInds []reflect.Value, nIndsPtr *[]reflect.Value, eTyps []reflect.Type, init bool) {
nInds := *nIndsPtr
cur := 0
for i, idxs := range sIdxes {
for i := 0; i < len(sInds); i++ {
sInd := sInds[i]
eTyp := eTyps[i]
......@@ -258,32 +209,8 @@ func (o *rawSet) loopSetRefs(refs []interface{}, sIdxes [][]int, sInds []reflect
o.setFieldValue(ind, value)
}
cur++
} else {
hasValue := false
for _, idx := range idxs {
tind := ind.Field(idx)
value := reflect.ValueOf(refs[cur]).Elem().Interface()
if value != nil {
hasValue = true
}
if tind.Kind() == reflect.Ptr {
if value == nil {
tindV := reflect.New(tind.Type()).Elem()
tind.Set(tindV)
} else {
tindV := reflect.New(tind.Type().Elem())
o.setFieldValue(tindV.Elem(), value)
tind.Set(tindV)
}
} else {
o.setFieldValue(tind, value)
}
cur++
}
if hasValue == false && isPtr {
val = reflect.New(val.Type()).Elem()
}
}
} else {
value := reflect.ValueOf(refs[cur]).Elem().Interface()
if isPtr && value == nil {
......@@ -312,16 +239,14 @@ func (o *rawSet) loopSetRefs(refs []interface{}, sIdxes [][]int, sInds []reflect
}
}
// query data and map to container
func (o *rawSet) QueryRow(containers ...interface{}) error {
if len(containers) == 0 {
panic(fmt.Errorf("<RawSeter.QueryRow> need at least one arg"))
}
refs := make([]interface{}, 0, len(containers))
sIdxes := make([][]int, 0)
sInds := make([]reflect.Value, 0)
eTyps := make([]reflect.Type, 0)
structMode := false
var sMi *modelInfo
for _, container := range containers {
val := reflect.ValueOf(container)
ind := reflect.Indirect(val)
......@@ -335,44 +260,123 @@ func (o *rawSet) QueryRow(containers ...interface{}) error {
if typ.Kind() == reflect.Ptr {
typ = typ.Elem()
}
if typ.Kind() == reflect.Ptr {
typ = typ.Elem()
}
sInds = append(sInds, ind)
eTyps = append(eTyps, etyp)
o.loopInitRefs(typ, &refs, &sIdxes)
if typ.Kind() == reflect.Struct && typ.String() != "time.Time" {
if len(containers) > 1 {
panic(fmt.Errorf("<RawSeter.QueryRow> now support one struct only. see #384"))
}
structMode = true
fn := getFullName(typ)
if mi, ok := modelCache.getByFN(fn); ok {
sMi = mi
}
} else {
var ref interface{}
refs = append(refs, &ref)
}
}
query := o.query
o.orm.alias.DbBaser.ReplaceMarks(&query)
args := getFlatParams(nil, o.args, o.orm.alias.TZ)
row := o.orm.db.QueryRow(query, args...)
if err := row.Scan(refs...); err == sql.ErrNoRows {
return ErrNoRows
} else if err != nil {
rows, err := o.orm.db.Query(query, args...)
if err != nil {
if err == sql.ErrNoRows {
return ErrNoRows
}
return err
}
nInds := make([]reflect.Value, len(sInds))
o.loopSetRefs(refs, sIdxes, sInds, &nInds, eTyps, true)
for i, sInd := range sInds {
nInd := nInds[i]
sInd.Set(nInd)
defer rows.Close()
if rows.Next() {
if structMode {
columns, err := rows.Columns()
if err != nil {
return err
}
columnsMp := make(map[string]interface{}, len(columns))
refs = make([]interface{}, 0, len(columns))
for _, col := range columns {
var ref interface{}
columnsMp[col] = &ref
refs = append(refs, &ref)
}
if err := rows.Scan(refs...); err != nil {
return err
}
ind := sInds[0]
if ind.Kind() == reflect.Ptr {
if ind.IsNil() || !ind.IsValid() {
ind.Set(reflect.New(eTyps[0].Elem()))
}
ind = ind.Elem()
}
if sMi != nil {
for _, col := range columns {
if fi := sMi.fields.GetByColumn(col); fi != nil {
value := reflect.ValueOf(columnsMp[col]).Elem().Interface()
o.setFieldValue(ind.FieldByIndex([]int{fi.fieldIndex}), value)
}
}
} else {
for i := 0; i < ind.NumField(); i++ {
f := ind.Field(i)
fe := ind.Type().Field(i)
var attrs map[string]bool
var tags map[string]string
parseStructTag(fe.Tag.Get("orm"), &attrs, &tags)
var col string
if col = tags["column"]; len(col) == 0 {
col = snakeString(fe.Name)
}
if v, ok := columnsMp[col]; ok {
value := reflect.ValueOf(v).Elem().Interface()
o.setFieldValue(f, value)
}
}
}
} else {
if err := rows.Scan(refs...); err != nil {
return err
}
nInds := make([]reflect.Value, len(sInds))
o.loopSetRefs(refs, sInds, &nInds, eTyps, true)
for i, sInd := range sInds {
nInd := nInds[i]
sInd.Set(nInd)
}
}
} else {
return ErrNoRows
}
return nil
}
// query data rows and map to container
func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) {
refs := make([]interface{}, 0)
sIdxes := make([][]int, 0)
refs := make([]interface{}, 0, len(containers))
sInds := make([]reflect.Value, 0)
eTyps := make([]reflect.Type, 0)
structMode := false
var sMi *modelInfo
for _, container := range containers {
val := reflect.ValueOf(container)
sInd := reflect.Indirect(val)
......@@ -389,7 +393,20 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) {
sInds = append(sInds, sInd)
eTyps = append(eTyps, etyp)
o.loopInitRefs(typ, &refs, &sIdxes)
if typ.Kind() == reflect.Struct && typ.String() != "time.Time" {
if len(containers) > 1 {
panic(fmt.Errorf("<RawSeter.QueryRow> now support one struct only. see #384"))
}
structMode = true
fn := getFullName(typ)
if mi, ok := modelCache.getByFN(fn); ok {
sMi = mi
}
} else {
var ref interface{}
refs = append(refs, &ref)
}
}
query := o.query
......@@ -401,23 +418,100 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) {
return 0, err
}
nInds := make([]reflect.Value, len(sInds))
defer rows.Close()
var cnt int64
nInds := make([]reflect.Value, len(sInds))
sInd := sInds[0]
for rows.Next() {
if err := rows.Scan(refs...); err != nil {
return 0, err
}
o.loopSetRefs(refs, sIdxes, sInds, &nInds, eTyps, cnt == 0)
if structMode {
columns, err := rows.Columns()
if err != nil {
return 0, err
}
columnsMp := make(map[string]interface{}, len(columns))
refs = make([]interface{}, 0, len(columns))
for _, col := range columns {
var ref interface{}
columnsMp[col] = &ref
refs = append(refs, &ref)
}
if err := rows.Scan(refs...); err != nil {
return 0, err
}
if cnt == 0 && !sInd.IsNil() {
sInd.Set(reflect.New(sInd.Type()).Elem())
}
var ind reflect.Value
if eTyps[0].Kind() == reflect.Ptr {
ind = reflect.New(eTyps[0].Elem())
} else {
ind = reflect.New(eTyps[0])
}
if ind.Kind() == reflect.Ptr {
ind = ind.Elem()
}
if sMi != nil {
for _, col := range columns {
if fi := sMi.fields.GetByColumn(col); fi != nil {
value := reflect.ValueOf(columnsMp[col]).Elem().Interface()
o.setFieldValue(ind.FieldByIndex([]int{fi.fieldIndex}), value)
}
}
} else {
for i := 0; i < ind.NumField(); i++ {
f := ind.Field(i)
fe := ind.Type().Field(i)
var attrs map[string]bool
var tags map[string]string
parseStructTag(fe.Tag.Get("orm"), &attrs, &tags)
var col string
if col = tags["column"]; len(col) == 0 {
col = snakeString(fe.Name)
}
if v, ok := columnsMp[col]; ok {
value := reflect.ValueOf(v).Elem().Interface()
o.setFieldValue(f, value)
}
}
}
if eTyps[0].Kind() == reflect.Ptr {
ind = ind.Addr()
}
sInd = reflect.Append(sInd, ind)
} else {
if err := rows.Scan(refs...); err != nil {
return 0, err
}
o.loopSetRefs(refs, sInds, &nInds, eTyps, cnt == 0)
}
cnt++
}
if cnt > 0 {
for i, sInd := range sInds {
nInd := nInds[i]
sInd.Set(nInd)
if structMode {
sInds[0].Set(sInd)
} else {
for i, sInd := range sInds {
nInd := nInds[i]
sInd.Set(nInd)
}
}
}
......@@ -455,6 +549,8 @@ func (o *rawSet) readValues(container interface{}) (int64, error) {
rs = r
}
defer rs.Close()
var (
refs []interface{}
cnt int64
......@@ -527,18 +623,22 @@ func (o *rawSet) readValues(container interface{}) (int64, error) {
return cnt, nil
}
// query data to []map[string]interface
func (o *rawSet) Values(container *[]Params) (int64, error) {
return o.readValues(container)
}
// query data to [][]interface
func (o *rawSet) ValuesList(container *[]ParamsList) (int64, error) {
return o.readValues(container)
}
// query data to []interface
func (o *rawSet) ValuesFlat(container *ParamsList) (int64, error) {
return o.readValues(container)
}
// return prepared raw statement for used in times.
func (o *rawSet) Prepare() (RawPreparer, error) {
return newRawPreparer(o)
}
......
......@@ -1322,58 +1322,6 @@ func TestRawQueryRow(t *testing.T) {
}
}
type Tmp struct {
Skip0 string
Id int
Char *string
Skip1 int `orm:"-"`
Date time.Time
DateTime time.Time
}
Boolean = false
Text = ""
Int64 = 0
Uint = 0
tmp := new(Tmp)
cols = []string{
"int", "char", "date", "datetime", "boolean", "text", "int64", "uint",
}
query = fmt.Sprintf("SELECT NULL, %s%s%s FROM data WHERE id = ?", Q, strings.Join(cols, sep), Q)
values = []interface{}{
tmp, &Boolean, &Text, &Int64, &Uint,
}
err = dORM.Raw(query, 1).QueryRow(values...)
throwFailNow(t, err)
for _, col := range cols {
switch col {
case "id":
throwFail(t, AssertIs(tmp.Id, data_values[col]))
case "char":
c := tmp.Char
throwFail(t, AssertIs(*c, data_values[col]))
case "date":
v := tmp.Date.In(DefaultTimeLoc)
value := data_values[col].(time.Time).In(DefaultTimeLoc)
throwFail(t, AssertIs(v, value, test_Date))
case "datetime":
v := tmp.DateTime.In(DefaultTimeLoc)
value := data_values[col].(time.Time).In(DefaultTimeLoc)
throwFail(t, AssertIs(v, value, test_DateTime))
case "boolean":
throwFail(t, AssertIs(Boolean, data_values[col]))
case "text":
throwFail(t, AssertIs(Text, data_values[col]))
case "int64":
throwFail(t, AssertIs(Int64, data_values[col]))
case "uint":
throwFail(t, AssertIs(Uint, data_values[col]))
}
}
var (
uid int
status *int
......@@ -1394,22 +1342,13 @@ func TestRawQueryRow(t *testing.T) {
func TestQueryRows(t *testing.T) {
Q := dDbBaser.TableQuote()
cols := []string{
"id", "boolean", "char", "text", "date", "datetime", "byte", "rune", "int", "int8", "int16", "int32",
"int64", "uint", "uint8", "uint16", "uint32", "uint64", "float32", "float64", "decimal",
}
var datas []*Data
var dids []int
sep := fmt.Sprintf("%s, %s", Q, Q)
query := fmt.Sprintf("SELECT %s%s%s, id FROM %sdata%s", Q, strings.Join(cols, sep), Q, Q, Q)
num, err := dORM.Raw(query).QueryRows(&datas, &dids)
query := fmt.Sprintf("SELECT * FROM %sdata%s", Q, Q)
num, err := dORM.Raw(query).QueryRows(&datas)
throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 1))
throwFailNow(t, AssertIs(len(datas), 1))
throwFailNow(t, AssertIs(len(dids), 1))
throwFailNow(t, AssertIs(dids[0], 1))
ind := reflect.Indirect(reflect.ValueOf(datas[0]))
......@@ -1427,90 +1366,43 @@ func TestQueryRows(t *testing.T) {
throwFail(t, AssertIs(vu == value, true), value, vu)
}
type Tmp struct {
Id int
Name string
Skiped0 string `orm:"-"`
Pid *int
Skiped1 Data
Skiped2 *Data
}
var datas2 []Data
var (
ids []int
userNames []string
profileIds1 []int
profileIds2 []*int
createds []time.Time
updateds []time.Time
tmps1 []*Tmp
tmps2 []Tmp
)
cols = []string{
"id", "user_name", "profile_id", "profile_id", "id", "user_name", "profile_id", "id", "user_name", "profile_id", "created", "updated",
}
query = fmt.Sprintf("SELECT %s%s%s FROM %suser%s ORDER BY id", Q, strings.Join(cols, sep), Q, Q, Q)
num, err = dORM.Raw(query).QueryRows(&ids, &userNames, &profileIds1, &profileIds2, &tmps1, &tmps2, &createds, &updateds)
query = fmt.Sprintf("SELECT * FROM %sdata%s", Q, Q)
num, err = dORM.Raw(query).QueryRows(&datas2)
throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 3))
throwFailNow(t, AssertIs(num, 1))
throwFailNow(t, AssertIs(len(datas2), 1))
var users []User
dORM.QueryTable("user").OrderBy("Id").All(&users)
for i := 0; i < 3; i++ {
id := ids[i]
name := userNames[i]
pid1 := profileIds1[i]
pid2 := profileIds2[i]
created := createds[i]
updated := updateds[i]
user := users[i]
throwFailNow(t, AssertIs(id, user.Id))
throwFailNow(t, AssertIs(name, user.UserName))
if user.Profile != nil {
throwFailNow(t, AssertIs(pid1, user.Profile.Id))
throwFailNow(t, AssertIs(*pid2, user.Profile.Id))
} else {
throwFailNow(t, AssertIs(pid1, 0))
throwFailNow(t, AssertIs(pid2, nil))
}
throwFailNow(t, AssertIs(created, user.Created, test_Date))
throwFailNow(t, AssertIs(updated, user.Updated, test_DateTime))
tmp := tmps1[i]
tmp1 := *tmp
throwFailNow(t, AssertIs(tmp1.Id, user.Id))
throwFailNow(t, AssertIs(tmp1.Name, user.UserName))
if user.Profile != nil {
pid := tmp1.Pid
throwFailNow(t, AssertIs(*pid, user.Profile.Id))
} else {
throwFailNow(t, AssertIs(tmp1.Pid, nil))
}
ind = reflect.Indirect(reflect.ValueOf(datas2[0]))
tmp2 := tmps2[i]
throwFailNow(t, AssertIs(tmp2.Id, user.Id))
throwFailNow(t, AssertIs(tmp2.Name, user.UserName))
if user.Profile != nil {
pid := tmp2.Pid
throwFailNow(t, AssertIs(*pid, user.Profile.Id))
} else {
throwFailNow(t, AssertIs(tmp2.Pid, nil))
for name, value := range Data_Values {
e := ind.FieldByName(name)
vu := e.Interface()
switch name {
case "Date":
vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_Date)
value = value.(time.Time).In(DefaultTimeLoc).Format(test_Date)
case "DateTime":
vu = vu.(time.Time).In(DefaultTimeLoc).Format(test_DateTime)
value = value.(time.Time).In(DefaultTimeLoc).Format(test_DateTime)
}
throwFail(t, AssertIs(vu == value, true), value, vu)
}
type Sec struct {
Id int
Name string
}
var tmp []*Sec
query = fmt.Sprintf("SELECT NULL, NULL FROM %suser%s LIMIT 1", Q, Q)
num, err = dORM.Raw(query).QueryRows(&tmp)
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
throwFail(t, AssertIs(tmp[0], nil))
var ids []int
var usernames []string
query = fmt.Sprintf("SELECT %sid%s, %suser_name%s FROM %suser%s ORDER BY %sid%s ASC", Q, Q, Q, Q, Q, Q, Q, Q)
num, err = dORM.Raw(query).QueryRows(&ids, &usernames)
throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 3))
throwFailNow(t, AssertIs(len(ids), 3))
throwFailNow(t, AssertIs(ids[0], 2))
throwFailNow(t, AssertIs(usernames[0], "slene"))
throwFailNow(t, AssertIs(ids[1], 3))
throwFailNow(t, AssertIs(usernames[1], "astaxie"))
throwFailNow(t, AssertIs(ids[2], 4))
throwFailNow(t, AssertIs(usernames[2], "nobody"))
}
func TestRawValues(t *testing.T) {
......@@ -1669,6 +1561,32 @@ func TestDelete(t *testing.T) {
num, err = qs.Filter("user_name", "slene").Filter("profile__isnull", true).Count()
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
qs = dORM.QueryTable("comment")
num, err = qs.Count()
throwFail(t, err)
throwFail(t, AssertIs(num, 6))
qs = dORM.QueryTable("post")
num, err = qs.Filter("Id", 3).Delete()
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
qs = dORM.QueryTable("comment")
num, err = qs.Count()
throwFail(t, err)
throwFail(t, AssertIs(num, 4))
fmt.Println("...")
qs = dORM.QueryTable("comment")
num, err = qs.Filter("Post__User", 3).Delete()
throwFail(t, err)
throwFail(t, AssertIs(num, 3))
qs = dORM.QueryTable("comment")
num, err = qs.Count()
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
}
func TestTransaction(t *testing.T) {
......
......@@ -6,11 +6,13 @@ import (
"time"
)
// database driver
type Driver interface {
Name() string
Type() DriverType
}
// field info
type Fielder interface {
String() string
FieldType() int
......@@ -18,9 +20,11 @@ type Fielder interface {
RawValue() interface{}
}
// orm struct
type Ormer interface {
Read(interface{}, ...string) error
Insert(interface{}) (int64, error)
InsertMulti(int, interface{}) (int64, error)
Update(interface{}, ...string) (int64, error)
Delete(interface{}) (int64, error)
LoadRelated(interface{}, string, ...interface{}) (int64, error)
......@@ -34,11 +38,13 @@ type Ormer interface {
Driver() Driver
}
// insert prepared statement
type Inserter interface {
Insert(interface{}) (int64, error)
Close() error
}
// query seter
type QuerySeter interface {
Filter(string, ...interface{}) QuerySeter
Exclude(string, ...interface{}) QuerySeter
......@@ -59,6 +65,7 @@ type QuerySeter interface {
ValuesFlat(*ParamsList, string) (int64, error)
}
// model to model query struct
type QueryM2Mer interface {
Add(...interface{}) (int64, error)
Remove(...interface{}) (int64, error)
......@@ -67,11 +74,13 @@ type QueryM2Mer interface {
Count() (int64, error)
}
// raw query statement
type RawPreparer interface {
Exec(...interface{}) (sql.Result, error)
Close() error
}
// raw query seter
type RawSeter interface {
Exec() (sql.Result, error)
QueryRow(...interface{}) error
......@@ -83,6 +92,7 @@ type RawSeter interface {
Prepare() (RawPreparer, error)
}
// statement querier
type stmtQuerier interface {
Close() error
Exec(args ...interface{}) (sql.Result, error)
......@@ -90,6 +100,7 @@ type stmtQuerier interface {
QueryRow(args ...interface{}) *sql.Row
}
// db querier
type dbQuerier interface {
Prepare(query string) (*sql.Stmt, error)
Exec(query string, args ...interface{}) (sql.Result, error)
......@@ -97,19 +108,23 @@ type dbQuerier interface {
QueryRow(query string, args ...interface{}) *sql.Row
}
// transaction beginner
type txer interface {
Begin() (*sql.Tx, error)
}
// transaction ending
type txEnder interface {
Commit() error
Rollback() error
}
// base database struct
type dbBaser interface {
Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) error
Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
InsertValue(dbQuerier, *modelInfo, []string, []interface{}) (int64, error)
InsertMulti(dbQuerier, *modelInfo, reflect.Value, int, *time.Location) (int64, error)
InsertValue(dbQuerier, *modelInfo, bool, []string, []interface{}) (int64, error)
InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
Update(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error)
Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
......
......@@ -10,6 +10,7 @@ import (
type StrTo string
// set string
func (f *StrTo) Set(v string) {
if v != "" {
*f = StrTo(v)
......@@ -18,77 +19,93 @@ func (f *StrTo) Set(v string) {
}
}
// clean string
func (f *StrTo) Clear() {
*f = StrTo(0x1E)
}
// check string exist
func (f StrTo) Exist() bool {
return string(f) != string(0x1E)
}
// string to bool
func (f StrTo) Bool() (bool, error) {
return strconv.ParseBool(f.String())
}
// string to float32
func (f StrTo) Float32() (float32, error) {
v, err := strconv.ParseFloat(f.String(), 32)
return float32(v), err
}
// string to float64
func (f StrTo) Float64() (float64, error) {
return strconv.ParseFloat(f.String(), 64)
}
// string to int
func (f StrTo) Int() (int, error) {
v, err := strconv.ParseInt(f.String(), 10, 32)
return int(v), err
}
// string to int8
func (f StrTo) Int8() (int8, error) {
v, err := strconv.ParseInt(f.String(), 10, 8)
return int8(v), err
}
// string to int16
func (f StrTo) Int16() (int16, error) {
v, err := strconv.ParseInt(f.String(), 10, 16)
return int16(v), err
}
// string to int32
func (f StrTo) Int32() (int32, error) {
v, err := strconv.ParseInt(f.String(), 10, 32)
return int32(v), err
}
// string to int64
func (f StrTo) Int64() (int64, error) {
v, err := strconv.ParseInt(f.String(), 10, 64)
return int64(v), err
}
// string to uint
func (f StrTo) Uint() (uint, error) {
v, err := strconv.ParseUint(f.String(), 10, 32)
return uint(v), err
}
// string to uint8
func (f StrTo) Uint8() (uint8, error) {
v, err := strconv.ParseUint(f.String(), 10, 8)
return uint8(v), err
}
// string to uint16
func (f StrTo) Uint16() (uint16, error) {
v, err := strconv.ParseUint(f.String(), 10, 16)
return uint16(v), err
}
// string to uint31
func (f StrTo) Uint32() (uint32, error) {
v, err := strconv.ParseUint(f.String(), 10, 32)
return uint32(v), err
}
// string to uint64
func (f StrTo) Uint64() (uint64, error) {
v, err := strconv.ParseUint(f.String(), 10, 64)
return uint64(v), err
}
// string to string
func (f StrTo) String() string {
if f.Exist() {
return string(f)
......@@ -96,6 +113,7 @@ func (f StrTo) String() string {
return ""
}
// interface to string
func ToStr(value interface{}, args ...int) (s string) {
switch v := value.(type) {
case bool:
......@@ -134,6 +152,7 @@ func ToStr(value interface{}, args ...int) (s string) {
return s
}
// interface to int64
func ToInt64(value interface{}) (d int64) {
val := reflect.ValueOf(value)
switch value.(type) {
......@@ -147,6 +166,7 @@ func ToInt64(value interface{}) (d int64) {
return
}
// snake string, XxYy to xx_yy
func snakeString(s string) string {
data := make([]byte, 0, len(s)*2)
j := false
......@@ -164,6 +184,7 @@ func snakeString(s string) string {
return strings.ToLower(string(data[:len(data)]))
}
// camel string, xx_yy to XxYy
func camelString(s string) string {
data := make([]byte, 0, len(s))
j := false
......@@ -190,6 +211,7 @@ func camelString(s string) string {
type argString []string
// get string by index from string slice
func (a argString) Get(i int, args ...string) (r string) {
if i >= 0 && i < len(a) {
r = a[i]
......@@ -201,6 +223,7 @@ func (a argString) Get(i int, args ...string) (r string) {
type argInt []int
// get int by index from int slice
func (a argInt) Get(i int, args ...int) (r int) {
if i >= 0 && i < len(a) {
r = a[i]
......@@ -213,6 +236,7 @@ func (a argInt) Get(i int, args ...int) (r int) {
type argAny []interface{}
// get interface by index from interface slice
func (a argAny) Get(i int, args ...interface{}) (r interface{}) {
if i >= 0 && i < len(a) {
r = a[i]
......@@ -223,15 +247,18 @@ func (a argAny) Get(i int, args ...interface{}) (r interface{}) {
return
}
// parse time to string with location
func timeParse(dateString, format string) (time.Time, error) {
tp, err := time.ParseInLocation(format, dateString, DefaultTimeLoc)
return tp, err
}
// format time string
func timeFormat(t time.Time, format string) string {
return t.Format(format)
}
// get pointer indirect type
func indirectType(v reflect.Type) reflect.Type {
switch v.Kind() {
case reflect.Ptr:
......
package beego
import (
"bufio"
"errors"
"fmt"
"net"
"net/http"
"net/url"
"os"
......@@ -30,6 +33,14 @@ const (
var (
// supported http methods.
HTTPMETHOD = []string{"get", "post", "put", "delete", "patch", "options", "head"}
// these beego.Controller's methods shouldn't reflect to AutoRouter
exceptMethod = []string{"Init", "Prepare", "Finish", "Render", "RenderString",
"RenderBytes", "Redirect", "Abort", "StopRun", "UrlFor", "ServeJson", "ServeJsonp",
"ServeXml", "Input", "ParseForm", "GetString", "GetStrings", "GetInt", "GetBool",
"GetFloat", "GetFile", "SaveToFile", "StartSession", "SetSession", "GetSession",
"DelSession", "SessionRegenerateID", "DestroySession", "IsAjax", "GetSecureCookie",
"SetSecureCookie", "XsrfToken", "CheckXsrfCookie", "XsrfFormHtml",
"GetControllerAndAction"}
)
type controllerInfo struct {
......@@ -77,7 +88,7 @@ func (p *ControllerRegistor) Add(pattern string, c ControllerInterface, mappingM
params := make(map[int]string)
for i, part := range parts {
if strings.HasPrefix(part, ":") {
expr := "(.+)"
expr := "(.*)"
//a user may choose to override the defult expression
// similar to expressjs: ‘/user/:id([0-9]+)’
if index := strings.Index(part, "("); index != -1 {
......@@ -100,7 +111,7 @@ func (p *ControllerRegistor) Add(pattern string, c ControllerInterface, mappingM
j++
}
if strings.HasPrefix(part, "*") {
expr := "(.+)"
expr := "(.*)"
if part == "*.*" {
params[j] = ":path"
parts[i] = "([^.]+).([^.]+)"
......@@ -218,8 +229,8 @@ func (p *ControllerRegistor) Add(pattern string, c ControllerInterface, mappingM
// Add auto router to ControllerRegistor.
// example beego.AddAuto(&MainContorlller{}),
// MainController has method List and Page.
// visit the url /main/list to exec List function
// /main/page to exec Page function.
// visit the url /main/list to execute List function
// /main/page to execute Page function.
func (p *ControllerRegistor) AddAuto(c ControllerInterface) {
p.enableAuto = true
reflectVal := reflect.ValueOf(c)
......@@ -232,14 +243,42 @@ func (p *ControllerRegistor) AddAuto(c ControllerInterface) {
p.autoRouter[firstParam] = make(map[string]reflect.Type)
}
for i := 0; i < rt.NumMethod(); i++ {
p.autoRouter[firstParam][rt.Method(i).Name] = ct
if !utils.InSlice(rt.Method(i).Name, exceptMethod) {
p.autoRouter[firstParam][rt.Method(i).Name] = ct
}
}
}
// Add auto router to ControllerRegistor with prefix.
// example beego.AddAutoPrefix("/admin",&MainContorlller{}),
// MainController has method List and Page.
// visit the url /admin/main/list to execute List function
// /admin/main/page to execute Page function.
func (p *ControllerRegistor) AddAutoPrefix(prefix string, c ControllerInterface) {
p.enableAuto = true
reflectVal := reflect.ValueOf(c)
rt := reflectVal.Type()
ct := reflect.Indirect(reflectVal).Type()
firstParam := strings.Trim(prefix, "/") + "/" + strings.ToLower(strings.TrimSuffix(ct.Name(), "Controller"))
if _, ok := p.autoRouter[firstParam]; ok {
return
} else {
p.autoRouter[firstParam] = make(map[string]reflect.Type)
}
for i := 0; i < rt.NumMethod(); i++ {
if !utils.InSlice(rt.Method(i).Name, exceptMethod) {
p.autoRouter[firstParam][rt.Method(i).Name] = ct
}
}
}
// [Deprecated] use InsertFilter.
// Add FilterFunc with pattern for action.
func (p *ControllerRegistor) AddFilter(pattern, action string, filter FilterFunc) {
mr := buildFilter(pattern, filter)
func (p *ControllerRegistor) AddFilter(pattern, action string, filter FilterFunc) error {
mr, err := buildFilter(pattern, filter)
if err != nil {
return err
}
switch action {
case "BeforeRouter":
p.filters[BeforeRouter] = append(p.filters[BeforeRouter], mr)
......@@ -253,13 +292,18 @@ func (p *ControllerRegistor) AddFilter(pattern, action string, filter FilterFunc
p.filters[FinishRouter] = append(p.filters[FinishRouter], mr)
}
p.enableFilter = true
return nil
}
// Add a FilterFunc with pattern rule and action constant.
func (p *ControllerRegistor) InsertFilter(pattern string, pos int, filter FilterFunc) {
mr := buildFilter(pattern, filter)
func (p *ControllerRegistor) InsertFilter(pattern string, pos int, filter FilterFunc) error {
mr, err := buildFilter(pattern, filter)
if err != nil {
return err
}
p.filters[pos] = append(p.filters[pos], mr)
p.enableFilter = true
return nil
}
// UrlFor does another controller handler in this request function.
......@@ -485,7 +529,9 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request)
// session init
if SessionOn {
context.Input.CruSession = GlobalSessions.SessionStart(w, r)
defer context.Input.CruSession.SessionRelease()
defer func() {
context.Input.CruSession.SessionRelease(w)
}()
}
if !utils.InSlice(strings.ToLower(r.Method), HTTPMETHOD) {
......@@ -575,12 +621,11 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request)
}
// pattern /admin url /admin 200 /admin/ 200
// pattern /admin/ url /admin 301 /admin/ 200
if requestPath[n-1] != '/' && len(route.pattern) == n+1 &&
route.pattern[n] == '/' && route.pattern[:n] == requestPath {
if requestPath[n-1] != '/' && requestPath+"/" == route.pattern {
http.Redirect(w, r, requestPath+"/", 301)
goto Admin
}
if requestPath[n-1] == '/' && n >= 2 && requestPath[:n-2] == route.pattern {
if requestPath[n-1] == '/' && route.pattern+"/" == requestPath {
runMethod = p.getRunMethod(r.Method, context, route)
if runMethod != "" {
runrouter = route.controllerType
......@@ -857,3 +902,13 @@ func (w *responseWriter) WriteHeader(code int) {
w.started = true
w.writer.WriteHeader(code)
}
// hijacker for http
func (w *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
hj, ok := w.writer.(http.Hijacker)
if !ok {
println("supported?")
return nil, nil, errors.New("webserver doesn't support hijacking")
}
return hj.Hijack()
}
......
......@@ -198,3 +198,15 @@ func TestPrepare(t *testing.T) {
t.Errorf(w.Body.String() + "user define func can't run")
}
}
func TestAutoPrefix(t *testing.T) {
r, _ := http.NewRequest("GET", "/admin/test/list", nil)
w := httptest.NewRecorder()
handler := NewControllerRegistor()
handler.AddAutoPrefix("/admin", &TestController{})
handler.ServeHTTP(w, r)
if w.Body.String() != "i am list" {
t.Errorf("TestAutoPrefix can't run")
}
}
......
......@@ -28,21 +28,21 @@ Then in you web app init the global session manager
* Use **memory** as provider:
func init() {
globalSessions, _ = session.NewManager("memory", "gosessionid", 3600,"")
globalSessions, _ = session.NewManager("memory", `{"cookieName":"gosessionid","gclifetime":3600}`)
go globalSessions.GC()
}
* Use **file** as provider, the last param is the path where you want file to be stored:
func init() {
globalSessions, _ = session.NewManager("file", "gosessionid", 3600, "./tmp")
globalSessions, _ = session.NewManager("file",`{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig","./tmp"}`)
go globalSessions.GC()
}
* Use **Redis** as provider, the last param is the Redis conn address,poolsize,password:
func init() {
globalSessions, _ = session.NewManager("redis", "gosessionid", 3600, "127.0.0.1:6379,100,astaxie")
globalSessions, _ = session.NewManager("redis", `{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig","127.0.0.1:6379,100,astaxie"}`)
go globalSessions.GC()
}
......@@ -50,15 +50,24 @@ Then in you web app init the global session manager
func init() {
globalSessions, _ = session.NewManager(
"mysql", "gosessionid", 3600, "username:password@protocol(address)/dbname?param=value")
"mysql", `{"cookieName":"gosessionid","gclifetime":3600,"ProviderConfig","username:password@protocol(address)/dbname?param=value"}`)
go globalSessions.GC()
}
* Use **Cookie** as provider:
func init() {
globalSessions, _ = session.NewManager(
"cookie", `{"cookieName":"gosessionid","enableSetCookie":false,gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}`)
go globalSessions.GC()
}
Finally in the handlerfunc you can use it like this
func login(w http.ResponseWriter, r *http.Request) {
sess := globalSessions.SessionStart(w, r)
defer sess.SessionRelease()
defer sess.SessionRelease(w)
username := sess.Get("username")
fmt.Println(username)
if r.Method == "GET" {
......@@ -78,19 +87,19 @@ When you develop a web app, maybe you want to write own provider because you mus
Writing a provider is easy. You only need to define two struct types
(Session and Provider), which satisfy the interface definition.
Maybe you will find the **memory** provider as good example.
Maybe you will find the **memory** provider is a good example.
type SessionStore interface {
Set(key, value interface{}) error //set session value
Get(key interface{}) interface{} //get session value
Delete(key interface{}) error //delete session value
SessionID() string //back current sessionID
SessionRelease() // release the resource & save data to provider
Flush() error //delete all data
Set(key, value interface{}) error //set session value
Get(key interface{}) interface{} //get session value
Delete(key interface{}) error //delete session value
SessionID() string //back current sessionID
SessionRelease(w http.ResponseWriter) // release the resource & save data to provider & return the data
Flush() error //delete all data
}
type Provider interface {
SessionInit(maxlifetime int64, savePath string) error
SessionInit(gclifetime int64, config string) error
SessionRead(sid string) (SessionStore, error)
SessionExist(sid string) bool
SessionRegenerate(oldsid, sid string) (SessionStore, error)
......
package session
import (
"crypto/aes"
"crypto/cipher"
"encoding/json"
"net/http"
"net/url"
"sync"
)
var cookiepder = &CookieProvider{}
type CookieSessionStore struct {
sid string
values map[interface{}]interface{} //session data
lock sync.RWMutex
}
func (st *CookieSessionStore) Set(key, value interface{}) error {
st.lock.Lock()
defer st.lock.Unlock()
st.values[key] = value
return nil
}
func (st *CookieSessionStore) Get(key interface{}) interface{} {
st.lock.RLock()
defer st.lock.RUnlock()
if v, ok := st.values[key]; ok {
return v
} else {
return nil
}
return nil
}
func (st *CookieSessionStore) Delete(key interface{}) error {
st.lock.Lock()
defer st.lock.Unlock()
delete(st.values, key)
return nil
}
func (st *CookieSessionStore) Flush() error {
st.lock.Lock()
defer st.lock.Unlock()
st.values = make(map[interface{}]interface{})
return nil
}
func (st *CookieSessionStore) SessionID() string {
return st.sid
}
func (st *CookieSessionStore) SessionRelease(w http.ResponseWriter) {
str, err := encodeCookie(cookiepder.block,
cookiepder.config.SecurityKey,
cookiepder.config.SecurityName,
st.values)
if err != nil {
return
}
cookie := &http.Cookie{Name: cookiepder.config.CookieName,
Value: url.QueryEscape(str),
Path: "/",
HttpOnly: true,
Secure: cookiepder.config.Secure}
http.SetCookie(w, cookie)
return
}
type cookieConfig struct {
SecurityKey string `json:"securityKey"`
BlockKey string `json:"blockKey"`
SecurityName string `json:"securityName"`
CookieName string `json:"cookieName"`
Secure bool `json:"secure"`
Maxage int `json:"maxage"`
}
type CookieProvider struct {
maxlifetime int64
config *cookieConfig
block cipher.Block
}
func (pder *CookieProvider) SessionInit(maxlifetime int64, config string) error {
pder.config = &cookieConfig{}
err := json.Unmarshal([]byte(config), pder.config)
if err != nil {
return err
}
if pder.config.BlockKey == "" {
pder.config.BlockKey = string(generateRandomKey(16))
}
if pder.config.SecurityName == "" {
pder.config.SecurityName = string(generateRandomKey(20))
}
pder.block, err = aes.NewCipher([]byte(pder.config.BlockKey))
if err != nil {
return err
}
return nil
}
func (pder *CookieProvider) SessionRead(sid string) (SessionStore, error) {
maps, _ := decodeCookie(pder.block,
pder.config.SecurityKey,
pder.config.SecurityName,
sid, pder.maxlifetime)
if maps == nil {
maps = make(map[interface{}]interface{})
}
rs := &CookieSessionStore{sid: sid, values: maps}
return rs, nil
}
func (pder *CookieProvider) SessionExist(sid string) bool {
return true
}
func (pder *CookieProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) {
return nil, nil
}
func (pder *CookieProvider) SessionDestroy(sid string) error {
return nil
}
func (pder *CookieProvider) SessionGC() {
return
}
func (pder *CookieProvider) SessionAll() int {
return 0
}
func (pder *CookieProvider) SessionUpdate(sid string) error {
return nil
}
func init() {
Register("cookie", cookiepder)
}
package session
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestCookie(t *testing.T) {
config := `{"cookieName":"gosessionid","enableSetCookie":false,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}`
globalSessions, err := NewManager("cookie", config)
if err != nil {
t.Fatal("init cookie session err", err)
}
r, _ := http.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
sess := globalSessions.SessionStart(w, r)
err = sess.Set("username", "astaxie")
if err != nil {
t.Fatal("set error,", err)
}
if username := sess.Get("username"); username != "astaxie" {
t.Fatal("get username error")
}
sess.SessionRelease(w)
if cookiestr := w.Header().Get("Set-Cookie"); cookiestr == "" {
t.Fatal("setcookie error")
} else {
parts := strings.Split(strings.TrimSpace(cookiestr), ";")
for k, v := range parts {
nameval := strings.Split(v, "=")
if k == 0 && nameval[0] != "gosessionid" {
t.Fatal("error")
}
}
}
}
......@@ -5,6 +5,7 @@ import (
"fmt"
"io"
"io/ioutil"
"net/http"
"os"
"path"
"path/filepath"
......@@ -60,7 +61,7 @@ func (fs *FileSessionStore) SessionID() string {
return fs.sid
}
func (fs *FileSessionStore) SessionRelease() {
func (fs *FileSessionStore) SessionRelease(w http.ResponseWriter) {
defer fs.f.Close()
b, err := encodeGob(fs.values)
if err != nil {
......
package session
import (
"bytes"
"encoding/gob"
)
func init() {
gob.Register([]interface{}{})
gob.Register(map[int]interface{}{})
gob.Register(map[string]interface{}{})
gob.Register(map[interface{}]interface{}{})
gob.Register(map[string]string{})
gob.Register(map[int]string{})
gob.Register(map[int]int{})
gob.Register(map[int]int64{})
}
func encodeGob(obj map[interface{}]interface{}) ([]byte, error) {
buf := bytes.NewBuffer(nil)
enc := gob.NewEncoder(buf)
err := enc.Encode(obj)
if err != nil {
return []byte(""), err
}
return buf.Bytes(), nil
}
func decodeGob(encoded []byte) (map[interface{}]interface{}, error) {
buf := bytes.NewBuffer(encoded)
dec := gob.NewDecoder(buf)
var out map[interface{}]interface{}
err := dec.Decode(&out)
if err != nil {
return nil, err
}
return out, nil
}
......@@ -2,6 +2,7 @@ package session
import (
"container/list"
"net/http"
"sync"
"time"
)
......@@ -9,9 +10,9 @@ import (
var mempder = &MemProvider{list: list.New(), sessions: make(map[string]*list.Element)}
type MemSessionStore struct {
sid string //session id唯一标示
timeAccessed time.Time //最后访问时间
value map[interface{}]interface{} //session里面存储的值
sid string //session id
timeAccessed time.Time //last access time
value map[interface{}]interface{} //session store
lock sync.RWMutex
}
......@@ -51,8 +52,7 @@ func (st *MemSessionStore) SessionID() string {
return st.sid
}
func (st *MemSessionStore) SessionRelease() {
func (st *MemSessionStore) SessionRelease(w http.ResponseWriter) {
}
type MemProvider struct {
......
package session
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestMem(t *testing.T) {
globalSessions, _ := NewManager("memory", `{"cookieName":"gosessionid","gclifetime":10}`)
go globalSessions.GC()
r, _ := http.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
sess := globalSessions.SessionStart(w, r)
defer sess.SessionRelease(w)
err := sess.Set("username", "astaxie")
if err != nil {
t.Fatal("set error,", err)
}
if username := sess.Get("username"); username != "astaxie" {
t.Fatal("get username error")
}
if cookiestr := w.Header().Get("Set-Cookie"); cookiestr == "" {
t.Fatal("setcookie error")
} else {
parts := strings.Split(strings.TrimSpace(cookiestr), ";")
for k, v := range parts {
nameval := strings.Split(v, "=")
if k == 0 && nameval[0] != "gosessionid" {
t.Fatal("error")
}
}
}
}
......@@ -9,6 +9,7 @@ package session
import (
"database/sql"
"net/http"
"sync"
"time"
......@@ -60,15 +61,15 @@ func (st *MysqlSessionStore) SessionID() string {
return st.sid
}
func (st *MysqlSessionStore) SessionRelease() {
func (st *MysqlSessionStore) SessionRelease(w http.ResponseWriter) {
defer st.c.Close()
if len(st.values) > 0 {
b, err := encodeGob(st.values)
if err != nil {
return
}
st.c.Exec("UPDATE session set `session_data`= ? where session_key=?", b, st.sid)
b, err := encodeGob(st.values)
if err != nil {
return
}
st.c.Exec("UPDATE session set `session_data`=?, `session_expiry`=? where session_key=?",
b, time.Now().Unix(), st.sid)
}
type MysqlProvider struct {
......@@ -96,7 +97,8 @@ func (mp *MysqlProvider) SessionRead(sid string) (SessionStore, error) {
var sessiondata []byte
err := row.Scan(&sessiondata)
if err == sql.ErrNoRows {
c.Exec("insert into session(`session_key`,`session_data`,`session_expiry`) values(?,?,?)", sid, "", time.Now().Unix())
c.Exec("insert into session(`session_key`,`session_data`,`session_expiry`) values(?,?,?)",
sid, "", time.Now().Unix())
}
var kv map[interface{}]interface{}
if len(sessiondata) == 0 {
......@@ -113,6 +115,7 @@ func (mp *MysqlProvider) SessionRead(sid string) (SessionStore, error) {
func (mp *MysqlProvider) SessionExist(sid string) bool {
c := mp.connectInit()
defer c.Close()
row := c.QueryRow("select session_data from session where session_key=?", sid)
var sessiondata []byte
err := row.Scan(&sessiondata)
......
package session
import (
"net/http"
"strconv"
"strings"
"sync"
......@@ -58,16 +59,14 @@ func (rs *RedisSessionStore) SessionID() string {
return rs.sid
}
func (rs *RedisSessionStore) SessionRelease() {
func (rs *RedisSessionStore) SessionRelease(w http.ResponseWriter) {
defer rs.c.Close()
if len(rs.values) > 0 {
b, err := encodeGob(rs.values)
if err != nil {
return
}
rs.c.Do("SET", rs.sid, string(b))
rs.c.Do("EXPIRE", rs.sid, rs.maxlifetime)
b, err := encodeGob(rs.values)
if err != nil {
return
}
rs.c.Do("SET", rs.sid, string(b))
rs.c.Do("EXPIRE", rs.sid, rs.maxlifetime)
}
type RedisProvider struct {
......
package session
import (
"crypto/aes"
"encoding/json"
"testing"
)
......@@ -26,3 +28,82 @@ func Test_gob(t *testing.T) {
t.Error("decode int error")
}
}
func TestGenerate(t *testing.T) {
str := generateRandomKey(20)
if len(str) != 20 {
t.Fatal("generate length is not equal to 20")
}
}
func TestCookieEncodeDecode(t *testing.T) {
hashKey := "testhashKey"
blockkey := generateRandomKey(16)
block, err := aes.NewCipher(blockkey)
if err != nil {
t.Fatal("NewCipher:", err)
}
securityName := string(generateRandomKey(20))
val := make(map[interface{}]interface{})
val["name"] = "astaxie"
val["gender"] = "male"
str, err := encodeCookie(block, hashKey, securityName, val)
if err != nil {
t.Fatal("encodeCookie:", err)
}
dst := make(map[interface{}]interface{})
dst, err = decodeCookie(block, hashKey, securityName, str, 3600)
if err != nil {
t.Fatal("decodeCookie", err)
}
if dst["name"] != "astaxie" {
t.Fatal("dst get map error")
}
if dst["gender"] != "male" {
t.Fatal("dst get map error")
}
}
func TestParseConfig(t *testing.T) {
s := `{"cookieName":"gosessionid","gclifetime":3600}`
cf := new(managerConfig)
cf.EnableSetCookie = true
err := json.Unmarshal([]byte(s), cf)
if err != nil {
t.Fatal("parse json error,", err)
}
if cf.CookieName != "gosessionid" {
t.Fatal("parseconfig get cookiename error")
}
if cf.Gclifetime != 3600 {
t.Fatal("parseconfig get gclifetime error")
}
cc := `{"cookieName":"gosessionid","enableSetCookie":false,"gclifetime":3600,"ProviderConfig":"{\"cookieName\":\"gosessionid\",\"securityKey\":\"beegocookiehashkey\"}"}`
cf2 := new(managerConfig)
cf2.EnableSetCookie = true
err = json.Unmarshal([]byte(cc), cf2)
if err != nil {
t.Fatal("parse json error,", err)
}
if cf2.CookieName != "gosessionid" {
t.Fatal("parseconfig get cookiename error")
}
if cf2.Gclifetime != 3600 {
t.Fatal("parseconfig get gclifetime error")
}
if cf2.EnableSetCookie != false {
t.Fatal("parseconfig get enableSetCookie error")
}
cconfig := new(cookieConfig)
err = json.Unmarshal([]byte(cf2.ProviderConfig), cconfig)
if err != nil {
t.Fatal("parse ProviderConfig err,", err)
}
if cconfig.CookieName != "gosessionid" {
t.Fatal("ProviderConfig get cookieName error")
}
if cconfig.SecurityKey != "beegocookiehashkey" {
t.Fatal("ProviderConfig get securityKey error")
}
}
......
package session
import (
"bytes"
"crypto/cipher"
"crypto/hmac"
"crypto/rand"
"crypto/sha1"
"crypto/subtle"
"encoding/base64"
"encoding/gob"
"errors"
"fmt"
"io"
"strconv"
"time"
)
func init() {
gob.Register([]interface{}{})
gob.Register(map[int]interface{}{})
gob.Register(map[string]interface{}{})
gob.Register(map[interface{}]interface{}{})
gob.Register(map[string]string{})
gob.Register(map[int]string{})
gob.Register(map[int]int{})
gob.Register(map[int]int64{})
}
func encodeGob(obj map[interface{}]interface{}) ([]byte, error) {
buf := bytes.NewBuffer(nil)
enc := gob.NewEncoder(buf)
err := enc.Encode(obj)
if err != nil {
return []byte(""), err
}
return buf.Bytes(), nil
}
func decodeGob(encoded []byte) (map[interface{}]interface{}, error) {
buf := bytes.NewBuffer(encoded)
dec := gob.NewDecoder(buf)
var out map[interface{}]interface{}
err := dec.Decode(&out)
if err != nil {
return nil, err
}
return out, nil
}
// generateRandomKey creates a random key with the given strength.
func generateRandomKey(strength int) []byte {
k := make([]byte, strength)
if _, err := io.ReadFull(rand.Reader, k); err != nil {
return nil
}
return k
}
// Encryption -----------------------------------------------------------------
// encrypt encrypts a value using the given block in counter mode.
//
// A random initialization vector (http://goo.gl/zF67k) with the length of the
// block size is prepended to the resulting ciphertext.
func encrypt(block cipher.Block, value []byte) ([]byte, error) {
iv := generateRandomKey(block.BlockSize())
if iv == nil {
return nil, errors.New("encrypt: failed to generate random iv")
}
// Encrypt it.
stream := cipher.NewCTR(block, iv)
stream.XORKeyStream(value, value)
// Return iv + ciphertext.
return append(iv, value...), nil
}
// decrypt decrypts a value using the given block in counter mode.
//
// The value to be decrypted must be prepended by a initialization vector
// (http://goo.gl/zF67k) with the length of the block size.
func decrypt(block cipher.Block, value []byte) ([]byte, error) {
size := block.BlockSize()
if len(value) > size {
// Extract iv.
iv := value[:size]
// Extract ciphertext.
value = value[size:]
// Decrypt it.
stream := cipher.NewCTR(block, iv)
stream.XORKeyStream(value, value)
return value, nil
}
return nil, errors.New("decrypt: the value could not be decrypted")
}
func encodeCookie(block cipher.Block, hashKey, name string, value map[interface{}]interface{}) (string, error) {
var err error
var b []byte
// 1. encodeGob.
if b, err = encodeGob(value); err != nil {
return "", err
}
// 2. Encrypt (optional).
if b, err = encrypt(block, b); err != nil {
return "", err
}
b = encode(b)
// 3. Create MAC for "name|date|value". Extra pipe to be used later.
b = []byte(fmt.Sprintf("%s|%d|%s|", name, time.Now().UTC().Unix(), b))
h := hmac.New(sha1.New, []byte(hashKey))
h.Write(b)
sig := h.Sum(nil)
// Append mac, remove name.
b = append(b, sig...)[len(name)+1:]
// 4. Encode to base64.
b = encode(b)
// Done.
return string(b), nil
}
func decodeCookie(block cipher.Block, hashKey, name, value string, gcmaxlifetime int64) (map[interface{}]interface{}, error) {
// 1. Decode from base64.
b, err := decode([]byte(value))
if err != nil {
return nil, err
}
// 2. Verify MAC. Value is "date|value|mac".
parts := bytes.SplitN(b, []byte("|"), 3)
if len(parts) != 3 {
return nil, errors.New("Decode: invalid value %v")
}
b = append([]byte(name+"|"), b[:len(b)-len(parts[2])]...)
h := hmac.New(sha1.New, []byte(hashKey))
h.Write(b)
sig := h.Sum(nil)
if len(sig) != len(parts[2]) || subtle.ConstantTimeCompare(sig, parts[2]) != 1 {
return nil, errors.New("Decode: the value is not valid")
}
// 3. Verify date ranges.
var t1 int64
if t1, err = strconv.ParseInt(string(parts[0]), 10, 64); err != nil {
return nil, errors.New("Decode: invalid timestamp")
}
t2 := time.Now().UTC().Unix()
if t1 > t2 {
return nil, errors.New("Decode: timestamp is too new")
}
if t1 < t2-gcmaxlifetime {
return nil, errors.New("Decode: expired timestamp")
}
// 4. Decrypt (optional).
b, err = decode(parts[1])
if err != nil {
return nil, err
}
if b, err = decrypt(block, b); err != nil {
return nil, err
}
// 5. decodeGob.
if dst, err := decodeGob(b); err != nil {
return nil, err
} else {
return dst, nil
}
// Done.
return nil, nil
}
// Encoding -------------------------------------------------------------------
// encode encodes a value using base64.
func encode(value []byte) []byte {
encoded := make([]byte, base64.URLEncoding.EncodedLen(len(value)))
base64.URLEncoding.Encode(encoded, value)
return encoded
}
// decode decodes a cookie using base64.
func decode(value []byte) ([]byte, error) {
decoded := make([]byte, base64.URLEncoding.DecodedLen(len(value)))
b, err := base64.URLEncoding.Decode(decoded, value)
if err != nil {
return nil, err
}
return decoded[:b], nil
}
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!