diff --git a/mongoext/registry.go b/mongoext/registry.go index 982102f..36038d7 100644 --- a/mongoext/registry.go +++ b/mongoext/registry.go @@ -3,29 +3,20 @@ package mongoext import ( "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/bsoncodec" - "go.mongodb.org/mongo-driver/bson/bsonrw" "gogs.mikescher.com/BlackForestBytes/goext/rfctime" "reflect" ) func CreateGoExtBsonRegistry() *bsoncodec.Registry { - var primitiveCodecs bson.PrimitiveCodecs - rb := bsoncodec.NewRegistryBuilder() + + rb.RegisterTypeDecoder(reflect.TypeOf(rfctime.RFC3339Time{}), rfctime.RFC3339Time{}) + rb.RegisterTypeDecoder(reflect.TypeOf(rfctime.RFC3339Time{}), rfctime.RFC3339NanoTime{}) + bsoncodec.DefaultValueEncoders{}.RegisterDefaultEncoders(rb) bsoncodec.DefaultValueDecoders{}.RegisterDefaultDecoders(rb) - rb.RegisterTypeDecoder(reflect.TypeOf(rfctime.RFC3339Time{}), rfctime.RFC3339Time{}) - - primitiveCodecs.RegisterPrimitiveCodecs(rb) + bson.PrimitiveCodecs{}.RegisterPrimitiveCodecs(rb) return rb.Build() } - -func encodeRFC3339Time(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) { - -} - -func decodeRFC3339Time(ec bsoncodec.EncodeContext, vr bsonrw.ValueReader, val reflect.Value) { - -} diff --git a/rfctime/rfc3339.go b/rfctime/rfc3339.go index d89419b..33427dd 100644 --- a/rfctime/rfc3339.go +++ b/rfctime/rfc3339.go @@ -70,7 +70,10 @@ func (t *RFC3339Time) UnmarshalText(data []byte) error { func (t *RFC3339Time) UnmarshalBSONValue(bt bsontype.Type, data []byte) error { if bt == bsontype.Null { - //t = nil + // we can't set nil in UnmarshalBSONValue (so we use default(struct)) + // https://stackoverflow.com/questions/75167597 + // https://jira.mongodb.org/browse/GODRIVER-2252 + *t = RFC3339Time{} return nil } if bt != bsontype.DateTime { diff --git a/rfctime/rfc3339Nano.go b/rfctime/rfc3339Nano.go index b6bef36..8868dc8 100644 --- a/rfctime/rfc3339Nano.go +++ b/rfctime/rfc3339Nano.go @@ -5,7 +5,10 @@ import ( "errors" "fmt" "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/bsoncodec" + "go.mongodb.org/mongo-driver/bson/bsonrw" "go.mongodb.org/mongo-driver/bson/bsontype" + "reflect" "time" ) @@ -67,7 +70,10 @@ func (t *RFC3339NanoTime) UnmarshalText(data []byte) error { func (t *RFC3339NanoTime) UnmarshalBSONValue(bt bsontype.Type, data []byte) error { if bt == bsontype.Null { - //t = nil + // we can't set nil in UnmarshalBSONValue (so we use default(struct)) + // https://stackoverflow.com/questions/75167597 + // https://jira.mongodb.org/browse/GODRIVER-2252 + *t = RFC3339NanoTime{} return nil } if bt != bsontype.DateTime { @@ -86,6 +92,32 @@ func (t RFC3339NanoTime) MarshalBSONValue() (bsontype.Type, []byte, error) { return bson.MarshalValue(time.Time(t)) } +func (t RFC3339NanoTime) DecodeValue(dc bsoncodec.DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { + if val.Kind() == reflect.Ptr && val.IsNil() { + if !val.CanSet() { + return errors.New("ValueUnmarshalerDecodeValue") + } + val.Set(reflect.New(val.Type().Elem())) + } + + tp, src, err := bsonrw.Copier{}.CopyValueToBytes(vr) + if err != nil { + return err + } + + if val.Kind() == reflect.Ptr && len(src) == 0 { + val.Set(reflect.Zero(val.Type())) + return nil + } + + err = t.UnmarshalBSONValue(tp, src) + if err != nil { + return err + } + + return nil +} + func (t RFC3339NanoTime) Serialize() string { return t.Time().Format(t.FormatStr()) }