package sq

import (
	"context"
	"fmt"
	"gogs.mikescher.com/BlackForestBytes/goext/exerr"
	"gogs.mikescher.com/BlackForestBytes/goext/langext"
	pag "gogs.mikescher.com/BlackForestBytes/goext/pagination"
)

func Paginate[TData any](ctx context.Context, q Queryable, table string, filter PaginateFilter, scanMode StructScanMode, scanSec StructScanSafety, page int, limit *int) ([]TData, pag.Pagination, error) {
	if filter == nil {
		filter = NewEmptyPaginateFilter()
	}

	prepParams := PP{}

	sortOrder := filter.Sort()
	sortCond := ""
	if len(sortOrder) > 0 {
		sortCond = "ORDER BY "
		for i, v := range sortOrder {
			if i > 0 {
				sortCond += ", "
			}
			sortCond += v.Field + " " + string(v.Direction)
		}
	}

	pageCond := ""
	if limit != nil {
		pageCond += fmt.Sprintf("LIMIT :%s OFFSET :%s", prepParams.Add(*limit+1), prepParams.Add(*limit*(page-1)))
	}

	filterCond, joinCond, joinTables := filter.SQL(prepParams)

	selectCond := table + ".*"
	for _, v := range joinTables {
		selectCond += ", " + v + ".*"
	}

	sqlQueryData := "SELECT " + selectCond + " FROM " + table + " " + joinCond + " WHERE ( " + filterCond + " ) " + sortCond + " " + pageCond
	sqlQueryCount := "SELECT " + "COUNT(*)" + " FROM " + table + " " + joinCond + " WHERE ( " + filterCond + " ) "

	rows, err := q.Query(ctx, sqlQueryData, prepParams)
	if err != nil {
		return nil, pag.Pagination{}, exerr.Wrap(err, "failed to list paginated entries from DB").Str("table", table).Any("filter", filter).Int("page", page).Any("limit", limit).Build()
	}

	entities, err := ScanAll[TData](ctx, q, rows, scanMode, scanSec, true)
	if err != nil {
		return nil, pag.Pagination{}, exerr.Wrap(err, "failed to decode paginated entries from DB").Str("table", table).Int("page", page).Any("limit", limit).Str("scanMode", string(scanMode)).Str("scanSec", string(scanSec)).Build()
	}

	if page == 1 && (limit == nil || len(entities) <= *limit) {
		return entities, pag.Pagination{
			Page:             1,
			Limit:            langext.Coalesce(limit, len(entities)),
			TotalPages:       1,
			TotalItems:       len(entities),
			CurrentPageCount: 1,
		}, nil
	} else {

		countRows, err := q.Query(ctx, sqlQueryCount, prepParams)
		if err != nil {
			return nil, pag.Pagination{}, exerr.Wrap(err, "failed to query total-count of paginated entries from DB").Str("table", table).Build()
		}

		if !countRows.Next() {
			return nil, pag.Pagination{}, exerr.New(exerr.TypeSQLDecode, "SQL COUNT(*) query returned no rows").Str("table", table).Any("filter", filter).Build()
		}

		var countRes int
		err = countRows.Scan(&countRes)
		if err != nil {
			return nil, pag.Pagination{}, exerr.Wrap(err, "failed to decode total-count of paginated entries from DB").Str("table", table).Build()
		}

		if len(entities) > *limit {
			entities = entities[:*limit]
		}

		paginationObj := pag.Pagination{
			Page:             page,
			Limit:            langext.Coalesce(limit, countRes),
			TotalPages:       pag.CalcPaginationTotalPages(countRes, langext.Coalesce(limit, countRes)),
			TotalItems:       countRes,
			CurrentPageCount: len(entities),
		}

		return entities, paginationObj, nil
	}
}

func Count(ctx context.Context, q Queryable, table string, filter PaginateFilter) (int, error) {
	if filter == nil {
		filter = NewEmptyPaginateFilter()
	}

	prepParams := PP{}

	filterCond, joinCond, _ := filter.SQL(prepParams)

	sqlQueryCount := "SELECT " + "COUNT(*)" + " FROM " + table + " " + joinCond + " WHERE ( " + filterCond + " )"

	countRows, err := q.Query(ctx, sqlQueryCount, prepParams)
	if err != nil {
		return 0, exerr.Wrap(err, "failed to query count of entries from DB").Str("table", table).Build()
	}

	if !countRows.Next() {
		return 0, exerr.New(exerr.TypeSQLDecode, "SQL COUNT(*) query returned no rows").Str("table", table).Any("filter", filter).Build()
	}

	var countRes int
	err = countRows.Scan(&countRes)
	if err != nil {
		return 0, exerr.Wrap(err, "failed to decode count of entries from DB").Str("table", table).Build()
	}

	return countRes, nil
}