package mongoext

import (
	"go.mongodb.org/mongo-driver/bson"
	"go.mongodb.org/mongo-driver/bson/bsoncodec"
	"go.mongodb.org/mongo-driver/bson/bsontype"
	"go.mongodb.org/mongo-driver/bson/primitive"
	"gogs.mikescher.com/BlackForestBytes/goext/langext"
	"gogs.mikescher.com/BlackForestBytes/goext/rfctime"
	"reflect"
)

func CreateGoExtBsonRegistry() *bsoncodec.Registry {
	rb := bsoncodec.NewRegistryBuilder()

	rb.RegisterTypeDecoder(reflect.TypeOf(rfctime.RFC3339Time{}), rfctime.RFC3339Time{})
	rb.RegisterTypeDecoder(reflect.TypeOf(&rfctime.RFC3339Time{}), rfctime.RFC3339Time{})

	rb.RegisterTypeDecoder(reflect.TypeOf(rfctime.RFC3339NanoTime{}), rfctime.RFC3339NanoTime{})
	rb.RegisterTypeDecoder(reflect.TypeOf(&rfctime.RFC3339NanoTime{}), rfctime.RFC3339NanoTime{})

	rb.RegisterTypeDecoder(reflect.TypeOf(rfctime.UnixTime{}), rfctime.UnixTime{})
	rb.RegisterTypeDecoder(reflect.TypeOf(&rfctime.UnixTime{}), rfctime.UnixTime{})

	rb.RegisterTypeDecoder(reflect.TypeOf(rfctime.UnixMilliTime{}), rfctime.UnixMilliTime{})
	rb.RegisterTypeDecoder(reflect.TypeOf(&rfctime.UnixMilliTime{}), rfctime.UnixMilliTime{})

	rb.RegisterTypeDecoder(reflect.TypeOf(rfctime.UnixNanoTime{}), rfctime.UnixNanoTime{})
	rb.RegisterTypeDecoder(reflect.TypeOf(&rfctime.UnixNanoTime{}), rfctime.UnixNanoTime{})

	rb.RegisterTypeDecoder(reflect.TypeOf(rfctime.Date{}), rfctime.Date{})
	rb.RegisterTypeDecoder(reflect.TypeOf(&rfctime.Date{}), rfctime.Date{})

	rb.RegisterTypeDecoder(reflect.TypeOf(rfctime.SecondsF64(0)), rfctime.SecondsF64(0))
	rb.RegisterTypeDecoder(reflect.TypeOf(langext.Ptr(rfctime.SecondsF64(0))), rfctime.SecondsF64(0))

	bsoncodec.DefaultValueEncoders{}.RegisterDefaultEncoders(rb)
	bsoncodec.DefaultValueDecoders{}.RegisterDefaultDecoders(rb)

	bson.PrimitiveCodecs{}.RegisterPrimitiveCodecs(rb)

	// otherwise we get []primitve.E when unmarshalling into any
	// which will result in {'key': .., 'value': ...}[] json when json-marshalling
	rb.RegisterTypeMapEntry(bsontype.EmbeddedDocument, reflect.TypeOf(primitive.M{}))

	return rb.Build()
}