package sq import ( "context" "database/sql" "errors" "fmt" "github.com/jmoiron/sqlx" "gogs.mikescher.com/BlackForestBytes/goext/exerr" "gogs.mikescher.com/BlackForestBytes/goext/langext" "reflect" ) type StructScanMode string const ( SModeFast StructScanMode = "FAST" // Use default sq.Scan, does not work with joined/resolved types and/or custom value converter SModeExtended StructScanMode = "EXTENDED" // Fully featured perhaps (?) a tiny bit slower - default ) type StructScanSafety string const ( Safe StructScanSafety = "SAFE" // return error for missing fields Unsafe StructScanSafety = "UNSAFE" // ignore missing fields ) func InsertSingle[TData any](ctx context.Context, q Queryable, tableName string, v TData) (sql.Result, error) { sqlstr, pp, err := BuildInsertStatement(q, tableName, v) if err != nil { return nil, err } sqlr, err := q.Exec(ctx, sqlstr, pp) if err != nil { return nil, err } return sqlr, nil } func InsertAndQuerySingle[TData any](ctx context.Context, q Queryable, tableName string, v TData, idColumn string, mode StructScanMode, sec StructScanSafety) (TData, error) { rval := reflect.ValueOf(v) idRVal := fieldByTag(rval, "db", idColumn) if !idRVal.IsValid() || idRVal.IsZero() { return *new(TData), fmt.Errorf("failed to find idColumn '%s' in %T", idColumn, v) } idValue, err := convertValueToDB(q, idRVal.Interface()) if err != nil { return *new(TData), err } _, err = InsertSingle[TData](ctx, q, tableName, v) if err != nil { return *new(TData), err } pp := PP{} //goland:noinspection ALL sqlstr := fmt.Sprintf("SELECT * FROM %s WHERE %s = :%s", tableName, idColumn, pp.Add(idValue)) return QuerySingle[TData](ctx, q, sqlstr, pp, mode, sec) } func fieldByTag(rval reflect.Value, tagkey string, tagval string) reflect.Value { rtyp := rval.Type() for i := 0; i < rtyp.NumField(); i++ { rsfield := rtyp.Field(i) if !rsfield.IsExported() { continue } if rsfield.Tag.Get(tagkey) == tagval { return rval.Field(i) } } panic(fmt.Sprintf("tag %s = '%s' not found in %s", tagkey, tagval, rtyp.Name())) } func InsertMultiple[TData any](ctx context.Context, q Queryable, tableName string, vArr []TData, maxBatch int) ([]sql.Result, error) { if len(vArr) == 0 { return make([]sql.Result, 0), nil } chunks := langext.ArrChunk(vArr, maxBatch) sqlstrArr := make([]string, 0) ppArr := make([]PP, 0) for _, chunk := range chunks { sqlstr, pp, err := BuildInsertMultipleStatement(q, tableName, chunk) if err != nil { return nil, exerr.Wrap(err, "").Build() } sqlstrArr = append(sqlstrArr, sqlstr) ppArr = append(ppArr, pp) } res := make([]sql.Result, 0, len(sqlstrArr)) for i := 0; i < len(sqlstrArr); i++ { sqlr, err := q.Exec(ctx, sqlstrArr[i], ppArr[i]) if err != nil { return nil, err } res = append(res, sqlr) } return res, nil } func UpdateSingle[TData any](ctx context.Context, q Queryable, tableName string, v TData, idColumn string) (sql.Result, error) { sqlstr, pp, err := BuildUpdateStatement(q, tableName, v, idColumn) if err != nil { return nil, err } sqlr, err := q.Exec(ctx, sqlstr, pp) if err != nil { return nil, err } return sqlr, nil } func UpdateAndQuerySingle[TData any](ctx context.Context, q Queryable, tableName string, v TData, idColumn string, mode StructScanMode, sec StructScanSafety) (TData, error) { rval := reflect.ValueOf(v) idRVal := fieldByTag(rval, "db", idColumn) if !idRVal.IsValid() || idRVal.IsZero() { return *new(TData), fmt.Errorf("failed to find idColumn '%s' in %T", idColumn, v) } idValue, err := convertValueToDB(q, idRVal.Interface()) if err != nil { return *new(TData), err } _, err = UpdateSingle[TData](ctx, q, tableName, v, idColumn) if err != nil { return *new(TData), err } pp := PP{} //goland:noinspection ALL sqlstr := fmt.Sprintf("SELECT * FROM %s WHERE %s = :%s", tableName, idColumn, pp.Add(idValue)) return QuerySingle[TData](ctx, q, sqlstr, pp, mode, sec) } func QuerySingle[TData any](ctx context.Context, q Queryable, sql string, pp PP, mode StructScanMode, sec StructScanSafety) (TData, error) { rows, err := q.Query(ctx, sql, pp) if err != nil { return *new(TData), err } data, err := ScanSingle[TData](ctx, q, rows, mode, sec, true) if err != nil { return *new(TData), err } return data, nil } func QuerySingleOpt[TData any](ctx context.Context, q Queryable, sqlstr string, pp PP, mode StructScanMode, sec StructScanSafety) (*TData, error) { rows, err := q.Query(ctx, sqlstr, pp) if err != nil { return nil, err } data, err := ScanSingle[TData](ctx, q, rows, mode, sec, true) if errors.Is(err, sql.ErrNoRows) { return nil, nil } if err != nil { return nil, err } return &data, nil } func QueryAll[TData any](ctx context.Context, q Queryable, sql string, pp PP, mode StructScanMode, sec StructScanSafety) ([]TData, error) { rows, err := q.Query(ctx, sql, pp) if err != nil { return nil, err } data, err := ScanAll[TData](ctx, q, rows, mode, sec, true) if err != nil { return nil, err } return data, nil } func ScanSingle[TData any](ctx context.Context, q Queryable, rows *sqlx.Rows, mode StructScanMode, sec StructScanSafety, close bool) (TData, error) { if rows.Next() { var strscan *StructScanner if sec == Safe { strscan = NewStructScanner(rows, false) var data TData err := strscan.Start(&data) if err != nil { return *new(TData), err } } else if sec == Unsafe { strscan = NewStructScanner(rows, true) var data TData err := strscan.Start(&data) if err != nil { return *new(TData), err } } else { return *new(TData), errors.New("unknown value for ") } var data TData if mode == SModeFast { err := strscan.StructScanBase(&data) if err != nil { return *new(TData), err } } else if mode == SModeExtended { err := strscan.StructScanExt(q, &data) if err != nil { return *new(TData), err } } else { return *new(TData), errors.New("unknown value for ") } if rows.Next() { if close { _ = rows.Close() } return *new(TData), errors.New("sql returned more than one row") } if close { err := rows.Close() if err != nil { return *new(TData), err } } if err := rows.Err(); err != nil { return *new(TData), err } if err := ctx.Err(); err != nil { return *new(TData), err } return data, nil } else { if close { _ = rows.Close() } return *new(TData), sql.ErrNoRows } } func ScanAll[TData any](ctx context.Context, q Queryable, rows *sqlx.Rows, mode StructScanMode, sec StructScanSafety, close bool) ([]TData, error) { var strscan *StructScanner if sec == Safe { strscan = NewStructScanner(rows, false) var data TData err := strscan.Start(&data) if err != nil { return nil, err } } else if sec == Unsafe { strscan = NewStructScanner(rows, true) var data TData err := strscan.Start(&data) if err != nil { return nil, err } } else { return nil, errors.New("unknown value for ") } res := make([]TData, 0) for rows.Next() { if err := ctx.Err(); err != nil { return nil, err } if mode == SModeFast { var data TData err := strscan.StructScanBase(&data) if err != nil { return nil, err } res = append(res, data) } else if mode == SModeExtended { var data TData err := strscan.StructScanExt(q, &data) if err != nil { return nil, err } res = append(res, data) } else { return nil, errors.New("unknown value for ") } } if close { err := strscan.rows.Close() if err != nil { return nil, err } } if err := rows.Err(); err != nil { return nil, err } return res, nil } func IterateAll[TData any](ctx context.Context, q Queryable, rows *sqlx.Rows, mode StructScanMode, sec StructScanSafety, close bool, consumer func(ctx context.Context, v TData) error) (int, error) { var strscan *StructScanner if sec == Safe { strscan = NewStructScanner(rows, false) var data TData err := strscan.Start(&data) if err != nil { return 0, err } } else if sec == Unsafe { strscan = NewStructScanner(rows, true) var data TData err := strscan.Start(&data) if err != nil { return 0, err } } else { return 0, errors.New("unknown value for ") } rcount := 0 for rows.Next() { if err := ctx.Err(); err != nil { return rcount, err } if mode == SModeFast { var data TData err := strscan.StructScanBase(&data) if err != nil { return rcount, err } err = consumer(ctx, data) if err != nil { return rcount, exerr.Wrap(err, "").Build() } rcount++ } else if mode == SModeExtended { var data TData err := strscan.StructScanExt(q, &data) if err != nil { return rcount, err } err = consumer(ctx, data) if err != nil { return rcount, exerr.Wrap(err, "").Build() } rcount++ } else { return rcount, errors.New("unknown value for ") } } if close { err := strscan.rows.Close() if err != nil { return rcount, err } } if err := rows.Err(); err != nil { return rcount, err } return rcount, nil }