package sq import ( "context" "database/sql" "github.com/jmoiron/sqlx" "sync" ) type DB interface { Exec(ctx context.Context, sql string, prep PP) (sql.Result, error) Query(ctx context.Context, sql string, prep PP) (*sqlx.Rows, error) Ping(ctx context.Context) error BeginTransaction(ctx context.Context, iso sql.IsolationLevel) (Tx, error) AddListener(listener Listener) } type database struct { db *sqlx.DB txctr uint16 lock sync.Mutex lstr []Listener } func NewDB(db *sqlx.DB) DB { return &database{ db: db, txctr: 0, lock: sync.Mutex{}, lstr: make([]Listener, 0), } } func (db *database) AddListener(listener Listener) { db.lstr = append(db.lstr, listener) } func (db *database) Exec(ctx context.Context, sqlstr string, prep PP) (sql.Result, error) { origsql := sqlstr for _, v := range db.lstr { err := v.PreExec(ctx, nil, &sqlstr, &prep) if err != nil { return nil, err } } res, err := db.db.NamedExecContext(ctx, sqlstr, prep) for _, v := range db.lstr { v.PostExec(nil, origsql, sqlstr, prep) } if err != nil { return nil, err } return res, nil } func (db *database) Query(ctx context.Context, sqlstr string, prep PP) (*sqlx.Rows, error) { origsql := sqlstr for _, v := range db.lstr { err := v.PreQuery(ctx, nil, &sqlstr, &prep) if err != nil { return nil, err } } rows, err := sqlx.NamedQueryContext(ctx, db.db, sqlstr, prep) for _, v := range db.lstr { v.PostQuery(nil, origsql, sqlstr, prep) } if err != nil { return nil, err } return rows, nil } func (db *database) Ping(ctx context.Context) error { for _, v := range db.lstr { err := v.PrePing(ctx) if err != nil { return err } } err := db.db.PingContext(ctx) for _, v := range db.lstr { v.PostPing(err) } if err != nil { return err } return nil } func (db *database) BeginTransaction(ctx context.Context, iso sql.IsolationLevel) (Tx, error) { db.lock.Lock() txid := db.txctr db.txctr += 1 // with overflow ! db.lock.Unlock() for _, v := range db.lstr { err := v.PreTxBegin(ctx, txid) if err != nil { return nil, err } } xtx, err := db.db.BeginTxx(ctx, &sql.TxOptions{Isolation: iso}) if err != nil { return nil, err } for _, v := range db.lstr { v.PostTxBegin(txid, err) } return NewTransaction(xtx, txid, db.lstr), nil }