package sq import ( "context" "database/sql" "github.com/jmoiron/sqlx" "gogs.mikescher.com/BlackForestBytes/goext/langext" ) type Tx interface { Rollback() error Commit() error Exec(ctx context.Context, sql string, prep PP) (sql.Result, error) Query(ctx context.Context, sql string, prep PP) (*sqlx.Rows, error) } type transaction struct { tx *sqlx.Tx id uint16 lstr []Listener } func NewTransaction(xtx *sqlx.Tx, txid uint16, lstr []Listener) Tx { return &transaction{ tx: xtx, id: txid, lstr: lstr, } } func (tx *transaction) Rollback() error { for _, v := range tx.lstr { err := v.PreTxRollback(tx.id) if err != nil { return err } } result := tx.tx.Rollback() for _, v := range tx.lstr { v.PostTxRollback(tx.id, result) } return result } func (tx *transaction) Commit() error { for _, v := range tx.lstr { err := v.PreTxCommit(tx.id) if err != nil { return err } } result := tx.tx.Commit() for _, v := range tx.lstr { v.PostTxRollback(tx.id, result) } return result } func (tx *transaction) Exec(ctx context.Context, sqlstr string, prep PP) (sql.Result, error) { origsql := sqlstr for _, v := range tx.lstr { err := v.PreExec(langext.Ptr(tx.id), &sqlstr, &prep) if err != nil { return nil, err } } res, err := tx.tx.NamedExecContext(ctx, sqlstr, prep) for _, v := range tx.lstr { v.PostExec(langext.Ptr(tx.id), origsql, sqlstr, prep) } if err != nil { return nil, err } return res, nil } func (tx *transaction) Query(ctx context.Context, sqlstr string, prep PP) (*sqlx.Rows, error) { origsql := sqlstr for _, v := range tx.lstr { err := v.PreQuery(langext.Ptr(tx.id), &sqlstr, &prep) if err != nil { return nil, err } } rows, err := sqlx.NamedQueryContext(ctx, tx.tx, sqlstr, prep) for _, v := range tx.lstr { v.PostQuery(langext.Ptr(tx.id), origsql, sqlstr, prep) } if err != nil { return nil, err } return rows, nil }