package sq import ( "context" "database/sql" "github.com/jmoiron/sqlx" "gogs.mikescher.com/BlackForestBytes/goext/langext" ) type TxStatus string const ( TxStatusInitial TxStatus = "INITIAL" TxStatusActive TxStatus = "ACTIVE" TxStatusComitted TxStatus = "COMMITTED" TxStatusRollback TxStatus = "ROLLBACK" ) type Tx interface { Rollback() error Commit() error Status() TxStatus Exec(ctx context.Context, sql string, prep PP) (sql.Result, error) Query(ctx context.Context, sql string, prep PP) (*sqlx.Rows, error) } type transaction struct { tx *sqlx.Tx id uint16 lstr []Listener status TxStatus execCtr int queryCtr int } func NewTransaction(xtx *sqlx.Tx, txid uint16, lstr []Listener) Tx { return &transaction{ tx: xtx, id: txid, lstr: lstr, status: TxStatusInitial, execCtr: 0, queryCtr: 0, } } func (tx *transaction) Rollback() error { for _, v := range tx.lstr { err := v.PreTxRollback(tx.id) if err != nil { return err } } result := tx.tx.Rollback() if result != nil { tx.status = TxStatusRollback } for _, v := range tx.lstr { v.PostTxRollback(tx.id, result) } return result } func (tx *transaction) Commit() error { for _, v := range tx.lstr { err := v.PreTxCommit(tx.id) if err != nil { return err } } result := tx.tx.Commit() if result != nil { tx.status = TxStatusComitted } for _, v := range tx.lstr { v.PostTxRollback(tx.id, result) } return result } func (tx *transaction) Exec(ctx context.Context, sqlstr string, prep PP) (sql.Result, error) { origsql := sqlstr for _, v := range tx.lstr { err := v.PreExec(ctx, langext.Ptr(tx.id), &sqlstr, &prep) if err != nil { return nil, err } } res, err := tx.tx.NamedExecContext(ctx, sqlstr, prep) if tx.status == TxStatusInitial && err != nil { tx.status = TxStatusActive } for _, v := range tx.lstr { v.PostExec(langext.Ptr(tx.id), origsql, sqlstr, prep) } if err != nil { return nil, err } return res, nil } func (tx *transaction) Query(ctx context.Context, sqlstr string, prep PP) (*sqlx.Rows, error) { origsql := sqlstr for _, v := range tx.lstr { err := v.PreQuery(ctx, langext.Ptr(tx.id), &sqlstr, &prep) if err != nil { return nil, err } } rows, err := sqlx.NamedQueryContext(ctx, tx.tx, sqlstr, prep) if tx.status == TxStatusInitial && err != nil { tx.status = TxStatusActive } for _, v := range tx.lstr { v.PostQuery(langext.Ptr(tx.id), origsql, sqlstr, prep) } if err != nil { return nil, err } return rows, nil } func (tx *transaction) Status() TxStatus { return tx.status } func (tx *transaction) Traffic() (int, int) { return tx.execCtr, tx.queryCtr }