package sq

import (
	"context"
	"database/sql"
	"github.com/jmoiron/sqlx"
	"gogs.mikescher.com/BlackForestBytes/goext/exerr"
	"gogs.mikescher.com/BlackForestBytes/goext/langext"
)

type TxStatus string

const (
	TxStatusInitial  TxStatus = "INITIAL"
	TxStatusActive   TxStatus = "ACTIVE"
	TxStatusComitted TxStatus = "COMMITTED"
	TxStatusRollback TxStatus = "ROLLBACK"
)

type Tx interface {
	Queryable

	Rollback() error
	Commit() error
	Status() TxStatus
}

type transaction struct {
	tx       *sqlx.Tx
	id       uint16
	status   TxStatus
	execCtr  int
	queryCtr int
	db       *database
}

func NewTransaction(xtx *sqlx.Tx, txid uint16, db *database) Tx {
	return &transaction{
		tx:       xtx,
		id:       txid,
		status:   TxStatusInitial,
		execCtr:  0,
		queryCtr: 0,
		db:       db,
	}
}

func (tx *transaction) Rollback() error {
	for _, v := range tx.db.lstr {
		err := v.PreTxRollback(tx.id)
		if err != nil {
			return exerr.Wrap(err, "failed to call SQL pre-rollback listener").Int("tx.id", int(tx.id)).Build()
		}
	}

	result := tx.tx.Rollback()

	if result == nil {
		tx.status = TxStatusRollback
	}

	for _, v := range tx.db.lstr {
		v.PostTxRollback(tx.id, result)
	}

	return result
}

func (tx *transaction) Commit() error {
	for _, v := range tx.db.lstr {
		err := v.PreTxCommit(tx.id)
		if err != nil {
			return exerr.Wrap(err, "failed to call SQL pre-commit listener").Int("tx.id", int(tx.id)).Build()
		}
	}

	result := tx.tx.Commit()

	if result == nil {
		tx.status = TxStatusComitted
	}

	for _, v := range tx.db.lstr {
		v.PostTxRollback(tx.id, result)
	}

	return result
}

func (tx *transaction) Exec(ctx context.Context, sqlstr string, prep PP) (sql.Result, error) {
	origsql := sqlstr
	for _, v := range tx.db.lstr {
		err := v.PreExec(ctx, langext.Ptr(tx.id), &sqlstr, &prep)
		if err != nil {
			return nil, exerr.Wrap(err, "failed to call SQL pre-exec listener").Int("tx.id", int(tx.id)).Str("original_sql", origsql).Str("sql", sqlstr).Any("sql_params", prep).Build()
		}
	}

	res, err := tx.tx.NamedExecContext(ctx, sqlstr, prep)

	if tx.status == TxStatusInitial && err == nil {
		tx.status = TxStatusActive
	}

	for _, v := range tx.db.lstr {
		v.PostExec(langext.Ptr(tx.id), origsql, sqlstr, prep)
	}

	if err != nil {
		return nil, exerr.Wrap(err, "Failed to [exec] sql statement").Int("tx.id", int(tx.id)).Str("original_sql", origsql).Str("sql", sqlstr).Any("sql_params", prep).Build()
	}
	return res, nil
}

func (tx *transaction) Query(ctx context.Context, sqlstr string, prep PP) (*sqlx.Rows, error) {
	origsql := sqlstr
	for _, v := range tx.db.lstr {
		err := v.PreQuery(ctx, langext.Ptr(tx.id), &sqlstr, &prep)
		if err != nil {
			return nil, exerr.Wrap(err, "failed to call SQL pre-query listener").Int("tx.id", int(tx.id)).Str("original_sql", origsql).Str("sql", sqlstr).Any("sql_params", prep).Build()
		}
	}

	rows, err := sqlx.NamedQueryContext(ctx, tx.tx, sqlstr, prep)

	if tx.status == TxStatusInitial && err == nil {
		tx.status = TxStatusActive
	}

	for _, v := range tx.db.lstr {
		v.PostQuery(langext.Ptr(tx.id), origsql, sqlstr, prep)
	}

	if err != nil {
		return nil, exerr.Wrap(err, "Failed to [query] sql statement").Int("tx.id", int(tx.id)).Str("original_sql", origsql).Str("sql", sqlstr).Any("sql_params", prep).Build()
	}
	return rows, nil
}

func (tx *transaction) Status() TxStatus {
	return tx.status
}

func (tx *transaction) ListConverter() []DBTypeConverter {
	return tx.db.conv
}

func (tx *transaction) Traffic() (int, int) {
	return tx.execCtr, tx.queryCtr
}