package sq import ( "context" "database/sql" "github.com/jmoiron/sqlx" "gogs.mikescher.com/BlackForestBytes/goext/exerr" "gogs.mikescher.com/BlackForestBytes/goext/langext" ) type TxStatus string const ( TxStatusInitial TxStatus = "INITIAL" TxStatusActive TxStatus = "ACTIVE" TxStatusComitted TxStatus = "COMMITTED" TxStatusRollback TxStatus = "ROLLBACK" ) type Tx interface { Queryable Rollback() error Commit() error Status() TxStatus } type transaction struct { tx *sqlx.Tx id uint16 status TxStatus execCtr int queryCtr int db *database } func NewTransaction(xtx *sqlx.Tx, txid uint16, db *database) Tx { return &transaction{ tx: xtx, id: txid, status: TxStatusInitial, execCtr: 0, queryCtr: 0, db: db, } } func (tx *transaction) Rollback() error { for _, v := range tx.db.lstr { err := v.PreTxRollback(tx.id) if err != nil { return exerr.Wrap(err, "failed to call SQL pre-rollback listener").Int("tx.id", int(tx.id)).Build() } } result := tx.tx.Rollback() if result == nil { tx.status = TxStatusRollback } for _, v := range tx.db.lstr { v.PostTxRollback(tx.id, result) } return result } func (tx *transaction) Commit() error { for _, v := range tx.db.lstr { err := v.PreTxCommit(tx.id) if err != nil { return exerr.Wrap(err, "failed to call SQL pre-commit listener").Int("tx.id", int(tx.id)).Build() } } result := tx.tx.Commit() if result == nil { tx.status = TxStatusComitted } for _, v := range tx.db.lstr { v.PostTxRollback(tx.id, result) } return result } func (tx *transaction) Exec(ctx context.Context, sqlstr string, prep PP) (sql.Result, error) { origsql := sqlstr for _, v := range tx.db.lstr { err := v.PreExec(ctx, langext.Ptr(tx.id), &sqlstr, &prep) if err != nil { return nil, exerr.Wrap(err, "failed to call SQL pre-exec listener").Int("tx.id", int(tx.id)).Str("original_sql", origsql).Str("sql", sqlstr).Any("sql_params", prep).Build() } } res, err := tx.tx.NamedExecContext(ctx, sqlstr, prep) if tx.status == TxStatusInitial && err == nil { tx.status = TxStatusActive } for _, v := range tx.db.lstr { v.PostExec(langext.Ptr(tx.id), origsql, sqlstr, prep) } if err != nil { return nil, exerr.Wrap(err, "Failed to [exec] sql statement").Int("tx.id", int(tx.id)).Str("original_sql", origsql).Str("sql", sqlstr).Any("sql_params", prep).Build() } return res, nil } func (tx *transaction) Query(ctx context.Context, sqlstr string, prep PP) (*sqlx.Rows, error) { origsql := sqlstr for _, v := range tx.db.lstr { err := v.PreQuery(ctx, langext.Ptr(tx.id), &sqlstr, &prep) if err != nil { return nil, exerr.Wrap(err, "failed to call SQL pre-query listener").Int("tx.id", int(tx.id)).Str("original_sql", origsql).Str("sql", sqlstr).Any("sql_params", prep).Build() } } rows, err := sqlx.NamedQueryContext(ctx, tx.tx, sqlstr, prep) if tx.status == TxStatusInitial && err == nil { tx.status = TxStatusActive } for _, v := range tx.db.lstr { v.PostQuery(langext.Ptr(tx.id), origsql, sqlstr, prep) } if err != nil { return nil, exerr.Wrap(err, "Failed to [query] sql statement").Int("tx.id", int(tx.id)).Str("original_sql", origsql).Str("sql", sqlstr).Any("sql_params", prep).Build() } return rows, nil } func (tx *transaction) Status() TxStatus { return tx.status } func (tx *transaction) ListConverter() []DBTypeConverter { return tx.db.conv } func (tx *transaction) Traffic() (int, int) { return tx.execCtr, tx.queryCtr }