3745bb72 by slene

orm.Read support specify condition fields, orm.Update and QuerySeter All/One support omit fields.

1 parent 55fe3ba5
...@@ -49,10 +49,17 @@ type dbBase struct { ...@@ -49,10 +49,17 @@ type dbBase struct {
49 ins dbBaser 49 ins dbBaser
50 } 50 }
51 51
52 func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, skipAuto bool, insert bool, tz *time.Location) (columns []string, values []interface{}, err error) { 52 var _ dbBaser = new(dbBase)
53
54 func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, tz *time.Location) (columns []string, values []interface{}, err error) {
53 _, pkValue, _ := getExistPk(mi, ind) 55 _, pkValue, _ := getExistPk(mi, ind)
54 for _, column := range mi.fields.orders { 56 for _, column := range cols {
55 fi := mi.fields.columns[column] 57 var fi *fieldInfo
58 if fi, _ = mi.fields.GetByAny(column); fi != nil {
59 column = fi.column
60 } else {
61 panic(fmt.Sprintf("wrong db field/column name `%s` for model `%s`", column, mi.fullName))
62 }
56 if fi.dbcol == false || fi.auto && skipAuto { 63 if fi.dbcol == false || fi.auto && skipAuto {
57 continue 64 continue
58 } 65 }
...@@ -160,7 +167,7 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, ...@@ -160,7 +167,7 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string,
160 } 167 }
161 168
162 func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { 169 func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
163 _, values, err := d.collectValues(mi, ind, true, true, tz) 170 _, values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, tz)
164 if err != nil { 171 if err != nil {
165 return 0, err 172 return 0, err
166 } 173 }
...@@ -179,10 +186,25 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, ...@@ -179,10 +186,25 @@ func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value,
179 } 186 }
180 } 187 }
181 188
182 func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) error { 189 func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) error {
183 pkColumn, pkValue, ok := getExistPk(mi, ind) 190 var whereCols []string
184 if ok == false { 191 var args []interface{}
185 return ErrMissPK 192
193 // if specify cols length > 0, then use it for where condition.
194 if len(cols) > 0 {
195 var err error
196 whereCols, args, err = d.collectValues(mi, ind, cols, false, false, tz)
197 if err != nil {
198 return err
199 }
200 } else {
201 // default use pk value as where condtion.
202 pkColumn, pkValue, ok := getExistPk(mi, ind)
203 if ok == false {
204 return ErrMissPK
205 }
206 whereCols = append(whereCols, pkColumn)
207 args = append(args, pkValue)
186 } 208 }
187 209
188 Q := d.ins.TableQuote() 210 Q := d.ins.TableQuote()
...@@ -191,7 +213,10 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo ...@@ -191,7 +213,10 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
191 sels := strings.Join(mi.fields.dbcols, sep) 213 sels := strings.Join(mi.fields.dbcols, sep)
192 colsNum := len(mi.fields.dbcols) 214 colsNum := len(mi.fields.dbcols)
193 215
194 query := fmt.Sprintf("SELECT %s%s%s FROM %s%s%s WHERE %s%s%s = ?", Q, sels, Q, Q, mi.table, Q, Q, pkColumn, Q) 216 sep = fmt.Sprintf("%s = ? AND %s", Q, Q)
217 wheres := strings.Join(whereCols, sep)
218
219 query := fmt.Sprintf("SELECT %s%s%s FROM %s%s%s WHERE %s%s%s = ?", Q, sels, Q, Q, mi.table, Q, Q, wheres, Q)
195 220
196 refs := make([]interface{}, colsNum) 221 refs := make([]interface{}, colsNum)
197 for i, _ := range refs { 222 for i, _ := range refs {
...@@ -201,7 +226,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo ...@@ -201,7 +226,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
201 226
202 d.ins.ReplaceMarks(&query) 227 d.ins.ReplaceMarks(&query)
203 228
204 row := q.QueryRow(query, pkValue) 229 row := q.QueryRow(query, args...)
205 if err := row.Scan(refs...); err != nil { 230 if err := row.Scan(refs...); err != nil {
206 if err == sql.ErrNoRows { 231 if err == sql.ErrNoRows {
207 return ErrNoRows 232 return ErrNoRows
...@@ -220,7 +245,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo ...@@ -220,7 +245,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
220 } 245 }
221 246
222 func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { 247 func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
223 names, values, err := d.collectValues(mi, ind, true, true, tz) 248 names, values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, tz)
224 if err != nil { 249 if err != nil {
225 return 0, err 250 return 0, err
226 } 251 }
...@@ -254,12 +279,18 @@ func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time. ...@@ -254,12 +279,18 @@ func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
254 } 279 }
255 } 280 }
256 281
257 func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { 282 func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) {
258 pkName, pkValue, ok := getExistPk(mi, ind) 283 pkName, pkValue, ok := getExistPk(mi, ind)
259 if ok == false { 284 if ok == false {
260 return 0, ErrMissPK 285 return 0, ErrMissPK
261 } 286 }
262 setNames, setValues, err := d.collectValues(mi, ind, true, false, tz) 287
288 // if specify cols length is zero, then commit all columns.
289 if len(cols) == 0 {
290 cols = mi.fields.dbcols
291 }
292
293 setNames, setValues, err := d.collectValues(mi, ind, cols, true, false, tz)
263 if err != nil { 294 if err != nil {
264 return 0, err 295 return 0, err
265 } 296 }
...@@ -473,7 +504,7 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con ...@@ -473,7 +504,7 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
473 return 0, nil 504 return 0, nil
474 } 505 }
475 506
476 func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location) (int64, error) { 507 func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location, cols []string) (int64, error) {
477 508
478 val := reflect.ValueOf(container) 509 val := reflect.ValueOf(container)
479 ind := reflect.Indirect(val) 510 ind := reflect.Indirect(val)
...@@ -513,6 +544,41 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi ...@@ -513,6 +544,41 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
513 544
514 Q := d.ins.TableQuote() 545 Q := d.ins.TableQuote()
515 546
547 var tCols []string
548 if len(cols) > 0 {
549 hasRel := len(qs.related) > 0 || qs.relDepth > 0
550 tCols = make([]string, 0, len(cols))
551 var maps map[string]bool
552 if hasRel {
553 maps = make(map[string]bool)
554 }
555 for _, col := range cols {
556 if fi, ok := mi.fields.GetByAny(col); ok {
557 tCols = append(tCols, fi.column)
558 if hasRel {
559 maps[fi.column] = true
560 }
561 } else {
562 panic(fmt.Sprintf("wrong field/column name `%s`", col))
563 }
564 }
565 if hasRel {
566 for _, fi := range mi.fields.fieldsDB {
567 if fi.fieldType&IsRelField > 0 {
568 if maps[fi.column] == false {
569 tCols = append(tCols, fi.column)
570 }
571 }
572 }
573 }
574 } else {
575 tCols = mi.fields.dbcols
576 }
577
578 colsNum := len(tCols)
579 sep := fmt.Sprintf("%s, T0.%s", Q, Q)
580 sels := fmt.Sprintf("T0.%s%s%s", Q, strings.Join(tCols, sep), Q)
581
516 tables := newDbTables(mi, d.ins) 582 tables := newDbTables(mi, d.ins)
517 tables.parseRelated(qs.related, qs.relDepth) 583 tables.parseRelated(qs.related, qs.relDepth)
518 584
...@@ -521,18 +587,15 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi ...@@ -521,18 +587,15 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
521 limit := tables.getLimitSql(mi, offset, rlimit) 587 limit := tables.getLimitSql(mi, offset, rlimit)
522 join := tables.getJoinSql() 588 join := tables.getJoinSql()
523 589
524 colsNum := len(mi.fields.dbcols)
525 sep := fmt.Sprintf("%s, T0.%s", Q, Q)
526 cols := fmt.Sprintf("T0.%s%s%s", Q, strings.Join(mi.fields.dbcols, sep), Q)
527 for _, tbl := range tables.tables { 590 for _, tbl := range tables.tables {
528 if tbl.sel { 591 if tbl.sel {
529 colsNum += len(tbl.mi.fields.dbcols) 592 colsNum += len(tbl.mi.fields.dbcols)
530 sep := fmt.Sprintf("%s, %s.%s", Q, tbl.index, Q) 593 sep := fmt.Sprintf("%s, %s.%s", Q, tbl.index, Q)
531 cols += fmt.Sprintf(", %s.%s%s%s", tbl.index, Q, strings.Join(tbl.mi.fields.dbcols, sep), Q) 594 sels += fmt.Sprintf(", %s.%s%s%s", tbl.index, Q, strings.Join(tbl.mi.fields.dbcols, sep), Q)
532 } 595 }
533 } 596 }
534 597
535 query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s%s%s", cols, Q, mi.table, Q, join, where, orderBy, limit) 598 query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s%s%s", sels, Q, mi.table, Q, join, where, orderBy, limit)
536 599
537 d.ins.ReplaceMarks(&query) 600 d.ins.ReplaceMarks(&query)
538 601
...@@ -565,8 +628,8 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi ...@@ -565,8 +628,8 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
565 cacheM := make(map[string]*modelInfo) 628 cacheM := make(map[string]*modelInfo)
566 trefs := refs 629 trefs := refs
567 630
568 d.setColsValues(mi, &mind, mi.fields.dbcols, refs[:len(mi.fields.dbcols)], tz) 631 d.setColsValues(mi, &mind, tCols, refs[:len(tCols)], tz)
569 trefs = refs[len(mi.fields.dbcols):] 632 trefs = refs[len(tCols):]
570 633
571 for _, tbl := range tables.tables { 634 for _, tbl := range tables.tables {
572 if tbl.sel { 635 if tbl.sel {
......
...@@ -53,9 +53,9 @@ func (o *orm) getMiInd(md interface{}) (mi *modelInfo, ind reflect.Value) { ...@@ -53,9 +53,9 @@ func (o *orm) getMiInd(md interface{}) (mi *modelInfo, ind reflect.Value) {
53 panic(fmt.Sprintf("<Ormer> table: `%s` not found, maybe not RegisterModel", name)) 53 panic(fmt.Sprintf("<Ormer> table: `%s` not found, maybe not RegisterModel", name))
54 } 54 }
55 55
56 func (o *orm) Read(md interface{}) error { 56 func (o *orm) Read(md interface{}, cols ...string) error {
57 mi, ind := o.getMiInd(md) 57 mi, ind := o.getMiInd(md)
58 err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ) 58 err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols)
59 if err != nil { 59 if err != nil {
60 return err 60 return err
61 } 61 }
...@@ -80,9 +80,9 @@ func (o *orm) Insert(md interface{}) (int64, error) { ...@@ -80,9 +80,9 @@ func (o *orm) Insert(md interface{}) (int64, error) {
80 return id, nil 80 return id, nil
81 } 81 }
82 82
83 func (o *orm) Update(md interface{}) (int64, error) { 83 func (o *orm) Update(md interface{}, cols ...string) (int64, error) {
84 mi, ind := o.getMiInd(md) 84 mi, ind := o.getMiInd(md)
85 num, err := o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ) 85 num, err := o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols)
86 if err != nil { 86 if err != nil {
87 return num, err 87 return num, err
88 } 88 }
......
...@@ -105,12 +105,12 @@ func (o *querySet) PrepareInsert() (Inserter, error) { ...@@ -105,12 +105,12 @@ func (o *querySet) PrepareInsert() (Inserter, error) {
105 return newInsertSet(o.orm, o.mi) 105 return newInsertSet(o.orm, o.mi)
106 } 106 }
107 107
108 func (o *querySet) All(container interface{}) (int64, error) { 108 func (o *querySet) All(container interface{}, cols ...string) (int64, error) {
109 return o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ) 109 return o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols)
110 } 110 }
111 111
112 func (o *querySet) One(container interface{}) error { 112 func (o *querySet) One(container interface{}, cols ...string) error {
113 num, err := o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ) 113 num, err := o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container, o.orm.alias.TZ, cols)
114 if err != nil { 114 if err != nil {
115 return err 115 return err
116 } 116 }
......
...@@ -20,9 +20,9 @@ type Fielder interface { ...@@ -20,9 +20,9 @@ type Fielder interface {
20 } 20 }
21 21
22 type Ormer interface { 22 type Ormer interface {
23 Read(interface{}) error 23 Read(interface{}, ...string) error
24 Insert(interface{}) (int64, error) 24 Insert(interface{}) (int64, error)
25 Update(interface{}) (int64, error) 25 Update(interface{}, ...string) (int64, error)
26 Delete(interface{}) (int64, error) 26 Delete(interface{}) (int64, error)
27 M2mAdd(interface{}, string, ...interface{}) (int64, error) 27 M2mAdd(interface{}, string, ...interface{}) (int64, error)
28 M2mDel(interface{}, string, ...interface{}) (int64, error) 28 M2mDel(interface{}, string, ...interface{}) (int64, error)
...@@ -53,8 +53,8 @@ type QuerySeter interface { ...@@ -53,8 +53,8 @@ type QuerySeter interface {
53 Update(Params) (int64, error) 53 Update(Params) (int64, error)
54 Delete() (int64, error) 54 Delete() (int64, error)
55 PrepareInsert() (Inserter, error) 55 PrepareInsert() (Inserter, error)
56 All(interface{}) (int64, error) 56 All(interface{}, ...string) (int64, error)
57 One(interface{}) error 57 One(interface{}, ...string) error
58 Values(*[]Params, ...string) (int64, error) 58 Values(*[]Params, ...string) (int64, error)
59 ValuesList(*[]ParamsList, ...string) (int64, error) 59 ValuesList(*[]ParamsList, ...string) (int64, error)
60 ValuesFlat(*ParamsList, string) (int64, error) 60 ValuesFlat(*ParamsList, string) (int64, error)
...@@ -111,12 +111,12 @@ type txEnder interface { ...@@ -111,12 +111,12 @@ type txEnder interface {
111 } 111 }
112 112
113 type dbBaser interface { 113 type dbBaser interface {
114 Read(dbQuerier, *modelInfo, reflect.Value, *time.Location) error 114 Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) error
115 Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) 115 Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
116 InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) 116 InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
117 Update(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) 117 Update(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error)
118 Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) 118 Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
119 ReadBatch(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, *time.Location) (int64, error) 119 ReadBatch(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, *time.Location, []string) (int64, error)
120 SupportUpdateJoin() bool 120 SupportUpdateJoin() bool
121 UpdateBatch(dbQuerier, *querySet, *modelInfo, *Condition, Params, *time.Location) (int64, error) 121 UpdateBatch(dbQuerier, *querySet, *modelInfo, *Condition, Params, *time.Location) (int64, error)
122 DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error) 122 DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition, *time.Location) (int64, error)
......
...@@ -38,6 +38,11 @@ func (f StrTo) Float64() (float64, error) { ...@@ -38,6 +38,11 @@ func (f StrTo) Float64() (float64, error) {
38 return strconv.ParseFloat(f.String(), 64) 38 return strconv.ParseFloat(f.String(), 64)
39 } 39 }
40 40
41 func (f StrTo) Int() (int, error) {
42 v, err := strconv.ParseInt(f.String(), 10, 32)
43 return int(v), err
44 }
45
41 func (f StrTo) Int8() (int8, error) { 46 func (f StrTo) Int8() (int8, error) {
42 v, err := strconv.ParseInt(f.String(), 10, 8) 47 v, err := strconv.ParseInt(f.String(), 10, 8)
43 return int8(v), err 48 return int8(v), err
...@@ -58,6 +63,11 @@ func (f StrTo) Int64() (int64, error) { ...@@ -58,6 +63,11 @@ func (f StrTo) Int64() (int64, error) {
58 return int64(v), err 63 return int64(v), err
59 } 64 }
60 65
66 func (f StrTo) Uint() (uint, error) {
67 v, err := strconv.ParseUint(f.String(), 10, 32)
68 return uint(v), err
69 }
70
61 func (f StrTo) Uint8() (uint8, error) { 71 func (f StrTo) Uint8() (uint8, error) {
62 v, err := strconv.ParseUint(f.String(), 10, 8) 72 v, err := strconv.ParseUint(f.String(), 10, 8)
63 return uint8(v), err 73 return uint8(v), err
......
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!