9e3ebc88 by slene

Merge pull request #513 from hobeone/develop

add support for sql.Null* types, thx hobeone
2 parents d05270d2 6e00cfb4
...@@ -103,16 +103,37 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val ...@@ -103,16 +103,37 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
103 } else { 103 } else {
104 switch fi.fieldType { 104 switch fi.fieldType {
105 case TypeBooleanField: 105 case TypeBooleanField:
106 if nb, ok := field.Interface().(sql.NullBool); ok {
107 value = nil
108 if nb.Valid {
109 value = nb.Bool
110 }
111 } else {
106 value = field.Bool() 112 value = field.Bool()
113 }
107 case TypeCharField, TypeTextField: 114 case TypeCharField, TypeTextField:
115 if ns, ok := field.Interface().(sql.NullString); ok {
116 value = nil
117 if ns.Valid {
118 value = ns.String
119 }
120 } else {
108 value = field.String() 121 value = field.String()
122 }
109 case TypeFloatField, TypeDecimalField: 123 case TypeFloatField, TypeDecimalField:
124 if nf, ok := field.Interface().(sql.NullFloat64); ok {
125 value = nil
126 if nf.Valid {
127 value = nf.Float64
128 }
129 } else {
110 vu := field.Interface() 130 vu := field.Interface()
111 if _, ok := vu.(float32); ok { 131 if _, ok := vu.(float32); ok {
112 value, _ = StrTo(ToStr(vu)).Float64() 132 value, _ = StrTo(ToStr(vu)).Float64()
113 } else { 133 } else {
114 value = field.Float() 134 value = field.Float()
115 } 135 }
136 }
116 case TypeDateField, TypeDateTimeField: 137 case TypeDateField, TypeDateTimeField:
117 value = field.Interface() 138 value = field.Interface()
118 if t, ok := value.(time.Time); ok { 139 if t, ok := value.(time.Time); ok {
...@@ -124,7 +145,14 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val ...@@ -124,7 +145,14 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
124 case fi.fieldType&IsPostiveIntegerField > 0: 145 case fi.fieldType&IsPostiveIntegerField > 0:
125 value = field.Uint() 146 value = field.Uint()
126 case fi.fieldType&IsIntegerField > 0: 147 case fi.fieldType&IsIntegerField > 0:
148 if ni, ok := field.Interface().(sql.NullInt64); ok {
149 value = nil
150 if ni.Valid {
151 value = ni.Int64
152 }
153 } else {
127 value = field.Int() 154 value = field.Int()
155 }
128 case fi.fieldType&IsRelField > 0: 156 case fi.fieldType&IsRelField > 0:
129 if field.IsNil() { 157 if field.IsNil() {
130 value = nil 158 value = nil
...@@ -1122,18 +1150,38 @@ setValue: ...@@ -1122,18 +1150,38 @@ setValue:
1122 switch { 1150 switch {
1123 case fieldType == TypeBooleanField: 1151 case fieldType == TypeBooleanField:
1124 if isNative { 1152 if isNative {
1153 if nb, ok := field.Interface().(sql.NullBool); ok {
1154 if value == nil {
1155 nb.Valid = false
1156 } else {
1157 nb.Bool = value.(bool)
1158 nb.Valid = true
1159 }
1160 field.Set(reflect.ValueOf(nb))
1161 } else {
1125 if value == nil { 1162 if value == nil {
1126 value = false 1163 value = false
1127 } 1164 }
1128 field.SetBool(value.(bool)) 1165 field.SetBool(value.(bool))
1129 } 1166 }
1167 }
1130 case fieldType == TypeCharField || fieldType == TypeTextField: 1168 case fieldType == TypeCharField || fieldType == TypeTextField:
1131 if isNative { 1169 if isNative {
1170 if ns, ok := field.Interface().(sql.NullString); ok {
1171 if value == nil {
1172 ns.Valid = false
1173 } else {
1174 ns.String = value.(string)
1175 ns.Valid = true
1176 }
1177 field.Set(reflect.ValueOf(ns))
1178 } else {
1132 if value == nil { 1179 if value == nil {
1133 value = "" 1180 value = ""
1134 } 1181 }
1135 field.SetString(value.(string)) 1182 field.SetString(value.(string))
1136 } 1183 }
1184 }
1137 case fieldType == TypeDateField || fieldType == TypeDateTimeField: 1185 case fieldType == TypeDateField || fieldType == TypeDateTimeField:
1138 if isNative { 1186 if isNative {
1139 if value == nil { 1187 if value == nil {
...@@ -1151,19 +1199,40 @@ setValue: ...@@ -1151,19 +1199,40 @@ setValue:
1151 } 1199 }
1152 } else { 1200 } else {
1153 if isNative { 1201 if isNative {
1202 if ni, ok := field.Interface().(sql.NullInt64); ok {
1203 if value == nil {
1204 ni.Valid = false
1205 } else {
1206 ni.Int64 = value.(int64)
1207 ni.Valid = true
1208 }
1209 field.Set(reflect.ValueOf(ni))
1210 } else {
1154 if value == nil { 1211 if value == nil {
1155 value = int64(0) 1212 value = int64(0)
1156 } 1213 }
1157 field.SetInt(value.(int64)) 1214 field.SetInt(value.(int64))
1158 } 1215 }
1159 } 1216 }
1217 }
1160 case fieldType == TypeFloatField || fieldType == TypeDecimalField: 1218 case fieldType == TypeFloatField || fieldType == TypeDecimalField:
1161 if isNative { 1219 if isNative {
1220 if nf, ok := field.Interface().(sql.NullFloat64); ok {
1221 if value == nil {
1222 nf.Valid = false
1223 } else {
1224 nf.Float64 = value.(float64)
1225 nf.Valid = true
1226 }
1227 field.Set(reflect.ValueOf(nf))
1228 } else {
1229
1162 if value == nil { 1230 if value == nil {
1163 value = float64(0) 1231 value = float64(0)
1164 } 1232 }
1165 field.SetFloat(value.(float64)) 1233 field.SetFloat(value.(float64))
1166 } 1234 }
1235 }
1167 case fieldType&IsRelField > 0: 1236 case fieldType&IsRelField > 0:
1168 if value != nil { 1237 if value != nil {
1169 fieldType = fi.relModelInfo.fields.pk.fieldType 1238 fieldType = fi.relModelInfo.fields.pk.fieldType
......
1 package orm 1 package orm
2 2
3 import ( 3 import (
4 "database/sql"
4 "encoding/json" 5 "encoding/json"
5 "fmt" 6 "fmt"
6 "os" 7 "os"
...@@ -137,6 +138,10 @@ type DataNull struct { ...@@ -137,6 +138,10 @@ type DataNull struct {
137 Float32 float32 `orm:"null"` 138 Float32 float32 `orm:"null"`
138 Float64 float64 `orm:"null"` 139 Float64 float64 `orm:"null"`
139 Decimal float64 `orm:"digits(8);decimals(4);null"` 140 Decimal float64 `orm:"digits(8);decimals(4);null"`
141 NullString sql.NullString `orm:"null"`
142 NullBool sql.NullBool `orm:"null"`
143 NullFloat64 sql.NullFloat64 `orm:"null"`
144 NullInt64 sql.NullInt64 `orm:"null"`
140 } 145 }
141 146
142 // only for mysql 147 // only for mysql
...@@ -303,9 +308,8 @@ go test -v github.com/astaxie/beego/orm ...@@ -303,9 +308,8 @@ go test -v github.com/astaxie/beego/orm
303 308
304 309
305 #### Sqlite3 310 #### Sqlite3
306 touch /path/to/orm_test.db
307 export ORM_DRIVER=sqlite3 311 export ORM_DRIVER=sqlite3
308 export ORM_SOURCE=/path/to/orm_test.db 312 export ORM_SOURCE='file:memory_test?mode=memory'
309 go test -v github.com/astaxie/beego/orm 313 go test -v github.com/astaxie/beego/orm
310 314
311 315
......
1 package orm 1 package orm
2 2
3 import ( 3 import (
4 "database/sql"
4 "fmt" 5 "fmt"
5 "reflect" 6 "reflect"
6 "strings" 7 "strings"
...@@ -98,30 +99,29 @@ func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col ...@@ -98,30 +99,29 @@ func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col
98 // return field type as type constant from reflect.Value 99 // return field type as type constant from reflect.Value
99 func getFieldType(val reflect.Value) (ft int, err error) { 100 func getFieldType(val reflect.Value) (ft int, err error) {
100 elm := reflect.Indirect(val) 101 elm := reflect.Indirect(val)
101 switch elm.Kind() { 102 switch elm.Interface().(type) {
102 case reflect.Int8: 103 case int8:
103 ft = TypeBitField 104 ft = TypeBitField
104 case reflect.Int16: 105 case int16:
105 ft = TypeSmallIntegerField 106 ft = TypeSmallIntegerField
106 case reflect.Int32, reflect.Int: 107 case int32, int:
107 ft = TypeIntegerField 108 ft = TypeIntegerField
108 case reflect.Int64: 109 case int64, sql.NullInt64:
109 ft = TypeBigIntegerField 110 ft = TypeBigIntegerField
110 case reflect.Uint8: 111 case uint8:
111 ft = TypePositiveBitField 112 ft = TypePositiveBitField
112 case reflect.Uint16: 113 case uint16:
113 ft = TypePositiveSmallIntegerField 114 ft = TypePositiveSmallIntegerField
114 case reflect.Uint32, reflect.Uint: 115 case uint32, uint:
115 ft = TypePositiveIntegerField 116 ft = TypePositiveIntegerField
116 case reflect.Uint64: 117 case uint64:
117 ft = TypePositiveBigIntegerField 118 ft = TypePositiveBigIntegerField
118 case reflect.Float32, reflect.Float64: 119 case float32, float64, sql.NullFloat64:
119 ft = TypeFloatField 120 ft = TypeFloatField
120 case reflect.Bool: 121 case bool, sql.NullBool:
121 ft = TypeBooleanField 122 ft = TypeBooleanField
122 case reflect.String: 123 case string, sql.NullString:
123 ft = TypeCharField 124 ft = TypeCharField
124 case reflect.Invalid:
125 default: 125 default:
126 if elm.CanInterface() { 126 if elm.CanInterface() {
127 if _, ok := elm.Interface().(time.Time); ok { 127 if _, ok := elm.Interface().(time.Time); ok {
......
...@@ -2,6 +2,7 @@ package orm ...@@ -2,6 +2,7 @@ package orm
2 2
3 import ( 3 import (
4 "bytes" 4 "bytes"
5 "database/sql"
5 "fmt" 6 "fmt"
6 "io/ioutil" 7 "io/ioutil"
7 "os" 8 "os"
...@@ -258,12 +259,45 @@ func TestNullDataTypes(t *testing.T) { ...@@ -258,12 +259,45 @@ func TestNullDataTypes(t *testing.T) {
258 err = dORM.Read(&d) 259 err = dORM.Read(&d)
259 throwFail(t, err) 260 throwFail(t, err)
260 261
262 throwFail(t, AssertIs(d.NullBool.Valid, false))
263 throwFail(t, AssertIs(d.NullString.Valid, false))
264 throwFail(t, AssertIs(d.NullInt64.Valid, false))
265 throwFail(t, AssertIs(d.NullFloat64.Valid, false))
266
261 _, err = dORM.Raw(`INSERT INTO data_null (boolean) VALUES (?)`, nil).Exec() 267 _, err = dORM.Raw(`INSERT INTO data_null (boolean) VALUES (?)`, nil).Exec()
262 throwFail(t, err) 268 throwFail(t, err)
263 269
264 d = DataNull{Id: 2} 270 d = DataNull{Id: 2}
265 err = dORM.Read(&d) 271 err = dORM.Read(&d)
266 throwFail(t, err) 272 throwFail(t, err)
273
274 d = DataNull{
275 DateTime: time.Now(),
276 NullString: sql.NullString{"test", true},
277 NullBool: sql.NullBool{true, true},
278 NullInt64: sql.NullInt64{42, true},
279 NullFloat64: sql.NullFloat64{42.42, true},
280 }
281
282 id, err = dORM.Insert(&d)
283 throwFail(t, err)
284 throwFail(t, AssertIs(id, 3))
285
286 d = DataNull{Id: 3}
287 err = dORM.Read(&d)
288 throwFail(t, err)
289
290 throwFail(t, AssertIs(d.NullBool.Valid, true))
291 throwFail(t, AssertIs(d.NullBool.Bool, true))
292
293 throwFail(t, AssertIs(d.NullString.Valid, true))
294 throwFail(t, AssertIs(d.NullString.String, "test"))
295
296 throwFail(t, AssertIs(d.NullInt64.Valid, true))
297 throwFail(t, AssertIs(d.NullInt64.Int64, 42))
298
299 throwFail(t, AssertIs(d.NullFloat64.Valid, true))
300 throwFail(t, AssertIs(d.NullFloat64.Float64, 42.42))
267 } 301 }
268 302
269 func TestCRUD(t *testing.T) { 303 func TestCRUD(t *testing.T) {
......
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!