From e3b8d2cc0f5cd69a7774af209dcebd0a7dd73b63 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mike=20Schw=C3=B6rer?= Date: Wed, 21 Dec 2022 15:34:59 +0100 Subject: [PATCH] v0.0.40 --- sq/database.go | 66 ++++++++++++++++++++++++++++++++++----------- sq/listener.go | 19 ++++++++----- sq/transaction.go | 68 ++++++++++++++++++++++++++++++++++++----------- 3 files changed, 115 insertions(+), 38 deletions(-) diff --git a/sq/database.go b/sq/database.go index c40ac80..089ee69 100644 --- a/sq/database.go +++ b/sq/database.go @@ -12,14 +12,14 @@ type DB interface { Query(ctx context.Context, sql string, prep PP) (*sqlx.Rows, error) Ping(ctx context.Context) error BeginTransaction(ctx context.Context, iso sql.IsolationLevel) (Tx, error) - SetListener(listener Listener) + AddListener(listener Listener) } type database struct { db *sqlx.DB txctr uint16 lock sync.Mutex - lstr Listener + lstr []Listener } func NewDB(db *sqlx.DB) DB { @@ -27,31 +27,50 @@ func NewDB(db *sqlx.DB) DB { db: db, txctr: 0, lock: sync.Mutex{}, + lstr: make([]Listener, 0), } } -func (db *database) SetListener(listener Listener) { - db.lstr = listener +func (db *database) AddListener(listener Listener) { + db.lstr = append(db.lstr, listener) } -func (db *database) Exec(ctx context.Context, sql string, prep PP) (sql.Result, error) { - if db.lstr != nil { - db.lstr.OnExec(nil, &sql, &prep) +func (db *database) Exec(ctx context.Context, sqlstr string, prep PP) (sql.Result, error) { + origsql := sqlstr + for _, v := range db.lstr { + err := v.PreExec(nil, &sqlstr, &prep) + if err != nil { + return nil, err + } + } + + res, err := db.db.NamedExecContext(ctx, sqlstr, prep) + + for _, v := range db.lstr { + v.PostExec(nil, origsql, sqlstr, prep) } - res, err := db.db.NamedExecContext(ctx, sql, prep) if err != nil { return nil, err } return res, nil } -func (db *database) Query(ctx context.Context, sql string, prep PP) (*sqlx.Rows, error) { - if db.lstr != nil { - db.lstr.OnQuery(nil, &sql, &prep) +func (db *database) Query(ctx context.Context, sqlstr string, prep PP) (*sqlx.Rows, error) { + origsql := sqlstr + for _, v := range db.lstr { + err := v.PreQuery(nil, &sqlstr, &prep) + if err != nil { + return nil, err + } + } + + rows, err := sqlx.NamedQueryContext(ctx, db.db, sqlstr, prep) + + for _, v := range db.lstr { + v.PostQuery(nil, origsql, sqlstr, prep) } - rows, err := db.db.NamedQueryContext(ctx, sql, prep) if err != nil { return nil, err } @@ -59,11 +78,19 @@ func (db *database) Query(ctx context.Context, sql string, prep PP) (*sqlx.Rows, } func (db *database) Ping(ctx context.Context) error { - if db.lstr != nil { - db.lstr.OnPing() + for _, v := range db.lstr { + err := v.PrePing() + if err != nil { + return err + } } err := db.db.PingContext(ctx) + + for _, v := range db.lstr { + v.PostPing(err) + } + if err != nil { return err } @@ -76,8 +103,11 @@ func (db *database) BeginTransaction(ctx context.Context, iso sql.IsolationLevel db.txctr += 1 // with overflow ! db.lock.Unlock() - if db.lstr != nil { - db.lstr.OnTxBegin(txid) + for _, v := range db.lstr { + err := v.PreTxBegin(txid) + if err != nil { + return nil, err + } } xtx, err := db.db.BeginTxx(ctx, &sql.TxOptions{Isolation: iso}) @@ -85,5 +115,9 @@ func (db *database) BeginTransaction(ctx context.Context, iso sql.IsolationLevel return nil, err } + for _, v := range db.lstr { + v.PostTxBegin(txid, err) + } + return NewTransaction(xtx, txid, db.lstr), nil } diff --git a/sq/listener.go b/sq/listener.go index d41f1c7..31624ea 100644 --- a/sq/listener.go +++ b/sq/listener.go @@ -1,10 +1,17 @@ package sq type Listener interface { - OnQuery(txID *uint16, sql *string, params *PP) - OnExec(txID *uint16, sql *string, params *PP) - OnPing() - OnTxBegin(txid uint16) - OnTxCommit(txid uint16) - OnTxRollback(txid uint16) + PrePing() error + PreTxBegin(txid uint16) error + PreTxCommit(txid uint16) error + PreTxRollback(txid uint16) error + PreQuery(txID *uint16, sql *string, params *PP) error + PreExec(txID *uint16, sql *string, params *PP) error + + PostPing(result error) + PostTxBegin(txid uint16, result error) + PostTxCommit(txid uint16, result error) + PostTxRollback(txid uint16, result error) + PostQuery(txID *uint16, sqlOriginal string, sqlReal string, params PP) + PostExec(txID *uint16, sqlOriginal string, sqlReal string, params PP) } diff --git a/sq/transaction.go b/sq/transaction.go index 50d111e..c0adaa2 100644 --- a/sq/transaction.go +++ b/sq/transaction.go @@ -17,10 +17,10 @@ type Tx interface { type transaction struct { tx *sqlx.Tx id uint16 - lstr Listener + lstr []Listener } -func NewTransaction(xtx *sqlx.Tx, txid uint16, lstr Listener) Tx { +func NewTransaction(xtx *sqlx.Tx, txid uint16, lstr []Listener) Tx { return &transaction{ tx: xtx, id: txid, @@ -29,39 +29,75 @@ func NewTransaction(xtx *sqlx.Tx, txid uint16, lstr Listener) Tx { } func (tx *transaction) Rollback() error { - if tx.lstr != nil { - tx.lstr.OnTxRollback(tx.id) + for _, v := range tx.lstr { + err := v.PreTxRollback(tx.id) + if err != nil { + return err + } } - return tx.tx.Rollback() + result := tx.tx.Rollback() + + for _, v := range tx.lstr { + v.PostTxRollback(tx.id, result) + } + + return result } func (tx *transaction) Commit() error { - if tx.lstr != nil { - tx.lstr.OnTxCommit(tx.id) + for _, v := range tx.lstr { + err := v.PreTxCommit(tx.id) + if err != nil { + return err + } } - return tx.tx.Commit() + result := tx.tx.Commit() + + for _, v := range tx.lstr { + v.PostTxRollback(tx.id, result) + } + + return result } -func (tx *transaction) Exec(ctx context.Context, sql string, prep PP) (sql.Result, error) { - if tx.lstr != nil { - tx.lstr.OnExec(langext.Ptr(tx.id), &sql, &prep) +func (tx *transaction) Exec(ctx context.Context, sqlstr string, prep PP) (sql.Result, error) { + origsql := sqlstr + for _, v := range tx.lstr { + err := v.PreExec(langext.Ptr(tx.id), &sqlstr, &prep) + if err != nil { + return nil, err + } + } + + res, err := tx.tx.NamedExecContext(ctx, sqlstr, prep) + + for _, v := range tx.lstr { + v.PostExec(langext.Ptr(tx.id), origsql, sqlstr, prep) } - res, err := tx.tx.NamedExecContext(ctx, sql, prep) if err != nil { return nil, err } return res, nil } -func (tx *transaction) Query(ctx context.Context, sql string, prep PP) (*sqlx.Rows, error) { - if tx.lstr != nil { - tx.lstr.OnQuery(langext.Ptr(tx.id), &sql, &prep) +func (tx *transaction) Query(ctx context.Context, sqlstr string, prep PP) (*sqlx.Rows, error) { + origsql := sqlstr + for _, v := range tx.lstr { + err := v.PreQuery(langext.Ptr(tx.id), &sqlstr, &prep) + if err != nil { + return nil, err + } + } + + rows, err := sqlx.NamedQueryContext(ctx, tx.tx, sqlstr, prep) + + for _, v := range tx.lstr { + v.PostQuery(langext.Ptr(tx.id), origsql, sqlstr, prep) } - rows, err := sqlx.NamedQueryContext(ctx, tx.tx, sql, prep) if err != nil { return nil, err }