diff --git a/sq/database.go b/sq/database.go index 089ee69..830368d 100644 --- a/sq/database.go +++ b/sq/database.go @@ -38,7 +38,7 @@ func (db *database) AddListener(listener Listener) { func (db *database) Exec(ctx context.Context, sqlstr string, prep PP) (sql.Result, error) { origsql := sqlstr for _, v := range db.lstr { - err := v.PreExec(nil, &sqlstr, &prep) + err := v.PreExec(ctx, nil, &sqlstr, &prep) if err != nil { return nil, err } @@ -59,7 +59,7 @@ func (db *database) Exec(ctx context.Context, sqlstr string, prep PP) (sql.Resul func (db *database) Query(ctx context.Context, sqlstr string, prep PP) (*sqlx.Rows, error) { origsql := sqlstr for _, v := range db.lstr { - err := v.PreQuery(nil, &sqlstr, &prep) + err := v.PreQuery(ctx, nil, &sqlstr, &prep) if err != nil { return nil, err } @@ -79,7 +79,7 @@ func (db *database) Query(ctx context.Context, sqlstr string, prep PP) (*sqlx.Ro func (db *database) Ping(ctx context.Context) error { for _, v := range db.lstr { - err := v.PrePing() + err := v.PrePing(ctx) if err != nil { return err } @@ -104,7 +104,7 @@ func (db *database) BeginTransaction(ctx context.Context, iso sql.IsolationLevel db.lock.Unlock() for _, v := range db.lstr { - err := v.PreTxBegin(txid) + err := v.PreTxBegin(ctx, txid) if err != nil { return nil, err } diff --git a/sq/listener.go b/sq/listener.go index 31624ea..6b8158b 100644 --- a/sq/listener.go +++ b/sq/listener.go @@ -1,12 +1,14 @@ package sq +import "context" + type Listener interface { - PrePing() error - PreTxBegin(txid uint16) error + PrePing(ctx context.Context) error + PreTxBegin(ctx context.Context, txid uint16) error PreTxCommit(txid uint16) error PreTxRollback(txid uint16) error - PreQuery(txID *uint16, sql *string, params *PP) error - PreExec(txID *uint16, sql *string, params *PP) error + PreQuery(ctx context.Context, txID *uint16, sql *string, params *PP) error + PreExec(ctx context.Context, txID *uint16, sql *string, params *PP) error PostPing(result error) PostTxBegin(txid uint16, result error) diff --git a/sq/transaction.go b/sq/transaction.go index c0adaa2..9216ac4 100644 --- a/sq/transaction.go +++ b/sq/transaction.go @@ -65,7 +65,7 @@ func (tx *transaction) Commit() error { func (tx *transaction) Exec(ctx context.Context, sqlstr string, prep PP) (sql.Result, error) { origsql := sqlstr for _, v := range tx.lstr { - err := v.PreExec(langext.Ptr(tx.id), &sqlstr, &prep) + err := v.PreExec(ctx, langext.Ptr(tx.id), &sqlstr, &prep) if err != nil { return nil, err } @@ -86,7 +86,7 @@ func (tx *transaction) Exec(ctx context.Context, sqlstr string, prep PP) (sql.Re func (tx *transaction) Query(ctx context.Context, sqlstr string, prep PP) (*sqlx.Rows, error) { origsql := sqlstr for _, v := range tx.lstr { - err := v.PreQuery(langext.Ptr(tx.id), &sqlstr, &prep) + err := v.PreQuery(ctx, langext.Ptr(tx.id), &sqlstr, &prep) if err != nil { return nil, err }