From aae8a706e90b3c3891b1c201ce68e5c6d8873a4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mike=20Schw=C3=B6rer?= Date: Sat, 13 Jan 2024 02:01:30 +0100 Subject: [PATCH] v0.0.369 autom. allow usage of existing converter for pointer-types (sq) --- goextVersion.go | 4 +-- sq/sq_test.go | 59 +++++++++++++++++++++++++++++++++++++++++++++ sq/structscanner.go | 42 +++++++++++++++++++++++++------- 3 files changed, 94 insertions(+), 11 deletions(-) diff --git a/goextVersion.go b/goextVersion.go index b967936..5cfbdf5 100644 --- a/goextVersion.go +++ b/goextVersion.go @@ -1,5 +1,5 @@ package goext -const GoextVersion = "0.0.368" +const GoextVersion = "0.0.369" -const GoextVersionTimestamp = "2024-01-13T01:29:40+0100" +const GoextVersionTimestamp = "2024-01-13T02:01:30+0100" diff --git a/sq/sq_test.go b/sq/sq_test.go index 7e8aa41..bc36bb3 100644 --- a/sq/sq_test.go +++ b/sq/sq_test.go @@ -93,3 +93,62 @@ func TestTypeConverter2(t *testing.T) { tst.AssertEqual(t, "002", r.ID) tst.AssertEqual(t, t0.UnixNano(), r.Timestamp.UnixNano()) } + +func TestTypeConverter3(t *testing.T) { + + if !langext.InArray("sqlite3", sql.Drivers()) { + sqlite.RegisterAsSQLITE3() + } + + type RequestData struct { + ID string `db:"id"` + Timestamp *rfctime.UnixMilliTime `db:"timestamp"` + } + + ctx := context.Background() + + dbdir := t.TempDir() + dbfile1 := filepath.Join(dbdir, langext.MustHexUUID()+".sqlite3") + + tst.AssertNoErr(t, os.MkdirAll(dbdir, os.ModePerm)) + + url := fmt.Sprintf("file:%s?_pragma=journal_mode(%s)&_pragma=timeout(%d)&_pragma=foreign_keys(%s)&_pragma=busy_timeout(%d)", dbfile1, "DELETE", 1000, "true", 1000) + + xdb := tst.Must(sqlx.Open("sqlite", url))(t) + + db := NewDB(xdb) + db.RegisterDefaultConverter() + + _, err := db.Exec(ctx, "CREATE TABLE `requests` ( id TEXT NOT NULL, timestamp INTEGER NULL, PRIMARY KEY (id) ) STRICT", PP{}) + tst.AssertNoErr(t, err) + + t0 := rfctime.NewUnixMilli(time.Date(2012, 03, 01, 16, 0, 0, 0, time.UTC)) + + _, err = InsertSingle(ctx, db, "requests", RequestData{ + ID: "001", + Timestamp: &t0, + }) + tst.AssertNoErr(t, err) + + _, err = InsertSingle(ctx, db, "requests", RequestData{ + ID: "002", + Timestamp: nil, + }) + tst.AssertNoErr(t, err) + + { + r1, err := QuerySingle[RequestData](ctx, db, "SELECT * FROM requests WHERE id = '001'", PP{}, SModeExtended, Safe) + tst.AssertNoErr(t, err) + fmt.Printf("%+v\n", r1) + tst.AssertEqual(t, "001", r1.ID) + tst.AssertEqual(t, t0.UnixNano(), r1.Timestamp.UnixNano()) + } + + { + r2, err := QuerySingle[RequestData](ctx, db, "SELECT * FROM requests WHERE id = '002'", PP{}, SModeExtended, Safe) + tst.AssertNoErr(t, err) + fmt.Printf("%+v\n", r2) + tst.AssertEqual(t, "002", r2.ID) + tst.AssertEqual(t, nil, r2.Timestamp) + } +} diff --git a/sq/structscanner.go b/sq/structscanner.go index ce15d4e..58aa778 100644 --- a/sq/structscanner.go +++ b/sq/structscanner.go @@ -7,6 +7,7 @@ import ( "github.com/jmoiron/sqlx/reflectx" "gogs.mikescher.com/BlackForestBytes/goext/langext" "reflect" + "strings" ) // forked from sqlx, but added ability to unmarshal optional-nested structs @@ -18,7 +19,7 @@ type StructScanner struct { fields [][]int values []any - converter []DBTypeConverter + converter []ssConverter columns []string } @@ -30,6 +31,11 @@ func NewStructScanner(rows *sqlx.Rows, unsafe bool) *StructScanner { } } +type ssConverter struct { + Converter DBTypeConverter + RefCount int +} + func (r *StructScanner) Start(dest any) error { v := reflect.ValueOf(dest) @@ -49,7 +55,7 @@ func (r *StructScanner) Start(dest any) error { return fmt.Errorf("missing destination name %s in %T", columns[f], dest) } r.values = make([]interface{}, len(columns)) - r.converter = make([]DBTypeConverter, len(columns)) + r.converter = make([]ssConverter, len(columns)) return nil } @@ -143,13 +149,19 @@ func (r *StructScanner) StructScanExt(q Queryable, dest any) error { f.Set(reflect.Zero(f.Type())) // set to nil } else { - if r.converter[i] != nil { - val3 := val2.Elem().Interface() - conv3, err := r.converter[i].DBToModel(val3) + if r.converter[i].Converter != nil { + val3 := val2.Elem() + conv3, err := r.converter[i].Converter.DBToModel(val3.Interface()) if err != nil { return err } - f.Set(reflect.ValueOf(conv3)) + conv3RVal := reflect.ValueOf(conv3) + for j := 0; j < r.converter[i].RefCount; j++ { + newConv3Val := reflect.New(conv3RVal.Type()) + newConv3Val.Elem().Set(conv3RVal) + conv3RVal = newConv3Val + } + f.Set(conv3RVal) } else { f.Set(val2.Elem()) } @@ -184,7 +196,7 @@ func (r *StructScanner) StructScanBase(dest any) error { } // fieldsByTraversal forked from github.com/jmoiron/sqlx@v1.3.5/sqlx.go -func fieldsByTraversalExtended(q Queryable, v reflect.Value, traversals [][]int, values []interface{}, converter []DBTypeConverter) error { +func fieldsByTraversalExtended(q Queryable, v reflect.Value, traversals [][]int, values []interface{}, converter []ssConverter) error { v = reflect.Indirect(v) if v.Kind() != reflect.Struct { return errors.New("argument not a struct") @@ -205,14 +217,26 @@ func fieldsByTraversalExtended(q Queryable, v reflect.Value, traversals [][]int, _v := langext.Ptr[any](nil) values[i] = _v foundConverter = true - converter[i] = conv + converter[i] = ssConverter{Converter: conv, RefCount: 0} break } } + if !foundConverter { + // also allow non-pointer converter for pointer-types + for _, conv := range q.ListConverter() { + if conv.ModelTypeString() == strings.TrimLeft(typeStr, "*") { + _v := langext.Ptr[any](nil) + values[i] = _v + foundConverter = true + converter[i] = ssConverter{Converter: conv, RefCount: len(typeStr) - len(strings.TrimLeft(typeStr, "*"))} // kind hacky way to get the amount of ptr before , but it works... + break + } + } + } if !foundConverter { values[i] = reflect.New(reflect.PointerTo(f.Type())).Interface() - converter[i] = nil + converter[i] = ssConverter{Converter: nil, RefCount: -1} } } return nil