85630002 by slene

orm operator args now support multi types eg: []int []*int *int, Model *Model

1 parent 9047d21e
...@@ -949,21 +949,76 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition ...@@ -949,21 +949,76 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition
949 return 949 return
950 } 950 }
951 951
952 func (d *dbBase) GetOperatorSql(mi *modelInfo, operator string, args []interface{}) (string, []interface{}) { 952 func (d *dbBase) getOperatorParams(operator string, args []interface{}) (params []interface{}) {
953 params := make([]interface{}, len(args)) 953 for _, arg := range args {
954 copy(params, args) 954 val := reflect.ValueOf(arg)
955 sql := "" 955
956 for i, arg := range args { 956 if arg == nil {
957 if md, ok := arg.(Modeler); ok { 957 params = append(params, arg)
958 ind := reflect.Indirect(reflect.ValueOf(md)) 958 continue
959 if _, vu, exist := d.existPk(mi, ind); exist { 959 }
960 arg = vu 960
961 kind := val.Kind()
962
963 switch kind {
964 case reflect.Slice, reflect.Array:
965 var args []interface{}
966 for i := 0; i < val.Len(); i++ {
967 v := val.Index(i)
968
969 var vu interface{}
970 if v.CanInterface() {
971 vu = v.Interface()
972 }
973
974 if vu == nil {
975 continue
976 }
977
978 args = append(args, vu)
979 }
980
981 if len(args) > 0 {
982 p := d.getOperatorParams(operator, args)
983 params = append(params, p...)
984 }
985
986 case reflect.Ptr, reflect.Struct:
987 ind := reflect.Indirect(val)
988
989 if ind.Kind() == reflect.Struct {
990 typ := ind.Type()
991 fullName := typ.PkgPath() + "." + typ.Name()
992 var value interface{}
993 if mmi, ok := modelCache.get(fullName); ok {
994 if _, vu, exist := d.existPk(mmi, ind); exist {
995 value = vu
996 }
997 }
998 arg = value
999
1000 if arg == nil {
1001 panic(fmt.Sprintf("`%s` operator need a valid args value, unknown table or value `%v`", operator, val.Type()))
1002 }
961 } else { 1003 } else {
962 panic(fmt.Sprintf("`%s` need a valid args value", operator)) 1004 arg = ind.Interface()
963 } 1005 }
1006
1007 params = append(params, arg)
1008
1009 default:
1010 params = append(params, arg)
964 } 1011 }
965 params[i] = arg 1012
966 } 1013 }
1014
1015 return
1016 }
1017
1018 func (d *dbBase) GetOperatorSql(mi *modelInfo, operator string, args []interface{}) (string, []interface{}) {
1019 sql := ""
1020 params := d.getOperatorParams(operator, args)
1021
967 if operator == "in" { 1022 if operator == "in" {
968 marks := make([]string, len(params)) 1023 marks := make([]string, len(params))
969 for i, _ := range marks { 1024 for i, _ := range marks {
......
...@@ -18,7 +18,7 @@ fmt.Println(o.Delete(user)) ...@@ -18,7 +18,7 @@ fmt.Println(o.Delete(user))
18 o := orm.NewOrm() 18 o := orm.NewOrm()
19 user := User{Id: 1} 19 user := User{Id: 1}
20 20
21 o.Read(&user) 21 err = o.Read(&user)
22 22
23 if err == sql.ErrNoRows { 23 if err == sql.ErrNoRows {
24 fmt.Println("查询不到") 24 fmt.Println("查询不到")
......
...@@ -16,7 +16,10 @@ const ( ...@@ -16,7 +16,10 @@ const (
16 16
17 var ( 17 var (
18 errLog *log.Logger 18 errLog *log.Logger
19 modelCache = &_modelCache{cache: make(map[string]*modelInfo)} 19 modelCache = &_modelCache{
20 cache: make(map[string]*modelInfo),
21 cacheByFN: make(map[string]*modelInfo),
22 }
20 supportTag = map[string]int{ 23 supportTag = map[string]int{
21 "null": 1, 24 "null": 1,
22 "blank": 1, 25 "blank": 1,
...@@ -47,9 +50,10 @@ func init() { ...@@ -47,9 +50,10 @@ func init() {
47 50
48 type _modelCache struct { 51 type _modelCache struct {
49 sync.RWMutex 52 sync.RWMutex
50 orders []string 53 orders []string
51 cache map[string]*modelInfo 54 cache map[string]*modelInfo
52 done bool 55 cacheByFN map[string]*modelInfo
56 done bool
53 } 57 }
54 58
55 func (mc *_modelCache) all() map[string]*modelInfo { 59 func (mc *_modelCache) all() map[string]*modelInfo {
...@@ -70,12 +74,16 @@ func (mc *_modelCache) allOrdered() []*modelInfo { ...@@ -70,12 +74,16 @@ func (mc *_modelCache) allOrdered() []*modelInfo {
70 74
71 func (mc *_modelCache) get(table string) (mi *modelInfo, ok bool) { 75 func (mc *_modelCache) get(table string) (mi *modelInfo, ok bool) {
72 mi, ok = mc.cache[table] 76 mi, ok = mc.cache[table]
77 if ok == false {
78 mi, ok = mc.cacheByFN[table]
79 }
73 return 80 return
74 } 81 }
75 82
76 func (mc *_modelCache) set(table string, mi *modelInfo) *modelInfo { 83 func (mc *_modelCache) set(table string, mi *modelInfo) *modelInfo {
77 mii := mc.cache[table] 84 mii := mc.cache[table]
78 mc.cache[table] = mi 85 mc.cache[table] = mi
86 mc.cacheByFN[mi.fullName] = mi
79 if mii == nil { 87 if mii == nil {
80 mc.orders = append(mc.orders, table) 88 mc.orders = append(mc.orders, table)
81 } 89 }
......
...@@ -410,6 +410,15 @@ func TestOperators(t *testing.T) { ...@@ -410,6 +410,15 @@ func TestOperators(t *testing.T) {
410 num, err = qs.Filter("status__in", 1, 2).Count() 410 num, err = qs.Filter("status__in", 1, 2).Count()
411 throwFail(t, err) 411 throwFail(t, err)
412 throwFail(t, AssertIs(num, T_Equal, 2)) 412 throwFail(t, AssertIs(num, T_Equal, 2))
413
414 num, err = qs.Filter("status__in", []int{1, 2}).Count()
415 throwFail(t, err)
416 throwFail(t, AssertIs(num, T_Equal, 2))
417
418 n1, n2 := 1, 2
419 num, err = qs.Filter("status__in", []*int{&n1}, &n2).Count()
420 throwFail(t, err)
421 throwFail(t, AssertIs(num, T_Equal, 2))
413 } 422 }
414 423
415 func TestAll(t *testing.T) { 424 func TestAll(t *testing.T) {
...@@ -684,5 +693,49 @@ func TestDelete(t *testing.T) { ...@@ -684,5 +693,49 @@ func TestDelete(t *testing.T) {
684 } 693 }
685 694
686 func TestTransaction(t *testing.T) { 695 func TestTransaction(t *testing.T) {
696 o := NewOrm()
697 err := o.Begin()
698 throwFail(t, err)
699
700 var names = []string{"1", "2", "3"}
701
702 var user User
703 user.UserName = names[0]
704 id, err := o.Insert(&user)
705 throwFail(t, err)
706 throwFail(t, AssertIs(id, T_Large, 0))
707
708 num, err := o.QueryTable("user").Filter("user_name", "slene").Update(Params{"user_name": names[1]})
709 throwFail(t, err)
710 throwFail(t, AssertIs(num, T_Large, 0))
711
712 switch o.Driver().Type() {
713 case DR_MySQL:
714 id, err := o.Raw("INSERT INTO user (user_name) VALUES (?)", names[2]).Exec()
715 throwFail(t, err)
716 throwFail(t, AssertIs(id, T_Large, 0))
717 }
718
719 err = o.Rollback()
720 throwFail(t, err)
721
722 num, err = o.QueryTable("user").Filter("user_name__in", &user).Count()
723 throwFail(t, err)
724 throwFail(t, AssertIs(num, T_Equal, 0))
725
726 err = o.Begin()
727 throwFail(t, err)
728
729 user.UserName = "commit"
730 id, err = o.Insert(&user)
731 throwFail(t, err)
732 throwFail(t, AssertIs(id, T_Large, 0))
733
734 o.Commit()
735 throwFail(t, err)
736
737 num, err = o.QueryTable("user").Filter("user_name", "commit").Delete()
738 throwFail(t, err)
739 throwFail(t, AssertIs(num, T_Equal, 1))
687 740
688 } 741 }
......
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!