goext/sq/database.go

152 lines
3.1 KiB
Go
Raw Permalink Normal View History

2022-12-07 23:21:36 +01:00
package sq
import (
"context"
"database/sql"
"github.com/jmoiron/sqlx"
2023-12-29 19:25:36 +01:00
"gogs.mikescher.com/BlackForestBytes/goext/langext"
2022-12-07 23:21:36 +01:00
"sync"
)
type DB interface {
2023-12-29 19:25:36 +01:00
Queryable
2022-12-07 23:21:36 +01:00
Ping(ctx context.Context) error
BeginTransaction(ctx context.Context, iso sql.IsolationLevel) (Tx, error)
2022-12-21 15:34:59 +01:00
AddListener(listener Listener)
2022-12-22 10:06:25 +01:00
Exit() error
2023-12-29 19:25:36 +01:00
RegisterConverter(DBTypeConverter)
RegisterDefaultConverter()
2022-12-07 23:21:36 +01:00
}
type database struct {
db *sqlx.DB
txctr uint16
lock sync.Mutex
2022-12-21 15:34:59 +01:00
lstr []Listener
2023-12-29 19:25:36 +01:00
conv []DBTypeConverter
2022-12-07 23:21:36 +01:00
}
func NewDB(db *sqlx.DB) DB {
return &database{
db: db,
txctr: 0,
lock: sync.Mutex{},
2022-12-21 15:34:59 +01:00
lstr: make([]Listener, 0),
2022-12-07 23:21:36 +01:00
}
}
2022-12-21 15:34:59 +01:00
func (db *database) AddListener(listener Listener) {
db.lstr = append(db.lstr, listener)
2022-12-07 23:21:36 +01:00
}
2022-12-21 15:34:59 +01:00
func (db *database) Exec(ctx context.Context, sqlstr string, prep PP) (sql.Result, error) {
origsql := sqlstr
for _, v := range db.lstr {
2022-12-21 15:41:41 +01:00
err := v.PreExec(ctx, nil, &sqlstr, &prep)
2022-12-21 15:34:59 +01:00
if err != nil {
return nil, err
}
}
res, err := db.db.NamedExecContext(ctx, sqlstr, prep)
for _, v := range db.lstr {
v.PostExec(nil, origsql, sqlstr, prep)
2022-12-07 23:21:36 +01:00
}
if err != nil {
return nil, err
}
return res, nil
}
2022-12-21 15:34:59 +01:00
func (db *database) Query(ctx context.Context, sqlstr string, prep PP) (*sqlx.Rows, error) {
origsql := sqlstr
for _, v := range db.lstr {
2022-12-21 15:41:41 +01:00
err := v.PreQuery(ctx, nil, &sqlstr, &prep)
2022-12-21 15:34:59 +01:00
if err != nil {
return nil, err
}
}
rows, err := sqlx.NamedQueryContext(ctx, db.db, sqlstr, prep)
for _, v := range db.lstr {
v.PostQuery(nil, origsql, sqlstr, prep)
2022-12-07 23:21:36 +01:00
}
if err != nil {
return nil, err
}
return rows, nil
}
func (db *database) Ping(ctx context.Context) error {
2022-12-21 15:34:59 +01:00
for _, v := range db.lstr {
2022-12-21 15:41:41 +01:00
err := v.PrePing(ctx)
2022-12-21 15:34:59 +01:00
if err != nil {
return err
}
2022-12-07 23:21:36 +01:00
}
err := db.db.PingContext(ctx)
2022-12-21 15:34:59 +01:00
for _, v := range db.lstr {
v.PostPing(err)
}
2022-12-07 23:21:36 +01:00
if err != nil {
return err
}
return nil
}
func (db *database) BeginTransaction(ctx context.Context, iso sql.IsolationLevel) (Tx, error) {
db.lock.Lock()
txid := db.txctr
db.txctr += 1 // with overflow !
db.lock.Unlock()
2022-12-21 15:34:59 +01:00
for _, v := range db.lstr {
2022-12-21 15:41:41 +01:00
err := v.PreTxBegin(ctx, txid)
2022-12-21 15:34:59 +01:00
if err != nil {
return nil, err
}
2022-12-07 23:21:36 +01:00
}
xtx, err := db.db.BeginTxx(ctx, &sql.TxOptions{Isolation: iso})
if err != nil {
return nil, err
}
2022-12-21 15:34:59 +01:00
for _, v := range db.lstr {
v.PostTxBegin(txid, err)
}
2023-12-29 19:25:36 +01:00
return NewTransaction(xtx, txid, db), nil
2022-12-07 23:21:36 +01:00
}
2022-12-22 10:06:25 +01:00
func (db *database) Exit() error {
return db.db.Close()
}
2023-12-29 19:25:36 +01:00
func (db *database) ListConverter() []DBTypeConverter {
return db.conv
}
func (db *database) RegisterConverter(conv DBTypeConverter) {
db.conv = langext.ArrFilter(db.conv, func(v DBTypeConverter) bool { return v.ModelTypeString() != conv.ModelTypeString() })
db.conv = append(db.conv, conv)
}
func (db *database) RegisterDefaultConverter() {
db.RegisterConverter(ConverterBoolToBit)
db.RegisterConverter(ConverterTimeToUnixMillis)
db.RegisterConverter(ConverterRFCUnixMilliTimeToUnixMillis)
db.RegisterConverter(ConverterRFCUnixNanoTimeToUnixNanos)
db.RegisterConverter(ConverterRFCUnixTimeToUnixSeconds)
db.RegisterConverter(ConverterRFC339TimeToString)
db.RegisterConverter(ConverterRFC339NanoTimeToString)
}