2022-12-22 15:49:10 +01:00
package sq
import (
"errors"
"fmt"
"github.com/jmoiron/sqlx"
"github.com/jmoiron/sqlx/reflectx"
2023-12-29 19:25:36 +01:00
"gogs.mikescher.com/BlackForestBytes/goext/langext"
2022-12-22 15:49:10 +01:00
"reflect"
2024-01-13 02:01:30 +01:00
"strings"
2022-12-22 15:49:10 +01:00
)
// forked from sqlx, but added ability to unmarshal optional-nested structs
type StructScanner struct {
rows * sqlx . Rows
Mapper * reflectx . Mapper
unsafe bool
2023-12-29 19:25:36 +01:00
fields [ ] [ ] int
values [ ] any
2024-01-13 02:01:30 +01:00
converter [ ] ssConverter
2023-12-29 19:25:36 +01:00
columns [ ] string
2022-12-22 15:49:10 +01:00
}
func NewStructScanner ( rows * sqlx . Rows , unsafe bool ) * StructScanner {
return & StructScanner {
rows : rows ,
Mapper : reflectx . NewMapper ( "db" ) ,
unsafe : unsafe ,
}
}
2024-01-13 02:01:30 +01:00
type ssConverter struct {
Converter DBTypeConverter
RefCount int
}
2022-12-22 15:49:10 +01:00
func ( r * StructScanner ) Start ( dest any ) error {
v := reflect . ValueOf ( dest )
if v . Kind ( ) != reflect . Ptr {
return errors . New ( "must pass a pointer, not a value, to StructScan destination" )
}
columns , err := r . rows . Columns ( )
if err != nil {
return err
}
r . columns = columns
r . fields = r . Mapper . TraversalsByName ( v . Type ( ) , columns )
// if we are not unsafe and are missing fields, return an error
if f , err := missingFields ( r . fields ) ; err != nil && ! r . unsafe {
return fmt . Errorf ( "missing destination name %s in %T" , columns [ f ] , dest )
}
r . values = make ( [ ] interface { } , len ( columns ) )
2024-01-13 02:01:30 +01:00
r . converter = make ( [ ] ssConverter , len ( columns ) )
2022-12-22 15:49:10 +01:00
return nil
}
2022-12-23 10:11:01 +01:00
// StructScanExt forked from github.com/jmoiron/sqlx@v1.3.5/sqlx.go
2023-12-29 19:25:36 +01:00
// does also work with nullabel structs (from LEFT JOIN's)
// does also work with custom value converters
func ( r * StructScanner ) StructScanExt ( q Queryable , dest any ) error {
2022-12-22 15:49:10 +01:00
v := reflect . ValueOf ( dest )
if v . Kind ( ) != reflect . Ptr {
return errors . New ( "must pass a pointer, not a value, to StructScan destination" )
}
// ========= STEP 1 :: =========
v = v . Elem ( )
2023-12-29 19:25:36 +01:00
err := fieldsByTraversalExtended ( q , v , r . fields , r . values , r . converter )
2022-12-22 15:49:10 +01:00
if err != nil {
return err
}
// scan into the struct field pointers and append to our results
err = r . rows . Scan ( r . values ... )
if err != nil {
return err
}
nullStructs := make ( map [ string ] bool )
for i , traversal := range r . fields {
if len ( traversal ) == 0 {
continue
}
isnsil := reflect . ValueOf ( r . values [ i ] ) . Elem ( ) . IsNil ( )
for i := 1 ; i < len ( traversal ) ; i ++ {
canParentNil := reflectx . FieldByIndexes ( v , traversal [ 0 : i ] ) . Kind ( ) == reflect . Pointer
k := fmt . Sprintf ( "%v" , traversal [ 0 : i ] )
if v , ok := nullStructs [ k ] ; ok {
nullStructs [ k ] = canParentNil && v && isnsil
} else {
nullStructs [ k ] = canParentNil && isnsil
}
}
}
forcenulled := make ( map [ string ] bool )
for i , traversal := range r . fields {
if len ( traversal ) == 0 {
continue
}
anyparentnull := false
for i := 1 ; i < len ( traversal ) ; i ++ {
k := fmt . Sprintf ( "%v" , traversal [ 0 : i ] )
if nv , ok := nullStructs [ k ] ; ok && nv {
if _ , ok := forcenulled [ k ] ; ! ok {
f := reflectx . FieldByIndexes ( v , traversal [ 0 : i ] )
2022-12-23 19:11:18 +01:00
f . Set ( reflect . Zero ( f . Type ( ) ) ) // set to nil
2022-12-22 15:49:10 +01:00
forcenulled [ k ] = true
}
anyparentnull = true
break
}
}
if anyparentnull {
continue
}
f := reflectx . FieldByIndexes ( v , traversal )
val1 := reflect . ValueOf ( r . values [ i ] )
val2 := val1 . Elem ( )
if val2 . IsNil ( ) {
if f . Kind ( ) != reflect . Pointer {
return errors . New ( fmt . Sprintf ( "Cannot set field %v to NULL value from column '%s' (type: %s)" , traversal , r . columns [ i ] , f . Type ( ) . String ( ) ) )
}
2022-12-23 19:11:18 +01:00
f . Set ( reflect . Zero ( f . Type ( ) ) ) // set to nil
2022-12-22 15:49:10 +01:00
} else {
2024-01-13 02:01:30 +01:00
if r . converter [ i ] . Converter != nil {
val3 := val2 . Elem ( )
conv3 , err := r . converter [ i ] . Converter . DBToModel ( val3 . Interface ( ) )
2023-12-29 19:25:36 +01:00
if err != nil {
return err
}
2024-01-13 02:01:30 +01:00
conv3RVal := reflect . ValueOf ( conv3 )
for j := 0 ; j < r . converter [ i ] . RefCount ; j ++ {
newConv3Val := reflect . New ( conv3RVal . Type ( ) )
newConv3Val . Elem ( ) . Set ( conv3RVal )
conv3RVal = newConv3Val
}
f . Set ( conv3RVal )
2023-12-29 19:25:36 +01:00
} else {
f . Set ( val2 . Elem ( ) )
}
2022-12-22 15:49:10 +01:00
}
}
return r . rows . Err ( )
}
2022-12-23 10:11:01 +01:00
// StructScanBase forked from github.com/jmoiron/sqlx@v1.3.5/sqlx.go
2022-12-22 15:49:10 +01:00
// without (relevant) changes
func ( r * StructScanner ) StructScanBase ( dest any ) error {
v := reflect . ValueOf ( dest )
if v . Kind ( ) != reflect . Ptr {
return errors . New ( "must pass a pointer, not a value, to StructScan destination" )
}
v = v . Elem ( )
2022-12-22 15:59:12 +01:00
err := fieldsByTraversalBase ( v , r . fields , r . values , true )
2022-12-22 15:49:10 +01:00
if err != nil {
return err
}
// scan into the struct field pointers and append to our results
err = r . rows . Scan ( r . values ... )
if err != nil {
return err
}
return r . rows . Err ( )
}
// fieldsByTraversal forked from github.com/jmoiron/sqlx@v1.3.5/sqlx.go
2024-01-13 02:01:30 +01:00
func fieldsByTraversalExtended ( q Queryable , v reflect . Value , traversals [ ] [ ] int , values [ ] interface { } , converter [ ] ssConverter ) error {
2022-12-22 15:49:10 +01:00
v = reflect . Indirect ( v )
if v . Kind ( ) != reflect . Struct {
return errors . New ( "argument not a struct" )
}
for i , traversal := range traversals {
if len ( traversal ) == 0 {
values [ i ] = new ( interface { } )
continue
}
f := reflectx . FieldByIndexes ( v , traversal )
2023-12-29 19:25:36 +01:00
typeStr := f . Type ( ) . String ( )
foundConverter := false
for _ , conv := range q . ListConverter ( ) {
if conv . ModelTypeString ( ) == typeStr {
_v := langext . Ptr [ any ] ( nil )
values [ i ] = _v
foundConverter = true
2024-01-13 02:01:30 +01:00
converter [ i ] = ssConverter { Converter : conv , RefCount : 0 }
2023-12-29 19:25:36 +01:00
break
}
}
2024-01-13 02:01:30 +01:00
if ! foundConverter {
// also allow non-pointer converter for pointer-types
for _ , conv := range q . ListConverter ( ) {
if conv . ModelTypeString ( ) == strings . TrimLeft ( typeStr , "*" ) {
_v := langext . Ptr [ any ] ( nil )
values [ i ] = _v
foundConverter = true
converter [ i ] = ssConverter { Converter : conv , RefCount : len ( typeStr ) - len ( strings . TrimLeft ( typeStr , "*" ) ) } // kind hacky way to get the amount of ptr before <f>, but it works...
break
}
}
}
2023-12-29 19:25:36 +01:00
if ! foundConverter {
values [ i ] = reflect . New ( reflect . PointerTo ( f . Type ( ) ) ) . Interface ( )
2024-01-13 02:01:30 +01:00
converter [ i ] = ssConverter { Converter : nil , RefCount : - 1 }
2023-12-29 19:25:36 +01:00
}
2022-12-22 15:49:10 +01:00
}
return nil
}
2022-12-22 15:59:12 +01:00
// fieldsByTraversal forked from github.com/jmoiron/sqlx@v1.3.5/sqlx.go
func fieldsByTraversalBase ( v reflect . Value , traversals [ ] [ ] int , values [ ] interface { } , ptrs bool ) error {
v = reflect . Indirect ( v )
if v . Kind ( ) != reflect . Struct {
return errors . New ( "argument not a struct" )
}
for i , traversal := range traversals {
if len ( traversal ) == 0 {
values [ i ] = new ( interface { } )
continue
}
f := reflectx . FieldByIndexes ( v , traversal )
if ptrs {
values [ i ] = f . Addr ( ) . Interface ( )
} else {
values [ i ] = f . Interface ( )
}
}
return nil
}
2022-12-22 15:49:10 +01:00
// missingFields forked from github.com/jmoiron/sqlx@v1.3.5/sqlx.go
func missingFields ( transversals [ ] [ ] int ) ( field int , err error ) {
for i , t := range transversals {
if len ( t ) == 0 {
return i , errors . New ( "missing field" )
}
}
return 0 , nil
}