45345fa7 by slene

orm add postgres support

1 parent 449fbe82
...@@ -49,28 +49,8 @@ type dbBase struct { ...@@ -49,28 +49,8 @@ type dbBase struct {
49 ins dbBaser 49 ins dbBaser
50 } 50 }
51 51
52 func (d *dbBase) existPk(mi *modelInfo, ind reflect.Value) (column string, value interface{}, exist bool) {
53
54 fi := mi.fields.pk
55
56 v := ind.Field(fi.fieldIndex)
57 if fi.fieldType&IsIntegerField > 0 {
58 vu := v.Int()
59 exist = vu > 0
60 value = vu
61 } else {
62 vu := v.String()
63 exist = vu != ""
64 value = vu
65 }
66
67 column = fi.column
68
69 return
70 }
71
72 func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, skipAuto bool, insert bool) (columns []string, values []interface{}, err error) { 52 func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, skipAuto bool, insert bool) (columns []string, values []interface{}, err error) {
73 _, pkValue, _ := d.existPk(mi, ind) 53 _, pkValue, _ := getExistPk(mi, ind)
74 for _, column := range mi.fields.orders { 54 for _, column := range mi.fields.orders {
75 fi := mi.fields.columns[column] 55 fi := mi.fields.columns[column]
76 if fi.dbcol == false || fi.auto && skipAuto { 56 if fi.dbcol == false || fi.auto && skipAuto {
...@@ -104,7 +84,7 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, skipAuto bool, ...@@ -104,7 +84,7 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, skipAuto bool,
104 if field.IsNil() { 84 if field.IsNil() {
105 value = nil 85 value = nil
106 } else { 86 } else {
107 if _, vu, ok := d.existPk(fi.relModelInfo, reflect.Indirect(field)); ok { 87 if _, vu, ok := getExistPk(fi.relModelInfo, reflect.Indirect(field)); ok {
108 value = vu 88 value = vu
109 } else { 89 } else {
110 value = nil 90 value = nil
...@@ -159,6 +139,8 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, ...@@ -159,6 +139,8 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string,
159 139
160 d.ins.ReplaceMarks(&query) 140 d.ins.ReplaceMarks(&query)
161 141
142 d.ins.HasReturningID(mi, &query)
143
162 stmt, err := q.Prepare(query) 144 stmt, err := q.Prepare(query)
163 return stmt, query, err 145 return stmt, query, err
164 } 146 }
...@@ -169,15 +151,22 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value) ...@@ -169,15 +151,22 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value)
169 return 0, err 151 return 0, err
170 } 152 }
171 153
172 if res, err := stmt.Exec(values...); err == nil { 154 if d.ins.HasReturningID(mi, nil) {
173 return res.LastInsertId() 155 row := stmt.QueryRow(values...)
156 var id int64
157 err := row.Scan(&id)
158 return id, err
174 } else { 159 } else {
175 return 0, err 160 if res, err := stmt.Exec(values...); err == nil {
161 return res.LastInsertId()
162 } else {
163 return 0, err
164 }
176 } 165 }
177 } 166 }
178 167
179 func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value) error { 168 func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value) error {
180 pkColumn, pkValue, ok := d.existPk(mi, ind) 169 pkColumn, pkValue, ok := getExistPk(mi, ind)
181 if ok == false { 170 if ok == false {
182 return ErrMissPK 171 return ErrMissPK
183 } 172 }
...@@ -237,15 +226,22 @@ func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e ...@@ -237,15 +226,22 @@ func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e
237 226
238 d.ins.ReplaceMarks(&query) 227 d.ins.ReplaceMarks(&query)
239 228
240 if res, err := q.Exec(query, values...); err == nil { 229 if d.ins.HasReturningID(mi, &query) {
241 return res.LastInsertId() 230 row := q.QueryRow(query, values...)
231 var id int64
232 err := row.Scan(&id)
233 return id, err
242 } else { 234 } else {
243 return 0, err 235 if res, err := q.Exec(query, values...); err == nil {
236 return res.LastInsertId()
237 } else {
238 return 0, err
239 }
244 } 240 }
245 } 241 }
246 242
247 func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, error) { 243 func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, error) {
248 pkName, pkValue, ok := d.existPk(mi, ind) 244 pkName, pkValue, ok := getExistPk(mi, ind)
249 if ok == false { 245 if ok == false {
250 return 0, ErrMissPK 246 return 0, ErrMissPK
251 } 247 }
...@@ -274,7 +270,7 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e ...@@ -274,7 +270,7 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e
274 } 270 }
275 271
276 func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, error) { 272 func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, error) {
277 pkName, pkValue, ok := d.existPk(mi, ind) 273 pkName, pkValue, ok := getExistPk(mi, ind)
278 if ok == false { 274 if ok == false {
279 return 0, ErrMissPK 275 return 0, ErrMissPK
280 } 276 }
...@@ -429,7 +425,7 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con ...@@ -429,7 +425,7 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
429 return 0, nil 425 return 0, nil
430 } 426 }
431 427
432 sql, args := d.ins.GenerateOperatorSql(mi, "in", args) 428 sql, args := d.ins.GenerateOperatorSql(mi, mi.fields.pk, "in", args)
433 query = fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s %s", Q, mi.table, Q, Q, mi.fields.pk.column, Q, sql) 429 query = fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s %s", Q, mi.table, Q, Q, mi.fields.pk.column, Q, sql)
434 430
435 d.ins.ReplaceMarks(&query) 431 d.ins.ReplaceMarks(&query)
...@@ -616,75 +612,14 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition ...@@ -616,75 +612,14 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition
616 return 612 return
617 } 613 }
618 614
619 func (d *dbBase) getOperatorParams(operator string, args []interface{}) (params []interface{}) { 615 func (d *dbBase) GenerateOperatorSql(mi *modelInfo, fi *fieldInfo, operator string, args []interface{}) (string, []interface{}) {
620 for _, arg := range args { 616 sql := ""
621 val := reflect.ValueOf(arg) 617 params := getFlatParams(fi, args)
622
623 if arg == nil {
624 params = append(params, arg)
625 continue
626 }
627
628 kind := val.Kind()
629
630 switch kind {
631 case reflect.Slice, reflect.Array:
632 var args []interface{}
633 for i := 0; i < val.Len(); i++ {
634 v := val.Index(i)
635
636 var vu interface{}
637 if v.CanInterface() {
638 vu = v.Interface()
639 }
640
641 if vu == nil {
642 continue
643 }
644
645 args = append(args, vu)
646 }
647
648 if len(args) > 0 {
649 p := d.getOperatorParams(operator, args)
650 params = append(params, p...)
651 }
652
653 case reflect.Ptr, reflect.Struct:
654 ind := reflect.Indirect(val)
655
656 if ind.Kind() == reflect.Struct {
657 typ := ind.Type()
658 name := getFullName(typ)
659 var value interface{}
660 if mmi, ok := modelCache.getByFN(name); ok {
661 if _, vu, exist := d.existPk(mmi, ind); exist {
662 value = vu
663 }
664 }
665 arg = value
666
667 if arg == nil {
668 panic(fmt.Sprintf("`%s` operator need a valid args value, unknown table or value `%s`", operator, name))
669 }
670 } else {
671 arg = ind.Interface()
672 }
673
674 params = append(params, arg)
675
676 default:
677 params = append(params, arg)
678 }
679 618
619 if len(params) == 0 {
620 panic(fmt.Sprintf("operator `%s` need at least one args", operator))
680 } 621 }
681 622 arg := params[0]
682 return
683 }
684
685 func (d *dbBase) GenerateOperatorSql(mi *modelInfo, operator string, args []interface{}) (string, []interface{}) {
686 sql := ""
687 params := d.getOperatorParams(operator, args)
688 623
689 if operator == "in" { 624 if operator == "in" {
690 marks := make([]string, len(params)) 625 marks := make([]string, len(params))
...@@ -697,7 +632,6 @@ func (d *dbBase) GenerateOperatorSql(mi *modelInfo, operator string, args []inte ...@@ -697,7 +632,6 @@ func (d *dbBase) GenerateOperatorSql(mi *modelInfo, operator string, args []inte
697 panic(fmt.Sprintf("operator `%s` need 1 args not %d", operator, len(params))) 632 panic(fmt.Sprintf("operator `%s` need 1 args not %d", operator, len(params)))
698 } 633 }
699 sql = d.ins.OperatorSql(operator) 634 sql = d.ins.OperatorSql(operator)
700 arg := params[0]
701 switch operator { 635 switch operator {
702 case "exact": 636 case "exact":
703 if arg == nil { 637 if arg == nil {
...@@ -731,6 +665,10 @@ func (d *dbBase) GenerateOperatorSql(mi *modelInfo, operator string, args []inte ...@@ -731,6 +665,10 @@ func (d *dbBase) GenerateOperatorSql(mi *modelInfo, operator string, args []inte
731 return sql, params 665 return sql, params
732 } 666 }
733 667
668 func (d *dbBase) GenerateOperatorLeftCol(string, *string) {
669
670 }
671
734 func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string, values []interface{}) { 672 func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string, values []interface{}) {
735 for i, column := range cols { 673 for i, column := range cols {
736 val := reflect.Indirect(reflect.ValueOf(values[i])).Interface() 674 val := reflect.Indirect(reflect.ValueOf(values[i])).Interface()
...@@ -1006,11 +944,11 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond ...@@ -1006,11 +944,11 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
1006 cols = make([]string, 0, len(exprs)) 944 cols = make([]string, 0, len(exprs))
1007 infos = make([]*fieldInfo, 0, len(exprs)) 945 infos = make([]*fieldInfo, 0, len(exprs))
1008 for _, ex := range exprs { 946 for _, ex := range exprs {
1009 index, col, name, fi, suc := tables.parseExprs(mi, strings.Split(ex, ExprSep)) 947 index, name, fi, suc := tables.parseExprs(mi, strings.Split(ex, ExprSep))
1010 if suc == false { 948 if suc == false {
1011 panic(fmt.Errorf("unknown field/column name `%s`", ex)) 949 panic(fmt.Errorf("unknown field/column name `%s`", ex))
1012 } 950 }
1013 cols = append(cols, fmt.Sprintf("%s.%s%s%s %s%s%s", index, Q, col, Q, Q, name, Q)) 951 cols = append(cols, fmt.Sprintf("%s.%s%s%s %s%s%s", index, Q, fi.column, Q, Q, name, Q))
1014 infos = append(infos, fi) 952 infos = append(infos, fi)
1015 } 953 }
1016 } else { 954 } else {
...@@ -1137,3 +1075,7 @@ func (d *dbBase) TableQuote() string { ...@@ -1137,3 +1075,7 @@ func (d *dbBase) TableQuote() string {
1137 func (d *dbBase) ReplaceMarks(query *string) { 1075 func (d *dbBase) ReplaceMarks(query *string) {
1138 // default use `?` as mark, do nothing 1076 // default use `?` as mark, do nothing
1139 } 1077 }
1078
1079 func (d *dbBase) HasReturningID(*modelInfo, *string) bool {
1080 return false
1081 }
......
1 package orm 1 package orm
2 2
3 import ( 3 import (
4 "fmt"
4 "strconv" 5 "strconv"
5 ) 6 )
6 7
...@@ -29,6 +30,23 @@ func (d *dbBasePostgres) OperatorSql(operator string) string { ...@@ -29,6 +30,23 @@ func (d *dbBasePostgres) OperatorSql(operator string) string {
29 return postgresOperators[operator] 30 return postgresOperators[operator]
30 } 31 }
31 32
33 func (d *dbBasePostgres) GenerateOperatorLeftCol(operator string, leftCol *string) {
34 switch operator {
35 case "contains", "startswith", "endswith":
36 *leftCol = fmt.Sprintf("%s::text", *leftCol)
37 case "iexact", "icontains", "istartswith", "iendswith":
38 *leftCol = fmt.Sprintf("UPPER(%s::text)", *leftCol)
39 }
40 }
41
42 func (d *dbBasePostgres) SupportUpdateJoin() bool {
43 return false
44 }
45
46 func (d *dbBasePostgres) MaxLimit() uint64 {
47 return 0
48 }
49
32 func (d *dbBasePostgres) TableQuote() string { 50 func (d *dbBasePostgres) TableQuote() string {
33 return `"` 51 return `"`
34 } 52 }
...@@ -59,7 +77,15 @@ func (d *dbBasePostgres) ReplaceMarks(query *string) { ...@@ -59,7 +77,15 @@ func (d *dbBasePostgres) ReplaceMarks(query *string) {
59 *query = string(data) 77 *query = string(data)
60 } 78 }
61 79
62 // func (d *dbBasePostgres) 80 func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) (has bool) {
81 if mi.fields.pk.auto {
82 if query != nil {
83 *query = fmt.Sprintf(`%s RETURNING "%s"`, *query, mi.fields.pk.column)
84 }
85 has = true
86 }
87 return
88 }
63 89
64 func newdbBasePostgres() dbBaser { 90 func newdbBasePostgres() dbBaser {
65 b := new(dbBasePostgres) 91 b := new(dbBasePostgres)
......
...@@ -177,7 +177,7 @@ func (t *dbTables) getJoinSql() (join string) { ...@@ -177,7 +177,7 @@ func (t *dbTables) getJoinSql() (join string) {
177 return 177 return
178 } 178 }
179 179
180 func (d *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, column, name string, info *fieldInfo, success bool) { 180 func (d *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string, info *fieldInfo, success bool) {
181 var ( 181 var (
182 ffi *fieldInfo 182 ffi *fieldInfo
183 jtl *dbTable 183 jtl *dbTable
...@@ -236,7 +236,6 @@ func (d *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, column, nam ...@@ -236,7 +236,6 @@ func (d *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, column, nam
236 } else { 236 } else {
237 index = jtl.index 237 index = jtl.index
238 } 238 }
239 column = fi.column
240 info = fi 239 info = fi
241 if jtl != nil { 240 if jtl != nil {
242 name = jtl.name + ExprSep + fi.name 241 name = jtl.name + ExprSep + fi.name
...@@ -256,14 +255,14 @@ func (d *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, column, nam ...@@ -256,14 +255,14 @@ func (d *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, column, nam
256 255
257 if exist == false { 256 if exist == false {
258 index = "" 257 index = ""
259 column = ""
260 name = "" 258 name = ""
259 info = nil
261 success = false 260 success = false
262 return 261 return
263 } 262 }
264 } 263 }
265 264
266 success = index != "" && column != "" 265 success = index != "" && info != nil
267 return 266 return
268 } 267 }
269 268
...@@ -305,7 +304,7 @@ func (d *dbTables) getCondSql(cond *Condition, sub bool) (where string, params [ ...@@ -305,7 +304,7 @@ func (d *dbTables) getCondSql(cond *Condition, sub bool) (where string, params [
305 exprs = exprs[:num] 304 exprs = exprs[:num]
306 } 305 }
307 306
308 index, column, _, _, suc := d.parseExprs(mi, exprs) 307 index, _, fi, suc := d.parseExprs(mi, exprs)
309 if suc == false { 308 if suc == false {
310 panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(p.exprs, ExprSep))) 309 panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(p.exprs, ExprSep)))
311 } 310 }
...@@ -314,9 +313,12 @@ func (d *dbTables) getCondSql(cond *Condition, sub bool) (where string, params [ ...@@ -314,9 +313,12 @@ func (d *dbTables) getCondSql(cond *Condition, sub bool) (where string, params [
314 operator = "exact" 313 operator = "exact"
315 } 314 }
316 315
317 operSql, args := d.base.GenerateOperatorSql(mi, operator, p.args) 316 operSql, args := d.base.GenerateOperatorSql(mi, fi, operator, p.args)
318 317
319 where += fmt.Sprintf("%s.%s%s%s %s ", index, Q, column, Q, operSql) 318 leftCol := fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q)
319 d.base.GenerateOperatorLeftCol(operator, &leftCol)
320
321 where += fmt.Sprintf("%s %s ", leftCol, operSql)
320 params = append(params, args...) 322 params = append(params, args...)
321 323
322 } 324 }
...@@ -345,12 +347,12 @@ func (d *dbTables) getOrderSql(orders []string) (orderSql string) { ...@@ -345,12 +347,12 @@ func (d *dbTables) getOrderSql(orders []string) (orderSql string) {
345 } 347 }
346 exprs := strings.Split(order, ExprSep) 348 exprs := strings.Split(order, ExprSep)
347 349
348 index, column, _, _, suc := d.parseExprs(d.mi, exprs) 350 index, _, fi, suc := d.parseExprs(d.mi, exprs)
349 if suc == false { 351 if suc == false {
350 panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep))) 352 panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep)))
351 } 353 }
352 354
353 orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, column, Q, asc)) 355 orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, fi.column, Q, asc))
354 } 356 }
355 357
356 orderSql = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", ")) 358 orderSql = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", "))
...@@ -365,7 +367,11 @@ func (d *dbTables) getLimitSql(mi *modelInfo, offset int64, limit int) (limits s ...@@ -365,7 +367,11 @@ func (d *dbTables) getLimitSql(mi *modelInfo, offset int64, limit int) (limits s
365 // no limit 367 // no limit
366 if offset > 0 { 368 if offset > 0 {
367 maxLimit := d.base.MaxLimit() 369 maxLimit := d.base.MaxLimit()
368 limits = fmt.Sprintf("LIMIT %d OFFSET %d", maxLimit, offset) 370 if maxLimit == 0 {
371 limits = fmt.Sprintf("OFFSET %d", offset)
372 } else {
373 limits = fmt.Sprintf("LIMIT %d OFFSET %d", maxLimit, offset)
374 }
369 } 375 }
370 } else if offset <= 0 { 376 } else if offset <= 0 {
371 limits = fmt.Sprintf("LIMIT %d", limit) 377 limits = fmt.Sprintf("LIMIT %d", limit)
......
1 package orm
2
3 import (
4 "fmt"
5 "reflect"
6 "time"
7 )
8
9 func getExistPk(mi *modelInfo, ind reflect.Value) (column string, value interface{}, exist bool) {
10 fi := mi.fields.pk
11
12 v := ind.Field(fi.fieldIndex)
13 if fi.fieldType&IsIntegerField > 0 {
14 vu := v.Int()
15 exist = vu > 0
16 value = vu
17 } else {
18 vu := v.String()
19 exist = vu != ""
20 value = vu
21 }
22
23 column = fi.column
24 return
25 }
26
27 func getFlatParams(fi *fieldInfo, args []interface{}) (params []interface{}) {
28
29 outFor:
30 for _, arg := range args {
31 val := reflect.ValueOf(arg)
32
33 if arg == nil {
34 params = append(params, arg)
35 continue
36 }
37
38 switch v := arg.(type) {
39 case []byte:
40 case time.Time:
41 if fi != nil && fi.fieldType == TypeDateField {
42 arg = v.Format(format_Date)
43 } else {
44 arg = v.Format(format_DateTime)
45 }
46 default:
47 kind := val.Kind()
48 switch kind {
49 case reflect.Slice, reflect.Array:
50
51 var args []interface{}
52 for i := 0; i < val.Len(); i++ {
53 v := val.Index(i)
54
55 var vu interface{}
56 if v.CanInterface() {
57 vu = v.Interface()
58 }
59
60 if vu == nil {
61 continue
62 }
63
64 args = append(args, vu)
65 }
66
67 if len(args) > 0 {
68 p := getFlatParams(fi, args)
69 params = append(params, p...)
70 }
71 continue outFor
72
73 case reflect.Ptr, reflect.Struct:
74 ind := reflect.Indirect(val)
75
76 if ind.Kind() == reflect.Struct {
77 typ := ind.Type()
78 name := getFullName(typ)
79 var value interface{}
80 if mmi, ok := modelCache.getByFN(name); ok {
81 if _, vu, exist := getExistPk(mmi, ind); exist {
82 value = vu
83 }
84 }
85 arg = value
86
87 if arg == nil {
88 panic(fmt.Errorf("need a valid args value, unknown table or value `%s`", name))
89 }
90 } else {
91 arg = ind.Interface()
92 }
93 }
94 }
95 params = append(params, arg)
96 }
97 return
98 }
...@@ -302,7 +302,8 @@ ORM_DRIVER=mysql ORM_SOURCE="root:root@/my_db?charset=utf8" go test github.com/a ...@@ -302,7 +302,8 @@ ORM_DRIVER=mysql ORM_SOURCE="root:root@/my_db?charset=utf8" go test github.com/a
302 queries := strings.Split(initSQLs[DBARGS.Driver], ";") 302 queries := strings.Split(initSQLs[DBARGS.Driver], ";")
303 303
304 for _, query := range queries { 304 for _, query := range queries {
305 if strings.TrimSpace(query) == "" { 305 query = strings.TrimSpace(query)
306 if len(query) == 0 {
306 continue 307 continue
307 } 308 }
308 _, err := dORM.Raw(query).Exec() 309 _, err := dORM.Raw(query).Exec()
......
...@@ -22,7 +22,11 @@ func NewLog(out io.Writer) *Log { ...@@ -22,7 +22,11 @@ func NewLog(out io.Writer) *Log {
22 func debugLogQueies(alias *alias, operaton, query string, t time.Time, err error, args ...interface{}) { 22 func debugLogQueies(alias *alias, operaton, query string, t time.Time, err error, args ...interface{}) {
23 sub := time.Now().Sub(t) / 1e5 23 sub := time.Now().Sub(t) / 1e5
24 elsp := float64(int(sub)) / 10.0 24 elsp := float64(int(sub)) / 10.0
25 con := fmt.Sprintf(" - %s - [Queries/%s] - [%11s / %7.1fms] - [%s]", t.Format(format_DateTime), alias.Name, operaton, elsp, query) 25 flag := " OK"
26 if err != nil {
27 flag = "FAIL"
28 }
29 con := fmt.Sprintf(" - %s - [Queries/%s] - [%s / %11s / %7.1fms] - [%s]", t.Format(format_DateTime), alias.Name, flag, operaton, elsp, query)
26 cons := make([]string, 0, len(args)) 30 cons := make([]string, 0, len(args))
27 for _, arg := range args { 31 for _, arg := range args {
28 cons = append(cons, fmt.Sprintf("%v", arg)) 32 cons = append(cons, fmt.Sprintf("%v", arg))
......
...@@ -27,12 +27,16 @@ func (o *rawPrepare) Close() error { ...@@ -27,12 +27,16 @@ func (o *rawPrepare) Close() error {
27 func newRawPreparer(rs *rawSet) (RawPreparer, error) { 27 func newRawPreparer(rs *rawSet) (RawPreparer, error) {
28 o := new(rawPrepare) 28 o := new(rawPrepare)
29 o.rs = rs 29 o.rs = rs
30 st, err := rs.orm.db.Prepare(rs.query) 30
31 query := rs.query
32 rs.orm.alias.DbBaser.ReplaceMarks(&query)
33
34 st, err := rs.orm.db.Prepare(query)
31 if err != nil { 35 if err != nil {
32 return nil, err 36 return nil, err
33 } 37 }
34 if Debug { 38 if Debug {
35 o.stmt = newStmtQueryLog(rs.orm.alias, st, rs.query) 39 o.stmt = newStmtQueryLog(rs.orm.alias, st, query)
36 } else { 40 } else {
37 o.stmt = st 41 o.stmt = st
38 } 42 }
...@@ -53,7 +57,11 @@ func (o rawSet) SetArgs(args ...interface{}) RawSeter { ...@@ -53,7 +57,11 @@ func (o rawSet) SetArgs(args ...interface{}) RawSeter {
53 } 57 }
54 58
55 func (o *rawSet) Exec() (sql.Result, error) { 59 func (o *rawSet) Exec() (sql.Result, error) {
56 return o.orm.db.Exec(o.query, o.args...) 60 query := o.query
61 o.orm.alias.DbBaser.ReplaceMarks(&query)
62
63 args := getFlatParams(nil, o.args)
64 return o.orm.db.Exec(query, args...)
57 } 65 }
58 66
59 func (o *rawSet) QueryRow(...interface{}) error { 67 func (o *rawSet) QueryRow(...interface{}) error {
...@@ -85,8 +93,13 @@ func (o *rawSet) readValues(container interface{}) (int64, error) { ...@@ -85,8 +93,13 @@ func (o *rawSet) readValues(container interface{}) (int64, error) {
85 panic(fmt.Sprintf("unsupport read values type `%T`", container)) 93 panic(fmt.Sprintf("unsupport read values type `%T`", container))
86 } 94 }
87 95
96 query := o.query
97 o.orm.alias.DbBaser.ReplaceMarks(&query)
98
99 args := getFlatParams(nil, o.args)
100
88 var rs *sql.Rows 101 var rs *sql.Rows
89 if r, err := o.orm.db.Query(o.query, o.args...); err != nil { 102 if r, err := o.orm.db.Query(query, args...); err != nil {
90 return 0, err 103 return 0, err
91 } else { 104 } else {
92 rs = r 105 rs = r
......
...@@ -51,7 +51,9 @@ func ValuesCompare(is bool, a interface{}, o T_Code, args ...interface{}) (err e ...@@ -51,7 +51,9 @@ func ValuesCompare(is bool, a interface{}, o T_Code, args ...interface{}) (err e
51 if v2, vo := b.(time.Time); vo { 51 if v2, vo := b.(time.Time); vo {
52 if arg.Get(1) != nil { 52 if arg.Get(1) != nil {
53 format := ToStr(arg.Get(1)) 53 format := ToStr(arg.Get(1))
54 ok = v.Format(format) == v2.Format(format) 54 a = v.Format(format)
55 b = v2.Format(format)
56 ok = a == b
55 } else { 57 } else {
56 err = fmt.Errorf("compare datetime miss format") 58 err = fmt.Errorf("compare datetime miss format")
57 goto wrongArg 59 goto wrongArg
...@@ -363,6 +365,10 @@ func TestExpr(t *testing.T) { ...@@ -363,6 +365,10 @@ func TestExpr(t *testing.T) {
363 num, err := qs.Filter("UserName", "slene").Filter("user_name", "slene").Filter("profile__Age", 28).Count() 365 num, err := qs.Filter("UserName", "slene").Filter("user_name", "slene").Filter("profile__Age", 28).Count()
364 throwFail(t, err) 366 throwFail(t, err)
365 throwFail(t, AssertIs(num, T_Equal, 1)) 367 throwFail(t, AssertIs(num, T_Equal, 1))
368
369 num, err = qs.Filter("created", time.Now()).Count()
370 throwFail(t, err)
371 throwFail(t, AssertIs(num, T_Equal, 3))
366 } 372 }
367 373
368 func TestOperators(t *testing.T) { 374 func TestOperators(t *testing.T) {
...@@ -722,6 +728,102 @@ func TestRaw(t *testing.T) { ...@@ -722,6 +728,102 @@ func TestRaw(t *testing.T) {
722 throwFail(t, AssertIs(list[1], T_Equal, "3")) 728 throwFail(t, AssertIs(list[1], T_Equal, "3"))
723 throwFail(t, AssertIs(list[2], T_Equal, "")) 729 throwFail(t, AssertIs(list[2], T_Equal, ""))
724 } 730 }
731
732 pre, err := dORM.Raw("INSERT INTO tag (name) VALUES (?)").Prepare()
733 throwFail(t, err)
734 if pre != nil {
735 r, err := pre.Exec("name1")
736 throwFail(t, err)
737
738 tid, err := r.LastInsertId()
739 throwFail(t, err)
740 throwFail(t, AssertIs(tid, T_Large, 0))
741
742 r, err = pre.Exec("name2")
743 throwFail(t, err)
744
745 id, err := r.LastInsertId()
746 throwFail(t, err)
747 throwFail(t, AssertIs(id, T_Equal, tid+1))
748
749 r, err = pre.Exec("name3")
750 throwFail(t, err)
751
752 id, err = r.LastInsertId()
753 throwFail(t, err)
754 throwFail(t, AssertIs(id, T_Equal, tid+2))
755
756 err = pre.Close()
757 throwFail(t, err)
758
759 res, err := dORM.Raw("DELETE FROM tag WHERE name IN (?, ?, ?)", []string{"name1", "name2", "name3"}).Exec()
760 throwFail(t, err)
761
762 num, err := res.RowsAffected()
763 throwFail(t, err)
764 throwFail(t, AssertIs(num, T_Equal, 3))
765 }
766
767 case IsPostgres:
768
769 res, err := dORM.Raw(`UPDATE "user" SET "user_name" = ? WHERE "user_name" = ?`, "testing", "slene").Exec()
770 throwFail(t, err)
771 num, err := res.RowsAffected()
772 throwFail(t, AssertIs(num, T_Equal, 1), err)
773
774 res, err = dORM.Raw(`UPDATE "user" SET "user_name" = ? WHERE "user_name" = ?`, "slene", "testing").Exec()
775 throwFail(t, err)
776 num, err = res.RowsAffected()
777 throwFail(t, AssertIs(num, T_Equal, 1), err)
778
779 var maps []Params
780 num, err = dORM.Raw(`SELECT "user_name" FROM "user" WHERE "status" = ?`, 1).Values(&maps)
781 throwFail(t, err)
782 throwFail(t, AssertIs(num, T_Equal, 1))
783 if num == 1 {
784 throwFail(t, AssertIs(maps[0]["user_name"], T_Equal, "slene"))
785 }
786
787 var lists []ParamsList
788 num, err = dORM.Raw(`SELECT "user_name" FROM "user" WHERE "status" = ?`, 1).ValuesList(&lists)
789 throwFail(t, err)
790 throwFail(t, AssertIs(num, T_Equal, 1))
791 if num == 1 {
792 throwFail(t, AssertIs(lists[0][0], T_Equal, "slene"))
793 }
794
795 var list ParamsList
796 num, err = dORM.Raw(`SELECT "profile_id" FROM "user" ORDER BY id ASC`).ValuesFlat(&list)
797 throwFail(t, err)
798 throwFail(t, AssertIs(num, T_Equal, 3))
799 if num == 3 {
800 throwFail(t, AssertIs(list[0], T_Equal, "2"))
801 throwFail(t, AssertIs(list[1], T_Equal, "3"))
802 throwFail(t, AssertIs(list[2], T_Equal, ""))
803 }
804
805 pre, err := dORM.Raw(`INSERT INTO "tag" ("name") VALUES (?) RETURNING "id"`).Prepare()
806 throwFail(t, err)
807 if pre != nil {
808 _, err := pre.Exec("name1")
809 throwFail(t, err)
810
811 _, err = pre.Exec("name2")
812 throwFail(t, err)
813
814 _, err = pre.Exec("name3")
815 throwFail(t, err)
816
817 err = pre.Close()
818 throwFail(t, err)
819
820 res, err := dORM.Raw(`DELETE FROM "tag" WHERE "name" IN (?, ?, ?)`, []string{"name1", "name2", "name3"}).Exec()
821 throwFail(t, err)
822
823 num, err := res.RowsAffected()
824 throwFail(t, err)
825 throwFail(t, AssertIs(num, T_Equal, 3))
826 }
725 } 827 }
726 } 828 }
727 829
......
...@@ -121,10 +121,12 @@ type dbBaser interface { ...@@ -121,10 +121,12 @@ type dbBaser interface {
121 DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition) (int64, error) 121 DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition) (int64, error)
122 Count(dbQuerier, *querySet, *modelInfo, *Condition) (int64, error) 122 Count(dbQuerier, *querySet, *modelInfo, *Condition) (int64, error)
123 OperatorSql(string) string 123 OperatorSql(string) string
124 GenerateOperatorSql(*modelInfo, string, []interface{}) (string, []interface{}) 124 GenerateOperatorSql(*modelInfo, *fieldInfo, string, []interface{}) (string, []interface{})
125 GenerateOperatorLeftCol(string, *string)
125 PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error) 126 PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error)
126 ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}) (int64, error) 127 ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}) (int64, error)
127 MaxLimit() uint64 128 MaxLimit() uint64
128 TableQuote() string 129 TableQuote() string
129 ReplaceMarks(*string) 130 ReplaceMarks(*string)
131 HasReturningID(*modelInfo, *string) bool
130 } 132 }
......
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!