6c41e6dd by slene

orm add sqlite3 support, may be support postgres in next commit

1 parent 9631c663
...@@ -43,395 +43,8 @@ var ( ...@@ -43,395 +43,8 @@ var (
43 "isnull": true, 43 "isnull": true,
44 // "search": true, 44 // "search": true,
45 } 45 }
46 operatorsSQL = map[string]string{
47 "exact": "= ?",
48 "iexact": "LIKE ?",
49 "contains": "LIKE BINARY ?",
50 "icontains": "LIKE ?",
51 // "regex": "REGEXP BINARY ?",
52 // "iregex": "REGEXP ?",
53 "gt": "> ?",
54 "gte": ">= ?",
55 "lt": "< ?",
56 "lte": "<= ?",
57 "startswith": "LIKE BINARY ?",
58 "endswith": "LIKE BINARY ?",
59 "istartswith": "LIKE ?",
60 "iendswith": "LIKE ?",
61 }
62 ) 46 )
63 47
64 type dbTable struct {
65 id int
66 index string
67 name string
68 names []string
69 sel bool
70 inner bool
71 mi *modelInfo
72 fi *fieldInfo
73 jtl *dbTable
74 }
75
76 type dbTables struct {
77 tablesM map[string]*dbTable
78 tables []*dbTable
79 mi *modelInfo
80 base dbBaser
81 }
82
83 func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool) *dbTable {
84 name := strings.Join(names, ExprSep)
85 if j, ok := t.tablesM[name]; ok {
86 j.name = name
87 j.mi = mi
88 j.fi = fi
89 j.inner = inner
90 } else {
91 i := len(t.tables) + 1
92 jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil}
93 t.tablesM[name] = jt
94 t.tables = append(t.tables, jt)
95 }
96 return t.tablesM[name]
97 }
98
99 func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool) (*dbTable, bool) {
100 name := strings.Join(names, ExprSep)
101 if _, ok := t.tablesM[name]; ok == false {
102 i := len(t.tables) + 1
103 jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil}
104 t.tablesM[name] = jt
105 t.tables = append(t.tables, jt)
106 return jt, true
107 }
108 return t.tablesM[name], false
109 }
110
111 func (t *dbTables) get(name string) (*dbTable, bool) {
112 j, ok := t.tablesM[name]
113 return j, ok
114 }
115
116 func (t *dbTables) loopDepth(depth int, prefix string, fi *fieldInfo, related []string) []string {
117 if depth < 0 || fi.fieldType == RelManyToMany {
118 return related
119 }
120
121 if prefix == "" {
122 prefix = fi.name
123 } else {
124 prefix = prefix + ExprSep + fi.name
125 }
126 related = append(related, prefix)
127
128 depth--
129 for _, fi := range fi.relModelInfo.fields.fieldsRel {
130 related = t.loopDepth(depth, prefix, fi, related)
131 }
132
133 return related
134 }
135
136 func (t *dbTables) parseRelated(rels []string, depth int) {
137
138 relsNum := len(rels)
139 related := make([]string, relsNum)
140 copy(related, rels)
141
142 relDepth := depth
143
144 if relsNum != 0 {
145 relDepth = 0
146 }
147
148 relDepth--
149 for _, fi := range t.mi.fields.fieldsRel {
150 related = t.loopDepth(relDepth, "", fi, related)
151 }
152
153 for i, s := range related {
154 var (
155 exs = strings.Split(s, ExprSep)
156 names = make([]string, 0, len(exs))
157 mmi = t.mi
158 cansel = true
159 jtl *dbTable
160 )
161 for _, ex := range exs {
162 if fi, ok := mmi.fields.GetByAny(ex); ok && fi.rel && fi.fieldType != RelManyToMany {
163 names = append(names, fi.name)
164 mmi = fi.relModelInfo
165
166 jt := t.set(names, mmi, fi, fi.null == false)
167 jt.jtl = jtl
168
169 if fi.reverse {
170 cansel = false
171 }
172
173 if cansel {
174 jt.sel = depth > 0
175
176 if i < relsNum {
177 jt.sel = true
178 }
179 }
180
181 jtl = jt
182
183 } else {
184 panic(fmt.Sprintf("unknown model/table name `%s`", ex))
185 }
186 }
187 }
188 }
189
190 func (t *dbTables) getJoinSql() (join string) {
191 for _, jt := range t.tables {
192 if jt.inner {
193 join += "INNER JOIN "
194 } else {
195 join += "LEFT OUTER JOIN "
196 }
197 var (
198 table string
199 t1, t2 string
200 c1, c2 string
201 )
202 t1 = "T0"
203 if jt.jtl != nil {
204 t1 = jt.jtl.index
205 }
206 t2 = jt.index
207 table = jt.mi.table
208
209 switch {
210 case jt.fi.fieldType == RelManyToMany || jt.fi.reverse && jt.fi.reverseFieldInfo.fieldType == RelManyToMany:
211 c1 = jt.fi.mi.fields.pk.column
212 for _, ffi := range jt.mi.fields.fieldsRel {
213 if jt.fi.mi == ffi.relModelInfo {
214 c2 = ffi.column
215 break
216 }
217 }
218 default:
219 c1 = jt.fi.column
220 c2 = jt.fi.relModelInfo.fields.pk.column
221
222 if jt.fi.reverse {
223 c1 = jt.mi.fields.pk.column
224 c2 = jt.fi.reverseFieldInfo.column
225 }
226 }
227
228 join += fmt.Sprintf("`%s` %s ON %s.`%s` = %s.`%s` ", table, t2,
229 t2, c2, t1, c1)
230 }
231 return
232 }
233
234 func (d *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, column, name string, info *fieldInfo, success bool) {
235 var (
236 ffi *fieldInfo
237 jtl *dbTable
238 mmi = mi
239 )
240
241 num := len(exprs) - 1
242 names := make([]string, 0)
243
244 for i, ex := range exprs {
245 exist := false
246
247 check:
248 fi, ok := mmi.fields.GetByAny(ex)
249
250 if ok {
251
252 if num != i {
253 names = append(names, fi.name)
254
255 switch {
256 case fi.rel:
257 mmi = fi.relModelInfo
258 if fi.fieldType == RelManyToMany {
259 mmi = fi.relThroughModelInfo
260 }
261 case fi.reverse:
262 mmi = fi.reverseFieldInfo.mi
263 if fi.reverseFieldInfo.fieldType == RelManyToMany {
264 mmi = fi.reverseFieldInfo.relThroughModelInfo
265 }
266 default:
267 return
268 }
269
270 jt, _ := d.add(names, mmi, fi, fi.null == false)
271 jt.jtl = jtl
272 jtl = jt
273
274 if fi.rel && fi.fieldType == RelManyToMany {
275 ex = fi.relModelInfo.name
276 goto check
277 }
278
279 if fi.reverse && fi.reverseFieldInfo.fieldType == RelManyToMany {
280 ex = fi.reverseFieldInfo.mi.name
281 goto check
282 }
283
284 exist = true
285
286 } else {
287
288 if ffi == nil {
289 index = "T0"
290 } else {
291 index = jtl.index
292 }
293 column = fi.column
294 info = fi
295 if jtl != nil {
296 name = jtl.name + ExprSep + fi.name
297 } else {
298 name = fi.name
299 }
300
301 switch fi.fieldType {
302 case RelManyToMany, RelReverseMany:
303 default:
304 exist = true
305 }
306 }
307
308 ffi = fi
309 }
310
311 if exist == false {
312 index = ""
313 column = ""
314 name = ""
315 success = false
316 return
317 }
318 }
319
320 success = index != "" && column != ""
321 return
322 }
323
324 func (d *dbTables) getCondSql(cond *Condition, sub bool) (where string, params []interface{}) {
325 if cond == nil || cond.IsEmpty() {
326 return
327 }
328
329 mi := d.mi
330
331 // outFor:
332 for i, p := range cond.params {
333 if i > 0 {
334 if p.isOr {
335 where += "OR "
336 } else {
337 where += "AND "
338 }
339 }
340 if p.isNot {
341 where += "NOT "
342 }
343 if p.isCond {
344 w, ps := d.getCondSql(p.cond, true)
345 if w != "" {
346 w = fmt.Sprintf("( %s) ", w)
347 }
348 where += w
349 params = append(params, ps...)
350 } else {
351 exprs := p.exprs
352
353 num := len(exprs) - 1
354 operator := ""
355 if operators[exprs[num]] {
356 operator = exprs[num]
357 exprs = exprs[:num]
358 }
359
360 index, column, _, _, suc := d.parseExprs(mi, exprs)
361 if suc == false {
362 panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(p.exprs, ExprSep)))
363 }
364
365 if operator == "" {
366 operator = "exact"
367 }
368
369 operSql, args := d.base.GetOperatorSql(mi, operator, p.args)
370
371 where += fmt.Sprintf("%s.`%s` %s ", index, column, operSql)
372 params = append(params, args...)
373
374 }
375 }
376
377 if sub == false && where != "" {
378 where = "WHERE " + where
379 }
380
381 return
382 }
383
384 func (d *dbTables) getOrderSql(orders []string) (orderSql string) {
385 if len(orders) == 0 {
386 return
387 }
388
389 orderSqls := make([]string, 0, len(orders))
390 for _, order := range orders {
391 asc := "ASC"
392 if order[0] == '-' {
393 asc = "DESC"
394 order = order[1:]
395 }
396 exprs := strings.Split(order, ExprSep)
397
398 index, column, _, _, suc := d.parseExprs(d.mi, exprs)
399 if suc == false {
400 panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep)))
401 }
402
403 orderSqls = append(orderSqls, fmt.Sprintf("%s.`%s` %s", index, column, asc))
404 }
405
406 orderSql = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", "))
407 return
408 }
409
410 func (d *dbTables) getLimitSql(offset int64, limit int) (limits string) {
411 if limit == 0 {
412 limit = DefaultRowsLimit
413 }
414 if limit < 0 {
415 // no limit
416 if offset > 0 {
417 limits = fmt.Sprintf("LIMIT 18446744073709551615 OFFSET %d", offset)
418 }
419 } else if offset <= 0 {
420 limits = fmt.Sprintf("LIMIT %d", limit)
421 } else {
422 limits = fmt.Sprintf("LIMIT %d OFFSET %d", limit, offset)
423 }
424 return
425 }
426
427 func newDbTables(mi *modelInfo, base dbBaser) *dbTables {
428 tables := &dbTables{}
429 tables.tablesM = make(map[string]*dbTable)
430 tables.mi = mi
431 tables.base = base
432 return tables
433 }
434
435 type dbBase struct { 48 type dbBase struct {
436 ins dbBaser 49 ins dbBaser
437 } 50 }
...@@ -528,6 +141,8 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, skipAuto bool, ...@@ -528,6 +141,8 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, skipAuto bool,
528 } 141 }
529 142
530 func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, error) { 143 func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, error) {
144 Q := d.ins.TableQuote()
145
531 dbcols := make([]string, 0, len(mi.fields.dbcols)) 146 dbcols := make([]string, 0, len(mi.fields.dbcols))
532 marks := make([]string, 0, len(mi.fields.dbcols)) 147 marks := make([]string, 0, len(mi.fields.dbcols))
533 for _, fi := range mi.fields.fieldsDB { 148 for _, fi := range mi.fields.fieldsDB {
...@@ -537,9 +152,13 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, ...@@ -537,9 +152,13 @@ func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string,
537 } 152 }
538 } 153 }
539 qmarks := strings.Join(marks, ", ") 154 qmarks := strings.Join(marks, ", ")
540 columns := strings.Join(dbcols, "`,`") 155 sep := fmt.Sprintf("%s, %s", Q, Q)
156 columns := strings.Join(dbcols, sep)
157
158 query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.table, Q, Q, columns, Q, qmarks)
159
160 d.ins.ReplaceMarks(&query)
541 161
542 query := fmt.Sprintf("INSERT INTO `%s` (`%s`) VALUES (%s)", mi.table, columns, qmarks)
543 stmt, err := q.Prepare(query) 162 stmt, err := q.Prepare(query)
544 return stmt, query, err 163 return stmt, query, err
545 } 164 }
...@@ -563,10 +182,13 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value) error { ...@@ -563,10 +182,13 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value) error {
563 return ErrMissPK 182 return ErrMissPK
564 } 183 }
565 184
566 sels := strings.Join(mi.fields.dbcols, "`, `") 185 Q := d.ins.TableQuote()
186
187 sep := fmt.Sprintf("%s, %s", Q, Q)
188 sels := strings.Join(mi.fields.dbcols, sep)
567 colsNum := len(mi.fields.dbcols) 189 colsNum := len(mi.fields.dbcols)
568 190
569 query := fmt.Sprintf("SELECT `%s` FROM `%s` WHERE `%s` = ?", sels, mi.table, pkColumn) 191 query := fmt.Sprintf("SELECT %s%s%s FROM %s%s%s WHERE %s%s%s = ?", Q, sels, Q, Q, mi.table, Q, Q, pkColumn, Q)
570 192
571 refs := make([]interface{}, colsNum) 193 refs := make([]interface{}, colsNum)
572 for i, _ := range refs { 194 for i, _ := range refs {
...@@ -574,6 +196,8 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value) error { ...@@ -574,6 +196,8 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value) error {
574 refs[i] = &ref 196 refs[i] = &ref
575 } 197 }
576 198
199 d.ins.ReplaceMarks(&query)
200
577 row := q.QueryRow(query, pkValue) 201 row := q.QueryRow(query, pkValue)
578 if err := row.Scan(refs...); err != nil { 202 if err := row.Scan(refs...); err != nil {
579 if err == sql.ErrNoRows { 203 if err == sql.ErrNoRows {
...@@ -598,14 +222,20 @@ func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e ...@@ -598,14 +222,20 @@ func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e
598 return 0, err 222 return 0, err
599 } 223 }
600 224
225 Q := d.ins.TableQuote()
226
601 marks := make([]string, len(names)) 227 marks := make([]string, len(names))
602 for i, _ := range marks { 228 for i, _ := range marks {
603 marks[i] = "?" 229 marks[i] = "?"
604 } 230 }
231
232 sep := fmt.Sprintf("%s, %s", Q, Q)
605 qmarks := strings.Join(marks, ", ") 233 qmarks := strings.Join(marks, ", ")
606 columns := strings.Join(names, "`,`") 234 columns := strings.Join(names, sep)
607 235
608 query := fmt.Sprintf("INSERT INTO `%s` (`%s`) VALUES (%s)", mi.table, columns, qmarks) 236 query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.table, Q, Q, columns, Q, qmarks)
237
238 d.ins.ReplaceMarks(&query)
609 239
610 if res, err := q.Exec(query, values...); err == nil { 240 if res, err := q.Exec(query, values...); err == nil {
611 return res.LastInsertId() 241 return res.LastInsertId()
...@@ -624,11 +254,16 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e ...@@ -624,11 +254,16 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e
624 return 0, err 254 return 0, err
625 } 255 }
626 256
627 setColumns := strings.Join(setNames, "` = ?, `") 257 setValues = append(setValues, pkValue)
628 258
629 query := fmt.Sprintf("UPDATE `%s` SET `%s` = ? WHERE `%s` = ?", mi.table, setColumns, pkName) 259 Q := d.ins.TableQuote()
630 260
631 setValues = append(setValues, pkValue) 261 sep := fmt.Sprintf("%s = ?, %s", Q, Q)
262 setColumns := strings.Join(setNames, sep)
263
264 query := fmt.Sprintf("UPDATE %s%s%s SET %s%s%s = ? WHERE %s%s%s = ?", Q, mi.table, Q, Q, setColumns, Q, Q, pkName, Q)
265
266 d.ins.ReplaceMarks(&query)
632 267
633 if res, err := q.Exec(query, setValues...); err == nil { 268 if res, err := q.Exec(query, setValues...); err == nil {
634 return res.RowsAffected() 269 return res.RowsAffected()
...@@ -644,7 +279,11 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e ...@@ -644,7 +279,11 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e
644 return 0, ErrMissPK 279 return 0, ErrMissPK
645 } 280 }
646 281
647 query := fmt.Sprintf("DELETE FROM `%s` WHERE `%s` = ?", mi.table, pkName) 282 Q := d.ins.TableQuote()
283
284 query := fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s = ?", Q, mi.table, Q, Q, pkName, Q)
285
286 d.ins.ReplaceMarks(&query)
648 287
649 if res, err := q.Exec(query, pkValue); err == nil { 288 if res, err := q.Exec(query, pkValue); err == nil {
650 289
...@@ -694,11 +333,24 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con ...@@ -694,11 +333,24 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
694 333
695 where, args := tables.getCondSql(cond, false) 334 where, args := tables.getCondSql(cond, false)
696 335
336 values = append(values, args...)
337
697 join := tables.getJoinSql() 338 join := tables.getJoinSql()
698 339
699 query := fmt.Sprintf("UPDATE `%s` T0 %sSET T0.`%s` = ? %s", mi.table, join, strings.Join(columns, "` = ?, T0.`"), where) 340 var query string
700 341
701 values = append(values, args...) 342 Q := d.ins.TableQuote()
343
344 if d.ins.SupportUpdateJoin() {
345 cols := strings.Join(columns, fmt.Sprintf("%s = ?, T0.%s", Q, Q))
346 query = fmt.Sprintf("UPDATE %s%s%s T0 %sSET T0.%s%s%s = ? %s", Q, mi.table, Q, join, Q, cols, Q, where)
347 } else {
348 cols := strings.Join(columns, fmt.Sprintf("%s = ?, %s", Q, Q))
349 supQuery := fmt.Sprintf("SELECT T0.%s%s%s FROM %s%s%s T0 %s%s", Q, mi.fields.pk.column, Q, Q, mi.table, Q, join, where)
350 query = fmt.Sprintf("UPDATE %s%s%s SET %s%s%s = ? WHERE %s%s%s IN ( %s )", Q, mi.table, Q, Q, cols, Q, Q, mi.fields.pk.column, Q, supQuery)
351 }
352
353 d.ins.ReplaceMarks(&query)
702 354
703 if res, err := q.Exec(query, values...); err == nil { 355 if res, err := q.Exec(query, values...); err == nil {
704 return res.RowsAffected() 356 return res.RowsAffected()
...@@ -744,11 +396,15 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con ...@@ -744,11 +396,15 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
744 panic("delete operation cannot execute without condition") 396 panic("delete operation cannot execute without condition")
745 } 397 }
746 398
399 Q := d.ins.TableQuote()
400
747 where, args := tables.getCondSql(cond, false) 401 where, args := tables.getCondSql(cond, false)
748 join := tables.getJoinSql() 402 join := tables.getJoinSql()
749 403
750 cols := fmt.Sprintf("T0.`%s`", mi.fields.pk.column) 404 cols := fmt.Sprintf("T0.%s%s%s", Q, mi.fields.pk.column, Q)
751 query := fmt.Sprintf("SELECT %s FROM `%s` T0 %s%s", cols, mi.table, join, where) 405 query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s", cols, Q, mi.table, Q, join, where)
406
407 d.ins.ReplaceMarks(&query)
752 408
753 var rs *sql.Rows 409 var rs *sql.Rows
754 if r, err := q.Query(query, args...); err != nil { 410 if r, err := q.Query(query, args...); err != nil {
...@@ -773,8 +429,10 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con ...@@ -773,8 +429,10 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
773 return 0, nil 429 return 0, nil
774 } 430 }
775 431
776 sql, args := d.ins.GetOperatorSql(mi, "in", args) 432 sql, args := d.ins.GenerateOperatorSql(mi, "in", args)
777 query = fmt.Sprintf("DELETE FROM `%s` WHERE `%s` %s", mi.table, mi.fields.pk.column, sql) 433 query = fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s %s", Q, mi.table, Q, Q, mi.fields.pk.column, Q, sql)
434
435 d.ins.ReplaceMarks(&query)
778 436
779 if res, err := q.Exec(query, args...); err == nil { 437 if res, err := q.Exec(query, args...); err == nil {
780 num, err := res.RowsAffected() 438 num, err := res.RowsAffected()
...@@ -831,24 +489,30 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi ...@@ -831,24 +489,30 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
831 offset = 0 489 offset = 0
832 } 490 }
833 491
492 Q := d.ins.TableQuote()
493
834 tables := newDbTables(mi, d.ins) 494 tables := newDbTables(mi, d.ins)
835 tables.parseRelated(qs.related, qs.relDepth) 495 tables.parseRelated(qs.related, qs.relDepth)
836 496
837 where, args := tables.getCondSql(cond, false) 497 where, args := tables.getCondSql(cond, false)
838 orderBy := tables.getOrderSql(qs.orders) 498 orderBy := tables.getOrderSql(qs.orders)
839 limit := tables.getLimitSql(offset, rlimit) 499 limit := tables.getLimitSql(mi, offset, rlimit)
840 join := tables.getJoinSql() 500 join := tables.getJoinSql()
841 501
842 colsNum := len(mi.fields.dbcols) 502 colsNum := len(mi.fields.dbcols)
843 cols := fmt.Sprintf("T0.`%s`", strings.Join(mi.fields.dbcols, "`, T0.`")) 503 sep := fmt.Sprintf("%s, T0.%s", Q, Q)
504 cols := fmt.Sprintf("T0.%s%s%s", Q, strings.Join(mi.fields.dbcols, sep), Q)
844 for _, tbl := range tables.tables { 505 for _, tbl := range tables.tables {
845 if tbl.sel { 506 if tbl.sel {
846 colsNum += len(tbl.mi.fields.dbcols) 507 colsNum += len(tbl.mi.fields.dbcols)
847 cols += fmt.Sprintf(", %s.`%s`", tbl.index, strings.Join(tbl.mi.fields.dbcols, "`, "+tbl.index+".`")) 508 sep := fmt.Sprintf("%s, %s.%s", Q, tbl.index, Q)
509 cols += fmt.Sprintf(", %s.%s%s%s", tbl.index, Q, strings.Join(tbl.mi.fields.dbcols, sep), Q)
848 } 510 }
849 } 511 }
850 512
851 query := fmt.Sprintf("SELECT %s FROM `%s` T0 %s%s%s%s", cols, mi.table, join, where, orderBy, limit) 513 query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s%s%s", cols, Q, mi.table, Q, join, where, orderBy, limit)
514
515 d.ins.ReplaceMarks(&query)
852 516
853 var rs *sql.Rows 517 var rs *sql.Rows
854 if r, err := q.Query(query, args...); err != nil { 518 if r, err := q.Query(query, args...); err != nil {
...@@ -940,7 +604,11 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition ...@@ -940,7 +604,11 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition
940 tables.getOrderSql(qs.orders) 604 tables.getOrderSql(qs.orders)
941 join := tables.getJoinSql() 605 join := tables.getJoinSql()
942 606
943 query := fmt.Sprintf("SELECT COUNT(*) FROM `%s` T0 %s%s", mi.table, join, where) 607 Q := d.ins.TableQuote()
608
609 query := fmt.Sprintf("SELECT COUNT(*) FROM %s%s%s T0 %s%s", Q, mi.table, Q, join, where)
610
611 d.ins.ReplaceMarks(&query)
944 612
945 row := q.QueryRow(query, args...) 613 row := q.QueryRow(query, args...)
946 614
...@@ -1014,7 +682,7 @@ func (d *dbBase) getOperatorParams(operator string, args []interface{}) (params ...@@ -1014,7 +682,7 @@ func (d *dbBase) getOperatorParams(operator string, args []interface{}) (params
1014 return 682 return
1015 } 683 }
1016 684
1017 func (d *dbBase) GetOperatorSql(mi *modelInfo, operator string, args []interface{}) (string, []interface{}) { 685 func (d *dbBase) GenerateOperatorSql(mi *modelInfo, operator string, args []interface{}) (string, []interface{}) {
1018 sql := "" 686 sql := ""
1019 params := d.getOperatorParams(operator, args) 687 params := d.getOperatorParams(operator, args)
1020 688
...@@ -1028,7 +696,7 @@ func (d *dbBase) GetOperatorSql(mi *modelInfo, operator string, args []interface ...@@ -1028,7 +696,7 @@ func (d *dbBase) GetOperatorSql(mi *modelInfo, operator string, args []interface
1028 if len(params) > 1 { 696 if len(params) > 1 {
1029 panic(fmt.Sprintf("operator `%s` need 1 args not %d", operator, len(params))) 697 panic(fmt.Sprintf("operator `%s` need 1 args not %d", operator, len(params)))
1030 } 698 }
1031 sql = operatorsSQL[operator] 699 sql = d.ins.OperatorSql(operator)
1032 arg := params[0] 700 arg := params[0]
1033 switch operator { 701 switch operator {
1034 case "exact": 702 case "exact":
...@@ -1073,13 +741,13 @@ func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string, ...@@ -1073,13 +741,13 @@ func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string,
1073 741
1074 value, err := d.getValue(fi, val) 742 value, err := d.getValue(fi, val)
1075 if err != nil { 743 if err != nil {
1076 panic(fmt.Sprintf("db value convert failed `%v` %s", val, err.Error())) 744 panic(fmt.Sprintf("Raw value: `%v` %s", val, err.Error()))
1077 } 745 }
1078 746
1079 _, err = d.setValue(fi, value, &field) 747 _, err = d.setValue(fi, value, &field)
1080 748
1081 if err != nil { 749 if err != nil {
1082 panic(fmt.Sprintf("db value convert failed `%v` %s", val, err.Error())) 750 panic(fmt.Sprintf("Raw value: `%v` %s", val, err.Error()))
1083 } 751 }
1084 } 752 }
1085 } 753 }
...@@ -1090,6 +758,7 @@ func (d *dbBase) getValue(fi *fieldInfo, val interface{}) (interface{}, error) { ...@@ -1090,6 +758,7 @@ func (d *dbBase) getValue(fi *fieldInfo, val interface{}) (interface{}, error) {
1090 } 758 }
1091 759
1092 var value interface{} 760 var value interface{}
761 var tErr error
1093 762
1094 var str *StrTo 763 var str *StrTo
1095 switch v := val.(type) { 764 switch v := val.(type) {
...@@ -1119,7 +788,8 @@ setValue: ...@@ -1119,7 +788,8 @@ setValue:
1119 if str != nil { 788 if str != nil {
1120 b, err := str.Bool() 789 b, err := str.Bool()
1121 if err != nil { 790 if err != nil {
1122 return nil, err 791 tErr = err
792 goto end
1123 } 793 }
1124 value = b 794 value = b
1125 } 795 }
...@@ -1140,14 +810,23 @@ setValue: ...@@ -1140,14 +810,23 @@ setValue:
1140 } 810 }
1141 } 811 }
1142 if str != nil { 812 if str != nil {
1143 format := format_DateTime 813 s := str.String()
814 var format string
1144 if fi.fieldType == TypeDateField { 815 if fi.fieldType == TypeDateField {
1145 format = format_Date 816 format = format_Date
817 if len(s) > 10 {
818 s = s[:10]
819 }
820 } else {
821 format = format_DateTime
822 if len(s) > 19 {
823 s = s[:19]
824 }
1146 } 825 }
1147 s := str.String()
1148 t, err := timeParse(s, format) 826 t, err := timeParse(s, format)
1149 if err != nil && s != "0000-00-00" && s != "0000-00-00 00:00:00" { 827 if err != nil && s != "0000-00-00" && s != "0000-00-00 00:00:00" {
1150 return nil, err 828 tErr = err
829 goto end
1151 } 830 }
1152 value = t 831 value = t
1153 } 832 }
...@@ -1173,7 +852,8 @@ setValue: ...@@ -1173,7 +852,8 @@ setValue:
1173 _, err = str.Uint64() 852 _, err = str.Uint64()
1174 } 853 }
1175 if err != nil { 854 if err != nil {
1176 return nil, err 855 tErr = err
856 goto end
1177 } 857 }
1178 if fieldType&IsPostiveIntegerField > 0 { 858 if fieldType&IsPostiveIntegerField > 0 {
1179 v, _ := str.Uint64() 859 v, _ := str.Uint64()
...@@ -1196,15 +876,23 @@ setValue: ...@@ -1196,15 +876,23 @@ setValue:
1196 if str != nil { 876 if str != nil {
1197 v, err := str.Float64() 877 v, err := str.Float64()
1198 if err != nil { 878 if err != nil {
1199 return nil, err 879 tErr = err
880 goto end
1200 } 881 }
1201 value = v 882 value = v
1202 } 883 }
1203 case fieldType&IsRelField > 0: 884 case fieldType&IsRelField > 0:
1204 fieldType = fi.relModelInfo.fields.pk.fieldType 885 fi = fi.relModelInfo.fields.pk
886 fieldType = fi.fieldType
1205 goto setValue 887 goto setValue
1206 } 888 }
1207 889
890 end:
891 if tErr != nil {
892 err := fmt.Errorf("convert to `%s` failed, field: %s err: %s", fi.addrValue.Type(), fi.fullName, tErr)
893 return nil, err
894 }
895
1208 return value, nil 896 return value, nil
1209 897
1210 } 898 }
...@@ -1275,6 +963,7 @@ setValue: ...@@ -1275,6 +963,7 @@ setValue:
1275 fd := field.Addr().Interface().(Fielder) 963 fd := field.Addr().Interface().(Fielder)
1276 err := fd.SetRaw(value) 964 err := fd.SetRaw(value)
1277 if err != nil { 965 if err != nil {
966 err = fmt.Errorf("converted value `%v` set to Fielder `%s` failed, err: %s", value, fi.fullName, err)
1278 return nil, err 967 return nil, err
1279 } 968 }
1280 } 969 }
...@@ -1311,6 +1000,8 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond ...@@ -1311,6 +1000,8 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
1311 1000
1312 hasExprs := len(exprs) > 0 1001 hasExprs := len(exprs) > 0
1313 1002
1003 Q := d.ins.TableQuote()
1004
1314 if hasExprs { 1005 if hasExprs {
1315 cols = make([]string, 0, len(exprs)) 1006 cols = make([]string, 0, len(exprs))
1316 infos = make([]*fieldInfo, 0, len(exprs)) 1007 infos = make([]*fieldInfo, 0, len(exprs))
...@@ -1319,26 +1010,26 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond ...@@ -1319,26 +1010,26 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
1319 if suc == false { 1010 if suc == false {
1320 panic(fmt.Errorf("unknown field/column name `%s`", ex)) 1011 panic(fmt.Errorf("unknown field/column name `%s`", ex))
1321 } 1012 }
1322 cols = append(cols, fmt.Sprintf("%s.`%s` `%s`", index, col, name)) 1013 cols = append(cols, fmt.Sprintf("%s.%s%s%s %s%s%s", index, Q, col, Q, Q, name, Q))
1323 infos = append(infos, fi) 1014 infos = append(infos, fi)
1324 } 1015 }
1325 } else { 1016 } else {
1326 cols = make([]string, 0, len(mi.fields.dbcols)) 1017 cols = make([]string, 0, len(mi.fields.dbcols))
1327 infos = make([]*fieldInfo, 0, len(exprs)) 1018 infos = make([]*fieldInfo, 0, len(exprs))
1328 for _, fi := range mi.fields.fieldsDB { 1019 for _, fi := range mi.fields.fieldsDB {
1329 cols = append(cols, fmt.Sprintf("T0.`%s` `%s`", fi.column, fi.name)) 1020 cols = append(cols, fmt.Sprintf("T0.%s%s%s %s%s%s", Q, fi.column, Q, Q, fi.name, Q))
1330 infos = append(infos, fi) 1021 infos = append(infos, fi)
1331 } 1022 }
1332 } 1023 }
1333 1024
1334 where, args := tables.getCondSql(cond, false) 1025 where, args := tables.getCondSql(cond, false)
1335 orderBy := tables.getOrderSql(qs.orders) 1026 orderBy := tables.getOrderSql(qs.orders)
1336 limit := tables.getLimitSql(qs.offset, qs.limit) 1027 limit := tables.getLimitSql(mi, qs.offset, qs.limit)
1337 join := tables.getJoinSql() 1028 join := tables.getJoinSql()
1338 1029
1339 sels := strings.Join(cols, ", ") 1030 sels := strings.Join(cols, ", ")
1340 1031
1341 query := fmt.Sprintf("SELECT %s FROM `%s` T0 %s%s%s%s", sels, mi.table, join, where, orderBy, limit) 1032 query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s%s%s", sels, Q, mi.table, Q, join, where, orderBy, limit)
1342 1033
1343 var rs *sql.Rows 1034 var rs *sql.Rows
1344 if r, err := q.Query(query, args...); err != nil { 1035 if r, err := q.Query(query, args...); err != nil {
...@@ -1430,3 +1121,19 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond ...@@ -1430,3 +1121,19 @@ func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Cond
1430 1121
1431 return cnt, nil 1122 return cnt, nil
1432 } 1123 }
1124
1125 func (d *dbBase) SupportUpdateJoin() bool {
1126 return true
1127 }
1128
1129 func (d *dbBase) MaxLimit() uint64 {
1130 return 18446744073709551615
1131 }
1132
1133 func (d *dbBase) TableQuote() string {
1134 return "`"
1135 }
1136
1137 func (d *dbBase) ReplaceMarks(query *string) {
1138 // default use `?` as mark, do nothing
1139 }
......
1 package orm 1 package orm
2 2
3 var mysqlOperators = map[string]string{
4 "exact": "= ?",
5 "iexact": "LIKE ?",
6 "contains": "LIKE BINARY ?",
7 "icontains": "LIKE ?",
8 // "regex": "REGEXP BINARY ?",
9 // "iregex": "REGEXP ?",
10 "gt": "> ?",
11 "gte": ">= ?",
12 "lt": "< ?",
13 "lte": "<= ?",
14 "startswith": "LIKE BINARY ?",
15 "endswith": "LIKE BINARY ?",
16 "istartswith": "LIKE ?",
17 "iendswith": "LIKE ?",
18 }
19
3 type dbBaseMysql struct { 20 type dbBaseMysql struct {
4 dbBase 21 dbBase
5 } 22 }
6 23
7 func (d *dbBaseMysql) GetOperatorSql(mi *modelInfo, operator string, args []interface{}) (sql string, params []interface{}) { 24 var _ dbBaser = new(dbBaseMysql)
8 return d.dbBase.GetOperatorSql(mi, operator, args) 25
26 func (d *dbBaseMysql) OperatorSql(operator string) string {
27 return mysqlOperators[operator]
9 } 28 }
10 29
11 func newdbBaseMysql() dbBaser { 30 func newdbBaseMysql() dbBaser {
......
...@@ -4,6 +4,12 @@ type dbBaseOracle struct { ...@@ -4,6 +4,12 @@ type dbBaseOracle struct {
4 dbBase 4 dbBase
5 } 5 }
6 6
7 var _ dbBaser = new(dbBaseOracle)
8
9 func (d *dbBase) OperatorSql(operator string) string {
10 return ""
11 }
12
7 func newdbBaseOracle() dbBaser { 13 func newdbBaseOracle() dbBaser {
8 b := new(dbBaseOracle) 14 b := new(dbBaseOracle)
9 b.ins = b 15 b.ins = b
......
1 package orm 1 package orm
2 2
3 import (
4 "strconv"
5 )
6
7 var postgresOperators = map[string]string{
8 "exact": "= ?",
9 "iexact": "= UPPER(?)",
10 "contains": "LIKE ?",
11 "icontains": "LIKE UPPER(?)",
12 "gt": "> ?",
13 "gte": ">= ?",
14 "lt": "< ?",
15 "lte": "<= ?",
16 "startswith": "LIKE ?",
17 "endswith": "LIKE ?",
18 "istartswith": "LIKE UPPER(?)",
19 "iendswith": "LIKE UPPER(?)",
20 }
21
3 type dbBasePostgres struct { 22 type dbBasePostgres struct {
4 dbBase 23 dbBase
5 } 24 }
6 25
26 var _ dbBaser = new(dbBasePostgres)
27
28 func (d *dbBasePostgres) OperatorSql(operator string) string {
29 return postgresOperators[operator]
30 }
31
32 func (d *dbBasePostgres) TableQuote() string {
33 return `"`
34 }
35
36 func (d *dbBasePostgres) ReplaceMarks(query *string) {
37 q := *query
38 num := 0
39 for _, c := range q {
40 if c == '?' {
41 num += 1
42 }
43 }
44 if num == 0 {
45 return
46 }
47 data := make([]byte, 0, len(q)+num)
48 num = 1
49 for i := 0; i < len(q); i++ {
50 c := q[i]
51 if c == '?' {
52 data = append(data, '$')
53 data = append(data, []byte(strconv.Itoa(num))...)
54 num += 1
55 } else {
56 data = append(data, c)
57 }
58 }
59 *query = string(data)
60 }
61
62 // func (d *dbBasePostgres)
63
7 func newdbBasePostgres() dbBaser { 64 func newdbBasePostgres() dbBaser {
8 b := new(dbBasePostgres) 65 b := new(dbBasePostgres)
9 b.ins = b 66 b.ins = b
......
1 package orm 1 package orm
2 2
3 var sqliteOperators = map[string]string{
4 "exact": "= ?",
5 "iexact": "LIKE ? ESCAPE '\\'",
6 "contains": "LIKE ? ESCAPE '\\'",
7 "icontains": "LIKE ? ESCAPE '\\'",
8 "gt": "> ?",
9 "gte": ">= ?",
10 "lt": "< ?",
11 "lte": "<= ?",
12 "startswith": "LIKE ? ESCAPE '\\'",
13 "endswith": "LIKE ? ESCAPE '\\'",
14 "istartswith": "LIKE ? ESCAPE '\\'",
15 "iendswith": "LIKE ? ESCAPE '\\'",
16 }
17
3 type dbBaseSqlite struct { 18 type dbBaseSqlite struct {
4 dbBase 19 dbBase
5 } 20 }
6 21
22 var _ dbBaser = new(dbBaseSqlite)
23
24 func (d *dbBaseSqlite) OperatorSql(operator string) string {
25 return sqliteOperators[operator]
26 }
27
28 func (d *dbBaseSqlite) SupportUpdateJoin() bool {
29 return false
30 }
31
32 func (d *dbBaseSqlite) MaxLimit() uint64 {
33 return 9223372036854775807
34 }
35
7 func newdbBaseSqlite() dbBaser { 36 func newdbBaseSqlite() dbBaser {
8 b := new(dbBaseSqlite) 37 b := new(dbBaseSqlite)
9 b.ins = b 38 b.ins = b
......
1 package orm
2
3 import (
4 "fmt"
5 "strings"
6 )
7
8 type dbTable struct {
9 id int
10 index string
11 name string
12 names []string
13 sel bool
14 inner bool
15 mi *modelInfo
16 fi *fieldInfo
17 jtl *dbTable
18 }
19
20 type dbTables struct {
21 tablesM map[string]*dbTable
22 tables []*dbTable
23 mi *modelInfo
24 base dbBaser
25 }
26
27 func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool) *dbTable {
28 name := strings.Join(names, ExprSep)
29 if j, ok := t.tablesM[name]; ok {
30 j.name = name
31 j.mi = mi
32 j.fi = fi
33 j.inner = inner
34 } else {
35 i := len(t.tables) + 1
36 jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil}
37 t.tablesM[name] = jt
38 t.tables = append(t.tables, jt)
39 }
40 return t.tablesM[name]
41 }
42
43 func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool) (*dbTable, bool) {
44 name := strings.Join(names, ExprSep)
45 if _, ok := t.tablesM[name]; ok == false {
46 i := len(t.tables) + 1
47 jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil}
48 t.tablesM[name] = jt
49 t.tables = append(t.tables, jt)
50 return jt, true
51 }
52 return t.tablesM[name], false
53 }
54
55 func (t *dbTables) get(name string) (*dbTable, bool) {
56 j, ok := t.tablesM[name]
57 return j, ok
58 }
59
60 func (t *dbTables) loopDepth(depth int, prefix string, fi *fieldInfo, related []string) []string {
61 if depth < 0 || fi.fieldType == RelManyToMany {
62 return related
63 }
64
65 if prefix == "" {
66 prefix = fi.name
67 } else {
68 prefix = prefix + ExprSep + fi.name
69 }
70 related = append(related, prefix)
71
72 depth--
73 for _, fi := range fi.relModelInfo.fields.fieldsRel {
74 related = t.loopDepth(depth, prefix, fi, related)
75 }
76
77 return related
78 }
79
80 func (t *dbTables) parseRelated(rels []string, depth int) {
81
82 relsNum := len(rels)
83 related := make([]string, relsNum)
84 copy(related, rels)
85
86 relDepth := depth
87
88 if relsNum != 0 {
89 relDepth = 0
90 }
91
92 relDepth--
93 for _, fi := range t.mi.fields.fieldsRel {
94 related = t.loopDepth(relDepth, "", fi, related)
95 }
96
97 for i, s := range related {
98 var (
99 exs = strings.Split(s, ExprSep)
100 names = make([]string, 0, len(exs))
101 mmi = t.mi
102 cansel = true
103 jtl *dbTable
104 )
105 for _, ex := range exs {
106 if fi, ok := mmi.fields.GetByAny(ex); ok && fi.rel && fi.fieldType != RelManyToMany {
107 names = append(names, fi.name)
108 mmi = fi.relModelInfo
109
110 jt := t.set(names, mmi, fi, fi.null == false)
111 jt.jtl = jtl
112
113 if fi.reverse {
114 cansel = false
115 }
116
117 if cansel {
118 jt.sel = depth > 0
119
120 if i < relsNum {
121 jt.sel = true
122 }
123 }
124
125 jtl = jt
126
127 } else {
128 panic(fmt.Sprintf("unknown model/table name `%s`", ex))
129 }
130 }
131 }
132 }
133
134 func (t *dbTables) getJoinSql() (join string) {
135 Q := t.base.TableQuote()
136
137 for _, jt := range t.tables {
138 if jt.inner {
139 join += "INNER JOIN "
140 } else {
141 join += "LEFT OUTER JOIN "
142 }
143 var (
144 table string
145 t1, t2 string
146 c1, c2 string
147 )
148 t1 = "T0"
149 if jt.jtl != nil {
150 t1 = jt.jtl.index
151 }
152 t2 = jt.index
153 table = jt.mi.table
154
155 switch {
156 case jt.fi.fieldType == RelManyToMany || jt.fi.reverse && jt.fi.reverseFieldInfo.fieldType == RelManyToMany:
157 c1 = jt.fi.mi.fields.pk.column
158 for _, ffi := range jt.mi.fields.fieldsRel {
159 if jt.fi.mi == ffi.relModelInfo {
160 c2 = ffi.column
161 break
162 }
163 }
164 default:
165 c1 = jt.fi.column
166 c2 = jt.fi.relModelInfo.fields.pk.column
167
168 if jt.fi.reverse {
169 c1 = jt.mi.fields.pk.column
170 c2 = jt.fi.reverseFieldInfo.column
171 }
172 }
173
174 join += fmt.Sprintf("%s%s%s %s ON %s.%s%s%s = %s.%s%s%s ", Q, table, Q, t2,
175 t2, Q, c2, Q, t1, Q, c1, Q)
176 }
177 return
178 }
179
180 func (d *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, column, name string, info *fieldInfo, success bool) {
181 var (
182 ffi *fieldInfo
183 jtl *dbTable
184 mmi = mi
185 )
186
187 num := len(exprs) - 1
188 names := make([]string, 0)
189
190 for i, ex := range exprs {
191 exist := false
192
193 check:
194 fi, ok := mmi.fields.GetByAny(ex)
195
196 if ok {
197
198 if num != i {
199 names = append(names, fi.name)
200
201 switch {
202 case fi.rel:
203 mmi = fi.relModelInfo
204 if fi.fieldType == RelManyToMany {
205 mmi = fi.relThroughModelInfo
206 }
207 case fi.reverse:
208 mmi = fi.reverseFieldInfo.mi
209 if fi.reverseFieldInfo.fieldType == RelManyToMany {
210 mmi = fi.reverseFieldInfo.relThroughModelInfo
211 }
212 default:
213 return
214 }
215
216 jt, _ := d.add(names, mmi, fi, fi.null == false)
217 jt.jtl = jtl
218 jtl = jt
219
220 if fi.rel && fi.fieldType == RelManyToMany {
221 ex = fi.relModelInfo.name
222 goto check
223 }
224
225 if fi.reverse && fi.reverseFieldInfo.fieldType == RelManyToMany {
226 ex = fi.reverseFieldInfo.mi.name
227 goto check
228 }
229
230 exist = true
231
232 } else {
233
234 if ffi == nil {
235 index = "T0"
236 } else {
237 index = jtl.index
238 }
239 column = fi.column
240 info = fi
241 if jtl != nil {
242 name = jtl.name + ExprSep + fi.name
243 } else {
244 name = fi.name
245 }
246
247 switch fi.fieldType {
248 case RelManyToMany, RelReverseMany:
249 default:
250 exist = true
251 }
252 }
253
254 ffi = fi
255 }
256
257 if exist == false {
258 index = ""
259 column = ""
260 name = ""
261 success = false
262 return
263 }
264 }
265
266 success = index != "" && column != ""
267 return
268 }
269
270 func (d *dbTables) getCondSql(cond *Condition, sub bool) (where string, params []interface{}) {
271 if cond == nil || cond.IsEmpty() {
272 return
273 }
274
275 Q := d.base.TableQuote()
276
277 mi := d.mi
278
279 // outFor:
280 for i, p := range cond.params {
281 if i > 0 {
282 if p.isOr {
283 where += "OR "
284 } else {
285 where += "AND "
286 }
287 }
288 if p.isNot {
289 where += "NOT "
290 }
291 if p.isCond {
292 w, ps := d.getCondSql(p.cond, true)
293 if w != "" {
294 w = fmt.Sprintf("( %s) ", w)
295 }
296 where += w
297 params = append(params, ps...)
298 } else {
299 exprs := p.exprs
300
301 num := len(exprs) - 1
302 operator := ""
303 if operators[exprs[num]] {
304 operator = exprs[num]
305 exprs = exprs[:num]
306 }
307
308 index, column, _, _, suc := d.parseExprs(mi, exprs)
309 if suc == false {
310 panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(p.exprs, ExprSep)))
311 }
312
313 if operator == "" {
314 operator = "exact"
315 }
316
317 operSql, args := d.base.GenerateOperatorSql(mi, operator, p.args)
318
319 where += fmt.Sprintf("%s.%s%s%s %s ", index, Q, column, Q, operSql)
320 params = append(params, args...)
321
322 }
323 }
324
325 if sub == false && where != "" {
326 where = "WHERE " + where
327 }
328
329 return
330 }
331
332 func (d *dbTables) getOrderSql(orders []string) (orderSql string) {
333 if len(orders) == 0 {
334 return
335 }
336
337 Q := d.base.TableQuote()
338
339 orderSqls := make([]string, 0, len(orders))
340 for _, order := range orders {
341 asc := "ASC"
342 if order[0] == '-' {
343 asc = "DESC"
344 order = order[1:]
345 }
346 exprs := strings.Split(order, ExprSep)
347
348 index, column, _, _, suc := d.parseExprs(d.mi, exprs)
349 if suc == false {
350 panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep)))
351 }
352
353 orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, column, Q, asc))
354 }
355
356 orderSql = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", "))
357 return
358 }
359
360 func (d *dbTables) getLimitSql(mi *modelInfo, offset int64, limit int) (limits string) {
361 if limit == 0 {
362 limit = DefaultRowsLimit
363 }
364 if limit < 0 {
365 // no limit
366 if offset > 0 {
367 maxLimit := d.base.MaxLimit()
368 limits = fmt.Sprintf("LIMIT %d OFFSET %d", maxLimit, offset)
369 }
370 } else if offset <= 0 {
371 limits = fmt.Sprintf("LIMIT %d", limit)
372 } else {
373 limits = fmt.Sprintf("LIMIT %d OFFSET %d", limit, offset)
374 }
375 return
376 }
377
378 func newDbTables(mi *modelInfo, base dbBaser) *dbTables {
379 tables := &dbTables{}
380 tables.tablesM = make(map[string]*dbTable)
381 tables.mi = mi
382 tables.base = base
383 return tables
384 }
...@@ -79,7 +79,7 @@ func newModelInfo(val reflect.Value) (info *modelInfo) { ...@@ -79,7 +79,7 @@ func newModelInfo(val reflect.Value) (info *modelInfo) {
79 func newM2MModelInfo(m1, m2 *modelInfo) (info *modelInfo) { 79 func newM2MModelInfo(m1, m2 *modelInfo) (info *modelInfo) {
80 info = new(modelInfo) 80 info = new(modelInfo)
81 info.fields = newFields() 81 info.fields = newFields()
82 info.table = m1.table + "_" + m2.table + "_rel" 82 info.table = m1.table + "_" + m2.table + "s"
83 info.name = camelString(info.table) 83 info.name = camelString(info.table)
84 info.fullName = m1.pkg + "." + info.name 84 info.fullName = m1.pkg + "." + info.name
85 85
......
...@@ -3,10 +3,11 @@ package orm ...@@ -3,10 +3,11 @@ package orm
3 import ( 3 import (
4 "fmt" 4 "fmt"
5 "os" 5 "os"
6 "strings"
6 "time" 7 "time"
7 8
8 _ "github.com/bmizerany/pq"
9 _ "github.com/go-sql-driver/mysql" 9 _ "github.com/go-sql-driver/mysql"
10 _ "github.com/lib/pq"
10 _ "github.com/mattn/go-sqlite3" 11 _ "github.com/mattn/go-sqlite3"
11 ) 12 )
12 13
...@@ -95,8 +96,178 @@ var DBARGS = struct { ...@@ -95,8 +96,178 @@ var DBARGS = struct {
95 os.Getenv("ORM_DEBUG"), 96 os.Getenv("ORM_DEBUG"),
96 } 97 }
97 98
99 var (
100 IsMysql = DBARGS.Driver == "mysql"
101 IsSqlite = DBARGS.Driver == "sqlite3"
102 IsPostgres = DBARGS.Driver == "postgres"
103 )
104
98 var dORM Ormer 105 var dORM Ormer
99 106
107 var initSQLs = map[string]string{
108 "mysql": "DROP TABLE IF EXISTS `user_profile`;\n" +
109 "DROP TABLE IF EXISTS `user`;\n" +
110 "DROP TABLE IF EXISTS `post`;\n" +
111 "DROP TABLE IF EXISTS `tag`;\n" +
112 "DROP TABLE IF EXISTS `post_tags`;\n" +
113 "DROP TABLE IF EXISTS `comment`;\n" +
114 "CREATE TABLE `user_profile` (\n" +
115 " `id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY,\n" +
116 " `age` smallint NOT NULL,\n" +
117 " `money` double precision NOT NULL\n" +
118 ") ENGINE=INNODB;\n" +
119 "CREATE TABLE `user` (\n" +
120 " `id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY,\n" +
121 " `user_name` varchar(30) NOT NULL UNIQUE,\n" +
122 " `email` varchar(100) NOT NULL,\n" +
123 " `password` varchar(100) NOT NULL,\n" +
124 " `status` smallint NOT NULL,\n" +
125 " `is_staff` bool NOT NULL,\n" +
126 " `is_active` bool NOT NULL,\n" +
127 " `created` date NOT NULL,\n" +
128 " `updated` datetime NOT NULL,\n" +
129 " `profile_id` integer\n" +
130 ") ENGINE=INNODB;\n" +
131 "CREATE TABLE `post` (\n" +
132 " `id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY,\n" +
133 " `user_id` integer NOT NULL,\n" +
134 " `title` varchar(60) NOT NULL,\n" +
135 " `content` longtext NOT NULL,\n" +
136 " `created` datetime NOT NULL,\n" +
137 " `updated` datetime NOT NULL\n" +
138 ") ENGINE=INNODB;\n" +
139 "CREATE TABLE `tag` (\n" +
140 " `id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY,\n" +
141 " `name` varchar(30) NOT NULL\n" +
142 ") ENGINE=INNODB;\n" +
143 "CREATE TABLE `post_tags` (\n" +
144 " `id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY,\n" +
145 " `post_id` integer NOT NULL,\n" +
146 " `tag_id` integer NOT NULL,\n" +
147 " UNIQUE (`post_id`, `tag_id`)\n" +
148 ") ENGINE=INNODB;\n" +
149 "CREATE TABLE `comment` (\n" +
150 " `id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY,\n" +
151 " `post_id` integer NOT NULL,\n" +
152 " `content` longtext NOT NULL,\n" +
153 " `parent_id` integer,\n" +
154 " `created` datetime NOT NULL\n" +
155 ") ENGINE=INNODB;\n" +
156 "CREATE INDEX `user_141c6eec` ON `user` (`profile_id`);\n" +
157 "CREATE INDEX `post_fbfc09f1` ON `post` (`user_id`);\n" +
158 "CREATE INDEX `comment_699ae8ca` ON `comment` (`post_id`);\n" +
159 "CREATE INDEX `comment_63f17a16` ON `comment` (`parent_id`);",
160
161 "sqlite3": `
162 DROP TABLE IF EXISTS "user_profile";
163 DROP TABLE IF EXISTS "user";
164 DROP TABLE IF EXISTS "post";
165 DROP TABLE IF EXISTS "tag";
166 DROP TABLE IF EXISTS "post_tags";
167 DROP TABLE IF EXISTS "comment";
168 CREATE TABLE "user_profile" (
169 "id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
170 "age" smallint NOT NULL,
171 "money" real NOT NULL
172 );
173 CREATE TABLE "user" (
174 "id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
175 "user_name" varchar(30) NOT NULL UNIQUE,
176 "email" varchar(100) NOT NULL,
177 "password" varchar(100) NOT NULL,
178 "status" smallint NOT NULL,
179 "is_staff" bool NOT NULL,
180 "is_active" bool NOT NULL,
181 "created" date NOT NULL,
182 "updated" datetime NOT NULL,
183 "profile_id" integer
184 );
185 CREATE TABLE "post" (
186 "id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
187 "user_id" integer NOT NULL,
188 "title" varchar(60) NOT NULL,
189 "content" text NOT NULL,
190 "created" datetime NOT NULL,
191 "updated" datetime NOT NULL
192 );
193 CREATE TABLE "tag" (
194 "id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
195 "name" varchar(30) NOT NULL
196 );
197 CREATE TABLE "post_tags" (
198 "id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
199 "post_id" integer NOT NULL,
200 "tag_id" integer NOT NULL,
201 UNIQUE ("post_id", "tag_id")
202 );
203 CREATE TABLE "comment" (
204 "id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
205 "post_id" integer NOT NULL,
206 "content" text NOT NULL,
207 "parent_id" integer,
208 "created" datetime NOT NULL
209 );
210 CREATE INDEX "user_141c6eec" ON "user" ("profile_id");
211 CREATE INDEX "post_fbfc09f1" ON "post" ("user_id");
212 CREATE INDEX "comment_699ae8ca" ON "comment" ("post_id");
213 CREATE INDEX "comment_63f17a16" ON "comment" ("parent_id");
214 `,
215
216 "postgres": `
217 DROP TABLE IF EXISTS "user_profile";
218 DROP TABLE IF EXISTS "user";
219 DROP TABLE IF EXISTS "post";
220 DROP TABLE IF EXISTS "tag";
221 DROP TABLE IF EXISTS "post_tags";
222 DROP TABLE IF EXISTS "comment";
223 CREATE TABLE "user_profile" (
224 "id" serial NOT NULL PRIMARY KEY,
225 "age" smallint NOT NULL,
226 "money" double precision NOT NULL
227 );
228 CREATE TABLE "user" (
229 "id" serial NOT NULL PRIMARY KEY,
230 "user_name" varchar(30) NOT NULL UNIQUE,
231 "email" varchar(100) NOT NULL,
232 "password" varchar(100) NOT NULL,
233 "status" smallint NOT NULL,
234 "is_staff" boolean NOT NULL,
235 "is_active" boolean NOT NULL,
236 "created" date NOT NULL,
237 "updated" timestamp with time zone NOT NULL,
238 "profile_id" integer
239 );
240 CREATE TABLE "post" (
241 "id" serial NOT NULL PRIMARY KEY,
242 "user_id" integer NOT NULL,
243 "title" varchar(60) NOT NULL,
244 "content" text NOT NULL,
245 "created" timestamp with time zone NOT NULL,
246 "updated" timestamp with time zone NOT NULL
247 );
248 CREATE TABLE "tag" (
249 "id" serial NOT NULL PRIMARY KEY,
250 "name" varchar(30) NOT NULL
251 );
252 CREATE TABLE "post_tags" (
253 "id" serial NOT NULL PRIMARY KEY,
254 "post_id" integer NOT NULL,
255 "tag_id" integer NOT NULL,
256 UNIQUE ("post_id", "tag_id")
257 );
258 CREATE TABLE "comment" (
259 "id" serial NOT NULL PRIMARY KEY,
260 "post_id" integer NOT NULL,
261 "content" text NOT NULL,
262 "parent_id" integer,
263 "created" timestamp with time zone NOT NULL
264 );
265 CREATE INDEX "user_profile_id" ON "user" ("profile_id");
266 CREATE INDEX "post_user_id" ON "post" ("user_id");
267 CREATE INDEX "comment_post_id" ON "comment" ("post_id");
268 CREATE INDEX "comment_parent_id" ON "comment" ("parent_id");
269 `}
270
100 func init() { 271 func init() {
101 RegisterModel(new(User)) 272 RegisterModel(new(User))
102 RegisterModel(new(Profile)) 273 RegisterModel(new(Profile))
...@@ -114,7 +285,7 @@ Default DB Drivers. ...@@ -114,7 +285,7 @@ Default DB Drivers.
114 driver: url 285 driver: url
115 mysql: https://github.com/go-sql-driver/mysql 286 mysql: https://github.com/go-sql-driver/mysql
116 sqlite3: https://github.com/mattn/go-sqlite3 287 sqlite3: https://github.com/mattn/go-sqlite3
117 postgres: https://github.com/bmizerany/pq 288 postgres: https://github.com/lib/pq
118 289
119 eg: mysql 290 eg: mysql
120 ORM_DRIVER=mysql ORM_SOURCE="root:root@/my_db?charset=utf8" go test github.com/astaxie/beego/orm 291 ORM_DRIVER=mysql ORM_SOURCE="root:root@/my_db?charset=utf8" go test github.com/astaxie/beego/orm
...@@ -126,20 +297,16 @@ ORM_DRIVER=mysql ORM_SOURCE="root:root@/my_db?charset=utf8" go test github.com/a ...@@ -126,20 +297,16 @@ ORM_DRIVER=mysql ORM_SOURCE="root:root@/my_db?charset=utf8" go test github.com/a
126 297
127 BootStrap() 298 BootStrap()
128 299
129 truncateTables()
130
131 dORM = NewOrm() 300 dORM = NewOrm()
132 }
133 301
134 func truncateTables() { 302 queries := strings.Split(initSQLs[DBARGS.Driver], ";")
135 logs := "truncate tables for test\n" 303
136 o := NewOrm() 304 for _, query := range queries {
137 for _, m := range modelCache.allOrdered() { 305 if strings.TrimSpace(query) == "" {
138 query := fmt.Sprintf("truncate table `%s`", m.table) 306 continue
139 _, err := o.Raw(query).Exec() 307 }
140 logs += query + "\n" 308 _, err := dORM.Raw(query).Exec()
141 if err != nil { 309 if err != nil {
142 fmt.Println(logs)
143 fmt.Println(err) 310 fmt.Println(err)
144 os.Exit(2) 311 os.Exit(2)
145 } 312 }
......
...@@ -135,7 +135,7 @@ func (d *dbQueryLog) Commit() error { ...@@ -135,7 +135,7 @@ func (d *dbQueryLog) Commit() error {
135 135
136 func (d *dbQueryLog) Rollback() error { 136 func (d *dbQueryLog) Rollback() error {
137 a := time.Now() 137 a := time.Now()
138 err := d.db.(txEnder).Commit() 138 err := d.db.(txEnder).Rollback()
139 debugLogQueies(d.alias, "tx.Rollback", "ROLLBACK", a, err) 139 debugLogQueies(d.alias, "tx.Rollback", "ROLLBACK", a, err)
140 return err 140 return err
141 } 141 }
......
...@@ -6,39 +6,17 @@ import ( ...@@ -6,39 +6,17 @@ import (
6 "reflect" 6 "reflect"
7 ) 7 )
8 8
9 func getResult(res sql.Result) (int64, error) {
10 if num, err := res.LastInsertId(); err != nil {
11 return 0, err
12 } else {
13 if num > 0 {
14 return num, nil
15 }
16 }
17 if num, err := res.RowsAffected(); err != nil {
18 return num, err
19 } else {
20 if num > 0 {
21 return num, nil
22 }
23 }
24 return 0, nil
25 }
26
27 type rawPrepare struct { 9 type rawPrepare struct {
28 rs *rawSet 10 rs *rawSet
29 stmt stmtQuerier 11 stmt stmtQuerier
30 closed bool 12 closed bool
31 } 13 }
32 14
33 func (o *rawPrepare) Exec(args ...interface{}) (int64, error) { 15 func (o *rawPrepare) Exec(args ...interface{}) (sql.Result, error) {
34 if o.closed { 16 if o.closed {
35 return 0, ErrStmtClosed 17 return nil, ErrStmtClosed
36 }
37 res, err := o.stmt.Exec(args...)
38 if err != nil {
39 return 0, err
40 } 18 }
41 return getResult(res) 19 return o.stmt.Exec(args...)
42 } 20 }
43 21
44 func (o *rawPrepare) Close() error { 22 func (o *rawPrepare) Close() error {
...@@ -74,12 +52,8 @@ func (o rawSet) SetArgs(args ...interface{}) RawSeter { ...@@ -74,12 +52,8 @@ func (o rawSet) SetArgs(args ...interface{}) RawSeter {
74 return &o 52 return &o
75 } 53 }
76 54
77 func (o *rawSet) Exec() (int64, error) { 55 func (o *rawSet) Exec() (sql.Result, error) {
78 res, err := o.orm.db.Exec(o.query, o.args...) 56 return o.orm.db.Exec(o.query, o.args...)
79 if err != nil {
80 return 0, err
81 }
82 return getResult(res)
83 } 57 }
84 58
85 func (o *rawSet) QueryRow(...interface{}) error { 59 func (o *rawSet) QueryRow(...interface{}) error {
......
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
4 "bytes" 4 "bytes"
5 "fmt" 5 "fmt"
6 "io/ioutil" 6 "io/ioutil"
7 "os"
7 "path/filepath" 8 "path/filepath"
8 "reflect" 9 "reflect"
9 "runtime" 10 "runtime"
...@@ -12,6 +13,8 @@ import ( ...@@ -12,6 +13,8 @@ import (
12 "time" 13 "time"
13 ) 14 )
14 15
16 var _ = os.PathSeparator
17
15 type T_Code int 18 type T_Code int
16 19
17 const ( 20 const (
...@@ -60,9 +63,9 @@ func ValuesCompare(is bool, a interface{}, o T_Code, args ...interface{}) (err e ...@@ -60,9 +63,9 @@ func ValuesCompare(is bool, a interface{}, o T_Code, args ...interface{}) (err e
60 ok = is && ok || !is && !ok 63 ok = is && ok || !is && !ok
61 if !ok { 64 if !ok {
62 if is { 65 if is {
63 err = fmt.Errorf("should: a == b, a = `%v`, b = `%v`", a, b) 66 err = fmt.Errorf("expected: a == `%v`, get `%v`", b, a)
64 } else { 67 } else {
65 err = fmt.Errorf("should: a != b, a = `%v`, b = `%v`", a, b) 68 err = fmt.Errorf("expected: a != `%v`, get `%v`", b, a)
66 } 69 }
67 } 70 }
68 case T_Less, T_Large: 71 case T_Less, T_Large:
...@@ -89,9 +92,9 @@ func ValuesCompare(is bool, a interface{}, o T_Code, args ...interface{}) (err e ...@@ -89,9 +92,9 @@ func ValuesCompare(is bool, a interface{}, o T_Code, args ...interface{}) (err e
89 ok = is && ok || !is && !ok 92 ok = is && ok || !is && !ok
90 if !ok { 93 if !ok {
91 if is { 94 if is {
92 err = fmt.Errorf("should: a %s b, a = `%v`, b = `%v`", opts[0], f1, f2) 95 err = fmt.Errorf("should: a %s b, but a = `%v`, b = `%v`", opts[0], f1, f2)
93 } else { 96 } else {
94 err = fmt.Errorf("should: a %s b, a = `%v`, b = `%v`", opts[1], f1, f2) 97 err = fmt.Errorf("should: a %s b, but a = `%v`, b = `%v`", opts[1], f1, f2)
95 } 98 }
96 } 99 }
97 } 100 }
...@@ -122,32 +125,51 @@ func getCaller(skip int) string { ...@@ -122,32 +125,51 @@ func getCaller(skip int) string {
122 fun := runtime.FuncForPC(pc) 125 fun := runtime.FuncForPC(pc)
123 _, fn := filepath.Split(file) 126 _, fn := filepath.Split(file)
124 data, err := ioutil.ReadFile(file) 127 data, err := ioutil.ReadFile(file)
125 code := "" 128 var codes []string
126 if err == nil { 129 if err == nil {
127 lines := bytes.Split(data, []byte{'\n'}) 130 lines := bytes.Split(data, []byte{'\n'})
128 code = strings.TrimSpace(string(lines[line-1])) 131 n := 10
132 for i := 0; i < n; i++ {
133 o := line - n
134 if o < 0 {
135 continue
136 }
137 cur := o + i + 1
138 flag := " "
139 if cur == line {
140 flag = ">>"
141 }
142 code := fmt.Sprintf(" %s %5d: %s", flag, cur, strings.TrimSpace(string(lines[o+i])))
143 if code != "" {
144 codes = append(codes, code)
145 }
146 }
129 } 147 }
130 funName := fun.Name() 148 funName := fun.Name()
131 if i := strings.LastIndex(funName, "."); i > -1 { 149 if i := strings.LastIndex(funName, "."); i > -1 {
132 funName = funName[i+1:] 150 funName = funName[i+1:]
133 } 151 }
134 return fmt.Sprintf("%s:%d: %s: %s", fn, line, funName, code) 152 return fmt.Sprintf("%s:%d: \n%s", fn, line, strings.Join(codes, "\n"))
135 } 153 }
136 154
137 func throwFail(t *testing.T, err error, args ...interface{}) { 155 func throwFail(t *testing.T, err error, args ...interface{}) {
138 if err != nil { 156 if err != nil {
139 params := []interface{}{"\n", getCaller(2), "\n", err, "\n"} 157 con := fmt.Sprintf("\t\nError: %s\n%s\n", err.Error(), getCaller(2))
140 params = append(params, args...) 158 if len(args) > 0 {
141 t.Error(params...) 159 con += fmt.Sprint(args...)
160 }
161 t.Error(con)
142 t.Fail() 162 t.Fail()
143 } 163 }
144 } 164 }
145 165
146 func throwFailNow(t *testing.T, err error, args ...interface{}) { 166 func throwFailNow(t *testing.T, err error, args ...interface{}) {
147 if err != nil { 167 if err != nil {
148 params := []interface{}{"\n", getCaller(2), "\n", err, "\n"} 168 con := fmt.Sprintf("\t\nError: %s\n%s\n", err.Error(), getCaller(2))
149 params = append(params, args...) 169 if len(args) > 0 {
150 t.Error(params...) 170 con += fmt.Sprint(args...)
171 }
172 t.Error(con)
151 t.FailNow() 173 t.FailNow()
152 } 174 }
153 } 175 }
...@@ -165,8 +187,8 @@ func TestCRUD(t *testing.T) { ...@@ -165,8 +187,8 @@ func TestCRUD(t *testing.T) {
165 profile.Age = 30 187 profile.Age = 30
166 profile.Money = 1234.12 188 profile.Money = 1234.12
167 id, err := dORM.Insert(profile) 189 id, err := dORM.Insert(profile)
168 throwFailNow(t, err) 190 throwFail(t, err)
169 throwFailNow(t, AssertIs(id, T_Large, 0)) 191 throwFail(t, AssertIs(id, T_Equal, 1))
170 192
171 user := NewUser() 193 user := NewUser()
172 user.UserName = "slene" 194 user.UserName = "slene"
...@@ -177,51 +199,53 @@ func TestCRUD(t *testing.T) { ...@@ -177,51 +199,53 @@ func TestCRUD(t *testing.T) {
177 user.IsActive = true 199 user.IsActive = true
178 200
179 id, err = dORM.Insert(user) 201 id, err = dORM.Insert(user)
180 throwFailNow(t, err) 202 throwFail(t, err)
181 throwFailNow(t, AssertIs(id, T_Large, 0)) 203 throwFail(t, AssertIs(id, T_Equal, 1))
182 204
183 u := &User{Id: user.Id} 205 u := &User{Id: user.Id}
184 err = dORM.Read(u) 206 err = dORM.Read(u)
185 throwFailNow(t, err) 207 throwFail(t, err)
186 208
187 throwFailNow(t, AssertIs(u.UserName, T_Equal, "slene")) 209 throwFail(t, AssertIs(u.UserName, T_Equal, "slene"))
188 throwFailNow(t, AssertIs(u.Email, T_Equal, "vslene@gmail.com")) 210 throwFail(t, AssertIs(u.Email, T_Equal, "vslene@gmail.com"))
189 throwFailNow(t, AssertIs(u.Password, T_Equal, "pass")) 211 throwFail(t, AssertIs(u.Password, T_Equal, "pass"))
190 throwFailNow(t, AssertIs(u.Status, T_Equal, 3)) 212 throwFail(t, AssertIs(u.Status, T_Equal, 3))
191 throwFailNow(t, AssertIs(u.IsStaff, T_Equal, true)) 213 throwFail(t, AssertIs(u.IsStaff, T_Equal, true))
192 throwFailNow(t, AssertIs(u.IsActive, T_Equal, true)) 214 throwFail(t, AssertIs(u.IsActive, T_Equal, true))
193 throwFailNow(t, AssertIs(u.Created, T_Equal, user.Created, format_Date)) 215 throwFail(t, AssertIs(u.Created, T_Equal, user.Created, format_Date))
194 throwFailNow(t, AssertIs(u.Updated, T_Equal, user.Updated, format_DateTime)) 216 throwFail(t, AssertIs(u.Updated, T_Equal, user.Updated, format_DateTime))
195 217
196 user.UserName = "astaxie" 218 user.UserName = "astaxie"
197 user.Profile = profile 219 user.Profile = profile
198 num, err := dORM.Update(user) 220 num, err := dORM.Update(user)
199 throwFailNow(t, err) 221 throwFail(t, err)
200 throwFailNow(t, AssertIs(num, T_Equal, 1)) 222 throwFail(t, AssertIs(num, T_Equal, 1))
201 223
202 u = &User{Id: user.Id} 224 u = &User{Id: user.Id}
203 err = dORM.Read(u) 225 err = dORM.Read(u)
204 throwFailNow(t, err) 226 throwFail(t, err)
205 227
206 throwFailNow(t, AssertIs(u.UserName, T_Equal, "astaxie")) 228 if err == nil {
207 throwFailNow(t, AssertIs(u.Profile.Id, T_Equal, profile.Id)) 229 throwFail(t, AssertIs(u.UserName, T_Equal, "astaxie"))
230 throwFail(t, AssertIs(u.Profile.Id, T_Equal, profile.Id))
231 }
208 232
209 num, err = dORM.Delete(profile) 233 num, err = dORM.Delete(profile)
210 throwFailNow(t, err) 234 throwFail(t, err)
211 throwFailNow(t, AssertIs(num, T_Equal, 1)) 235 throwFail(t, AssertIs(num, T_Equal, 1))
212 236
213 u = &User{Id: user.Id} 237 u = &User{Id: user.Id}
214 err = dORM.Read(u) 238 err = dORM.Read(u)
215 throwFailNow(t, err) 239 throwFail(t, err)
216 throwFailNow(t, AssertIs(true, T_Equal, u.Profile == nil)) 240 throwFail(t, AssertIs(true, T_Equal, u.Profile == nil))
217 241
218 num, err = dORM.Delete(user) 242 num, err = dORM.Delete(user)
219 throwFailNow(t, err) 243 throwFail(t, err)
220 throwFailNow(t, AssertIs(num, T_Equal, 1)) 244 throwFail(t, AssertIs(num, T_Equal, 1))
221 245
222 u = &User{Id: 100} 246 u = &User{Id: 100}
223 err = dORM.Read(u) 247 err = dORM.Read(u)
224 throwFailNow(t, AssertIs(err, T_Equal, ErrNoRows)) 248 throwFail(t, AssertIs(err, T_Equal, ErrNoRows))
225 } 249 }
226 250
227 func TestInsertTestData(t *testing.T) { 251 func TestInsertTestData(t *testing.T) {
...@@ -232,8 +256,8 @@ func TestInsertTestData(t *testing.T) { ...@@ -232,8 +256,8 @@ func TestInsertTestData(t *testing.T) {
232 profile.Money = 1234.12 256 profile.Money = 1234.12
233 257
234 id, err := dORM.Insert(profile) 258 id, err := dORM.Insert(profile)
235 throwFailNow(t, err) 259 throwFail(t, err)
236 throwFailNow(t, AssertIs(id, T_Large, 0)) 260 throwFail(t, AssertIs(id, T_Equal, 2))
237 261
238 user := NewUser() 262 user := NewUser()
239 user.UserName = "slene" 263 user.UserName = "slene"
...@@ -247,16 +271,16 @@ func TestInsertTestData(t *testing.T) { ...@@ -247,16 +271,16 @@ func TestInsertTestData(t *testing.T) {
247 users = append(users, user) 271 users = append(users, user)
248 272
249 id, err = dORM.Insert(user) 273 id, err = dORM.Insert(user)
250 throwFailNow(t, err) 274 throwFail(t, err)
251 throwFailNow(t, AssertIs(id, T_Large, 0)) 275 throwFail(t, AssertIs(id, T_Equal, 2))
252 276
253 profile = NewProfile() 277 profile = NewProfile()
254 profile.Age = 30 278 profile.Age = 30
255 profile.Money = 4321.09 279 profile.Money = 4321.09
256 280
257 id, err = dORM.Insert(profile) 281 id, err = dORM.Insert(profile)
258 throwFailNow(t, err) 282 throwFail(t, err)
259 throwFailNow(t, AssertIs(id, T_Large, 0)) 283 throwFail(t, AssertIs(id, T_Equal, 3))
260 284
261 user = NewUser() 285 user = NewUser()
262 user.UserName = "astaxie" 286 user.UserName = "astaxie"
...@@ -270,8 +294,8 @@ func TestInsertTestData(t *testing.T) { ...@@ -270,8 +294,8 @@ func TestInsertTestData(t *testing.T) {
270 users = append(users, user) 294 users = append(users, user)
271 295
272 id, err = dORM.Insert(user) 296 id, err = dORM.Insert(user)
273 throwFailNow(t, err) 297 throwFail(t, err)
274 throwFailNow(t, AssertIs(id, T_Large, 0)) 298 throwFail(t, AssertIs(id, T_Equal, 3))
275 299
276 user = NewUser() 300 user = NewUser()
277 user.UserName = "nobody" 301 user.UserName = "nobody"
...@@ -284,8 +308,8 @@ func TestInsertTestData(t *testing.T) { ...@@ -284,8 +308,8 @@ func TestInsertTestData(t *testing.T) {
284 users = append(users, user) 308 users = append(users, user)
285 309
286 id, err = dORM.Insert(user) 310 id, err = dORM.Insert(user)
287 throwFailNow(t, err) 311 throwFail(t, err)
288 throwFailNow(t, AssertIs(id, T_Large, 0)) 312 throwFail(t, AssertIs(id, T_Equal, 4))
289 313
290 tags := []*Tag{ 314 tags := []*Tag{
291 &Tag{Name: "golang"}, 315 &Tag{Name: "golang"},
...@@ -315,21 +339,21 @@ The program—and web server—godoc processes Go source files to extract docume ...@@ -315,21 +339,21 @@ The program—and web server—godoc processes Go source files to extract docume
315 339
316 for _, tag := range tags { 340 for _, tag := range tags {
317 id, err := dORM.Insert(tag) 341 id, err := dORM.Insert(tag)
318 throwFailNow(t, err) 342 throwFail(t, err)
319 throwFailNow(t, AssertIs(id, T_Large, 0)) 343 throwFail(t, AssertIs(id, T_Large, 0))
320 } 344 }
321 345
322 for _, post := range posts { 346 for _, post := range posts {
323 id, err := dORM.Insert(post) 347 id, err := dORM.Insert(post)
324 throwFailNow(t, err) 348 throwFail(t, err)
325 throwFailNow(t, AssertIs(id, T_Large, 0)) 349 throwFail(t, AssertIs(id, T_Large, 0))
326 // dORM.M2mAdd(post, "tags", post.Tags) 350 // dORM.M2mAdd(post, "tags", post.Tags)
327 } 351 }
328 352
329 for _, comment := range comments { 353 for _, comment := range comments {
330 id, err := dORM.Insert(comment) 354 id, err := dORM.Insert(comment)
331 throwFailNow(t, err) 355 throwFail(t, err)
332 throwFailNow(t, AssertIs(id, T_Large, 0)) 356 throwFail(t, AssertIs(id, T_Large, 0))
333 } 357 }
334 } 358 }
335 359
...@@ -359,9 +383,17 @@ func TestOperators(t *testing.T) { ...@@ -359,9 +383,17 @@ func TestOperators(t *testing.T) {
359 throwFail(t, err) 383 throwFail(t, err)
360 throwFail(t, AssertIs(num, T_Equal, 2)) 384 throwFail(t, AssertIs(num, T_Equal, 2))
361 385
386 var shouldNum int
387
388 if IsSqlite {
389 shouldNum = 2
390 } else {
391 shouldNum = 0
392 }
393
362 num, err = qs.Filter("user_name__contains", "E").Count() 394 num, err = qs.Filter("user_name__contains", "E").Count()
363 throwFail(t, err) 395 throwFail(t, err)
364 throwFail(t, AssertIs(num, T_Equal, 0)) 396 throwFail(t, AssertIs(num, T_Equal, shouldNum))
365 397
366 num, err = qs.Filter("user_name__icontains", "E").Count() 398 num, err = qs.Filter("user_name__icontains", "E").Count()
367 throwFail(t, err) 399 throwFail(t, err)
...@@ -391,9 +423,15 @@ func TestOperators(t *testing.T) { ...@@ -391,9 +423,15 @@ func TestOperators(t *testing.T) {
391 throwFail(t, err) 423 throwFail(t, err)
392 throwFail(t, AssertIs(num, T_Equal, 1)) 424 throwFail(t, AssertIs(num, T_Equal, 1))
393 425
426 if IsSqlite {
427 shouldNum = 1
428 } else {
429 shouldNum = 0
430 }
431
394 num, err = qs.Filter("user_name__startswith", "S").Count() 432 num, err = qs.Filter("user_name__startswith", "S").Count()
395 throwFail(t, err) 433 throwFail(t, err)
396 throwFail(t, AssertIs(num, T_Equal, 0)) 434 throwFail(t, AssertIs(num, T_Equal, shouldNum))
397 435
398 num, err = qs.Filter("user_name__istartswith", "S").Count() 436 num, err = qs.Filter("user_name__istartswith", "S").Count()
399 throwFail(t, err) 437 throwFail(t, err)
...@@ -403,9 +441,15 @@ func TestOperators(t *testing.T) { ...@@ -403,9 +441,15 @@ func TestOperators(t *testing.T) {
403 throwFail(t, err) 441 throwFail(t, err)
404 throwFail(t, AssertIs(num, T_Equal, 2)) 442 throwFail(t, AssertIs(num, T_Equal, 2))
405 443
444 if IsSqlite {
445 shouldNum = 2
446 } else {
447 shouldNum = 0
448 }
449
406 num, err = qs.Filter("user_name__endswith", "E").Count() 450 num, err = qs.Filter("user_name__endswith", "E").Count()
407 throwFail(t, err) 451 throwFail(t, err)
408 throwFail(t, AssertIs(num, T_Equal, 0)) 452 throwFail(t, AssertIs(num, T_Equal, shouldNum))
409 453
410 num, err = qs.Filter("user_name__iendswith", "E").Count() 454 num, err = qs.Filter("user_name__iendswith", "E").Count()
411 throwFail(t, err) 455 throwFail(t, err)
...@@ -537,7 +581,6 @@ func TestRelatedSel(t *testing.T) { ...@@ -537,7 +581,6 @@ func TestRelatedSel(t *testing.T) {
537 throwFail(t, err) 581 throwFail(t, err)
538 throwFail(t, AssertIs(num, T_Equal, 1)) 582 throwFail(t, AssertIs(num, T_Equal, 1))
539 throwFail(t, AssertNot(user.Profile, T_Equal, nil)) 583 throwFail(t, AssertNot(user.Profile, T_Equal, nil))
540 throwFail(t, AssertIs(user.Profile.Age, T_Equal, 28))
541 if user.Profile != nil { 584 if user.Profile != nil {
542 throwFail(t, AssertIs(user.Profile.Age, T_Equal, 28)) 585 throwFail(t, AssertIs(user.Profile.Age, T_Equal, 28))
543 } 586 }
...@@ -617,7 +660,7 @@ func TestOrderBy(t *testing.T) { ...@@ -617,7 +660,7 @@ func TestOrderBy(t *testing.T) {
617 func TestPrepareInsert(t *testing.T) { 660 func TestPrepareInsert(t *testing.T) {
618 qs := dORM.QueryTable("user") 661 qs := dORM.QueryTable("user")
619 i, err := qs.PrepareInsert() 662 i, err := qs.PrepareInsert()
620 throwFail(t, err) 663 throwFailNow(t, err)
621 664
622 var user User 665 var user User
623 user.UserName = "testing1" 666 user.UserName = "testing1"
...@@ -641,15 +684,18 @@ func TestPrepareInsert(t *testing.T) { ...@@ -641,15 +684,18 @@ func TestPrepareInsert(t *testing.T) {
641 } 684 }
642 685
643 func TestRaw(t *testing.T) { 686 func TestRaw(t *testing.T) {
644 switch dORM.Driver().Type() { 687 switch {
645 case DR_MySQL: 688 case IsMysql || IsSqlite:
646 num, err := dORM.Raw("UPDATE user SET user_name = ? WHERE user_name = ?", "testing", "slene").Exec() 689
690 res, err := dORM.Raw("UPDATE user SET user_name = ? WHERE user_name = ?", "testing", "slene").Exec()
647 throwFail(t, err) 691 throwFail(t, err)
648 throwFail(t, AssertIs(num, T_Equal, 1)) 692 num, err := res.RowsAffected()
693 throwFail(t, AssertIs(num, T_Equal, 1), err)
649 694
650 num, err = dORM.Raw("UPDATE user SET user_name = ? WHERE user_name = ?", "slene", "testing").Exec() 695 res, err = dORM.Raw("UPDATE user SET user_name = ? WHERE user_name = ?", "slene", "testing").Exec()
651 throwFail(t, err) 696 throwFail(t, err)
652 throwFail(t, AssertIs(num, T_Equal, 1)) 697 num, err = res.RowsAffected()
698 throwFail(t, AssertIs(num, T_Equal, 1), err)
653 699
654 var maps []Params 700 var maps []Params
655 num, err = dORM.Raw("SELECT user_name FROM user WHERE status = ?", 1).Values(&maps) 701 num, err = dORM.Raw("SELECT user_name FROM user WHERE status = ?", 1).Values(&maps)
...@@ -681,11 +727,18 @@ func TestRaw(t *testing.T) { ...@@ -681,11 +727,18 @@ func TestRaw(t *testing.T) {
681 727
682 func TestUpdate(t *testing.T) { 728 func TestUpdate(t *testing.T) {
683 qs := dORM.QueryTable("user") 729 qs := dORM.QueryTable("user")
684 num, err := qs.Filter("user_name", "slene").Update(Params{ 730 num, err := qs.Filter("user_name", "slene").Filter("is_staff", false).Update(Params{
685 "is_staff": true, 731 "is_staff": true,
686 }) 732 })
687 throwFail(t, err) 733 throwFail(t, err)
688 throwFail(t, AssertIs(num, T_Equal, 1)) 734 throwFail(t, AssertIs(num, T_Equal, 1))
735
736 // with join
737 num, err = qs.Filter("user_name", "slene").Filter("profile__age", 28).Filter("is_staff", true).Update(Params{
738 "is_staff": false,
739 })
740 throwFail(t, err)
741 throwFail(t, AssertIs(num, T_Equal, 1))
689 } 742 }
690 743
691 func TestDelete(t *testing.T) { 744 func TestDelete(t *testing.T) {
...@@ -701,48 +754,54 @@ func TestDelete(t *testing.T) { ...@@ -701,48 +754,54 @@ func TestDelete(t *testing.T) {
701 } 754 }
702 755
703 func TestTransaction(t *testing.T) { 756 func TestTransaction(t *testing.T) {
757 // this test worked when database support transaction
758
704 o := NewOrm() 759 o := NewOrm()
705 err := o.Begin() 760 err := o.Begin()
706 throwFail(t, err) 761 throwFail(t, err)
707 762
708 var names = []string{"1", "2", "3"} 763 var names = []string{"1", "2", "3"}
709 764
710 var user User 765 var tag Tag
711 user.UserName = names[0] 766 tag.Name = names[0]
712 id, err := o.Insert(&user) 767 id, err := o.Insert(&tag)
713 throwFail(t, err) 768 throwFail(t, err)
714 throwFail(t, AssertIs(id, T_Large, 0)) 769 throwFail(t, AssertIs(id, T_Large, 0))
715 770
716 num, err := o.QueryTable("user").Filter("user_name", "slene").Update(Params{"user_name": names[1]}) 771 num, err := o.QueryTable("tag").Filter("name", "golang").Update(Params{"name": names[1]})
717 throwFail(t, err) 772 throwFail(t, err)
718 throwFail(t, AssertIs(num, T_Large, 0)) 773 throwFail(t, AssertIs(num, T_Equal, 1))
719 774
720 switch o.Driver().Type() { 775 switch {
721 case DR_MySQL: 776 case IsMysql || IsSqlite:
722 id, err := o.Raw("INSERT INTO user (user_name) VALUES (?)", names[2]).Exec() 777 res, err := o.Raw("INSERT INTO tag (name) VALUES (?)", names[2]).Exec()
778 throwFail(t, err)
779 if err == nil {
780 id, err = res.LastInsertId()
723 throwFail(t, err) 781 throwFail(t, err)
724 throwFail(t, AssertIs(id, T_Large, 0)) 782 throwFail(t, AssertIs(id, T_Large, 0))
725 } 783 }
784 }
726 785
727 err = o.Rollback() 786 err = o.Rollback()
728 throwFail(t, err) 787 throwFail(t, err)
729 788
730 num, err = o.QueryTable("user").Filter("user_name__in", &user).Count() 789 num, err = o.QueryTable("tag").Filter("name__in", names).Count()
731 throwFail(t, err) 790 throwFail(t, err)
732 throwFail(t, AssertIs(num, T_Equal, 0)) 791 throwFail(t, AssertIs(num, T_Equal, 0))
733 792
734 err = o.Begin() 793 err = o.Begin()
735 throwFail(t, err) 794 throwFail(t, err)
736 795
737 user.UserName = "commit" 796 tag.Name = "commit"
738 id, err = o.Insert(&user) 797 id, err = o.Insert(&tag)
739 throwFail(t, err) 798 throwFail(t, err)
740 throwFail(t, AssertIs(id, T_Large, 0)) 799 throwFail(t, AssertIs(id, T_Large, 0))
741 800
742 o.Commit() 801 o.Commit()
743 throwFail(t, err) 802 throwFail(t, err)
744 803
745 num, err = o.QueryTable("user").Filter("user_name", "commit").Delete() 804 num, err = o.QueryTable("tag").Filter("name", "commit").Delete()
746 throwFail(t, err) 805 throwFail(t, err)
747 throwFail(t, AssertIs(num, T_Equal, 1)) 806 throwFail(t, AssertIs(num, T_Equal, 1))
748 807
......
...@@ -60,12 +60,12 @@ type QuerySeter interface { ...@@ -60,12 +60,12 @@ type QuerySeter interface {
60 } 60 }
61 61
62 type RawPreparer interface { 62 type RawPreparer interface {
63 Exec(...interface{}) (int64, error) 63 Exec(...interface{}) (sql.Result, error)
64 Close() error 64 Close() error
65 } 65 }
66 66
67 type RawSeter interface { 67 type RawSeter interface {
68 Exec() (int64, error) 68 Exec() (sql.Result, error)
69 QueryRow(...interface{}) error 69 QueryRow(...interface{}) error
70 QueryRows(...interface{}) (int64, error) 70 QueryRows(...interface{}) (int64, error)
71 SetArgs(...interface{}) RawSeter 71 SetArgs(...interface{}) RawSeter
...@@ -116,10 +116,15 @@ type dbBaser interface { ...@@ -116,10 +116,15 @@ type dbBaser interface {
116 Update(dbQuerier, *modelInfo, reflect.Value) (int64, error) 116 Update(dbQuerier, *modelInfo, reflect.Value) (int64, error)
117 Delete(dbQuerier, *modelInfo, reflect.Value) (int64, error) 117 Delete(dbQuerier, *modelInfo, reflect.Value) (int64, error)
118 ReadBatch(dbQuerier, *querySet, *modelInfo, *Condition, interface{}) (int64, error) 118 ReadBatch(dbQuerier, *querySet, *modelInfo, *Condition, interface{}) (int64, error)
119 SupportUpdateJoin() bool
119 UpdateBatch(dbQuerier, *querySet, *modelInfo, *Condition, Params) (int64, error) 120 UpdateBatch(dbQuerier, *querySet, *modelInfo, *Condition, Params) (int64, error)
120 DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition) (int64, error) 121 DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition) (int64, error)
121 Count(dbQuerier, *querySet, *modelInfo, *Condition) (int64, error) 122 Count(dbQuerier, *querySet, *modelInfo, *Condition) (int64, error)
122 GetOperatorSql(*modelInfo, string, []interface{}) (string, []interface{}) 123 OperatorSql(string) string
124 GenerateOperatorSql(*modelInfo, string, []interface{}) (string, []interface{})
123 PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error) 125 PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error)
124 ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}) (int64, error) 126 ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}) (int64, error)
127 MaxLimit() uint64
128 TableQuote() string
129 ReplaceMarks(*string)
125 } 130 }
......
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!