bce35c70 by slene

init orm project, beta, unstable

1 parent ccbf116f
1 ## beego orm
2
3 a powerful orm framework
4
5 now, beta, unstable, may be changing some api make your app build failed.
6
7 ## TODO
8 - some unrealized api
9 - examples
10 - docs
11 - support postgres
12 - support sqlite
...\ No newline at end of file ...\ No newline at end of file
1 package orm
2
3 import (
4 "flag"
5 "fmt"
6 "os"
7 )
8
9 func printHelp() {
10
11 }
12
13 func getSqlAll() (sql string) {
14 for _, mi := range modelCache.allOrdered() {
15 _ = mi
16 }
17 return
18 }
19
20 func runCommand() {
21 if len(os.Args) < 2 || os.Args[1] != "orm" {
22 return
23 }
24
25 _ = flag.NewFlagSet("orm command", flag.ExitOnError)
26
27 args := argString(os.Args[2:])
28 cmd := args.Get(0)
29
30 switch cmd {
31 case "syncdb":
32 case "sqlall":
33 sql := getSqlAll()
34 fmt.Println(sql)
35 default:
36 if cmd != "" {
37 fmt.Printf("unknown command %s", cmd)
38 } else {
39 printHelp()
40 }
41
42 os.Exit(2)
43 }
44 }
1 package orm
2
3 import (
4 "database/sql"
5 "errors"
6 "fmt"
7 "reflect"
8 "strings"
9 "time"
10 )
11
12 const (
13 format_Date = "2006-01-02"
14 format_DateTime = "2006-01-02 15:04:05"
15 )
16
17 var (
18 ErrMissPK = errors.New("missed pk value")
19 )
20
21 var (
22 operators = map[string]bool{
23 "exact": true,
24 "iexact": true,
25 "contains": true,
26 "icontains": true,
27 // "regex": true,
28 // "iregex": true,
29 "gt": true,
30 "gte": true,
31 "lt": true,
32 "lte": true,
33 "startswith": true,
34 "endswith": true,
35 "istartswith": true,
36 "iendswith": true,
37 "in": true,
38 // "range": true,
39 // "year": true,
40 // "month": true,
41 // "day": true,
42 // "week_day": true,
43 "isnull": true,
44 // "search": true,
45 }
46 operatorsSQL = map[string]string{
47 "exact": "= ?",
48 "iexact": "LIKE ?",
49 "contains": "LIKE BINARY ?",
50 "icontains": "LIKE ?",
51 // "regex": "REGEXP BINARY ?",
52 // "iregex": "REGEXP ?",
53 "gt": "> ?",
54 "gte": ">= ?",
55 "lt": "< ?",
56 "lte": "<= ?",
57 "startswith": "LIKE BINARY ?",
58 "endswith": "LIKE BINARY ?",
59 "istartswith": "LIKE ?",
60 "iendswith": "LIKE ?",
61 }
62 )
63
64 type dbTable struct {
65 id int
66 index string
67 name string
68 names []string
69 sel bool
70 inner bool
71 mi *modelInfo
72 fi *fieldInfo
73 jtl *dbTable
74 }
75
76 type dbTables struct {
77 tablesM map[string]*dbTable
78 tables []*dbTable
79 mi *modelInfo
80 base dbBaser
81 }
82
83 func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool) *dbTable {
84 name := strings.Join(names, ExprSep)
85 if j, ok := t.tablesM[name]; ok {
86 j.name = name
87 j.mi = mi
88 j.fi = fi
89 j.inner = inner
90 } else {
91 i := len(t.tables) + 1
92 jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil}
93 t.tablesM[name] = jt
94 t.tables = append(t.tables, jt)
95 }
96 return t.tablesM[name]
97 }
98
99 func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool) (*dbTable, bool) {
100 name := strings.Join(names, ExprSep)
101 if _, ok := t.tablesM[name]; ok == false {
102 i := len(t.tables) + 1
103 jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil}
104 t.tablesM[name] = jt
105 t.tables = append(t.tables, jt)
106 return jt, true
107 }
108 return t.tablesM[name], false
109 }
110
111 func (t *dbTables) get(name string) (*dbTable, bool) {
112 j, ok := t.tablesM[name]
113 return j, ok
114 }
115
116 func (t *dbTables) loopDepth(depth int, prefix string, fi *fieldInfo, related []string) []string {
117 if depth < 0 || fi.fieldType == RelManyToMany {
118 return related
119 }
120
121 if prefix == "" {
122 prefix = fi.name
123 } else {
124 prefix = prefix + ExprSep + fi.name
125 }
126 related = append(related, prefix)
127
128 depth--
129 for _, fi := range fi.relModelInfo.fields.fieldsRel {
130 related = t.loopDepth(depth, prefix, fi, related)
131 }
132
133 return related
134 }
135
136 func (t *dbTables) parseRelated(rels []string, depth int) {
137
138 relsNum := len(rels)
139 related := make([]string, relsNum)
140 copy(related, rels)
141
142 relDepth := depth
143
144 if relsNum != 0 {
145 relDepth = 0
146 }
147
148 relDepth--
149 for _, fi := range t.mi.fields.fieldsRel {
150 related = t.loopDepth(relDepth, "", fi, related)
151 }
152
153 for i, s := range related {
154 var (
155 exs = strings.Split(s, ExprSep)
156 names = make([]string, 0, len(exs))
157 mmi = t.mi
158 cansel = true
159 jtl *dbTable
160 )
161 for _, ex := range exs {
162 if fi, ok := mmi.fields.GetByAny(ex); ok && fi.rel && fi.fieldType != RelManyToMany {
163 names = append(names, fi.name)
164 mmi = fi.relModelInfo
165
166 jt := t.set(names, mmi, fi, fi.null == false)
167 jt.jtl = jtl
168
169 if fi.reverse {
170 cansel = false
171 }
172
173 if cansel {
174 jt.sel = depth > 0
175
176 if i < relsNum {
177 jt.sel = true
178 }
179 }
180
181 jtl = jt
182
183 } else {
184 panic(fmt.Sprintf("unknown model/table name `%s`", ex))
185 }
186 }
187 }
188 }
189
190 func (t *dbTables) getJoinSql() (join string) {
191 for _, jt := range t.tables {
192 if jt.inner {
193 join += "INNER JOIN "
194 } else {
195 join += "LEFT OUTER JOIN "
196 }
197 var (
198 table string
199 t1, t2 string
200 c1, c2 string
201 )
202 t1 = "T0"
203 if jt.jtl != nil {
204 t1 = jt.jtl.index
205 }
206 t2 = jt.index
207 table = jt.mi.table
208
209 switch {
210 case jt.fi.fieldType == RelManyToMany || jt.fi.reverse && jt.fi.reverseFieldInfo.fieldType == RelManyToMany:
211 c1 = jt.fi.mi.fields.pk[0].column
212 for _, ffi := range jt.mi.fields.fieldsRel {
213 if jt.fi.mi == ffi.relModelInfo {
214 c2 = ffi.column
215 break
216 }
217 }
218 default:
219 c1 = jt.fi.column
220 c2 = jt.fi.relModelInfo.fields.pk[0].column
221
222 if jt.fi.reverse {
223 c1 = jt.mi.fields.pk[0].column
224 c2 = jt.fi.reverseFieldInfo.column
225 }
226 }
227
228 join += fmt.Sprintf("`%s` %s ON %s.`%s` = %s.`%s` ", table, t2,
229 t2, c2, t1, c1)
230 }
231 return
232 }
233
234 func (d *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, column string, info *fieldInfo, success bool) {
235 var (
236 ffi *fieldInfo
237 jtl *dbTable
238 mmi = mi
239 )
240
241 num := len(exprs) - 1
242 names := make([]string, 0)
243
244 for i, ex := range exprs {
245 exist := false
246
247 check:
248 fi, ok := mmi.fields.GetByAny(ex)
249
250 if ok {
251
252 if num != i {
253 names = append(names, fi.name)
254
255 switch {
256 case fi.rel:
257 mmi = fi.relModelInfo
258 if fi.fieldType == RelManyToMany {
259 mmi = fi.relThroughModelInfo
260 }
261 case fi.reverse:
262 mmi = fi.reverseFieldInfo.mi
263 if fi.reverseFieldInfo.fieldType == RelManyToMany {
264 mmi = fi.reverseFieldInfo.relThroughModelInfo
265 }
266 }
267
268 jt, _ := d.add(names, mmi, fi, fi.null == false)
269 jt.jtl = jtl
270 jtl = jt
271
272 if fi.rel && fi.fieldType == RelManyToMany {
273 ex = fi.relModelInfo.name
274 goto check
275 }
276
277 if fi.reverse && fi.reverseFieldInfo.fieldType == RelManyToMany {
278 ex = fi.reverseFieldInfo.mi.name
279 goto check
280 }
281
282 exist = true
283
284 } else {
285
286 if ffi == nil {
287 index = "T0"
288 } else {
289 index = jtl.index
290 }
291 column = fi.column
292 info = fi
293
294 switch fi.fieldType {
295 case RelManyToMany, RelReverseMany:
296 default:
297 exist = true
298 }
299 }
300
301 ffi = fi
302 }
303
304 if exist == false {
305 index = ""
306 column = ""
307 success = false
308 return
309 }
310 }
311
312 success = index != "" && column != ""
313 return
314 }
315
316 func (d *dbTables) getCondSql(cond *Condition, sub bool) (where string, params []interface{}) {
317 if cond == nil || cond.IsEmpty() {
318 return
319 }
320
321 mi := d.mi
322
323 // outFor:
324 for i, p := range cond.params {
325 if i > 0 {
326 if p.isOr {
327 where += "OR "
328 } else {
329 where += "AND "
330 }
331 }
332 if p.isNot {
333 where += "NOT "
334 }
335 if p.isCond {
336 w, ps := d.getCondSql(p.cond, true)
337 if w != "" {
338 w = fmt.Sprintf("( %s) ", w)
339 }
340 where += w
341 params = append(params, ps...)
342 } else {
343 exprs := p.exprs
344
345 num := len(exprs) - 1
346 operator := ""
347 if operators[exprs[num]] {
348 operator = exprs[num]
349 exprs = exprs[:num]
350 }
351
352 index, column, _, suc := d.parseExprs(mi, exprs)
353 if suc == false {
354 panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(p.exprs, ExprSep)))
355 }
356
357 if operator == "" {
358 operator = "exact"
359 }
360
361 operSql, args := d.base.GetOperatorSql(mi, operator, p.args)
362
363 where += fmt.Sprintf("%s.`%s` %s ", index, column, operSql)
364 params = append(params, args...)
365
366 }
367 }
368
369 if sub == false && where != "" {
370 where = "WHERE " + where
371 }
372
373 return
374 }
375
376 func (d *dbTables) getOrderSql(orders []string) (orderSql string) {
377 if len(orders) == 0 {
378 return
379 }
380
381 orderSqls := make([]string, 0, len(orders))
382 for _, order := range orders {
383 asc := "ASC"
384 if order[0] == '-' {
385 asc = "DESC"
386 order = order[1:]
387 }
388 exprs := strings.Split(order, ExprSep)
389
390 index, column, _, suc := d.parseExprs(d.mi, exprs)
391 if suc == false {
392 panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep)))
393 }
394
395 orderSqls = append(orderSqls, fmt.Sprintf("%s.`%s` %s", index, column, asc))
396 }
397
398 orderSql = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", "))
399 return
400 }
401
402 func (d *dbTables) getLimitSql(offset int64, limit int) (limits string) {
403 if limit == 0 {
404 limit = DefaultRowsLimit
405 }
406 if limit < 0 {
407 // no limit
408 if offset > 0 {
409 limits = fmt.Sprintf("OFFSET %d", offset)
410 }
411 } else if offset <= 0 {
412 limits = fmt.Sprintf("LIMIT %d", limit)
413 } else {
414 limits = fmt.Sprintf("LIMIT %d OFFSET %d", limit, offset)
415 }
416 return
417 }
418
419 func newDbTables(mi *modelInfo, base dbBaser) *dbTables {
420 tables := &dbTables{}
421 tables.tablesM = make(map[string]*dbTable)
422 tables.mi = mi
423 tables.base = base
424 return tables
425 }
426
427 type dbBase struct {
428 ins dbBaser
429 }
430
431 func (d *dbBase) existPk(mi *modelInfo, ind reflect.Value) ([]string, []interface{}, bool) {
432 exist := true
433 columns := make([]string, 0, len(mi.fields.pk))
434 values := make([]interface{}, 0, len(mi.fields.pk))
435 for _, fi := range mi.fields.pk {
436 v := ind.Field(fi.fieldIndex)
437 if fi.fieldType&IsIntegerField > 0 {
438 vu := v.Int()
439 if exist {
440 exist = vu > 0
441 }
442 values = append(values, vu)
443 } else {
444 vu := v.String()
445 if exist {
446 exist = vu != ""
447 }
448 values = append(values, vu)
449 }
450 columns = append(columns, fi.column)
451 }
452 return columns, values, exist
453 }
454
455 func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, skipAuto bool, insert bool) (columns []string, values []interface{}, err error) {
456 _, pkValues, _ := d.existPk(mi, ind)
457 for _, column := range mi.fields.orders {
458 fi := mi.fields.columns[column]
459 if fi.dbcol == false || fi.auto && skipAuto {
460 continue
461 }
462 var value interface{}
463 if i, ok := mi.fields.pk.Exist(fi); ok {
464 value = pkValues[i]
465 } else {
466 field := ind.Field(fi.fieldIndex)
467 if fi.isFielder {
468 f := field.Addr().Interface().(Fielder)
469 value = f.RawValue()
470 } else {
471 switch fi.fieldType {
472 case TypeBooleanField:
473 value = field.Bool()
474 case TypeCharField, TypeTextField:
475 value = field.String()
476 case TypeFloatField, TypeDecimalField:
477 value = field.Float()
478 case TypeDateField, TypeDateTimeField:
479 value = field.Interface()
480 default:
481 switch {
482 case fi.fieldType&IsPostiveIntegerField > 0:
483 value = field.Uint()
484 case fi.fieldType&IsIntegerField > 0:
485 value = field.Int()
486 case fi.fieldType&IsRelField > 0:
487 if field.IsNil() {
488 value = nil
489 } else {
490 _, fvalues, fok := d.existPk(fi.relModelInfo, reflect.Indirect(field))
491 if fok {
492 value = fvalues[0]
493 } else {
494 value = nil
495 }
496 }
497 if fi.null == false && value == nil {
498 return nil, nil, errors.New(fmt.Sprintf("field `%s` cannot be NULL", fi.fullName))
499 }
500 }
501 }
502 }
503 switch fi.fieldType {
504 case TypeDateField, TypeDateTimeField:
505 if fi.auto_now || fi.auto_now_add && insert {
506 tnow := time.Now()
507 if fi.fieldType == TypeDateField {
508 value = timeFormat(tnow, format_Date)
509 } else {
510 value = timeFormat(tnow, format_DateTime)
511 }
512 if fi.isFielder {
513 f := field.Addr().Interface().(Fielder)
514 f.SetRaw(tnow)
515 } else {
516 field.Set(reflect.ValueOf(tnow))
517 }
518 }
519 }
520 }
521 columns = append(columns, column)
522 values = append(values, value)
523 }
524 return
525 }
526
527 func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (*sql.Stmt, error) {
528 dbcols := make([]string, 0, len(mi.fields.dbcols))
529 marks := make([]string, 0, len(mi.fields.dbcols))
530 for _, fi := range mi.fields.fieldsDB {
531 if fi.auto == false {
532 dbcols = append(dbcols, fi.column)
533 marks = append(marks, "?")
534 }
535 }
536 qmarks := strings.Join(marks, ", ")
537 columns := strings.Join(dbcols, "`,`")
538
539 query := fmt.Sprintf("INSERT INTO `%s` (`%s`) VALUES (%s)", mi.table, columns, qmarks)
540 return q.Prepare(query)
541 }
542
543 func (d *dbBase) InsertStmt(stmt *sql.Stmt, mi *modelInfo, ind reflect.Value) (int64, error) {
544 _, values, err := d.collectValues(mi, ind, true, true)
545 if err != nil {
546 return 0, err
547 }
548
549 if res, err := stmt.Exec(values...); err == nil {
550 return res.LastInsertId()
551 } else {
552 return 0, err
553 }
554 }
555
556 func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, error) {
557 names, values, err := d.collectValues(mi, ind, true, true)
558 if err != nil {
559 return 0, err
560 }
561
562 marks := make([]string, len(names))
563 for i, _ := range marks {
564 marks[i] = "?"
565 }
566 qmarks := strings.Join(marks, ", ")
567 columns := strings.Join(names, "`,`")
568
569 query := fmt.Sprintf("INSERT INTO `%s` (`%s`) VALUES (%s)", mi.table, columns, qmarks)
570
571 if res, err := q.Exec(query, values...); err == nil {
572 return res.LastInsertId()
573 } else {
574 return 0, err
575 }
576 }
577
578 func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, error) {
579 pkNames, pkValues, ok := d.existPk(mi, ind)
580 if ok == false {
581 return 0, ErrMissPK
582 }
583 setNames, setValues, err := d.collectValues(mi, ind, true, false)
584 if err != nil {
585 return 0, err
586 }
587
588 pkColumns := strings.Join(pkNames, "` = ? AND `")
589 setColumns := strings.Join(setNames, "` = ?, `")
590
591 query := fmt.Sprintf("UPDATE `%s` SET `%s` = ? WHERE `%s` = ?", mi.table, setColumns, pkColumns)
592
593 setValues = append(setValues, pkValues...)
594
595 if res, err := q.Exec(query, setValues...); err == nil {
596 return res.RowsAffected()
597 } else {
598 return 0, err
599 }
600 return 0, nil
601 }
602
603 func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, error) {
604 names, values, ok := d.existPk(mi, ind)
605 if ok == false {
606 return 0, ErrMissPK
607 }
608
609 columns := strings.Join(names, "` = ? AND `")
610
611 query := fmt.Sprintf("DELETE FROM `%s` WHERE `%s` = ?", mi.table, columns)
612
613 if res, err := q.Exec(query, values...); err == nil {
614
615 num, err := res.RowsAffected()
616 if err != nil {
617 return 0, err
618 }
619
620 if num > 0 {
621 if mi.fields.auto != nil {
622 ind.Field(mi.fields.auto.fieldIndex).SetInt(0)
623 }
624
625 if len(names) == 1 {
626 err := d.deleteRels(q, mi, values)
627 if err != nil {
628 return num, err
629 }
630 }
631 }
632
633 return num, err
634 } else {
635 return 0, err
636 }
637 return 0, nil
638 }
639
640 func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, params Params) (int64, error) {
641 columns := make([]string, 0, len(params))
642 values := make([]interface{}, 0, len(params))
643 for col, val := range params {
644 column := snakeString(col)
645 if fi, ok := mi.fields.columns[column]; ok == false || fi.dbcol == false {
646 panic(fmt.Sprintf("wrong field/column name `%s`", column))
647 }
648 columns = append(columns, column)
649 values = append(values, val)
650 }
651
652 if len(columns) == 0 {
653 panic("update params cannot empty")
654 }
655
656 tables := newDbTables(mi, d.ins)
657 if qs != nil {
658 tables.parseRelated(qs.related, qs.relDepth)
659 }
660
661 where, args := tables.getCondSql(cond, false)
662
663 join := tables.getJoinSql()
664
665 query := fmt.Sprintf("UPDATE `%s` T0 %sSET T0.`%s` = ? %s", mi.table, join, strings.Join(columns, "` = ?, T0.`"), where)
666
667 values = append(values, args...)
668
669 if res, err := q.Exec(query, values...); err == nil {
670 return res.RowsAffected()
671 } else {
672 return 0, err
673 }
674 return 0, nil
675 }
676
677 func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}) error {
678 for _, fi := range mi.fields.fieldsReverse {
679 fi = fi.reverseFieldInfo
680 switch fi.onDelete {
681 case od_CASCADE:
682 cond := NewCondition()
683 cond.And(fmt.Sprintf("%s__in", fi.name), args...)
684 _, err := d.DeleteBatch(q, nil, fi.mi, cond)
685 if err != nil {
686 return err
687 }
688 case od_SET_DEFAULT, od_SET_NULL:
689 cond := NewCondition()
690 cond.And(fmt.Sprintf("%s__in", fi.name), args...)
691 params := Params{fi.column: nil}
692 if fi.onDelete == od_SET_DEFAULT {
693 params[fi.column] = fi.initial.String()
694 }
695 _, err := d.UpdateBatch(q, nil, fi.mi, cond, params)
696 if err != nil {
697 return err
698 }
699 case od_DO_NOTHING:
700 }
701 }
702 return nil
703 }
704
705 func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition) (int64, error) {
706 tables := newDbTables(mi, d.ins)
707 if qs != nil {
708 tables.parseRelated(qs.related, qs.relDepth)
709 }
710
711 if cond == nil || cond.IsEmpty() {
712 panic("delete operation cannot execute without condition")
713 }
714
715 where, args := tables.getCondSql(cond, false)
716 join := tables.getJoinSql()
717
718 colsNum := len(mi.fields.pk)
719 cols := make([]string, colsNum)
720 for i, fi := range mi.fields.pk {
721 cols[i] = fi.column
722 }
723 colsql := fmt.Sprintf("T0.`%s`", strings.Join(cols, "`, T0.`"))
724 query := fmt.Sprintf("SELECT %s FROM `%s` T0 %s%s", colsql, mi.table, join, where)
725
726 var rs *sql.Rows
727 if r, err := q.Query(query, args...); err != nil {
728 return 0, err
729 } else {
730 rs = r
731 }
732
733 refs := make([]interface{}, colsNum)
734 for i, _ := range refs {
735 var ref string
736 refs[i] = &ref
737 }
738
739 args = make([]interface{}, 0)
740 cnt := 0
741 for rs.Next() {
742 if err := rs.Scan(refs...); err != nil {
743 return 0, err
744 }
745 for _, ref := range refs {
746 args = append(args, reflect.ValueOf(ref).Elem().Interface())
747 }
748 cnt++
749 }
750
751 if cnt == 0 {
752 return 0, nil
753 }
754
755 if colsNum > 1 {
756 columns := strings.Join(cols, "` = ? AND `")
757 query = fmt.Sprintf("DELETE FROM `%s` WHERE `%s` = ?", mi.table, columns)
758 } else {
759 var sql string
760 sql, args = d.ins.GetOperatorSql(mi, "in", args)
761 query = fmt.Sprintf("DELETE FROM `%s` WHERE `%s` %s", mi.table, cols[0], sql)
762 }
763
764 if res, err := q.Exec(query, args...); err == nil {
765 num, err := res.RowsAffected()
766 if err != nil {
767 return 0, err
768 }
769
770 if colsNum == 1 && num > 0 {
771 err := d.deleteRels(q, mi, args)
772 if err != nil {
773 return num, err
774 }
775 }
776
777 return num, nil
778 } else {
779 return 0, err
780 }
781
782 return 0, nil
783 }
784
785 func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}) (int64, error) {
786
787 val := reflect.ValueOf(container)
788 ind := reflect.Indirect(val)
789 typ := ind.Type()
790
791 errTyp := true
792
793 one := true
794
795 if val.Kind() == reflect.Ptr {
796 tp := typ
797 if ind.Kind() == reflect.Slice {
798 one = false
799 if ind.Type().Elem().Kind() == reflect.Ptr {
800 tp = ind.Type().Elem().Elem()
801 }
802 }
803 errTyp = tp.PkgPath()+"."+tp.Name() != mi.fullName
804 }
805
806 if errTyp {
807 panic(fmt.Sprintf("wrong object type `%s` for rows scan, need *[]*%s or *%s", val.Type(), mi.fullName, mi.fullName))
808 }
809
810 rlimit := qs.limit
811 offset := qs.offset
812 if one {
813 rlimit = 0
814 offset = 0
815 }
816
817 tables := newDbTables(mi, d.ins)
818 tables.parseRelated(qs.related, qs.relDepth)
819
820 where, args := tables.getCondSql(cond, false)
821 orderBy := tables.getOrderSql(qs.orders)
822 limit := tables.getLimitSql(offset, rlimit)
823 join := tables.getJoinSql()
824
825 colsNum := len(mi.fields.dbcols)
826 cols := fmt.Sprintf("T0.`%s`", strings.Join(mi.fields.dbcols, "`, T0.`"))
827 for _, tbl := range tables.tables {
828 if tbl.sel {
829 colsNum += len(tbl.mi.fields.dbcols)
830 cols += fmt.Sprintf(", %s.`%s`", tbl.index, strings.Join(tbl.mi.fields.dbcols, "`, "+tbl.index+".`"))
831 }
832 }
833
834 query := fmt.Sprintf("SELECT %s FROM `%s` T0 %s%s%s%s", cols, mi.table, join, where, orderBy, limit)
835
836 var rs *sql.Rows
837 if r, err := q.Query(query, args...); err != nil {
838 return 0, err
839 } else {
840 rs = r
841 }
842
843 refs := make([]interface{}, colsNum)
844 for i, _ := range refs {
845 var ref string
846 refs[i] = &ref
847 }
848
849 slice := ind
850
851 var cnt int64
852 for rs.Next() {
853 if one && cnt == 0 || one == false {
854 if err := rs.Scan(refs...); err != nil {
855 return 0, err
856 }
857
858 elm := reflect.New(mi.addrField.Elem().Type())
859 md := elm.Interface().(Modeler)
860 md.Init(md)
861 mind := reflect.Indirect(elm)
862
863 cacheV := make(map[string]*reflect.Value)
864 cacheM := make(map[string]*modelInfo)
865 trefs := refs
866
867 d.setColsValues(mi, &mind, mi.fields.dbcols, refs[:len(mi.fields.dbcols)])
868 trefs = refs[len(mi.fields.dbcols):]
869
870 for _, tbl := range tables.tables {
871 if tbl.sel {
872 last := mind
873 names := ""
874 mmi := mi
875 for _, name := range tbl.names {
876 names += name
877 if val, ok := cacheV[names]; ok {
878 last = *val
879 mmi = cacheM[names]
880 } else {
881 fi := mmi.fields.GetByName(name)
882 lastm := mmi
883 mmi := fi.relModelInfo
884 field := reflect.Indirect(last.Field(fi.fieldIndex))
885 d.setColsValues(mmi, &field, mmi.fields.dbcols, trefs[:len(mmi.fields.dbcols)])
886 for _, fi := range mmi.fields.fieldsReverse {
887 if fi.reverseFieldInfo.mi == lastm {
888 if fi.reverseFieldInfo != nil {
889 field.Field(fi.fieldIndex).Set(last.Addr())
890 }
891 }
892 }
893 trefs = trefs[len(mmi.fields.dbcols):]
894 cacheV[names] = &field
895 cacheM[names] = mmi
896 last = field
897 }
898 }
899 }
900 }
901
902 if one {
903 ind.Set(mind)
904 } else {
905 slice = reflect.Append(slice, mind.Addr())
906 }
907 }
908 cnt++
909 }
910
911 if one == false {
912 ind.Set(slice)
913 }
914
915 return cnt, nil
916 }
917
918 func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition) (cnt int64, err error) {
919 tables := newDbTables(mi, d.ins)
920 tables.parseRelated(qs.related, qs.relDepth)
921
922 where, args := tables.getCondSql(cond, false)
923 tables.getOrderSql(qs.orders)
924 join := tables.getJoinSql()
925
926 query := fmt.Sprintf("SELECT COUNT(*) FROM `%s` T0 %s%s", mi.table, join, where)
927
928 row := q.QueryRow(query, args...)
929
930 err = row.Scan(&cnt)
931 return
932 }
933
934 func (d *dbBase) GetOperatorSql(mi *modelInfo, operator string, args []interface{}) (string, []interface{}) {
935 params := make([]interface{}, len(args))
936 copy(params, args)
937 sql := ""
938 for i, arg := range args {
939 if len(mi.fields.pk) == 1 {
940 if md, ok := arg.(Modeler); ok {
941 ind := reflect.Indirect(reflect.ValueOf(md))
942 if _, values, exist := d.existPk(mi, ind); exist {
943 arg = values[0]
944 } else {
945 panic(fmt.Sprintf("`%s` need a valid args value", operator))
946 }
947 }
948 }
949 params[i] = arg
950 }
951 if operator == "in" {
952 marks := make([]string, len(params))
953 for i, _ := range marks {
954 marks[i] = "?"
955 }
956 sql = fmt.Sprintf("IN (%s)", strings.Join(marks, ", "))
957 } else {
958 if len(params) > 1 {
959 panic(fmt.Sprintf("operator `%s` need 1 args not %d", operator, len(params)))
960 }
961 sql = operatorsSQL[operator]
962 arg := params[0]
963 switch operator {
964 case "iexact", "contains", "icontains", "startswith", "endswith", "istartswith", "iendswith":
965 param := strings.Replace(ToStr(arg), `%`, `\%`, -1)
966 switch operator {
967 case "iexact", "contains", "icontains":
968 param = fmt.Sprintf("%%%s%%", param)
969 case "startswith", "istartswith":
970 param = fmt.Sprintf("%s%%", param)
971 case "endswith", "iendswith":
972 param = fmt.Sprintf("%%%s", param)
973 }
974 params[0] = param
975 case "isnull":
976 if b, ok := arg.(bool); ok {
977 if b {
978 sql = "IS NULL"
979 } else {
980 sql = "IS NOT NULL"
981 }
982 params = nil
983 } else {
984 panic(fmt.Sprintf("operator `%s` need a bool value not `%T`", operator, arg))
985 }
986 }
987 }
988 return sql, params
989 }
990
991 func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string, values []interface{}) {
992 for i, column := range cols {
993 val := reflect.Indirect(reflect.ValueOf(values[i])).Interface()
994
995 fi := mi.fields.GetByColumn(column)
996
997 field := ind.Field(fi.fieldIndex)
998
999 value, err := d.getValue(fi, val)
1000 if err != nil {
1001 panic(fmt.Sprintf("db value convert failed `%v` %s", val, err.Error()))
1002 }
1003
1004 _, err = d.setValue(fi, value, &field)
1005
1006 if err != nil {
1007 panic(fmt.Sprintf("db value convert failed `%v` %s", val, err.Error()))
1008 }
1009 }
1010 }
1011
1012 func (d *dbBase) getValue(fi *fieldInfo, val interface{}) (interface{}, error) {
1013 if val == nil {
1014 return nil, nil
1015 }
1016
1017 var value interface{}
1018
1019 var str *StrTo
1020 switch v := val.(type) {
1021 case []byte:
1022 s := StrTo(string(v))
1023 str = &s
1024 case string:
1025 s := StrTo(v)
1026 str = &s
1027 }
1028
1029 fieldType := fi.fieldType
1030
1031 setValue:
1032 switch {
1033 case fieldType == TypeBooleanField:
1034 if str == nil {
1035 switch v := val.(type) {
1036 case int64:
1037 b := v == 1
1038 value = b
1039 default:
1040 s := StrTo(ToStr(v))
1041 str = &s
1042 }
1043 }
1044 if str != nil {
1045 b, err := str.Bool()
1046 if err != nil {
1047 return nil, err
1048 }
1049 value = b
1050 }
1051 case fieldType == TypeCharField || fieldType == TypeTextField:
1052 s := str.String()
1053 if str == nil {
1054 s = ToStr(val)
1055 }
1056 value = s
1057 case fieldType == TypeDateField || fieldType == TypeDateTimeField:
1058 if str == nil {
1059 switch v := val.(type) {
1060 case time.Time:
1061 value = v
1062 default:
1063 s := StrTo(ToStr(v))
1064 str = &s
1065 }
1066 }
1067 if str != nil {
1068 format := format_DateTime
1069 if fi.fieldType == TypeDateField {
1070 format = format_Date
1071 }
1072 s := str.String()
1073 t, err := timeParse(s, format)
1074 if err != nil && s != "0000-00-00" && s != "0000-00-00 00:00:00" {
1075 return nil, err
1076 }
1077 value = t
1078 }
1079 case fieldType&IsIntegerField > 0:
1080 if str == nil {
1081 s := StrTo(ToStr(val))
1082 str = &s
1083 }
1084 if str != nil {
1085 var err error
1086 switch fieldType {
1087 case TypeSmallIntegerField:
1088 _, err = str.Int16()
1089 case TypeIntegerField:
1090 _, err = str.Int32()
1091 case TypeBigIntegerField:
1092 _, err = str.Int64()
1093 case TypePositiveSmallIntegerField:
1094 _, err = str.Uint16()
1095 case TypePositiveIntegerField:
1096 _, err = str.Uint32()
1097 case TypePositiveBigIntegerField:
1098 _, err = str.Uint64()
1099 }
1100 if err != nil {
1101 return nil, err
1102 }
1103 if fieldType&IsPostiveIntegerField > 0 {
1104 v, _ := str.Uint64()
1105 value = v
1106 } else {
1107 v, _ := str.Int64()
1108 value = v
1109 }
1110 }
1111 case fieldType == TypeFloatField || fieldType == TypeDecimalField:
1112 if str == nil {
1113 switch v := val.(type) {
1114 case float64:
1115 value = v
1116 default:
1117 s := StrTo(ToStr(v))
1118 str = &s
1119 }
1120 }
1121 if str != nil {
1122 v, err := str.Float64()
1123 if err != nil {
1124 return nil, err
1125 }
1126 value = v
1127 }
1128 case fieldType&IsRelField > 0:
1129 fieldType = fi.relModelInfo.fields.pk[0].fieldType
1130 goto setValue
1131 }
1132
1133 return value, nil
1134
1135 }
1136
1137 func (d *dbBase) setValue(fi *fieldInfo, value interface{}, field *reflect.Value) (interface{}, error) {
1138
1139 fieldType := fi.fieldType
1140 isNative := fi.isFielder == false
1141
1142 setValue:
1143 switch {
1144 case fieldType == TypeBooleanField:
1145 if isNative {
1146 field.SetBool(value.(bool))
1147 }
1148 case fieldType == TypeCharField || fieldType == TypeTextField:
1149 if isNative {
1150 field.SetString(value.(string))
1151 }
1152 case fieldType == TypeDateField || fieldType == TypeDateTimeField:
1153 if isNative {
1154 field.Set(reflect.ValueOf(value))
1155 }
1156 case fieldType&IsIntegerField > 0:
1157 if fieldType&IsPostiveIntegerField > 0 {
1158 if isNative {
1159 field.SetUint(value.(uint64))
1160 }
1161 } else {
1162 if isNative {
1163 field.SetInt(value.(int64))
1164 }
1165 }
1166 case fieldType == TypeFloatField || fieldType == TypeDecimalField:
1167 if isNative {
1168 field.SetFloat(value.(float64))
1169 }
1170 case fieldType&IsRelField > 0:
1171 fieldType = fi.relModelInfo.fields.pk[0].fieldType
1172 mf := reflect.New(fi.relModelInfo.addrField.Elem().Type())
1173 md := mf.Interface().(Modeler)
1174 md.Init(md)
1175 field.Set(mf)
1176 f := mf.Elem().Field(fi.relModelInfo.fields.pk[0].fieldIndex)
1177 field = &f
1178 goto setValue
1179 }
1180
1181 if isNative == false {
1182 fd := field.Addr().Interface().(Fielder)
1183 err := fd.SetRaw(value)
1184 if err != nil {
1185 return nil, err
1186 }
1187 }
1188
1189 return value, nil
1190 }
1191
1192 func (d *dbBase) xsetValue(fi *fieldInfo, val interface{}, field *reflect.Value) (interface{}, error) {
1193 if val == nil {
1194 return nil, nil
1195 }
1196
1197 var value interface{}
1198
1199 var str *StrTo
1200 switch v := val.(type) {
1201 case []byte:
1202 s := StrTo(string(v))
1203 str = &s
1204 case string:
1205 s := StrTo(v)
1206 str = &s
1207 }
1208
1209 fieldType := fi.fieldType
1210 isNative := fi.isFielder == false
1211
1212 setValue:
1213 switch {
1214 case fieldType == TypeBooleanField:
1215 if str == nil {
1216 switch v := val.(type) {
1217 case int64:
1218 b := v == 1
1219 if isNative {
1220 field.SetBool(b)
1221 }
1222 value = b
1223 default:
1224 s := StrTo(ToStr(v))
1225 str = &s
1226 }
1227 }
1228 if str != nil {
1229 b, err := str.Bool()
1230 if err != nil {
1231 return nil, err
1232 }
1233 if isNative {
1234 field.SetBool(b)
1235 }
1236 value = b
1237 }
1238 case fieldType == TypeCharField || fieldType == TypeTextField:
1239 s := str.String()
1240 if str == nil {
1241 s = ToStr(val)
1242 }
1243 if isNative {
1244 field.SetString(s)
1245 }
1246 value = s
1247 case fieldType == TypeDateField || fieldType == TypeDateTimeField:
1248 if str == nil {
1249 switch v := val.(type) {
1250 case time.Time:
1251 if isNative {
1252 field.Set(reflect.ValueOf(v))
1253 }
1254 value = v
1255 default:
1256 s := StrTo(ToStr(v))
1257 str = &s
1258 }
1259 }
1260 if str != nil {
1261 format := format_DateTime
1262 if fi.fieldType == TypeDateField {
1263 format = format_Date
1264 }
1265
1266 t, err := timeParse(str.String(), format)
1267 if err != nil {
1268 return nil, err
1269 }
1270 if isNative {
1271 field.Set(reflect.ValueOf(t))
1272 }
1273 value = t
1274 }
1275 case fieldType&IsIntegerField > 0:
1276 if str == nil {
1277 s := StrTo(ToStr(val))
1278 str = &s
1279 }
1280 if str != nil {
1281 var err error
1282 switch fieldType {
1283 case TypeSmallIntegerField:
1284 value, err = str.Int16()
1285 case TypeIntegerField:
1286 value, err = str.Int32()
1287 case TypeBigIntegerField:
1288 value, err = str.Int64()
1289 case TypePositiveSmallIntegerField:
1290 value, err = str.Uint16()
1291 case TypePositiveIntegerField:
1292 value, err = str.Uint32()
1293 case TypePositiveBigIntegerField:
1294 value, err = str.Uint64()
1295 }
1296 if err != nil {
1297 return nil, err
1298 }
1299 if fieldType&IsPostiveIntegerField > 0 {
1300 v, _ := str.Uint64()
1301 if isNative {
1302 field.SetUint(v)
1303 }
1304 } else {
1305 v, _ := str.Int64()
1306 if isNative {
1307 field.SetInt(v)
1308 }
1309 }
1310 }
1311 case fieldType == TypeFloatField || fieldType == TypeDecimalField:
1312 if str == nil {
1313 switch v := val.(type) {
1314 case float64:
1315 if isNative {
1316 field.SetFloat(v)
1317 }
1318 value = v
1319 default:
1320 s := StrTo(ToStr(v))
1321 str = &s
1322 }
1323 }
1324 if str != nil {
1325 v, err := str.Float64()
1326 if err != nil {
1327 return nil, err
1328 }
1329 if isNative {
1330 field.SetFloat(v)
1331 }
1332 value = v
1333 }
1334 case fieldType&IsRelField > 0:
1335 fieldType = fi.relModelInfo.fields.pk[0].fieldType
1336 mf := reflect.New(fi.relModelInfo.addrField.Elem().Type())
1337 md := mf.Interface().(Modeler)
1338 md.Init(md)
1339 field.Set(mf)
1340 f := mf.Elem().Field(fi.relModelInfo.fields.pk[0].fieldIndex)
1341 field = &f
1342 goto setValue
1343 }
1344
1345 if isNative == false {
1346 fd := field.Addr().Interface().(Fielder)
1347 err := fd.SetRaw(value)
1348 if err != nil {
1349 return nil, err
1350 }
1351 }
1352
1353 return value, nil
1354 }
1355
1356 func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, exprs []string, container interface{}) (int64, error) {
1357
1358 var (
1359 maps []Params
1360 lists []ParamsList
1361 list ParamsList
1362 )
1363
1364 typ := 0
1365 switch container.(type) {
1366 case *[]Params:
1367 typ = 1
1368 case *[]ParamsList:
1369 typ = 2
1370 case *ParamsList:
1371 typ = 3
1372 default:
1373 panic(fmt.Sprintf("unsupport read values type `%T`", container))
1374 }
1375
1376 tables := newDbTables(mi, d.ins)
1377
1378 var (
1379 cols []string
1380 infos []*fieldInfo
1381 )
1382
1383 hasExprs := len(exprs) > 0
1384
1385 if hasExprs {
1386 cols = make([]string, 0, len(exprs))
1387 infos = make([]*fieldInfo, 0, len(exprs))
1388 for _, ex := range exprs {
1389 index, col, fi, suc := tables.parseExprs(mi, strings.Split(ex, ExprSep))
1390 if suc == false {
1391 panic(fmt.Errorf("unknown field/column name `%s`", ex))
1392 }
1393 cols = append(cols, fmt.Sprintf("%s.`%s`", index, col))
1394 infos = append(infos, fi)
1395 }
1396 } else {
1397 cols = make([]string, 0, len(mi.fields.dbcols))
1398 infos = make([]*fieldInfo, 0, len(exprs))
1399 for _, fi := range mi.fields.fieldsDB {
1400 cols = append(cols, fmt.Sprintf("T0.`%s`", fi.column))
1401 infos = append(infos, fi)
1402 }
1403 }
1404
1405 where, args := tables.getCondSql(cond, false)
1406 orderBy := tables.getOrderSql(qs.orders)
1407 limit := tables.getLimitSql(qs.offset, qs.limit)
1408 join := tables.getJoinSql()
1409
1410 sels := strings.Join(cols, ", ")
1411
1412 query := fmt.Sprintf("SELECT %s FROM `%s` T0 %s%s%s%s", sels, mi.table, join, where, orderBy, limit)
1413
1414 var rs *sql.Rows
1415 if r, err := q.Query(query, args...); err != nil {
1416 return 0, err
1417 } else {
1418 rs = r
1419 }
1420
1421 refs := make([]interface{}, len(cols))
1422 for i, _ := range refs {
1423 var ref string
1424 refs[i] = &ref
1425 }
1426
1427 var cnt int64
1428 for rs.Next() {
1429 if err := rs.Scan(refs...); err != nil {
1430 return 0, err
1431 }
1432
1433 switch typ {
1434 case 1:
1435 params := make(Params, len(cols))
1436 for i, ref := range refs {
1437 fi := infos[i]
1438
1439 val := reflect.Indirect(reflect.ValueOf(ref)).Interface()
1440
1441 value, err := d.getValue(fi, val)
1442 if err != nil {
1443 panic(fmt.Sprintf("db value convert failed `%v` %s", val, err.Error()))
1444 }
1445
1446 if hasExprs {
1447 params[exprs[i]] = value
1448 } else {
1449 params[mi.fields.dbcols[i]] = value
1450 }
1451 }
1452 maps = append(maps, params)
1453 case 2:
1454 params := make(ParamsList, 0, len(cols))
1455 for i, ref := range refs {
1456 fi := infos[i]
1457
1458 val := reflect.Indirect(reflect.ValueOf(ref)).Interface()
1459
1460 value, err := d.getValue(fi, val)
1461 if err != nil {
1462 panic(fmt.Sprintf("db value convert failed `%v` %s", val, err.Error()))
1463 }
1464
1465 params = append(params, value)
1466 }
1467 lists = append(lists, params)
1468 case 3:
1469 for i, ref := range refs {
1470 fi := infos[i]
1471
1472 val := reflect.Indirect(reflect.ValueOf(ref)).Interface()
1473
1474 value, err := d.getValue(fi, val)
1475 if err != nil {
1476 panic(fmt.Sprintf("db value convert failed `%v` %s", val, err.Error()))
1477 }
1478
1479 list = append(list, value)
1480 }
1481 }
1482
1483 cnt++
1484 }
1485
1486 switch v := container.(type) {
1487 case *[]Params:
1488 *v = maps
1489 case *[]ParamsList:
1490 *v = lists
1491 case *ParamsList:
1492 *v = list
1493 }
1494
1495 return cnt, nil
1496 }
1 package orm
2
3 import (
4 "database/sql"
5 "fmt"
6 "os"
7 "sync"
8 )
9
10 const defaultMaxIdle = 30
11
12 type driverType int
13
14 const (
15 _ driverType = iota
16 DR_MySQL
17 DR_Sqlite
18 DR_Oracle
19 DR_Postgres
20 )
21
22 var (
23 dataBaseCache = &_dbCache{cache: make(map[string]*alias)}
24 drivers = make(map[string]driverType)
25 dbBasers = map[driverType]dbBaser{
26 DR_MySQL: newdbBaseMysql(),
27 DR_Sqlite: newdbBaseSqlite(),
28 DR_Oracle: newdbBaseMysql(),
29 DR_Postgres: newdbBasePostgres(),
30 }
31 )
32
33 type _dbCache struct {
34 mux sync.RWMutex
35 cache map[string]*alias
36 }
37
38 func (ac *_dbCache) add(name string, al *alias) (added bool) {
39 ac.mux.Lock()
40 defer ac.mux.Unlock()
41 if _, ok := ac.cache[name]; ok == false {
42 ac.cache[name] = al
43 added = true
44 }
45 return
46 }
47
48 func (ac *_dbCache) get(name string) (al *alias, ok bool) {
49 ac.mux.RLock()
50 defer ac.mux.RUnlock()
51 al, ok = ac.cache[name]
52 return
53 }
54
55 func (ac *_dbCache) getDefault() (al *alias) {
56 al, _ = ac.get("default")
57 return
58 }
59
60 type alias struct {
61 Name string
62 DriverName string
63 DataSource string
64 MaxIdle int
65 DB *sql.DB
66 DbBaser dbBaser
67 }
68
69 func RegisterDataBase(name, driverName, dataSource string, maxIdle int) {
70 if maxIdle <= 0 {
71 maxIdle = defaultMaxIdle
72 }
73
74 al := new(alias)
75 al.Name = name
76 al.DriverName = driverName
77 al.DataSource = dataSource
78 al.MaxIdle = maxIdle
79
80 var (
81 err error
82 )
83
84 if dr, ok := drivers[driverName]; ok {
85 al.DbBaser = dbBasers[dr]
86 } else {
87 err = fmt.Errorf("driver name `%s` have not registered", driverName)
88 goto end
89 }
90
91 if dataBaseCache.add(name, al) == false {
92 err = fmt.Errorf("db name `%s` already registered, cannot reuse", name)
93 goto end
94 }
95
96 al.DB, err = sql.Open(driverName, dataSource)
97 if err != nil {
98 err = fmt.Errorf("register db `%s`, %s", name, err.Error())
99 goto end
100 }
101
102 err = al.DB.Ping()
103 if err != nil {
104 err = fmt.Errorf("register db `%s`, %s", name, err.Error())
105 goto end
106 }
107
108 end:
109 if err != nil {
110 fmt.Println(err.Error())
111 os.Exit(2)
112 }
113 }
114
115 func RegisterDriver(name string, typ driverType) {
116 if _, ok := drivers[name]; ok == false {
117 drivers[name] = typ
118 } else {
119 fmt.Println("name `%s` db driver already registered")
120 os.Exit(2)
121 }
122 }
123
124 func init() {
125 // RegisterDriver("mysql", DR_MySQL)
126 RegisterDriver("mymysql", DR_MySQL)
127 }
1 package orm
2
3 type dbBaseMysql struct {
4 dbBase
5 }
6
7 func (d *dbBaseMysql) GetOperatorSql(mi *modelInfo, operator string, args []interface{}) (sql string, params []interface{}) {
8 return d.dbBase.GetOperatorSql(mi, operator, args)
9 }
10
11 func newdbBaseMysql() dbBaser {
12 b := new(dbBaseMysql)
13 b.ins = b
14 return b
15 }
1 package orm
2
3 type dbBaseOracle struct {
4 dbBase
5 }
6
7 func newdbBaseOracle() dbBaser {
8 b := new(dbBaseOracle)
9 b.ins = b
10 return b
11 }
1 package orm
2
3 type dbBasePostgres struct {
4 dbBase
5 }
6
7 func newdbBasePostgres() dbBaser {
8 b := new(dbBasePostgres)
9 b.ins = b
10 return b
11 }
1 package orm
2
3 type dbBaseSqlite struct {
4 dbBase
5 }
6
7 func newdbBaseSqlite() dbBaser {
8 b := new(dbBaseSqlite)
9 b.ins = b
10 return b
11 }
File mode changed
1 package orm
2
3 import (
4 "log"
5 "os"
6 "sync"
7 )
8
9 const (
10 od_CASCADE = "cascade"
11 od_SET_NULL = "set_null"
12 od_SET_DEFAULT = "set_default"
13 od_DO_NOTHING = "do_nothing"
14 defaultStructTagName = "orm"
15 )
16
17 var (
18 errLog *log.Logger
19 modelCache = &_modelCache{cache: make(map[string]*modelInfo)}
20 supportTag = map[string]int{
21 "null": 1,
22 "blank": 1,
23 "index": 1,
24 "unique": 1,
25 "pk": 1,
26 "auto": 1,
27 "auto_now": 1,
28 "auto_now_add": 1,
29 "max_length": 2,
30 "choices": 2,
31 "column": 2,
32 "default": 2,
33 "rel": 2,
34 "reverse": 2,
35 "rel_table": 2,
36 "rel_through": 2,
37 "digits": 2,
38 "decimals": 2,
39 "on_delete": 2,
40 }
41 )
42
43 func init() {
44 errLog = log.New(os.Stderr, "[ORM] ", log.Ldate|log.Ltime|log.Lshortfile)
45 }
46
47 type _modelCache struct {
48 sync.RWMutex
49 orders []string
50 cache map[string]*modelInfo
51 }
52
53 func (mc *_modelCache) all() map[string]*modelInfo {
54 m := make(map[string]*modelInfo, len(mc.cache))
55 for k, v := range mc.cache {
56 m[k] = v
57 }
58 return m
59 }
60
61 func (mc *_modelCache) allOrdered() []*modelInfo {
62 m := make([]*modelInfo, 0, len(mc.orders))
63 for _, v := range mc.cache {
64 m = append(m, v)
65 }
66 return m
67 }
68
69 func (mc *_modelCache) get(table string) (mi *modelInfo, ok bool) {
70 mi, ok = mc.cache[table]
71 return
72 }
73
74 func (mc *_modelCache) set(table string, mi *modelInfo) *modelInfo {
75 mii := mc.cache[table]
76 mc.cache[table] = mi
77 if mii == nil {
78 mc.orders = append(mc.orders, table)
79 }
80 return mii
81 }
1 package orm
2
3 import (
4 "errors"
5 "fmt"
6 "os"
7 "reflect"
8 "strings"
9 )
10
11 func RegisterModel(model Modeler) {
12 info := newModelInfo(model)
13 model.Init(model)
14 table := model.GetTableName()
15 if _, ok := modelCache.get(table); ok {
16 fmt.Printf("model <%T> redeclared, must be unique\n", model)
17 os.Exit(2)
18 }
19 if info.fields.pk == nil {
20 fmt.Printf("model <%T> need a primary key field\n", model)
21 os.Exit(2)
22 }
23 info.table = table
24 info.pkg = getPkgPath(model)
25 info.model = model
26 info.manual = true
27 modelCache.set(table, info)
28 }
29
30 func BootStrap() {
31 modelCache.Lock()
32 defer modelCache.Unlock()
33
34 var (
35 err error
36 models map[string]*modelInfo
37 )
38
39 if dataBaseCache.getDefault() == nil {
40 err = fmt.Errorf("must have one register alias named `default`")
41 goto end
42 }
43
44 models = modelCache.all()
45 for _, mi := range models {
46 for _, fi := range mi.fields.columns {
47 if fi.rel || fi.reverse {
48 elm := fi.addrValue.Type().Elem()
49 switch fi.fieldType {
50 case RelReverseMany, RelManyToMany:
51 elm = elm.Elem()
52 }
53
54 tn := getTableName(reflect.New(elm).Interface().(Modeler))
55 mii, ok := modelCache.get(tn)
56 if ok == false || mii.pkg != elm.PkgPath() {
57 err = fmt.Errorf("can not found rel in field `%s`, `%s` may be miss register", fi.fullName, elm.String())
58 goto end
59 }
60 fi.relModelInfo = mii
61
62 if fi.rel {
63
64 if mii.fields.pk.IsMulti() {
65 err = fmt.Errorf("field `%s` unsupport rel to multi primary key field", fi.fullName)
66 goto end
67 }
68 }
69
70 switch fi.fieldType {
71 case RelManyToMany:
72 if fi.relThrough != "" {
73 msg := fmt.Sprintf("filed `%s` wrong rel_through value `%s`", fi.fullName, fi.relThrough)
74 if i := strings.LastIndex(fi.relThrough, "."); i != -1 && len(fi.relThrough) > (i+1) {
75 pn := fi.relThrough[:i]
76 mn := fi.relThrough[i+1:]
77 tn := snakeString(mn)
78 rmi, ok := modelCache.get(tn)
79 if ok == false || pn != rmi.pkg {
80 err = errors.New(msg + " cannot find table")
81 goto end
82 }
83
84 fi.relThroughModelInfo = rmi
85 fi.relTable = rmi.table
86
87 } else {
88 err = errors.New(msg)
89 goto end
90 }
91 err = nil
92 } else {
93 i := newM2MModelInfo(mi, mii)
94 if fi.relTable != "" {
95 i.table = fi.relTable
96 }
97
98 if v := modelCache.set(i.table, i); v != nil {
99 err = fmt.Errorf("the rel table name `%s` already registered, cannot be use, please change one", fi.relTable)
100 goto end
101 }
102 fi.relTable = i.table
103 fi.relThroughModelInfo = i
104 }
105 }
106 }
107 }
108 }
109
110 models = modelCache.all()
111 for _, mi := range models {
112 for _, fi := range mi.fields.fieldsRel {
113 switch fi.fieldType {
114 case RelForeignKey, RelOneToOne, RelManyToMany:
115 inModel := false
116 for _, ffi := range fi.relModelInfo.fields.fieldsReverse {
117 if ffi.relModelInfo == mi {
118 inModel = true
119 break
120 }
121 }
122 if inModel == false {
123 rmi := fi.relModelInfo
124 ffi := new(fieldInfo)
125 ffi.name = mi.name
126 ffi.column = ffi.name
127 ffi.fullName = rmi.fullName + "." + ffi.name
128 ffi.reverse = true
129 ffi.relModelInfo = mi
130 ffi.mi = rmi
131 if fi.fieldType == RelOneToOne {
132 ffi.fieldType = RelReverseOne
133 } else {
134 ffi.fieldType = RelReverseMany
135 }
136 if rmi.fields.Add(ffi) == false {
137 added := false
138 for cnt := 0; cnt < 5; cnt++ {
139 ffi.name = fmt.Sprintf("%s%d", mi.name, cnt)
140 ffi.column = ffi.name
141 ffi.fullName = rmi.fullName + "." + ffi.name
142 if added = rmi.fields.Add(ffi); added {
143 break
144 }
145 }
146 if added == false {
147 panic(fmt.Sprintf("cannot generate auto reverse field info `%s` to `%s`", fi.fullName, ffi.fullName))
148 }
149 }
150 }
151 }
152 }
153 }
154
155 for _, mi := range models {
156 if fields, ok := mi.fields.fieldsByType[RelReverseOne]; ok {
157 for _, fi := range fields {
158 found := false
159 mForA:
160 for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelOneToOne] {
161 if ffi.relModelInfo == mi {
162 found = true
163 fi.reverseField = ffi.name
164 fi.reverseFieldInfo = ffi
165 break mForA
166 }
167 }
168 if found == false {
169 err = fmt.Errorf("reverse field `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName)
170 goto end
171 }
172 }
173 }
174 if fields, ok := mi.fields.fieldsByType[RelReverseMany]; ok {
175 for _, fi := range fields {
176 found := false
177 mForB:
178 for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelForeignKey] {
179 if ffi.relModelInfo == mi {
180 found = true
181 fi.reverseField = ffi.name
182 fi.reverseFieldInfo = ffi
183 break mForB
184 }
185 }
186 if found == false {
187 mForC:
188 for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelManyToMany] {
189 if ffi.relModelInfo == mi {
190 found = true
191 fi.reverseField = ffi.name
192 fi.reverseFieldInfo = ffi
193 break mForC
194 }
195 }
196 }
197 if found == false {
198 err = fmt.Errorf("reverse field `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName)
199 goto end
200 }
201 }
202 }
203 }
204
205 end:
206 if err != nil {
207 fmt.Println(err)
208 os.Exit(2)
209 }
210
211 runCommand()
212 }
1 package orm
2
3 import (
4 "errors"
5 "fmt"
6 "strconv"
7 "time"
8 )
9
10 const (
11 // bool
12 TypeBooleanField = 1 << iota
13
14 // string
15 TypeCharField
16
17 // string
18 TypeTextField
19
20 // time.Time
21 TypeDateField
22 // time.Time
23 TypeDateTimeField
24
25 // int16
26 TypeSmallIntegerField
27 // int32
28 TypeIntegerField
29 // int64
30 TypeBigIntegerField
31 // uint16
32 TypePositiveSmallIntegerField
33 // uint32
34 TypePositiveIntegerField
35 // uint64
36 TypePositiveBigIntegerField
37
38 // float64
39 TypeFloatField
40 // float64
41 TypeDecimalField
42
43 RelForeignKey
44 RelOneToOne
45 RelManyToMany
46 RelReverseOne
47 RelReverseMany
48 )
49
50 const (
51 IsIntegerField = ^-TypePositiveBigIntegerField >> 4 << 5
52 IsPostiveIntegerField = ^-TypePositiveBigIntegerField >> 7 << 8
53 IsRelField = ^-RelReverseMany >> 12 << 13
54 IsFieldType = ^-RelReverseMany<<1 + 1
55 )
56
57 // A true/false field.
58 type BooleanField bool
59
60 func (e BooleanField) Value() bool {
61 return bool(e)
62 }
63
64 func (e *BooleanField) Set(d bool) {
65 *e = BooleanField(d)
66 }
67
68 func (e *BooleanField) String() string {
69 return strconv.FormatBool(e.Value())
70 }
71
72 func (e *BooleanField) FieldType() int {
73 return TypeBooleanField
74 }
75
76 func (e *BooleanField) SetRaw(value interface{}) error {
77 switch d := value.(type) {
78 case bool:
79 e.Set(d)
80 case string:
81 v, err := StrTo(d).Bool()
82 if err != nil {
83 e.Set(v)
84 }
85 return err
86 default:
87 return errors.New(fmt.Sprintf("<BooleanField.SetRaw> unknown value `%s`", value))
88 }
89 return nil
90 }
91
92 func (e *BooleanField) RawValue() interface{} {
93 return e.Value()
94 }
95
96 // A string field
97 // required values tag: max_length
98 // The max_length is enforced at the database level and in models’s validation.
99 // eg: `max_length:"120"`
100 type CharField string
101
102 func (e CharField) Value() string {
103 return string(e)
104 }
105
106 func (e *CharField) Set(d string) {
107 *e = CharField(d)
108 }
109
110 func (e *CharField) String() string {
111 return e.Value()
112 }
113
114 func (e *CharField) FieldType() int {
115 return TypeCharField
116 }
117
118 func (e *CharField) SetRaw(value interface{}) error {
119 switch d := value.(type) {
120 case string:
121 e.Set(d)
122 default:
123 return errors.New(fmt.Sprintf("<CharField.SetRaw> unknown value `%s`", value))
124 }
125 return nil
126 }
127
128 func (e *CharField) RawValue() interface{} {
129 return e.Value()
130 }
131
132 // A date, represented in go by a time.Time instance.
133 // only date values like 2006-01-02
134 // Has a few extra, optional attr tag:
135 //
136 // auto_now:
137 // Automatically set the field to now every time the object is saved. Useful for “last-modified” timestamps.
138 // Note that the current date is always used; it’s not just a default value that you can override.
139 //
140 // auto_now_add:
141 // Automatically set the field to now when the object is first created. Useful for creation of timestamps.
142 // Note that the current date is always used; it’s not just a default value that you can override.
143 //
144 // eg: `attr:"auto_now"` or `attr:"auto_now_add"`
145 type DateField time.Time
146
147 func (e DateField) Value() time.Time {
148 return time.Time(e)
149 }
150
151 func (e *DateField) Set(d time.Time) {
152 *e = DateField(d)
153 }
154
155 func (e *DateField) String() string {
156 return e.Value().String()
157 }
158
159 func (e *DateField) FieldType() int {
160 return TypeDateField
161 }
162
163 func (e *DateField) SetRaw(value interface{}) error {
164 switch d := value.(type) {
165 case time.Time:
166 e.Set(d)
167 case string:
168 v, err := timeParse(d, format_Date)
169 if err != nil {
170 e.Set(v)
171 }
172 return err
173 default:
174 return errors.New(fmt.Sprintf("<DateField.SetRaw> unknown value `%s`", value))
175 }
176 return nil
177 }
178
179 func (e *DateField) RawValue() interface{} {
180 return e.Value()
181 }
182
183 // A date, represented in go by a time.Time instance.
184 // datetime values like 2006-01-02 15:04:05
185 // Takes the same extra arguments as DateField.
186 type DateTimeField time.Time
187
188 func (e DateTimeField) Value() time.Time {
189 return time.Time(e)
190 }
191
192 func (e *DateTimeField) Set(d time.Time) {
193 *e = DateTimeField(d)
194 }
195
196 func (e *DateTimeField) String() string {
197 return e.Value().String()
198 }
199
200 func (e *DateTimeField) FieldType() int {
201 return TypeDateTimeField
202 }
203
204 func (e *DateTimeField) SetRaw(value interface{}) error {
205 switch d := value.(type) {
206 case time.Time:
207 e.Set(d)
208 case string:
209 v, err := timeParse(d, format_DateTime)
210 if err != nil {
211 e.Set(v)
212 }
213 return err
214 default:
215 return errors.New(fmt.Sprintf("<DateTimeField.SetRaw> unknown value `%s`", value))
216 }
217 return nil
218 }
219
220 func (e *DateTimeField) RawValue() interface{} {
221 return e.Value()
222 }
223
224 // A floating-point number represented in go by a float32 value.
225 type FloatField float64
226
227 func (e FloatField) Value() float64 {
228 return float64(e)
229 }
230
231 func (e *FloatField) Set(d float64) {
232 *e = FloatField(d)
233 }
234
235 func (e *FloatField) String() string {
236 return ToStr(e.Value(), -1, 32)
237 }
238
239 func (e *FloatField) FieldType() int {
240 return TypeFloatField
241 }
242
243 func (e *FloatField) SetRaw(value interface{}) error {
244 switch d := value.(type) {
245 case float32:
246 e.Set(float64(d))
247 case float64:
248 e.Set(d)
249 case string:
250 v, err := StrTo(d).Float64()
251 if err != nil {
252 e.Set(v)
253 }
254 default:
255 return errors.New(fmt.Sprintf("<FloatField.SetRaw> unknown value `%s`", value))
256 }
257 return nil
258 }
259
260 func (e *FloatField) RawValue() interface{} {
261 return e.Value()
262 }
263
264 // -32768 to 32767
265 type SmallIntegerField int16
266
267 func (e SmallIntegerField) Value() int16 {
268 return int16(e)
269 }
270
271 func (e *SmallIntegerField) Set(d int16) {
272 *e = SmallIntegerField(d)
273 }
274
275 func (e *SmallIntegerField) String() string {
276 return ToStr(e.Value())
277 }
278
279 func (e *SmallIntegerField) FieldType() int {
280 return TypeSmallIntegerField
281 }
282
283 func (e *SmallIntegerField) SetRaw(value interface{}) error {
284 switch d := value.(type) {
285 case int16:
286 e.Set(d)
287 case string:
288 v, err := StrTo(d).Int16()
289 if err != nil {
290 e.Set(v)
291 }
292 default:
293 return errors.New(fmt.Sprintf("<SmallIntegerField.SetRaw> unknown value `%s`", value))
294 }
295 return nil
296 }
297
298 func (e *SmallIntegerField) RawValue() interface{} {
299 return e.Value()
300 }
301
302 // -2147483648 to 2147483647
303 type IntegerField int32
304
305 func (e IntegerField) Value() int32 {
306 return int32(e)
307 }
308
309 func (e *IntegerField) Set(d int32) {
310 *e = IntegerField(d)
311 }
312
313 func (e *IntegerField) String() string {
314 return ToStr(e.Value())
315 }
316
317 func (e *IntegerField) FieldType() int {
318 return TypeIntegerField
319 }
320
321 func (e *IntegerField) SetRaw(value interface{}) error {
322 switch d := value.(type) {
323 case int32:
324 e.Set(d)
325 case string:
326 v, err := StrTo(d).Int32()
327 if err != nil {
328 e.Set(v)
329 }
330 default:
331 return errors.New(fmt.Sprintf("<IntegerField.SetRaw> unknown value `%s`", value))
332 }
333 return nil
334 }
335
336 func (e *IntegerField) RawValue() interface{} {
337 return e.Value()
338 }
339
340 // -9223372036854775808 to 9223372036854775807.
341 type BigIntegerField int64
342
343 func (e BigIntegerField) Value() int64 {
344 return int64(e)
345 }
346
347 func (e *BigIntegerField) Set(d int64) {
348 *e = BigIntegerField(d)
349 }
350
351 func (e *BigIntegerField) String() string {
352 return ToStr(e.Value())
353 }
354
355 func (e *BigIntegerField) FieldType() int {
356 return TypeBigIntegerField
357 }
358
359 func (e *BigIntegerField) SetRaw(value interface{}) error {
360 switch d := value.(type) {
361 case int64:
362 e.Set(d)
363 case string:
364 v, err := StrTo(d).Int64()
365 if err != nil {
366 e.Set(v)
367 }
368 default:
369 return errors.New(fmt.Sprintf("<BigIntegerField.SetRaw> unknown value `%s`", value))
370 }
371 return nil
372 }
373
374 func (e *BigIntegerField) RawValue() interface{} {
375 return e.Value()
376 }
377
378 // 0 to 65535
379 type PositiveSmallIntegerField uint16
380
381 func (e PositiveSmallIntegerField) Value() uint16 {
382 return uint16(e)
383 }
384
385 func (e *PositiveSmallIntegerField) Set(d uint16) {
386 *e = PositiveSmallIntegerField(d)
387 }
388
389 func (e *PositiveSmallIntegerField) String() string {
390 return ToStr(e.Value())
391 }
392
393 func (e *PositiveSmallIntegerField) FieldType() int {
394 return TypePositiveSmallIntegerField
395 }
396
397 func (e *PositiveSmallIntegerField) SetRaw(value interface{}) error {
398 switch d := value.(type) {
399 case uint16:
400 e.Set(d)
401 case string:
402 v, err := StrTo(d).Uint16()
403 if err != nil {
404 e.Set(v)
405 }
406 default:
407 return errors.New(fmt.Sprintf("<PositiveSmallIntegerField.SetRaw> unknown value `%s`", value))
408 }
409 return nil
410 }
411
412 func (e *PositiveSmallIntegerField) RawValue() interface{} {
413 return e.Value()
414 }
415
416 // 0 to 4294967295
417 type PositiveIntegerField uint32
418
419 func (e PositiveIntegerField) Value() uint32 {
420 return uint32(e)
421 }
422
423 func (e *PositiveIntegerField) Set(d uint32) {
424 *e = PositiveIntegerField(d)
425 }
426
427 func (e *PositiveIntegerField) String() string {
428 return ToStr(e.Value())
429 }
430
431 func (e *PositiveIntegerField) FieldType() int {
432 return TypePositiveIntegerField
433 }
434
435 func (e *PositiveIntegerField) SetRaw(value interface{}) error {
436 switch d := value.(type) {
437 case uint32:
438 e.Set(d)
439 case string:
440 v, err := StrTo(d).Uint32()
441 if err != nil {
442 e.Set(v)
443 }
444 default:
445 return errors.New(fmt.Sprintf("<PositiveIntegerField.SetRaw> unknown value `%s`", value))
446 }
447 return nil
448 }
449
450 func (e *PositiveIntegerField) RawValue() interface{} {
451 return e.Value()
452 }
453
454 // 0 to 18446744073709551615
455 type PositiveBigIntegerField uint64
456
457 func (e PositiveBigIntegerField) Value() uint64 {
458 return uint64(e)
459 }
460
461 func (e *PositiveBigIntegerField) Set(d uint64) {
462 *e = PositiveBigIntegerField(d)
463 }
464
465 func (e *PositiveBigIntegerField) String() string {
466 return ToStr(e.Value())
467 }
468
469 func (e *PositiveBigIntegerField) FieldType() int {
470 return TypePositiveIntegerField
471 }
472
473 func (e *PositiveBigIntegerField) SetRaw(value interface{}) error {
474 switch d := value.(type) {
475 case uint64:
476 e.Set(d)
477 case string:
478 v, err := StrTo(d).Uint64()
479 if err != nil {
480 e.Set(v)
481 }
482 default:
483 return errors.New(fmt.Sprintf("<PositiveBigIntegerField.SetRaw> unknown value `%s`", value))
484 }
485 return nil
486 }
487
488 func (e *PositiveBigIntegerField) RawValue() interface{} {
489 return e.Value()
490 }
491
492 // A large text field.
493 type TextField string
494
495 func (e TextField) Value() string {
496 return string(e)
497 }
498
499 func (e *TextField) Set(d string) {
500 *e = TextField(d)
501 }
502
503 func (e *TextField) String() string {
504 return e.Value()
505 }
506
507 func (e *TextField) FieldType() int {
508 return TypeTextField
509 }
510
511 func (e *TextField) SetRaw(value interface{}) error {
512 switch d := value.(type) {
513 case string:
514 e.Set(d)
515 default:
516 return errors.New(fmt.Sprintf("<TextField.SetRaw> unknown value `%s`", value))
517 }
518 return nil
519 }
520
521 func (e *TextField) RawValue() interface{} {
522 return e.Value()
523 }
1 package orm
2
3 import (
4 "errors"
5 "fmt"
6 "reflect"
7 "strings"
8 )
9
10 type fieldChoices []StrTo
11
12 func (f *fieldChoices) Add(s StrTo) {
13 if f.Have(s) == false {
14 *f = append(*f, s)
15 }
16 }
17
18 func (f *fieldChoices) Clear() {
19 *f = fieldChoices([]StrTo{})
20 }
21
22 func (f *fieldChoices) Have(s StrTo) bool {
23 for _, v := range *f {
24 if v == s {
25 return true
26 }
27 }
28 return false
29 }
30
31 func (f *fieldChoices) Clone() fieldChoices {
32 return *f
33 }
34
35 type primaryKeys []*fieldInfo
36
37 func (p *primaryKeys) Add(fi *fieldInfo) {
38 *p = append(*p, fi)
39 }
40
41 func (p primaryKeys) Exist(fi *fieldInfo) (int, bool) {
42 for i, v := range p {
43 if v == fi {
44 return i, true
45 }
46 }
47 return -1, false
48 }
49
50 func (p primaryKeys) IsMulti() bool {
51 return len(p) > 1
52 }
53
54 func (p primaryKeys) IsEmpty() bool {
55 return len(p) == 0
56 }
57
58 type fields struct {
59 pk primaryKeys
60 auto *fieldInfo
61 columns map[string]*fieldInfo
62 fields map[string]*fieldInfo
63 fieldsLow map[string]*fieldInfo
64 fieldsByType map[int][]*fieldInfo
65 fieldsRel []*fieldInfo
66 fieldsReverse []*fieldInfo
67 fieldsDB []*fieldInfo
68 rels []*fieldInfo
69 orders []string
70 dbcols []string
71 }
72
73 func (f *fields) Add(fi *fieldInfo) (added bool) {
74 if f.fields[fi.name] == nil && f.columns[fi.column] == nil {
75 f.columns[fi.column] = fi
76 f.fields[fi.name] = fi
77 f.fieldsLow[strings.ToLower(fi.name)] = fi
78 } else {
79 return
80 }
81 if _, ok := f.fieldsByType[fi.fieldType]; ok == false {
82 f.fieldsByType[fi.fieldType] = make([]*fieldInfo, 0)
83 }
84 f.fieldsByType[fi.fieldType] = append(f.fieldsByType[fi.fieldType], fi)
85 f.orders = append(f.orders, fi.column)
86 if fi.dbcol {
87 f.dbcols = append(f.dbcols, fi.column)
88 f.fieldsDB = append(f.fieldsDB, fi)
89 }
90 if fi.rel {
91 f.fieldsRel = append(f.fieldsRel, fi)
92 }
93 if fi.reverse {
94 f.fieldsReverse = append(f.fieldsReverse, fi)
95 }
96 return true
97 }
98
99 func (f *fields) GetByName(name string) *fieldInfo {
100 return f.fields[name]
101 }
102
103 func (f *fields) GetByColumn(column string) *fieldInfo {
104 return f.columns[column]
105 }
106
107 func (f *fields) GetByAny(name string) (*fieldInfo, bool) {
108 if fi, ok := f.fields[name]; ok {
109 return fi, ok
110 }
111 if fi, ok := f.fieldsLow[strings.ToLower(name)]; ok {
112 return fi, ok
113 }
114 if fi, ok := f.columns[name]; ok {
115 return fi, ok
116 }
117 return nil, false
118 }
119
120 func newFields() *fields {
121 f := new(fields)
122 f.fields = make(map[string]*fieldInfo)
123 f.fieldsLow = make(map[string]*fieldInfo)
124 f.columns = make(map[string]*fieldInfo)
125 f.fieldsByType = make(map[int][]*fieldInfo)
126 return f
127 }
128
129 type fieldInfo struct {
130 mi *modelInfo
131 fieldIndex int
132 fieldType int
133 dbcol bool
134 inModel bool
135 name string
136 fullName string
137 column string
138 addrValue *reflect.Value
139 sf *reflect.StructField
140 auto bool
141 pk bool
142 null bool
143 blank bool
144 index bool
145 unique bool
146 initial StrTo
147 choices fieldChoices
148 maxLength int
149 auto_now bool
150 auto_now_add bool
151 rel bool
152 reverse bool
153 reverseField string
154 reverseFieldInfo *fieldInfo
155 relTable string
156 relThrough string
157 relThroughModelInfo *modelInfo
158 relModelInfo *modelInfo
159 digits int
160 decimals int
161 isFielder bool
162 onDelete string
163 }
164
165 func newFieldInfo(mi *modelInfo, field reflect.Value, sf reflect.StructField) (fi *fieldInfo, err error) {
166 var (
167 tag string
168 tagValue string
169 choices fieldChoices
170 values fieldChoices
171 initial StrTo
172 fieldType int
173 attrs map[string]bool
174 tags map[string]string
175 parts []string
176 addrField reflect.Value
177 )
178
179 fi = new(fieldInfo)
180
181 if field.Kind() != reflect.Ptr && field.Kind() != reflect.Slice && field.CanAddr() {
182 addrField = field.Addr()
183 } else {
184 addrField = field
185 }
186
187 parseStructTag(sf.Tag.Get(defaultStructTagName), &attrs, &tags)
188
189 digits := tags["digits"]
190 decimals := tags["decimals"]
191 maxLength := tags["max_length"]
192 onDelete := tags["on_delete"]
193
194 checkType:
195 switch f := addrField.Interface().(type) {
196 case Fielder:
197 fi.isFielder = true
198 if field.Kind() == reflect.Ptr {
199 err = fmt.Errorf("the model Fielder can not be use ptr")
200 goto end
201 }
202 fieldType = f.FieldType()
203 if fieldType&IsRelField > 0 {
204 err = fmt.Errorf("unsupport rel type custom field")
205 goto end
206 }
207 default:
208 tag = "rel"
209 tagValue = tags[tag]
210 if tagValue != "" {
211 switch tagValue {
212 case "fk":
213 fieldType = RelForeignKey
214 break checkType
215 case "one":
216 fieldType = RelOneToOne
217 break checkType
218 case "m2m":
219 fieldType = RelManyToMany
220 if tv := tags["rel_table"]; tv != "" {
221 fi.relTable = tv
222 } else if tv := tags["rel_through"]; tv != "" {
223 fi.relThrough = tv
224 }
225 break checkType
226 default:
227 err = fmt.Errorf("error")
228 goto wrongTag
229 }
230 }
231 tag = "reverse"
232 tagValue = tags[tag]
233 if tagValue != "" {
234 switch tagValue {
235 case "one":
236 fieldType = RelReverseOne
237 break checkType
238 case "many":
239 fieldType = RelReverseMany
240 break checkType
241 default:
242 err = fmt.Errorf("error")
243 goto wrongTag
244 }
245 }
246
247 fieldType, err = getFieldType(addrField)
248 if err != nil {
249 goto end
250 }
251 if fieldType == TypeTextField && maxLength != "" {
252 fieldType = TypeCharField
253 }
254 if fieldType == TypeFloatField && (digits != "" || decimals != "") {
255 fieldType = TypeDecimalField
256 }
257 if fieldType == TypeDateTimeField && attrs["date"] {
258 fieldType = TypeDateField
259 }
260 }
261
262 switch fieldType {
263 case RelForeignKey, RelOneToOne, RelReverseOne:
264 if _, ok := addrField.Interface().(Modeler); ok == false {
265 err = fmt.Errorf("rel/reverse:one field must be implements Modeler")
266 goto end
267 }
268 if field.Kind() != reflect.Ptr {
269 err = fmt.Errorf("rel/reverse:one field must be *%s", field.Type().Name())
270 goto end
271 }
272 case RelManyToMany, RelReverseMany:
273 if field.Kind() != reflect.Slice {
274 err = fmt.Errorf("rel/reverse:many field must be slice")
275 goto end
276 } else {
277 if field.Type().Elem().Kind() != reflect.Ptr {
278 err = fmt.Errorf("rel/reverse:many slice must be []*%s", field.Type().Elem().Name())
279 goto end
280 }
281 if _, ok := reflect.New(field.Type().Elem()).Elem().Interface().(Modeler); ok == false {
282 err = fmt.Errorf("rel/reverse:many slice element must be implements Modeler")
283 goto end
284 }
285 }
286 }
287
288 if fieldType&IsFieldType == 0 {
289 err = fmt.Errorf("wrong field type")
290 goto end
291 }
292
293 fi.fieldType = fieldType
294 fi.name = sf.Name
295 fi.column = getColumnName(fieldType, addrField, sf, tags["column"])
296 fi.addrValue = &addrField
297 fi.sf = &sf
298 fi.fullName = mi.fullName + "." + sf.Name
299
300 fi.null = attrs["null"]
301 fi.blank = attrs["blank"]
302 fi.index = attrs["index"]
303 fi.auto = attrs["auto"]
304 fi.pk = attrs["pk"]
305 fi.unique = attrs["unique"]
306
307 switch fieldType {
308 case RelManyToMany, RelReverseMany, RelReverseOne:
309 fi.null = false
310 fi.blank = false
311 fi.index = false
312 fi.auto = false
313 fi.pk = false
314 fi.unique = false
315 default:
316 fi.dbcol = true
317 }
318
319 switch fieldType {
320 case RelForeignKey, RelOneToOne, RelManyToMany:
321 fi.rel = true
322 if fieldType == RelOneToOne {
323 fi.unique = true
324 }
325 case RelReverseMany, RelReverseOne:
326 fi.reverse = true
327 }
328
329 if fi.rel && fi.dbcol {
330 switch onDelete {
331 case od_CASCADE, od_DO_NOTHING:
332 case od_SET_DEFAULT:
333 if tags["default"] == "" {
334 err = errors.New("on_delete: set_default need set field a default value")
335 goto end
336 }
337 case od_SET_NULL:
338 if fi.null == false {
339 err = errors.New("on_delete: set_null need set field null")
340 goto end
341 }
342 default:
343 if onDelete == "" {
344 onDelete = od_CASCADE
345 } else {
346 err = fmt.Errorf("on_delete value expected choice in `cascade,set_null,set_default,do_nothing`, unknown `%s`", onDelete)
347 goto end
348 }
349 }
350
351 fi.onDelete = onDelete
352 }
353
354 switch fieldType {
355 case TypeBooleanField:
356 case TypeCharField:
357 if maxLength != "" {
358 v, e := StrTo(maxLength).Int32()
359 if e != nil {
360 err = fmt.Errorf("wrong maxLength value `%s`", maxLength)
361 } else {
362 fi.maxLength = int(v)
363 }
364 } else {
365 err = fmt.Errorf("maxLength must be specify")
366 }
367 case TypeTextField:
368 fi.index = false
369 fi.unique = false
370 case TypeDateField, TypeDateTimeField:
371 if attrs["auto_now"] {
372 fi.auto_now = true
373 } else if attrs["auto_now_add"] {
374 fi.auto_now_add = true
375 }
376 case TypeFloatField:
377 case TypeDecimalField:
378 d1 := digits
379 d2 := decimals
380 v1, er1 := StrTo(d1).Int16()
381 v2, er2 := StrTo(d2).Int16()
382 if er1 != nil || er2 != nil {
383 err = fmt.Errorf("wrong digits/decimals value %s/%s", d2, d1)
384 goto end
385 }
386 fi.digits = int(v1)
387 fi.decimals = int(v2)
388 default:
389 switch {
390 case fieldType&IsIntegerField > 0:
391 case fieldType&IsRelField > 0:
392 }
393 }
394
395 if fieldType&IsIntegerField == 0 {
396 if fi.auto {
397 err = fmt.Errorf("non-integer type cannot set auto")
398 goto end
399 }
400
401 if fi.pk || fi.index || fi.unique {
402 if fieldType != TypeCharField && fieldType != RelOneToOne {
403 err = fmt.Errorf("cannot set pk/index/unique")
404 goto end
405 }
406 }
407 }
408
409 if fi.auto || fi.pk {
410 if fi.auto {
411 fi.pk = true
412 }
413 fi.null = false
414 fi.blank = false
415 fi.index = false
416 fi.unique = false
417 }
418
419 if fi.unique {
420 fi.null = false
421 fi.blank = false
422 fi.index = false
423 }
424
425 parts = strings.Split(tags["choices"], ",")
426 if len(parts) > 1 {
427 for _, v := range parts {
428 choices.Add(StrTo(strings.TrimSpace(v)))
429 }
430 }
431
432 initial.Clear()
433 if v, ok := tags["default"]; ok {
434 initial.Set(v)
435 }
436
437 if fi.auto || fi.pk || fi.unique || fieldType == TypeDateField || fieldType == TypeDateTimeField {
438 // can not set default
439 choices.Clear()
440 initial.Clear()
441 }
442
443 values = choices.Clone()
444
445 if initial.Exist() {
446 values.Add(initial)
447 }
448
449 for i, v := range values {
450 switch fieldType {
451 case TypeBooleanField:
452 _, err = v.Bool()
453 case TypeFloatField, TypeDecimalField:
454 _, err = v.Float64()
455 case TypeSmallIntegerField:
456 _, err = v.Int16()
457 case TypeIntegerField:
458 _, err = v.Int32()
459 case TypeBigIntegerField:
460 _, err = v.Int64()
461 case TypePositiveSmallIntegerField:
462 _, err = v.Uint16()
463 case TypePositiveIntegerField:
464 _, err = v.Uint32()
465 case TypePositiveBigIntegerField:
466 _, err = v.Uint64()
467 }
468 if err != nil {
469 if initial.Exist() && len(values) == i {
470 tag, tagValue = "default", tags["default"]
471 } else {
472 tag, tagValue = "choices", tags["choices"]
473 }
474 goto wrongTag
475 }
476 }
477
478 if len(choices) > 0 && initial.Exist() {
479 if choices.Have(initial) == false {
480 err = fmt.Errorf("default value `%s` not in choices `%s`", tags["default"], tags["choices"])
481 goto end
482 }
483 }
484
485 fi.choices = choices
486 fi.initial = initial
487
488 end:
489 if err != nil {
490 return nil, err
491 }
492 return
493 wrongTag:
494 return nil, fmt.Errorf("wrong tag format: `%s:\"%s\"`, %s", tag, tagValue, err)
495 }
1 package orm
2
3 import (
4 "errors"
5 "fmt"
6 "os"
7 "reflect"
8 )
9
10 type modelInfo struct {
11 pkg string
12 name string
13 fullName string
14 table string
15 model Modeler
16 fields *fields
17 manual bool
18 addrField reflect.Value
19 }
20
21 func newModelInfo(model Modeler) (info *modelInfo) {
22 var (
23 err error
24 fi *fieldInfo
25 sf reflect.StructField
26 )
27
28 info = &modelInfo{}
29 info.fields = newFields()
30
31 val := reflect.ValueOf(model)
32 ind := reflect.Indirect(val)
33 typ := ind.Type()
34
35 info.addrField = ind.Addr()
36
37 info.name = typ.Name()
38 info.fullName = typ.PkgPath() + "." + typ.Name()
39
40 for i := 0; i < ind.NumField(); i++ {
41 field := ind.Field(i)
42 sf = ind.Type().Field(i)
43 if field.CanAddr() {
44 addr := field.Addr()
45 if _, ok := addr.Interface().(*Manager); ok {
46 continue
47 }
48 }
49 fi, err = newFieldInfo(info, field, sf)
50 if err != nil {
51 break
52 }
53 added := info.fields.Add(fi)
54 if added == false {
55 err = errors.New(fmt.Sprintf("duplicate column name: %s", fi.column))
56 break
57 }
58 if fi.pk {
59 if info.fields.pk != nil {
60 err = errors.New(fmt.Sprintf("one model must have one pk field only"))
61 break
62 } else {
63 info.fields.pk.Add(fi)
64 }
65 }
66 if fi.auto {
67 info.fields.auto = fi
68 }
69 fi.fieldIndex = i
70 fi.mi = info
71 }
72
73 if _, ok := info.fields.pk.Exist(info.fields.auto); info.fields.auto != nil && ok == false {
74 err = errors.New(fmt.Sprintf("when auto field exists, you cannot set other pk field"))
75 goto end
76 }
77
78 if err != nil {
79 fmt.Println(fmt.Errorf("field: %s.%s, %s", ind.Type(), sf.Name, err))
80 os.Exit(2)
81 }
82
83 end:
84 if err != nil {
85 fmt.Println(err)
86 os.Exit(2)
87 }
88 return
89 }
90
91 func newM2MModelInfo(m1, m2 *modelInfo) (info *modelInfo) {
92 info = new(modelInfo)
93 info.fields = newFields()
94 info.table = m1.table + "_" + m2.table + "_rel"
95 info.name = camelString(info.table)
96 info.fullName = m1.pkg + "." + info.name
97
98 fa := new(fieldInfo)
99 f1 := new(fieldInfo)
100 f2 := new(fieldInfo)
101 fa.fieldType = TypeBigIntegerField
102 fa.auto = true
103 fa.pk = true
104 fa.dbcol = true
105
106 f1.dbcol = true
107 f2.dbcol = true
108 f1.fieldType = RelForeignKey
109 f2.fieldType = RelForeignKey
110 f1.name = camelString(m1.table)
111 f2.name = camelString(m2.table)
112 f1.fullName = info.fullName + "." + f1.name
113 f2.fullName = info.fullName + "." + f2.name
114 f1.column = m1.table + "_id"
115 f2.column = m2.table + "_id"
116 f1.rel = true
117 f2.rel = true
118 f1.relTable = m1.table
119 f2.relTable = m2.table
120 f1.relModelInfo = m1
121 f2.relModelInfo = m2
122 f1.mi = info
123 f2.mi = info
124
125 info.fields.Add(fa)
126 info.fields.Add(f1)
127 info.fields.Add(f2)
128 info.fields.pk.Add(fa)
129 return
130 }
1 package orm
2
3 import ()
4
5 // non cleaned field errors
6 type FieldErrors map[string]error
7
8 func (fe FieldErrors) Get(name string) error {
9 return fe[name]
10 }
11
12 func (fe FieldErrors) Set(name string, value error) {
13 fe[name] = value
14 }
15
16 type Manager struct {
17 ins Modeler
18 inited bool
19 }
20
21 // func (m *Manager) init(model reflect.Value) {
22 // elm := model.Elem()
23 // for i := 0; i < elm.NumField(); i++ {
24 // field := elm.Field(i)
25 // if _, ok := field.Interface().(Fielder); ok && field.CanSet() {
26 // if field.Elem().Kind() != reflect.Struct {
27 // field.Set(reflect.New(field.Type().Elem()))
28 // }
29 // }
30 // }
31 // }
32
33 func (m *Manager) Init(model Modeler) Modeler {
34 if m.inited {
35 return m.ins
36 }
37 m.inited = true
38 m.ins = model
39 return model
40 }
41
42 func (m *Manager) IsInited() bool {
43 return m.inited
44 }
45
46 func (m *Manager) Clean() FieldErrors {
47 return nil
48 }
49
50 func (m *Manager) CleanFields(name string) FieldErrors {
51 return nil
52 }
53
54 func (m *Manager) GetTableName() string {
55 return getTableName(m.ins)
56 }
1 package orm
2
3 import (
4 "fmt"
5 "reflect"
6 "strings"
7 "time"
8 )
9
10 func getTableName(model Modeler) string {
11 val := reflect.ValueOf(model)
12 ind := reflect.Indirect(val)
13 fun := val.MethodByName("TableName")
14 if fun.IsValid() {
15 vals := fun.Call([]reflect.Value{})
16 if len(vals) > 0 {
17 val := vals[0]
18 if val.Kind() == reflect.String {
19 return val.String()
20 }
21 }
22 }
23 return snakeString(ind.Type().Name())
24 }
25
26 func getPkgPath(model Modeler) string {
27 val := reflect.ValueOf(model)
28 return val.Type().Elem().PkgPath()
29 }
30
31 func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col string) string {
32 column := strings.ToLower(col)
33 if column == "" {
34 column = snakeString(sf.Name)
35 }
36 switch ft {
37 case RelForeignKey, RelOneToOne:
38 column = column + "_id"
39 case RelManyToMany, RelReverseMany, RelReverseOne:
40 column = sf.Name
41 }
42 return column
43 }
44
45 func getFieldType(val reflect.Value) (ft int, err error) {
46 elm := reflect.Indirect(val)
47 switch elm.Kind() {
48 case reflect.Int16:
49 ft = TypeSmallIntegerField
50 case reflect.Int32, reflect.Int:
51 ft = TypeIntegerField
52 case reflect.Int64:
53 ft = TypeBigIntegerField
54 case reflect.Uint16:
55 ft = TypePositiveSmallIntegerField
56 case reflect.Uint32:
57 ft = TypePositiveIntegerField
58 case reflect.Uint64:
59 ft = TypePositiveBigIntegerField
60 case reflect.Float32, reflect.Float64:
61 ft = TypeFloatField
62 case reflect.Bool:
63 ft = TypeBooleanField
64 case reflect.String:
65 ft = TypeTextField
66 case reflect.Invalid:
67 default:
68 if elm.CanInterface() {
69 if _, ok := elm.Interface().(time.Time); ok {
70 ft = TypeDateTimeField
71 }
72 }
73 }
74 if ft&IsFieldType == 0 {
75 err = fmt.Errorf("unsupport field type %s, may be miss setting tag", val)
76 }
77 return
78 }
79
80 func parseStructTag(data string, attrs *map[string]bool, tags *map[string]string) {
81 attr := make(map[string]bool)
82 tag := make(map[string]string)
83 for _, v := range strings.Split(data, ";") {
84 v = strings.TrimSpace(v)
85 if supportTag[v] == 1 {
86 attr[v] = true
87 } else if i := strings.Index(v, "("); i > 0 && strings.Index(v, ")") == len(v)-1 {
88 name := v[:i]
89 if supportTag[name] == 2 {
90 v = v[i+1 : len(v)-1]
91 tag[name] = v
92 }
93 }
94 }
95 *attrs = attr
96 *tags = tag
97 }
1 package orm
2
3 import (
4 "database/sql"
5 "errors"
6 "fmt"
7 "time"
8 )
9
10 var (
11 ErrTXHasBegin = errors.New("<Ormer.Begin> transaction already begin")
12 ErrTXNotBegin = errors.New("<Ormer.Commit/Rollback> transaction not begin")
13 ErrMultiRows = errors.New("<QuerySeter.One> return multi rows")
14 ErrStmtClosed = errors.New("<QuerySeter.Insert> stmt already closed")
15 DefaultRowsLimit = 1000
16 DefaultRelsDepth = 5
17 DefaultTimeLoc = time.Local
18 )
19
20 type Params map[string]interface{}
21 type ParamsList []interface{}
22
23 type orm struct {
24 alias *alias
25 db dbQuerier
26 isTx bool
27 }
28
29 func (o *orm) Object(md Modeler) ObjectSeter {
30 name := md.GetTableName()
31 if mi, ok := modelCache.get(name); ok {
32 return newObject(o, mi, md)
33 }
34 panic(fmt.Sprintf("<orm.Object> table name: `%s` not exists", name))
35 }
36
37 func (o *orm) QueryTable(ptrStructOrTableName interface{}) QuerySeter {
38 name := ""
39 if table, ok := ptrStructOrTableName.(string); ok {
40 name = snakeString(table)
41 } else if m, ok := ptrStructOrTableName.(Modeler); ok {
42 name = m.GetTableName()
43 }
44 if mi, ok := modelCache.get(name); ok {
45 return newQuerySet(o, mi)
46 }
47 panic(fmt.Sprintf("<orm.SetTable> table name: `%s` not exists", name))
48 }
49
50 func (o *orm) Using(name string) error {
51 if o.isTx {
52 panic("<orm.Using> transaction has been start, cannot change db")
53 }
54 if al, ok := dataBaseCache.get(name); ok {
55 o.alias = al
56 o.db = al.DB
57 } else {
58 return errors.New(fmt.Sprintf("<orm.Using> unknown db alias name `%s`", name))
59 }
60 return nil
61 }
62
63 func (o *orm) Begin() error {
64 if o.isTx {
65 return ErrTXHasBegin
66 }
67 tx, err := o.alias.DB.Begin()
68 if err != nil {
69 return err
70 }
71 o.isTx = true
72 o.db = tx
73 return nil
74 }
75
76 func (o *orm) Commit() error {
77 if o.isTx == false {
78 return ErrTXNotBegin
79 }
80 err := o.db.(*sql.Tx).Commit()
81 if err == nil {
82 o.isTx = false
83 o.db = o.alias.DB
84 }
85 return err
86 }
87
88 func (o *orm) Rollback() error {
89 if o.isTx == false {
90 return ErrTXNotBegin
91 }
92 err := o.db.(*sql.Tx).Rollback()
93 if err == nil {
94 o.isTx = false
95 o.db = o.alias.DB
96 }
97 return err
98 }
99
100 func (o *orm) Raw(query string, args ...interface{}) RawSeter {
101 return newRawSet(o, query, args)
102 }
103
104 func NewOrm() Ormer {
105 o := new(orm)
106 err := o.Using("default")
107 if err != nil {
108 panic(err)
109 }
110 return o
111 }
1 package orm
2
3 import (
4 "strings"
5 )
6
7 const (
8 ExprSep = "__"
9 )
10
11 type condValue struct {
12 exprs []string
13 args []interface{}
14 cond *Condition
15 isOr bool
16 isNot bool
17 isCond bool
18 }
19
20 type Condition struct {
21 params []condValue
22 }
23
24 func NewCondition() *Condition {
25 c := &Condition{}
26 return c
27 }
28
29 func (c *Condition) And(expr string, args ...interface{}) *Condition {
30 if expr == "" || len(args) == 0 {
31 panic("<Condition.And> args cannot empty")
32 }
33 c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args})
34 return c
35 }
36
37 func (c *Condition) AndNot(expr string, args ...interface{}) *Condition {
38 if expr == "" || len(args) == 0 {
39 panic("<Condition.AndNot> args cannot empty")
40 }
41 c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isNot: true})
42 return c
43 }
44
45 func (c *Condition) AndCond(cond *Condition) *Condition {
46 if c == cond {
47 panic("cannot use self as sub cond")
48 }
49 if cond != nil {
50 c.params = append(c.params, condValue{cond: cond, isCond: true})
51 }
52 return c
53 }
54
55 func (c *Condition) Or(expr string, args ...interface{}) *Condition {
56 if expr == "" || len(args) == 0 {
57 panic("<Condition.Or> args cannot empty")
58 }
59 c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isOr: true})
60 return c
61 }
62
63 func (c *Condition) OrNot(expr string, args ...interface{}) *Condition {
64 if expr == "" || len(args) == 0 {
65 panic("<Condition.OrNot> args cannot empty")
66 }
67 c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isNot: true, isOr: true})
68 return c
69 }
70
71 func (c *Condition) OrCond(cond *Condition) *Condition {
72 if c == cond {
73 panic("cannot use self as sub cond")
74 }
75 if cond != nil {
76 c.params = append(c.params, condValue{cond: cond, isCond: true, isOr: true})
77 }
78 return c
79 }
80
81 func (c *Condition) IsEmpty() bool {
82 return len(c.params) == 0
83 }
84
85 func (c Condition) Clone() *Condition {
86 params := c.params
87 c.params = make([]condValue, len(params))
88 copy(c.params, params)
89 return &c
90 }
91
92 func (c *Condition) Merge() (expr string, args []interface{}) {
93 return expr, args
94 }
1 package orm
2
3 import (
4 "database/sql"
5 "fmt"
6 "reflect"
7 )
8
9 type insertSet struct {
10 mi *modelInfo
11 orm *orm
12 stmt *sql.Stmt
13 closed bool
14 }
15
16 func (o *insertSet) Insert(md Modeler) (int64, error) {
17 if o.closed {
18 return 0, ErrStmtClosed
19 }
20 val := reflect.ValueOf(md)
21 ind := reflect.Indirect(val)
22 if val.Type() != o.mi.addrField.Type() {
23 panic(fmt.Sprintf("<Inserter.Insert> need type `%s` but found `%s`", o.mi.addrField.Type(), val.Type()))
24 }
25 id, err := o.orm.alias.DbBaser.InsertStmt(o.stmt, o.mi, ind)
26 if err != nil {
27 return id, err
28 }
29 if id > 0 {
30 if o.mi.fields.auto != nil {
31 ind.Field(o.mi.fields.auto.fieldIndex).SetInt(id)
32 }
33 }
34 return id, nil
35 }
36
37 func (o *insertSet) Close() error {
38 o.closed = true
39 return o.stmt.Close()
40 }
41
42 func newInsertSet(orm *orm, mi *modelInfo) (Inserter, error) {
43 bi := new(insertSet)
44 bi.orm = orm
45 bi.mi = mi
46 st, err := orm.alias.DbBaser.PrepareInsert(orm.db, mi)
47 if err != nil {
48 return nil, err
49 }
50 bi.stmt = st
51 return bi, nil
52 }
53
54 type object struct {
55 ind reflect.Value
56 mi *modelInfo
57 orm *orm
58 }
59
60 func (o *object) Insert() (int64, error) {
61 id, err := o.orm.alias.DbBaser.Insert(o.orm.db, o.mi, o.ind)
62 if err != nil {
63 return id, err
64 }
65 if id > 0 {
66 if o.mi.fields.auto != nil {
67 o.ind.Field(o.mi.fields.auto.fieldIndex).SetInt(id)
68 }
69 }
70 return id, nil
71 }
72
73 func (o *object) Update() (int64, error) {
74 num, err := o.orm.alias.DbBaser.Update(o.orm.db, o.mi, o.ind)
75 if err != nil {
76 return num, err
77 }
78 return 0, nil
79 }
80
81 func (o *object) Delete() (int64, error) {
82 return o.orm.alias.DbBaser.Delete(o.orm.db, o.mi, o.ind)
83 }
84
85 func newObject(orm *orm, mi *modelInfo, md Modeler) ObjectSeter {
86 o := new(object)
87 ind := reflect.Indirect(reflect.ValueOf(md))
88 o.ind = ind
89 o.mi = mi
90 o.orm = orm
91 return o
92 }
1 package orm
2
3 import (
4 "fmt"
5 )
6
7 type querySet struct {
8 mi *modelInfo
9 cond *Condition
10 related []string
11 relDepth int
12 limit int
13 offset int64
14 orders []string
15 orm *orm
16 }
17
18 func (o *querySet) Filter(expr string, args ...interface{}) QuerySeter {
19 if o.cond == nil {
20 o.cond = NewCondition()
21 }
22 o.cond.And(expr, args...)
23 return o.Clone()
24 }
25
26 func (o *querySet) Exclude(expr string, args ...interface{}) QuerySeter {
27 if o.cond == nil {
28 o.cond = NewCondition()
29 }
30 o.cond.AndNot(expr, args...)
31 return o.Clone()
32 }
33
34 func (o *querySet) Limit(limit int, args ...int64) QuerySeter {
35 o.limit = limit
36 if len(args) > 0 {
37 o.offset = args[0]
38 }
39 return o.Clone()
40 }
41
42 func (o *querySet) Offset(offset int64) QuerySeter {
43 o.offset = offset
44 return o.Clone()
45 }
46
47 func (o *querySet) OrderBy(orders ...string) QuerySeter {
48 o.orders = orders
49 return o.Clone()
50 }
51
52 func (o *querySet) RelatedSel(params ...interface{}) QuerySeter {
53 var related []string
54 if len(params) == 0 {
55 o.relDepth = DefaultRelsDepth
56 } else {
57 for _, p := range params {
58 switch val := p.(type) {
59 case string:
60 related = append(o.related, val)
61 case int:
62 o.relDepth = val
63 default:
64 panic(fmt.Sprintf("<querySet.RelatedSel> wrong param kind: %v", val))
65 }
66 }
67 }
68 o.related = related
69 return o.Clone()
70 }
71
72 func (o querySet) Clone() QuerySeter {
73 if o.cond != nil {
74 o.cond = o.cond.Clone()
75 }
76 return &o
77 }
78
79 func (o *querySet) SetCond(cond *Condition) error {
80 o.cond = cond
81 return nil
82 }
83
84 func (o *querySet) Count() (int64, error) {
85 return o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond)
86 }
87
88 func (o *querySet) Update(values Params) (int64, error) {
89 return o.orm.alias.DbBaser.UpdateBatch(o.orm.db, o, o.mi, o.cond, values)
90 }
91
92 func (o *querySet) Delete() (int64, error) {
93 return o.orm.alias.DbBaser.DeleteBatch(o.orm.db, o, o.mi, o.cond)
94 }
95
96 func (o *querySet) PrepareInsert() (Inserter, error) {
97 return newInsertSet(o.orm, o.mi)
98 }
99
100 func (o *querySet) All(container interface{}) (int64, error) {
101 return o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container)
102 }
103
104 func (o *querySet) One(container Modeler) error {
105 num, err := o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container)
106 if err != nil {
107 return err
108 }
109 if num > 1 {
110 return ErrMultiRows
111 }
112 return nil
113 }
114
115 func (o *querySet) Values(results *[]Params, args ...string) (int64, error) {
116 return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, args, results)
117 }
118
119 func (o *querySet) ValuesList(results *[]ParamsList, args ...string) (int64, error) {
120 return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, args, results)
121 }
122
123 func (o *querySet) ValuesFlat(result *ParamsList, arg string) (int64, error) {
124 return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, []string{arg}, result)
125 }
126
127 func newQuerySet(orm *orm, mi *modelInfo) QuerySeter {
128 o := new(querySet)
129 o.mi = mi
130 o.orm = orm
131 return o
132 }
1 package orm
2
3 import (
4 "database/sql"
5 "fmt"
6 "reflect"
7 )
8
9 func getResult(res sql.Result) (int64, error) {
10 if num, err := res.LastInsertId(); err != nil {
11 return 0, err
12 } else {
13 if num > 0 {
14 return num, nil
15 }
16 }
17 if num, err := res.RowsAffected(); err != nil {
18 return num, err
19 } else {
20 if num > 0 {
21 return num, nil
22 }
23 }
24 return 0, nil
25 }
26
27 type rawPrepare struct {
28 rs *rawSet
29 stmt *sql.Stmt
30 closed bool
31 }
32
33 func (o *rawPrepare) Exec(args ...interface{}) (int64, error) {
34 if o.closed {
35 return 0, ErrStmtClosed
36 }
37 res, err := o.stmt.Exec(args...)
38 if err != nil {
39 return 0, err
40 }
41 return getResult(res)
42 }
43
44 func (o *rawPrepare) Close() error {
45 o.closed = true
46 return o.stmt.Close()
47 }
48
49 func newRawPreparer(rs *rawSet) (RawPreparer, error) {
50 o := new(rawPrepare)
51 o.rs = rs
52 st, err := rs.orm.db.Prepare(rs.query)
53 if err != nil {
54 return nil, err
55 }
56 o.stmt = st
57 return o, nil
58 }
59
60 type rawSet struct {
61 query string
62 args []interface{}
63 orm *orm
64 }
65
66 func (o rawSet) SetArgs(args ...interface{}) RawSeter {
67 o.args = args
68 return &o
69 }
70
71 func (o *rawSet) Exec() (int64, error) {
72 res, err := o.orm.db.Exec(o.query, o.args...)
73 if err != nil {
74 return 0, err
75 }
76 return getResult(res)
77 }
78
79 func (o *rawSet) Mapper(...interface{}) (int64, error) {
80 //TODO
81 return 0, nil
82 }
83
84 func (o *rawSet) readValues(container interface{}) (int64, error) {
85 var (
86 maps []Params
87 lists []ParamsList
88 list ParamsList
89 )
90
91 typ := 0
92 switch container.(type) {
93 case *[]Params:
94 typ = 1
95 case *[]ParamsList:
96 typ = 2
97 case *ParamsList:
98 typ = 3
99 default:
100 panic(fmt.Sprintf("unsupport read values type `%T`", container))
101 }
102
103 var rs *sql.Rows
104 if r, err := o.orm.db.Query(o.query, o.args...); err != nil {
105 return 0, err
106 } else {
107 rs = r
108 }
109
110 var (
111 refs []interface{}
112 cnt int64
113 cols []string
114 )
115 for rs.Next() {
116 if cnt == 0 {
117 if columns, err := rs.Columns(); err != nil {
118 return 0, err
119 } else {
120 cols = columns
121 refs = make([]interface{}, len(cols))
122 for i, _ := range refs {
123 var ref string
124 refs[i] = &ref
125 }
126 }
127 }
128
129 if err := rs.Scan(refs...); err != nil {
130 return 0, err
131 }
132
133 switch typ {
134 case 1:
135 params := make(Params, len(cols))
136 for i, ref := range refs {
137 value := reflect.Indirect(reflect.ValueOf(ref)).Interface()
138 params[cols[i]] = value
139 }
140 maps = append(maps, params)
141 case 2:
142 params := make(ParamsList, 0, len(cols))
143 for _, ref := range refs {
144 value := reflect.Indirect(reflect.ValueOf(ref)).Interface()
145 params = append(params, value)
146 }
147 lists = append(lists, params)
148 case 3:
149 for _, ref := range refs {
150 value := reflect.Indirect(reflect.ValueOf(ref)).Interface()
151 list = append(list, value)
152 }
153 }
154
155 cnt++
156 }
157
158 switch v := container.(type) {
159 case *[]Params:
160 *v = maps
161 case *[]ParamsList:
162 *v = lists
163 case *ParamsList:
164 *v = list
165 }
166
167 return cnt, nil
168 }
169
170 func (o *rawSet) Values(container *[]Params) (int64, error) {
171 return o.readValues(container)
172 }
173
174 func (o *rawSet) ValuesList(container *[]ParamsList) (int64, error) {
175 return o.readValues(container)
176 }
177
178 func (o *rawSet) ValuesFlat(container *ParamsList) (int64, error) {
179 return o.readValues(container)
180 }
181
182 func (o *rawSet) Prepare() (RawPreparer, error) {
183 return newRawPreparer(o)
184 }
185
186 func newRawSet(orm *orm, query string, args []interface{}) RawSeter {
187 o := new(rawSet)
188 o.query = query
189 o.args = args
190 o.orm = orm
191 return o
192 }
1 package orm
2
3 import (
4 "database/sql"
5 "reflect"
6 )
7
8 type Fielder interface {
9 String() string
10 FieldType() int
11 SetRaw(interface{}) error
12 RawValue() interface{}
13 }
14
15 type Modeler interface {
16 Init(Modeler) Modeler
17 IsInited() bool
18 Clean() FieldErrors
19 CleanFields(string) FieldErrors
20 GetTableName() string
21 }
22
23 type Ormer interface {
24 Object(Modeler) ObjectSeter
25 QueryTable(interface{}) QuerySeter
26 Using(string) error
27 Begin() error
28 Commit() error
29 Rollback() error
30 Raw(string, ...interface{}) RawSeter
31 }
32
33 type ObjectSeter interface {
34 Insert() (int64, error)
35 Update() (int64, error)
36 Delete() (int64, error)
37 }
38
39 type Inserter interface {
40 Insert(Modeler) (int64, error)
41 Close() error
42 }
43
44 type QuerySeter interface {
45 Filter(string, ...interface{}) QuerySeter
46 Exclude(string, ...interface{}) QuerySeter
47 Limit(int, ...int64) QuerySeter
48 Offset(int64) QuerySeter
49 OrderBy(...string) QuerySeter
50 RelatedSel(...interface{}) QuerySeter
51 Clone() QuerySeter
52 SetCond(*Condition) error
53 Count() (int64, error)
54 Update(Params) (int64, error)
55 Delete() (int64, error)
56 PrepareInsert() (Inserter, error)
57
58 All(interface{}) (int64, error)
59 One(Modeler) error
60 Values(*[]Params, ...string) (int64, error)
61 ValuesList(*[]ParamsList, ...string) (int64, error)
62 ValuesFlat(*ParamsList, string) (int64, error)
63 }
64
65 type RawPreparer interface {
66 Close() error
67 }
68
69 type RawSeter interface {
70 Exec() (int64, error)
71 Mapper(...interface{}) (int64, error)
72 Values(*[]Params) (int64, error)
73 ValuesList(*[]ParamsList) (int64, error)
74 ValuesFlat(*ParamsList) (int64, error)
75 Prepare() (RawPreparer, error)
76 }
77
78 type dbQuerier interface {
79 Prepare(query string) (*sql.Stmt, error)
80 Exec(query string, args ...interface{}) (sql.Result, error)
81 Query(query string, args ...interface{}) (*sql.Rows, error)
82 QueryRow(query string, args ...interface{}) *sql.Row
83 }
84
85 type dbBaser interface {
86 Insert(dbQuerier, *modelInfo, reflect.Value) (int64, error)
87 InsertStmt(*sql.Stmt, *modelInfo, reflect.Value) (int64, error)
88 Update(dbQuerier, *modelInfo, reflect.Value) (int64, error)
89 Delete(dbQuerier, *modelInfo, reflect.Value) (int64, error)
90 ReadBatch(dbQuerier, *querySet, *modelInfo, *Condition, interface{}) (int64, error)
91 UpdateBatch(dbQuerier, *querySet, *modelInfo, *Condition, Params) (int64, error)
92 DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition) (int64, error)
93 Count(dbQuerier, *querySet, *modelInfo, *Condition) (int64, error)
94 GetOperatorSql(*modelInfo, string, []interface{}) (string, []interface{})
95 PrepareInsert(dbQuerier, *modelInfo) (*sql.Stmt, error)
96 ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}) (int64, error)
97 }
1 package orm
2
3 import (
4 "fmt"
5 "strconv"
6 "strings"
7 "time"
8 )
9
10 type StrTo string
11
12 func (f *StrTo) Set(v string) {
13 if v != "" {
14 *f = StrTo(v)
15 } else {
16 f.Clear()
17 }
18 }
19
20 func (f *StrTo) Clear() {
21 *f = StrTo(0x1E)
22 }
23
24 func (f StrTo) Exist() bool {
25 return string(f) != string(0x1E)
26 }
27
28 func (f StrTo) Bool() (bool, error) {
29 return strconv.ParseBool(f.String())
30 }
31
32 func (f StrTo) Float32() (float32, error) {
33 v, err := strconv.ParseFloat(f.String(), 32)
34 return float32(v), err
35 }
36
37 func (f StrTo) Float64() (float64, error) {
38 return strconv.ParseFloat(f.String(), 64)
39 }
40
41 func (f StrTo) Int16() (int16, error) {
42 v, err := strconv.ParseInt(f.String(), 10, 16)
43 return int16(v), err
44 }
45
46 func (f StrTo) Int32() (int32, error) {
47 v, err := strconv.ParseInt(f.String(), 10, 32)
48 return int32(v), err
49 }
50
51 func (f StrTo) Int64() (int64, error) {
52 v, err := strconv.ParseInt(f.String(), 10, 64)
53 return int64(v), err
54 }
55
56 func (f StrTo) Uint16() (uint16, error) {
57 v, err := strconv.ParseUint(f.String(), 10, 16)
58 return uint16(v), err
59 }
60
61 func (f StrTo) Uint32() (uint32, error) {
62 v, err := strconv.ParseUint(f.String(), 10, 32)
63 return uint32(v), err
64 }
65
66 func (f StrTo) Uint64() (uint64, error) {
67 v, err := strconv.ParseUint(f.String(), 10, 64)
68 return uint64(v), err
69 }
70
71 func (f StrTo) String() string {
72 if f.Exist() {
73 return string(f)
74 }
75 return ""
76 }
77
78 func ToStr(value interface{}, args ...int) (s string) {
79 switch v := value.(type) {
80 case bool:
81 s = strconv.FormatBool(v)
82 case float32:
83 s = strconv.FormatFloat(float64(v), 'f', argInt(args).Get(0, -1), argInt(args).Get(1, 32))
84 case float64:
85 s = strconv.FormatFloat(v, 'f', argInt(args).Get(0, -1), argInt(args).Get(1, 64))
86 case int:
87 s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10))
88 case int16:
89 s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10))
90 case int32:
91 s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10))
92 case int64:
93 s = strconv.FormatInt(v, argInt(args).Get(0, 10))
94 case uint:
95 s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10))
96 case uint16:
97 s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10))
98 case uint32:
99 s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10))
100 case uint64:
101 s = strconv.FormatUint(v, argInt(args).Get(0, 10))
102 case string:
103 s = v
104 default:
105 s = fmt.Sprintf("%v", v)
106 }
107 return s
108 }
109
110 func snakeString(s string) string {
111 data := make([]byte, 0, len(s)*2)
112 j := false
113 num := len(s)
114 for i := 0; i < num; i++ {
115 d := s[i]
116 if i > 0 && d >= 'A' && d <= 'Z' && j {
117 data = append(data, '_')
118 }
119 if d != '_' {
120 j = true
121 }
122 data = append(data, d)
123 }
124 return strings.ToLower(string(data[:len(data)]))
125 }
126
127 func camelString(s string) string {
128 data := make([]byte, 0, len(s))
129 j := false
130 k := false
131 num := len(s) - 1
132 for i := 0; i <= num; i++ {
133 d := s[i]
134 if k == false && d >= 'A' && d <= 'Z' {
135 k = true
136 }
137 if d >= 'a' && d <= 'z' && (j || k == false) {
138 d = d - 32
139 j = false
140 k = true
141 }
142 if k && d == '_' && num > i && s[i+1] >= 'a' && s[i+1] <= 'z' {
143 j = true
144 continue
145 }
146 data = append(data, d)
147 }
148 return string(data[:len(data)])
149 }
150
151 type argString []string
152
153 func (a argString) Get(i int, args ...string) (r string) {
154 if i >= 0 && i < len(a) {
155 r = a[i]
156 } else if len(args) > 0 {
157 r = args[0]
158 }
159 return
160 }
161
162 type argInt []int
163
164 func (a argInt) Get(i int, args ...int) (r int) {
165 if i >= 0 && i < len(a) {
166 r = a[i]
167 }
168 if len(args) > 0 {
169 r = args[0]
170 }
171 return
172 }
173
174 func timeParse(dateString, format string) (time.Time, error) {
175 tp, err := time.ParseInLocation(format, dateString, DefaultTimeLoc)
176 return tp, err
177 }
178
179 func timeFormat(t time.Time, format string) string {
180 return t.Format(format)
181 }
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!