goext/sq/structscanner.go

252 lines
5.7 KiB
Go
Raw Permalink Normal View History

2022-12-22 15:49:10 +01:00
package sq
import (
"errors"
"fmt"
"github.com/jmoiron/sqlx"
"github.com/jmoiron/sqlx/reflectx"
2023-12-29 19:25:36 +01:00
"gogs.mikescher.com/BlackForestBytes/goext/langext"
2022-12-22 15:49:10 +01:00
"reflect"
)
// forked from sqlx, but added ability to unmarshal optional-nested structs
type StructScanner struct {
rows *sqlx.Rows
Mapper *reflectx.Mapper
unsafe bool
2023-12-29 19:25:36 +01:00
fields [][]int
values []any
converter []DBTypeConverter
columns []string
2022-12-22 15:49:10 +01:00
}
func NewStructScanner(rows *sqlx.Rows, unsafe bool) *StructScanner {
return &StructScanner{
rows: rows,
Mapper: reflectx.NewMapper("db"),
unsafe: unsafe,
}
}
func (r *StructScanner) Start(dest any) error {
v := reflect.ValueOf(dest)
if v.Kind() != reflect.Ptr {
return errors.New("must pass a pointer, not a value, to StructScan destination")
}
columns, err := r.rows.Columns()
if err != nil {
return err
}
r.columns = columns
r.fields = r.Mapper.TraversalsByName(v.Type(), columns)
// if we are not unsafe and are missing fields, return an error
if f, err := missingFields(r.fields); err != nil && !r.unsafe {
return fmt.Errorf("missing destination name %s in %T", columns[f], dest)
}
r.values = make([]interface{}, len(columns))
2023-12-29 19:25:36 +01:00
r.converter = make([]DBTypeConverter, len(columns))
2022-12-22 15:49:10 +01:00
return nil
}
2022-12-23 10:11:01 +01:00
// StructScanExt forked from github.com/jmoiron/sqlx@v1.3.5/sqlx.go
2023-12-29 19:25:36 +01:00
// does also work with nullabel structs (from LEFT JOIN's)
// does also work with custom value converters
func (r *StructScanner) StructScanExt(q Queryable, dest any) error {
2022-12-22 15:49:10 +01:00
v := reflect.ValueOf(dest)
if v.Kind() != reflect.Ptr {
return errors.New("must pass a pointer, not a value, to StructScan destination")
}
// ========= STEP 1 :: =========
v = v.Elem()
2023-12-29 19:25:36 +01:00
err := fieldsByTraversalExtended(q, v, r.fields, r.values, r.converter)
2022-12-22 15:49:10 +01:00
if err != nil {
return err
}
// scan into the struct field pointers and append to our results
err = r.rows.Scan(r.values...)
if err != nil {
return err
}
nullStructs := make(map[string]bool)
for i, traversal := range r.fields {
if len(traversal) == 0 {
continue
}
isnsil := reflect.ValueOf(r.values[i]).Elem().IsNil()
for i := 1; i < len(traversal); i++ {
canParentNil := reflectx.FieldByIndexes(v, traversal[0:i]).Kind() == reflect.Pointer
k := fmt.Sprintf("%v", traversal[0:i])
if v, ok := nullStructs[k]; ok {
nullStructs[k] = canParentNil && v && isnsil
} else {
nullStructs[k] = canParentNil && isnsil
}
}
}
forcenulled := make(map[string]bool)
for i, traversal := range r.fields {
if len(traversal) == 0 {
continue
}
anyparentnull := false
for i := 1; i < len(traversal); i++ {
k := fmt.Sprintf("%v", traversal[0:i])
if nv, ok := nullStructs[k]; ok && nv {
if _, ok := forcenulled[k]; !ok {
f := reflectx.FieldByIndexes(v, traversal[0:i])
2022-12-23 19:11:18 +01:00
f.Set(reflect.Zero(f.Type())) // set to nil
2022-12-22 15:49:10 +01:00
forcenulled[k] = true
}
anyparentnull = true
break
}
}
if anyparentnull {
continue
}
f := reflectx.FieldByIndexes(v, traversal)
val1 := reflect.ValueOf(r.values[i])
val2 := val1.Elem()
if val2.IsNil() {
if f.Kind() != reflect.Pointer {
return errors.New(fmt.Sprintf("Cannot set field %v to NULL value from column '%s' (type: %s)", traversal, r.columns[i], f.Type().String()))
}
2022-12-23 19:11:18 +01:00
f.Set(reflect.Zero(f.Type())) // set to nil
2022-12-22 15:49:10 +01:00
} else {
2023-12-29 19:25:36 +01:00
if r.converter[i] != nil {
val3 := val2.Elem().Interface()
conv3, err := r.converter[i].DBToModel(val3)
if err != nil {
return err
}
f.Set(reflect.ValueOf(conv3))
} else {
f.Set(val2.Elem())
}
2022-12-22 15:49:10 +01:00
}
}
return r.rows.Err()
}
2022-12-23 10:11:01 +01:00
// StructScanBase forked from github.com/jmoiron/sqlx@v1.3.5/sqlx.go
2022-12-22 15:49:10 +01:00
// without (relevant) changes
func (r *StructScanner) StructScanBase(dest any) error {
v := reflect.ValueOf(dest)
if v.Kind() != reflect.Ptr {
return errors.New("must pass a pointer, not a value, to StructScan destination")
}
v = v.Elem()
2022-12-22 15:59:12 +01:00
err := fieldsByTraversalBase(v, r.fields, r.values, true)
2022-12-22 15:49:10 +01:00
if err != nil {
return err
}
// scan into the struct field pointers and append to our results
err = r.rows.Scan(r.values...)
if err != nil {
return err
}
return r.rows.Err()
}
// fieldsByTraversal forked from github.com/jmoiron/sqlx@v1.3.5/sqlx.go
2023-12-29 19:25:36 +01:00
func fieldsByTraversalExtended(q Queryable, v reflect.Value, traversals [][]int, values []interface{}, converter []DBTypeConverter) error {
2022-12-22 15:49:10 +01:00
v = reflect.Indirect(v)
if v.Kind() != reflect.Struct {
return errors.New("argument not a struct")
}
for i, traversal := range traversals {
if len(traversal) == 0 {
values[i] = new(interface{})
continue
}
f := reflectx.FieldByIndexes(v, traversal)
2023-12-29 19:25:36 +01:00
typeStr := f.Type().String()
foundConverter := false
for _, conv := range q.ListConverter() {
if conv.ModelTypeString() == typeStr {
_v := langext.Ptr[any](nil)
values[i] = _v
foundConverter = true
converter[i] = conv
break
}
}
if !foundConverter {
values[i] = reflect.New(reflect.PointerTo(f.Type())).Interface()
converter[i] = nil
}
2022-12-22 15:49:10 +01:00
}
return nil
}
2022-12-22 15:59:12 +01:00
// fieldsByTraversal forked from github.com/jmoiron/sqlx@v1.3.5/sqlx.go
func fieldsByTraversalBase(v reflect.Value, traversals [][]int, values []interface{}, ptrs bool) error {
v = reflect.Indirect(v)
if v.Kind() != reflect.Struct {
return errors.New("argument not a struct")
}
for i, traversal := range traversals {
if len(traversal) == 0 {
values[i] = new(interface{})
continue
}
f := reflectx.FieldByIndexes(v, traversal)
if ptrs {
values[i] = f.Addr().Interface()
} else {
values[i] = f.Interface()
}
}
return nil
}
2022-12-22 15:49:10 +01:00
// missingFields forked from github.com/jmoiron/sqlx@v1.3.5/sqlx.go
func missingFields(transversals [][]int) (field int, err error) {
for i, t := range transversals {
if len(t) == 0 {
return i, errors.New("missing field")
}
}
return 0, nil
}