b766f65c by slene

#436 support insert multi

1 parent 6f3a759b
...@@ -51,7 +51,13 @@ type dbBase struct { ...@@ -51,7 +51,13 @@ type dbBase struct {
51 51
52 var _ dbBaser = new(dbBase) 52 var _ dbBaser = new(dbBase)
53 53
54 func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, tz *time.Location) (columns []string, values []interface{}, err error) { 54 func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, names *[]string, tz *time.Location) (values []interface{}, err error) {
55 var columns []string
56
57 if names != nil {
58 columns = *names
59 }
60
55 for _, column := range cols { 61 for _, column := range cols {
56 var fi *fieldInfo 62 var fi *fieldInfo
57 if fi, _ = mi.fields.GetByAny(column); fi != nil { 63 if fi, _ = mi.fields.GetByAny(column); fi != nil {
...@@ -64,11 +70,20 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, ...@@ -64,11 +70,20 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string,
64 } 70 }
65 value, err := d.collectFieldValue(mi, fi, ind, insert, tz) 71 value, err := d.collectFieldValue(mi, fi, ind, insert, tz)
66 if err != nil { 72 if err != nil {
67 return nil, nil, err 73 return nil, err
68 } 74 }
75
76 if names != nil {
69 columns = append(columns, column) 77 columns = append(columns, column)
78 }
79
70 values = append(values, value) 80 values = append(values, value)
71 } 81 }
82
83 if names != nil {
84 *names = columns
85 }
86
72 return 87 return
73 } 88 }
74 89
...@@ -166,7 +181,7 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, ...@@ -166,7 +181,7 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string,
166 } 181 }
167 182
168 func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { 183 func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
169 _, values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, tz) 184 values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz)
170 if err != nil { 185 if err != nil {
171 return 0, err 186 return 0, err
172 } 187 }
...@@ -192,7 +207,8 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo ...@@ -192,7 +207,8 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
192 // if specify cols length > 0, then use it for where condition. 207 // if specify cols length > 0, then use it for where condition.
193 if len(cols) > 0 { 208 if len(cols) > 0 {
194 var err error 209 var err error
195 whereCols, args, err = d.collectValues(mi, ind, cols, false, false, tz) 210 whereCols = make([]string, 0, len(cols))
211 args, err = d.collectValues(mi, ind, cols, false, false, &whereCols, tz)
196 if err != nil { 212 if err != nil {
197 return err 213 return err
198 } 214 }
...@@ -202,7 +218,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo ...@@ -202,7 +218,7 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
202 if ok == false { 218 if ok == false {
203 return ErrMissPK 219 return ErrMissPK
204 } 220 }
205 whereCols = append(whereCols, pkColumn) 221 whereCols = []string{pkColumn}
206 args = append(args, pkValue) 222 args = append(args, pkValue)
207 } 223 }
208 224
...@@ -244,15 +260,72 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo ...@@ -244,15 +260,72 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Lo
244 } 260 }
245 261
246 func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { 262 func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) {
247 names, values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, tz) 263 names := make([]string, 0, len(mi.fields.dbcols)-1)
264 values, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, tz)
248 if err != nil { 265 if err != nil {
249 return 0, err 266 return 0, err
250 } 267 }
251 268
252 return d.InsertValue(q, mi, names, values) 269 return d.InsertValue(q, mi, false, names, values)
270 }
271
272 func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bulk int, tz *time.Location) (int64, error) {
273 var (
274 cnt int64
275 nums int
276 values []interface{}
277 names []string
278 )
279
280 // typ := reflect.Indirect(mi.addrField).Type()
281
282 length := sind.Len()
283
284 for i := 1; i <= length; i++ {
285
286 ind := reflect.Indirect(sind.Index(i - 1))
287
288 // Is this needed ?
289 // if !ind.Type().AssignableTo(typ) {
290 // return cnt, ErrArgs
291 // }
292
293 if i == 1 {
294 vus, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, tz)
295 if err != nil {
296 return cnt, err
297 }
298 values = make([]interface{}, bulk*len(vus))
299 nums += copy(values, vus)
300
301 } else {
302
303 vus, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz)
304 if err != nil {
305 return cnt, err
306 }
307
308 if len(vus) != len(names) {
309 return cnt, ErrArgs
310 }
311
312 nums += copy(values[nums:], vus)
313 }
314
315 if i > 1 && i%bulk == 0 || length == i {
316 num, err := d.InsertValue(q, mi, true, names, values[:nums])
317 if err != nil {
318 return cnt, err
319 }
320 cnt += num
321 nums = 0
322 }
323 }
324
325 return cnt, nil
253 } 326 }
254 327
255 func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, names []string, values []interface{}) (int64, error) { 328 func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) {
256 Q := d.ins.TableQuote() 329 Q := d.ins.TableQuote()
257 330
258 marks := make([]string, len(names)) 331 marks := make([]string, len(names))
...@@ -264,21 +337,30 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, names []string, values ...@@ -264,21 +337,30 @@ func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, names []string, values
264 qmarks := strings.Join(marks, ", ") 337 qmarks := strings.Join(marks, ", ")
265 columns := strings.Join(names, sep) 338 columns := strings.Join(names, sep)
266 339
340 multi := len(values) / len(names)
341
342 if isMulti {
343 qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks
344 }
345
267 query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.table, Q, Q, columns, Q, qmarks) 346 query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.table, Q, Q, columns, Q, qmarks)
268 347
269 d.ins.ReplaceMarks(&query) 348 d.ins.ReplaceMarks(&query)
270 349
271 if d.ins.HasReturningID(mi, &query) { 350 if isMulti || !d.ins.HasReturningID(mi, &query) {
272 row := q.QueryRow(query, values...)
273 var id int64
274 err := row.Scan(&id)
275 return id, err
276 } else {
277 if res, err := q.Exec(query, values...); err == nil { 351 if res, err := q.Exec(query, values...); err == nil {
352 if isMulti {
353 return res.RowsAffected()
354 }
278 return res.LastInsertId() 355 return res.LastInsertId()
279 } else { 356 } else {
280 return 0, err 357 return 0, err
281 } 358 }
359 } else {
360 row := q.QueryRow(query, values...)
361 var id int64
362 err := row.Scan(&id)
363 return id, err
282 } 364 }
283 } 365 }
284 366
...@@ -288,12 +370,17 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time. ...@@ -288,12 +370,17 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
288 return 0, ErrMissPK 370 return 0, ErrMissPK
289 } 371 }
290 372
373 var setNames []string
374
291 // if specify cols length is zero, then commit all columns. 375 // if specify cols length is zero, then commit all columns.
292 if len(cols) == 0 { 376 if len(cols) == 0 {
293 cols = mi.fields.dbcols 377 cols = mi.fields.dbcols
378 setNames = make([]string, 0, len(mi.fields.dbcols)-1)
379 } else {
380 setNames = make([]string, 0, len(cols))
294 } 381 }
295 382
296 setNames, setValues, err := d.collectValues(mi, ind, cols, true, false, tz) 383 setValues, err := d.collectValues(mi, ind, cols, true, false, &setNames, tz)
297 if err != nil { 384 if err != nil {
298 return 0, err 385 return 0, err
299 } 386 }
......
...@@ -214,8 +214,6 @@ loopFor: ...@@ -214,8 +214,6 @@ loopFor:
214 fi, ok = mmi.fields.GetByAny(ex) 214 fi, ok = mmi.fields.GetByAny(ex)
215 } 215 }
216 216
217 // fmt.Println(ex, fi.name, fiN)
218
219 _ = okN 217 _ = okN
220 218
221 if ok { 219 if ok {
......
...@@ -25,6 +25,7 @@ var ( ...@@ -25,6 +25,7 @@ var (
25 ErrMultiRows = errors.New("<QuerySeter> return multi rows") 25 ErrMultiRows = errors.New("<QuerySeter> return multi rows")
26 ErrNoRows = errors.New("<QuerySeter> no row found") 26 ErrNoRows = errors.New("<QuerySeter> no row found")
27 ErrStmtClosed = errors.New("<QuerySeter> stmt already closed") 27 ErrStmtClosed = errors.New("<QuerySeter> stmt already closed")
28 ErrArgs = errors.New("<Ormer> args error may be empty")
28 ErrNotImplement = errors.New("have not implement") 29 ErrNotImplement = errors.New("have not implement")
29 ) 30 )
30 31
...@@ -39,11 +40,11 @@ type orm struct { ...@@ -39,11 +40,11 @@ type orm struct {
39 40
40 var _ Ormer = new(orm) 41 var _ Ormer = new(orm)
41 42
42 func (o *orm) getMiInd(md interface{}) (mi *modelInfo, ind reflect.Value) { 43 func (o *orm) getMiInd(md interface{}, needPtr bool) (mi *modelInfo, ind reflect.Value) {
43 val := reflect.ValueOf(md) 44 val := reflect.ValueOf(md)
44 ind = reflect.Indirect(val) 45 ind = reflect.Indirect(val)
45 typ := ind.Type() 46 typ := ind.Type()
46 if val.Kind() != reflect.Ptr { 47 if needPtr && val.Kind() != reflect.Ptr {
47 panic(fmt.Errorf("<Ormer> cannot use non-ptr model struct `%s`", getFullName(typ))) 48 panic(fmt.Errorf("<Ormer> cannot use non-ptr model struct `%s`", getFullName(typ)))
48 } 49 }
49 name := getFullName(typ) 50 name := getFullName(typ)
...@@ -62,7 +63,7 @@ func (o *orm) getFieldInfo(mi *modelInfo, name string) *fieldInfo { ...@@ -62,7 +63,7 @@ func (o *orm) getFieldInfo(mi *modelInfo, name string) *fieldInfo {
62 } 63 }
63 64
64 func (o *orm) Read(md interface{}, cols ...string) error { 65 func (o *orm) Read(md interface{}, cols ...string) error {
65 mi, ind := o.getMiInd(md) 66 mi, ind := o.getMiInd(md, true)
66 err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols) 67 err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols)
67 if err != nil { 68 if err != nil {
68 return err 69 return err
...@@ -71,12 +72,18 @@ func (o *orm) Read(md interface{}, cols ...string) error { ...@@ -71,12 +72,18 @@ func (o *orm) Read(md interface{}, cols ...string) error {
71 } 72 }
72 73
73 func (o *orm) Insert(md interface{}) (int64, error) { 74 func (o *orm) Insert(md interface{}) (int64, error) {
74 mi, ind := o.getMiInd(md) 75 mi, ind := o.getMiInd(md, true)
75 id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ) 76 id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ)
76 if err != nil { 77 if err != nil {
77 return id, err 78 return id, err
78 } 79 }
79 if id > 0 { 80
81 o.setPk(mi, ind, id)
82
83 return id, nil
84 }
85
86 func (o *orm) setPk(mi *modelInfo, ind reflect.Value, id int64) {
80 if mi.fields.pk.auto { 87 if mi.fields.pk.auto {
81 if mi.fields.pk.fieldType&IsPostiveIntegerField > 0 { 88 if mi.fields.pk.fieldType&IsPostiveIntegerField > 0 {
82 ind.Field(mi.fields.pk.fieldIndex).SetUint(uint64(id)) 89 ind.Field(mi.fields.pk.fieldIndex).SetUint(uint64(id))
...@@ -84,12 +91,44 @@ func (o *orm) Insert(md interface{}) (int64, error) { ...@@ -84,12 +91,44 @@ func (o *orm) Insert(md interface{}) (int64, error) {
84 ind.Field(mi.fields.pk.fieldIndex).SetInt(id) 91 ind.Field(mi.fields.pk.fieldIndex).SetInt(id)
85 } 92 }
86 } 93 }
94 }
95
96 func (o *orm) InsertMulti(bulk int, mds interface{}) (int64, error) {
97 var cnt int64
98
99 sind := reflect.Indirect(reflect.ValueOf(mds))
100
101 switch sind.Kind() {
102 case reflect.Array, reflect.Slice:
103 if sind.Len() == 0 {
104 return cnt, ErrArgs
87 } 105 }
88 return id, nil 106 default:
107 return cnt, ErrArgs
108 }
109
110 if bulk <= 1 {
111 for i := 0; i < sind.Len(); i++ {
112 ind := sind.Index(i)
113 mi, _ := o.getMiInd(ind.Interface(), false)
114 id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ)
115 if err != nil {
116 return cnt, err
117 }
118
119 o.setPk(mi, ind, id)
120
121 cnt += 1
122 }
123 } else {
124 mi, _ := o.getMiInd(sind.Index(0).Interface(), false)
125 return o.alias.DbBaser.InsertMulti(o.db, mi, sind, bulk, o.alias.TZ)
126 }
127 return cnt, nil
89 } 128 }
90 129
91 func (o *orm) Update(md interface{}, cols ...string) (int64, error) { 130 func (o *orm) Update(md interface{}, cols ...string) (int64, error) {
92 mi, ind := o.getMiInd(md) 131 mi, ind := o.getMiInd(md, true)
93 num, err := o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols) 132 num, err := o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols)
94 if err != nil { 133 if err != nil {
95 return num, err 134 return num, err
...@@ -98,25 +137,19 @@ func (o *orm) Update(md interface{}, cols ...string) (int64, error) { ...@@ -98,25 +137,19 @@ func (o *orm) Update(md interface{}, cols ...string) (int64, error) {
98 } 137 }
99 138
100 func (o *orm) Delete(md interface{}) (int64, error) { 139 func (o *orm) Delete(md interface{}) (int64, error) {
101 mi, ind := o.getMiInd(md) 140 mi, ind := o.getMiInd(md, true)
102 num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ) 141 num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ)
103 if err != nil { 142 if err != nil {
104 return num, err 143 return num, err
105 } 144 }
106 if num > 0 { 145 if num > 0 {
107 if mi.fields.pk.auto { 146 o.setPk(mi, ind, 0)
108 if mi.fields.pk.fieldType&IsPostiveIntegerField > 0 {
109 ind.Field(mi.fields.pk.fieldIndex).SetUint(0)
110 } else {
111 ind.Field(mi.fields.pk.fieldIndex).SetInt(0)
112 }
113 }
114 } 147 }
115 return num, nil 148 return num, nil
116 } 149 }
117 150
118 func (o *orm) QueryM2M(md interface{}, name string) QueryM2Mer { 151 func (o *orm) QueryM2M(md interface{}, name string) QueryM2Mer {
119 mi, ind := o.getMiInd(md) 152 mi, ind := o.getMiInd(md, true)
120 fi := o.getFieldInfo(mi, name) 153 fi := o.getFieldInfo(mi, name)
121 154
122 switch { 155 switch {
...@@ -197,7 +230,7 @@ func (o *orm) QueryRelated(md interface{}, name string) QuerySeter { ...@@ -197,7 +230,7 @@ func (o *orm) QueryRelated(md interface{}, name string) QuerySeter {
197 } 230 }
198 231
199 func (o *orm) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, reflect.Value, QuerySeter) { 232 func (o *orm) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, reflect.Value, QuerySeter) {
200 mi, ind := o.getMiInd(md) 233 mi, ind := o.getMiInd(md, true)
201 fi := o.getFieldInfo(mi, name) 234 fi := o.getFieldInfo(mi, name)
202 235
203 _, _, exist := getExistPk(mi, ind) 236 _, _, exist := getExistPk(mi, ind)
......
...@@ -44,7 +44,8 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) { ...@@ -44,7 +44,8 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
44 44
45 names := []string{mfi.column, rfi.column} 45 names := []string{mfi.column, rfi.column}
46 46
47 var nums int64 47 values := make([]interface{}, 0, len(models)*2)
48
48 for _, md := range models { 49 for _, md := range models {
49 50
50 ind := reflect.Indirect(reflect.ValueOf(md)) 51 ind := reflect.Indirect(reflect.ValueOf(md))
...@@ -59,16 +60,11 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) { ...@@ -59,16 +60,11 @@ func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
59 } 60 }
60 } 61 }
61 62
62 values := []interface{}{v1, v2} 63 values = append(values, v1, v2)
63 _, err := dbase.InsertValue(orm.db, mi, names, values)
64 if err != nil {
65 return nums, err
66 }
67 64
68 nums += 1
69 } 65 }
70 66
71 return nums, nil 67 return dbase.InsertValue(orm.db, mi, true, names, values)
72 } 68 }
73 69
74 func (o *queryM2M) Remove(mds ...interface{}) (int64, error) { 70 func (o *queryM2M) Remove(mds ...interface{}) (int64, error) {
......
...@@ -21,6 +21,7 @@ type Fielder interface { ...@@ -21,6 +21,7 @@ type Fielder interface {
21 type Ormer interface { 21 type Ormer interface {
22 Read(interface{}, ...string) error 22 Read(interface{}, ...string) error
23 Insert(interface{}) (int64, error) 23 Insert(interface{}) (int64, error)
24 InsertMulti(int, interface{}) (int64, error)
24 Update(interface{}, ...string) (int64, error) 25 Update(interface{}, ...string) (int64, error)
25 Delete(interface{}) (int64, error) 26 Delete(interface{}) (int64, error)
26 LoadRelated(interface{}, string, ...interface{}) (int64, error) 27 LoadRelated(interface{}, string, ...interface{}) (int64, error)
...@@ -109,7 +110,8 @@ type txEnder interface { ...@@ -109,7 +110,8 @@ type txEnder interface {
109 type dbBaser interface { 110 type dbBaser interface {
110 Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) error 111 Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) error
111 Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) 112 Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
112 InsertValue(dbQuerier, *modelInfo, []string, []interface{}) (int64, error) 113 InsertMulti(dbQuerier, *modelInfo, reflect.Value, int, *time.Location) (int64, error)
114 InsertValue(dbQuerier, *modelInfo, bool, []string, []interface{}) (int64, error)
113 InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) 115 InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
114 Update(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) 116 Update(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error)
115 Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) 117 Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
......
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!