package sq import ( "context" "database/sql" "github.com/jmoiron/sqlx" "gogs.mikescher.com/BlackForestBytes/goext/exerr" "gogs.mikescher.com/BlackForestBytes/goext/langext" "sync" ) type DB interface { Queryable Ping(ctx context.Context) error BeginTransaction(ctx context.Context, iso sql.IsolationLevel) (Tx, error) AddListener(listener Listener) Exit() error RegisterConverter(DBTypeConverter) } type DBOptions struct { RegisterDefaultConverter *bool RegisterCommentTrimmer *bool } type database struct { db *sqlx.DB txctr uint16 lock sync.Mutex lstr []Listener conv []DBTypeConverter } func NewDB(db *sqlx.DB, opt DBOptions) DB { sqdb := &database{ db: db, txctr: 0, lock: sync.Mutex{}, lstr: make([]Listener, 0), } if langext.Coalesce(opt.RegisterDefaultConverter, true) { sqdb.registerDefaultConverter() } if langext.Coalesce(opt.RegisterCommentTrimmer, true) { sqdb.AddListener(CommentTrimmer) } return sqdb } func (db *database) AddListener(listener Listener) { db.lstr = append(db.lstr, listener) } func (db *database) Exec(ctx context.Context, sqlstr string, prep PP) (sql.Result, error) { origsql := sqlstr for _, v := range db.lstr { err := v.PreExec(ctx, nil, &sqlstr, &prep) if err != nil { return nil, exerr.Wrap(err, "failed to call SQL pre-exec listener").Str("original_sql", origsql).Str("sql", sqlstr).Any("sql_params", prep).Build() } } res, err := db.db.NamedExecContext(ctx, sqlstr, prep) for _, v := range db.lstr { v.PostExec(nil, origsql, sqlstr, prep) } if err != nil { return nil, exerr.Wrap(err, "Failed to [exec] sql statement").Str("original_sql", origsql).Str("sql", sqlstr).Any("sql_params", prep).Build() } return res, nil } func (db *database) Query(ctx context.Context, sqlstr string, prep PP) (*sqlx.Rows, error) { origsql := sqlstr for _, v := range db.lstr { err := v.PreQuery(ctx, nil, &sqlstr, &prep) if err != nil { return nil, exerr.Wrap(err, "failed to call SQL pre-query listener").Str("original_sql", origsql).Str("sql", sqlstr).Any("sql_params", prep).Build() } } rows, err := sqlx.NamedQueryContext(ctx, db.db, sqlstr, prep) for _, v := range db.lstr { v.PostQuery(nil, origsql, sqlstr, prep) } if err != nil { return nil, exerr.Wrap(err, "Failed to [query] sql statement").Str("original_sql", origsql).Str("sql", sqlstr).Any("sql_params", prep).Build() } return rows, nil } func (db *database) Ping(ctx context.Context) error { for _, v := range db.lstr { err := v.PrePing(ctx) if err != nil { return err } } err := db.db.PingContext(ctx) for _, v := range db.lstr { v.PostPing(err) } if err != nil { return exerr.Wrap(err, "Failed to [ping] sql database").Build() } return nil } func (db *database) BeginTransaction(ctx context.Context, iso sql.IsolationLevel) (Tx, error) { db.lock.Lock() txid := db.txctr db.txctr += 1 // with overflow ! db.lock.Unlock() for _, v := range db.lstr { err := v.PreTxBegin(ctx, txid) if err != nil { return nil, err } } xtx, err := db.db.BeginTxx(ctx, &sql.TxOptions{Isolation: iso}) if err != nil { return nil, exerr.Wrap(err, "Failed to start sql transaction").Build() } for _, v := range db.lstr { v.PostTxBegin(txid, err) } return NewTransaction(xtx, txid, db), nil } func (db *database) Exit() error { return db.db.Close() } func (db *database) ListConverter() []DBTypeConverter { return db.conv } func (db *database) RegisterConverter(conv DBTypeConverter) { db.conv = langext.ArrFilter(db.conv, func(v DBTypeConverter) bool { return v.ModelTypeString() != conv.ModelTypeString() }) db.conv = append(db.conv, conv) } func (db *database) registerDefaultConverter() { db.RegisterConverter(ConverterBoolToBit) db.RegisterConverter(ConverterTimeToUnixMillis) db.RegisterConverter(ConverterRFCUnixMilliTimeToUnixMillis) db.RegisterConverter(ConverterRFCUnixNanoTimeToUnixNanos) db.RegisterConverter(ConverterRFCUnixTimeToUnixSeconds) db.RegisterConverter(ConverterRFC339TimeToString) db.RegisterConverter(ConverterRFC339NanoTimeToString) db.RegisterConverter(ConverterRFCDateToString) db.RegisterConverter(ConverterRFCTimeToString) db.RegisterConverter(ConverterRFCSecondsF64ToString) db.RegisterConverter(ConverterJsonObjToString) db.RegisterConverter(ConverterJsonArrToString) db.RegisterConverter(ConverterExErrCategoryToString) db.RegisterConverter(ConverterExErrSeverityToString) db.RegisterConverter(ConverterExErrTypeToString) }