From bbb33e9fd6b3f6722b727cd74e85cb8243440ff3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mike=20Schw=C3=B6rer?= Date: Thu, 22 Dec 2022 15:49:10 +0100 Subject: [PATCH] v0.0.44 --- dataext/bufferedReadCloser.go | 8 +- sq/scanner.go | 108 +++++++++++++++--- sq/structscanner.go | 201 ++++++++++++++++++++++++++++++++++ 3 files changed, 300 insertions(+), 17 deletions(-) create mode 100644 sq/structscanner.go diff --git a/dataext/bufferedReadCloser.go b/dataext/bufferedReadCloser.go index 59bc825..376f1aa 100644 --- a/dataext/bufferedReadCloser.go +++ b/dataext/bufferedReadCloser.go @@ -8,10 +8,10 @@ import ( type brcMode int const ( - modeSourceReading = 0 - modeSourceFinished = 1 - modeBufferReading = 2 - modeBufferFinished = 3 + modeSourceReading brcMode = 0 + modeSourceFinished brcMode = 1 + modeBufferReading brcMode = 2 + modeBufferFinished brcMode = 3 ) type BufferedReadCloser interface { diff --git a/sq/scanner.go b/sq/scanner.go index 646961b..b34fb3f 100644 --- a/sq/scanner.go +++ b/sq/scanner.go @@ -6,24 +6,75 @@ import ( "github.com/jmoiron/sqlx" ) -func ScanSingle[TData any](rows *sqlx.Rows, close bool) (TData, error) { +type StructScanMode string + +const ( + SModeFast StructScanMode = "FAST" + SModeExtended StructScanMode = "EXTENDED" +) + +type StructScanSafety string + +const ( + Safe StructScanSafety = "SAFE" + Unsafe StructScanSafety = "UNSAFE" +) + +func ScanSingle[TData any](rows *sqlx.Rows, mode StructScanMode, sec StructScanSafety, close bool) (TData, error) { if rows.Next() { + var strscan *StructScanner + + if sec == Safe { + strscan = NewStructScanner(rows, false) + var data TData + err := strscan.Start(&data) + if err != nil { + return *new(TData), err + } + } else if sec == Unsafe { + strscan = NewStructScanner(rows, true) + var data TData + err := strscan.Start(&data) + if err != nil { + return *new(TData), err + } + } else { + return *new(TData), errors.New("unknown value for ") + } + var data TData - err := rows.StructScan(&data) - if err != nil { - return *new(TData), err + + if mode == SModeFast { + err := strscan.StructScanBase(&data) + if err != nil { + return *new(TData), err + } + } else if mode == SModeExtended { + var data TData + err := strscan.StructScanExt(&data) + if err != nil { + return *new(TData), err + } + } else { + return *new(TData), errors.New("unknown value for ") } + if rows.Next() { - _ = rows.Close() - return *new(TData), errors.New("sql returned more than onw row") + if close { + _ = rows.Close() + } + return *new(TData), errors.New("sql returned more than one row") } + if close { - err = rows.Close() + err := rows.Close() if err != nil { return *new(TData), err } } + return data, nil + } else { if close { _ = rows.Close() @@ -32,18 +83,49 @@ func ScanSingle[TData any](rows *sqlx.Rows, close bool) (TData, error) { } } -func ScanAll[TData any](rows *sqlx.Rows, close bool) ([]TData, error) { - res := make([]TData, 0) - for rows.Next() { +func ScanAll[TData any](rows *sqlx.Rows, mode StructScanMode, sec StructScanSafety, close bool) ([]TData, error) { + var strscan *StructScanner + + if sec == Safe { + strscan = NewStructScanner(rows, false) var data TData - err := rows.StructScan(&data) + err := strscan.Start(&data) if err != nil { return nil, err } - res = append(res, data) + } else if sec == Unsafe { + strscan = NewStructScanner(rows, true) + var data TData + err := strscan.Start(&data) + if err != nil { + return nil, err + } + } else { + return nil, errors.New("unknown value for ") + } + + res := make([]TData, 0) + for rows.Next() { + if mode == SModeFast { + var data TData + err := strscan.StructScanBase(&data) + if err != nil { + return nil, err + } + res = append(res, data) + } else if mode == SModeExtended { + var data TData + err := strscan.StructScanExt(&data) + if err != nil { + return nil, err + } + res = append(res, data) + } else { + return nil, errors.New("unknown value for ") + } } if close { - err := rows.Close() + err := strscan.rows.Close() if err != nil { return nil, err } diff --git a/sq/structscanner.go b/sq/structscanner.go new file mode 100644 index 0000000..d2307e5 --- /dev/null +++ b/sq/structscanner.go @@ -0,0 +1,201 @@ +package sq + +import ( + "errors" + "fmt" + "github.com/jmoiron/sqlx" + "github.com/jmoiron/sqlx/reflectx" + "reflect" +) + +// forked from sqlx, but added ability to unmarshal optional-nested structs + +type StructScanner struct { + rows *sqlx.Rows + Mapper *reflectx.Mapper + unsafe bool + + fields [][]int + values []any + columns []string +} + +func NewStructScanner(rows *sqlx.Rows, unsafe bool) *StructScanner { + return &StructScanner{ + rows: rows, + Mapper: reflectx.NewMapper("db"), + unsafe: unsafe, + } +} + +func (r *StructScanner) Start(dest any) error { + v := reflect.ValueOf(dest) + + if v.Kind() != reflect.Ptr { + return errors.New("must pass a pointer, not a value, to StructScan destination") + } + + columns, err := r.rows.Columns() + if err != nil { + return err + } + + r.columns = columns + r.fields = r.Mapper.TraversalsByName(v.Type(), columns) + // if we are not unsafe and are missing fields, return an error + if f, err := missingFields(r.fields); err != nil && !r.unsafe { + return fmt.Errorf("missing destination name %s in %T", columns[f], dest) + } + r.values = make([]interface{}, len(columns)) + + return nil +} + +// StructScan forked from github.com/jmoiron/sqlx@v1.3.5/sqlx.go +// does also wok with nullabel structs (from LEFT JOIN's) +func (r *StructScanner) StructScanExt(dest any) error { + v := reflect.ValueOf(dest) + + if v.Kind() != reflect.Ptr { + return errors.New("must pass a pointer, not a value, to StructScan destination") + } + + // ========= STEP 1 :: ========= + + v = v.Elem() + + err := fieldsByTraversal(v, r.fields, r.values) + if err != nil { + return err + } + // scan into the struct field pointers and append to our results + err = r.rows.Scan(r.values...) + if err != nil { + return err + } + + nullStructs := make(map[string]bool) + + for i, traversal := range r.fields { + if len(traversal) == 0 { + continue + } + + isnsil := reflect.ValueOf(r.values[i]).Elem().IsNil() + + for i := 1; i < len(traversal); i++ { + + canParentNil := reflectx.FieldByIndexes(v, traversal[0:i]).Kind() == reflect.Pointer + + k := fmt.Sprintf("%v", traversal[0:i]) + if v, ok := nullStructs[k]; ok { + + nullStructs[k] = canParentNil && v && isnsil + + } else { + nullStructs[k] = canParentNil && isnsil + } + } + + } + + forcenulled := make(map[string]bool) + + for i, traversal := range r.fields { + if len(traversal) == 0 { + continue + } + + anyparentnull := false + for i := 1; i < len(traversal); i++ { + k := fmt.Sprintf("%v", traversal[0:i]) + if nv, ok := nullStructs[k]; ok && nv { + + if _, ok := forcenulled[k]; !ok { + f := reflectx.FieldByIndexes(v, traversal[0:i]) + f.Set(reflect.New(f.Type().Elem())) // set to nil + forcenulled[k] = true + } + + anyparentnull = true + break + + } + } + + if anyparentnull { + continue + } + + f := reflectx.FieldByIndexes(v, traversal) + + val1 := reflect.ValueOf(r.values[i]) + val2 := val1.Elem() + val3 := val2.Elem() + + if val2.IsNil() { + if f.Kind() != reflect.Pointer { + return errors.New(fmt.Sprintf("Cannot set field %v to NULL value from column '%s' (type: %s)", traversal, r.columns[i], f.Type().String())) + } + + f.Set(reflect.New(f.Type().Elem())) // set to nil + } else { + f.Set(val3) + } + + } + + return r.rows.Err() +} + +// StructScan forked from github.com/jmoiron/sqlx@v1.3.5/sqlx.go +// without (relevant) changes +func (r *StructScanner) StructScanBase(dest any) error { + v := reflect.ValueOf(dest) + + if v.Kind() != reflect.Ptr { + return errors.New("must pass a pointer, not a value, to StructScan destination") + } + + v = v.Elem() + + err := fieldsByTraversal(v, r.fields, r.values) + if err != nil { + return err + } + // scan into the struct field pointers and append to our results + err = r.rows.Scan(r.values...) + if err != nil { + return err + } + return r.rows.Err() +} + +// fieldsByTraversal forked from github.com/jmoiron/sqlx@v1.3.5/sqlx.go +func fieldsByTraversal(v reflect.Value, traversals [][]int, values []interface{}) error { + v = reflect.Indirect(v) + if v.Kind() != reflect.Struct { + return errors.New("argument not a struct") + } + + for i, traversal := range traversals { + if len(traversal) == 0 { + values[i] = new(interface{}) + continue + } + f := reflectx.FieldByIndexes(v, traversal) + + values[i] = reflect.New(reflect.PointerTo(f.Type())).Interface() + } + return nil +} + +// missingFields forked from github.com/jmoiron/sqlx@v1.3.5/sqlx.go +func missingFields(transversals [][]int) (field int, err error) { + for i, t := range transversals { + if len(t) == 0 { + return i, errors.New("missing field") + } + } + return 0, nil +}