goext/sq/transaction.go
2022-12-07 23:21:36 +01:00

70 lines
1.3 KiB
Go

package sq
import (
"context"
"database/sql"
"github.com/jmoiron/sqlx"
"gogs.mikescher.com/BlackForestBytes/goext/langext"
)
type Tx interface {
Rollback() error
Commit() error
Exec(ctx context.Context, sql string, prep PP) (sql.Result, error)
Query(ctx context.Context, sql string, prep PP) (*sqlx.Rows, error)
}
type transaction struct {
tx *sqlx.Tx
id uint16
lstr Listener
}
func NewTransaction(xtx *sqlx.Tx, txid uint16, lstr Listener) Tx {
return &transaction{
tx: xtx,
id: txid,
lstr: lstr,
}
}
func (tx *transaction) Rollback() error {
if tx.lstr != nil {
tx.lstr.OnTxRollback(tx.id)
}
return tx.tx.Rollback()
}
func (tx *transaction) Commit() error {
if tx.lstr != nil {
tx.lstr.OnTxCommit(tx.id)
}
return tx.tx.Commit()
}
func (tx *transaction) Exec(ctx context.Context, sql string, prep PP) (sql.Result, error) {
if tx.lstr != nil {
tx.lstr.OnExec(langext.Ptr(tx.id), sql, &prep)
}
res, err := tx.tx.NamedExecContext(ctx, sql, prep)
if err != nil {
return nil, err
}
return res, nil
}
func (tx *transaction) Query(ctx context.Context, sql string, prep PP) (*sqlx.Rows, error) {
if tx.lstr != nil {
tx.lstr.OnQuery(langext.Ptr(tx.id), sql, &prep)
}
rows, err := sqlx.NamedQueryContext(ctx, tx.tx, sql, prep)
if err != nil {
return nil, err
}
return rows, nil
}