goext/sq/database.go

90 lines
1.7 KiB
Go
Raw Permalink Normal View History

2022-12-07 23:21:36 +01:00
package sq
import (
"context"
"database/sql"
"github.com/jmoiron/sqlx"
"sync"
)
type DB interface {
Exec(ctx context.Context, sql string, prep PP) (sql.Result, error)
Query(ctx context.Context, sql string, prep PP) (*sqlx.Rows, error)
Ping(ctx context.Context) error
BeginTransaction(ctx context.Context, iso sql.IsolationLevel) (Tx, error)
2022-12-07 23:25:21 +01:00
SetListener(listener Listener)
2022-12-07 23:21:36 +01:00
}
type database struct {
db *sqlx.DB
txctr uint16
lock sync.Mutex
lstr Listener
}
func NewDB(db *sqlx.DB) DB {
return &database{
db: db,
txctr: 0,
lock: sync.Mutex{},
}
}
func (db *database) SetListener(listener Listener) {
db.lstr = listener
}
func (db *database) Exec(ctx context.Context, sql string, prep PP) (sql.Result, error) {
if db.lstr != nil {
db.lstr.OnExec(nil, sql, &prep)
}
res, err := db.db.NamedExecContext(ctx, sql, prep)
if err != nil {
return nil, err
}
return res, nil
}
func (db *database) Query(ctx context.Context, sql string, prep PP) (*sqlx.Rows, error) {
if db.lstr != nil {
db.lstr.OnQuery(nil, sql, &prep)
}
rows, err := db.db.NamedQueryContext(ctx, sql, prep)
if err != nil {
return nil, err
}
return rows, nil
}
func (db *database) Ping(ctx context.Context) error {
if db.lstr != nil {
db.lstr.OnPing()
}
err := db.db.PingContext(ctx)
if err != nil {
return err
}
return nil
}
func (db *database) BeginTransaction(ctx context.Context, iso sql.IsolationLevel) (Tx, error) {
db.lock.Lock()
txid := db.txctr
db.txctr += 1 // with overflow !
db.lock.Unlock()
if db.lstr != nil {
db.lstr.OnTxBegin(txid)
}
xtx, err := db.db.BeginTxx(ctx, &sql.TxOptions{Isolation: iso})
if err != nil {
return nil, err
}
return NewTransaction(xtx, txid, db.lstr), nil
}