goext/sq/scanner.go

412 lines
8.9 KiB
Go
Raw Permalink Normal View History

2022-12-11 03:12:02 +01:00
package sq
import (
2023-05-28 19:41:24 +02:00
"context"
2022-12-11 03:12:02 +01:00
"database/sql"
"errors"
2024-02-09 15:40:09 +01:00
"fmt"
2022-12-11 03:12:02 +01:00
"github.com/jmoiron/sqlx"
2024-02-09 15:17:51 +01:00
"gogs.mikescher.com/BlackForestBytes/goext/exerr"
"gogs.mikescher.com/BlackForestBytes/goext/langext"
2024-02-09 15:40:09 +01:00
"reflect"
2022-12-11 03:12:02 +01:00
)
2022-12-22 15:49:10 +01:00
type StructScanMode string
const (
2023-12-29 19:25:36 +01:00
SModeFast StructScanMode = "FAST" // Use default sq.Scan, does not work with joined/resolved types and/or custom value converter
SModeExtended StructScanMode = "EXTENDED" // Fully featured perhaps (?) a tiny bit slower - default
2022-12-22 15:49:10 +01:00
)
type StructScanSafety string
const (
2023-05-28 19:41:24 +02:00
Safe StructScanSafety = "SAFE" // return error for missing fields
Unsafe StructScanSafety = "UNSAFE" // ignore missing fields
2022-12-22 15:49:10 +01:00
)
2023-05-28 19:41:24 +02:00
func InsertSingle[TData any](ctx context.Context, q Queryable, tableName string, v TData) (sql.Result, error) {
2024-02-09 15:17:51 +01:00
sqlstr, pp, err := BuildInsertStatement(q, tableName, v)
if err != nil {
return nil, err
}
sqlr, err := q.Exec(ctx, sqlstr, pp)
if err != nil {
return nil, err
}
2023-05-28 19:41:24 +02:00
2024-02-09 15:17:51 +01:00
return sqlr, nil
}
2023-05-28 19:41:24 +02:00
2024-02-09 15:58:21 +01:00
func InsertAndQuerySingle[TData any](ctx context.Context, q Queryable, tableName string, v TData, idColumn string, mode StructScanMode, sec StructScanSafety) (TData, error) {
rval := reflect.ValueOf(v)
2024-02-12 18:17:49 +01:00
idRVal := fieldByTag(rval, "db", idColumn)
2024-02-09 15:58:21 +01:00
if !idRVal.IsValid() || idRVal.IsZero() {
return *new(TData), fmt.Errorf("failed to find idColumn '%s' in %T", idColumn, v)
}
idValue, err := convertValueToDB(q, idRVal.Interface())
if err != nil {
return *new(TData), err
}
_, err = InsertSingle[TData](ctx, q, tableName, v)
if err != nil {
return *new(TData), err
}
pp := PP{}
//goland:noinspection ALL
sqlstr := fmt.Sprintf("SELECT * FROM %s WHERE %s = :%s", tableName, idColumn, pp.Add(idValue))
return QuerySingle[TData](ctx, q, sqlstr, pp, mode, sec)
}
2024-02-12 18:17:49 +01:00
func fieldByTag(rval reflect.Value, tagkey string, tagval string) reflect.Value {
rtyp := rval.Type()
for i := 0; i < rtyp.NumField(); i++ {
rsfield := rtyp.Field(i)
if !rsfield.IsExported() {
continue
}
if rsfield.Tag.Get(tagkey) == tagval {
return rval.Field(i)
}
}
panic(fmt.Sprintf("tag %s = '%s' not found in %s", tagkey, tagval, rtyp.Name()))
}
2024-02-09 15:17:51 +01:00
func InsertMultiple[TData any](ctx context.Context, q Queryable, tableName string, vArr []TData, maxBatch int) ([]sql.Result, error) {
2023-05-28 19:41:24 +02:00
2024-02-09 15:17:51 +01:00
if len(vArr) == 0 {
return make([]sql.Result, 0), nil
}
2023-05-28 19:41:24 +02:00
2024-02-09 15:17:51 +01:00
chunks := langext.ArrChunk(vArr, maxBatch)
2023-05-28 19:41:24 +02:00
2024-02-09 15:17:51 +01:00
sqlstrArr := make([]string, 0)
ppArr := make([]PP, 0)
for _, chunk := range chunks {
sqlstr, pp, err := BuildInsertMultipleStatement(q, tableName, chunk)
if err != nil {
return nil, exerr.Wrap(err, "").Build()
2023-05-28 19:41:24 +02:00
}
2024-02-09 15:17:51 +01:00
sqlstrArr = append(sqlstrArr, sqlstr)
ppArr = append(ppArr, pp)
}
2023-05-28 19:41:24 +02:00
2024-02-09 15:17:51 +01:00
res := make([]sql.Result, 0, len(sqlstrArr))
2023-12-29 19:25:36 +01:00
2024-02-09 15:17:51 +01:00
for i := 0; i < len(sqlstrArr); i++ {
sqlr, err := q.Exec(ctx, sqlstrArr[i], ppArr[i])
2023-12-29 19:25:36 +01:00
if err != nil {
return nil, err
}
2024-02-09 15:17:51 +01:00
res = append(res, sqlr)
2023-05-28 19:41:24 +02:00
}
2024-02-09 15:17:51 +01:00
return res, nil
}
func UpdateSingle[TData any](ctx context.Context, q Queryable, tableName string, v TData, idColumn string) (sql.Result, error) {
sqlstr, pp, err := BuildUpdateStatement(q, tableName, v, idColumn)
if err != nil {
return nil, err
}
2023-05-28 19:41:24 +02:00
sqlr, err := q.Exec(ctx, sqlstr, pp)
if err != nil {
return nil, err
}
return sqlr, nil
}
2024-02-09 15:40:09 +01:00
func UpdateAndQuerySingle[TData any](ctx context.Context, q Queryable, tableName string, v TData, idColumn string, mode StructScanMode, sec StructScanSafety) (TData, error) {
rval := reflect.ValueOf(v)
2024-02-12 18:17:49 +01:00
idRVal := fieldByTag(rval, "db", idColumn)
2024-02-09 15:40:09 +01:00
if !idRVal.IsValid() || idRVal.IsZero() {
return *new(TData), fmt.Errorf("failed to find idColumn '%s' in %T", idColumn, v)
}
idValue, err := convertValueToDB(q, idRVal.Interface())
if err != nil {
return *new(TData), err
}
_, err = UpdateSingle[TData](ctx, q, tableName, v, idColumn)
if err != nil {
return *new(TData), err
}
pp := PP{}
//goland:noinspection ALL
sqlstr := fmt.Sprintf("SELECT * FROM %s WHERE %s = :%s", tableName, idColumn, pp.Add(idValue))
return QuerySingle[TData](ctx, q, sqlstr, pp, mode, sec)
}
2023-05-28 19:41:24 +02:00
func QuerySingle[TData any](ctx context.Context, q Queryable, sql string, pp PP, mode StructScanMode, sec StructScanSafety) (TData, error) {
rows, err := q.Query(ctx, sql, pp)
if err != nil {
return *new(TData), err
}
2023-12-29 19:25:36 +01:00
data, err := ScanSingle[TData](ctx, q, rows, mode, sec, true)
2023-05-28 19:41:24 +02:00
if err != nil {
return *new(TData), err
}
return data, nil
}
2024-02-09 15:20:46 +01:00
func QuerySingleOpt[TData any](ctx context.Context, q Queryable, sqlstr string, pp PP, mode StructScanMode, sec StructScanSafety) (*TData, error) {
rows, err := q.Query(ctx, sqlstr, pp)
if err != nil {
return nil, err
}
data, err := ScanSingle[TData](ctx, q, rows, mode, sec, true)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
if err != nil {
return nil, err
}
return &data, nil
}
2023-05-28 19:41:24 +02:00
func QueryAll[TData any](ctx context.Context, q Queryable, sql string, pp PP, mode StructScanMode, sec StructScanSafety) ([]TData, error) {
rows, err := q.Query(ctx, sql, pp)
if err != nil {
return nil, err
}
2023-12-29 19:25:36 +01:00
data, err := ScanAll[TData](ctx, q, rows, mode, sec, true)
2023-05-28 19:41:24 +02:00
if err != nil {
return nil, err
}
return data, nil
}
2023-12-29 19:25:36 +01:00
func ScanSingle[TData any](ctx context.Context, q Queryable, rows *sqlx.Rows, mode StructScanMode, sec StructScanSafety, close bool) (TData, error) {
2022-12-11 03:12:02 +01:00
if rows.Next() {
2022-12-22 15:49:10 +01:00
var strscan *StructScanner
if sec == Safe {
strscan = NewStructScanner(rows, false)
var data TData
err := strscan.Start(&data)
if err != nil {
return *new(TData), err
}
} else if sec == Unsafe {
strscan = NewStructScanner(rows, true)
var data TData
err := strscan.Start(&data)
if err != nil {
return *new(TData), err
}
} else {
return *new(TData), errors.New("unknown value for <sec>")
}
2022-12-11 03:12:02 +01:00
var data TData
2022-12-22 15:49:10 +01:00
if mode == SModeFast {
err := strscan.StructScanBase(&data)
if err != nil {
return *new(TData), err
}
} else if mode == SModeExtended {
2023-12-29 19:25:36 +01:00
err := strscan.StructScanExt(q, &data)
2022-12-22 15:49:10 +01:00
if err != nil {
return *new(TData), err
}
} else {
return *new(TData), errors.New("unknown value for <mode>")
2022-12-11 03:12:02 +01:00
}
2022-12-22 15:49:10 +01:00
2022-12-11 03:12:02 +01:00
if rows.Next() {
2022-12-22 15:49:10 +01:00
if close {
_ = rows.Close()
}
return *new(TData), errors.New("sql returned more than one row")
2022-12-11 03:12:02 +01:00
}
2022-12-22 15:49:10 +01:00
2022-12-11 03:12:02 +01:00
if close {
2022-12-22 15:49:10 +01:00
err := rows.Close()
2022-12-11 03:12:02 +01:00
if err != nil {
return *new(TData), err
}
}
2022-12-22 15:49:10 +01:00
2022-12-22 15:55:32 +01:00
if err := rows.Err(); err != nil {
return *new(TData), err
}
2023-12-29 19:25:36 +01:00
if err := ctx.Err(); err != nil {
return *new(TData), err
}
2022-12-11 03:12:02 +01:00
return data, nil
2022-12-22 15:49:10 +01:00
2022-12-11 03:12:02 +01:00
} else {
if close {
_ = rows.Close()
}
return *new(TData), sql.ErrNoRows
}
}
2023-12-29 19:25:36 +01:00
func ScanAll[TData any](ctx context.Context, q Queryable, rows *sqlx.Rows, mode StructScanMode, sec StructScanSafety, close bool) ([]TData, error) {
2022-12-22 15:49:10 +01:00
var strscan *StructScanner
if sec == Safe {
strscan = NewStructScanner(rows, false)
2022-12-11 03:12:02 +01:00
var data TData
2022-12-22 15:49:10 +01:00
err := strscan.Start(&data)
2022-12-11 03:12:02 +01:00
if err != nil {
return nil, err
}
2022-12-22 15:49:10 +01:00
} else if sec == Unsafe {
strscan = NewStructScanner(rows, true)
var data TData
err := strscan.Start(&data)
if err != nil {
return nil, err
}
} else {
return nil, errors.New("unknown value for <sec>")
}
res := make([]TData, 0)
for rows.Next() {
2023-12-29 19:25:36 +01:00
if err := ctx.Err(); err != nil {
return nil, err
}
2022-12-22 15:49:10 +01:00
if mode == SModeFast {
var data TData
err := strscan.StructScanBase(&data)
if err != nil {
return nil, err
}
res = append(res, data)
} else if mode == SModeExtended {
var data TData
2023-12-29 19:25:36 +01:00
err := strscan.StructScanExt(q, &data)
2022-12-22 15:49:10 +01:00
if err != nil {
return nil, err
}
res = append(res, data)
} else {
return nil, errors.New("unknown value for <mode>")
}
2022-12-11 03:12:02 +01:00
}
if close {
2022-12-22 15:49:10 +01:00
err := strscan.rows.Close()
2022-12-11 03:12:02 +01:00
if err != nil {
return nil, err
}
}
2022-12-22 15:55:32 +01:00
if err := rows.Err(); err != nil {
return nil, err
}
2022-12-11 03:12:02 +01:00
return res, nil
}
2024-03-11 16:40:41 +01:00
func IterateAll[TData any](ctx context.Context, q Queryable, rows *sqlx.Rows, mode StructScanMode, sec StructScanSafety, close bool, consumer func(v TData) error) (int, error) {
var strscan *StructScanner
if sec == Safe {
strscan = NewStructScanner(rows, false)
var data TData
err := strscan.Start(&data)
if err != nil {
return 0, err
}
} else if sec == Unsafe {
strscan = NewStructScanner(rows, true)
var data TData
err := strscan.Start(&data)
if err != nil {
return 0, err
}
} else {
return 0, errors.New("unknown value for <sec>")
}
rcount := 0
for rows.Next() {
if err := ctx.Err(); err != nil {
return rcount, err
}
if mode == SModeFast {
var data TData
err := strscan.StructScanBase(&data)
if err != nil {
return rcount, err
}
err = consumer(data)
if err != nil {
return rcount, exerr.Wrap(err, "").Build()
}
rcount++
} else if mode == SModeExtended {
var data TData
err := strscan.StructScanExt(q, &data)
if err != nil {
return rcount, err
}
err = consumer(data)
if err != nil {
return rcount, exerr.Wrap(err, "").Build()
}
rcount++
} else {
return rcount, errors.New("unknown value for <mode>")
}
}
if close {
err := strscan.rows.Close()
if err != nil {
return rcount, err
}
}
if err := rows.Err(); err != nil {
return rcount, err
}
return rcount, nil
}