diff --git a/goextVersion.go b/goextVersion.go index db68496..4380b67 100644 --- a/goextVersion.go +++ b/goextVersion.go @@ -1,5 +1,5 @@ package goext -const GoextVersion = "0.0.382" +const GoextVersion = "0.0.383" -const GoextVersionTimestamp = "2024-02-09T12:25:01+0100" +const GoextVersionTimestamp = "2024-02-09T15:17:51+0100" diff --git a/langext/array.go b/langext/array.go index fb08de7..1b80cee 100644 --- a/langext/array.go +++ b/langext/array.go @@ -479,3 +479,33 @@ func JoinString(arr []string, delimiter string) string { return str } + +// ArrChunk splits the array into buckets of max-size `chunkSize` +// order is being kept. +// The last chunk may contain less than length elements. +// +// (chunkSize == -1) means no chunking +// +// see https://www.php.net/manual/en/function.array-chunk.php +func ArrChunk[T any](arr []T, chunkSize int) [][]T { + if chunkSize == -1 { + return [][]T{arr} + } + + res := make([][]T, 0, 1+len(arr)/chunkSize) + + i := 0 + for i < len(arr) { + + right := i + chunkSize + if right >= len(arr) { + right = len(arr) + } + + res = append(res, arr[i:right]) + + i = right + } + + return res +} diff --git a/sq/builder.go b/sq/builder.go index 92e2a17..f337d53 100644 --- a/sq/builder.go +++ b/sq/builder.go @@ -1,13 +1,14 @@ package sq import ( + "errors" "fmt" "gogs.mikescher.com/BlackForestBytes/goext/exerr" "reflect" "strings" ) -func BuildUpdateStatement(q Queryable, tableName string, obj any, idColumn string) (string, PP, error) { +func BuildUpdateStatement[TData any](q Queryable, tableName string, obj TData, idColumn string) (string, PP, error) { rval := reflect.ValueOf(obj) rtyp := rval.Type() @@ -70,7 +71,7 @@ func BuildUpdateStatement(q Queryable, tableName string, obj any, idColumn strin return fmt.Sprintf("UPDATE %s SET %s WHERE %s", tableName, strings.Join(setClauses, ", "), matchClause), params, nil } -func BuildInsertStatement(q Queryable, tableName string, obj any) (string, PP, error) { +func BuildInsertStatement[TData any](q Queryable, tableName string, obj TData) (string, PP, error) { rval := reflect.ValueOf(obj) rtyp := rval.Type() @@ -118,3 +119,81 @@ func BuildInsertStatement(q Queryable, tableName string, obj any) (string, PP, e //goland:noinspection SqlNoDataSourceInspection return fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", tableName, strings.Join(fields, ", "), strings.Join(values, ", ")), params, nil } + +func BuildInsertMultipleStatement[TData any](q Queryable, tableName string, vArr []TData) (string, PP, error) { + + if len(vArr) == 0 { + return "", nil, errors.New("no data supplied") + } + + rtyp := reflect.ValueOf(vArr[0]).Type() + + sqlPrefix := "" + { + columns := make([]string, 0) + + for i := 0; i < rtyp.NumField(); i++ { + rsfield := rtyp.Field(i) + + if !rsfield.IsExported() { + continue + } + + columnName := rsfield.Tag.Get("db") + if columnName == "" || columnName == "-" { + continue + } + + columns = append(columns, "\""+columnName+"\"") + } + + sqlPrefix = fmt.Sprintf("INSERT"+" INTO \"%s\" (%s) VALUES", tableName, strings.Join(columns, ", ")) + } + + pp := PP{} + + sqlValuesArr := make([]string, 0) + + for _, v := range vArr { + + rval := reflect.ValueOf(v) + + params := make([]string, 0) + + for i := 0; i < rtyp.NumField(); i++ { + + rsfield := rtyp.Field(i) + rvfield := rval.Field(i) + + if !rsfield.IsExported() { + continue + } + + columnName := rsfield.Tag.Get("db") + if columnName == "" || columnName == "-" { + continue + } + + if rsfield.Type.Kind() == reflect.Ptr && rvfield.IsNil() { + + params = append(params, "NULL") + + } else { + + val, err := convertValueToDB(q, rvfield.Interface()) + if err != nil { + return "", nil, err + } + + params = append(params, ":"+pp.Add(val)) + + } + } + + sqlValuesArr = append(sqlValuesArr, fmt.Sprintf("(%s)", strings.Join(params, ", "))) + } + + sqlstr := fmt.Sprintf("%s %s", sqlPrefix, strings.Join(sqlValuesArr, ", ")) + + return sqlstr, pp, nil +} diff --git a/sq/main_test.go b/sq/main_test.go new file mode 100644 index 0000000..ffea457 --- /dev/null +++ b/sq/main_test.go @@ -0,0 +1,15 @@ +package sq + +import ( + "gogs.mikescher.com/BlackForestBytes/goext/exerr" + "gogs.mikescher.com/BlackForestBytes/goext/langext" + "os" + "testing" +) + +func TestMain(m *testing.M) { + if !exerr.Initialized() { + exerr.Init(exerr.ErrorPackageConfigInit{ZeroLogErrTraces: langext.PFalse, ZeroLogAllTraces: langext.PFalse}) + } + os.Exit(m.Run()) +} diff --git a/sq/scanner.go b/sq/scanner.go index 949c20a..fd03ab7 100644 --- a/sq/scanner.go +++ b/sq/scanner.go @@ -4,10 +4,9 @@ import ( "context" "database/sql" "errors" - "fmt" "github.com/jmoiron/sqlx" - "reflect" - "strings" + "gogs.mikescher.com/BlackForestBytes/goext/exerr" + "gogs.mikescher.com/BlackForestBytes/goext/langext" ) type StructScanMode string @@ -26,42 +25,61 @@ const ( func InsertSingle[TData any](ctx context.Context, q Queryable, tableName string, v TData) (sql.Result, error) { - rval := reflect.ValueOf(v) - rtyp := rval.Type() + sqlstr, pp, err := BuildInsertStatement(q, tableName, v) + if err != nil { + return nil, err + } - columns := make([]string, 0) - params := make([]string, 0) - pp := PP{} + sqlr, err := q.Exec(ctx, sqlstr, pp) + if err != nil { + return nil, err + } - for i := 0; i < rtyp.NumField(); i++ { + return sqlr, nil +} - rsfield := rtyp.Field(i) - rvfield := rval.Field(i) +func InsertMultiple[TData any](ctx context.Context, q Queryable, tableName string, vArr []TData, maxBatch int) ([]sql.Result, error) { - if !rsfield.IsExported() { - continue + if len(vArr) == 0 { + return make([]sql.Result, 0), nil + } + + chunks := langext.ArrChunk(vArr, maxBatch) + + sqlstrArr := make([]string, 0) + ppArr := make([]PP, 0) + + for _, chunk := range chunks { + + sqlstr, pp, err := BuildInsertMultipleStatement(q, tableName, chunk) + if err != nil { + return nil, exerr.Wrap(err, "").Build() } - columnName := rsfield.Tag.Get("db") - if columnName == "" || columnName == "-" { - continue - } + sqlstrArr = append(sqlstrArr, sqlstr) + ppArr = append(ppArr, pp) + } - paramkey := fmt.Sprintf("_%s", columnName) + res := make([]sql.Result, 0, len(sqlstrArr)) - columns = append(columns, "\""+columnName+"\"") - params = append(params, ":"+paramkey) - - val, err := convertValueToDB(q, rvfield.Interface()) + for i := 0; i < len(sqlstrArr); i++ { + sqlr, err := q.Exec(ctx, sqlstrArr[i], ppArr[i]) if err != nil { return nil, err } - pp[paramkey] = val - + res = append(res, sqlr) } - sqlstr := fmt.Sprintf("INSERT"+" INTO \"%s\" (%s) VALUES (%s)", tableName, strings.Join(columns, ", "), strings.Join(params, ", ")) + return res, nil +} + +func UpdateSingle[TData any](ctx context.Context, q Queryable, tableName string, v TData, idColumn string) (sql.Result, error) { + + sqlstr, pp, err := BuildUpdateStatement(q, tableName, v, idColumn) + if err != nil { + return nil, err + } sqlr, err := q.Exec(ctx, sqlstr, pp) if err != nil { diff --git a/sq/scanner_test.go b/sq/scanner_test.go new file mode 100644 index 0000000..7ec7d1c --- /dev/null +++ b/sq/scanner_test.go @@ -0,0 +1,237 @@ +package sq + +import ( + "context" + "database/sql" + "fmt" + "github.com/glebarez/go-sqlite" + "github.com/jmoiron/sqlx" + "gogs.mikescher.com/BlackForestBytes/goext/langext" + "gogs.mikescher.com/BlackForestBytes/goext/tst" + "path/filepath" + "testing" +) + +func TestInsertSingle(t *testing.T) { + + type request struct { + ID string `json:"id" db:"id"` + Timestamp int `json:"timestamp" db:"timestamp"` + StrVal string `json:"strVal" db:"str_val"` + FloatVal float64 `json:"floatVal" db:"float_val"` + Dummy bool `json:"dummyBool" db:"dummy_bool"` + JsonVal JsonObj `json:"jsonVal" db:"json_val"` + } + + if !langext.InArray("sqlite3", sql.Drivers()) { + sqlite.RegisterAsSQLITE3() + } + + ctx := context.Background() + + dbdir := t.TempDir() + dbfile1 := filepath.Join(dbdir, langext.MustHexUUID()+".sqlite3") + + url := fmt.Sprintf("file:%s?_pragma=journal_mode(%s)&_pragma=timeout(%d)&_pragma=foreign_keys(%s)&_pragma=busy_timeout(%d)", dbfile1, "DELETE", 1000, "true", 1000) + + xdb := tst.Must(sqlx.Open("sqlite", url))(t) + + db := NewDB(xdb) + db.RegisterDefaultConverter() + + _, err := db.Exec(ctx, ` + CREATE TABLE requests ( + id TEXT NOT NULL, + timestamp INTEGER NOT NULL, + str_val TEXT NOT NULL, + float_val REAL NOT NULL, + dummy_bool INTEGER NOT NULL CHECK(dummy_bool IN (0, 1)), + json_val TEXT NOT NULL, + PRIMARY KEY (id) + ) STRICT +`, PP{}) + tst.AssertNoErr(t, err) + + _, err = InsertSingle(ctx, db, "requests", request{ + ID: "9927", + Timestamp: 12321, + StrVal: "hello world", + Dummy: true, + FloatVal: 3.14159, + JsonVal: JsonObj{ + "firs": 1, + "second": true, + }, + }) + tst.AssertNoErr(t, err) +} + +func TestUpdateSingle(t *testing.T) { + + type request struct { + ID string `json:"id" db:"id"` + Timestamp int `json:"timestamp" db:"timestamp"` + StrVal string `json:"strVal" db:"str_val"` + FloatVal float64 `json:"floatVal" db:"float_val"` + Dummy bool `json:"dummyBool" db:"dummy_bool"` + JsonVal JsonObj `json:"jsonVal" db:"json_val"` + } + + if !langext.InArray("sqlite3", sql.Drivers()) { + sqlite.RegisterAsSQLITE3() + } + + ctx := context.Background() + + dbdir := t.TempDir() + dbfile1 := filepath.Join(dbdir, langext.MustHexUUID()+".sqlite3") + + url := fmt.Sprintf("file:%s?_pragma=journal_mode(%s)&_pragma=timeout(%d)&_pragma=foreign_keys(%s)&_pragma=busy_timeout(%d)", dbfile1, "DELETE", 1000, "true", 1000) + + xdb := tst.Must(sqlx.Open("sqlite", url))(t) + + db := NewDB(xdb) + db.RegisterDefaultConverter() + + _, err := db.Exec(ctx, ` + CREATE TABLE requests ( + id TEXT NOT NULL, + timestamp INTEGER NOT NULL, + str_val TEXT NOT NULL, + float_val REAL NOT NULL, + dummy_bool INTEGER NOT NULL CHECK(dummy_bool IN (0, 1)), + json_val TEXT NOT NULL, + PRIMARY KEY (id) + ) STRICT +`, PP{}) + tst.AssertNoErr(t, err) + + _, err = InsertSingle(ctx, db, "requests", request{ + ID: "9927", + Timestamp: 12321, + StrVal: "hello world", + Dummy: true, + FloatVal: 3.14159, + JsonVal: JsonObj{ + "first": 1, + "second": true, + }, + }) + tst.AssertNoErr(t, err) + + v, err := QuerySingle[request](ctx, db, "SELECT * FROM requests WHERE id = '9927' LIMIT 1", PP{}, SModeExtended, Safe) + tst.AssertNoErr(t, err) + + tst.AssertEqual(t, v.Timestamp, 12321) + tst.AssertEqual(t, v.StrVal, "hello world") + tst.AssertEqual(t, v.Dummy, true) + tst.AssertEqual(t, v.FloatVal, 3.14159) + tst.AssertStrRepEqual(t, v.JsonVal["first"], 1) + tst.AssertStrRepEqual(t, v.JsonVal["second"], true) + + _, err = UpdateSingle(ctx, db, "requests", request{ + ID: "9927", + Timestamp: 9999, + StrVal: "9999 hello world", + Dummy: false, + FloatVal: 123.222, + JsonVal: JsonObj{ + "first": 2, + "second": false, + }, + }, "id") + + v, err = QuerySingle[request](ctx, db, "SELECT * FROM requests WHERE id = '9927' LIMIT 1", PP{}, SModeExtended, Safe) + tst.AssertNoErr(t, err) + + tst.AssertEqual(t, v.Timestamp, 9999) + tst.AssertEqual(t, v.StrVal, "9999 hello world") + tst.AssertEqual(t, v.Dummy, false) + tst.AssertEqual(t, v.FloatVal, 123.222) + tst.AssertStrRepEqual(t, v.JsonVal["first"], 2) + tst.AssertStrRepEqual(t, v.JsonVal["second"], false) +} + +func TestInsertMultiple(t *testing.T) { + + type request struct { + ID string `json:"id" db:"id"` + Timestamp int `json:"timestamp" db:"timestamp"` + StrVal string `json:"strVal" db:"str_val"` + FloatVal float64 `json:"floatVal" db:"float_val"` + Dummy bool `json:"dummyBool" db:"dummy_bool"` + JsonVal JsonObj `json:"jsonVal" db:"json_val"` + } + + if !langext.InArray("sqlite3", sql.Drivers()) { + sqlite.RegisterAsSQLITE3() + } + + ctx := context.Background() + + dbdir := t.TempDir() + dbfile1 := filepath.Join(dbdir, langext.MustHexUUID()+".sqlite3") + + url := fmt.Sprintf("file:%s?_pragma=journal_mode(%s)&_pragma=timeout(%d)&_pragma=foreign_keys(%s)&_pragma=busy_timeout(%d)", dbfile1, "DELETE", 1000, "true", 1000) + + xdb := tst.Must(sqlx.Open("sqlite", url))(t) + + db := NewDB(xdb) + db.RegisterDefaultConverter() + + _, err := db.Exec(ctx, ` + CREATE TABLE requests ( + id TEXT NOT NULL, + timestamp INTEGER NOT NULL, + str_val TEXT NOT NULL, + float_val REAL NOT NULL, + dummy_bool INTEGER NOT NULL CHECK(dummy_bool IN (0, 1)), + json_val TEXT NOT NULL, + PRIMARY KEY (id) + ) STRICT +`, PP{}) + tst.AssertNoErr(t, err) + + _, err = InsertMultiple(ctx, db, "requests", []request{ + { + ID: "1", + Timestamp: 1000, + StrVal: "one", + Dummy: true, + FloatVal: 0.1, + JsonVal: JsonObj{ + "arr": []int{0}, + }, + }, + { + ID: "2", + Timestamp: 2000, + StrVal: "two", + Dummy: true, + FloatVal: 0.2, + JsonVal: JsonObj{ + "arr": []int{0, 0}, + }, + }, + { + ID: "3", + Timestamp: 3000, + StrVal: "three", + Dummy: true, + FloatVal: 0.3, + JsonVal: JsonObj{ + "arr": []int{0, 0, 0}, + }, + }, + }, -1) + tst.AssertNoErr(t, err) + + _, err = QuerySingle[request](ctx, db, "SELECT * FROM requests WHERE id = '1' LIMIT 1", PP{}, SModeExtended, Safe) + tst.AssertNoErr(t, err) + + _, err = QuerySingle[request](ctx, db, "SELECT * FROM requests WHERE id = '2' LIMIT 1", PP{}, SModeExtended, Safe) + tst.AssertNoErr(t, err) + + _, err = QuerySingle[request](ctx, db, "SELECT * FROM requests WHERE id = '3' LIMIT 1", PP{}, SModeExtended, Safe) + tst.AssertNoErr(t, err) +} diff --git a/tst/assertions.go b/tst/assertions.go index 5bcfaed..02248e0 100644 --- a/tst/assertions.go +++ b/tst/assertions.go @@ -2,6 +2,7 @@ package tst import ( "encoding/hex" + "fmt" "reflect" "runtime/debug" "testing" @@ -125,3 +126,17 @@ func AssertNoErr(t *testing.T, anerr error) { t.Error("Function returned an error: " + anerr.Error() + "\n" + string(debug.Stack())) } } + +func AssertStrRepEqual(t *testing.T, actual any, expected any) { + t.Helper() + if fmt.Sprintf("%v", actual) != fmt.Sprintf("%v", expected) { + t.Errorf("values differ: Actual: '%v', Expected: '%v'", actual, expected) + } +} + +func AssertStrRepNotEqual(t *testing.T, actual any, expected any) { + t.Helper() + if fmt.Sprintf("%v", actual) == fmt.Sprintf("%v", expected) { + t.Errorf("values do not differ: Actual: '%v', Expected: '%v'", actual, expected) + } +}