46668b81 by slene

some fix / add test

1 parent 10f4e822
...@@ -208,7 +208,7 @@ func (t *dbTables) getJoinSql() (join string) { ...@@ -208,7 +208,7 @@ func (t *dbTables) getJoinSql() (join string) {
208 208
209 switch { 209 switch {
210 case jt.fi.fieldType == RelManyToMany || jt.fi.reverse && jt.fi.reverseFieldInfo.fieldType == RelManyToMany: 210 case jt.fi.fieldType == RelManyToMany || jt.fi.reverse && jt.fi.reverseFieldInfo.fieldType == RelManyToMany:
211 c1 = jt.fi.mi.fields.pk[0].column 211 c1 = jt.fi.mi.fields.pk.column
212 for _, ffi := range jt.mi.fields.fieldsRel { 212 for _, ffi := range jt.mi.fields.fieldsRel {
213 if jt.fi.mi == ffi.relModelInfo { 213 if jt.fi.mi == ffi.relModelInfo {
214 c2 = ffi.column 214 c2 = ffi.column
...@@ -217,10 +217,10 @@ func (t *dbTables) getJoinSql() (join string) { ...@@ -217,10 +217,10 @@ func (t *dbTables) getJoinSql() (join string) {
217 } 217 }
218 default: 218 default:
219 c1 = jt.fi.column 219 c1 = jt.fi.column
220 c2 = jt.fi.relModelInfo.fields.pk[0].column 220 c2 = jt.fi.relModelInfo.fields.pk.column
221 221
222 if jt.fi.reverse { 222 if jt.fi.reverse {
223 c1 = jt.mi.fields.pk[0].column 223 c1 = jt.mi.fields.pk.column
224 c2 = jt.fi.reverseFieldInfo.column 224 c2 = jt.fi.reverseFieldInfo.column
225 } 225 }
226 } 226 }
...@@ -263,6 +263,8 @@ func (d *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, column, nam ...@@ -263,6 +263,8 @@ func (d *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, column, nam
263 if fi.reverseFieldInfo.fieldType == RelManyToMany { 263 if fi.reverseFieldInfo.fieldType == RelManyToMany {
264 mmi = fi.reverseFieldInfo.relThroughModelInfo 264 mmi = fi.reverseFieldInfo.relThroughModelInfo
265 } 265 }
266 default:
267 return
266 } 268 }
267 269
268 jt, _ := d.add(names, mmi, fi, fi.null == false) 270 jt, _ := d.add(names, mmi, fi, fi.null == false)
...@@ -434,40 +436,36 @@ type dbBase struct { ...@@ -434,40 +436,36 @@ type dbBase struct {
434 ins dbBaser 436 ins dbBaser
435 } 437 }
436 438
437 func (d *dbBase) existPk(mi *modelInfo, ind reflect.Value) ([]string, []interface{}, bool) { 439 func (d *dbBase) existPk(mi *modelInfo, ind reflect.Value) (column string, value interface{}, exist bool) {
438 exist := true 440
439 columns := make([]string, 0, len(mi.fields.pk)) 441 fi := mi.fields.pk
440 values := make([]interface{}, 0, len(mi.fields.pk)) 442
441 for _, fi := range mi.fields.pk {
442 v := ind.Field(fi.fieldIndex) 443 v := ind.Field(fi.fieldIndex)
443 if fi.fieldType&IsIntegerField > 0 { 444 if fi.fieldType&IsIntegerField > 0 {
444 vu := v.Int() 445 vu := v.Int()
445 if exist {
446 exist = vu > 0 446 exist = vu > 0
447 } 447 value = vu
448 values = append(values, vu)
449 } else { 448 } else {
450 vu := v.String() 449 vu := v.String()
451 if exist {
452 exist = vu != "" 450 exist = vu != ""
451 value = vu
453 } 452 }
454 values = append(values, vu) 453
455 } 454 column = fi.column
456 columns = append(columns, fi.column) 455
457 } 456 return
458 return columns, values, exist
459 } 457 }
460 458
461 func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, skipAuto bool, insert bool) (columns []string, values []interface{}, err error) { 459 func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, skipAuto bool, insert bool) (columns []string, values []interface{}, err error) {
462 _, pkValues, _ := d.existPk(mi, ind) 460 _, pkValue, _ := d.existPk(mi, ind)
463 for _, column := range mi.fields.orders { 461 for _, column := range mi.fields.orders {
464 fi := mi.fields.columns[column] 462 fi := mi.fields.columns[column]
465 if fi.dbcol == false || fi.auto && skipAuto { 463 if fi.dbcol == false || fi.auto && skipAuto {
466 continue 464 continue
467 } 465 }
468 var value interface{} 466 var value interface{}
469 if i, ok := mi.fields.pk.Exist(fi); ok { 467 if fi.pk {
470 value = pkValues[i] 468 value = pkValue
471 } else { 469 } else {
472 field := ind.Field(fi.fieldIndex) 470 field := ind.Field(fi.fieldIndex)
473 if fi.isFielder { 471 if fi.isFielder {
...@@ -493,9 +491,8 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, skipAuto bool, ...@@ -493,9 +491,8 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, skipAuto bool,
493 if field.IsNil() { 491 if field.IsNil() {
494 value = nil 492 value = nil
495 } else { 493 } else {
496 _, fvalues, fok := d.existPk(fi.relModelInfo, reflect.Indirect(field)) 494 if _, vu, ok := d.existPk(fi.relModelInfo, reflect.Indirect(field)); ok {
497 if fok { 495 value = vu
498 value = fvalues[0]
499 } else { 496 } else {
500 value = nil 497 value = nil
501 } 498 }
...@@ -560,17 +557,15 @@ func (d *dbBase) InsertStmt(stmt *sql.Stmt, mi *modelInfo, ind reflect.Value) (i ...@@ -560,17 +557,15 @@ func (d *dbBase) InsertStmt(stmt *sql.Stmt, mi *modelInfo, ind reflect.Value) (i
560 } 557 }
561 558
562 func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value) error { 559 func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value) error {
563 pkNames, pkValues, ok := d.existPk(mi, ind) 560 pkColumn, pkValue, ok := d.existPk(mi, ind)
564 if ok == false { 561 if ok == false {
565 return ErrMissPK 562 return ErrMissPK
566 } 563 }
567 564
568 pkColumns := strings.Join(pkNames, "` = ? AND `")
569
570 sels := strings.Join(mi.fields.dbcols, "`, `") 565 sels := strings.Join(mi.fields.dbcols, "`, `")
571 colsNum := len(mi.fields.dbcols) 566 colsNum := len(mi.fields.dbcols)
572 567
573 query := fmt.Sprintf("SELECT `%s` FROM `%s` WHERE `%s` = ?", sels, mi.table, pkColumns) 568 query := fmt.Sprintf("SELECT `%s` FROM `%s` WHERE `%s` = ?", sels, mi.table, pkColumn)
574 569
575 refs := make([]interface{}, colsNum) 570 refs := make([]interface{}, colsNum)
576 for i, _ := range refs { 571 for i, _ := range refs {
...@@ -578,8 +573,11 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value) error { ...@@ -578,8 +573,11 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value) error {
578 refs[i] = &ref 573 refs[i] = &ref
579 } 574 }
580 575
581 row := q.QueryRow(query, pkValues...) 576 row := q.QueryRow(query, pkValue)
582 if err := row.Scan(refs...); err != nil { 577 if err := row.Scan(refs...); err != nil {
578 if err == sql.ErrNoRows {
579 return ErrNoRows
580 }
583 return err 581 return err
584 } else { 582 } else {
585 elm := reflect.New(mi.addrField.Elem().Type()) 583 elm := reflect.New(mi.addrField.Elem().Type())
...@@ -618,7 +616,7 @@ func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e ...@@ -618,7 +616,7 @@ func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e
618 } 616 }
619 617
620 func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, error) { 618 func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, error) {
621 pkNames, pkValues, ok := d.existPk(mi, ind) 619 pkName, pkValue, ok := d.existPk(mi, ind)
622 if ok == false { 620 if ok == false {
623 return 0, ErrMissPK 621 return 0, ErrMissPK
624 } 622 }
...@@ -627,12 +625,11 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e ...@@ -627,12 +625,11 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e
627 return 0, err 625 return 0, err
628 } 626 }
629 627
630 pkColumns := strings.Join(pkNames, "` = ? AND `")
631 setColumns := strings.Join(setNames, "` = ?, `") 628 setColumns := strings.Join(setNames, "` = ?, `")
632 629
633 query := fmt.Sprintf("UPDATE `%s` SET `%s` = ? WHERE `%s` = ?", mi.table, setColumns, pkColumns) 630 query := fmt.Sprintf("UPDATE `%s` SET `%s` = ? WHERE `%s` = ?", mi.table, setColumns, pkName)
634 631
635 setValues = append(setValues, pkValues...) 632 setValues = append(setValues, pkValue)
636 633
637 if res, err := q.Exec(query, setValues...); err == nil { 634 if res, err := q.Exec(query, setValues...); err == nil {
638 return res.RowsAffected() 635 return res.RowsAffected()
...@@ -643,16 +640,14 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e ...@@ -643,16 +640,14 @@ func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e
643 } 640 }
644 641
645 func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, error) { 642 func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, error) {
646 names, values, ok := d.existPk(mi, ind) 643 pkName, pkValue, ok := d.existPk(mi, ind)
647 if ok == false { 644 if ok == false {
648 return 0, ErrMissPK 645 return 0, ErrMissPK
649 } 646 }
650 647
651 columns := strings.Join(names, "` = ? AND `") 648 query := fmt.Sprintf("DELETE FROM `%s` WHERE `%s` = ?", mi.table, pkName)
652
653 query := fmt.Sprintf("DELETE FROM `%s` WHERE `%s` = ?", mi.table, columns)
654 649
655 if res, err := q.Exec(query, values...); err == nil { 650 if res, err := q.Exec(query, pkValue); err == nil {
656 651
657 num, err := res.RowsAffected() 652 num, err := res.RowsAffected()
658 if err != nil { 653 if err != nil {
...@@ -660,17 +655,15 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e ...@@ -660,17 +655,15 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, e
660 } 655 }
661 656
662 if num > 0 { 657 if num > 0 {
663 if mi.fields.auto != nil { 658 if mi.fields.pk.auto {
664 ind.Field(mi.fields.auto.fieldIndex).SetInt(0) 659 ind.Field(mi.fields.pk.fieldIndex).SetInt(0)
665 } 660 }
666 661
667 if len(names) == 1 { 662 err := d.deleteRels(q, mi, []interface{}{pkValue})
668 err := d.deleteRels(q, mi, values)
669 if err != nil { 663 if err != nil {
670 return num, err 664 return num, err
671 } 665 }
672 } 666 }
673 }
674 667
675 return num, err 668 return num, err
676 } else { 669 } else {
...@@ -683,13 +676,13 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con ...@@ -683,13 +676,13 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
683 columns := make([]string, 0, len(params)) 676 columns := make([]string, 0, len(params))
684 values := make([]interface{}, 0, len(params)) 677 values := make([]interface{}, 0, len(params))
685 for col, val := range params { 678 for col, val := range params {
686 column := snakeString(col) 679 if fi, ok := mi.fields.GetByAny(col); ok == false || fi.dbcol == false {
687 if fi, ok := mi.fields.columns[column]; ok == false || fi.dbcol == false { 680 panic(fmt.Sprintf("wrong field/column name `%s`", col))
688 panic(fmt.Sprintf("wrong field/column name `%s`", column)) 681 } else {
689 } 682 columns = append(columns, fi.column)
690 columns = append(columns, column)
691 values = append(values, val) 683 values = append(values, val)
692 } 684 }
685 }
693 686
694 if len(columns) == 0 { 687 if len(columns) == 0 {
695 panic("update params cannot empty") 688 panic("update params cannot empty")
...@@ -721,15 +714,13 @@ func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}) erro ...@@ -721,15 +714,13 @@ func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}) erro
721 fi = fi.reverseFieldInfo 714 fi = fi.reverseFieldInfo
722 switch fi.onDelete { 715 switch fi.onDelete {
723 case od_CASCADE: 716 case od_CASCADE:
724 cond := NewCondition() 717 cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...)
725 cond.And(fmt.Sprintf("%s__in", fi.name), args...)
726 _, err := d.DeleteBatch(q, nil, fi.mi, cond) 718 _, err := d.DeleteBatch(q, nil, fi.mi, cond)
727 if err != nil { 719 if err != nil {
728 return err 720 return err
729 } 721 }
730 case od_SET_DEFAULT, od_SET_NULL: 722 case od_SET_DEFAULT, od_SET_NULL:
731 cond := NewCondition() 723 cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...)
732 cond.And(fmt.Sprintf("%s__in", fi.name), args...)
733 params := Params{fi.column: nil} 724 params := Params{fi.column: nil}
734 if fi.onDelete == od_SET_DEFAULT { 725 if fi.onDelete == od_SET_DEFAULT {
735 params[fi.column] = fi.initial.String() 726 params[fi.column] = fi.initial.String()
...@@ -757,13 +748,8 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con ...@@ -757,13 +748,8 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
757 where, args := tables.getCondSql(cond, false) 748 where, args := tables.getCondSql(cond, false)
758 join := tables.getJoinSql() 749 join := tables.getJoinSql()
759 750
760 colsNum := len(mi.fields.pk) 751 cols := fmt.Sprintf("T0.`%s`", mi.fields.pk.column)
761 cols := make([]string, colsNum) 752 query := fmt.Sprintf("SELECT %s FROM `%s` T0 %s%s", cols, mi.table, join, where)
762 for i, fi := range mi.fields.pk {
763 cols[i] = fi.column
764 }
765 colsql := fmt.Sprintf("T0.`%s`", strings.Join(cols, "`, T0.`"))
766 query := fmt.Sprintf("SELECT %s FROM `%s` T0 %s%s", colsql, mi.table, join, where)
767 753
768 var rs *sql.Rows 754 var rs *sql.Rows
769 if r, err := q.Query(query, args...); err != nil { 755 if r, err := q.Query(query, args...); err != nil {
...@@ -772,21 +758,15 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con ...@@ -772,21 +758,15 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
772 rs = r 758 rs = r
773 } 759 }
774 760
775 refs := make([]interface{}, colsNum)
776 for i, _ := range refs {
777 var ref interface{} 761 var ref interface{}
778 refs[i] = &ref
779 }
780 762
781 args = make([]interface{}, 0) 763 args = make([]interface{}, 0)
782 cnt := 0 764 cnt := 0
783 for rs.Next() { 765 for rs.Next() {
784 if err := rs.Scan(refs...); err != nil { 766 if err := rs.Scan(&ref); err != nil {
785 return 0, err 767 return 0, err
786 } 768 }
787 for _, ref := range refs { 769 args = append(args, reflect.ValueOf(ref).Interface())
788 args = append(args, reflect.ValueOf(ref).Elem().Interface())
789 }
790 cnt++ 770 cnt++
791 } 771 }
792 772
...@@ -794,14 +774,8 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con ...@@ -794,14 +774,8 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
794 return 0, nil 774 return 0, nil
795 } 775 }
796 776
797 if colsNum > 1 { 777 sql, args := d.ins.GetOperatorSql(mi, "in", args)
798 columns := strings.Join(cols, "` = ? AND `") 778 query = fmt.Sprintf("DELETE FROM `%s` WHERE `%s` %s", mi.table, mi.fields.pk.column, sql)
799 query = fmt.Sprintf("DELETE FROM `%s` WHERE `%s` = ?", mi.table, columns)
800 } else {
801 var sql string
802 sql, args = d.ins.GetOperatorSql(mi, "in", args)
803 query = fmt.Sprintf("DELETE FROM `%s` WHERE `%s` %s", mi.table, cols[0], sql)
804 }
805 779
806 if res, err := q.Exec(query, args...); err == nil { 780 if res, err := q.Exec(query, args...); err == nil {
807 num, err := res.RowsAffected() 781 num, err := res.RowsAffected()
...@@ -809,7 +783,7 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con ...@@ -809,7 +783,7 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
809 return 0, err 783 return 0, err
810 } 784 }
811 785
812 if colsNum == 1 && num > 0 { 786 if num > 0 {
813 err := d.deleteRels(q, mi, args) 787 err := d.deleteRels(q, mi, args)
814 if err != nil { 788 if err != nil {
815 return num, err 789 return num, err
...@@ -980,16 +954,14 @@ func (d *dbBase) GetOperatorSql(mi *modelInfo, operator string, args []interface ...@@ -980,16 +954,14 @@ func (d *dbBase) GetOperatorSql(mi *modelInfo, operator string, args []interface
980 copy(params, args) 954 copy(params, args)
981 sql := "" 955 sql := ""
982 for i, arg := range args { 956 for i, arg := range args {
983 if len(mi.fields.pk) == 1 {
984 if md, ok := arg.(Modeler); ok { 957 if md, ok := arg.(Modeler); ok {
985 ind := reflect.Indirect(reflect.ValueOf(md)) 958 ind := reflect.Indirect(reflect.ValueOf(md))
986 if _, values, exist := d.existPk(mi, ind); exist { 959 if _, vu, exist := d.existPk(mi, ind); exist {
987 arg = values[0] 960 arg = vu
988 } else { 961 } else {
989 panic(fmt.Sprintf("`%s` need a valid args value", operator)) 962 panic(fmt.Sprintf("`%s` need a valid args value", operator))
990 } 963 }
991 } 964 }
992 }
993 params[i] = arg 965 params[i] = arg
994 } 966 }
995 if operator == "in" { 967 if operator == "in" {
...@@ -1175,7 +1147,7 @@ setValue: ...@@ -1175,7 +1147,7 @@ setValue:
1175 value = v 1147 value = v
1176 } 1148 }
1177 case fieldType&IsRelField > 0: 1149 case fieldType&IsRelField > 0:
1178 fieldType = fi.relModelInfo.fields.pk[0].fieldType 1150 fieldType = fi.relModelInfo.fields.pk.fieldType
1179 goto setValue 1151 goto setValue
1180 } 1152 }
1181 1153
...@@ -1236,12 +1208,12 @@ setValue: ...@@ -1236,12 +1208,12 @@ setValue:
1236 } 1208 }
1237 case fieldType&IsRelField > 0: 1209 case fieldType&IsRelField > 0:
1238 if value != nil { 1210 if value != nil {
1239 fieldType = fi.relModelInfo.fields.pk[0].fieldType 1211 fieldType = fi.relModelInfo.fields.pk.fieldType
1240 mf := reflect.New(fi.relModelInfo.addrField.Elem().Type()) 1212 mf := reflect.New(fi.relModelInfo.addrField.Elem().Type())
1241 md := mf.Interface().(Modeler) 1213 md := mf.Interface().(Modeler)
1242 md.Init(md) 1214 md.Init(md)
1243 field.Set(mf) 1215 field.Set(mf)
1244 f := mf.Elem().Field(fi.relModelInfo.fields.pk[0].fieldIndex) 1216 f := mf.Elem().Field(fi.relModelInfo.fields.pk.fieldIndex)
1245 field = &f 1217 field = &f
1246 goto setValue 1218 goto setValue
1247 } 1219 }
......
...@@ -9,24 +9,37 @@ import ( ...@@ -9,24 +9,37 @@ import (
9 9
10 const defaultMaxIdle = 30 10 const defaultMaxIdle = 30
11 11
12 type driverType int 12 type DriverType int
13 13
14 const ( 14 const (
15 _ driverType = iota 15 _ DriverType = iota
16 DR_MySQL 16 DR_MySQL
17 DR_Sqlite 17 DR_Sqlite
18 DR_Oracle 18 DR_Oracle
19 DR_Postgres 19 DR_Postgres
20 ) 20 )
21 21
22 type driver string
23
24 func (d driver) Type() DriverType {
25 a, _ := dataBaseCache.get(string(d))
26 return a.Driver
27 }
28
29 func (d driver) Name() string {
30 return string(d)
31 }
32
33 var _ Driver = new(driver)
34
22 var ( 35 var (
23 dataBaseCache = &_dbCache{cache: make(map[string]*alias)} 36 dataBaseCache = &_dbCache{cache: make(map[string]*alias)}
24 drivers = map[string]driverType{ 37 drivers = map[string]DriverType{
25 "mysql": DR_MySQL, 38 "mysql": DR_MySQL,
26 "postgres": DR_Postgres, 39 "postgres": DR_Postgres,
27 "sqlite3": DR_Sqlite, 40 "sqlite3": DR_Sqlite,
28 } 41 }
29 dbBasers = map[driverType]dbBaser{ 42 dbBasers = map[DriverType]dbBaser{
30 DR_MySQL: newdbBaseMysql(), 43 DR_MySQL: newdbBaseMysql(),
31 DR_Sqlite: newdbBaseSqlite(), 44 DR_Sqlite: newdbBaseSqlite(),
32 DR_Oracle: newdbBaseMysql(), 45 DR_Oracle: newdbBaseMysql(),
...@@ -63,6 +76,7 @@ func (ac *_dbCache) getDefault() (al *alias) { ...@@ -63,6 +76,7 @@ func (ac *_dbCache) getDefault() (al *alias) {
63 76
64 type alias struct { 77 type alias struct {
65 Name string 78 Name string
79 Driver DriverType
66 DriverName string 80 DriverName string
67 DataSource string 81 DataSource string
68 MaxIdle int 82 MaxIdle int
...@@ -87,6 +101,7 @@ func RegisterDataBase(name, driverName, dataSource string, maxIdle int) { ...@@ -87,6 +101,7 @@ func RegisterDataBase(name, driverName, dataSource string, maxIdle int) {
87 101
88 if dr, ok := drivers[driverName]; ok { 102 if dr, ok := drivers[driverName]; ok {
89 al.DbBaser = dbBasers[dr] 103 al.DbBaser = dbBasers[dr]
104 al.Driver = dr
90 } else { 105 } else {
91 err = fmt.Errorf("driver name `%s` have not registered", driverName) 106 err = fmt.Errorf("driver name `%s` have not registered", driverName)
92 goto end 107 goto end
...@@ -116,7 +131,7 @@ end: ...@@ -116,7 +131,7 @@ end:
116 } 131 }
117 } 132 }
118 133
119 func RegisterDriver(name string, typ driverType) { 134 func RegisterDriver(name string, typ DriverType) {
120 if t, ok := drivers[name]; ok == false { 135 if t, ok := drivers[name]; ok == false {
121 drivers[name] = typ 136 drivers[name] = typ
122 } else { 137 } else {
......
...@@ -49,6 +49,7 @@ type _modelCache struct { ...@@ -49,6 +49,7 @@ type _modelCache struct {
49 sync.RWMutex 49 sync.RWMutex
50 orders []string 50 orders []string
51 cache map[string]*modelInfo 51 cache map[string]*modelInfo
52 done bool
52 } 53 }
53 54
54 func (mc *_modelCache) all() map[string]*modelInfo { 55 func (mc *_modelCache) all() map[string]*modelInfo {
......
...@@ -8,7 +8,7 @@ import ( ...@@ -8,7 +8,7 @@ import (
8 "strings" 8 "strings"
9 ) 9 )
10 10
11 func RegisterModel(model Modeler) { 11 func registerModel(model Modeler) {
12 info := newModelInfo(model) 12 info := newModelInfo(model)
13 model.Init(model) 13 model.Init(model)
14 table := model.GetTableName() 14 table := model.GetTableName()
...@@ -27,9 +27,10 @@ func RegisterModel(model Modeler) { ...@@ -27,9 +27,10 @@ func RegisterModel(model Modeler) {
27 modelCache.set(table, info) 27 modelCache.set(table, info)
28 } 28 }
29 29
30 func BootStrap() { 30 func bootStrap() {
31 modelCache.Lock() 31 if modelCache.done {
32 defer modelCache.Unlock() 32 return
33 }
33 34
34 var ( 35 var (
35 err error 36 err error
...@@ -59,14 +60,6 @@ func BootStrap() { ...@@ -59,14 +60,6 @@ func BootStrap() {
59 } 60 }
60 fi.relModelInfo = mii 61 fi.relModelInfo = mii
61 62
62 if fi.rel {
63
64 if mii.fields.pk.IsMulti() {
65 err = fmt.Errorf("field `%s` unsupport rel to multi primary key field", fi.fullName)
66 goto end
67 }
68 }
69
70 switch fi.fieldType { 63 switch fi.fieldType {
71 case RelManyToMany: 64 case RelManyToMany:
72 if fi.relThrough != "" { 65 if fi.relThrough != "" {
...@@ -207,6 +200,25 @@ end: ...@@ -207,6 +200,25 @@ end:
207 fmt.Println(err) 200 fmt.Println(err)
208 os.Exit(2) 201 os.Exit(2)
209 } 202 }
203 }
204
205 func RegisterModel(models ...Modeler) {
206 if modelCache.done {
207 panic(fmt.Errorf("RegisterModel must be run begore BootStrap"))
208 }
209
210 for _, model := range models {
211 registerModel(model)
212 }
213 }
214
215 func BootStrap() {
216 if modelCache.done {
217 return
218 }
210 219
211 runCommand() 220 modelCache.Lock()
221 defer modelCache.Unlock()
222 bootStrap()
223 modelCache.done = true
212 } 224 }
......
...@@ -32,32 +32,8 @@ func (f *fieldChoices) Clone() fieldChoices { ...@@ -32,32 +32,8 @@ func (f *fieldChoices) Clone() fieldChoices {
32 return *f 32 return *f
33 } 33 }
34 34
35 type primaryKeys []*fieldInfo
36
37 func (p *primaryKeys) Add(fi *fieldInfo) {
38 *p = append(*p, fi)
39 }
40
41 func (p primaryKeys) Exist(fi *fieldInfo) (int, bool) {
42 for i, v := range p {
43 if v == fi {
44 return i, true
45 }
46 }
47 return -1, false
48 }
49
50 func (p primaryKeys) IsMulti() bool {
51 return len(p) > 1
52 }
53
54 func (p primaryKeys) IsEmpty() bool {
55 return len(p) == 0
56 }
57
58 type fields struct { 35 type fields struct {
59 pk primaryKeys 36 pk *fieldInfo
60 auto *fieldInfo
61 columns map[string]*fieldInfo 37 columns map[string]*fieldInfo
62 fields map[string]*fieldInfo 38 fields map[string]*fieldInfo
63 fieldsLow map[string]*fieldInfo 39 fieldsLow map[string]*fieldInfo
......
...@@ -50,41 +50,31 @@ func newModelInfo(model Modeler) (info *modelInfo) { ...@@ -50,41 +50,31 @@ func newModelInfo(model Modeler) (info *modelInfo) {
50 if err != nil { 50 if err != nil {
51 break 51 break
52 } 52 }
53
53 added := info.fields.Add(fi) 54 added := info.fields.Add(fi)
54 if added == false { 55 if added == false {
55 err = errors.New(fmt.Sprintf("duplicate column name: %s", fi.column)) 56 err = errors.New(fmt.Sprintf("duplicate column name: %s", fi.column))
56 break 57 break
57 } 58 }
59
58 if fi.pk { 60 if fi.pk {
59 if info.fields.pk != nil { 61 if info.fields.pk != nil {
60 err = errors.New(fmt.Sprintf("one model must have one pk field only")) 62 err = errors.New(fmt.Sprintf("one model must have one pk field only"))
61 break 63 break
62 } else { 64 } else {
63 info.fields.pk.Add(fi) 65 info.fields.pk = fi
64 }
65 } 66 }
66 if fi.auto {
67 info.fields.auto = fi
68 } 67 }
68
69 fi.fieldIndex = i 69 fi.fieldIndex = i
70 fi.mi = info 70 fi.mi = info
71 } 71 }
72 72
73 if _, ok := info.fields.pk.Exist(info.fields.auto); info.fields.auto != nil && ok == false {
74 err = errors.New(fmt.Sprintf("when auto field exists, you cannot set other pk field"))
75 goto end
76 }
77
78 if err != nil { 73 if err != nil {
79 fmt.Println(fmt.Errorf("field: %s.%s, %s", ind.Type(), sf.Name, err)) 74 fmt.Println(fmt.Errorf("field: %s.%s, %s", ind.Type(), sf.Name, err))
80 os.Exit(2) 75 os.Exit(2)
81 } 76 }
82 77
83 end:
84 if err != nil {
85 fmt.Println(err)
86 os.Exit(2)
87 }
88 return 78 return
89 } 79 }
90 80
...@@ -125,6 +115,6 @@ func newM2MModelInfo(m1, m2 *modelInfo) (info *modelInfo) { ...@@ -125,6 +115,6 @@ func newM2MModelInfo(m1, m2 *modelInfo) (info *modelInfo) {
125 info.fields.Add(fa) 115 info.fields.Add(fa)
126 info.fields.Add(f1) 116 info.fields.Add(f1)
127 info.fields.Add(f2) 117 info.fields.Add(f2)
128 info.fields.pk.Add(fa) 118 info.fields.pk = fa
129 return 119 return
130 } 120 }
......
1 package orm
2
3 import (
4 "fmt"
5 "os"
6 "time"
7
8 _ "github.com/bmizerany/pq"
9 _ "github.com/go-sql-driver/mysql"
10 _ "github.com/mattn/go-sqlite3"
11 )
12
13 type User struct {
14 Id int `orm:"auto"`
15 UserName string `orm:"size(30);unique"`
16 Email string `orm:"size(100)"`
17 Password string `orm:"size(100)"`
18 Status int16 `orm:"choices(0,1,2,3);defalut(0)"`
19 IsStaff bool `orm:"default(false)"`
20 IsActive bool `orm:"default(1)"`
21 Created time.Time `orm:"auto_now_add;type(date)"`
22 Updated time.Time `orm:"auto_now"`
23 Profile *Profile `orm:"null;rel(one);on_delete(set_null)"`
24 Posts []*Post `orm:"reverse(many)" json:"-"`
25 Manager `json:"-"`
26 }
27
28 func NewUser() *User {
29 obj := new(User)
30 obj.Manager.Init(obj)
31 return obj
32 }
33
34 type Profile struct {
35 Id int `orm:"auto"`
36 Age int16 ``
37 Money float64 ``
38 User *User `orm:"reverse(one)" json:"-"`
39 Manager `json:"-"`
40 }
41
42 func (u *Profile) TableName() string {
43 return "user_profile"
44 }
45
46 func NewProfile() *Profile {
47 obj := new(Profile)
48 obj.Manager.Init(obj)
49 return obj
50 }
51
52 type Post struct {
53 Id int `orm:"auto"`
54 User *User `orm:"rel(fk)"` //
55 Title string `orm:"size(60)"`
56 Content string ``
57 Created time.Time `orm:"auto_now_add"`
58 Updated time.Time `orm:"auto_now"`
59 Tags []*Tag `orm:"rel(m2m)"`
60 Manager `json:"-"`
61 }
62
63 func NewPost() *Post {
64 obj := new(Post)
65 obj.Manager.Init(obj)
66 return obj
67 }
68
69 type Tag struct {
70 Id int `orm:"auto"`
71 Name string `orm:"size(30)"`
72 Posts []*Post `orm:"reverse(many)" json:"-"`
73 Manager `json:"-"`
74 }
75
76 func NewTag() *Tag {
77 obj := new(Tag)
78 obj.Manager.Init(obj)
79 return obj
80 }
81
82 type Comment struct {
83 Id int `orm:"auto"`
84 Post *Post `orm:"rel(fk)"`
85 Content string ``
86 Parent *Comment `orm:"null;rel(fk)"`
87 Created time.Time `orm:"auto_now_add"`
88 Manager `json:"-"`
89 }
90
91 func NewComment() *Comment {
92 obj := new(Comment)
93 obj.Manager.Init(obj)
94 return obj
95 }
96
97 var DBARGS = struct {
98 Driver string
99 Source string
100 }{
101 os.Getenv("ORM_DRIVER"),
102 os.Getenv("ORM_SOURCE"),
103 }
104
105 var dORM Ormer
106
107 func init() {
108 RegisterModel(new(User))
109 RegisterModel(new(Profile))
110 RegisterModel(new(Post))
111 RegisterModel(new(Tag))
112 RegisterModel(new(Comment))
113
114 if DBARGS.Driver == "" || DBARGS.Source == "" {
115 fmt.Println(`need driver and source!
116
117 Default DB Drivers.
118
119 driver: url
120 mysql: https://github.com/go-sql-driver/mysql
121 sqlite3: https://github.com/mattn/go-sqlite3
122 postgres: https://github.com/bmizerany/pq
123
124 eg: mysql
125 ORM_DRIVER=mysql ORM_SOURCE="root:root@/my_db?charset=utf8" go test github.com/astaxie/beego/orm
126 `)
127 os.Exit(2)
128 }
129
130 RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, 20)
131
132 BootStrap()
133
134 truncateTables()
135
136 dORM = NewOrm()
137 }
138
139 func truncateTables() {
140 logs := "truncate tables for test\n"
141 o := NewOrm()
142 for _, m := range modelCache.allOrdered() {
143 query := fmt.Sprintf("truncate table `%s`", m.table)
144 _, err := o.Raw(query).Exec()
145 logs += query + "\n"
146 if err != nil {
147 fmt.Println(logs)
148 fmt.Println(err)
149 os.Exit(2)
150 }
151 }
152 }
...@@ -9,13 +9,15 @@ import ( ...@@ -9,13 +9,15 @@ import (
9 ) 9 )
10 10
11 var ( 11 var (
12 ErrTXHasBegin = errors.New("<Ormer.Begin> transaction already begin")
13 ErrTXNotBegin = errors.New("<Ormer.Commit/Rollback> transaction not begin")
14 ErrMultiRows = errors.New("<QuerySeter.One> return multi rows")
15 ErrStmtClosed = errors.New("<QuerySeter.Insert> stmt already closed")
16 DefaultRowsLimit = 1000 12 DefaultRowsLimit = 1000
17 DefaultRelsDepth = 5 13 DefaultRelsDepth = 5
18 DefaultTimeLoc = time.Local 14 DefaultTimeLoc = time.Local
15 ErrTxHasBegan = errors.New("<Ormer.Begin> transaction already begin")
16 ErrTxDone = errors.New("<Ormer.Commit/Rollback> transaction not begin")
17 ErrMultiRows = errors.New("<QuerySeter> return multi rows")
18 ErrNoRows = errors.New("<QuerySeter> not row found")
19 ErrStmtClosed = errors.New("<QuerySeter> stmt already closed")
20 ErrNotImplement = errors.New("have not implement")
19 ) 21 )
20 22
21 type Params map[string]interface{} 23 type Params map[string]interface{}
...@@ -27,13 +29,15 @@ type orm struct { ...@@ -27,13 +29,15 @@ type orm struct {
27 isTx bool 29 isTx bool
28 } 30 }
29 31
32 var _ Ormer = new(orm)
33
30 func (o *orm) getMiInd(md Modeler) (mi *modelInfo, ind reflect.Value) { 34 func (o *orm) getMiInd(md Modeler) (mi *modelInfo, ind reflect.Value) {
31 md.Init(md, true) 35 md.Init(md, true)
32 name := md.GetTableName() 36 name := md.GetTableName()
33 if mi, ok := modelCache.get(name); ok { 37 if mi, ok := modelCache.get(name); ok {
34 return mi, reflect.Indirect(reflect.ValueOf(md)) 38 return mi, reflect.Indirect(reflect.ValueOf(md))
35 } 39 }
36 panic(fmt.Sprintf("<orm.Object> table name: `%s` not exists", name)) 40 panic(fmt.Sprintf("<orm> table name: `%s` not exists", name))
37 } 41 }
38 42
39 func (o *orm) Read(md Modeler) error { 43 func (o *orm) Read(md Modeler) error {
...@@ -52,8 +56,8 @@ func (o *orm) Insert(md Modeler) (int64, error) { ...@@ -52,8 +56,8 @@ func (o *orm) Insert(md Modeler) (int64, error) {
52 return id, err 56 return id, err
53 } 57 }
54 if id > 0 { 58 if id > 0 {
55 if mi.fields.auto != nil { 59 if mi.fields.pk.auto {
56 ind.Field(mi.fields.auto.fieldIndex).SetInt(id) 60 ind.Field(mi.fields.pk.fieldIndex).SetInt(id)
57 } 61 }
58 } 62 }
59 return id, nil 63 return id, nil
...@@ -75,13 +79,31 @@ func (o *orm) Delete(md Modeler) (int64, error) { ...@@ -75,13 +79,31 @@ func (o *orm) Delete(md Modeler) (int64, error) {
75 return num, err 79 return num, err
76 } 80 }
77 if num > 0 { 81 if num > 0 {
78 if mi.fields.auto != nil { 82 if mi.fields.pk.auto {
79 ind.Field(mi.fields.auto.fieldIndex).SetInt(0) 83 ind.Field(mi.fields.pk.fieldIndex).SetInt(0)
80 } 84 }
81 } 85 }
82 return num, nil 86 return num, nil
83 } 87 }
84 88
89 func (o *orm) M2mAdd(md Modeler, name string, mds ...interface{}) (int64, error) {
90 // TODO
91 panic(ErrNotImplement)
92 return 0, nil
93 }
94
95 func (o *orm) M2mDel(md Modeler, name string, mds ...interface{}) (int64, error) {
96 // TODO
97 panic(ErrNotImplement)
98 return 0, nil
99 }
100
101 func (o *orm) LoadRel(md Modeler, name string) (int64, error) {
102 // TODO
103 panic(ErrNotImplement)
104 return 0, nil
105 }
106
85 func (o *orm) QueryTable(ptrStructOrTableName interface{}) QuerySeter { 107 func (o *orm) QueryTable(ptrStructOrTableName interface{}) QuerySeter {
86 name := "" 108 name := ""
87 if table, ok := ptrStructOrTableName.(string); ok { 109 if table, ok := ptrStructOrTableName.(string); ok {
...@@ -111,7 +133,7 @@ func (o *orm) Using(name string) error { ...@@ -111,7 +133,7 @@ func (o *orm) Using(name string) error {
111 133
112 func (o *orm) Begin() error { 134 func (o *orm) Begin() error {
113 if o.isTx { 135 if o.isTx {
114 return ErrTXHasBegin 136 return ErrTxHasBegan
115 } 137 }
116 tx, err := o.alias.DB.Begin() 138 tx, err := o.alias.DB.Begin()
117 if err != nil { 139 if err != nil {
...@@ -124,24 +146,28 @@ func (o *orm) Begin() error { ...@@ -124,24 +146,28 @@ func (o *orm) Begin() error {
124 146
125 func (o *orm) Commit() error { 147 func (o *orm) Commit() error {
126 if o.isTx == false { 148 if o.isTx == false {
127 return ErrTXNotBegin 149 return ErrTxDone
128 } 150 }
129 err := o.db.(*sql.Tx).Commit() 151 err := o.db.(*sql.Tx).Commit()
130 if err == nil { 152 if err == nil {
131 o.isTx = false 153 o.isTx = false
132 o.db = o.alias.DB 154 o.db = o.alias.DB
155 } else if err == sql.ErrTxDone {
156 return ErrTxDone
133 } 157 }
134 return err 158 return err
135 } 159 }
136 160
137 func (o *orm) Rollback() error { 161 func (o *orm) Rollback() error {
138 if o.isTx == false { 162 if o.isTx == false {
139 return ErrTXNotBegin 163 return ErrTxDone
140 } 164 }
141 err := o.db.(*sql.Tx).Rollback() 165 err := o.db.(*sql.Tx).Rollback()
142 if err == nil { 166 if err == nil {
143 o.isTx = false 167 o.isTx = false
144 o.db = o.alias.DB 168 o.db = o.alias.DB
169 } else if err == sql.ErrTxDone {
170 return ErrTxDone
145 } 171 }
146 return err 172 return err
147 } 173 }
...@@ -150,7 +176,13 @@ func (o *orm) Raw(query string, args ...interface{}) RawSeter { ...@@ -150,7 +176,13 @@ func (o *orm) Raw(query string, args ...interface{}) RawSeter {
150 return newRawSet(o, query, args) 176 return newRawSet(o, query, args)
151 } 177 }
152 178
179 func (o *orm) Driver() Driver {
180 return driver(o.alias.Name)
181 }
182
153 func NewOrm() Ormer { 183 func NewOrm() Ormer {
184 BootStrap() // execute only once
185
154 o := new(orm) 186 o := new(orm)
155 err := o.Using("default") 187 err := o.Using("default")
156 if err != nil { 188 if err != nil {
......
...@@ -26,23 +26,24 @@ func NewCondition() *Condition { ...@@ -26,23 +26,24 @@ func NewCondition() *Condition {
26 return c 26 return c
27 } 27 }
28 28
29 func (c *Condition) And(expr string, args ...interface{}) *Condition { 29 func (c Condition) And(expr string, args ...interface{}) *Condition {
30 if expr == "" || len(args) == 0 { 30 if expr == "" || len(args) == 0 {
31 panic("<Condition.And> args cannot empty") 31 panic("<Condition.And> args cannot empty")
32 } 32 }
33 c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args}) 33 c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args})
34 return c 34 return &c
35 } 35 }
36 36
37 func (c *Condition) AndNot(expr string, args ...interface{}) *Condition { 37 func (c Condition) AndNot(expr string, args ...interface{}) *Condition {
38 if expr == "" || len(args) == 0 { 38 if expr == "" || len(args) == 0 {
39 panic("<Condition.AndNot> args cannot empty") 39 panic("<Condition.AndNot> args cannot empty")
40 } 40 }
41 c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isNot: true}) 41 c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isNot: true})
42 return c 42 return &c
43 } 43 }
44 44
45 func (c *Condition) AndCond(cond *Condition) *Condition { 45 func (c *Condition) AndCond(cond *Condition) *Condition {
46 c = c.clone()
46 if c == cond { 47 if c == cond {
47 panic("cannot use self as sub cond") 48 panic("cannot use self as sub cond")
48 } 49 }
...@@ -52,23 +53,24 @@ func (c *Condition) AndCond(cond *Condition) *Condition { ...@@ -52,23 +53,24 @@ func (c *Condition) AndCond(cond *Condition) *Condition {
52 return c 53 return c
53 } 54 }
54 55
55 func (c *Condition) Or(expr string, args ...interface{}) *Condition { 56 func (c Condition) Or(expr string, args ...interface{}) *Condition {
56 if expr == "" || len(args) == 0 { 57 if expr == "" || len(args) == 0 {
57 panic("<Condition.Or> args cannot empty") 58 panic("<Condition.Or> args cannot empty")
58 } 59 }
59 c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isOr: true}) 60 c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isOr: true})
60 return c 61 return &c
61 } 62 }
62 63
63 func (c *Condition) OrNot(expr string, args ...interface{}) *Condition { 64 func (c Condition) OrNot(expr string, args ...interface{}) *Condition {
64 if expr == "" || len(args) == 0 { 65 if expr == "" || len(args) == 0 {
65 panic("<Condition.OrNot> args cannot empty") 66 panic("<Condition.OrNot> args cannot empty")
66 } 67 }
67 c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isNot: true, isOr: true}) 68 c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isNot: true, isOr: true})
68 return c 69 return &c
69 } 70 }
70 71
71 func (c *Condition) OrCond(cond *Condition) *Condition { 72 func (c *Condition) OrCond(cond *Condition) *Condition {
73 c = c.clone()
72 if c == cond { 74 if c == cond {
73 panic("cannot use self as sub cond") 75 panic("cannot use self as sub cond")
74 } 76 }
...@@ -82,13 +84,6 @@ func (c *Condition) IsEmpty() bool { ...@@ -82,13 +84,6 @@ func (c *Condition) IsEmpty() bool {
82 return len(c.params) == 0 84 return len(c.params) == 0
83 } 85 }
84 86
85 func (c Condition) Clone() *Condition { 87 func (c Condition) clone() *Condition {
86 params := c.params
87 c.params = make([]condValue, len(params))
88 copy(c.params, params)
89 return &c 88 return &c
90 } 89 }
91
92 func (c *Condition) Merge() (expr string, args []interface{}) {
93 return expr, args
94 }
......
...@@ -13,6 +13,8 @@ type insertSet struct { ...@@ -13,6 +13,8 @@ type insertSet struct {
13 closed bool 13 closed bool
14 } 14 }
15 15
16 var _ Inserter = new(insertSet)
17
16 func (o *insertSet) Insert(md Modeler) (int64, error) { 18 func (o *insertSet) Insert(md Modeler) (int64, error) {
17 if o.closed { 19 if o.closed {
18 return 0, ErrStmtClosed 20 return 0, ErrStmtClosed
...@@ -28,14 +30,17 @@ func (o *insertSet) Insert(md Modeler) (int64, error) { ...@@ -28,14 +30,17 @@ func (o *insertSet) Insert(md Modeler) (int64, error) {
28 return id, err 30 return id, err
29 } 31 }
30 if id > 0 { 32 if id > 0 {
31 if o.mi.fields.auto != nil { 33 if o.mi.fields.pk.auto {
32 ind.Field(o.mi.fields.auto.fieldIndex).SetInt(id) 34 ind.Field(o.mi.fields.pk.fieldIndex).SetInt(id)
33 } 35 }
34 } 36 }
35 return id, nil 37 return id, nil
36 } 38 }
37 39
38 func (o *insertSet) Close() error { 40 func (o *insertSet) Close() error {
41 if o.closed {
42 return ErrStmtClosed
43 }
39 o.closed = true 44 o.closed = true
40 return o.stmt.Close() 45 return o.stmt.Close()
41 } 46 }
......
...@@ -15,47 +15,43 @@ type querySet struct { ...@@ -15,47 +15,43 @@ type querySet struct {
15 orm *orm 15 orm *orm
16 } 16 }
17 17
18 func (o *querySet) Filter(expr string, args ...interface{}) QuerySeter { 18 var _ QuerySeter = new(querySet)
19 o = o.clone() 19
20 func (o querySet) Filter(expr string, args ...interface{}) QuerySeter {
20 if o.cond == nil { 21 if o.cond == nil {
21 o.cond = NewCondition() 22 o.cond = NewCondition()
22 } 23 }
23 o.cond.And(expr, args...) 24 o.cond = o.cond.And(expr, args...)
24 return o 25 return &o
25 } 26 }
26 27
27 func (o *querySet) Exclude(expr string, args ...interface{}) QuerySeter { 28 func (o querySet) Exclude(expr string, args ...interface{}) QuerySeter {
28 o = o.clone()
29 if o.cond == nil { 29 if o.cond == nil {
30 o.cond = NewCondition() 30 o.cond = NewCondition()
31 } 31 }
32 o.cond.AndNot(expr, args...) 32 o.cond = o.cond.AndNot(expr, args...)
33 return o 33 return &o
34 } 34 }
35 35
36 func (o *querySet) Limit(limit int, args ...int64) QuerySeter { 36 func (o querySet) Limit(limit int, args ...int64) QuerySeter {
37 o = o.clone()
38 o.limit = limit 37 o.limit = limit
39 if len(args) > 0 { 38 if len(args) > 0 {
40 o.offset = args[0] 39 o.offset = args[0]
41 } 40 }
42 return o 41 return &o
43 } 42 }
44 43
45 func (o *querySet) Offset(offset int64) QuerySeter { 44 func (o querySet) Offset(offset int64) QuerySeter {
46 o = o.clone()
47 o.offset = offset 45 o.offset = offset
48 return o 46 return &o
49 } 47 }
50 48
51 func (o *querySet) OrderBy(exprs ...string) QuerySeter { 49 func (o querySet) OrderBy(exprs ...string) QuerySeter {
52 o = o.clone()
53 o.orders = exprs 50 o.orders = exprs
54 return o 51 return &o
55 } 52 }
56 53
57 func (o *querySet) RelatedSel(params ...interface{}) QuerySeter { 54 func (o querySet) RelatedSel(params ...interface{}) QuerySeter {
58 o = o.clone()
59 var related []string 55 var related []string
60 if len(params) == 0 { 56 if len(params) == 0 {
61 o.relDepth = DefaultRelsDepth 57 o.relDepth = DefaultRelsDepth
...@@ -72,13 +68,6 @@ func (o *querySet) RelatedSel(params ...interface{}) QuerySeter { ...@@ -72,13 +68,6 @@ func (o *querySet) RelatedSel(params ...interface{}) QuerySeter {
72 } 68 }
73 } 69 }
74 o.related = related 70 o.related = related
75 return o
76 }
77
78 func (o querySet) clone() *querySet {
79 if o.cond != nil {
80 o.cond = o.cond.Clone()
81 }
82 return &o 71 return &o
83 } 72 }
84 73
...@@ -115,6 +104,9 @@ func (o *querySet) One(container Modeler) error { ...@@ -115,6 +104,9 @@ func (o *querySet) One(container Modeler) error {
115 if num > 1 { 104 if num > 1 {
116 return ErrMultiRows 105 return ErrMultiRows
117 } 106 }
107 if num == 0 {
108 return ErrNoRows
109 }
118 return nil 110 return nil
119 } 111 }
120 112
......
...@@ -63,6 +63,8 @@ type rawSet struct { ...@@ -63,6 +63,8 @@ type rawSet struct {
63 orm *orm 63 orm *orm
64 } 64 }
65 65
66 var _ RawSeter = new(rawSet)
67
66 func (o rawSet) SetArgs(args ...interface{}) RawSeter { 68 func (o rawSet) SetArgs(args ...interface{}) RawSeter {
67 o.args = args 69 o.args = args
68 return &o 70 return &o
...@@ -76,7 +78,12 @@ func (o *rawSet) Exec() (int64, error) { ...@@ -76,7 +78,12 @@ func (o *rawSet) Exec() (int64, error) {
76 return getResult(res) 78 return getResult(res)
77 } 79 }
78 80
79 func (o *rawSet) Mapper(...interface{}) (int64, error) { 81 func (o *rawSet) QueryRow(...interface{}) error {
82 //TODO
83 return nil
84 }
85
86 func (o *rawSet) QueryRows(...interface{}) (int64, error) {
80 //TODO 87 //TODO
81 return 0, nil 88 return 0, nil
82 } 89 }
...@@ -120,7 +127,7 @@ func (o *rawSet) readValues(container interface{}) (int64, error) { ...@@ -120,7 +127,7 @@ func (o *rawSet) readValues(container interface{}) (int64, error) {
120 cols = columns 127 cols = columns
121 refs = make([]interface{}, len(cols)) 128 refs = make([]interface{}, len(cols))
122 for i, _ := range refs { 129 for i, _ := range refs {
123 var ref string 130 var ref sql.NullString
124 refs[i] = &ref 131 refs[i] = &ref
125 } 132 }
126 } 133 }
...@@ -134,21 +141,21 @@ func (o *rawSet) readValues(container interface{}) (int64, error) { ...@@ -134,21 +141,21 @@ func (o *rawSet) readValues(container interface{}) (int64, error) {
134 case 1: 141 case 1:
135 params := make(Params, len(cols)) 142 params := make(Params, len(cols))
136 for i, ref := range refs { 143 for i, ref := range refs {
137 value := reflect.Indirect(reflect.ValueOf(ref)).Interface() 144 value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString)
138 params[cols[i]] = value 145 params[cols[i]] = value.String
139 } 146 }
140 maps = append(maps, params) 147 maps = append(maps, params)
141 case 2: 148 case 2:
142 params := make(ParamsList, 0, len(cols)) 149 params := make(ParamsList, 0, len(cols))
143 for _, ref := range refs { 150 for _, ref := range refs {
144 value := reflect.Indirect(reflect.ValueOf(ref)).Interface() 151 value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString)
145 params = append(params, value) 152 params = append(params, value.String)
146 } 153 }
147 lists = append(lists, params) 154 lists = append(lists, params)
148 case 3: 155 case 3:
149 for _, ref := range refs { 156 for _, ref := range refs {
150 value := reflect.Indirect(reflect.ValueOf(ref)).Interface() 157 value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString)
151 list = append(list, value) 158 list = append(list, value.String)
152 } 159 }
153 } 160 }
154 161
......
...@@ -5,6 +5,11 @@ import ( ...@@ -5,6 +5,11 @@ import (
5 "reflect" 5 "reflect"
6 ) 6 )
7 7
8 type Driver interface {
9 Name() string
10 Type() DriverType
11 }
12
8 type Fielder interface { 13 type Fielder interface {
9 String() string 14 String() string
10 FieldType() int 15 FieldType() int
...@@ -26,12 +31,16 @@ type Ormer interface { ...@@ -26,12 +31,16 @@ type Ormer interface {
26 Insert(Modeler) (int64, error) 31 Insert(Modeler) (int64, error)
27 Update(Modeler) (int64, error) 32 Update(Modeler) (int64, error)
28 Delete(Modeler) (int64, error) 33 Delete(Modeler) (int64, error)
34 M2mAdd(Modeler, string, ...interface{}) (int64, error)
35 M2mDel(Modeler, string, ...interface{}) (int64, error)
36 LoadRel(Modeler, string) (int64, error)
29 QueryTable(interface{}) QuerySeter 37 QueryTable(interface{}) QuerySeter
30 Using(string) error 38 Using(string) error
31 Begin() error 39 Begin() error
32 Commit() error 40 Commit() error
33 Rollback() error 41 Rollback() error
34 Raw(string, ...interface{}) RawSeter 42 Raw(string, ...interface{}) RawSeter
43 Driver() Driver
35 } 44 }
36 45
37 type Inserter interface { 46 type Inserter interface {
...@@ -42,16 +51,15 @@ type Inserter interface { ...@@ -42,16 +51,15 @@ type Inserter interface {
42 type QuerySeter interface { 51 type QuerySeter interface {
43 Filter(string, ...interface{}) QuerySeter 52 Filter(string, ...interface{}) QuerySeter
44 Exclude(string, ...interface{}) QuerySeter 53 Exclude(string, ...interface{}) QuerySeter
54 SetCond(*Condition) QuerySeter
45 Limit(int, ...int64) QuerySeter 55 Limit(int, ...int64) QuerySeter
46 Offset(int64) QuerySeter 56 Offset(int64) QuerySeter
47 OrderBy(...string) QuerySeter 57 OrderBy(...string) QuerySeter
48 RelatedSel(...interface{}) QuerySeter 58 RelatedSel(...interface{}) QuerySeter
49 SetCond(*Condition) QuerySeter
50 Count() (int64, error) 59 Count() (int64, error)
51 Update(Params) (int64, error) 60 Update(Params) (int64, error)
52 Delete() (int64, error) 61 Delete() (int64, error)
53 PrepareInsert() (Inserter, error) 62 PrepareInsert() (Inserter, error)
54
55 All(interface{}) (int64, error) 63 All(interface{}) (int64, error)
56 One(Modeler) error 64 One(Modeler) error
57 Values(*[]Params, ...string) (int64, error) 65 Values(*[]Params, ...string) (int64, error)
...@@ -60,12 +68,15 @@ type QuerySeter interface { ...@@ -60,12 +68,15 @@ type QuerySeter interface {
60 } 68 }
61 69
62 type RawPreparer interface { 70 type RawPreparer interface {
71 Exec(...interface{}) (int64, error)
63 Close() error 72 Close() error
64 } 73 }
65 74
66 type RawSeter interface { 75 type RawSeter interface {
67 Exec() (int64, error) 76 Exec() (int64, error)
68 Mapper(...interface{}) (int64, error) 77 QueryRow(...interface{}) error
78 QueryRows(...interface{}) (int64, error)
79 SetArgs(...interface{}) RawSeter
69 Values(*[]Params) (int64, error) 80 Values(*[]Params) (int64, error)
70 ValuesList(*[]ParamsList) (int64, error) 81 ValuesList(*[]ParamsList) (int64, error)
71 ValuesFlat(*ParamsList) (int64, error) 82 ValuesFlat(*ParamsList) (int64, error)
......
...@@ -171,6 +171,18 @@ func (a argInt) Get(i int, args ...int) (r int) { ...@@ -171,6 +171,18 @@ func (a argInt) Get(i int, args ...int) (r int) {
171 return 171 return
172 } 172 }
173 173
174 type argAny []interface{}
175
176 func (a argAny) Get(i int, args ...interface{}) (r interface{}) {
177 if i >= 0 && i < len(a) {
178 r = a[i]
179 }
180 if len(args) > 0 {
181 r = args[0]
182 }
183 return
184 }
185
174 func timeParse(dateString, format string) (time.Time, error) { 186 func timeParse(dateString, format string) (time.Time, error) {
175 tp, err := time.ParseInLocation(format, dateString, DefaultTimeLoc) 187 tp, err := time.ParseInLocation(format, dateString, DefaultTimeLoc)
176 return tp, err 188 return tp, err
......
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!