package sq import ( "context" "database/sql" "github.com/jmoiron/sqlx" "gogs.mikescher.com/BlackForestBytes/goext/exerr" "gogs.mikescher.com/BlackForestBytes/goext/langext" "time" ) type TxStatus string const ( TxStatusInitial TxStatus = "INITIAL" TxStatusActive TxStatus = "ACTIVE" TxStatusComitted TxStatus = "COMMITTED" TxStatusRollback TxStatus = "ROLLBACK" ) type Tx interface { Queryable Rollback() error Commit() error Status() TxStatus } type transaction struct { constructorContext context.Context tx *sqlx.Tx id uint16 status TxStatus execCtr int queryCtr int db *database } func newTransaction(ctx context.Context, xtx *sqlx.Tx, txid uint16, db *database) Tx { return &transaction{ constructorContext: ctx, tx: xtx, id: txid, status: TxStatusInitial, execCtr: 0, queryCtr: 0, db: db, } } func (tx *transaction) Rollback() error { t0 := time.Now() preMeta := PreTxRollbackMeta{ConstructorContext: tx.constructorContext} for _, v := range tx.db.lstr { err := v.PreTxRollback(tx.id, preMeta) if err != nil { return exerr.Wrap(err, "failed to call SQL pre-rollback listener").Int("tx.id", int(tx.id)).Build() } } t1 := time.Now() result := tx.tx.Rollback() if result == nil { tx.status = TxStatusRollback } postMeta := PostTxRollbackMeta{ConstructorContext: tx.constructorContext, Init: t0, Start: t1, End: time.Now(), ExecCounter: tx.execCtr, QueryCounter: tx.queryCtr} for _, v := range tx.db.lstr { v.PostTxRollback(tx.id, result, postMeta) } return result } func (tx *transaction) Commit() error { t0 := time.Now() preMeta := PreTxCommitMeta{ConstructorContext: tx.constructorContext} for _, v := range tx.db.lstr { err := v.PreTxCommit(tx.id, preMeta) if err != nil { return exerr.Wrap(err, "failed to call SQL pre-commit listener").Int("tx.id", int(tx.id)).Build() } } t1 := time.Now() result := tx.tx.Commit() if result == nil { tx.status = TxStatusComitted } postMeta := PostTxCommitMeta{ConstructorContext: tx.constructorContext, Init: t0, Start: t1, End: time.Now(), ExecCounter: tx.execCtr, QueryCounter: tx.queryCtr} for _, v := range tx.db.lstr { v.PostTxCommit(tx.id, result, postMeta) } return result } func (tx *transaction) Exec(ctx context.Context, sqlstr string, prep PP) (sql.Result, error) { origsql := sqlstr t0 := time.Now() preMeta := PreExecMeta{Context: ctx, TransactionConstructorContext: tx.constructorContext} for _, v := range tx.db.lstr { err := v.PreExec(ctx, langext.Ptr(tx.id), &sqlstr, &prep, preMeta) if err != nil { return nil, exerr.Wrap(err, "failed to call SQL pre-exec listener").Int("tx.id", int(tx.id)).Str("original_sql", origsql).Str("sql", sqlstr).Any("sql_params", prep).Build() } } t1 := time.Now() res, err := tx.tx.NamedExecContext(ctx, sqlstr, prep) tx.execCtr++ if tx.status == TxStatusInitial && err == nil { tx.status = TxStatusActive } postMeta := PostExecMeta{Context: ctx, TransactionConstructorContext: tx.constructorContext, Init: t0, Start: t1, End: time.Now()} for _, v := range tx.db.lstr { v.PostExec(langext.Ptr(tx.id), origsql, sqlstr, prep, postMeta) } if err != nil { return nil, exerr.Wrap(err, "Failed to [exec] sql statement").Int("tx.id", int(tx.id)).Str("original_sql", origsql).Str("sql", sqlstr).Any("sql_params", prep).Build() } return res, nil } func (tx *transaction) Query(ctx context.Context, sqlstr string, prep PP) (*sqlx.Rows, error) { origsql := sqlstr t0 := time.Now() preMeta := PreQueryMeta{Context: ctx, TransactionConstructorContext: tx.constructorContext} for _, v := range tx.db.lstr { err := v.PreQuery(ctx, langext.Ptr(tx.id), &sqlstr, &prep, preMeta) if err != nil { return nil, exerr.Wrap(err, "failed to call SQL pre-query listener").Int("tx.id", int(tx.id)).Str("original_sql", origsql).Str("sql", sqlstr).Any("sql_params", prep).Build() } } t1 := time.Now() rows, err := sqlx.NamedQueryContext(ctx, tx.tx, sqlstr, prep) tx.queryCtr++ if tx.status == TxStatusInitial && err == nil { tx.status = TxStatusActive } postMeta := PostQueryMeta{Context: ctx, TransactionConstructorContext: tx.constructorContext, Init: t0, Start: t1, End: time.Now()} for _, v := range tx.db.lstr { v.PostQuery(langext.Ptr(tx.id), origsql, sqlstr, prep, postMeta) } if err != nil { return nil, exerr.Wrap(err, "Failed to [query] sql statement").Int("tx.id", int(tx.id)).Str("original_sql", origsql).Str("sql", sqlstr).Any("sql_params", prep).Build() } return rows, nil } func (tx *transaction) Status() TxStatus { return tx.status } func (tx *transaction) ListConverter() []DBTypeConverter { return tx.db.conv } func (tx *transaction) Traffic() (int, int) { return tx.execCtr, tx.queryCtr }