46668b81 by slene

some fix / add test

1 parent 10f4e822
...@@ -208,7 +208,7 @@ func (t *dbTables) getJoinSql() (join string) { ...@@ -208,7 +208,7 @@ func (t *dbTables) getJoinSql() (join string) {
208 208
209 switch { 209 switch {
210 case jt.fi.fieldType == RelManyToMany || jt.fi.reverse && jt.fi.reverseFieldInfo.fieldType == RelManyToMany: 210 case jt.fi.fieldType == RelManyToMany || jt.fi.reverse && jt.fi.reverseFieldInfo.fieldType == RelManyToMany:
211 c1 = jt.fi.mi.fields.pk[0].column 211 c1 = jt.fi.mi.fields.pk.column
212 for _, ffi := range jt.mi.fields.fieldsRel { 212 for _, ffi := range jt.mi.fields.fieldsRel {
213 if jt.fi.mi == ffi.relModelInfo { 213 if jt.fi.mi == ffi.relModelInfo {
214 c2 = ffi.column 214 c2 = ffi.column
...@@ -217,10 +217,10 @@ func (t *dbTables) getJoinSql() (join string) { ...@@ -217,10 +217,10 @@ func (t *dbTables) getJoinSql() (join string) {
217 } 217 }
218 default: 218 default:
219 c1 = jt.fi.column 219 c1 = jt.fi.column
220 c2 = jt.fi.relModelInfo.fields.pk[0].column 220 c2 = jt.fi.relModelInfo.fields.pk.column
221 221
222 if jt.fi.reverse { 222 if jt.fi.reverse {
223 c1 = jt.mi.fields.pk[0].column 223 c1 = jt.mi.fields.pk.column
224 c2 = jt.fi.reverseFieldInfo.column 224 c2 = jt.fi.reverseFieldInfo.column
225 } 225 }
226 } 226 }
...@@ -263,6 +263,8 @@ func (d *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, column, nam ...@@ -263,6 +263,8 @@ func (d *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, column, nam
263 if fi.reverseFieldInfo.fieldType == RelManyToMany { 263 if fi.reverseFieldInfo.fieldType == RelManyToMany {
264 mmi = fi.reverseFieldInfo.relThroughModelInfo 264 mmi = fi.reverseFieldInfo.relThroughModelInfo
265 } 265 }
266 default:
267 return
266 } 268 }
267 269
268 jt, _ := d.add(names, mmi, fi, fi.null == false) 270 jt, _ := d.add(names, mmi, fi, fi.null == false)
...@@ -434,40 +436,36 @@ type dbBase struct { ...@@ -434,40 +436,36 @@ type dbBase struct {
434 ins dbBaser 436 ins dbBaser
435 } 437 }
436 438
437 func (d *dbBase) existPk(mi *modelInfo, ind reflect.Value) ([]string, []interface{}, bool) { 439 func (d *dbBase) existPk(mi *modelInfo, ind reflect.Value) (column string, value interface{}, exist bool) {
438 exist := true 440
439 columns := make([]string, 0, len(mi.fields.pk)) 441 fi := mi.fields.pk
440 values := make([]interface{}, 0, len(mi.fields.pk)) 442
441 for _, fi := range mi.fields.pk { 443 v := ind.Field(fi.fieldIndex)
442 v := ind.Field(fi.fieldIndex) 444 if fi.fieldType&IsIntegerField > 0 {
443 if fi.fieldType&IsIntegerField > 0 { 445 vu := v.Int()
444 vu := v.Int() 446 exist = vu > 0
445 if exist { 447 value = vu
446 exist = vu > 0 448 } else {
447 } 449 vu := v.String()
448 values = append(values, vu) 450 exist = vu != ""
449 } else { 451 value = vu
450 vu := v.String()
451 if exist {
452 exist = vu != ""
453 }
454 values = append(values, vu)
455 }
456 columns = append(columns, fi.column)
457 } 452 }
458 return columns, values, exist 453
454 column = fi.column
455
456 return
459 } 457 }
460 458
461 func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, skipAuto bool, insert bool) (columns []string, values []interface{}, err error) { 459 func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, skipAuto bool, insert bool) (columns []string, values []interface{}, err error) {
462 _, pkValues, _ := d.existPk(mi, ind) 460 _, pkValue, _ := d.existPk(mi, ind)
463 for _, column := range mi.fields.orders { 461 for _, column := range mi.fields.orders {
464 fi := mi.fields.columns[column] 462 fi := mi.fields.columns[column]
465 if fi.dbcol == false || fi.auto && skipAuto { 463 if fi.dbcol == false || fi.auto && skipAuto {
466 continue 464 continue
467 } 465 }
468 var value interface{} 466 var value interface{}
469 if i, ok := mi.fields.pk.Exist(fi); ok { 467 if fi.pk {
470 value = pkValues[i] 468 value = pkValue
471 } else { 469 } else {
472 field := ind.Field(fi.fieldIndex) 470 field := ind.Field(fi.fieldIndex)
473 if fi.isFielder { 471 if fi.isFielder {
...@@ -493,9 +491,8 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, skipAuto bool, ...@@ -493,9 +491,8 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, skipAuto bool,
493 if field.IsNil() { 491 if field.IsNil() {
494 value = nil 492 value = nil
495 } else { 493 } else {
496 _, fvalues, fok := d.existPk(fi.relModelInfo, reflect.Indirect(field)) 494 if _, vu, ok := d.existPk(fi.relModelInfo, reflect.Indirect(field)); ok {
497 if fok { 495 value = vu
498 value = fvalues[0]
499 } else { 496 } else {
500 value = nil 497 value = nil
501 } 498 }
...@@ -560,17 +557,15 @@ func (d *dbBase) InsertStmt(stmt *sql.Stmt, mi *modelInfo, ind reflect.Value) (i ...@@ -560,17 +557,15 @@ func (d *dbBase) InsertStmt(stmt *sql.Stmt, mi *modelInfo, ind reflect.Value) (i
560 } 557 }
561 558
562 func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value) error { 559 func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value) error {
563 pkNames, pkValues, ok := d.existPk(mi, ind) 560 pkColumn, pkValue, ok := d.existPk(mi, ind)
564 if ok == false { 561 if ok == false {
565 return ErrMissPK 562 return ErrMissPK
566 } 563 }
567 564
568 pkColumns := strings.Join(pkNames, "` = ? AND `")
569
570 sels := strings.Join(mi.fields.dbcols, "`, `") 565 sels := strings.Join(mi.fields.dbcols, "`, `")
571 colsNum := len(mi.fields.dbcols) 566 colsNum := len(mi.fields.dbcols)
572 567
573 query := fmt.Sprintf("SELECT `%s` FROM `%s` WHERE `%s` = ?", sels, mi.table, pkColumns) 568 query := fmt.Sprintf("SELECT `%s` FROM `%s` WHERE `%s` = ?", sels, mi.table, pkColumn)
574 569
575 refs := make([]interface{}, colsNum) 570 refs := make([]interface{}, colsNum)
576 for i, _ := range refs { 571 for i, _ := range refs {
...@@ -578,8 +573,11 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value) error { ...@@ -578,8 +573,11 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value) error {
578 refs[i] = &ref 573 refs[i] = &ref
579 } 574 }
580 575
581 row := q.QueryRow(query, pkValues...) 576 row := q.QueryRow(query, pkValue)
582 if err := row.Scan(refs...); err != nil { 577 if err := row.Scan(refs...); err != nil {
578 if err == sql.ErrNoRows {
579 return ErrNoRows
580 }
583 return err 581 return err
584 } else { 582 } else {
585 elm := reflect.New(mi.addrField.Elem().Type()) 583 elm := reflect.New(mi.addrField.Elem().Type())
...@@ -618,7 +616,7 @@ func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e ...@@ -618,7 +616,7 @@ func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e
618 } 616 }
619 617
620 func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, error) { 618 func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, error) {
621 pkNames, pkValues, ok := d.existPk(mi, ind) 619 pkName, pkValue, ok := d.existPk(mi, ind)
622 if ok == false { 620 if ok == false {
623 return 0, ErrMissPK 621 return 0, ErrMissPK
624 } 622 }
...@@ -627,12 +625,11 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e ...@@ -627,12 +625,11 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e
627 return 0, err 625 return 0, err
628 } 626 }
629 627
630 pkColumns := strings.Join(pkNames, "` = ? AND `")
631 setColumns := strings.Join(setNames, "` = ?, `") 628 setColumns := strings.Join(setNames, "` = ?, `")
632 629
633 query := fmt.Sprintf("UPDATE `%s` SET `%s` = ? WHERE `%s` = ?", mi.table, setColumns, pkColumns) 630 query := fmt.Sprintf("UPDATE `%s` SET `%s` = ? WHERE `%s` = ?", mi.table, setColumns, pkName)
634 631
635 setValues = append(setValues, pkValues...) 632 setValues = append(setValues, pkValue)
636 633
637 if res, err := q.Exec(query, setValues...); err == nil { 634 if res, err := q.Exec(query, setValues...); err == nil {
638 return res.RowsAffected() 635 return res.RowsAffected()
...@@ -643,16 +640,14 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e ...@@ -643,16 +640,14 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e
643 } 640 }
644 641
645 func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, error) { 642 func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, error) {
646 names, values, ok := d.existPk(mi, ind) 643 pkName, pkValue, ok := d.existPk(mi, ind)
647 if ok == false { 644 if ok == false {
648 return 0, ErrMissPK 645 return 0, ErrMissPK
649 } 646 }
650 647
651 columns := strings.Join(names, "` = ? AND `") 648 query := fmt.Sprintf("DELETE FROM `%s` WHERE `%s` = ?", mi.table, pkName)
652 649
653 query := fmt.Sprintf("DELETE FROM `%s` WHERE `%s` = ?", mi.table, columns) 650 if res, err := q.Exec(query, pkValue); err == nil {
654
655 if res, err := q.Exec(query, values...); err == nil {
656 651
657 num, err := res.RowsAffected() 652 num, err := res.RowsAffected()
658 if err != nil { 653 if err != nil {
...@@ -660,15 +655,13 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e ...@@ -660,15 +655,13 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e
660 } 655 }
661 656
662 if num > 0 { 657 if num > 0 {
663 if mi.fields.auto != nil { 658 if mi.fields.pk.auto {
664 ind.Field(mi.fields.auto.fieldIndex).SetInt(0) 659 ind.Field(mi.fields.pk.fieldIndex).SetInt(0)
665 } 660 }
666 661
667 if len(names) == 1 { 662 err := d.deleteRels(q, mi, []interface{}{pkValue})
668 err := d.deleteRels(q, mi, values) 663 if err != nil {
669 if err != nil { 664 return num, err
670 return num, err
671 }
672 } 665 }
673 } 666 }
674 667
...@@ -683,12 +676,12 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con ...@@ -683,12 +676,12 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
683 columns := make([]string, 0, len(params)) 676 columns := make([]string, 0, len(params))
684 values := make([]interface{}, 0, len(params)) 677 values := make([]interface{}, 0, len(params))
685 for col, val := range params { 678 for col, val := range params {
686 column := snakeString(col) 679 if fi, ok := mi.fields.GetByAny(col); ok == false || fi.dbcol == false {
687 if fi, ok := mi.fields.columns[column]; ok == false || fi.dbcol == false { 680 panic(fmt.Sprintf("wrong field/column name `%s`", col))
688 panic(fmt.Sprintf("wrong field/column name `%s`", column)) 681 } else {
682 columns = append(columns, fi.column)
683 values = append(values, val)
689 } 684 }
690 columns = append(columns, column)
691 values = append(values, val)
692 } 685 }
693 686
694 if len(columns) == 0 { 687 if len(columns) == 0 {
...@@ -721,15 +714,13 @@ func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}) erro ...@@ -721,15 +714,13 @@ func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}) erro
721 fi = fi.reverseFieldInfo 714 fi = fi.reverseFieldInfo
722 switch fi.onDelete { 715 switch fi.onDelete {
723 case od_CASCADE: 716 case od_CASCADE:
724 cond := NewCondition() 717 cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...)
725 cond.And(fmt.Sprintf("%s__in", fi.name), args...)
726 _, err := d.DeleteBatch(q, nil, fi.mi, cond) 718 _, err := d.DeleteBatch(q, nil, fi.mi, cond)
727 if err != nil { 719 if err != nil {
728 return err 720 return err
729 } 721 }
730 case od_SET_DEFAULT, od_SET_NULL: 722 case od_SET_DEFAULT, od_SET_NULL:
731 cond := NewCondition() 723 cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...)
732 cond.And(fmt.Sprintf("%s__in", fi.name), args...)
733 params := Params{fi.column: nil} 724 params := Params{fi.column: nil}
734 if fi.onDelete == od_SET_DEFAULT { 725 if fi.onDelete == od_SET_DEFAULT {
735 params[fi.column] = fi.initial.String() 726 params[fi.column] = fi.initial.String()
...@@ -757,13 +748,8 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con ...@@ -757,13 +748,8 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
757 where, args := tables.getCondSql(cond, false) 748 where, args := tables.getCondSql(cond, false)
758 join := tables.getJoinSql() 749 join := tables.getJoinSql()
759 750
760 colsNum := len(mi.fields.pk) 751 cols := fmt.Sprintf("T0.`%s`", mi.fields.pk.column)
761 cols := make([]string, colsNum) 752 query := fmt.Sprintf("SELECT %s FROM `%s` T0 %s%s", cols, mi.table, join, where)
762 for i, fi := range mi.fields.pk {
763 cols[i] = fi.column
764 }
765 colsql := fmt.Sprintf("T0.`%s`", strings.Join(cols, "`, T0.`"))
766 query := fmt.Sprintf("SELECT %s FROM `%s` T0 %s%s", colsql, mi.table, join, where)
767 753
768 var rs *sql.Rows 754 var rs *sql.Rows
769 if r, err := q.Query(query, args...); err != nil { 755 if r, err := q.Query(query, args...); err != nil {
...@@ -772,21 +758,15 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con ...@@ -772,21 +758,15 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
772 rs = r 758 rs = r
773 } 759 }
774 760
775 refs := make([]interface{}, colsNum) 761 var ref interface{}
776 for i, _ := range refs {
777 var ref interface{}
778 refs[i] = &ref
779 }
780 762
781 args = make([]interface{}, 0) 763 args = make([]interface{}, 0)
782 cnt := 0 764 cnt := 0
783 for rs.Next() { 765 for rs.Next() {
784 if err := rs.Scan(refs...); err != nil { 766 if err := rs.Scan(&ref); err != nil {
785 return 0, err 767 return 0, err
786 } 768 }
787 for _, ref := range refs { 769 args = append(args, reflect.ValueOf(ref).Interface())
788 args = append(args, reflect.ValueOf(ref).Elem().Interface())
789 }
790 cnt++ 770 cnt++
791 } 771 }
792 772
...@@ -794,14 +774,8 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con ...@@ -794,14 +774,8 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
794 return 0, nil 774 return 0, nil
795 } 775 }
796 776
797 if colsNum > 1 { 777 sql, args := d.ins.GetOperatorSql(mi, "in", args)
798 columns := strings.Join(cols, "` = ? AND `") 778 query = fmt.Sprintf("DELETE FROM `%s` WHERE `%s` %s", mi.table, mi.fields.pk.column, sql)
799 query = fmt.Sprintf("DELETE FROM `%s` WHERE `%s` = ?", mi.table, columns)
800 } else {
801 var sql string
802 sql, args = d.ins.GetOperatorSql(mi, "in", args)
803 query = fmt.Sprintf("DELETE FROM `%s` WHERE `%s` %s", mi.table, cols[0], sql)
804 }
805 779
806 if res, err := q.Exec(query, args...); err == nil { 780 if res, err := q.Exec(query, args...); err == nil {
807 num, err := res.RowsAffected() 781 num, err := res.RowsAffected()
...@@ -809,7 +783,7 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con ...@@ -809,7 +783,7 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
809 return 0, err 783 return 0, err
810 } 784 }
811 785
812 if colsNum == 1 && num > 0 { 786 if num > 0 {
813 err := d.deleteRels(q, mi, args) 787 err := d.deleteRels(q, mi, args)
814 if err != nil { 788 if err != nil {
815 return num, err 789 return num, err
...@@ -980,14 +954,12 @@ func (d *dbBase) GetOperatorSql(mi *modelInfo, operator string, args []interface ...@@ -980,14 +954,12 @@ func (d *dbBase) GetOperatorSql(mi *modelInfo, operator string, args []interface
980 copy(params, args) 954 copy(params, args)
981 sql := "" 955 sql := ""
982 for i, arg := range args { 956 for i, arg := range args {
983 if len(mi.fields.pk) == 1 { 957 if md, ok := arg.(Modeler); ok {
984 if md, ok := arg.(Modeler); ok { 958 ind := reflect.Indirect(reflect.ValueOf(md))
985 ind := reflect.Indirect(reflect.ValueOf(md)) 959 if _, vu, exist := d.existPk(mi, ind); exist {
986 if _, values, exist := d.existPk(mi, ind); exist { 960 arg = vu
987 arg = values[0] 961 } else {
988 } else { 962 panic(fmt.Sprintf("`%s` need a valid args value", operator))
989 panic(fmt.Sprintf("`%s` need a valid args value", operator))
990 }
991 } 963 }
992 } 964 }
993 params[i] = arg 965 params[i] = arg
...@@ -1175,7 +1147,7 @@ setValue: ...@@ -1175,7 +1147,7 @@ setValue:
1175 value = v 1147 value = v
1176 } 1148 }
1177 case fieldType&IsRelField > 0: 1149 case fieldType&IsRelField > 0:
1178 fieldType = fi.relModelInfo.fields.pk[0].fieldType 1150 fieldType = fi.relModelInfo.fields.pk.fieldType
1179 goto setValue 1151 goto setValue
1180 } 1152 }
1181 1153
...@@ -1236,12 +1208,12 @@ setValue: ...@@ -1236,12 +1208,12 @@ setValue:
1236 } 1208 }
1237 case fieldType&IsRelField > 0: 1209 case fieldType&IsRelField > 0:
1238 if value != nil { 1210 if value != nil {
1239 fieldType = fi.relModelInfo.fields.pk[0].fieldType 1211 fieldType = fi.relModelInfo.fields.pk.fieldType
1240 mf := reflect.New(fi.relModelInfo.addrField.Elem().Type()) 1212 mf := reflect.New(fi.relModelInfo.addrField.Elem().Type())
1241 md := mf.Interface().(Modeler) 1213 md := mf.Interface().(Modeler)
1242 md.Init(md) 1214 md.Init(md)
1243 field.Set(mf) 1215 field.Set(mf)
1244 f := mf.Elem().Field(fi.relModelInfo.fields.pk[0].fieldIndex) 1216 f := mf.Elem().Field(fi.relModelInfo.fields.pk.fieldIndex)
1245 field = &f 1217 field = &f
1246 goto setValue 1218 goto setValue
1247 } 1219 }
......
...@@ -9,24 +9,37 @@ import ( ...@@ -9,24 +9,37 @@ import (
9 9
10 const defaultMaxIdle = 30 10 const defaultMaxIdle = 30
11 11
12 type driverType int 12 type DriverType int
13 13
14 const ( 14 const (
15 _ driverType = iota 15 _ DriverType = iota
16 DR_MySQL 16 DR_MySQL
17 DR_Sqlite 17 DR_Sqlite
18 DR_Oracle 18 DR_Oracle
19 DR_Postgres 19 DR_Postgres
20 ) 20 )
21 21
22 type driver string
23
24 func (d driver) Type() DriverType {
25 a, _ := dataBaseCache.get(string(d))
26 return a.Driver
27 }
28
29 func (d driver) Name() string {
30 return string(d)
31 }
32
33 var _ Driver = new(driver)
34
22 var ( 35 var (
23 dataBaseCache = &_dbCache{cache: make(map[string]*alias)} 36 dataBaseCache = &_dbCache{cache: make(map[string]*alias)}
24 drivers = map[string]driverType{ 37 drivers = map[string]DriverType{
25 "mysql": DR_MySQL, 38 "mysql": DR_MySQL,
26 "postgres": DR_Postgres, 39 "postgres": DR_Postgres,
27 "sqlite3": DR_Sqlite, 40 "sqlite3": DR_Sqlite,
28 } 41 }
29 dbBasers = map[driverType]dbBaser{ 42 dbBasers = map[DriverType]dbBaser{
30 DR_MySQL: newdbBaseMysql(), 43 DR_MySQL: newdbBaseMysql(),
31 DR_Sqlite: newdbBaseSqlite(), 44 DR_Sqlite: newdbBaseSqlite(),
32 DR_Oracle: newdbBaseMysql(), 45 DR_Oracle: newdbBaseMysql(),
...@@ -63,6 +76,7 @@ func (ac *_dbCache) getDefault() (al *alias) { ...@@ -63,6 +76,7 @@ func (ac *_dbCache) getDefault() (al *alias) {
63 76
64 type alias struct { 77 type alias struct {
65 Name string 78 Name string
79 Driver DriverType
66 DriverName string 80 DriverName string
67 DataSource string 81 DataSource string
68 MaxIdle int 82 MaxIdle int
...@@ -87,6 +101,7 @@ func RegisterDataBase(name, driverName, dataSource string, maxIdle int) { ...@@ -87,6 +101,7 @@ func RegisterDataBase(name, driverName, dataSource string, maxIdle int) {
87 101
88 if dr, ok := drivers[driverName]; ok { 102 if dr, ok := drivers[driverName]; ok {
89 al.DbBaser = dbBasers[dr] 103 al.DbBaser = dbBasers[dr]
104 al.Driver = dr
90 } else { 105 } else {
91 err = fmt.Errorf("driver name `%s` have not registered", driverName) 106 err = fmt.Errorf("driver name `%s` have not registered", driverName)
92 goto end 107 goto end
...@@ -116,7 +131,7 @@ end: ...@@ -116,7 +131,7 @@ end:
116 } 131 }
117 } 132 }
118 133
119 func RegisterDriver(name string, typ driverType) { 134 func RegisterDriver(name string, typ DriverType) {
120 if t, ok := drivers[name]; ok == false { 135 if t, ok := drivers[name]; ok == false {
121 drivers[name] = typ 136 drivers[name] = typ
122 } else { 137 } else {
......
...@@ -49,6 +49,7 @@ type _modelCache struct { ...@@ -49,6 +49,7 @@ type _modelCache struct {
49 sync.RWMutex 49 sync.RWMutex
50 orders []string 50 orders []string
51 cache map[string]*modelInfo 51 cache map[string]*modelInfo
52 done bool
52 } 53 }
53 54
54 func (mc *_modelCache) all() map[string]*modelInfo { 55 func (mc *_modelCache) all() map[string]*modelInfo {
......
...@@ -8,7 +8,7 @@ import ( ...@@ -8,7 +8,7 @@ import (
8 "strings" 8 "strings"
9 ) 9 )
10 10
11 func RegisterModel(model Modeler) { 11 func registerModel(model Modeler) {
12 info := newModelInfo(model) 12 info := newModelInfo(model)
13 model.Init(model) 13 model.Init(model)
14 table := model.GetTableName() 14 table := model.GetTableName()
...@@ -27,9 +27,10 @@ func RegisterModel(model Modeler) { ...@@ -27,9 +27,10 @@ func RegisterModel(model Modeler) {
27 modelCache.set(table, info) 27 modelCache.set(table, info)
28 } 28 }
29 29
30 func BootStrap() { 30 func bootStrap() {
31 modelCache.Lock() 31 if modelCache.done {
32 defer modelCache.Unlock() 32 return
33 }
33 34
34 var ( 35 var (
35 err error 36 err error
...@@ -59,14 +60,6 @@ func BootStrap() { ...@@ -59,14 +60,6 @@ func BootStrap() {
59 } 60 }
60 fi.relModelInfo = mii 61 fi.relModelInfo = mii
61 62
62 if fi.rel {
63
64 if mii.fields.pk.IsMulti() {
65 err = fmt.Errorf("field `%s` unsupport rel to multi primary key field", fi.fullName)
66 goto end
67 }
68 }
69
70 switch fi.fieldType { 63 switch fi.fieldType {
71 case RelManyToMany: 64 case RelManyToMany:
72 if fi.relThrough != "" { 65 if fi.relThrough != "" {
...@@ -207,6 +200,25 @@ end: ...@@ -207,6 +200,25 @@ end:
207 fmt.Println(err) 200 fmt.Println(err)
208 os.Exit(2) 201 os.Exit(2)
209 } 202 }
203 }
210 204
211 runCommand() 205 func RegisterModel(models ...Modeler) {
206 if modelCache.done {
207 panic(fmt.Errorf("RegisterModel must be run begore BootStrap"))
208 }
209
210 for _, model := range models {
211 registerModel(model)
212 }
213 }
214
215 func BootStrap() {
216 if modelCache.done {
217 return
218 }
219
220 modelCache.Lock()
221 defer modelCache.Unlock()
222 bootStrap()
223 modelCache.done = true
212 } 224 }
......
...@@ -32,32 +32,8 @@ func (f *fieldChoices) Clone() fieldChoices { ...@@ -32,32 +32,8 @@ func (f *fieldChoices) Clone() fieldChoices {
32 return *f 32 return *f
33 } 33 }
34 34
35 type primaryKeys []*fieldInfo
36
37 func (p *primaryKeys) Add(fi *fieldInfo) {
38 *p = append(*p, fi)
39 }
40
41 func (p primaryKeys) Exist(fi *fieldInfo) (int, bool) {
42 for i, v := range p {
43 if v == fi {
44 return i, true
45 }
46 }
47 return -1, false
48 }
49
50 func (p primaryKeys) IsMulti() bool {
51 return len(p) > 1
52 }
53
54 func (p primaryKeys) IsEmpty() bool {
55 return len(p) == 0
56 }
57
58 type fields struct { 35 type fields struct {
59 pk primaryKeys 36 pk *fieldInfo
60 auto *fieldInfo
61 columns map[string]*fieldInfo 37 columns map[string]*fieldInfo
62 fields map[string]*fieldInfo 38 fields map[string]*fieldInfo
63 fieldsLow map[string]*fieldInfo 39 fieldsLow map[string]*fieldInfo
......
...@@ -50,41 +50,31 @@ func newModelInfo(model Modeler) (info *modelInfo) { ...@@ -50,41 +50,31 @@ func newModelInfo(model Modeler) (info *modelInfo) {
50 if err != nil { 50 if err != nil {
51 break 51 break
52 } 52 }
53
53 added := info.fields.Add(fi) 54 added := info.fields.Add(fi)
54 if added == false { 55 if added == false {
55 err = errors.New(fmt.Sprintf("duplicate column name: %s", fi.column)) 56 err = errors.New(fmt.Sprintf("duplicate column name: %s", fi.column))
56 break 57 break
57 } 58 }
59
58 if fi.pk { 60 if fi.pk {
59 if info.fields.pk != nil { 61 if info.fields.pk != nil {
60 err = errors.New(fmt.Sprintf("one model must have one pk field only")) 62 err = errors.New(fmt.Sprintf("one model must have one pk field only"))
61 break 63 break
62 } else { 64 } else {
63 info.fields.pk.Add(fi) 65 info.fields.pk = fi
64 } 66 }
65 } 67 }
66 if fi.auto { 68
67 info.fields.auto = fi
68 }
69 fi.fieldIndex = i 69 fi.fieldIndex = i
70 fi.mi = info 70 fi.mi = info
71 } 71 }
72 72
73 if _, ok := info.fields.pk.Exist(info.fields.auto); info.fields.auto != nil && ok == false {
74 err = errors.New(fmt.Sprintf("when auto field exists, you cannot set other pk field"))
75 goto end
76 }
77
78 if err != nil { 73 if err != nil {
79 fmt.Println(fmt.Errorf("field: %s.%s, %s", ind.Type(), sf.Name, err)) 74 fmt.Println(fmt.Errorf("field: %s.%s, %s", ind.Type(), sf.Name, err))
80 os.Exit(2) 75 os.Exit(2)
81 } 76 }
82 77
83 end:
84 if err != nil {
85 fmt.Println(err)
86 os.Exit(2)
87 }
88 return 78 return
89 } 79 }
90 80
...@@ -125,6 +115,6 @@ func newM2MModelInfo(m1, m2 *modelInfo) (info *modelInfo) { ...@@ -125,6 +115,6 @@ func newM2MModelInfo(m1, m2 *modelInfo) (info *modelInfo) {
125 info.fields.Add(fa) 115 info.fields.Add(fa)
126 info.fields.Add(f1) 116 info.fields.Add(f1)
127 info.fields.Add(f2) 117 info.fields.Add(f2)
128 info.fields.pk.Add(fa) 118 info.fields.pk = fa
129 return 119 return
130 } 120 }
......
1 package orm
2
3 import (
4 "fmt"
5 "os"
6 "time"
7
8 _ "github.com/bmizerany/pq"
9 _ "github.com/go-sql-driver/mysql"
10 _ "github.com/mattn/go-sqlite3"
11 )
12
13 type User struct {
14 Id int `orm:"auto"`
15 UserName string `orm:"size(30);unique"`
16 Email string `orm:"size(100)"`
17 Password string `orm:"size(100)"`
18 Status int16 `orm:"choices(0,1,2,3);defalut(0)"`
19 IsStaff bool `orm:"default(false)"`
20 IsActive bool `orm:"default(1)"`
21 Created time.Time `orm:"auto_now_add;type(date)"`
22 Updated time.Time `orm:"auto_now"`
23 Profile *Profile `orm:"null;rel(one);on_delete(set_null)"`
24 Posts []*Post `orm:"reverse(many)" json:"-"`
25 Manager `json:"-"`
26 }
27
28 func NewUser() *User {
29 obj := new(User)
30 obj.Manager.Init(obj)
31 return obj
32 }
33
34 type Profile struct {
35 Id int `orm:"auto"`
36 Age int16 ``
37 Money float64 ``
38 User *User `orm:"reverse(one)" json:"-"`
39 Manager `json:"-"`
40 }
41
42 func (u *Profile) TableName() string {
43 return "user_profile"
44 }
45
46 func NewProfile() *Profile {
47 obj := new(Profile)
48 obj.Manager.Init(obj)
49 return obj
50 }
51
52 type Post struct {
53 Id int `orm:"auto"`
54 User *User `orm:"rel(fk)"` //
55 Title string `orm:"size(60)"`
56 Content string ``
57 Created time.Time `orm:"auto_now_add"`
58 Updated time.Time `orm:"auto_now"`
59 Tags []*Tag `orm:"rel(m2m)"`
60 Manager `json:"-"`
61 }
62
63 func NewPost() *Post {
64 obj := new(Post)
65 obj.Manager.Init(obj)
66 return obj
67 }
68
69 type Tag struct {
70 Id int `orm:"auto"`
71 Name string `orm:"size(30)"`
72 Posts []*Post `orm:"reverse(many)" json:"-"`
73 Manager `json:"-"`
74 }
75
76 func NewTag() *Tag {
77 obj := new(Tag)
78 obj.Manager.Init(obj)
79 return obj
80 }
81
82 type Comment struct {
83 Id int `orm:"auto"`
84 Post *Post `orm:"rel(fk)"`
85 Content string ``
86 Parent *Comment `orm:"null;rel(fk)"`
87 Created time.Time `orm:"auto_now_add"`
88 Manager `json:"-"`
89 }
90
91 func NewComment() *Comment {
92 obj := new(Comment)
93 obj.Manager.Init(obj)
94 return obj
95 }
96
97 var DBARGS = struct {
98 Driver string
99 Source string
100 }{
101 os.Getenv("ORM_DRIVER"),
102 os.Getenv("ORM_SOURCE"),
103 }
104
105 var dORM Ormer
106
107 func init() {
108 RegisterModel(new(User))
109 RegisterModel(new(Profile))
110 RegisterModel(new(Post))
111 RegisterModel(new(Tag))
112 RegisterModel(new(Comment))
113
114 if DBARGS.Driver == "" || DBARGS.Source == "" {
115 fmt.Println(`need driver and source!
116
117 Default DB Drivers.
118
119 driver: url
120 mysql: https://github.com/go-sql-driver/mysql
121 sqlite3: https://github.com/mattn/go-sqlite3
122 postgres: https://github.com/bmizerany/pq
123
124 eg: mysql
125 ORM_DRIVER=mysql ORM_SOURCE="root:root@/my_db?charset=utf8" go test github.com/astaxie/beego/orm
126 `)
127 os.Exit(2)
128 }
129
130 RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, 20)
131
132 BootStrap()
133
134 truncateTables()
135
136 dORM = NewOrm()
137 }
138
139 func truncateTables() {
140 logs := "truncate tables for test\n"
141 o := NewOrm()
142 for _, m := range modelCache.allOrdered() {
143 query := fmt.Sprintf("truncate table `%s`", m.table)
144 _, err := o.Raw(query).Exec()
145 logs += query + "\n"
146 if err != nil {
147 fmt.Println(logs)
148 fmt.Println(err)
149 os.Exit(2)
150 }
151 }
152 }
...@@ -9,13 +9,15 @@ import ( ...@@ -9,13 +9,15 @@ import (
9 ) 9 )
10 10
11 var ( 11 var (
12 ErrTXHasBegin = errors.New("<Ormer.Begin> transaction already begin")
13 ErrTXNotBegin = errors.New("<Ormer.Commit/Rollback> transaction not begin")
14 ErrMultiRows = errors.New("<QuerySeter.One> return multi rows")
15 ErrStmtClosed = errors.New("<QuerySeter.Insert> stmt already closed")
16 DefaultRowsLimit = 1000 12 DefaultRowsLimit = 1000
17 DefaultRelsDepth = 5 13 DefaultRelsDepth = 5
18 DefaultTimeLoc = time.Local 14 DefaultTimeLoc = time.Local
15 ErrTxHasBegan = errors.New("<Ormer.Begin> transaction already begin")
16 ErrTxDone = errors.New("<Ormer.Commit/Rollback> transaction not begin")
17 ErrMultiRows = errors.New("<QuerySeter> return multi rows")
18 ErrNoRows = errors.New("<QuerySeter> not row found")
19 ErrStmtClosed = errors.New("<QuerySeter> stmt already closed")
20 ErrNotImplement = errors.New("have not implement")
19 ) 21 )
20 22
21 type Params map[string]interface{} 23 type Params map[string]interface{}
...@@ -27,13 +29,15 @@ type orm struct { ...@@ -27,13 +29,15 @@ type orm struct {
27 isTx bool 29 isTx bool
28 } 30 }
29 31
32 var _ Ormer = new(orm)
33
30 func (o *orm) getMiInd(md Modeler) (mi *modelInfo, ind reflect.Value) { 34 func (o *orm) getMiInd(md Modeler) (mi *modelInfo, ind reflect.Value) {
31 md.Init(md, true) 35 md.Init(md, true)
32 name := md.GetTableName() 36 name := md.GetTableName()
33 if mi, ok := modelCache.get(name); ok { 37 if mi, ok := modelCache.get(name); ok {
34 return mi, reflect.Indirect(reflect.ValueOf(md)) 38 return mi, reflect.Indirect(reflect.ValueOf(md))
35 } 39 }
36 panic(fmt.Sprintf("<orm.Object> table name: `%s` not exists", name)) 40 panic(fmt.Sprintf("<orm> table name: `%s` not exists", name))
37 } 41 }
38 42
39 func (o *orm) Read(md Modeler) error { 43 func (o *orm) Read(md Modeler) error {
...@@ -52,8 +56,8 @@ func (o *orm) Insert(md Modeler) (int64, error) { ...@@ -52,8 +56,8 @@ func (o *orm) Insert(md Modeler) (int64, error) {
52 return id, err 56 return id, err
53 } 57 }
54 if id > 0 { 58 if id > 0 {
55 if mi.fields.auto != nil { 59 if mi.fields.pk.auto {
56 ind.Field(mi.fields.auto.fieldIndex).SetInt(id) 60 ind.Field(mi.fields.pk.fieldIndex).SetInt(id)
57 } 61 }
58 } 62 }
59 return id, nil 63 return id, nil
...@@ -75,13 +79,31 @@ func (o *orm) Delete(md Modeler) (int64, error) { ...@@ -75,13 +79,31 @@ func (o *orm) Delete(md Modeler) (int64, error) {
75 return num, err 79 return num, err
76 } 80 }
77 if num > 0 { 81 if num > 0 {
78 if mi.fields.auto != nil { 82 if mi.fields.pk.auto {
79 ind.Field(mi.fields.auto.fieldIndex).SetInt(0) 83 ind.Field(mi.fields.pk.fieldIndex).SetInt(0)
80 } 84 }
81 } 85 }
82 return num, nil 86 return num, nil
83 } 87 }
84 88
89 func (o *orm) M2mAdd(md Modeler, name string, mds ...interface{}) (int64, error) {
90 // TODO
91 panic(ErrNotImplement)
92 return 0, nil
93 }
94
95 func (o *orm) M2mDel(md Modeler, name string, mds ...interface{}) (int64, error) {
96 // TODO
97 panic(ErrNotImplement)
98 return 0, nil
99 }
100
101 func (o *orm) LoadRel(md Modeler, name string) (int64, error) {
102 // TODO
103 panic(ErrNotImplement)
104 return 0, nil
105 }
106
85 func (o *orm) QueryTable(ptrStructOrTableName interface{}) QuerySeter { 107 func (o *orm) QueryTable(ptrStructOrTableName interface{}) QuerySeter {
86 name := "" 108 name := ""
87 if table, ok := ptrStructOrTableName.(string); ok { 109 if table, ok := ptrStructOrTableName.(string); ok {
...@@ -111,7 +133,7 @@ func (o *orm) Using(name string) error { ...@@ -111,7 +133,7 @@ func (o *orm) Using(name string) error {
111 133
112 func (o *orm) Begin() error { 134 func (o *orm) Begin() error {
113 if o.isTx { 135 if o.isTx {
114 return ErrTXHasBegin 136 return ErrTxHasBegan
115 } 137 }
116 tx, err := o.alias.DB.Begin() 138 tx, err := o.alias.DB.Begin()
117 if err != nil { 139 if err != nil {
...@@ -124,24 +146,28 @@ func (o *orm) Begin() error { ...@@ -124,24 +146,28 @@ func (o *orm) Begin() error {
124 146
125 func (o *orm) Commit() error { 147 func (o *orm) Commit() error {
126 if o.isTx == false { 148 if o.isTx == false {
127 return ErrTXNotBegin 149 return ErrTxDone
128 } 150 }
129 err := o.db.(*sql.Tx).Commit() 151 err := o.db.(*sql.Tx).Commit()
130 if err == nil { 152 if err == nil {
131 o.isTx = false 153 o.isTx = false
132 o.db = o.alias.DB 154 o.db = o.alias.DB
155 } else if err == sql.ErrTxDone {
156 return ErrTxDone
133 } 157 }
134 return err 158 return err
135 } 159 }
136 160
137 func (o *orm) Rollback() error { 161 func (o *orm) Rollback() error {
138 if o.isTx == false { 162 if o.isTx == false {
139 return ErrTXNotBegin 163 return ErrTxDone
140 } 164 }
141 err := o.db.(*sql.Tx).Rollback() 165 err := o.db.(*sql.Tx).Rollback()
142 if err == nil { 166 if err == nil {
143 o.isTx = false 167 o.isTx = false
144 o.db = o.alias.DB 168 o.db = o.alias.DB
169 } else if err == sql.ErrTxDone {
170 return ErrTxDone
145 } 171 }
146 return err 172 return err
147 } 173 }
...@@ -150,7 +176,13 @@ func (o *orm) Raw(query string, args ...interface{}) RawSeter { ...@@ -150,7 +176,13 @@ func (o *orm) Raw(query string, args ...interface{}) RawSeter {
150 return newRawSet(o, query, args) 176 return newRawSet(o, query, args)
151 } 177 }
152 178
179 func (o *orm) Driver() Driver {
180 return driver(o.alias.Name)
181 }
182
153 func NewOrm() Ormer { 183 func NewOrm() Ormer {
184 BootStrap() // execute only once
185
154 o := new(orm) 186 o := new(orm)
155 err := o.Using("default") 187 err := o.Using("default")
156 if err != nil { 188 if err != nil {
......
...@@ -26,23 +26,24 @@ func NewCondition() *Condition { ...@@ -26,23 +26,24 @@ func NewCondition() *Condition {
26 return c 26 return c
27 } 27 }
28 28
29 func (c *Condition) And(expr string, args ...interface{}) *Condition { 29 func (c Condition) And(expr string, args ...interface{}) *Condition {
30 if expr == "" || len(args) == 0 { 30 if expr == "" || len(args) == 0 {
31 panic("<Condition.And> args cannot empty") 31 panic("<Condition.And> args cannot empty")
32 } 32 }
33 c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args}) 33 c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args})
34 return c 34 return &c
35 } 35 }
36 36
37 func (c *Condition) AndNot(expr string, args ...interface{}) *Condition { 37 func (c Condition) AndNot(expr string, args ...interface{}) *Condition {
38 if expr == "" || len(args) == 0 { 38 if expr == "" || len(args) == 0 {
39 panic("<Condition.AndNot> args cannot empty") 39 panic("<Condition.AndNot> args cannot empty")
40 } 40 }
41 c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isNot: true}) 41 c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isNot: true})
42 return c 42 return &c
43 } 43 }
44 44
45 func (c *Condition) AndCond(cond *Condition) *Condition { 45 func (c *Condition) AndCond(cond *Condition) *Condition {
46 c = c.clone()
46 if c == cond { 47 if c == cond {
47 panic("cannot use self as sub cond") 48 panic("cannot use self as sub cond")
48 } 49 }
...@@ -52,23 +53,24 @@ func (c *Condition) AndCond(cond *Condition) *Condition { ...@@ -52,23 +53,24 @@ func (c *Condition) AndCond(cond *Condition) *Condition {
52 return c 53 return c
53 } 54 }
54 55
55 func (c *Condition) Or(expr string, args ...interface{}) *Condition { 56 func (c Condition) Or(expr string, args ...interface{}) *Condition {
56 if expr == "" || len(args) == 0 { 57 if expr == "" || len(args) == 0 {
57 panic("<Condition.Or> args cannot empty") 58 panic("<Condition.Or> args cannot empty")
58 } 59 }
59 c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isOr: true}) 60 c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isOr: true})
60 return c 61 return &c
61 } 62 }
62 63
63 func (c *Condition) OrNot(expr string, args ...interface{}) *Condition { 64 func (c Condition) OrNot(expr string, args ...interface{}) *Condition {
64 if expr == "" || len(args) == 0 { 65 if expr == "" || len(args) == 0 {
65 panic("<Condition.OrNot> args cannot empty") 66 panic("<Condition.OrNot> args cannot empty")
66 } 67 }
67 c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isNot: true, isOr: true}) 68 c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isNot: true, isOr: true})
68 return c 69 return &c
69 } 70 }
70 71
71 func (c *Condition) OrCond(cond *Condition) *Condition { 72 func (c *Condition) OrCond(cond *Condition) *Condition {
73 c = c.clone()
72 if c == cond { 74 if c == cond {
73 panic("cannot use self as sub cond") 75 panic("cannot use self as sub cond")
74 } 76 }
...@@ -82,13 +84,6 @@ func (c *Condition) IsEmpty() bool { ...@@ -82,13 +84,6 @@ func (c *Condition) IsEmpty() bool {
82 return len(c.params) == 0 84 return len(c.params) == 0
83 } 85 }
84 86
85 func (c Condition) Clone() *Condition { 87 func (c Condition) clone() *Condition {
86 params := c.params
87 c.params = make([]condValue, len(params))
88 copy(c.params, params)
89 return &c 88 return &c
90 } 89 }
91
92 func (c *Condition) Merge() (expr string, args []interface{}) {
93 return expr, args
94 }
......
...@@ -13,6 +13,8 @@ type insertSet struct { ...@@ -13,6 +13,8 @@ type insertSet struct {
13 closed bool 13 closed bool
14 } 14 }
15 15
16 var _ Inserter = new(insertSet)
17
16 func (o *insertSet) Insert(md Modeler) (int64, error) { 18 func (o *insertSet) Insert(md Modeler) (int64, error) {
17 if o.closed { 19 if o.closed {
18 return 0, ErrStmtClosed 20 return 0, ErrStmtClosed
...@@ -28,14 +30,17 @@ func (o *insertSet) Insert(md Modeler) (int64, error) { ...@@ -28,14 +30,17 @@ func (o *insertSet) Insert(md Modeler) (int64, error) {
28 return id, err 30 return id, err
29 } 31 }
30 if id > 0 { 32 if id > 0 {
31 if o.mi.fields.auto != nil { 33 if o.mi.fields.pk.auto {
32 ind.Field(o.mi.fields.auto.fieldIndex).SetInt(id) 34 ind.Field(o.mi.fields.pk.fieldIndex).SetInt(id)
33 } 35 }
34 } 36 }
35 return id, nil 37 return id, nil
36 } 38 }
37 39
38 func (o *insertSet) Close() error { 40 func (o *insertSet) Close() error {
41 if o.closed {
42 return ErrStmtClosed
43 }
39 o.closed = true 44 o.closed = true
40 return o.stmt.Close() 45 return o.stmt.Close()
41 } 46 }
......
...@@ -15,47 +15,43 @@ type querySet struct { ...@@ -15,47 +15,43 @@ type querySet struct {
15 orm *orm 15 orm *orm
16 } 16 }
17 17
18 func (o *querySet) Filter(expr string, args ...interface{}) QuerySeter { 18 var _ QuerySeter = new(querySet)
19 o = o.clone() 19
20 func (o querySet) Filter(expr string, args ...interface{}) QuerySeter {
20 if o.cond == nil { 21 if o.cond == nil {
21 o.cond = NewCondition() 22 o.cond = NewCondition()
22 } 23 }
23 o.cond.And(expr, args...) 24 o.cond = o.cond.And(expr, args...)
24 return o 25 return &o
25 } 26 }
26 27
27 func (o *querySet) Exclude(expr string, args ...interface{}) QuerySeter { 28 func (o querySet) Exclude(expr string, args ...interface{}) QuerySeter {
28 o = o.clone()
29 if o.cond == nil { 29 if o.cond == nil {
30 o.cond = NewCondition() 30 o.cond = NewCondition()
31 } 31 }
32 o.cond.AndNot(expr, args...) 32 o.cond = o.cond.AndNot(expr, args...)
33 return o 33 return &o
34 } 34 }
35 35
36 func (o *querySet) Limit(limit int, args ...int64) QuerySeter { 36 func (o querySet) Limit(limit int, args ...int64) QuerySeter {
37 o = o.clone()
38 o.limit = limit 37 o.limit = limit
39 if len(args) > 0 { 38 if len(args) > 0 {
40 o.offset = args[0] 39 o.offset = args[0]
41 } 40 }
42 return o 41 return &o
43 } 42 }
44 43
45 func (o *querySet) Offset(offset int64) QuerySeter { 44 func (o querySet) Offset(offset int64) QuerySeter {
46 o = o.clone()
47 o.offset = offset 45 o.offset = offset
48 return o 46 return &o
49 } 47 }
50 48
51 func (o *querySet) OrderBy(exprs ...string) QuerySeter { 49 func (o querySet) OrderBy(exprs ...string) QuerySeter {
52 o = o.clone()
53 o.orders = exprs 50 o.orders = exprs
54 return o 51 return &o
55 } 52 }
56 53
57 func (o *querySet) RelatedSel(params ...interface{}) QuerySeter { 54 func (o querySet) RelatedSel(params ...interface{}) QuerySeter {
58 o = o.clone()
59 var related []string 55 var related []string
60 if len(params) == 0 { 56 if len(params) == 0 {
61 o.relDepth = DefaultRelsDepth 57 o.relDepth = DefaultRelsDepth
...@@ -72,13 +68,6 @@ func (o *querySet) RelatedSel(params ...interface{}) QuerySeter { ...@@ -72,13 +68,6 @@ func (o *querySet) RelatedSel(params ...interface{}) QuerySeter {
72 } 68 }
73 } 69 }
74 o.related = related 70 o.related = related
75 return o
76 }
77
78 func (o querySet) clone() *querySet {
79 if o.cond != nil {
80 o.cond = o.cond.Clone()
81 }
82 return &o 71 return &o
83 } 72 }
84 73
...@@ -115,6 +104,9 @@ func (o *querySet) One(container Modeler) error { ...@@ -115,6 +104,9 @@ func (o *querySet) One(container Modeler) error {
115 if num > 1 { 104 if num > 1 {
116 return ErrMultiRows 105 return ErrMultiRows
117 } 106 }
107 if num == 0 {
108 return ErrNoRows
109 }
118 return nil 110 return nil
119 } 111 }
120 112
......
...@@ -63,6 +63,8 @@ type rawSet struct { ...@@ -63,6 +63,8 @@ type rawSet struct {
63 orm *orm 63 orm *orm
64 } 64 }
65 65
66 var _ RawSeter = new(rawSet)
67
66 func (o rawSet) SetArgs(args ...interface{}) RawSeter { 68 func (o rawSet) SetArgs(args ...interface{}) RawSeter {
67 o.args = args 69 o.args = args
68 return &o 70 return &o
...@@ -76,7 +78,12 @@ func (o *rawSet) Exec() (int64, error) { ...@@ -76,7 +78,12 @@ func (o *rawSet) Exec() (int64, error) {
76 return getResult(res) 78 return getResult(res)
77 } 79 }
78 80
79 func (o *rawSet) Mapper(...interface{}) (int64, error) { 81 func (o *rawSet) QueryRow(...interface{}) error {
82 //TODO
83 return nil
84 }
85
86 func (o *rawSet) QueryRows(...interface{}) (int64, error) {
80 //TODO 87 //TODO
81 return 0, nil 88 return 0, nil
82 } 89 }
...@@ -120,7 +127,7 @@ func (o *rawSet) readValues(container interface{}) (int64, error) { ...@@ -120,7 +127,7 @@ func (o *rawSet) readValues(container interface{}) (int64, error) {
120 cols = columns 127 cols = columns
121 refs = make([]interface{}, len(cols)) 128 refs = make([]interface{}, len(cols))
122 for i, _ := range refs { 129 for i, _ := range refs {
123 var ref string 130 var ref sql.NullString
124 refs[i] = &ref 131 refs[i] = &ref
125 } 132 }
126 } 133 }
...@@ -134,21 +141,21 @@ func (o *rawSet) readValues(container interface{}) (int64, error) { ...@@ -134,21 +141,21 @@ func (o *rawSet) readValues(container interface{}) (int64, error) {
134 case 1: 141 case 1:
135 params := make(Params, len(cols)) 142 params := make(Params, len(cols))
136 for i, ref := range refs { 143 for i, ref := range refs {
137 value := reflect.Indirect(reflect.ValueOf(ref)).Interface() 144 value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString)
138 params[cols[i]] = value 145 params[cols[i]] = value.String
139 } 146 }
140 maps = append(maps, params) 147 maps = append(maps, params)
141 case 2: 148 case 2:
142 params := make(ParamsList, 0, len(cols)) 149 params := make(ParamsList, 0, len(cols))
143 for _, ref := range refs { 150 for _, ref := range refs {
144 value := reflect.Indirect(reflect.ValueOf(ref)).Interface() 151 value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString)
145 params = append(params, value) 152 params = append(params, value.String)
146 } 153 }
147 lists = append(lists, params) 154 lists = append(lists, params)
148 case 3: 155 case 3:
149 for _, ref := range refs { 156 for _, ref := range refs {
150 value := reflect.Indirect(reflect.ValueOf(ref)).Interface() 157 value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString)
151 list = append(list, value) 158 list = append(list, value.String)
152 } 159 }
153 } 160 }
154 161
......
1 package orm
2
3 import (
4 "bytes"
5 "fmt"
6 "io/ioutil"
7 "path/filepath"
8 "reflect"
9 "runtime"
10 "strings"
11 "testing"
12 "time"
13 )
14
15 type T_Code int
16
17 const (
18 // =
19 T_Equal T_Code = iota
20 // <
21 T_Less
22 // >
23 T_Large
24 // elment in slice/array
25 // T_In
26 // key exists in map
27 // T_KeyExist
28 // index != -1
29 // T_Contain
30 // index == 0
31 // T_StartWith
32 // index == len(x) - 1
33 // T_EndWith
34 )
35
36 func ValuesCompare(is bool, a interface{}, o T_Code, args ...interface{}) (err error, ok bool) {
37 if len(args) == 0 {
38 return fmt.Errorf("miss args"), false
39 }
40 b := args[0]
41 arg := argAny(args)
42 switch o {
43 case T_Equal:
44 switch v := a.(type) {
45 case reflect.Kind:
46 ok = reflect.ValueOf(b).Kind() == v
47 case time.Time:
48 if v2, vo := b.(time.Time); vo {
49 if arg.Get(1) != nil {
50 format := ToStr(arg.Get(1))
51 ok = v.Format(format) == v2.Format(format)
52 } else {
53 err = fmt.Errorf("compare datetime miss format")
54 goto wrongArg
55 }
56 }
57 default:
58 ok = ToStr(a) == ToStr(b)
59 }
60 ok = is && ok || !is && !ok
61 if !ok {
62 if is {
63 err = fmt.Errorf("should: a == b, a = `%v`, b = `%v`", a, b)
64 } else {
65 err = fmt.Errorf("should: a != b, a = `%v`, b = `%v`", a, b)
66 }
67 }
68 case T_Less, T_Large:
69 as := ToStr(a)
70 bs := ToStr(b)
71 f1, er := StrTo(as).Float64()
72 if er != nil {
73 err = fmt.Errorf("wrong type need numeric: `%v`", a)
74 goto wrongArg
75 }
76 f2, er := StrTo(bs).Float64()
77 if er != nil {
78 err = fmt.Errorf("wrong type need numeric: `%v`", b)
79 goto wrongArg
80 }
81 var opts []string
82 if o == T_Less {
83 opts = []string{"<", ">="}
84 ok = f1 < f2
85 } else {
86 opts = []string{">", "<="}
87 ok = f1 > f2
88 }
89 ok = is && ok || !is && !ok
90 if !ok {
91 if is {
92 err = fmt.Errorf("should: a %s b, a = `%v`, b = `%v`", opts[0], f1, f2)
93 } else {
94 err = fmt.Errorf("should: a %s b, a = `%v`, b = `%v`", opts[1], f1, f2)
95 }
96 }
97 }
98 wrongArg:
99 if err != nil {
100 return err, false
101 }
102
103 return nil, true
104 }
105
106 func AssertIs(a interface{}, o T_Code, args ...interface{}) error {
107 if err, ok := ValuesCompare(true, a, o, args...); ok == false {
108 return err
109 }
110 return nil
111 }
112
113 func AssertNot(a interface{}, o T_Code, args ...interface{}) error {
114 if err, ok := ValuesCompare(false, a, o, args...); ok == false {
115 return err
116 }
117 return nil
118 }
119
120 func getCaller(skip int) string {
121 pc, file, line, _ := runtime.Caller(skip)
122 fun := runtime.FuncForPC(pc)
123 _, fn := filepath.Split(file)
124 data, err := ioutil.ReadFile(file)
125 code := ""
126 if err == nil {
127 lines := bytes.Split(data, []byte{'\n'})
128 code = strings.TrimSpace(string(lines[line-1]))
129 }
130 funName := fun.Name()
131 if i := strings.LastIndex(funName, "."); i > -1 {
132 funName = funName[i+1:]
133 }
134 return fmt.Sprintf("%s:%d: %s: %s", fn, line, funName, code)
135 }
136
137 func throwFail(t *testing.T, err error, args ...interface{}) {
138 if err != nil {
139 params := []interface{}{"\n", getCaller(2), "\n", err, "\n"}
140 params = append(params, args...)
141 t.Error(params...)
142 t.Fail()
143 }
144 }
145
146 func throwFailNow(t *testing.T, err error, args ...interface{}) {
147 if err != nil {
148 params := []interface{}{"\n", getCaller(2), "\n", err, "\n"}
149 params = append(params, args...)
150 t.Error(params...)
151 t.FailNow()
152 }
153 }
154
155 func TestCRUD(t *testing.T) {
156 profile := NewProfile()
157 profile.Age = 30
158 profile.Money = 1234.12
159 id, err := dORM.Insert(profile)
160 throwFailNow(t, err)
161 throwFailNow(t, AssertIs(id, T_Large, 0))
162
163 user := NewUser()
164 user.UserName = "slene"
165 user.Email = "vslene@gmail.com"
166 user.Password = "pass"
167 user.Status = 3
168 user.IsStaff = true
169 user.IsActive = true
170
171 id, err = dORM.Insert(user)
172 throwFailNow(t, err)
173 throwFailNow(t, AssertIs(id, T_Large, 0))
174
175 u := &User{Id: user.Id}
176 err = dORM.Read(u)
177 throwFailNow(t, err)
178
179 throwFailNow(t, AssertIs(u.UserName, T_Equal, "slene"))
180 throwFailNow(t, AssertIs(u.Email, T_Equal, "vslene@gmail.com"))
181 throwFailNow(t, AssertIs(u.Password, T_Equal, "pass"))
182 throwFailNow(t, AssertIs(u.Status, T_Equal, 3))
183 throwFailNow(t, AssertIs(u.IsStaff, T_Equal, true))
184 throwFailNow(t, AssertIs(u.IsActive, T_Equal, true))
185 throwFailNow(t, AssertIs(u.Created, T_Equal, user.Created, format_Date))
186 throwFailNow(t, AssertIs(u.Updated, T_Equal, user.Updated, format_DateTime))
187
188 user.UserName = "astaxie"
189 user.Profile = profile
190 num, err := dORM.Update(user)
191 throwFailNow(t, err)
192 throwFailNow(t, AssertIs(num, T_Equal, 1))
193
194 u = &User{Id: user.Id}
195 err = dORM.Read(u)
196 throwFailNow(t, err)
197
198 throwFailNow(t, AssertIs(u.UserName, T_Equal, "astaxie"))
199 throwFailNow(t, AssertIs(u.Profile.Id, T_Equal, profile.Id))
200
201 num, err = dORM.Delete(profile)
202 throwFailNow(t, err)
203 throwFailNow(t, AssertIs(num, T_Equal, 1))
204
205 u = &User{Id: user.Id}
206 err = dORM.Read(u)
207 throwFailNow(t, err)
208 throwFailNow(t, AssertIs(true, T_Equal, u.Profile == nil))
209
210 num, err = dORM.Delete(user)
211 throwFailNow(t, err)
212 throwFailNow(t, AssertIs(num, T_Equal, 1))
213
214 u = &User{Id: 100}
215 err = dORM.Read(u)
216 throwFailNow(t, AssertIs(err, T_Equal, ErrNoRows))
217 }
218
219 func TestInsertTestData(t *testing.T) {
220 var users []*User
221
222 profile := NewProfile()
223 profile.Age = 28
224 profile.Money = 1234.12
225
226 id, err := dORM.Insert(profile)
227 throwFailNow(t, err)
228 throwFailNow(t, AssertIs(id, T_Large, 0))
229
230 user := NewUser()
231 user.UserName = "slene"
232 user.Email = "vslene@gmail.com"
233 user.Password = "pass"
234 user.Status = 1
235 user.IsStaff = false
236 user.IsActive = true
237 user.Profile = profile
238
239 users = append(users, user)
240
241 id, err = dORM.Insert(user)
242 throwFailNow(t, err)
243 throwFailNow(t, AssertIs(id, T_Large, 0))
244
245 profile = NewProfile()
246 profile.Age = 30
247 profile.Money = 4321.09
248
249 id, err = dORM.Insert(profile)
250 throwFailNow(t, err)
251 throwFailNow(t, AssertIs(id, T_Large, 0))
252
253 user = NewUser()
254 user.UserName = "astaxie"
255 user.Email = "astaxie@gmail.com"
256 user.Password = "password"
257 user.Status = 2
258 user.IsStaff = true
259 user.IsActive = false
260 user.Profile = profile
261
262 users = append(users, user)
263
264 id, err = dORM.Insert(user)
265 throwFailNow(t, err)
266 throwFailNow(t, AssertIs(id, T_Large, 0))
267
268 user = NewUser()
269 user.UserName = "nobody"
270 user.Email = "nobody@gmail.com"
271 user.Password = "nobody"
272 user.Status = 3
273 user.IsStaff = false
274 user.IsActive = false
275
276 users = append(users, user)
277
278 id, err = dORM.Insert(user)
279 throwFailNow(t, err)
280 throwFailNow(t, AssertIs(id, T_Large, 0))
281
282 tags := []*Tag{
283 &Tag{Name: "golang"},
284 &Tag{Name: "example"},
285 &Tag{Name: "format"},
286 &Tag{Name: "c++"},
287 }
288
289 posts := []*Post{
290 &Post{User: users[0], Tags: []*Tag{tags[0]}, Title: "Introduction", Content: `Go is a new language. Although it borrows ideas from existing languages, it has unusual properties that make effective Go programs different in character from programs written in its relatives. A straightforward translation of a C++ or Java program into Go is unlikely to produce a satisfactory result—Java programs are written in Java, not Go. On the other hand, thinking about the problem from a Go perspective could produce a successful but quite different program. In other words, to write Go well, it's important to understand its properties and idioms. It's also important to know the established conventions for programming in Go, such as naming, formatting, program construction, and so on, so that programs you write will be easy for other Go programmers to understand.
291 This document gives tips for writing clear, idiomatic Go code. It augments the language specification, the Tour of Go, and How to Write Go Code, all of which you should read first.`},
292 &Post{User: users[1], Tags: []*Tag{tags[0], tags[1]}, Title: "Examples", Content: `The Go package sources are intended to serve not only as the core library but also as examples of how to use the language. Moreover, many of the packages contain working, self-contained executable examples you can run directly from the golang.org web site, such as this one (click on the word "Example" to open it up). If you have a question about how to approach a problem or how something might be implemented, the documentation, code and examples in the library can provide answers, ideas and background.`},
293 &Post{User: users[1], Tags: []*Tag{tags[0], tags[2]}, Title: "Formatting", Content: `Formatting issues are the most contentious but the least consequential. People can adapt to different formatting styles but it's better if they don't have to, and less time is devoted to the topic if everyone adheres to the same style. The problem is how to approach this Utopia without a long prescriptive style guide.
294 With Go we take an unusual approach and let the machine take care of most formatting issues. The gofmt program (also available as go fmt, which operates at the package level rather than source file level) reads a Go program and emits the source in a standard style of indentation and vertical alignment, retaining and if necessary reformatting comments. If you want to know how to handle some new layout situation, run gofmt; if the answer doesn't seem right, rearrange your program (or file a bug about gofmt), don't work around it.`},
295 &Post{User: users[2], Tags: []*Tag{tags[3]}, Title: "Commentary", Content: `Go provides C-style /* */ block comments and C++-style // line comments. Line comments are the norm; block comments appear mostly as package comments, but are useful within an expression or to disable large swaths of code.
296 The program—and web server—godoc processes Go source files to extract documentation about the contents of the package. Comments that appear before top-level declarations, with no intervening newlines, are extracted along with the declaration to serve as explanatory text for the item. The nature and style of these comments determines the quality of the documentation godoc produces.`},
297 }
298
299 comments := []*Comment{
300 &Comment{Post: posts[0], Content: "a comment"},
301 &Comment{Post: posts[1], Content: "yes"},
302 &Comment{Post: posts[1]},
303 &Comment{Post: posts[1]},
304 &Comment{Post: posts[2]},
305 &Comment{Post: posts[2]},
306 }
307
308 for _, tag := range tags {
309 id, err := dORM.Insert(tag)
310 throwFailNow(t, err)
311 throwFailNow(t, AssertIs(id, T_Large, 0))
312 }
313
314 for _, post := range posts {
315 id, err := dORM.Insert(post)
316 throwFailNow(t, err)
317 throwFailNow(t, AssertIs(id, T_Large, 0))
318 // dORM.M2mAdd(post, "tags", post.Tags)
319 }
320
321 for _, comment := range comments {
322 id, err := dORM.Insert(comment)
323 throwFailNow(t, err)
324 throwFailNow(t, AssertIs(id, T_Large, 0))
325 }
326 }
327
328 func TestExpr(t *testing.T) {
329 qs := dORM.QueryTable("User")
330 qs = dORM.QueryTable("user")
331 num, err := qs.Filter("UserName", "slene").Filter("user_name", "slene").Filter("profile__Age", 28).Count()
332 throwFail(t, err)
333 throwFail(t, AssertIs(num, T_Equal, 1))
334 }
335
336 func TestOperators(t *testing.T) {
337 qs := dORM.QueryTable("user")
338 num, err := qs.Filter("user_name", "slene").Count()
339 throwFail(t, err)
340 throwFail(t, AssertIs(num, T_Equal, 1))
341
342 num, err = qs.Filter("user_name__exact", "slene").Count()
343 throwFail(t, err)
344 throwFail(t, AssertIs(num, T_Equal, 1))
345
346 num, err = qs.Filter("user_name__iexact", "Slene").Count()
347 throwFail(t, err)
348 throwFail(t, AssertIs(num, T_Equal, 1))
349
350 num, err = qs.Filter("user_name__contains", "e").Count()
351 throwFail(t, err)
352 throwFail(t, AssertIs(num, T_Equal, 2))
353
354 num, err = qs.Filter("user_name__contains", "E").Count()
355 throwFail(t, err)
356 throwFail(t, AssertIs(num, T_Equal, 0))
357
358 num, err = qs.Filter("user_name__icontains", "E").Count()
359 throwFail(t, err)
360 throwFail(t, AssertIs(num, T_Equal, 2))
361
362 num, err = qs.Filter("user_name__icontains", "E").Count()
363 throwFail(t, err)
364 throwFail(t, AssertIs(num, T_Equal, 2))
365
366 num, err = qs.Filter("status__gt", 1).Count()
367 throwFail(t, err)
368 throwFail(t, AssertIs(num, T_Equal, 2))
369
370 num, err = qs.Filter("status__gte", 1).Count()
371 throwFail(t, err)
372 throwFail(t, AssertIs(num, T_Equal, 3))
373
374 num, err = qs.Filter("status__lt", 3).Count()
375 throwFail(t, err)
376 throwFail(t, AssertIs(num, T_Equal, 2))
377
378 num, err = qs.Filter("status__lte", 3).Count()
379 throwFail(t, err)
380 throwFail(t, AssertIs(num, T_Equal, 3))
381
382 num, err = qs.Filter("user_name__startswith", "s").Count()
383 throwFail(t, err)
384 throwFail(t, AssertIs(num, T_Equal, 1))
385
386 num, err = qs.Filter("user_name__startswith", "S").Count()
387 throwFail(t, err)
388 throwFail(t, AssertIs(num, T_Equal, 0))
389
390 num, err = qs.Filter("user_name__istartswith", "S").Count()
391 throwFail(t, err)
392 throwFail(t, AssertIs(num, T_Equal, 1))
393
394 num, err = qs.Filter("user_name__endswith", "e").Count()
395 throwFail(t, err)
396 throwFail(t, AssertIs(num, T_Equal, 2))
397
398 num, err = qs.Filter("user_name__endswith", "E").Count()
399 throwFail(t, err)
400 throwFail(t, AssertIs(num, T_Equal, 0))
401
402 num, err = qs.Filter("user_name__iendswith", "E").Count()
403 throwFail(t, err)
404 throwFail(t, AssertIs(num, T_Equal, 2))
405
406 num, err = qs.Filter("profile__isnull", true).Count()
407 throwFail(t, err)
408 throwFail(t, AssertIs(num, T_Equal, 1))
409
410 num, err = qs.Filter("status__in", 1, 2).Count()
411 throwFail(t, err)
412 throwFail(t, AssertIs(num, T_Equal, 2))
413 }
414
415 func TestAll(t *testing.T) {
416 var users []*User
417 qs := dORM.QueryTable("user")
418 num, err := qs.All(&users)
419 throwFail(t, err)
420 throwFail(t, AssertIs(num, T_Equal, 3))
421
422 qs = dORM.QueryTable("user")
423 num, err = qs.Filter("user_name", "nothing").All(&users)
424 throwFail(t, err)
425 throwFail(t, AssertIs(num, T_Equal, 0))
426 }
427
428 func TestOne(t *testing.T) {
429 var user User
430 qs := dORM.QueryTable("user")
431 err := qs.One(&user)
432 throwFail(t, AssertIs(err, T_Equal, ErrMultiRows))
433
434 err = qs.Filter("user_name", "nothing").One(&user)
435 throwFail(t, AssertIs(err, T_Equal, ErrNoRows))
436 }
437
438 func TestValues(t *testing.T) {
439 var maps []Params
440 qs := dORM.QueryTable("user")
441
442 num, err := qs.Values(&maps)
443 throwFail(t, err)
444 throwFail(t, AssertIs(num, T_Equal, 3))
445 if num == 3 {
446 throwFail(t, AssertIs(maps[0]["UserName"], T_Equal, "slene"))
447 throwFail(t, AssertIs(maps[2]["Profile"], T_Equal, nil))
448 }
449
450 num, err = qs.Values(&maps, "UserName", "Profile__Age")
451 throwFail(t, err)
452 throwFail(t, AssertIs(num, T_Equal, 3))
453 if num == 3 {
454 throwFail(t, AssertIs(maps[0]["UserName"], T_Equal, "slene"))
455 throwFail(t, AssertIs(maps[0]["Profile__Age"], T_Equal, 28))
456 throwFail(t, AssertIs(maps[2]["Profile__Age"], T_Equal, nil))
457 }
458 }
459
460 func TestValuesList(t *testing.T) {
461 var list []ParamsList
462 qs := dORM.QueryTable("user")
463
464 num, err := qs.ValuesList(&list)
465 throwFail(t, err)
466 throwFail(t, AssertIs(num, T_Equal, 3))
467 if num == 3 {
468 throwFail(t, AssertIs(list[0][1], T_Equal, "slene"))
469 throwFail(t, AssertIs(list[2][9], T_Equal, nil))
470 }
471
472 num, err = qs.ValuesList(&list, "UserName", "Profile__Age")
473 throwFail(t, err)
474 throwFail(t, AssertIs(num, T_Equal, 3))
475 if num == 3 {
476 throwFail(t, AssertIs(list[0][0], T_Equal, "slene"))
477 throwFail(t, AssertIs(list[0][1], T_Equal, 28))
478 throwFail(t, AssertIs(list[2][1], T_Equal, nil))
479 }
480 }
481
482 func TestValuesFlat(t *testing.T) {
483 var list ParamsList
484 qs := dORM.QueryTable("user")
485
486 num, err := qs.OrderBy("id").ValuesFlat(&list, "UserName")
487 throwFail(t, err)
488 throwFail(t, AssertIs(num, T_Equal, 3))
489 if num == 3 {
490 throwFail(t, AssertIs(list[0], T_Equal, "slene"))
491 throwFail(t, AssertIs(list[1], T_Equal, "astaxie"))
492 throwFail(t, AssertIs(list[2], T_Equal, "nobody"))
493 }
494 }
495
496 func TestRelatedSel(t *testing.T) {
497 qs := dORM.QueryTable("user")
498 num, err := qs.Filter("profile__age", 28).Count()
499 throwFail(t, err)
500 throwFail(t, AssertIs(num, T_Equal, 1))
501
502 num, err = qs.Filter("profile__age__gt", 28).Count()
503 throwFail(t, err)
504 throwFail(t, AssertIs(num, T_Equal, 1))
505
506 num, err = qs.Filter("profile__user__profile__age__gt", 28).Count()
507 throwFail(t, err)
508 throwFail(t, AssertIs(num, T_Equal, 1))
509
510 var user User
511 err = qs.Filter("user_name", "slene").RelatedSel("profile").One(&user)
512 throwFail(t, err)
513 throwFail(t, AssertIs(num, T_Equal, 1))
514 throwFail(t, AssertNot(user.Profile, T_Equal, nil))
515 if user.Profile != nil {
516 throwFail(t, AssertIs(user.Profile.Age, T_Equal, 28))
517 }
518
519 err = qs.Filter("user_name", "slene").RelatedSel().One(&user)
520 throwFail(t, err)
521 throwFail(t, AssertIs(num, T_Equal, 1))
522 throwFail(t, AssertNot(user.Profile, T_Equal, nil))
523 throwFail(t, AssertIs(user.Profile.Age, T_Equal, 28))
524 if user.Profile != nil {
525 throwFail(t, AssertIs(user.Profile.Age, T_Equal, 28))
526 }
527
528 err = qs.Filter("user_name", "nobody").RelatedSel("profile").One(&user)
529 throwFail(t, AssertIs(num, T_Equal, 1))
530 throwFail(t, AssertIs(user.Profile, T_Equal, nil))
531
532 qs = dORM.QueryTable("user_profile")
533 num, err = qs.Filter("user__username", "slene").Count()
534 throwFail(t, err)
535 throwFail(t, AssertIs(num, T_Equal, 1))
536 }
537
538 func TestSetCond(t *testing.T) {
539 cond := NewCondition()
540 cond1 := cond.And("profile__isnull", false).AndNot("status__in", 1).Or("profile__age__gt", 2000)
541
542 qs := dORM.QueryTable("user")
543 num, err := qs.SetCond(cond1).Count()
544 throwFail(t, err)
545 throwFail(t, AssertIs(num, T_Equal, 1))
546
547 cond2 := cond.AndCond(cond1).OrCond(cond.And("user_name", "slene"))
548 num, err = qs.SetCond(cond2).Count()
549 throwFail(t, err)
550 throwFail(t, AssertIs(num, T_Equal, 2))
551 }
552
553 func TestLimit(t *testing.T) {
554 var posts []*Post
555 qs := dORM.QueryTable("post")
556 num, err := qs.Limit(1).All(&posts)
557 throwFail(t, err)
558 throwFail(t, AssertIs(num, T_Equal, 1))
559
560 num, err = qs.Limit(-1).All(&posts)
561 throwFail(t, err)
562 throwFail(t, AssertIs(num, T_Equal, 4))
563
564 num, err = qs.Limit(-1, 2).All(&posts)
565 throwFail(t, err)
566 throwFail(t, AssertIs(num, T_Equal, 2))
567
568 num, err = qs.Limit(0, 2).All(&posts)
569 throwFail(t, err)
570 throwFail(t, AssertIs(num, T_Equal, 2))
571 }
572
573 func TestOffset(t *testing.T) {
574 var posts []*Post
575 qs := dORM.QueryTable("post")
576 num, err := qs.Limit(1).Offset(2).All(&posts)
577 throwFail(t, err)
578 throwFail(t, AssertIs(num, T_Equal, 1))
579
580 num, err = qs.Offset(2).All(&posts)
581 throwFail(t, err)
582 throwFail(t, AssertIs(num, T_Equal, 2))
583 }
584
585 func TestOrderBy(t *testing.T) {
586 qs := dORM.QueryTable("user")
587 num, err := qs.OrderBy("-status").Filter("user_name", "nobody").Count()
588 throwFail(t, err)
589 throwFail(t, AssertIs(num, T_Equal, 1))
590
591 num, err = qs.OrderBy("status").Filter("user_name", "slene").Count()
592 throwFail(t, err)
593 throwFail(t, AssertIs(num, T_Equal, 1))
594
595 num, err = qs.OrderBy("-profile__age").Filter("user_name", "astaxie").Count()
596 throwFail(t, err)
597 throwFail(t, AssertIs(num, T_Equal, 1))
598 }
599
600 func TestPrepareInsert(t *testing.T) {
601 qs := dORM.QueryTable("user")
602 i, err := qs.PrepareInsert()
603 throwFail(t, err)
604
605 var user User
606 user.UserName = "testing1"
607 num, err := i.Insert(&user)
608 throwFail(t, err)
609 throwFail(t, AssertIs(num, T_Large, 0))
610
611 user.UserName = "testing2"
612 num, err = i.Insert(&user)
613 throwFail(t, err)
614 throwFail(t, AssertIs(num, T_Large, 0))
615
616 num, err = qs.Filter("user_name__in", "testing1", "testing2").Delete()
617 throwFail(t, err)
618 throwFail(t, AssertIs(num, T_Equal, 2))
619
620 err = i.Close()
621 throwFail(t, err)
622 err = i.Close()
623 throwFail(t, AssertIs(err, T_Equal, ErrStmtClosed))
624 }
625
626 func TestRaw(t *testing.T) {
627 switch dORM.Driver().Type() {
628 case DR_MySQL:
629 num, err := dORM.Raw("UPDATE user SET user_name = ? WHERE user_name = ?", "testing", "slene").Exec()
630 throwFail(t, err)
631 throwFail(t, AssertIs(num, T_Equal, 1))
632
633 num, err = dORM.Raw("UPDATE user SET user_name = ? WHERE user_name = ?", "slene", "testing").Exec()
634 throwFail(t, err)
635 throwFail(t, AssertIs(num, T_Equal, 1))
636
637 var maps []Params
638 num, err = dORM.Raw("SELECT user_name FROM user WHERE status = ?", 1).Values(&maps)
639 throwFail(t, err)
640 throwFail(t, AssertIs(num, T_Equal, 1))
641 if num == 1 {
642 throwFail(t, AssertIs(maps[0]["user_name"], T_Equal, "slene"))
643 }
644
645 var lists []ParamsList
646 num, err = dORM.Raw("SELECT user_name FROM user WHERE status = ?", 1).ValuesList(&lists)
647 throwFail(t, err)
648 throwFail(t, AssertIs(num, T_Equal, 1))
649 if num == 1 {
650 throwFail(t, AssertIs(lists[0][0], T_Equal, "slene"))
651 }
652
653 var list ParamsList
654 num, err = dORM.Raw("SELECT profile_id FROM user ORDER BY id ASC").ValuesFlat(&list)
655 throwFail(t, err)
656 throwFail(t, AssertIs(num, T_Equal, 3))
657 if num == 3 {
658 throwFail(t, AssertIs(list[0], T_Equal, "2"))
659 throwFail(t, AssertIs(list[1], T_Equal, "3"))
660 throwFail(t, AssertIs(list[2], T_Equal, ""))
661 }
662 }
663 }
664
665 func TestUpdate(t *testing.T) {
666 qs := dORM.QueryTable("user")
667 num, err := qs.Filter("user_name", "slene").Update(Params{
668 "is_staff": true,
669 })
670 throwFail(t, err)
671 throwFail(t, AssertIs(num, T_Equal, 1))
672 }
673
674 func TestDelete(t *testing.T) {
675 qs := dORM.QueryTable("user_profile")
676 num, err := qs.Filter("user__user_name", "slene").Delete()
677 throwFail(t, err)
678 throwFail(t, AssertIs(num, T_Equal, 1))
679
680 qs = dORM.QueryTable("user")
681 num, err = qs.Filter("user_name", "slene").Filter("profile__isnull", true).Count()
682 throwFail(t, err)
683 throwFail(t, AssertIs(num, T_Equal, 1))
684 }
685
686 func TestTransaction(t *testing.T) {
687
688 }
...@@ -5,6 +5,11 @@ import ( ...@@ -5,6 +5,11 @@ import (
5 "reflect" 5 "reflect"
6 ) 6 )
7 7
8 type Driver interface {
9 Name() string
10 Type() DriverType
11 }
12
8 type Fielder interface { 13 type Fielder interface {
9 String() string 14 String() string
10 FieldType() int 15 FieldType() int
...@@ -26,12 +31,16 @@ type Ormer interface { ...@@ -26,12 +31,16 @@ type Ormer interface {
26 Insert(Modeler) (int64, error) 31 Insert(Modeler) (int64, error)
27 Update(Modeler) (int64, error) 32 Update(Modeler) (int64, error)
28 Delete(Modeler) (int64, error) 33 Delete(Modeler) (int64, error)
34 M2mAdd(Modeler, string, ...interface{}) (int64, error)
35 M2mDel(Modeler, string, ...interface{}) (int64, error)
36 LoadRel(Modeler, string) (int64, error)
29 QueryTable(interface{}) QuerySeter 37 QueryTable(interface{}) QuerySeter
30 Using(string) error 38 Using(string) error
31 Begin() error 39 Begin() error
32 Commit() error 40 Commit() error
33 Rollback() error 41 Rollback() error
34 Raw(string, ...interface{}) RawSeter 42 Raw(string, ...interface{}) RawSeter
43 Driver() Driver
35 } 44 }
36 45
37 type Inserter interface { 46 type Inserter interface {
...@@ -42,16 +51,15 @@ type Inserter interface { ...@@ -42,16 +51,15 @@ type Inserter interface {
42 type QuerySeter interface { 51 type QuerySeter interface {
43 Filter(string, ...interface{}) QuerySeter 52 Filter(string, ...interface{}) QuerySeter
44 Exclude(string, ...interface{}) QuerySeter 53 Exclude(string, ...interface{}) QuerySeter
54 SetCond(*Condition) QuerySeter
45 Limit(int, ...int64) QuerySeter 55 Limit(int, ...int64) QuerySeter
46 Offset(int64) QuerySeter 56 Offset(int64) QuerySeter
47 OrderBy(...string) QuerySeter 57 OrderBy(...string) QuerySeter
48 RelatedSel(...interface{}) QuerySeter 58 RelatedSel(...interface{}) QuerySeter
49 SetCond(*Condition) QuerySeter
50 Count() (int64, error) 59 Count() (int64, error)
51 Update(Params) (int64, error) 60 Update(Params) (int64, error)
52 Delete() (int64, error) 61 Delete() (int64, error)
53 PrepareInsert() (Inserter, error) 62 PrepareInsert() (Inserter, error)
54
55 All(interface{}) (int64, error) 63 All(interface{}) (int64, error)
56 One(Modeler) error 64 One(Modeler) error
57 Values(*[]Params, ...string) (int64, error) 65 Values(*[]Params, ...string) (int64, error)
...@@ -60,12 +68,15 @@ type QuerySeter interface { ...@@ -60,12 +68,15 @@ type QuerySeter interface {
60 } 68 }
61 69
62 type RawPreparer interface { 70 type RawPreparer interface {
71 Exec(...interface{}) (int64, error)
63 Close() error 72 Close() error
64 } 73 }
65 74
66 type RawSeter interface { 75 type RawSeter interface {
67 Exec() (int64, error) 76 Exec() (int64, error)
68 Mapper(...interface{}) (int64, error) 77 QueryRow(...interface{}) error
78 QueryRows(...interface{}) (int64, error)
79 SetArgs(...interface{}) RawSeter
69 Values(*[]Params) (int64, error) 80 Values(*[]Params) (int64, error)
70 ValuesList(*[]ParamsList) (int64, error) 81 ValuesList(*[]ParamsList) (int64, error)
71 ValuesFlat(*ParamsList) (int64, error) 82 ValuesFlat(*ParamsList) (int64, error)
......
...@@ -171,6 +171,18 @@ func (a argInt) Get(i int, args ...int) (r int) { ...@@ -171,6 +171,18 @@ func (a argInt) Get(i int, args ...int) (r int) {
171 return 171 return
172 } 172 }
173 173
174 type argAny []interface{}
175
176 func (a argAny) Get(i int, args ...interface{}) (r interface{}) {
177 if i >= 0 && i < len(a) {
178 r = a[i]
179 }
180 if len(args) > 0 {
181 r = args[0]
182 }
183 return
184 }
185
174 func timeParse(dateString, format string) (time.Time, error) { 186 func timeParse(dateString, format string) (time.Time, error) {
175 tp, err := time.ParseInLocation(format, dateString, DefaultTimeLoc) 187 tp, err := time.ParseInLocation(format, dateString, DefaultTimeLoc)
176 return tp, err 188 return tp, err
......
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!