package wmo

import (
	"context"
	"go.mongodb.org/mongo-driver/bson"
	"go.mongodb.org/mongo-driver/mongo"
	"gogs.mikescher.com/BlackForestBytes/goext/exerr"
	"gogs.mikescher.com/BlackForestBytes/goext/langext"
	pag "gogs.mikescher.com/BlackForestBytes/goext/pagination"
)

func (c *Coll[TData]) Paginate(ctx context.Context, filter pag.MongoFilter, page int, limit *int) ([]TData, pag.Pagination, error) {
	type totalCountResult struct {
		Count int `bson:"count"`
	}

	if page < 0 {
		page = 1
	}

	pipelineSort := mongo.Pipeline{}
	pipelineFilter := mongo.Pipeline{}
	sort := bson.D{}

	if filter != nil {
		pipelineFilter = filter.FilterQuery(ctx)
		sort = filter.Sort(ctx)
	}

	if len(sort) != 0 {
		pipelineSort = append(pipelineSort, bson.D{{Key: "$sort", Value: sort}})
	}

	pipelinePaginate := mongo.Pipeline{}
	if limit != nil {
		pipelinePaginate = append(pipelinePaginate, bson.D{{Key: "$skip", Value: *limit * (page - 1)}})
		pipelinePaginate = append(pipelinePaginate, bson.D{{Key: "$limit", Value: *limit}})
	} else {
		page = 1
	}

	pipelineCount := mongo.Pipeline{}
	pipelineCount = append(pipelineCount, bson.D{{Key: "$count", Value: "count"}})

	extrModPipelineResolved := mongo.Pipeline{}
	for _, ppl := range c.extraModPipeline {
		extrModPipelineResolved = langext.ArrConcat(extrModPipelineResolved, ppl(ctx))
	}

	pipelineList := langext.ArrConcat(pipelineFilter, pipelineSort, pipelinePaginate, extrModPipelineResolved, pipelineSort)
	pipelineTotalCount := langext.ArrConcat(pipelineFilter, pipelineCount)

	cursorList, err := c.coll.Aggregate(ctx, pipelineList)
	if err != nil {
		return nil, pag.Pagination{}, exerr.Wrap(err, "mongo-aggregation failed").Any("pipeline", pipelineList).Str("collection", c.Name()).Build()
	}

	entities, err := c.decodeAll(ctx, cursorList)
	if err != nil {
		return nil, pag.Pagination{}, exerr.Wrap(err, "failed to all-decode entities").Build()
	}

	var tcRes totalCountResult

	if limit == nil {
		// optimization, limit==nil, so we query all entities anyway, just use the array length
		tcRes.Count = len(entities)
	} else {

		cursorTotalCount, err := c.coll.Aggregate(ctx, pipelineTotalCount)
		if err != nil {
			return nil, pag.Pagination{}, exerr.Wrap(err, "mongo-aggregation failed").Any("pipeline", pipelineTotalCount).Str("collection", c.Name()).Build()
		}

		if cursorTotalCount.Next(ctx) {
			err = cursorTotalCount.Decode(&tcRes)
			if err != nil {
				return nil, pag.Pagination{}, exerr.Wrap(err, "failed to decode mongo-aggregation $count result").Any("pipeline", pipelineTotalCount).Str("collection", c.Name()).Build()
			}
		} else {
			tcRes.Count = 0 // no entries in DB
		}

	}

	paginationObj := pag.Pagination{
		Page:             page,
		Limit:            langext.Coalesce(limit, tcRes.Count),
		TotalPages:       pag.CalcPaginationTotalPages(tcRes.Count, langext.Coalesce(limit, tcRes.Count)),
		TotalItems:       tcRes.Count,
		CurrentPageCount: len(entities),
	}

	return entities, paginationObj, nil
}