d043ebcd by slene

orm support complete m2m operation api / auto load related api

1 parent e11c40ee
...@@ -151,7 +151,11 @@ func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex ...@@ -151,7 +151,11 @@ func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex
151 } 151 }
152 152
153 if mi.model != nil { 153 if mi.model != nil {
154 for _, names := range getTableUnique(mi.addrField) { 154 allnames := getTableUnique(mi.addrField)
155 if !mi.manual && len(mi.uniques) > 0 {
156 allnames = append(allnames, mi.uniques)
157 }
158 for _, names := range allnames {
155 cols := make([]string, 0, len(names)) 159 cols := make([]string, 0, len(names))
156 for _, name := range names { 160 for _, name := range names {
157 if fi, ok := mi.fields.GetByAny(name); ok && fi.dbcol { 161 if fi, ok := mi.fields.GetByAny(name); ok && fi.dbcol {
......
...@@ -52,7 +52,6 @@ type dbBase struct { ...@@ -52,7 +52,6 @@ type dbBase struct {
52 var _ dbBaser = new(dbBase) 52 var _ dbBaser = new(dbBase)
53 53
54 func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, tz *time.Location) (columns []string, values []interface{}, err error) { 54 func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, tz *time.Location) (columns []string, values []interface{}, err error) {
55 _, pkValue, _ := getExistPk(mi, ind)
56 for _, column := range cols { 55 for _, column := range cols {
57 var fi *fieldInfo 56 var fi *fieldInfo
58 if fi, _ = mi.fields.GetByAny(column); fi != nil { 57 if fi, _ = mi.fields.GetByAny(column); fi != nil {
...@@ -63,9 +62,20 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, ...@@ -63,9 +62,20 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string,
63 if fi.dbcol == false || fi.auto && skipAuto { 62 if fi.dbcol == false || fi.auto && skipAuto {
64 continue 63 continue
65 } 64 }
65 value, err := d.collectFieldValue(mi, fi, ind, insert, tz)
66 if err != nil {
67 return nil, nil, err
68 }
69 columns = append(columns, column)
70 values = append(values, value)
71 }
72 return
73 }
74
75 func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Value, insert bool, tz *time.Location) (interface{}, error) {
66 var value interface{} 76 var value interface{}
67 if fi.pk { 77 if fi.pk {
68 value = pkValue 78 _, value, _ = getExistPk(mi, ind)
69 } else { 79 } else {
70 field := ind.Field(fi.fieldIndex) 80 field := ind.Field(fi.fieldIndex)
71 if fi.isFielder { 81 if fi.isFielder {
...@@ -111,7 +121,7 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, ...@@ -111,7 +121,7 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string,
111 } 121 }
112 } 122 }
113 if fi.null == false && value == nil { 123 if fi.null == false && value == nil {
114 return nil, nil, errors.New(fmt.Sprintf("field `%s` cannot be NULL", fi.fullName)) 124 return nil, errors.New(fmt.Sprintf("field `%s` cannot be NULL", fi.fullName))
115 } 125 }
116 } 126 }
117 } 127 }
...@@ -135,10 +145,7 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, ...@@ -135,10 +145,7 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string,
135 } 145 }
136 } 146 }
137 } 147 }
138 columns = append(columns, column) 148 return value, nil
139 values = append(values, value)
140 }
141 return
142 } 149 }
143 150
144 func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, error) { 151 func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, error) {
...@@ -250,6 +257,10 @@ func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time. ...@@ -250,6 +257,10 @@ func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
250 return 0, err 257 return 0, err
251 } 258 }
252 259
260 return d.InsertValue(q, mi, names, values)
261 }
262
263 func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, names []string, values []interface{}) (int64, error) {
253 Q := d.ins.TableQuote() 264 Q := d.ins.TableQuote()
254 265
255 marks := make([]string, len(names)) 266 marks := make([]string, len(names))
...@@ -653,10 +664,12 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi ...@@ -653,10 +664,12 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
653 trefs = refs[len(tCols):] 664 trefs = refs[len(tCols):]
654 665
655 for _, tbl := range tables.tables { 666 for _, tbl := range tables.tables {
667 // loop selected tables
656 if tbl.sel { 668 if tbl.sel {
657 last := mind 669 last := mind
658 names := "" 670 names := ""
659 mmi := mi 671 mmi := mi
672 // loop cascade models
660 for _, name := range tbl.names { 673 for _, name := range tbl.names {
661 names += name 674 names += name
662 if val, ok := cacheV[names]; ok { 675 if val, ok := cacheV[names]; ok {
...@@ -665,8 +678,10 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi ...@@ -665,8 +678,10 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
665 } else { 678 } else {
666 fi := mmi.fields.GetByName(name) 679 fi := mmi.fields.GetByName(name)
667 lastm := mmi 680 lastm := mmi
668 mmi := fi.relModelInfo 681 mmi = fi.relModelInfo
669 field := reflect.Indirect(last.Field(fi.fieldIndex)) 682 field := last
683 if last.Kind() != reflect.Invalid {
684 field = reflect.Indirect(last.Field(fi.fieldIndex))
670 if field.IsValid() { 685 if field.IsValid() {
671 d.setColsValues(mmi, &field, mmi.fields.dbcols, trefs[:len(mmi.fields.dbcols)], tz) 686 d.setColsValues(mmi, &field, mmi.fields.dbcols, trefs[:len(mmi.fields.dbcols)], tz)
672 for _, fi := range mmi.fields.fieldsReverse { 687 for _, fi := range mmi.fields.fieldsReverse {
...@@ -679,14 +694,15 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi ...@@ -679,14 +694,15 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
679 } 694 }
680 } 695 }
681 } 696 }
682 cacheV[names] = &field
683 cacheM[names] = mmi
684 last = field 697 last = field
685 } 698 }
686 trefs = trefs[len(mmi.fields.dbcols):]
687 } 699 }
700 cacheV[names] = &field
701 cacheM[names] = mmi
688 } 702 }
689 } 703 }
704 trefs = trefs[len(mmi.fields.dbcols):]
705 }
690 } 706 }
691 707
692 if one { 708 if one {
......
...@@ -100,22 +100,29 @@ func (t *dbTables) parseRelated(rels []string, depth int) { ...@@ -100,22 +100,29 @@ func (t *dbTables) parseRelated(rels []string, depth int) {
100 exs = strings.Split(s, ExprSep) 100 exs = strings.Split(s, ExprSep)
101 names = make([]string, 0, len(exs)) 101 names = make([]string, 0, len(exs))
102 mmi = t.mi 102 mmi = t.mi
103 cansel = true 103 cancel = true
104 jtl *dbTable 104 jtl *dbTable
105 ) 105 )
106
107 inner := true
108
106 for _, ex := range exs { 109 for _, ex := range exs {
107 if fi, ok := mmi.fields.GetByAny(ex); ok && fi.rel && fi.fieldType != RelManyToMany { 110 if fi, ok := mmi.fields.GetByAny(ex); ok && fi.rel && fi.fieldType != RelManyToMany {
108 names = append(names, fi.name) 111 names = append(names, fi.name)
109 mmi = fi.relModelInfo 112 mmi = fi.relModelInfo
110 113
111 jt := t.set(names, mmi, fi, fi.null == false) 114 if fi.null {
115 inner = false
116 }
117
118 jt := t.set(names, mmi, fi, inner)
112 jt.jtl = jtl 119 jt.jtl = jtl
113 120
114 if fi.reverse { 121 if fi.reverse {
115 cansel = false 122 cancel = false
116 } 123 }
117 124
118 if cansel { 125 if cancel {
119 jt.sel = depth > 0 126 jt.sel = depth > 0
120 127
121 if i < relsNum { 128 if i < relsNum {
...@@ -178,9 +185,8 @@ func (t *dbTables) getJoinSql() (join string) { ...@@ -178,9 +185,8 @@ func (t *dbTables) getJoinSql() (join string) {
178 return 185 return
179 } 186 }
180 187
181 func (d *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string, info *fieldInfo, success bool) { 188 func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string, info *fieldInfo, success bool) {
182 var ( 189 var (
183 ffi *fieldInfo
184 jtl *dbTable 190 jtl *dbTable
185 mmi = mi 191 mmi = mi
186 ) 192 )
...@@ -188,15 +194,16 @@ func (d *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string ...@@ -188,15 +194,16 @@ func (d *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string
188 num := len(exprs) - 1 194 num := len(exprs) - 1
189 names := make([]string, 0) 195 names := make([]string, 0)
190 196
197 inner := true
198
191 for i, ex := range exprs { 199 for i, ex := range exprs {
192 exist := false
193 200
194 check:
195 fi, ok := mmi.fields.GetByAny(ex) 201 fi, ok := mmi.fields.GetByAny(ex)
196 202
197 if ok { 203 if ok {
198 204
199 if num != i { 205 isRel := fi.rel || fi.reverse
206
200 names = append(names, fi.name) 207 names = append(names, fi.name)
201 208
202 switch { 209 switch {
...@@ -207,54 +214,47 @@ func (d *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string ...@@ -207,54 +214,47 @@ func (d *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string
207 } 214 }
208 case fi.reverse: 215 case fi.reverse:
209 mmi = fi.reverseFieldInfo.mi 216 mmi = fi.reverseFieldInfo.mi
210 if fi.reverseFieldInfo.fieldType == RelManyToMany {
211 mmi = fi.reverseFieldInfo.relThroughModelInfo
212 } 217 }
213 default: 218
214 return 219 if isRel && (fi.mi.isThrough == false || num != i) {
220 if fi.null {
221 inner = false
215 } 222 }
216 223
217 jt, _ := d.add(names, mmi, fi, fi.null == false) 224 jt, _ := t.add(names, mmi, fi, inner)
218 jt.jtl = jtl 225 jt.jtl = jtl
219 jtl = jt 226 jtl = jt
220
221 if fi.rel && fi.fieldType == RelManyToMany {
222 ex = fi.relModelInfo.name
223 goto check
224 }
225
226 if fi.reverse && fi.reverseFieldInfo.fieldType == RelManyToMany {
227 ex = fi.reverseFieldInfo.mi.name
228 goto check
229 } 227 }
230 228
231 exist = true 229 if num == i {
232 230 if i == 0 || jtl == nil {
233 } else {
234
235 if ffi == nil {
236 index = "T0" 231 index = "T0"
237 } else { 232 } else {
238 index = jtl.index 233 index = jtl.index
239 } 234 }
235
240 info = fi 236 info = fi
241 if jtl != nil { 237
242 name = jtl.name + ExprSep + fi.name 238 if jtl == nil {
243 } else {
244 name = fi.name 239 name = fi.name
240 } else {
241 name = jtl.name + ExprSep + fi.name
245 } 242 }
246 243
247 switch fi.fieldType { 244 switch {
248 case RelManyToMany, RelReverseMany: 245 case fi.rel:
249 default: 246
250 exist = true 247 case fi.reverse:
248 switch fi.reverseFieldInfo.fieldType {
249 case RelOneToOne, RelForeignKey:
250 index = jtl.index
251 info = fi.reverseFieldInfo.mi.fields.pk
252 name = info.name
251 } 253 }
252 } 254 }
253
254 ffi = fi
255 } 255 }
256 256
257 if exist == false { 257 } else {
258 index = "" 258 index = ""
259 name = "" 259 name = ""
260 info = nil 260 info = nil
...@@ -267,16 +267,15 @@ func (d *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string ...@@ -267,16 +267,15 @@ func (d *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string
267 return 267 return
268 } 268 }
269 269
270 func (d *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (where string, params []interface{}) { 270 func (t *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (where string, params []interface{}) {
271 if cond == nil || cond.IsEmpty() { 271 if cond == nil || cond.IsEmpty() {
272 return 272 return
273 } 273 }
274 274
275 Q := d.base.TableQuote() 275 Q := t.base.TableQuote()
276 276
277 mi := d.mi 277 mi := t.mi
278 278
279 // outFor:
280 for i, p := range cond.params { 279 for i, p := range cond.params {
281 if i > 0 { 280 if i > 0 {
282 if p.isOr { 281 if p.isOr {
...@@ -289,7 +288,7 @@ func (d *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (whe ...@@ -289,7 +288,7 @@ func (d *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (whe
289 where += "NOT " 288 where += "NOT "
290 } 289 }
291 if p.isCond { 290 if p.isCond {
292 w, ps := d.getCondSql(p.cond, true, tz) 291 w, ps := t.getCondSql(p.cond, true, tz)
293 if w != "" { 292 if w != "" {
294 w = fmt.Sprintf("( %s) ", w) 293 w = fmt.Sprintf("( %s) ", w)
295 } 294 }
...@@ -305,7 +304,7 @@ func (d *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (whe ...@@ -305,7 +304,7 @@ func (d *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (whe
305 exprs = exprs[:num] 304 exprs = exprs[:num]
306 } 305 }
307 306
308 index, _, fi, suc := d.parseExprs(mi, exprs) 307 index, _, fi, suc := t.parseExprs(mi, exprs)
309 if suc == false { 308 if suc == false {
310 panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(p.exprs, ExprSep))) 309 panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(p.exprs, ExprSep)))
311 } 310 }
...@@ -314,10 +313,10 @@ func (d *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (whe ...@@ -314,10 +313,10 @@ func (d *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (whe
314 operator = "exact" 313 operator = "exact"
315 } 314 }
316 315
317 operSql, args := d.base.GenerateOperatorSql(mi, fi, operator, p.args, tz) 316 operSql, args := t.base.GenerateOperatorSql(mi, fi, operator, p.args, tz)
318 317
319 leftCol := fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q) 318 leftCol := fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q)
320 d.base.GenerateOperatorLeftCol(fi, operator, &leftCol) 319 t.base.GenerateOperatorLeftCol(fi, operator, &leftCol)
321 320
322 where += fmt.Sprintf("%s %s ", leftCol, operSql) 321 where += fmt.Sprintf("%s %s ", leftCol, operSql)
323 params = append(params, args...) 322 params = append(params, args...)
...@@ -332,12 +331,12 @@ func (d *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (whe ...@@ -332,12 +331,12 @@ func (d *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (whe
332 return 331 return
333 } 332 }
334 333
335 func (d *dbTables) getOrderSql(orders []string) (orderSql string) { 334 func (t *dbTables) getOrderSql(orders []string) (orderSql string) {
336 if len(orders) == 0 { 335 if len(orders) == 0 {
337 return 336 return
338 } 337 }
339 338
340 Q := d.base.TableQuote() 339 Q := t.base.TableQuote()
341 340
342 orderSqls := make([]string, 0, len(orders)) 341 orderSqls := make([]string, 0, len(orders))
343 for _, order := range orders { 342 for _, order := range orders {
...@@ -348,7 +347,7 @@ func (d *dbTables) getOrderSql(orders []string) (orderSql string) { ...@@ -348,7 +347,7 @@ func (d *dbTables) getOrderSql(orders []string) (orderSql string) {
348 } 347 }
349 exprs := strings.Split(order, ExprSep) 348 exprs := strings.Split(order, ExprSep)
350 349
351 index, _, fi, suc := d.parseExprs(d.mi, exprs) 350 index, _, fi, suc := t.parseExprs(t.mi, exprs)
352 if suc == false { 351 if suc == false {
353 panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep))) 352 panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep)))
354 } 353 }
...@@ -360,14 +359,14 @@ func (d *dbTables) getOrderSql(orders []string) (orderSql string) { ...@@ -360,14 +359,14 @@ func (d *dbTables) getOrderSql(orders []string) (orderSql string) {
360 return 359 return
361 } 360 }
362 361
363 func (d *dbTables) getLimitSql(mi *modelInfo, offset int64, limit int64) (limits string) { 362 func (t *dbTables) getLimitSql(mi *modelInfo, offset int64, limit int64) (limits string) {
364 if limit == 0 { 363 if limit == 0 {
365 limit = int64(DefaultRowsLimit) 364 limit = int64(DefaultRowsLimit)
366 } 365 }
367 if limit < 0 { 366 if limit < 0 {
368 // no limit 367 // no limit
369 if offset > 0 { 368 if offset > 0 {
370 maxLimit := d.base.MaxLimit() 369 maxLimit := t.base.MaxLimit()
371 if maxLimit == 0 { 370 if maxLimit == 0 {
372 limits = fmt.Sprintf("OFFSET %d", offset) 371 limits = fmt.Sprintf("OFFSET %d", offset)
373 } else { 372 } else {
......
...@@ -121,7 +121,6 @@ func bootStrap() { ...@@ -121,7 +121,6 @@ func bootStrap() {
121 err = errors.New(msg) 121 err = errors.New(msg)
122 goto end 122 goto end
123 } 123 }
124 err = nil
125 } else { 124 } else {
126 i := newM2MModelInfo(mi, mii) 125 i := newM2MModelInfo(mi, mii)
127 if fi.relTable != "" { 126 if fi.relTable != "" {
...@@ -135,6 +134,8 @@ func bootStrap() { ...@@ -135,6 +134,8 @@ func bootStrap() {
135 fi.relTable = i.table 134 fi.relTable = i.table
136 fi.relThroughModelInfo = i 135 fi.relThroughModelInfo = i
137 } 136 }
137
138 fi.relThroughModelInfo.isThrough = true
138 } 139 }
139 } 140 }
140 } 141 }
...@@ -152,6 +153,7 @@ func bootStrap() { ...@@ -152,6 +153,7 @@ func bootStrap() {
152 break 153 break
153 } 154 }
154 } 155 }
156
155 if inModel == false { 157 if inModel == false {
156 rmi := fi.relModelInfo 158 rmi := fi.relModelInfo
157 ffi := new(fieldInfo) 159 ffi := new(fieldInfo)
...@@ -185,9 +187,34 @@ func bootStrap() { ...@@ -185,9 +187,34 @@ func bootStrap() {
185 } 187 }
186 } 188 }
187 189
190 models = modelCache.all()
188 for _, mi := range models { 191 for _, mi := range models {
189 if fields, ok := mi.fields.fieldsByType[RelReverseOne]; ok { 192 for _, fi := range mi.fields.fieldsRel {
190 for _, fi := range fields { 193 switch fi.fieldType {
194 case RelManyToMany:
195 for _, ffi := range fi.relThroughModelInfo.fields.fieldsRel {
196 switch ffi.fieldType {
197 case RelOneToOne, RelForeignKey:
198 if ffi.relModelInfo == fi.relModelInfo {
199 fi.reverseFieldInfoTwo = ffi
200 }
201 if ffi.relModelInfo == mi {
202 fi.reverseField = ffi.name
203 fi.reverseFieldInfo = ffi
204 }
205 }
206 }
207
208 if fi.reverseFieldInfoTwo == nil {
209 err = fmt.Errorf("can not find m2m field for m2m model `%s`, ensure your m2m model defined correct",
210 fi.relThroughModelInfo.fullName)
211 goto end
212 }
213 }
214 }
215 for _, fi := range mi.fields.fieldsReverse {
216 switch fi.fieldType {
217 case RelReverseOne:
191 found := false 218 found := false
192 mForA: 219 mForA:
193 for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelOneToOne] { 220 for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelOneToOne] {
...@@ -195,6 +222,9 @@ func bootStrap() { ...@@ -195,6 +222,9 @@ func bootStrap() {
195 found = true 222 found = true
196 fi.reverseField = ffi.name 223 fi.reverseField = ffi.name
197 fi.reverseFieldInfo = ffi 224 fi.reverseFieldInfo = ffi
225
226 ffi.reverseField = fi.name
227 ffi.reverseFieldInfo = fi
198 break mForA 228 break mForA
199 } 229 }
200 } 230 }
...@@ -202,10 +232,7 @@ func bootStrap() { ...@@ -202,10 +232,7 @@ func bootStrap() {
202 err = fmt.Errorf("reverse field `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName) 232 err = fmt.Errorf("reverse field `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName)
203 goto end 233 goto end
204 } 234 }
205 } 235 case RelReverseMany:
206 }
207 if fields, ok := mi.fields.fieldsByType[RelReverseMany]; ok {
208 for _, fi := range fields {
209 found := false 236 found := false
210 mForB: 237 mForB:
211 for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelForeignKey] { 238 for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelForeignKey] {
...@@ -213,6 +240,10 @@ func bootStrap() { ...@@ -213,6 +240,10 @@ func bootStrap() {
213 found = true 240 found = true
214 fi.reverseField = ffi.name 241 fi.reverseField = ffi.name
215 fi.reverseFieldInfo = ffi 242 fi.reverseFieldInfo = ffi
243
244 ffi.reverseField = fi.name
245 ffi.reverseFieldInfo = fi
246
216 break mForB 247 break mForB
217 } 248 }
218 } 249 }
...@@ -221,14 +252,20 @@ func bootStrap() { ...@@ -221,14 +252,20 @@ func bootStrap() {
221 for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelManyToMany] { 252 for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelManyToMany] {
222 if ffi.relModelInfo == mi { 253 if ffi.relModelInfo == mi {
223 found = true 254 found = true
224 fi.reverseField = ffi.name 255
225 fi.reverseFieldInfo = ffi 256 fi.reverseField = ffi.reverseFieldInfoTwo.name
257 fi.reverseFieldInfo = ffi.reverseFieldInfoTwo
258 fi.relThroughModelInfo = ffi.relThroughModelInfo
259 fi.reverseFieldInfoTwo = ffi.reverseFieldInfo
260 fi.reverseFieldInfoM2M = ffi
261 ffi.reverseFieldInfoM2M = fi
262
226 break mForC 263 break mForC
227 } 264 }
228 } 265 }
229 } 266 }
230 if found == false { 267 if found == false {
231 err = fmt.Errorf("reverse field `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName) 268 err = fmt.Errorf("reverse field for `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName)
232 goto end 269 goto end
233 } 270 }
234 } 271 }
......
...@@ -103,6 +103,8 @@ type fieldInfo struct { ...@@ -103,6 +103,8 @@ type fieldInfo struct {
103 reverse bool 103 reverse bool
104 reverseField string 104 reverseField string
105 reverseFieldInfo *fieldInfo 105 reverseFieldInfo *fieldInfo
106 reverseFieldInfoTwo *fieldInfo
107 reverseFieldInfoM2M *fieldInfo
106 relTable string 108 relTable string
107 relThrough string 109 relThrough string
108 relThroughModelInfo *modelInfo 110 relThroughModelInfo *modelInfo
......
...@@ -16,6 +16,8 @@ type modelInfo struct { ...@@ -16,6 +16,8 @@ type modelInfo struct {
16 fields *fields 16 fields *fields
17 manual bool 17 manual bool
18 addrField reflect.Value 18 addrField reflect.Value
19 uniques []string
20 isThrough bool
19 } 21 }
20 22
21 func newModelInfo(val reflect.Value) (info *modelInfo) { 23 func newModelInfo(val reflect.Value) (info *modelInfo) {
...@@ -118,5 +120,7 @@ func newM2MModelInfo(m1, m2 *modelInfo) (info *modelInfo) { ...@@ -118,5 +120,7 @@ func newM2MModelInfo(m1, m2 *modelInfo) (info *modelInfo) {
118 info.fields.Add(f1) 120 info.fields.Add(f1)
119 info.fields.Add(f2) 121 info.fields.Add(f2)
120 info.fields.pk = fa 122 info.fields.pk = fa
123
124 info.uniques = []string{f1.column, f2.column}
121 return 125 return
122 } 126 }
......
...@@ -103,6 +103,7 @@ type Profile struct { ...@@ -103,6 +103,7 @@ type Profile struct {
103 Age int16 103 Age int16
104 Money float64 104 Money float64
105 User *User `orm:"reverse(one)" json:"-"` 105 User *User `orm:"reverse(one)" json:"-"`
106 BestPost *Post `orm:"rel(one);null"`
106 } 107 }
107 108
108 func (u *Profile) TableName() string { 109 func (u *Profile) TableName() string {
...@@ -138,6 +139,7 @@ func NewPost() *Post { ...@@ -138,6 +139,7 @@ func NewPost() *Post {
138 type Tag struct { 139 type Tag struct {
139 Id int 140 Id int
140 Name string `orm:"size(30)"` 141 Name string `orm:"size(30)"`
142 BestPost *Post `orm:"rel(one);null"`
141 Posts []*Post `orm:"reverse(many)" json:"-"` 143 Posts []*Post `orm:"reverse(many)" json:"-"`
142 } 144 }
143 145
......
...@@ -18,7 +18,7 @@ var ( ...@@ -18,7 +18,7 @@ var (
18 Debug = false 18 Debug = false
19 DebugLog = NewLog(os.Stderr) 19 DebugLog = NewLog(os.Stderr)
20 DefaultRowsLimit = 1000 20 DefaultRowsLimit = 1000
21 DefaultRelsDepth = 5 21 DefaultRelsDepth = 2
22 DefaultTimeLoc = time.Local 22 DefaultTimeLoc = time.Local
23 ErrTxHasBegan = errors.New("<Ormer.Begin> transaction already begin") 23 ErrTxHasBegan = errors.New("<Ormer.Begin> transaction already begin")
24 ErrTxDone = errors.New("<Ormer.Commit/Rollback> transaction not begin") 24 ErrTxDone = errors.New("<Ormer.Commit/Rollback> transaction not begin")
...@@ -53,6 +53,14 @@ func (o *orm) getMiInd(md interface{}) (mi *modelInfo, ind reflect.Value) { ...@@ -53,6 +53,14 @@ func (o *orm) getMiInd(md interface{}) (mi *modelInfo, ind reflect.Value) {
53 panic(fmt.Errorf("<Ormer> table: `%s` not found, maybe not RegisterModel", name)) 53 panic(fmt.Errorf("<Ormer> table: `%s` not found, maybe not RegisterModel", name))
54 } 54 }
55 55
56 func (o *orm) getFieldInfo(mi *modelInfo, name string) *fieldInfo {
57 fi, ok := mi.fields.GetByAny(name)
58 if !ok {
59 panic(fmt.Errorf("<Ormer> cannot find field `%s` for model `%s`", name, mi.fullName))
60 }
61 return fi
62 }
63
56 func (o *orm) Read(md interface{}, cols ...string) error { 64 func (o *orm) Read(md interface{}, cols ...string) error {
57 mi, ind := o.getMiInd(md) 65 mi, ind := o.getMiInd(md)
58 err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols) 66 err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols)
...@@ -107,22 +115,152 @@ func (o *orm) Delete(md interface{}) (int64, error) { ...@@ -107,22 +115,152 @@ func (o *orm) Delete(md interface{}) (int64, error) {
107 return num, nil 115 return num, nil
108 } 116 }
109 117
110 func (o *orm) M2mAdd(md interface{}, name string, mds ...interface{}) (int64, error) { 118 func (o *orm) QueryM2M(md interface{}, name string) QueryM2Mer {
111 // TODO 119 mi, ind := o.getMiInd(md)
112 panic(ErrNotImplement) 120 fi := o.getFieldInfo(mi, name)
113 return 0, nil 121
122 if fi.fieldType != RelManyToMany {
123 panic(fmt.Errorf("<Ormer.QueryM2M> name `%s` for model `%s` is not a m2m field", fi.name, mi.fullName))
124 }
125
126 return newQueryM2M(md, o, mi, fi, ind)
127 }
128
129 func (o *orm) LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) {
130 _, fi, ind, qseter := o.queryRelated(md, name)
131
132 qs := qseter.(*querySet)
133
134 var relDepth int
135 var limit, offset int64
136 var order string
137 for i, arg := range args {
138 switch i {
139 case 0:
140 if v, ok := arg.(bool); ok {
141 if v {
142 relDepth = DefaultRelsDepth
143 }
144 } else if v, ok := arg.(int); ok {
145 relDepth = v
146 }
147 case 1:
148 limit = ToInt64(arg)
149 case 2:
150 offset = ToInt64(arg)
151 case 3:
152 order, _ = arg.(string)
153 }
154 }
155
156 switch fi.fieldType {
157 case RelOneToOne, RelForeignKey, RelReverseOne:
158 limit = 1
159 offset = 0
160 }
161
162 qs.limit = limit
163 qs.offset = offset
164 qs.relDepth = relDepth
165
166 if len(order) > 0 {
167 qs.orders = []string{order}
168 }
169
170 find := ind.Field(fi.fieldIndex)
171
172 var nums int64
173 var err error
174 switch fi.fieldType {
175 case RelOneToOne, RelForeignKey, RelReverseOne:
176 val := reflect.New(find.Type().Elem())
177 container := val.Interface()
178 err = qs.One(container)
179 if err == nil {
180 find.Set(val)
181 nums = 1
182 }
183 default:
184 nums, err = qs.All(find.Addr().Interface())
185 }
186
187 return nums, err
188 }
189
190 func (o *orm) QueryRelated(md interface{}, name string) QuerySeter {
191 // is this api needed ?
192 _, _, _, qs := o.queryRelated(md, name)
193 return qs
114 } 194 }
115 195
116 func (o *orm) M2mDel(md interface{}, name string, mds ...interface{}) (int64, error) { 196 func (o *orm) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, reflect.Value, QuerySeter) {
117 // TODO 197 mi, ind := o.getMiInd(md)
118 panic(ErrNotImplement) 198 fi := o.getFieldInfo(mi, name)
119 return 0, nil 199
200 _, _, exist := getExistPk(mi, ind)
201 if exist == false {
202 panic(ErrMissPK)
203 }
204
205 var qs *querySet
206
207 switch fi.fieldType {
208 case RelOneToOne, RelForeignKey, RelManyToMany:
209 if !fi.inModel {
210 break
211 }
212 qs = o.getRelQs(md, mi, fi)
213 case RelReverseOne, RelReverseMany:
214 if !fi.inModel {
215 break
216 }
217 qs = o.getReverseQs(md, mi, fi)
218 }
219
220 if qs == nil {
221 panic(fmt.Errorf("<Ormer> name `%s` for model `%s` is not an available rel/reverse field"))
222 }
223
224 return mi, fi, ind, qs
120 } 225 }
121 226
122 func (o *orm) LoadRel(md interface{}, name string) (int64, error) { 227 func (o *orm) getReverseQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet {
123 // TODO 228 switch fi.fieldType {
124 panic(ErrNotImplement) 229 case RelReverseOne, RelReverseMany:
125 return 0, nil 230 default:
231 panic(fmt.Errorf("<Ormer> name `%s` for model `%s` is not an available reverse field", fi.name, mi.fullName))
232 }
233
234 var q *querySet
235
236 if fi.fieldType == RelReverseMany && fi.reverseFieldInfo.mi.isThrough {
237 q = newQuerySet(o, fi.relModelInfo).(*querySet)
238 q.cond = NewCondition().And(fi.reverseFieldInfoM2M.column+ExprSep+fi.reverseFieldInfo.column, md)
239 } else {
240 q = newQuerySet(o, fi.reverseFieldInfo.mi).(*querySet)
241 q.cond = NewCondition().And(fi.reverseFieldInfo.column, md)
242 }
243
244 return q
245 }
246
247 func (o *orm) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet {
248 switch fi.fieldType {
249 case RelOneToOne, RelForeignKey, RelManyToMany:
250 default:
251 panic(fmt.Errorf("<Ormer> name `%s` for model `%s` is not an available rel field", fi.name, mi.fullName))
252 }
253
254 q := newQuerySet(o, fi.relModelInfo).(*querySet)
255 q.cond = NewCondition()
256
257 if fi.fieldType == RelManyToMany {
258 q.cond = q.cond.And(fi.reverseFieldInfoM2M.column+ExprSep+fi.reverseFieldInfo.column, md)
259 } else {
260 q.cond = q.cond.And(fi.reverseFieldInfo.column, md)
261 }
262
263 return q
126 } 264 }
127 265
128 func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) { 266 func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) {
......
1 package orm
2
3 import (
4 "reflect"
5 )
6
7 type queryM2M struct {
8 md interface{}
9 mi *modelInfo
10 fi *fieldInfo
11 qs *querySet
12 ind reflect.Value
13 }
14
15 func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
16 fi := o.fi
17 mi := fi.relThroughModelInfo
18 mfi := fi.reverseFieldInfo
19 rfi := fi.reverseFieldInfoTwo
20
21 orm := o.qs.orm
22 dbase := orm.alias.DbBaser
23
24 var models []interface{}
25
26 for _, md := range mds {
27 val := reflect.ValueOf(md)
28 if val.Kind() == reflect.Slice || val.Kind() == reflect.Array {
29 for i := 0; i < val.Len(); i++ {
30 v := val.Index(i)
31 if v.CanInterface() {
32 models = append(models, v.Interface())
33 }
34 }
35 } else {
36 models = append(models, md)
37 }
38 }
39
40 _, v1, exist := getExistPk(o.mi, o.ind)
41 if exist == false {
42 panic(ErrMissPK)
43 }
44
45 names := []string{mfi.column, rfi.column}
46
47 var nums int64
48 for _, md := range models {
49
50 ind := reflect.Indirect(reflect.ValueOf(md))
51
52 var v2 interface{}
53 if ind.Kind() != reflect.Struct {
54 v2 = ind.Interface()
55 } else {
56 _, v2, exist = getExistPk(fi.relModelInfo, ind)
57 if exist == false {
58 panic(ErrMissPK)
59 }
60 }
61
62 values := []interface{}{v1, v2}
63 _, err := dbase.InsertValue(orm.db, mi, names, values)
64 if err != nil {
65 return nums, err
66 }
67
68 nums += 1
69 }
70
71 return nums, nil
72 }
73
74 func (o *queryM2M) Remove(mds ...interface{}) (int64, error) {
75 fi := o.fi
76 qs := o.qs.Filter(fi.reverseFieldInfo.name, o.md)
77
78 nums, err := qs.Filter(fi.reverseFieldInfoTwo.name+ExprSep+"in", mds).Delete()
79 if err != nil {
80 return nums, err
81 }
82 return nums, nil
83 }
84
85 func (o *queryM2M) Exist(md interface{}) bool {
86 fi := o.fi
87 return o.qs.Filter(fi.reverseFieldInfo.name, o.md).
88 Filter(fi.reverseFieldInfoTwo.name, md).Exist()
89 }
90
91 func (o *queryM2M) Clear() (int64, error) {
92 fi := o.fi
93 return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Delete()
94 }
95
96 func (o *queryM2M) Count() (int64, error) {
97 fi := o.fi
98 return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Count()
99 }
100
101 var _ QueryM2Mer = new(queryM2M)
102
103 func newQueryM2M(md interface{}, o *orm, mi *modelInfo, fi *fieldInfo, ind reflect.Value) QueryM2Mer {
104 qm2m := new(queryM2M)
105 qm2m.md = md
106 qm2m.mi = mi
107 qm2m.fi = fi
108 qm2m.ind = ind
109 qm2m.qs = newQuerySet(o, fi.relThroughModelInfo).(*querySet)
110 return qm2m
111 }
...@@ -48,9 +48,9 @@ func ValuesCompare(is bool, a interface{}, args ...interface{}) (err error, ok b ...@@ -48,9 +48,9 @@ func ValuesCompare(is bool, a interface{}, args ...interface{}) (err error, ok b
48 ok = is && ok || !is && !ok 48 ok = is && ok || !is && !ok
49 if !ok { 49 if !ok {
50 if is { 50 if is {
51 err = fmt.Errorf("expected: a == `%v`, get `%v`", b, a) 51 err = fmt.Errorf("expected: `%v`, get `%v`", b, a)
52 } else { 52 } else {
53 err = fmt.Errorf("expected: a != `%v`, get `%v`", b, a) 53 err = fmt.Errorf("expected: `%v`, get `%v`", b, a)
54 } 54 }
55 } 55 }
56 56
...@@ -419,7 +419,7 @@ func TestInsertTestData(t *testing.T) { ...@@ -419,7 +419,7 @@ func TestInsertTestData(t *testing.T) {
419 throwFail(t, AssertIs(id, 4)) 419 throwFail(t, AssertIs(id, 4))
420 420
421 tags := []*Tag{ 421 tags := []*Tag{
422 &Tag{Name: "golang"}, 422 &Tag{Name: "golang", BestPost: &Post{Id: 2}},
423 &Tag{Name: "example"}, 423 &Tag{Name: "example"},
424 &Tag{Name: "format"}, 424 &Tag{Name: "format"},
425 &Tag{Name: "c++"}, 425 &Tag{Name: "c++"},
...@@ -454,7 +454,13 @@ The program—and web server—godoc processes Go source files to extract docume ...@@ -454,7 +454,13 @@ The program—and web server—godoc processes Go source files to extract docume
454 id, err := dORM.Insert(post) 454 id, err := dORM.Insert(post)
455 throwFail(t, err) 455 throwFail(t, err)
456 throwFail(t, AssertIs(id > 0, true)) 456 throwFail(t, AssertIs(id > 0, true))
457 // dORM.M2mAdd(post, "tags", post.Tags) 457
458 num := len(post.Tags)
459 if num > 0 {
460 nums, err := dORM.QueryM2M(post, "tags").Add(post.Tags)
461 throwFailNow(t, err)
462 throwFailNow(t, AssertIs(nums, num))
463 }
458 } 464 }
459 465
460 for _, comment := range comments { 466 for _, comment := range comments {
...@@ -590,6 +596,68 @@ func TestOperators(t *testing.T) { ...@@ -590,6 +596,68 @@ func TestOperators(t *testing.T) {
590 throwFail(t, AssertIs(num, 2)) 596 throwFail(t, AssertIs(num, 2))
591 } 597 }
592 598
599 func TestSetCond(t *testing.T) {
600 cond := NewCondition()
601 cond1 := cond.And("profile__isnull", false).AndNot("status__in", 1).Or("profile__age__gt", 2000)
602
603 qs := dORM.QueryTable("user")
604 num, err := qs.SetCond(cond1).Count()
605 throwFail(t, err)
606 throwFail(t, AssertIs(num, 1))
607
608 cond2 := cond.AndCond(cond1).OrCond(cond.And("user_name", "slene"))
609 num, err = qs.SetCond(cond2).Count()
610 throwFail(t, err)
611 throwFail(t, AssertIs(num, 2))
612 }
613
614 func TestLimit(t *testing.T) {
615 var posts []*Post
616 qs := dORM.QueryTable("post")
617 num, err := qs.Limit(1).All(&posts)
618 throwFail(t, err)
619 throwFail(t, AssertIs(num, 1))
620
621 num, err = qs.Limit(-1).All(&posts)
622 throwFail(t, err)
623 throwFail(t, AssertIs(num, 4))
624
625 num, err = qs.Limit(-1, 2).All(&posts)
626 throwFail(t, err)
627 throwFail(t, AssertIs(num, 2))
628
629 num, err = qs.Limit(0, 2).All(&posts)
630 throwFail(t, err)
631 throwFail(t, AssertIs(num, 2))
632 }
633
634 func TestOffset(t *testing.T) {
635 var posts []*Post
636 qs := dORM.QueryTable("post")
637 num, err := qs.Limit(1).Offset(2).All(&posts)
638 throwFail(t, err)
639 throwFail(t, AssertIs(num, 1))
640
641 num, err = qs.Offset(2).All(&posts)
642 throwFail(t, err)
643 throwFail(t, AssertIs(num, 2))
644 }
645
646 func TestOrderBy(t *testing.T) {
647 qs := dORM.QueryTable("user")
648 num, err := qs.OrderBy("-status").Filter("user_name", "nobody").Count()
649 throwFail(t, err)
650 throwFail(t, AssertIs(num, 1))
651
652 num, err = qs.OrderBy("status").Filter("user_name", "slene").Count()
653 throwFail(t, err)
654 throwFail(t, AssertIs(num, 1))
655
656 num, err = qs.OrderBy("-profile__age").Filter("user_name", "astaxie").Count()
657 throwFail(t, err)
658 throwFail(t, AssertIs(num, 1))
659 }
660
593 func TestAll(t *testing.T) { 661 func TestAll(t *testing.T) {
594 var users []*User 662 var users []*User
595 qs := dORM.QueryTable("user") 663 qs := dORM.QueryTable("user")
...@@ -758,66 +826,292 @@ func TestRelatedSel(t *testing.T) { ...@@ -758,66 +826,292 @@ func TestRelatedSel(t *testing.T) {
758 throwFailNow(t, AssertIs(posts[3].User.UserName, "nobody")) 826 throwFailNow(t, AssertIs(posts[3].User.UserName, "nobody"))
759 } 827 }
760 828
761 func TestSetCond(t *testing.T) { 829 func TestReverseQuery(t *testing.T) {
762 cond := NewCondition() 830 var profile Profile
763 cond1 := cond.And("profile__isnull", false).AndNot("status__in", 1).Or("profile__age__gt", 2000) 831 err := dORM.QueryTable("user_profile").Filter("User", 3).One(&profile)
832 throwFailNow(t, err)
833 throwFailNow(t, AssertIs(profile.Age, 30))
764 834
765 qs := dORM.QueryTable("user") 835 profile = Profile{}
766 num, err := qs.SetCond(cond1).Count() 836 err = dORM.QueryTable("user_profile").Filter("User__UserName", "astaxie").One(&profile)
767 throwFail(t, err) 837 throwFailNow(t, err)
768 throwFail(t, AssertIs(num, 1)) 838 throwFailNow(t, AssertIs(profile.Age, 30))
769 839
770 cond2 := cond.AndCond(cond1).OrCond(cond.And("user_name", "slene")) 840 var user User
771 num, err = qs.SetCond(cond2).Count() 841 err = dORM.QueryTable("user").Filter("Posts__Title", "Examples").One(&user)
772 throwFail(t, err) 842 throwFailNow(t, err)
773 throwFail(t, AssertIs(num, 2)) 843 throwFailNow(t, AssertIs(user.UserName, "astaxie"))
774 } 844
845 user = User{}
846 err = dORM.QueryTable("user").Filter("Posts__User__UserName", "astaxie").Limit(1).One(&user)
847 throwFailNow(t, err)
848 throwFailNow(t, AssertIs(user.UserName, "astaxie"))
849
850 user = User{}
851 err = dORM.QueryTable("user").Filter("Posts__User__UserName", "astaxie").RelatedSel().Limit(1).One(&user)
852 throwFailNow(t, err)
853 throwFailNow(t, AssertIs(user.UserName, "astaxie"))
854 throwFailNow(t, AssertIs(user.Profile == nil, false))
855 throwFailNow(t, AssertIs(user.Profile.Age, 30))
775 856
776 func TestLimit(t *testing.T) {
777 var posts []*Post 857 var posts []*Post
778 qs := dORM.QueryTable("post") 858 num, err := dORM.QueryTable("post").Filter("Tags__Tag__Name", "golang").All(&posts)
779 num, err := qs.Limit(1).All(&posts) 859 throwFailNow(t, err)
780 throwFail(t, err) 860 throwFailNow(t, AssertIs(num, 3))
781 throwFail(t, AssertIs(num, 1)) 861 throwFailNow(t, AssertIs(posts[0].Title, "Introduction"))
782 862
783 num, err = qs.Limit(-1).All(&posts) 863 posts = []*Post{}
784 throwFail(t, err) 864 num, err = dORM.QueryTable("post").Filter("Tags__Tag__Name", "golang").Filter("User__UserName", "slene").All(&posts)
785 throwFail(t, AssertIs(num, 4)) 865 throwFailNow(t, err)
866 throwFailNow(t, AssertIs(num, 1))
867 throwFailNow(t, AssertIs(posts[0].Title, "Introduction"))
786 868
787 num, err = qs.Limit(-1, 2).All(&posts) 869 posts = []*Post{}
788 throwFail(t, err) 870 num, err = dORM.QueryTable("post").Filter("Tags__Tag__Name", "golang").
789 throwFail(t, AssertIs(num, 2)) 871 Filter("User__UserName", "slene").RelatedSel().All(&posts)
872 throwFailNow(t, err)
873 throwFailNow(t, AssertIs(num, 1))
874 throwFailNow(t, AssertIs(posts[0].User == nil, false))
875 throwFailNow(t, AssertIs(posts[0].User.UserName, "slene"))
790 876
791 num, err = qs.Limit(0, 2).All(&posts) 877 var tags []*Tag
792 throwFail(t, err) 878 num, err = dORM.QueryTable("tag").Filter("Posts__Post__Title", "Introduction").All(&tags)
793 throwFail(t, AssertIs(num, 2)) 879 throwFailNow(t, err)
880 throwFailNow(t, AssertIs(num, 1))
881 throwFailNow(t, AssertIs(tags[0].Name, "golang"))
882
883 tags = []*Tag{}
884 num, err = dORM.QueryTable("tag").Filter("Posts__Post__Title", "Introduction").
885 Filter("BestPost__User__UserName", "astaxie").All(&tags)
886 throwFailNow(t, err)
887 throwFailNow(t, AssertIs(num, 1))
888 throwFailNow(t, AssertIs(tags[0].Name, "golang"))
889
890 tags = []*Tag{}
891 num, err = dORM.QueryTable("tag").Filter("Posts__Post__Title", "Introduction").
892 Filter("BestPost__User__UserName", "astaxie").RelatedSel().All(&tags)
893 throwFailNow(t, err)
894 throwFailNow(t, AssertIs(num, 1))
895 throwFailNow(t, AssertIs(tags[0].Name, "golang"))
896 throwFailNow(t, AssertIs(tags[0].BestPost == nil, false))
897 throwFailNow(t, AssertIs(tags[0].BestPost.Title, "Examples"))
898 throwFailNow(t, AssertIs(tags[0].BestPost.User == nil, false))
899 throwFailNow(t, AssertIs(tags[0].BestPost.User.UserName, "astaxie"))
794 } 900 }
795 901
796 func TestOffset(t *testing.T) { 902 func TestLoadRelated(t *testing.T) {
797 var posts []*Post 903 // load reverse foreign key
798 qs := dORM.QueryTable("post") 904 user := User{Id: 3}
799 num, err := qs.Limit(1).Offset(2).All(&posts)
800 throwFail(t, err)
801 throwFail(t, AssertIs(num, 1))
802 905
803 num, err = qs.Offset(2).All(&posts) 906 err := dORM.Read(&user)
804 throwFail(t, err) 907 throwFailNow(t, err)
805 throwFail(t, AssertIs(num, 2)) 908
909 num, err := dORM.LoadRelated(&user, "Posts")
910 throwFailNow(t, err)
911 throwFailNow(t, AssertIs(num, 2))
912 throwFailNow(t, AssertIs(len(user.Posts), 2))
913 throwFailNow(t, AssertIs(user.Posts[0].User.Id, 3))
914
915 num, err = dORM.LoadRelated(&user, "Posts", true)
916 throwFailNow(t, err)
917 throwFailNow(t, AssertIs(len(user.Posts), 2))
918 throwFailNow(t, AssertIs(user.Posts[0].User.UserName, "astaxie"))
919
920 num, err = dORM.LoadRelated(&user, "Posts", true, 1)
921 throwFailNow(t, err)
922 throwFailNow(t, AssertIs(len(user.Posts), 1))
923
924 num, err = dORM.LoadRelated(&user, "Posts", true, 0, 0, "-Id")
925 throwFailNow(t, err)
926 throwFailNow(t, AssertIs(len(user.Posts), 2))
927 throwFailNow(t, AssertIs(user.Posts[0].Title, "Formatting"))
928
929 num, err = dORM.LoadRelated(&user, "Posts", true, 1, 1, "Id")
930 throwFailNow(t, err)
931 throwFailNow(t, AssertIs(len(user.Posts), 1))
932 throwFailNow(t, AssertIs(user.Posts[0].Title, "Formatting"))
933
934 // load reverse one to one
935 profile := Profile{Id: 3}
936 profile.BestPost = &Post{Id: 2}
937 num, err = dORM.Update(&profile, "BestPost")
938 throwFailNow(t, err)
939 throwFailNow(t, AssertIs(num, 1))
940
941 err = dORM.Read(&profile)
942 throwFailNow(t, err)
943
944 num, err = dORM.LoadRelated(&profile, "User")
945 throwFailNow(t, err)
946 throwFailNow(t, AssertIs(num, 1))
947 throwFailNow(t, AssertIs(profile.User == nil, false))
948 throwFailNow(t, AssertIs(profile.User.UserName, "astaxie"))
949
950 num, err = dORM.LoadRelated(&profile, "User", true)
951 throwFailNow(t, err)
952 throwFailNow(t, AssertIs(num, 1))
953 throwFailNow(t, AssertIs(profile.User == nil, false))
954 throwFailNow(t, AssertIs(profile.User.UserName, "astaxie"))
955 throwFailNow(t, AssertIs(profile.User.Profile.Age, profile.Age))
956
957 // load rel one to one
958 err = dORM.Read(&user)
959 throwFailNow(t, err)
960
961 num, err = dORM.LoadRelated(&user, "Profile")
962 throwFailNow(t, err)
963 throwFailNow(t, AssertIs(num, 1))
964 throwFailNow(t, AssertIs(user.Profile == nil, false))
965 throwFailNow(t, AssertIs(user.Profile.Age, 30))
966
967 num, err = dORM.LoadRelated(&user, "Profile", true)
968 throwFailNow(t, err)
969 throwFailNow(t, AssertIs(num, 1))
970 throwFailNow(t, AssertIs(user.Profile == nil, false))
971 throwFailNow(t, AssertIs(user.Profile.Age, 30))
972 throwFailNow(t, AssertIs(user.Profile.BestPost == nil, false))
973 throwFailNow(t, AssertIs(user.Profile.BestPost.Title, "Examples"))
974
975 post := Post{Id: 2}
976
977 // load rel foreign key
978 err = dORM.Read(&post)
979 throwFailNow(t, err)
980
981 num, err = dORM.LoadRelated(&post, "User")
982 throwFailNow(t, err)
983 throwFailNow(t, AssertIs(num, 1))
984 throwFailNow(t, AssertIs(post.User == nil, false))
985 throwFailNow(t, AssertIs(post.User.UserName, "astaxie"))
986
987 num, err = dORM.LoadRelated(&post, "User", true)
988 throwFailNow(t, err)
989 throwFailNow(t, AssertIs(num, 1))
990 throwFailNow(t, AssertIs(post.User == nil, false))
991 throwFailNow(t, AssertIs(post.User.UserName, "astaxie"))
992 throwFailNow(t, AssertIs(post.User.Profile == nil, false))
993 throwFailNow(t, AssertIs(post.User.Profile.Age, 30))
994
995 // load rel m2m
996 post = Post{Id: 2}
997
998 err = dORM.Read(&post)
999 throwFailNow(t, err)
1000
1001 num, err = dORM.LoadRelated(&post, "Tags")
1002 throwFailNow(t, err)
1003 throwFailNow(t, AssertIs(num, 2))
1004 throwFailNow(t, AssertIs(len(post.Tags), 2))
1005 throwFailNow(t, AssertIs(post.Tags[0].Name, "golang"))
1006
1007 num, err = dORM.LoadRelated(&post, "Tags", true)
1008 throwFailNow(t, err)
1009 throwFailNow(t, AssertIs(num, 2))
1010 throwFailNow(t, AssertIs(len(post.Tags), 2))
1011 throwFailNow(t, AssertIs(post.Tags[0].Name, "golang"))
1012 throwFailNow(t, AssertIs(post.Tags[0].BestPost == nil, false))
1013 throwFailNow(t, AssertIs(post.Tags[0].BestPost.User.UserName, "astaxie"))
1014
1015 // load reverse m2m
1016 tag := Tag{Id: 1}
1017
1018 err = dORM.Read(&tag)
1019 throwFailNow(t, err)
1020
1021 num, err = dORM.LoadRelated(&tag, "Posts")
1022 throwFailNow(t, err)
1023 throwFailNow(t, AssertIs(num, 3))
1024 throwFailNow(t, AssertIs(tag.Posts[0].Title, "Introduction"))
1025 throwFailNow(t, AssertIs(tag.Posts[0].User.Id, 2))
1026 throwFailNow(t, AssertIs(tag.Posts[0].User.Profile == nil, true))
1027
1028 num, err = dORM.LoadRelated(&tag, "Posts", true)
1029 throwFailNow(t, err)
1030 throwFailNow(t, AssertIs(num, 3))
1031 throwFailNow(t, AssertIs(tag.Posts[0].Title, "Introduction"))
1032 throwFailNow(t, AssertIs(tag.Posts[0].User.Id, 2))
1033 throwFailNow(t, AssertIs(tag.Posts[0].User.UserName, "slene"))
806 } 1034 }
807 1035
808 func TestOrderBy(t *testing.T) { 1036 func TestQueryM2M(t *testing.T) {
809 qs := dORM.QueryTable("user") 1037 post := Post{Id: 4}
810 num, err := qs.OrderBy("-status").Filter("user_name", "nobody").Count() 1038 m2m := dORM.QueryM2M(&post, "Tags")
811 throwFail(t, err)
812 throwFail(t, AssertIs(num, 1))
813 1039
814 num, err = qs.OrderBy("status").Filter("user_name", "slene").Count() 1040 tag1 := []*Tag{&Tag{Name: "TestTag1"}, &Tag{Name: "TestTag2"}}
815 throwFail(t, err) 1041 tag2 := &Tag{Name: "TestTag3"}
816 throwFail(t, AssertIs(num, 1)) 1042 tag3 := []interface{}{&Tag{Name: "TestTag4"}}
817 1043
818 num, err = qs.OrderBy("-profile__age").Filter("user_name", "astaxie").Count() 1044 tags := []interface{}{tag1[0], tag1[1], tag2, tag3[0]}
819 throwFail(t, err) 1045
820 throwFail(t, AssertIs(num, 1)) 1046 for _, tag := range tags {
1047 _, err := dORM.Insert(tag)
1048 throwFailNow(t, err)
1049 }
1050
1051 num, err := m2m.Add(tag1)
1052 throwFailNow(t, err)
1053 throwFailNow(t, AssertIs(num, 2))
1054
1055 num, err = m2m.Add(tag2)
1056 throwFailNow(t, err)
1057 throwFailNow(t, AssertIs(num, 1))
1058
1059 num, err = m2m.Add(tag3)
1060 throwFailNow(t, err)
1061 throwFailNow(t, AssertIs(num, 1))
1062
1063 num, err = m2m.Count()
1064 throwFailNow(t, err)
1065 throwFailNow(t, AssertIs(num, 5))
1066
1067 num, err = m2m.Remove(tag3)
1068 throwFailNow(t, err)
1069 throwFailNow(t, AssertIs(num, 1))
1070
1071 num, err = m2m.Count()
1072 throwFailNow(t, err)
1073 throwFailNow(t, AssertIs(num, 4))
1074
1075 exist := m2m.Exist(tag2)
1076 throwFailNow(t, AssertIs(exist, true))
1077
1078 num, err = m2m.Remove(tag2)
1079 throwFailNow(t, err)
1080 throwFailNow(t, AssertIs(num, 1))
1081
1082 exist = m2m.Exist(tag2)
1083 throwFailNow(t, AssertIs(exist, false))
1084
1085 num, err = m2m.Count()
1086 throwFailNow(t, err)
1087 throwFailNow(t, AssertIs(num, 3))
1088
1089 num, err = m2m.Clear()
1090 throwFailNow(t, err)
1091 throwFailNow(t, AssertIs(num, 3))
1092
1093 num, err = m2m.Count()
1094 throwFailNow(t, err)
1095 throwFailNow(t, AssertIs(num, 0))
1096 }
1097
1098 func TestQueryRelate(t *testing.T) {
1099 // post := &Post{Id: 2}
1100
1101 // qs := dORM.QueryRelate(post, "Tags")
1102 // num, err := qs.Count()
1103 // throwFailNow(t, err)
1104 // throwFailNow(t, AssertIs(num, 2))
1105
1106 // var tags []*Tag
1107 // num, err = qs.All(&tags)
1108 // throwFailNow(t, err)
1109 // throwFailNow(t, AssertIs(num, 2))
1110 // throwFailNow(t, AssertIs(tags[0].Name, "golang"))
1111
1112 // num, err = dORM.QueryTable("Tag").Filter("Posts__Post", 2).Count()
1113 // throwFailNow(t, err)
1114 // throwFailNow(t, AssertIs(num, 2))
821 } 1115 }
822 1116
823 func TestPrepareInsert(t *testing.T) { 1117 func TestPrepareInsert(t *testing.T) {
......
...@@ -24,9 +24,8 @@ type Ormer interface { ...@@ -24,9 +24,8 @@ type Ormer interface {
24 Insert(interface{}) (int64, error) 24 Insert(interface{}) (int64, error)
25 Update(interface{}, ...string) (int64, error) 25 Update(interface{}, ...string) (int64, error)
26 Delete(interface{}) (int64, error) 26 Delete(interface{}) (int64, error)
27 M2mAdd(interface{}, string, ...interface{}) (int64, error) 27 LoadRelated(interface{}, string, ...interface{}) (int64, error)
28 M2mDel(interface{}, string, ...interface{}) (int64, error) 28 QueryM2M(interface{}, string) QueryM2Mer
29 LoadRel(interface{}, string) (int64, error)
30 QueryTable(interface{}) QuerySeter 29 QueryTable(interface{}) QuerySeter
31 Using(string) error 30 Using(string) error
32 Begin() error 31 Begin() error
...@@ -61,6 +60,14 @@ type QuerySeter interface { ...@@ -61,6 +60,14 @@ type QuerySeter interface {
61 ValuesFlat(*ParamsList, string) (int64, error) 60 ValuesFlat(*ParamsList, string) (int64, error)
62 } 61 }
63 62
63 type QueryM2Mer interface {
64 Add(...interface{}) (int64, error)
65 Remove(...interface{}) (int64, error)
66 Exist(interface{}) bool
67 Clear() (int64, error)
68 Count() (int64, error)
69 }
70
64 type RawPreparer interface { 71 type RawPreparer interface {
65 Exec(...interface{}) (sql.Result, error) 72 Exec(...interface{}) (sql.Result, error)
66 Close() error 73 Close() error
...@@ -114,6 +121,7 @@ type txEnder interface { ...@@ -114,6 +121,7 @@ type txEnder interface {
114 type dbBaser interface { 121 type dbBaser interface {
115 Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) error 122 Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) error
116 Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) 123 Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
124 InsertValue(dbQuerier, *modelInfo, []string, []interface{}) (int64, error)
117 InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) 125 InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
118 Update(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error) 126 Update(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error)
119 Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error) 127 Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
...@@ -139,4 +147,5 @@ type dbBaser interface { ...@@ -139,4 +147,5 @@ type dbBaser interface {
139 ShowTablesQuery() string 147 ShowTablesQuery() string
140 ShowColumnsQuery(string) string 148 ShowColumnsQuery(string) string
141 IndexExists(dbQuerier, string, string) bool 149 IndexExists(dbQuerier, string, string) bool
150 collectFieldValue(*modelInfo, *fieldInfo, reflect.Value, bool, *time.Location) (interface{}, error)
142 } 151 }
......
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!