package wmo

import (
	"context"
	"go.mongodb.org/mongo-driver/bson"
	"go.mongodb.org/mongo-driver/mongo"
	ct "gogs.mikescher.com/BlackForestBytes/goext/cursortoken"
	"gogs.mikescher.com/BlackForestBytes/goext/exerr"
	"gogs.mikescher.com/BlackForestBytes/goext/langext"
)

func (c *Coll[TData]) List(ctx context.Context, filter ct.Filter, pageSize *int, inTok ct.CursorToken) ([]TData, ct.CursorToken, error) {
	if inTok.Mode == ct.CTMEnd {
		return make([]TData, 0), ct.End(), nil
	}

	pipeline := mongo.Pipeline{}
	pf1 := "_id"
	pd1 := ct.SortASC
	pf2 := "_id"
	pd2 := ct.SortASC

	if filter != nil {
		pipeline = filter.FilterQuery()
		pf1, pd1, pf2, pd2 = filter.Pagination()
	}

	sortPrimary := pf1
	sortDirPrimary := pd1
	sortSecondary := &pf2
	sortDirSecondary := &pd2

	if pf1 == pf2 {
		sortSecondary = nil
		sortDirSecondary = nil
	}

	paginationPipeline, err := createPaginationPipeline(c, inTok, sortPrimary, sortDirPrimary, sortSecondary, sortDirSecondary, pageSize)
	if err != nil {
		return nil, ct.CursorToken{}, exerr.
			Wrap(err, "failed to create pagination").
			WithType(exerr.TypeCursorTokenDecode).
			Str("collection", c.Name()).
			Any("inTok", inTok).
			Any("sortPrimary", sortPrimary).
			Any("sortDirPrimary", sortDirPrimary).
			Any("sortSecondary", sortSecondary).
			Any("sortDirSecondary", sortDirSecondary).
			Any("pageSize", pageSize).
			Build()
	}

	pipeline = append(pipeline, paginationPipeline...)

	for _, ppl := range c.extraModPipeline {
		pipeline = langext.ArrConcat(pipeline, ppl(ctx))
	}

	cursor, err := c.coll.Aggregate(ctx, pipeline)
	if err != nil {
		return nil, ct.CursorToken{}, exerr.Wrap(err, "mongo-aggregation failed").Any("pipeline", pipeline).Str("collection", c.Name()).Build()
	}

	// fast branch
	if pageSize == nil {
		entries, err := c.decodeAll(ctx, cursor)
		if err != nil {
			return nil, ct.CursorToken{}, exerr.Wrap(err, "failed to all-decode entities").Build()
		}
		return entries, ct.End(), nil
	}

	entities := make([]TData, 0, cursor.RemainingBatchLength())
	for (pageSize == nil || len(entities) != *pageSize) && cursor.Next(ctx) {
		var entry TData
		entry, err = c.decodeSingle(ctx, cursor)
		if err != nil {
			return nil, ct.CursorToken{}, exerr.Wrap(err, "failed to decode entity").Build()
		}
		entities = append(entities, entry)
	}

	if pageSize == nil || len(entities) < *pageSize || !cursor.TryNext(ctx) {
		return entities, ct.End(), nil
	}

	last := entities[len(entities)-1]

	c.EnsureInitializedReflection(last)

	nextToken, err := c.createToken(sortPrimary, sortDirPrimary, sortSecondary, sortDirSecondary, last, pageSize)
	if err != nil {
		return nil, ct.CursorToken{}, exerr.Wrap(err, "failed to create (out)-token").Build()
	}

	return entities, nextToken, nil
}

func (c *Coll[TData]) Count(ctx context.Context, filter ct.RawFilter) (int64, error) {
	type countRes struct {
		Count int64 `bson:"c"`
	}

	pipeline := filter.FilterQuery()

	pipeline = append(pipeline, bson.D{{Key: "$count", Value: "c"}})

	cursor, err := c.coll.Aggregate(ctx, pipeline)
	if err != nil {
		return 0, exerr.Wrap(err, "mongo-aggregation failed").Any("pipeline", pipeline).Str("collection", c.Name()).Build()
	}

	if cursor.Next(ctx) {
		v := countRes{}
		err = cursor.Decode(&v)
		if err != nil {
			return 0, exerr.Wrap(err, "failed to decode entity").Build()
		}
		return v.Count, nil
	}

	return 0, nil
}

