package sq

import (
	"context"
	"database/sql"
	"github.com/jmoiron/sqlx"
	"gogs.mikescher.com/BlackForestBytes/goext/exerr"
	"gogs.mikescher.com/BlackForestBytes/goext/langext"
	"sync"
)

type DB interface {
	Queryable

	Ping(ctx context.Context) error
	BeginTransaction(ctx context.Context, iso sql.IsolationLevel) (Tx, error)
	AddListener(listener Listener)
	Exit() error
	RegisterConverter(DBTypeConverter)
}

type DBOptions struct {
	RegisterDefaultConverter *bool
	RegisterCommentTrimmer   *bool
}

type database struct {
	db    *sqlx.DB
	txctr uint16
	lock  sync.Mutex
	lstr  []Listener
	conv  []DBTypeConverter
}

func NewDB(db *sqlx.DB, opt DBOptions) DB {
	sqdb := &database{
		db:    db,
		txctr: 0,
		lock:  sync.Mutex{},
		lstr:  make([]Listener, 0),
	}

	if langext.Coalesce(opt.RegisterDefaultConverter, true) {
		sqdb.registerDefaultConverter()
	}

	if langext.Coalesce(opt.RegisterCommentTrimmer, true) {
		sqdb.AddListener(CommentTrimmer)
	}

	return sqdb
}

func (db *database) AddListener(listener Listener) {
	db.lstr = append(db.lstr, listener)
}

func (db *database) Exec(ctx context.Context, sqlstr string, prep PP) (sql.Result, error) {
	origsql := sqlstr
	for _, v := range db.lstr {
		err := v.PreExec(ctx, nil, &sqlstr, &prep)
		if err != nil {
			return nil, exerr.Wrap(err, "failed to call SQL pre-exec listener").Str("original_sql", origsql).Str("sql", sqlstr).Any("sql_params", prep).Build()
		}
	}

	res, err := db.db.NamedExecContext(ctx, sqlstr, prep)

	for _, v := range db.lstr {
		v.PostExec(nil, origsql, sqlstr, prep)
	}

	if err != nil {
		return nil, exerr.Wrap(err, "Failed to [exec] sql statement").Str("original_sql", origsql).Str("sql", sqlstr).Any("sql_params", prep).Build()
	}
	return res, nil
}

func (db *database) Query(ctx context.Context, sqlstr string, prep PP) (*sqlx.Rows, error) {
	origsql := sqlstr
	for _, v := range db.lstr {
		err := v.PreQuery(ctx, nil, &sqlstr, &prep)
		if err != nil {
			return nil, exerr.Wrap(err, "failed to call SQL pre-query listener").Str("original_sql", origsql).Str("sql", sqlstr).Any("sql_params", prep).Build()
		}
	}

	rows, err := sqlx.NamedQueryContext(ctx, db.db, sqlstr, prep)

	for _, v := range db.lstr {
		v.PostQuery(nil, origsql, sqlstr, prep)
	}

	if err != nil {
		return nil, exerr.Wrap(err, "Failed to [query] sql statement").Str("original_sql", origsql).Str("sql", sqlstr).Any("sql_params", prep).Build()
	}
	return rows, nil
}

func (db *database) Ping(ctx context.Context) error {
	for _, v := range db.lstr {
		err := v.PrePing(ctx)
		if err != nil {
			return err
		}
	}

	err := db.db.PingContext(ctx)

	for _, v := range db.lstr {
		v.PostPing(err)
	}

	if err != nil {
		return exerr.Wrap(err, "Failed to [ping] sql database").Build()
	}
	return nil
}

func (db *database) BeginTransaction(ctx context.Context, iso sql.IsolationLevel) (Tx, error) {
	db.lock.Lock()
	txid := db.txctr
	db.txctr += 1 // with overflow !
	db.lock.Unlock()

	for _, v := range db.lstr {
		err := v.PreTxBegin(ctx, txid)
		if err != nil {
			return nil, err
		}
	}

	xtx, err := db.db.BeginTxx(ctx, &sql.TxOptions{Isolation: iso})
	if err != nil {
		return nil, exerr.Wrap(err, "Failed to start sql transaction").Build()
	}

	for _, v := range db.lstr {
		v.PostTxBegin(txid, err)
	}

	return NewTransaction(xtx, txid, db), nil
}

func (db *database) Exit() error {
	return db.db.Close()
}

func (db *database) ListConverter() []DBTypeConverter {
	return db.conv
}

func (db *database) RegisterConverter(conv DBTypeConverter) {
	db.conv = langext.ArrFilter(db.conv, func(v DBTypeConverter) bool { return v.ModelTypeString() != conv.ModelTypeString() })
	db.conv = append(db.conv, conv)
}

func (db *database) registerDefaultConverter() {
	db.RegisterConverter(ConverterBoolToBit)

	db.RegisterConverter(ConverterTimeToUnixMillis)

	db.RegisterConverter(ConverterRFCUnixMilliTimeToUnixMillis)
	db.RegisterConverter(ConverterRFCUnixNanoTimeToUnixNanos)
	db.RegisterConverter(ConverterRFCUnixTimeToUnixSeconds)
	db.RegisterConverter(ConverterRFC339TimeToString)
	db.RegisterConverter(ConverterRFC339NanoTimeToString)
	db.RegisterConverter(ConverterRFCDateToString)
	db.RegisterConverter(ConverterRFCTimeToString)
	db.RegisterConverter(ConverterRFCSecondsF64ToString)

	db.RegisterConverter(ConverterJsonObjToString)
	db.RegisterConverter(ConverterJsonArrToString)

	db.RegisterConverter(ConverterExErrCategoryToString)
	db.RegisterConverter(ConverterExErrSeverityToString)
	db.RegisterConverter(ConverterExErrTypeToString)
}