orm add postgres support
Showing
9 changed files
with
315 additions
and
121 deletions
| ... | @@ -49,28 +49,8 @@ type dbBase struct { | ... | @@ -49,28 +49,8 @@ type dbBase struct { |
| 49 | ins dbBaser | 49 | ins dbBaser |
| 50 | } | 50 | } |
| 51 | 51 | ||
| 52 | func (d *dbBase) existPk(mi *modelInfo, ind reflect.Value) (column string, value interface{}, exist bool) { | ||
| 53 | |||
| 54 | fi := mi.fields.pk | ||
| 55 | |||
| 56 | v := ind.Field(fi.fieldIndex) | ||
| 57 | if fi.fieldType&IsIntegerField > 0 { | ||
| 58 | vu := v.Int() | ||
| 59 | exist = vu > 0 | ||
| 60 | value = vu | ||
| 61 | } else { | ||
| 62 | vu := v.String() | ||
| 63 | exist = vu != "" | ||
| 64 | value = vu | ||
| 65 | } | ||
| 66 | |||
| 67 | column = fi.column | ||
| 68 | |||
| 69 | return | ||
| 70 | } | ||
| 71 | |||
| 72 | func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, skipAuto bool, insert bool) (columns []string, values []interface{}, err error) { | 52 | func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, skipAuto bool, insert bool) (columns []string, values []interface{}, err error) { |
| 73 | _, pkValue, _ := d.existPk(mi, ind) | 53 | _, pkValue, _ := getExistPk(mi, ind) |
| 74 | for _, column := range mi.fields.orders { | 54 | for _, column := range mi.fields.orders { |
| 75 | fi := mi.fields.columns[column] | 55 | fi := mi.fields.columns[column] |
| 76 | if fi.dbcol == false || fi.auto && skipAuto { | 56 | if fi.dbcol == false || fi.auto && skipAuto { |
| ... | @@ -104,7 +84,7 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, skipAuto bool, | ... | @@ -104,7 +84,7 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, skipAuto bool, |
| 104 | if field.IsNil() { | 84 | if field.IsNil() { |
| 105 | value = nil | 85 | value = nil |
| 106 | } else { | 86 | } else { |
| 107 | if _, vu, ok := d.existPk(fi.relModelInfo, reflect.Indirect(field)); ok { | 87 | if _, vu, ok := getExistPk(fi.relModelInfo, reflect.Indirect(field)); ok { |
| 108 | value = vu | 88 | value = vu |
| 109 | } else { | 89 | } else { |
| 110 | value = nil | 90 | value = nil |
| ... | @@ -159,6 +139,8 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, | ... | @@ -159,6 +139,8 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, |
| 159 | 139 | ||
| 160 | d.ins.ReplaceMarks(&query) | 140 | d.ins.ReplaceMarks(&query) |
| 161 | 141 | ||
| 142 | d.ins.HasReturningID(mi, &query) | ||
| 143 | |||
| 162 | stmt, err := q.Prepare(query) | 144 | stmt, err := q.Prepare(query) |
| 163 | return stmt, query, err | 145 | return stmt, query, err |
| 164 | } | 146 | } |
| ... | @@ -169,15 +151,22 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value) | ... | @@ -169,15 +151,22 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value) |
| 169 | return 0, err | 151 | return 0, err |
| 170 | } | 152 | } |
| 171 | 153 | ||
| 172 | if res, err := stmt.Exec(values...); err == nil { | 154 | if d.ins.HasReturningID(mi, nil) { |
| 173 | return res.LastInsertId() | 155 | row := stmt.QueryRow(values...) |
| 156 | var id int64 | ||
| 157 | err := row.Scan(&id) | ||
| 158 | return id, err | ||
| 174 | } else { | 159 | } else { |
| 175 | return 0, err | 160 | if res, err := stmt.Exec(values...); err == nil { |
| 161 | return res.LastInsertId() | ||
| 162 | } else { | ||
| 163 | return 0, err | ||
| 164 | } | ||
| 176 | } | 165 | } |
| 177 | } | 166 | } |
| 178 | 167 | ||
| 179 | func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value) error { | 168 | func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value) error { |
| 180 | pkColumn, pkValue, ok := d.existPk(mi, ind) | 169 | pkColumn, pkValue, ok := getExistPk(mi, ind) |
| 181 | if ok == false { | 170 | if ok == false { |
| 182 | return ErrMissPK | 171 | return ErrMissPK |
| 183 | } | 172 | } |
| ... | @@ -237,15 +226,22 @@ func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e | ... | @@ -237,15 +226,22 @@ func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e |
| 237 | 226 | ||
| 238 | d.ins.ReplaceMarks(&query) | 227 | d.ins.ReplaceMarks(&query) |
| 239 | 228 | ||
| 240 | if res, err := q.Exec(query, values...); err == nil { | 229 | if d.ins.HasReturningID(mi, &query) { |
| 241 | return res.LastInsertId() | 230 | row := q.QueryRow(query, values...) |
| 231 | var id int64 | ||
| 232 | err := row.Scan(&id) | ||
| 233 | return id, err | ||
| 242 | } else { | 234 | } else { |
| 243 | return 0, err | 235 | if res, err := q.Exec(query, values...); err == nil { |
| 236 | return res.LastInsertId() | ||
| 237 | } else { | ||
| 238 | return 0, err | ||
| 239 | } | ||
| 244 | } | 240 | } |
| 245 | } | 241 | } |
| 246 | 242 | ||
| 247 | func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, error) { | 243 | func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, error) { |
| 248 | pkName, pkValue, ok := d.existPk(mi, ind) | 244 | pkName, pkValue, ok := getExistPk(mi, ind) |
| 249 | if ok == false { | 245 | if ok == false { |
| 250 | return 0, ErrMissPK | 246 | return 0, ErrMissPK |
| 251 | } | 247 | } |
| ... | @@ -274,7 +270,7 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e | ... | @@ -274,7 +270,7 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e |
| 274 | } | 270 | } |
| 275 | 271 | ||
| 276 | func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, error) { | 272 | func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, error) { |
| 277 | pkName, pkValue, ok := d.existPk(mi, ind) | 273 | pkName, pkValue, ok := getExistPk(mi, ind) |
| 278 | if ok == false { | 274 | if ok == false { |
| 279 | return 0, ErrMissPK | 275 | return 0, ErrMissPK |
| 280 | } | 276 | } |
| ... | @@ -429,7 +425,7 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con | ... | @@ -429,7 +425,7 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con |
| 429 | return 0, nil | 425 | return 0, nil |
| 430 | } | 426 | } |
| 431 | 427 | ||
| 432 | sql, args := d.ins.GenerateOperatorSql(mi, "in", args) | 428 | sql, args := d.ins.GenerateOperatorSql(mi, mi.fields.pk, "in", args) |
| 433 | query = fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s %s", Q, mi.table, Q, Q, mi.fields.pk.column, Q, sql) | 429 | query = fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s %s", Q, mi.table, Q, Q, mi.fields.pk.column, Q, sql) |
| 434 | 430 | ||
| 435 | d.ins.ReplaceMarks(&query) | 431 | d.ins.ReplaceMarks(&query) |
| ... | @@ -616,75 +612,14 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition | ... | @@ -616,75 +612,14 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition |
| 616 | return | 612 | return |
| 617 | } | 613 | } |
| 618 | 614 | ||
| 619 | func (d *dbBase) getOperatorParams(operator string, args []interface{}) (params []interface{}) { | 615 | func (d *dbBase) GenerateOperatorSql(mi *modelInfo, fi *fieldInfo, operator string, args []interface{}) (string, []interface{}) { |
| 620 | for _, arg := range args { | 616 | sql := "" |
| 621 | val := reflect.ValueOf(arg) | 617 | params := getFlatParams(fi, args) |
| 622 | |||
| 623 | if arg == nil { | ||
| 624 | params = append(params, arg) | ||
| 625 | continue | ||
| 626 | } | ||
| 627 | |||
| 628 | kind := val.Kind() | ||
| 629 | |||
| 630 | switch kind { | ||
| 631 | case reflect.Slice, reflect.Array: | ||
| 632 | var args []interface{} | ||
| 633 | for i := 0; i < val.Len(); i++ { | ||
| 634 | v := val.Index(i) | ||
| 635 | |||
| 636 | var vu interface{} | ||
| 637 | if v.CanInterface() { | ||
| 638 | vu = v.Interface() | ||
| 639 | } | ||
| 640 | |||
| 641 | if vu == nil { | ||
| 642 | continue | ||
| 643 | } | ||
| 644 | |||
| 645 | args = append(args, vu) | ||
| 646 | } | ||
| 647 | |||
| 648 | if len(args) > 0 { | ||
| 649 | p := d.getOperatorParams(operator, args) | ||
| 650 | params = append(params, p...) | ||
| 651 | } | ||
| 652 | |||
| 653 | case reflect.Ptr, reflect.Struct: | ||
| 654 | ind := reflect.Indirect(val) | ||
| 655 | |||
| 656 | if ind.Kind() == reflect.Struct { | ||
| 657 | typ := ind.Type() | ||
| 658 | name := getFullName(typ) | ||
| 659 | var value interface{} | ||
| 660 | if mmi, ok := modelCache.getByFN(name); ok { | ||
| 661 | if _, vu, exist := d.existPk(mmi, ind); exist { | ||
| 662 | value = vu | ||
| 663 | } | ||
| 664 | } | ||
| 665 | arg = value | ||
| 666 | |||
| 667 | if arg == nil { | ||
| 668 | panic(fmt.Sprintf("`%s` operator need a valid args value, unknown table or value `%s`", operator, name)) | ||
| 669 | } | ||
| 670 | } else { | ||
| 671 | arg = ind.Interface() | ||
| 672 | } | ||
| 673 | |||
| 674 | params = append(params, arg) | ||
| 675 | |||
| 676 | default: | ||
| 677 | params = append(params, arg) | ||
| 678 | } | ||
| 679 | 618 | ||
| 619 | if len(params) == 0 { | ||
| 620 | panic(fmt.Sprintf("operator `%s` need at least one args", operator)) | ||
| 680 | } | 621 | } |
| 681 | 622 | arg := params[0] | |
| 682 | return | ||
| 683 | } | ||
| 684 | |||
| 685 | func (d *dbBase) GenerateOperatorSql(mi *modelInfo, operator string, args []interface{}) (string, []interface{}) { | ||
| 686 | sql := "" | ||
| 687 | params := d.getOperatorParams(operator, args) | ||
| 688 | 623 | ||
| 689 | if operator == "in" { | 624 | if operator == "in" { |
| 690 | marks := make([]string, len(params)) | 625 | marks := make([]string, len(params)) |
| ... | @@ -697,7 +632,6 @@ func (d *dbBase) GenerateOperatorSql(mi *modelInfo, operator string, args []inte | ... | @@ -697,7 +632,6 @@ func (d *dbBase) GenerateOperatorSql(mi *modelInfo, operator string, args []inte |
| 697 | panic(fmt.Sprintf("operator `%s` need 1 args not %d", operator, len(params))) | 632 | panic(fmt.Sprintf("operator `%s` need 1 args not %d", operator, len(params))) |
| 698 | } | 633 | } |
| 699 | sql = d.ins.OperatorSql(operator) | 634 | sql = d.ins.OperatorSql(operator) |
| 700 | arg := params[0] | ||
| 701 | switch operator { | 635 | switch operator { |
| 702 | case "exact": | 636 | case "exact": |
| 703 | if arg == nil { | 637 | if arg == nil { |
| ... | @@ -731,6 +665,10 @@ func (d *dbBase) GenerateOperatorSql(mi *modelInfo, operator string, args []inte | ... | @@ -731,6 +665,10 @@ func (d *dbBase) GenerateOperatorSql(mi *modelInfo, operator string, args []inte |
| 731 | return sql, params | 665 | return sql, params |
| 732 | } | 666 | } |
| 733 | 667 | ||
| 668 | func (d *dbBase) GenerateOperatorLeftCol(string, *string) { | ||
| 669 | |||
| 670 | } | ||
| 671 | |||
| 734 | func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string, values []interface{}) { | 672 | func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string, values []interface{}) { |
| 735 | for i, column := range cols { | 673 | for i, column := range cols { |
| 736 | val := reflect.Indirect(reflect.ValueOf(values[i])).Interface() | 674 | val := reflect.Indirect(reflect.ValueOf(values[i])).Interface() |
| ... | @@ -1006,11 +944,11 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond | ... | @@ -1006,11 +944,11 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond |
| 1006 | cols = make([]string, 0, len(exprs)) | 944 | cols = make([]string, 0, len(exprs)) |
| 1007 | infos = make([]*fieldInfo, 0, len(exprs)) | 945 | infos = make([]*fieldInfo, 0, len(exprs)) |
| 1008 | for _, ex := range exprs { | 946 | for _, ex := range exprs { |
| 1009 | index, col, name, fi, suc := tables.parseExprs(mi, strings.Split(ex, ExprSep)) | 947 | index, name, fi, suc := tables.parseExprs(mi, strings.Split(ex, ExprSep)) |
| 1010 | if suc == false { | 948 | if suc == false { |
| 1011 | panic(fmt.Errorf("unknown field/column name `%s`", ex)) | 949 | panic(fmt.Errorf("unknown field/column name `%s`", ex)) |
| 1012 | } | 950 | } |
| 1013 | cols = append(cols, fmt.Sprintf("%s.%s%s%s %s%s%s", index, Q, col, Q, Q, name, Q)) | 951 | cols = append(cols, fmt.Sprintf("%s.%s%s%s %s%s%s", index, Q, fi.column, Q, Q, name, Q)) |
| 1014 | infos = append(infos, fi) | 952 | infos = append(infos, fi) |
| 1015 | } | 953 | } |
| 1016 | } else { | 954 | } else { |
| ... | @@ -1137,3 +1075,7 @@ func (d *dbBase) TableQuote() string { | ... | @@ -1137,3 +1075,7 @@ func (d *dbBase) TableQuote() string { |
| 1137 | func (d *dbBase) ReplaceMarks(query *string) { | 1075 | func (d *dbBase) ReplaceMarks(query *string) { |
| 1138 | // default use `?` as mark, do nothing | 1076 | // default use `?` as mark, do nothing |
| 1139 | } | 1077 | } |
| 1078 | |||
| 1079 | func (d *dbBase) HasReturningID(*modelInfo, *string) bool { | ||
| 1080 | return false | ||
| 1081 | } | ... | ... |
| 1 | package orm | 1 | package orm |
| 2 | 2 | ||
| 3 | import ( | 3 | import ( |
| 4 | "fmt" | ||
| 4 | "strconv" | 5 | "strconv" |
| 5 | ) | 6 | ) |
| 6 | 7 | ||
| ... | @@ -29,6 +30,23 @@ func (d *dbBasePostgres) OperatorSql(operator string) string { | ... | @@ -29,6 +30,23 @@ func (d *dbBasePostgres) OperatorSql(operator string) string { |
| 29 | return postgresOperators[operator] | 30 | return postgresOperators[operator] |
| 30 | } | 31 | } |
| 31 | 32 | ||
| 33 | func (d *dbBasePostgres) GenerateOperatorLeftCol(operator string, leftCol *string) { | ||
| 34 | switch operator { | ||
| 35 | case "contains", "startswith", "endswith": | ||
| 36 | *leftCol = fmt.Sprintf("%s::text", *leftCol) | ||
| 37 | case "iexact", "icontains", "istartswith", "iendswith": | ||
| 38 | *leftCol = fmt.Sprintf("UPPER(%s::text)", *leftCol) | ||
| 39 | } | ||
| 40 | } | ||
| 41 | |||
| 42 | func (d *dbBasePostgres) SupportUpdateJoin() bool { | ||
| 43 | return false | ||
| 44 | } | ||
| 45 | |||
| 46 | func (d *dbBasePostgres) MaxLimit() uint64 { | ||
| 47 | return 0 | ||
| 48 | } | ||
| 49 | |||
| 32 | func (d *dbBasePostgres) TableQuote() string { | 50 | func (d *dbBasePostgres) TableQuote() string { |
| 33 | return `"` | 51 | return `"` |
| 34 | } | 52 | } |
| ... | @@ -59,7 +77,15 @@ func (d *dbBasePostgres) ReplaceMarks(query *string) { | ... | @@ -59,7 +77,15 @@ func (d *dbBasePostgres) ReplaceMarks(query *string) { |
| 59 | *query = string(data) | 77 | *query = string(data) |
| 60 | } | 78 | } |
| 61 | 79 | ||
| 62 | // func (d *dbBasePostgres) | 80 | func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) (has bool) { |
| 81 | if mi.fields.pk.auto { | ||
| 82 | if query != nil { | ||
| 83 | *query = fmt.Sprintf(`%s RETURNING "%s"`, *query, mi.fields.pk.column) | ||
| 84 | } | ||
| 85 | has = true | ||
| 86 | } | ||
| 87 | return | ||
| 88 | } | ||
| 63 | 89 | ||
| 64 | func newdbBasePostgres() dbBaser { | 90 | func newdbBasePostgres() dbBaser { |
| 65 | b := new(dbBasePostgres) | 91 | b := new(dbBasePostgres) | ... | ... |
| ... | @@ -177,7 +177,7 @@ func (t *dbTables) getJoinSql() (join string) { | ... | @@ -177,7 +177,7 @@ func (t *dbTables) getJoinSql() (join string) { |
| 177 | return | 177 | return |
| 178 | } | 178 | } |
| 179 | 179 | ||
| 180 | func (d *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, column, name string, info *fieldInfo, success bool) { | 180 | func (d *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string, info *fieldInfo, success bool) { |
| 181 | var ( | 181 | var ( |
| 182 | ffi *fieldInfo | 182 | ffi *fieldInfo |
| 183 | jtl *dbTable | 183 | jtl *dbTable |
| ... | @@ -236,7 +236,6 @@ func (d *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, column, nam | ... | @@ -236,7 +236,6 @@ func (d *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, column, nam |
| 236 | } else { | 236 | } else { |
| 237 | index = jtl.index | 237 | index = jtl.index |
| 238 | } | 238 | } |
| 239 | column = fi.column | ||
| 240 | info = fi | 239 | info = fi |
| 241 | if jtl != nil { | 240 | if jtl != nil { |
| 242 | name = jtl.name + ExprSep + fi.name | 241 | name = jtl.name + ExprSep + fi.name |
| ... | @@ -256,14 +255,14 @@ func (d *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, column, nam | ... | @@ -256,14 +255,14 @@ func (d *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, column, nam |
| 256 | 255 | ||
| 257 | if exist == false { | 256 | if exist == false { |
| 258 | index = "" | 257 | index = "" |
| 259 | column = "" | ||
| 260 | name = "" | 258 | name = "" |
| 259 | info = nil | ||
| 261 | success = false | 260 | success = false |
| 262 | return | 261 | return |
| 263 | } | 262 | } |
| 264 | } | 263 | } |
| 265 | 264 | ||
| 266 | success = index != "" && column != "" | 265 | success = index != "" && info != nil |
| 267 | return | 266 | return |
| 268 | } | 267 | } |
| 269 | 268 | ||
| ... | @@ -305,7 +304,7 @@ func (d *dbTables) getCondSql(cond *Condition, sub bool) (where string, params [ | ... | @@ -305,7 +304,7 @@ func (d *dbTables) getCondSql(cond *Condition, sub bool) (where string, params [ |
| 305 | exprs = exprs[:num] | 304 | exprs = exprs[:num] |
| 306 | } | 305 | } |
| 307 | 306 | ||
| 308 | index, column, _, _, suc := d.parseExprs(mi, exprs) | 307 | index, _, fi, suc := d.parseExprs(mi, exprs) |
| 309 | if suc == false { | 308 | if suc == false { |
| 310 | panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(p.exprs, ExprSep))) | 309 | panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(p.exprs, ExprSep))) |
| 311 | } | 310 | } |
| ... | @@ -314,9 +313,12 @@ func (d *dbTables) getCondSql(cond *Condition, sub bool) (where string, params [ | ... | @@ -314,9 +313,12 @@ func (d *dbTables) getCondSql(cond *Condition, sub bool) (where string, params [ |
| 314 | operator = "exact" | 313 | operator = "exact" |
| 315 | } | 314 | } |
| 316 | 315 | ||
| 317 | operSql, args := d.base.GenerateOperatorSql(mi, operator, p.args) | 316 | operSql, args := d.base.GenerateOperatorSql(mi, fi, operator, p.args) |
| 318 | 317 | ||
| 319 | where += fmt.Sprintf("%s.%s%s%s %s ", index, Q, column, Q, operSql) | 318 | leftCol := fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q) |
| 319 | d.base.GenerateOperatorLeftCol(operator, &leftCol) | ||
| 320 | |||
| 321 | where += fmt.Sprintf("%s %s ", leftCol, operSql) | ||
| 320 | params = append(params, args...) | 322 | params = append(params, args...) |
| 321 | 323 | ||
| 322 | } | 324 | } |
| ... | @@ -345,12 +347,12 @@ func (d *dbTables) getOrderSql(orders []string) (orderSql string) { | ... | @@ -345,12 +347,12 @@ func (d *dbTables) getOrderSql(orders []string) (orderSql string) { |
| 345 | } | 347 | } |
| 346 | exprs := strings.Split(order, ExprSep) | 348 | exprs := strings.Split(order, ExprSep) |
| 347 | 349 | ||
| 348 | index, column, _, _, suc := d.parseExprs(d.mi, exprs) | 350 | index, _, fi, suc := d.parseExprs(d.mi, exprs) |
| 349 | if suc == false { | 351 | if suc == false { |
| 350 | panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep))) | 352 | panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep))) |
| 351 | } | 353 | } |
| 352 | 354 | ||
| 353 | orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, column, Q, asc)) | 355 | orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, fi.column, Q, asc)) |
| 354 | } | 356 | } |
| 355 | 357 | ||
| 356 | orderSql = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", ")) | 358 | orderSql = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", ")) |
| ... | @@ -365,7 +367,11 @@ func (d *dbTables) getLimitSql(mi *modelInfo, offset int64, limit int) (limits s | ... | @@ -365,7 +367,11 @@ func (d *dbTables) getLimitSql(mi *modelInfo, offset int64, limit int) (limits s |
| 365 | // no limit | 367 | // no limit |
| 366 | if offset > 0 { | 368 | if offset > 0 { |
| 367 | maxLimit := d.base.MaxLimit() | 369 | maxLimit := d.base.MaxLimit() |
| 368 | limits = fmt.Sprintf("LIMIT %d OFFSET %d", maxLimit, offset) | 370 | if maxLimit == 0 { |
| 371 | limits = fmt.Sprintf("OFFSET %d", offset) | ||
| 372 | } else { | ||
| 373 | limits = fmt.Sprintf("LIMIT %d OFFSET %d", maxLimit, offset) | ||
| 374 | } | ||
| 369 | } | 375 | } |
| 370 | } else if offset <= 0 { | 376 | } else if offset <= 0 { |
| 371 | limits = fmt.Sprintf("LIMIT %d", limit) | 377 | limits = fmt.Sprintf("LIMIT %d", limit) | ... | ... |
orm/db_utils.go
0 → 100644
| 1 | package orm | ||
| 2 | |||
| 3 | import ( | ||
| 4 | "fmt" | ||
| 5 | "reflect" | ||
| 6 | "time" | ||
| 7 | ) | ||
| 8 | |||
| 9 | func getExistPk(mi *modelInfo, ind reflect.Value) (column string, value interface{}, exist bool) { | ||
| 10 | fi := mi.fields.pk | ||
| 11 | |||
| 12 | v := ind.Field(fi.fieldIndex) | ||
| 13 | if fi.fieldType&IsIntegerField > 0 { | ||
| 14 | vu := v.Int() | ||
| 15 | exist = vu > 0 | ||
| 16 | value = vu | ||
| 17 | } else { | ||
| 18 | vu := v.String() | ||
| 19 | exist = vu != "" | ||
| 20 | value = vu | ||
| 21 | } | ||
| 22 | |||
| 23 | column = fi.column | ||
| 24 | return | ||
| 25 | } | ||
| 26 | |||
| 27 | func getFlatParams(fi *fieldInfo, args []interface{}) (params []interface{}) { | ||
| 28 | |||
| 29 | outFor: | ||
| 30 | for _, arg := range args { | ||
| 31 | val := reflect.ValueOf(arg) | ||
| 32 | |||
| 33 | if arg == nil { | ||
| 34 | params = append(params, arg) | ||
| 35 | continue | ||
| 36 | } | ||
| 37 | |||
| 38 | switch v := arg.(type) { | ||
| 39 | case []byte: | ||
| 40 | case time.Time: | ||
| 41 | if fi != nil && fi.fieldType == TypeDateField { | ||
| 42 | arg = v.Format(format_Date) | ||
| 43 | } else { | ||
| 44 | arg = v.Format(format_DateTime) | ||
| 45 | } | ||
| 46 | default: | ||
| 47 | kind := val.Kind() | ||
| 48 | switch kind { | ||
| 49 | case reflect.Slice, reflect.Array: | ||
| 50 | |||
| 51 | var args []interface{} | ||
| 52 | for i := 0; i < val.Len(); i++ { | ||
| 53 | v := val.Index(i) | ||
| 54 | |||
| 55 | var vu interface{} | ||
| 56 | if v.CanInterface() { | ||
| 57 | vu = v.Interface() | ||
| 58 | } | ||
| 59 | |||
| 60 | if vu == nil { | ||
| 61 | continue | ||
| 62 | } | ||
| 63 | |||
| 64 | args = append(args, vu) | ||
| 65 | } | ||
| 66 | |||
| 67 | if len(args) > 0 { | ||
| 68 | p := getFlatParams(fi, args) | ||
| 69 | params = append(params, p...) | ||
| 70 | } | ||
| 71 | continue outFor | ||
| 72 | |||
| 73 | case reflect.Ptr, reflect.Struct: | ||
| 74 | ind := reflect.Indirect(val) | ||
| 75 | |||
| 76 | if ind.Kind() == reflect.Struct { | ||
| 77 | typ := ind.Type() | ||
| 78 | name := getFullName(typ) | ||
| 79 | var value interface{} | ||
| 80 | if mmi, ok := modelCache.getByFN(name); ok { | ||
| 81 | if _, vu, exist := getExistPk(mmi, ind); exist { | ||
| 82 | value = vu | ||
| 83 | } | ||
| 84 | } | ||
| 85 | arg = value | ||
| 86 | |||
| 87 | if arg == nil { | ||
| 88 | panic(fmt.Errorf("need a valid args value, unknown table or value `%s`", name)) | ||
| 89 | } | ||
| 90 | } else { | ||
| 91 | arg = ind.Interface() | ||
| 92 | } | ||
| 93 | } | ||
| 94 | } | ||
| 95 | params = append(params, arg) | ||
| 96 | } | ||
| 97 | return | ||
| 98 | } |
| ... | @@ -302,7 +302,8 @@ ORM_DRIVER=mysql ORM_SOURCE="root:root@/my_db?charset=utf8" go test github.com/a | ... | @@ -302,7 +302,8 @@ ORM_DRIVER=mysql ORM_SOURCE="root:root@/my_db?charset=utf8" go test github.com/a |
| 302 | queries := strings.Split(initSQLs[DBARGS.Driver], ";") | 302 | queries := strings.Split(initSQLs[DBARGS.Driver], ";") |
| 303 | 303 | ||
| 304 | for _, query := range queries { | 304 | for _, query := range queries { |
| 305 | if strings.TrimSpace(query) == "" { | 305 | query = strings.TrimSpace(query) |
| 306 | if len(query) == 0 { | ||
| 306 | continue | 307 | continue |
| 307 | } | 308 | } |
| 308 | _, err := dORM.Raw(query).Exec() | 309 | _, err := dORM.Raw(query).Exec() | ... | ... |
| ... | @@ -22,7 +22,11 @@ func NewLog(out io.Writer) *Log { | ... | @@ -22,7 +22,11 @@ func NewLog(out io.Writer) *Log { |
| 22 | func debugLogQueies(alias *alias, operaton, query string, t time.Time, err error, args ...interface{}) { | 22 | func debugLogQueies(alias *alias, operaton, query string, t time.Time, err error, args ...interface{}) { |
| 23 | sub := time.Now().Sub(t) / 1e5 | 23 | sub := time.Now().Sub(t) / 1e5 |
| 24 | elsp := float64(int(sub)) / 10.0 | 24 | elsp := float64(int(sub)) / 10.0 |
| 25 | con := fmt.Sprintf(" - %s - [Queries/%s] - [%11s / %7.1fms] - [%s]", t.Format(format_DateTime), alias.Name, operaton, elsp, query) | 25 | flag := " OK" |
| 26 | if err != nil { | ||
| 27 | flag = "FAIL" | ||
| 28 | } | ||
| 29 | con := fmt.Sprintf(" - %s - [Queries/%s] - [%s / %11s / %7.1fms] - [%s]", t.Format(format_DateTime), alias.Name, flag, operaton, elsp, query) | ||
| 26 | cons := make([]string, 0, len(args)) | 30 | cons := make([]string, 0, len(args)) |
| 27 | for _, arg := range args { | 31 | for _, arg := range args { |
| 28 | cons = append(cons, fmt.Sprintf("%v", arg)) | 32 | cons = append(cons, fmt.Sprintf("%v", arg)) | ... | ... |
| ... | @@ -27,12 +27,16 @@ func (o *rawPrepare) Close() error { | ... | @@ -27,12 +27,16 @@ func (o *rawPrepare) Close() error { |
| 27 | func newRawPreparer(rs *rawSet) (RawPreparer, error) { | 27 | func newRawPreparer(rs *rawSet) (RawPreparer, error) { |
| 28 | o := new(rawPrepare) | 28 | o := new(rawPrepare) |
| 29 | o.rs = rs | 29 | o.rs = rs |
| 30 | st, err := rs.orm.db.Prepare(rs.query) | 30 | |
| 31 | query := rs.query | ||
| 32 | rs.orm.alias.DbBaser.ReplaceMarks(&query) | ||
| 33 | |||
| 34 | st, err := rs.orm.db.Prepare(query) | ||
| 31 | if err != nil { | 35 | if err != nil { |
| 32 | return nil, err | 36 | return nil, err |
| 33 | } | 37 | } |
| 34 | if Debug { | 38 | if Debug { |
| 35 | o.stmt = newStmtQueryLog(rs.orm.alias, st, rs.query) | 39 | o.stmt = newStmtQueryLog(rs.orm.alias, st, query) |
| 36 | } else { | 40 | } else { |
| 37 | o.stmt = st | 41 | o.stmt = st |
| 38 | } | 42 | } |
| ... | @@ -53,7 +57,11 @@ func (o rawSet) SetArgs(args ...interface{}) RawSeter { | ... | @@ -53,7 +57,11 @@ func (o rawSet) SetArgs(args ...interface{}) RawSeter { |
| 53 | } | 57 | } |
| 54 | 58 | ||
| 55 | func (o *rawSet) Exec() (sql.Result, error) { | 59 | func (o *rawSet) Exec() (sql.Result, error) { |
| 56 | return o.orm.db.Exec(o.query, o.args...) | 60 | query := o.query |
| 61 | o.orm.alias.DbBaser.ReplaceMarks(&query) | ||
| 62 | |||
| 63 | args := getFlatParams(nil, o.args) | ||
| 64 | return o.orm.db.Exec(query, args...) | ||
| 57 | } | 65 | } |
| 58 | 66 | ||
| 59 | func (o *rawSet) QueryRow(...interface{}) error { | 67 | func (o *rawSet) QueryRow(...interface{}) error { |
| ... | @@ -85,8 +93,13 @@ func (o *rawSet) readValues(container interface{}) (int64, error) { | ... | @@ -85,8 +93,13 @@ func (o *rawSet) readValues(container interface{}) (int64, error) { |
| 85 | panic(fmt.Sprintf("unsupport read values type `%T`", container)) | 93 | panic(fmt.Sprintf("unsupport read values type `%T`", container)) |
| 86 | } | 94 | } |
| 87 | 95 | ||
| 96 | query := o.query | ||
| 97 | o.orm.alias.DbBaser.ReplaceMarks(&query) | ||
| 98 | |||
| 99 | args := getFlatParams(nil, o.args) | ||
| 100 | |||
| 88 | var rs *sql.Rows | 101 | var rs *sql.Rows |
| 89 | if r, err := o.orm.db.Query(o.query, o.args...); err != nil { | 102 | if r, err := o.orm.db.Query(query, args...); err != nil { |
| 90 | return 0, err | 103 | return 0, err |
| 91 | } else { | 104 | } else { |
| 92 | rs = r | 105 | rs = r | ... | ... |
| ... | @@ -51,7 +51,9 @@ func ValuesCompare(is bool, a interface{}, o T_Code, args ...interface{}) (err e | ... | @@ -51,7 +51,9 @@ func ValuesCompare(is bool, a interface{}, o T_Code, args ...interface{}) (err e |
| 51 | if v2, vo := b.(time.Time); vo { | 51 | if v2, vo := b.(time.Time); vo { |
| 52 | if arg.Get(1) != nil { | 52 | if arg.Get(1) != nil { |
| 53 | format := ToStr(arg.Get(1)) | 53 | format := ToStr(arg.Get(1)) |
| 54 | ok = v.Format(format) == v2.Format(format) | 54 | a = v.Format(format) |
| 55 | b = v2.Format(format) | ||
| 56 | ok = a == b | ||
| 55 | } else { | 57 | } else { |
| 56 | err = fmt.Errorf("compare datetime miss format") | 58 | err = fmt.Errorf("compare datetime miss format") |
| 57 | goto wrongArg | 59 | goto wrongArg |
| ... | @@ -363,6 +365,10 @@ func TestExpr(t *testing.T) { | ... | @@ -363,6 +365,10 @@ func TestExpr(t *testing.T) { |
| 363 | num, err := qs.Filter("UserName", "slene").Filter("user_name", "slene").Filter("profile__Age", 28).Count() | 365 | num, err := qs.Filter("UserName", "slene").Filter("user_name", "slene").Filter("profile__Age", 28).Count() |
| 364 | throwFail(t, err) | 366 | throwFail(t, err) |
| 365 | throwFail(t, AssertIs(num, T_Equal, 1)) | 367 | throwFail(t, AssertIs(num, T_Equal, 1)) |
| 368 | |||
| 369 | num, err = qs.Filter("created", time.Now()).Count() | ||
| 370 | throwFail(t, err) | ||
| 371 | throwFail(t, AssertIs(num, T_Equal, 3)) | ||
| 366 | } | 372 | } |
| 367 | 373 | ||
| 368 | func TestOperators(t *testing.T) { | 374 | func TestOperators(t *testing.T) { |
| ... | @@ -722,6 +728,102 @@ func TestRaw(t *testing.T) { | ... | @@ -722,6 +728,102 @@ func TestRaw(t *testing.T) { |
| 722 | throwFail(t, AssertIs(list[1], T_Equal, "3")) | 728 | throwFail(t, AssertIs(list[1], T_Equal, "3")) |
| 723 | throwFail(t, AssertIs(list[2], T_Equal, "")) | 729 | throwFail(t, AssertIs(list[2], T_Equal, "")) |
| 724 | } | 730 | } |
| 731 | |||
| 732 | pre, err := dORM.Raw("INSERT INTO tag (name) VALUES (?)").Prepare() | ||
| 733 | throwFail(t, err) | ||
| 734 | if pre != nil { | ||
| 735 | r, err := pre.Exec("name1") | ||
| 736 | throwFail(t, err) | ||
| 737 | |||
| 738 | tid, err := r.LastInsertId() | ||
| 739 | throwFail(t, err) | ||
| 740 | throwFail(t, AssertIs(tid, T_Large, 0)) | ||
| 741 | |||
| 742 | r, err = pre.Exec("name2") | ||
| 743 | throwFail(t, err) | ||
| 744 | |||
| 745 | id, err := r.LastInsertId() | ||
| 746 | throwFail(t, err) | ||
| 747 | throwFail(t, AssertIs(id, T_Equal, tid+1)) | ||
| 748 | |||
| 749 | r, err = pre.Exec("name3") | ||
| 750 | throwFail(t, err) | ||
| 751 | |||
| 752 | id, err = r.LastInsertId() | ||
| 753 | throwFail(t, err) | ||
| 754 | throwFail(t, AssertIs(id, T_Equal, tid+2)) | ||
| 755 | |||
| 756 | err = pre.Close() | ||
| 757 | throwFail(t, err) | ||
| 758 | |||
| 759 | res, err := dORM.Raw("DELETE FROM tag WHERE name IN (?, ?, ?)", []string{"name1", "name2", "name3"}).Exec() | ||
| 760 | throwFail(t, err) | ||
| 761 | |||
| 762 | num, err := res.RowsAffected() | ||
| 763 | throwFail(t, err) | ||
| 764 | throwFail(t, AssertIs(num, T_Equal, 3)) | ||
| 765 | } | ||
| 766 | |||
| 767 | case IsPostgres: | ||
| 768 | |||
| 769 | res, err := dORM.Raw(`UPDATE "user" SET "user_name" = ? WHERE "user_name" = ?`, "testing", "slene").Exec() | ||
| 770 | throwFail(t, err) | ||
| 771 | num, err := res.RowsAffected() | ||
| 772 | throwFail(t, AssertIs(num, T_Equal, 1), err) | ||
| 773 | |||
| 774 | res, err = dORM.Raw(`UPDATE "user" SET "user_name" = ? WHERE "user_name" = ?`, "slene", "testing").Exec() | ||
| 775 | throwFail(t, err) | ||
| 776 | num, err = res.RowsAffected() | ||
| 777 | throwFail(t, AssertIs(num, T_Equal, 1), err) | ||
| 778 | |||
| 779 | var maps []Params | ||
| 780 | num, err = dORM.Raw(`SELECT "user_name" FROM "user" WHERE "status" = ?`, 1).Values(&maps) | ||
| 781 | throwFail(t, err) | ||
| 782 | throwFail(t, AssertIs(num, T_Equal, 1)) | ||
| 783 | if num == 1 { | ||
| 784 | throwFail(t, AssertIs(maps[0]["user_name"], T_Equal, "slene")) | ||
| 785 | } | ||
| 786 | |||
| 787 | var lists []ParamsList | ||
| 788 | num, err = dORM.Raw(`SELECT "user_name" FROM "user" WHERE "status" = ?`, 1).ValuesList(&lists) | ||
| 789 | throwFail(t, err) | ||
| 790 | throwFail(t, AssertIs(num, T_Equal, 1)) | ||
| 791 | if num == 1 { | ||
| 792 | throwFail(t, AssertIs(lists[0][0], T_Equal, "slene")) | ||
| 793 | } | ||
| 794 | |||
| 795 | var list ParamsList | ||
| 796 | num, err = dORM.Raw(`SELECT "profile_id" FROM "user" ORDER BY id ASC`).ValuesFlat(&list) | ||
| 797 | throwFail(t, err) | ||
| 798 | throwFail(t, AssertIs(num, T_Equal, 3)) | ||
| 799 | if num == 3 { | ||
| 800 | throwFail(t, AssertIs(list[0], T_Equal, "2")) | ||
| 801 | throwFail(t, AssertIs(list[1], T_Equal, "3")) | ||
| 802 | throwFail(t, AssertIs(list[2], T_Equal, "")) | ||
| 803 | } | ||
| 804 | |||
| 805 | pre, err := dORM.Raw(`INSERT INTO "tag" ("name") VALUES (?) RETURNING "id"`).Prepare() | ||
| 806 | throwFail(t, err) | ||
| 807 | if pre != nil { | ||
| 808 | _, err := pre.Exec("name1") | ||
| 809 | throwFail(t, err) | ||
| 810 | |||
| 811 | _, err = pre.Exec("name2") | ||
| 812 | throwFail(t, err) | ||
| 813 | |||
| 814 | _, err = pre.Exec("name3") | ||
| 815 | throwFail(t, err) | ||
| 816 | |||
| 817 | err = pre.Close() | ||
| 818 | throwFail(t, err) | ||
| 819 | |||
| 820 | res, err := dORM.Raw(`DELETE FROM "tag" WHERE "name" IN (?, ?, ?)`, []string{"name1", "name2", "name3"}).Exec() | ||
| 821 | throwFail(t, err) | ||
| 822 | |||
| 823 | num, err := res.RowsAffected() | ||
| 824 | throwFail(t, err) | ||
| 825 | throwFail(t, AssertIs(num, T_Equal, 3)) | ||
| 826 | } | ||
| 725 | } | 827 | } |
| 726 | } | 828 | } |
| 727 | 829 | ... | ... |
| ... | @@ -121,10 +121,12 @@ type dbBaser interface { | ... | @@ -121,10 +121,12 @@ type dbBaser interface { |
| 121 | DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition) (int64, error) | 121 | DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition) (int64, error) |
| 122 | Count(dbQuerier, *querySet, *modelInfo, *Condition) (int64, error) | 122 | Count(dbQuerier, *querySet, *modelInfo, *Condition) (int64, error) |
| 123 | OperatorSql(string) string | 123 | OperatorSql(string) string |
| 124 | GenerateOperatorSql(*modelInfo, string, []interface{}) (string, []interface{}) | 124 | GenerateOperatorSql(*modelInfo, *fieldInfo, string, []interface{}) (string, []interface{}) |
| 125 | GenerateOperatorLeftCol(string, *string) | ||
| 125 | PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error) | 126 | PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error) |
| 126 | ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}) (int64, error) | 127 | ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}) (int64, error) |
| 127 | MaxLimit() uint64 | 128 | MaxLimit() uint64 |
| 128 | TableQuote() string | 129 | TableQuote() string |
| 129 | ReplaceMarks(*string) | 130 | ReplaceMarks(*string) |
| 131 | HasReturningID(*modelInfo, *string) bool | ||
| 130 | } | 132 | } | ... | ... |
-
Please register or sign in to post a comment