d7c3727f by Peter Fern

Add support for basic type pointer fields

Allows models like:

```
type User struct {
	Id    int64
	Name  string
	Email *string `orm:"null"`
}
```

This helps a lot when doing JSON marshalling/unmarshalling.

Pointer fields should always be declared with the NULL orm tag for sanity, this
probably requires documentation.
1 parent 03eb1fc1
...@@ -122,6 +122,12 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val ...@@ -122,6 +122,12 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
122 if nb.Valid { 122 if nb.Valid {
123 value = nb.Bool 123 value = nb.Bool
124 } 124 }
125 } else if field.Kind() == reflect.Ptr {
126 if field.IsNil() {
127 value = nil
128 } else {
129 value = field.Elem().Bool()
130 }
125 } else { 131 } else {
126 value = field.Bool() 132 value = field.Bool()
127 } 133 }
...@@ -131,6 +137,12 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val ...@@ -131,6 +137,12 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
131 if ns.Valid { 137 if ns.Valid {
132 value = ns.String 138 value = ns.String
133 } 139 }
140 } else if field.Kind() == reflect.Ptr {
141 if field.IsNil() {
142 value = nil
143 } else {
144 value = field.Elem().String()
145 }
134 } else { 146 } else {
135 value = field.String() 147 value = field.String()
136 } 148 }
...@@ -140,6 +152,12 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val ...@@ -140,6 +152,12 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
140 if nf.Valid { 152 if nf.Valid {
141 value = nf.Float64 153 value = nf.Float64
142 } 154 }
155 } else if field.Kind() == reflect.Ptr {
156 if field.IsNil() {
157 value = nil
158 } else {
159 value = field.Elem().Float()
160 }
143 } else { 161 } else {
144 vu := field.Interface() 162 vu := field.Interface()
145 if _, ok := vu.(float32); ok { 163 if _, ok := vu.(float32); ok {
...@@ -161,13 +179,27 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val ...@@ -161,13 +179,27 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
161 default: 179 default:
162 switch { 180 switch {
163 case fi.fieldType&IsPostiveIntegerField > 0: 181 case fi.fieldType&IsPostiveIntegerField > 0:
164 value = field.Uint() 182 if field.Kind() == reflect.Ptr {
183 if field.IsNil() {
184 value = nil
185 } else {
186 value = field.Elem().Uint()
187 }
188 } else {
189 value = field.Uint()
190 }
165 case fi.fieldType&IsIntegerField > 0: 191 case fi.fieldType&IsIntegerField > 0:
166 if ni, ok := field.Interface().(sql.NullInt64); ok { 192 if ni, ok := field.Interface().(sql.NullInt64); ok {
167 value = nil 193 value = nil
168 if ni.Valid { 194 if ni.Valid {
169 value = ni.Int64 195 value = ni.Int64
170 } 196 }
197 } else if field.Kind() == reflect.Ptr {
198 if field.IsNil() {
199 value = nil
200 } else {
201 value = field.Elem().Int()
202 }
171 } else { 203 } else {
172 value = field.Int() 204 value = field.Int()
173 } 205 }
...@@ -1177,6 +1209,11 @@ setValue: ...@@ -1177,6 +1209,11 @@ setValue:
1177 nb.Valid = true 1209 nb.Valid = true
1178 } 1210 }
1179 field.Set(reflect.ValueOf(nb)) 1211 field.Set(reflect.ValueOf(nb))
1212 } else if field.Kind() == reflect.Ptr {
1213 if value != nil {
1214 v := value.(bool)
1215 field.Set(reflect.ValueOf(&v))
1216 }
1180 } else { 1217 } else {
1181 if value == nil { 1218 if value == nil {
1182 value = false 1219 value = false
...@@ -1194,6 +1231,11 @@ setValue: ...@@ -1194,6 +1231,11 @@ setValue:
1194 ns.Valid = true 1231 ns.Valid = true
1195 } 1232 }
1196 field.Set(reflect.ValueOf(ns)) 1233 field.Set(reflect.ValueOf(ns))
1234 } else if field.Kind() == reflect.Ptr {
1235 if value != nil {
1236 v := value.(string)
1237 field.Set(reflect.ValueOf(&v))
1238 }
1197 } else { 1239 } else {
1198 if value == nil { 1240 if value == nil {
1199 value = "" 1241 value = ""
...@@ -1208,6 +1250,56 @@ setValue: ...@@ -1208,6 +1250,56 @@ setValue:
1208 } 1250 }
1209 field.Set(reflect.ValueOf(value)) 1251 field.Set(reflect.ValueOf(value))
1210 } 1252 }
1253 case fieldType == TypePositiveBitField && field.Kind() == reflect.Ptr:
1254 if value != nil {
1255 v := uint8(value.(uint64))
1256 field.Set(reflect.ValueOf(&v))
1257 }
1258 case fieldType == TypePositiveSmallIntegerField && field.Kind() == reflect.Ptr:
1259 if value != nil {
1260 v := uint16(value.(uint64))
1261 field.Set(reflect.ValueOf(&v))
1262 }
1263 case fieldType == TypePositiveIntegerField && field.Kind() == reflect.Ptr:
1264 if value != nil {
1265 if field.Type() == reflect.TypeOf(new(uint)) {
1266 v := uint(value.(uint64))
1267 field.Set(reflect.ValueOf(&v))
1268 } else {
1269 v := uint32(value.(uint64))
1270 field.Set(reflect.ValueOf(&v))
1271 }
1272 }
1273 case fieldType == TypePositiveBigIntegerField && field.Kind() == reflect.Ptr:
1274 if value != nil {
1275 v := value.(uint64)
1276 field.Set(reflect.ValueOf(&v))
1277 }
1278 case fieldType == TypeBitField && field.Kind() == reflect.Ptr:
1279 if value != nil {
1280 v := int8(value.(int64))
1281 field.Set(reflect.ValueOf(&v))
1282 }
1283 case fieldType == TypeSmallIntegerField && field.Kind() == reflect.Ptr:
1284 if value != nil {
1285 v := int16(value.(int64))
1286 field.Set(reflect.ValueOf(&v))
1287 }
1288 case fieldType == TypeIntegerField && field.Kind() == reflect.Ptr:
1289 if value != nil {
1290 if field.Type() == reflect.TypeOf(new(int)) {
1291 v := int(value.(int64))
1292 field.Set(reflect.ValueOf(&v))
1293 } else {
1294 v := int32(value.(int64))
1295 field.Set(reflect.ValueOf(&v))
1296 }
1297 }
1298 case fieldType == TypeBigIntegerField && field.Kind() == reflect.Ptr:
1299 if value != nil {
1300 v := value.(int64)
1301 field.Set(reflect.ValueOf(&v))
1302 }
1211 case fieldType&IsIntegerField > 0: 1303 case fieldType&IsIntegerField > 0:
1212 if fieldType&IsPostiveIntegerField > 0 { 1304 if fieldType&IsPostiveIntegerField > 0 {
1213 if isNative { 1305 if isNative {
...@@ -1244,6 +1336,16 @@ setValue: ...@@ -1244,6 +1336,16 @@ setValue:
1244 nf.Valid = true 1336 nf.Valid = true
1245 } 1337 }
1246 field.Set(reflect.ValueOf(nf)) 1338 field.Set(reflect.ValueOf(nf))
1339 } else if field.Kind() == reflect.Ptr {
1340 if value != nil {
1341 if field.Type() == reflect.TypeOf(new(float32)) {
1342 v := float32(value.(float64))
1343 field.Set(reflect.ValueOf(&v))
1344 } else {
1345 v := value.(float64)
1346 field.Set(reflect.ValueOf(&v))
1347 }
1348 }
1247 } else { 1349 } else {
1248 1350
1249 if value == nil { 1351 if value == nil {
......
...@@ -155,6 +155,24 @@ type DataNull struct { ...@@ -155,6 +155,24 @@ type DataNull struct {
155 NullBool sql.NullBool `orm:"null"` 155 NullBool sql.NullBool `orm:"null"`
156 NullFloat64 sql.NullFloat64 `orm:"null"` 156 NullFloat64 sql.NullFloat64 `orm:"null"`
157 NullInt64 sql.NullInt64 `orm:"null"` 157 NullInt64 sql.NullInt64 `orm:"null"`
158 BooleanPtr *bool `orm:"null"`
159 CharPtr *string `orm:"null;size(50)"`
160 TextPtr *string `orm:"null;type(text)"`
161 BytePtr *byte `orm:"null"`
162 RunePtr *rune `orm:"null"`
163 IntPtr *int `orm:"null"`
164 Int8Ptr *int8 `orm:"null"`
165 Int16Ptr *int16 `orm:"null"`
166 Int32Ptr *int32 `orm:"null"`
167 Int64Ptr *int64 `orm:"null"`
168 UintPtr *uint `orm:"null"`
169 Uint8Ptr *uint8 `orm:"null"`
170 Uint16Ptr *uint16 `orm:"null"`
171 Uint32Ptr *uint32 `orm:"null"`
172 Uint64Ptr *uint64 `orm:"null"`
173 Float32Ptr *float32 `orm:"null"`
174 Float64Ptr *float64 `orm:"null"`
175 DecimalPtr *float64 `orm:"digits(8);decimals(4);null"`
158 } 176 }
159 177
160 type String string 178 type String string
......
...@@ -111,45 +111,73 @@ func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col ...@@ -111,45 +111,73 @@ func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col
111 111
112 // return field type as type constant from reflect.Value 112 // return field type as type constant from reflect.Value
113 func getFieldType(val reflect.Value) (ft int, err error) { 113 func getFieldType(val reflect.Value) (ft int, err error) {
114 elm := reflect.Indirect(val) 114 switch val.Type() {
115 switch elm.Kind() { 115 case reflect.TypeOf(new(int8)):
116 case reflect.Int8:
117 ft = TypeBitField 116 ft = TypeBitField
118 case reflect.Int16: 117 case reflect.TypeOf(new(int16)):
119 ft = TypeSmallIntegerField 118 ft = TypeSmallIntegerField
120 case reflect.Int32, reflect.Int: 119 case reflect.TypeOf(new(int32)),
120 reflect.TypeOf(new(int)):
121 ft = TypeIntegerField 121 ft = TypeIntegerField
122 case reflect.Int64: 122 case reflect.TypeOf(new(int64)):
123 ft = TypeBigIntegerField 123 ft = TypeBigIntegerField
124 case reflect.Uint8: 124 case reflect.TypeOf(new(uint8)):
125 ft = TypePositiveBitField 125 ft = TypePositiveBitField
126 case reflect.Uint16: 126 case reflect.TypeOf(new(uint16)):
127 ft = TypePositiveSmallIntegerField 127 ft = TypePositiveSmallIntegerField
128 case reflect.Uint32, reflect.Uint: 128 case reflect.TypeOf(new(uint32)),
129 reflect.TypeOf(new(uint)):
129 ft = TypePositiveIntegerField 130 ft = TypePositiveIntegerField
130 case reflect.Uint64: 131 case reflect.TypeOf(new(uint64)):
131 ft = TypePositiveBigIntegerField 132 ft = TypePositiveBigIntegerField
132 case reflect.Float32, reflect.Float64: 133 case reflect.TypeOf(new(float32)),
134 reflect.TypeOf(new(float64)):
133 ft = TypeFloatField 135 ft = TypeFloatField
134 case reflect.Bool: 136 case reflect.TypeOf(new(bool)):
135 ft = TypeBooleanField 137 ft = TypeBooleanField
136 case reflect.String: 138 case reflect.TypeOf(new(string)):
137 ft = TypeCharField 139 ft = TypeCharField
138 default: 140 default:
139 if elm.Interface() == nil { 141 elm := reflect.Indirect(val)
140 panic(fmt.Errorf("%s is nil pointer, may be miss setting tag", val)) 142 switch elm.Kind() {
141 } 143 case reflect.Int8:
142 switch elm.Interface().(type) { 144 ft = TypeBitField
143 case sql.NullInt64: 145 case reflect.Int16:
146 ft = TypeSmallIntegerField
147 case reflect.Int32, reflect.Int:
148 ft = TypeIntegerField
149 case reflect.Int64:
144 ft = TypeBigIntegerField 150 ft = TypeBigIntegerField
145 case sql.NullFloat64: 151 case reflect.Uint8:
152 ft = TypePositiveBitField
153 case reflect.Uint16:
154 ft = TypePositiveSmallIntegerField
155 case reflect.Uint32, reflect.Uint:
156 ft = TypePositiveIntegerField
157 case reflect.Uint64:
158 ft = TypePositiveBigIntegerField
159 case reflect.Float32, reflect.Float64:
146 ft = TypeFloatField 160 ft = TypeFloatField
147 case sql.NullBool: 161 case reflect.Bool:
148 ft = TypeBooleanField 162 ft = TypeBooleanField
149 case sql.NullString: 163 case reflect.String:
150 ft = TypeCharField 164 ft = TypeCharField
151 case time.Time: 165 default:
152 ft = TypeDateTimeField 166 if elm.Interface() == nil {
167 panic(fmt.Errorf("%s is nil pointer, may be miss setting tag", val))
168 }
169 switch elm.Interface().(type) {
170 case sql.NullInt64:
171 ft = TypeBigIntegerField
172 case sql.NullFloat64:
173 ft = TypeFloatField
174 case sql.NullBool:
175 ft = TypeBooleanField
176 case sql.NullString:
177 ft = TypeCharField
178 case time.Time:
179 ft = TypeDateTimeField
180 }
153 } 181 }
154 } 182 }
155 if ft&IsFieldType == 0 { 183 if ft&IsFieldType == 0 {
......
...@@ -287,6 +287,25 @@ func TestNullDataTypes(t *testing.T) { ...@@ -287,6 +287,25 @@ func TestNullDataTypes(t *testing.T) {
287 throwFail(t, AssertIs(d.NullInt64.Valid, false)) 287 throwFail(t, AssertIs(d.NullInt64.Valid, false))
288 throwFail(t, AssertIs(d.NullFloat64.Valid, false)) 288 throwFail(t, AssertIs(d.NullFloat64.Valid, false))
289 289
290 throwFail(t, AssertIs(d.BooleanPtr, nil))
291 throwFail(t, AssertIs(d.CharPtr, nil))
292 throwFail(t, AssertIs(d.TextPtr, nil))
293 throwFail(t, AssertIs(d.BytePtr, nil))
294 throwFail(t, AssertIs(d.RunePtr, nil))
295 throwFail(t, AssertIs(d.IntPtr, nil))
296 throwFail(t, AssertIs(d.Int8Ptr, nil))
297 throwFail(t, AssertIs(d.Int16Ptr, nil))
298 throwFail(t, AssertIs(d.Int32Ptr, nil))
299 throwFail(t, AssertIs(d.Int64Ptr, nil))
300 throwFail(t, AssertIs(d.UintPtr, nil))
301 throwFail(t, AssertIs(d.Uint8Ptr, nil))
302 throwFail(t, AssertIs(d.Uint16Ptr, nil))
303 throwFail(t, AssertIs(d.Uint32Ptr, nil))
304 throwFail(t, AssertIs(d.Uint64Ptr, nil))
305 throwFail(t, AssertIs(d.Float32Ptr, nil))
306 throwFail(t, AssertIs(d.Float64Ptr, nil))
307 throwFail(t, AssertIs(d.DecimalPtr, nil))
308
290 _, err = dORM.Raw(`INSERT INTO data_null (boolean) VALUES (?)`, nil).Exec() 309 _, err = dORM.Raw(`INSERT INTO data_null (boolean) VALUES (?)`, nil).Exec()
291 throwFail(t, err) 310 throwFail(t, err)
292 311
...@@ -294,12 +313,49 @@ func TestNullDataTypes(t *testing.T) { ...@@ -294,12 +313,49 @@ func TestNullDataTypes(t *testing.T) {
294 err = dORM.Read(&d) 313 err = dORM.Read(&d)
295 throwFail(t, err) 314 throwFail(t, err)
296 315
316 booleanPtr := true
317 charPtr := string("test")
318 textPtr := string("test")
319 bytePtr := byte('t')
320 runePtr := rune('t')
321 intPtr := int(42)
322 int8Ptr := int8(42)
323 int16Ptr := int16(42)
324 int32Ptr := int32(42)
325 int64Ptr := int64(42)
326 uintPtr := uint(42)
327 uint8Ptr := uint8(42)
328 uint16Ptr := uint16(42)
329 uint32Ptr := uint32(42)
330 uint64Ptr := uint64(42)
331 float32Ptr := float32(42.0)
332 float64Ptr := float64(42.0)
333 decimalPtr := float64(42.0)
334
297 d = DataNull{ 335 d = DataNull{
298 DateTime: time.Now(), 336 DateTime: time.Now(),
299 NullString: sql.NullString{String: "test", Valid: true}, 337 NullString: sql.NullString{String: "test", Valid: true},
300 NullBool: sql.NullBool{Bool: true, Valid: true}, 338 NullBool: sql.NullBool{Bool: true, Valid: true},
301 NullInt64: sql.NullInt64{Int64: 42, Valid: true}, 339 NullInt64: sql.NullInt64{Int64: 42, Valid: true},
302 NullFloat64: sql.NullFloat64{Float64: 42.42, Valid: true}, 340 NullFloat64: sql.NullFloat64{Float64: 42.42, Valid: true},
341 BooleanPtr: &booleanPtr,
342 CharPtr: &charPtr,
343 TextPtr: &textPtr,
344 BytePtr: &bytePtr,
345 RunePtr: &runePtr,
346 IntPtr: &intPtr,
347 Int8Ptr: &int8Ptr,
348 Int16Ptr: &int16Ptr,
349 Int32Ptr: &int32Ptr,
350 Int64Ptr: &int64Ptr,
351 UintPtr: &uintPtr,
352 Uint8Ptr: &uint8Ptr,
353 Uint16Ptr: &uint16Ptr,
354 Uint32Ptr: &uint32Ptr,
355 Uint64Ptr: &uint64Ptr,
356 Float32Ptr: &float32Ptr,
357 Float64Ptr: &float64Ptr,
358 DecimalPtr: &decimalPtr,
303 } 359 }
304 360
305 id, err = dORM.Insert(&d) 361 id, err = dORM.Insert(&d)
...@@ -321,6 +377,25 @@ func TestNullDataTypes(t *testing.T) { ...@@ -321,6 +377,25 @@ func TestNullDataTypes(t *testing.T) {
321 377
322 throwFail(t, AssertIs(d.NullFloat64.Valid, true)) 378 throwFail(t, AssertIs(d.NullFloat64.Valid, true))
323 throwFail(t, AssertIs(d.NullFloat64.Float64, 42.42)) 379 throwFail(t, AssertIs(d.NullFloat64.Float64, 42.42))
380
381 throwFail(t, AssertIs(*d.BooleanPtr, booleanPtr))
382 throwFail(t, AssertIs(*d.CharPtr, charPtr))
383 throwFail(t, AssertIs(*d.TextPtr, textPtr))
384 throwFail(t, AssertIs(*d.BytePtr, bytePtr))
385 throwFail(t, AssertIs(*d.RunePtr, runePtr))
386 throwFail(t, AssertIs(*d.IntPtr, intPtr))
387 throwFail(t, AssertIs(*d.Int8Ptr, int8Ptr))
388 throwFail(t, AssertIs(*d.Int16Ptr, int16Ptr))
389 throwFail(t, AssertIs(*d.Int32Ptr, int32Ptr))
390 throwFail(t, AssertIs(*d.Int64Ptr, int64Ptr))
391 throwFail(t, AssertIs(*d.UintPtr, uintPtr))
392 throwFail(t, AssertIs(*d.Uint8Ptr, uint8Ptr))
393 throwFail(t, AssertIs(*d.Uint16Ptr, uint16Ptr))
394 throwFail(t, AssertIs(*d.Uint32Ptr, uint32Ptr))
395 throwFail(t, AssertIs(*d.Uint64Ptr, uint64Ptr))
396 throwFail(t, AssertIs(*d.Float32Ptr, float32Ptr))
397 throwFail(t, AssertIs(*d.Float64Ptr, float64Ptr))
398 throwFail(t, AssertIs(*d.DecimalPtr, decimalPtr))
324 } 399 }
325 400
326 func TestDataCustomTypes(t *testing.T) { 401 func TestDataCustomTypes(t *testing.T) {
......
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!