package wmo

import (
	"context"
	"go.mongodb.org/mongo-driver/bson"
	"go.mongodb.org/mongo-driver/mongo"
	ct "gogs.mikescher.com/BlackForestBytes/goext/cursortoken"
)

func (c *Coll[TData]) List(ctx context.Context, filter ct.Filter, pageSize *int, inTok ct.CursorToken) ([]TData, ct.CursorToken, error) {
	if inTok.Mode == ct.CTMEnd {
		return make([]TData, 0), ct.End(), nil
	}

	pipeline := mongo.Pipeline{}
	pf1 := "_id"
	pd1 := ct.SortASC
	pf2 := "_id"
	pd2 := ct.SortASC

	if filter != nil {
		pipeline = filter.FilterQuery()
		pf1, pd1, pf2, pd2 = filter.Pagination()
	}

	sortPrimary := pf1
	sortDirPrimary := pd1
	sortSecondary := &pf2
	sortDirSecondary := &pd2

	if pf1 == pf2 {
		sortSecondary = nil
		sortDirSecondary = nil
	}

	paginationPipeline, err := CreatePagination(c, inTok, sortPrimary, sortDirPrimary, sortSecondary, sortDirSecondary, pageSize)
	if err != nil {
		return nil, ct.CursorToken{}, err
	}

	pipeline = append(pipeline, paginationPipeline...)

	cursor, err := c.coll.Aggregate(ctx, pipeline)
	if err != nil {
		return nil, ct.CursorToken{}, err
	}

	// fast branch
	if pageSize == nil {
		entries, err := c.decodeAll(ctx, cursor)
		if err != nil {
			return nil, ct.CursorToken{}, err
		}
		return entries, ct.End(), nil
	}

	entities := make([]TData, 0, cursor.RemainingBatchLength())
	for (pageSize == nil || len(entities) != *pageSize) && cursor.Next(ctx) {
		var entry TData
		entry, err = c.decodeSingle(ctx, cursor)
		if err != nil {
			return nil, ct.CursorToken{}, err
		}
		entities = append(entities, entry)
	}

	if pageSize == nil || len(entities) < *pageSize || !cursor.TryNext(ctx) {
		return entities, ct.End(), nil
	}

	last := entities[len(entities)-1]

	c.EnsureInitializedReflection(last)

	nextToken, err := c.createToken(sortPrimary, sortDirPrimary, sortSecondary, sortDirSecondary, last, pageSize)
	if err != nil {
		return nil, ct.CursorToken{}, err
	}

	return entities, nextToken, nil
}

type countRes struct {
	Count int64 `bson:"c"`
}

func (c *Coll[TData]) Count(ctx context.Context, filter ct.Filter) (int64, error) {
	pipeline := filter.FilterQuery()

	pipeline = append(pipeline, bson.D{{Key: "$count", Value: "c"}})

	cursor, err := c.coll.Aggregate(ctx, pipeline)
	if err != nil {
		return 0, err
	}

	if cursor.Next(ctx) {
		v := countRes{}
		err = cursor.Decode(&v)
		if err != nil {
			return 0, err
		}
		return v.Count, nil
	}

	return 0, nil
}