func (c *Coll[TData]) ListWithCount(ctx context.Context, filter ct.Filter, pageSize *int, inTok ct.CursorToken) ([]TData, ct.CursorToken, int64, error) {
	// NOTE: Possible optimization: Cache count in CursorToken, then fetch count only on first page.
	count, err := c.Count(ctx, filter)
	if err != nil {
		return nil, ct.CursorToken{}, 0, err
	}

	data, token, err := c.List(ctx, filter, pageSize, inTok)
	if err != nil {
		return nil, ct.CursorToken{}, 0, err
	}
	return data, token, count, nil
}

func createPaginationPipeline[TData any](coll *Coll[TData], token ct.CursorToken, fieldPrimary string, sortPrimary ct.SortDirection, fieldSecondary *string, sortSecondary *ct.SortDirection, pageSize *int) ([]bson.D, error) {

	cond := bson.A{}
	sort := bson.D{}

	valuePrimary, err := coll.getTokenValueAsMongoType(token.ValuePrimary, fieldPrimary)
	if err != nil {
		return nil, exerr.Wrap(err, "failed to get (primary) token-value as mongo-type").Build()
	}

	if sortPrimary == ct.SortASC {
		// We sort ASC on <field> - so we want all entries newer ($gt) than the $primary
		cond = append(cond, bson.M{fieldPrimary: bson.M{"$gt": valuePrimary}})
		sort = append(sort, bson.E{Key: fieldPrimary, Value: +1})
	} else if sortPrimary == ct.SortDESC {
		// We sort DESC on <field> - so we want all entries older ($lt) than the $primary
		cond = append(cond, bson.M{fieldPrimary: bson.M{"$lt": valuePrimary}})
		sort = append(sort, bson.E{Key: fieldPrimary, Value: -1})
	}

	if fieldSecondary != nil && sortSecondary != nil && *fieldSecondary != fieldPrimary {

		valueSecondary, err := coll.getTokenValueAsMongoType(token.ValueSecondary, *fieldSecondary)
		if err != nil {
			return nil, exerr.Wrap(err, "failed to get (secondary) token-value as mongo-type").Build()
		}

		if *sortSecondary == ct.SortASC {

			// the conflict-resolution condition, for entries with the _same_ <field> as the $primary we take the ones with a greater $secondary (= newer)
			cond = append(cond, bson.M{"$and": bson.A{
				bson.M{fieldPrimary: valuePrimary},
				bson.M{*fieldSecondary: bson.M{"$gt": valueSecondary}},
			}})

			sort = append(sort, bson.E{Key: fieldPrimary, Value: +1})

		} else if *sortSecondary == ct.SortDESC {

			// the conflict-resolution condition, for entries with the _same_ <field> as the $primary we take the ones with a smaller $secondary (= older)
			cond = append(cond, bson.M{"$and": bson.A{
				bson.M{fieldPrimary: valuePrimary},
				bson.M{*fieldSecondary: bson.M{"$lt": valueSecondary}},
			}})

			sort = append(sort, bson.E{Key: fieldPrimary, Value: -1})

		}
	}

	pipeline := make([]bson.D, 0, 3)

	if token.Mode == ct.CTMStart {

		// no gt/lt condition

	} else if token.Mode == ct.CTMNormal {

		pipeline = append(pipeline, bson.D{{Key: "$match", Value: bson.M{"$or": cond}}})

	} else if token.Mode == ct.CTMEnd {

		// false
		pipeline = append(pipeline, bson.D{{Key: "$match", Value: bson.M{"$expr": bson.M{"$eq": bson.A{"1", "0"}}}}})

	} else {

		return nil, exerr.New(exerr.TypeInternal, "unknown ct mode: "+string(token.Mode)).Any("token.Mode", token.Mode).Build()

	}

	pipeline = append(pipeline, bson.D{{Key: "$sort", Value: sort}})

	if pageSize != nil {
		pipeline = append(pipeline, bson.D{{Key: "$limit", Value: int64(*pageSize + 1)}})
	}

	return pipeline, nil
}