package sq import ( "context" "database/sql" "errors" "fmt" "github.com/jmoiron/sqlx" "reflect" "strings" ) type StructScanMode string const ( SModeFast StructScanMode = "FAST" SModeExtended StructScanMode = "EXTENDED" ) type StructScanSafety string const ( Safe StructScanSafety = "SAFE" // return error for missing fields Unsafe StructScanSafety = "UNSAFE" // ignore missing fields ) func InsertSingle[TData any](ctx context.Context, q Queryable, tableName string, v TData) (sql.Result, error) { rval := reflect.ValueOf(v) rtyp := rval.Type() columns := make([]string, 0) params := make([]string, 0) pp := PP{} for i := 0; i < rtyp.NumField(); i++ { rsfield := rtyp.Field(i) rvfield := rval.Field(i) if !rsfield.IsExported() { continue } columnName := rsfield.Tag.Get("db") if columnName == "" || columnName == "-" { continue } paramkey := fmt.Sprintf("_%s", columnName) columns = append(columns, "\""+columnName+"\"") params = append(params, ":"+paramkey) pp[paramkey] = rvfield.Interface() } sqlstr := fmt.Sprintf("INSERT"+" INTO \"%s\" (%s) VALUES (%s)", tableName, strings.Join(columns, ", "), strings.Join(params, ", ")) sqlr, err := q.Exec(ctx, sqlstr, pp) if err != nil { return nil, err } return sqlr, nil } func QuerySingle[TData any](ctx context.Context, q Queryable, sql string, pp PP, mode StructScanMode, sec StructScanSafety) (TData, error) { rows, err := q.Query(ctx, sql, pp) if err != nil { return *new(TData), err } data, err := ScanSingle[TData](rows, mode, sec, true) if err != nil { return *new(TData), err } return data, nil } func QueryAll[TData any](ctx context.Context, q Queryable, sql string, pp PP, mode StructScanMode, sec StructScanSafety) ([]TData, error) { rows, err := q.Query(ctx, sql, pp) if err != nil { return nil, err } data, err := ScanAll[TData](rows, mode, sec, true) if err != nil { return nil, err } return data, nil } func ScanSingle[TData any](rows *sqlx.Rows, mode StructScanMode, sec StructScanSafety, close bool) (TData, error) { if rows.Next() { var strscan *StructScanner if sec == Safe { strscan = NewStructScanner(rows, false) var data TData err := strscan.Start(&data) if err != nil { return *new(TData), err } } else if sec == Unsafe { strscan = NewStructScanner(rows, true) var data TData err := strscan.Start(&data) if err != nil { return *new(TData), err } } else { return *new(TData), errors.New("unknown value for <sec>") } var data TData if mode == SModeFast { err := strscan.StructScanBase(&data) if err != nil { return *new(TData), err } } else if mode == SModeExtended { err := strscan.StructScanExt(&data) if err != nil { return *new(TData), err } } else { return *new(TData), errors.New("unknown value for <mode>") } if rows.Next() { if close { _ = rows.Close() } return *new(TData), errors.New("sql returned more than one row") } if close { err := rows.Close() if err != nil { return *new(TData), err } } if err := rows.Err(); err != nil { return *new(TData), err } return data, nil } else { if close { _ = rows.Close() } return *new(TData), sql.ErrNoRows } } func ScanAll[TData any](rows *sqlx.Rows, mode StructScanMode, sec StructScanSafety, close bool) ([]TData, error) { var strscan *StructScanner if sec == Safe { strscan = NewStructScanner(rows, false) var data TData err := strscan.Start(&data) if err != nil { return nil, err } } else if sec == Unsafe { strscan = NewStructScanner(rows, true) var data TData err := strscan.Start(&data) if err != nil { return nil, err } } else { return nil, errors.New("unknown value for <sec>") } res := make([]TData, 0) for rows.Next() { if mode == SModeFast { var data TData err := strscan.StructScanBase(&data) if err != nil { return nil, err } res = append(res, data) } else if mode == SModeExtended { var data TData err := strscan.StructScanExt(&data) if err != nil { return nil, err } res = append(res, data) } else { return nil, errors.New("unknown value for <mode>") } } if close { err := strscan.rows.Close() if err != nil { return nil, err } } if err := rows.Err(); err != nil { return nil, err } return res, nil }