goext/mongo/bson/bsoncodec/registry_test.go

453 lines
14 KiB
Go
Raw Normal View History

2023-06-18 15:50:55 +02:00
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncodec
import (
"reflect"
"testing"
"github.com/google/go-cmp/cmp"
"go.mongodb.org/mongo-driver/bson/bsonrw"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/internal/testutil/assert"
)
func TestRegistry(t *testing.T) {
t.Run("Register", func(t *testing.T) {
fc1, fc2, fc3, fc4 := new(fakeCodec), new(fakeCodec), new(fakeCodec), new(fakeCodec)
t.Run("interface", func(t *testing.T) {
var t1f *testInterface1
var t2f *testInterface2
var t4f *testInterface4
ips := []interfaceValueEncoder{
{i: reflect.TypeOf(t1f).Elem(), ve: fc1},
{i: reflect.TypeOf(t2f).Elem(), ve: fc2},
{i: reflect.TypeOf(t1f).Elem(), ve: fc3},
{i: reflect.TypeOf(t4f).Elem(), ve: fc4},
}
want := []interfaceValueEncoder{
{i: reflect.TypeOf(t1f).Elem(), ve: fc3},
{i: reflect.TypeOf(t2f).Elem(), ve: fc2},
{i: reflect.TypeOf(t4f).Elem(), ve: fc4},
}
rb := NewRegistryBuilder()
for _, ip := range ips {
rb.RegisterHookEncoder(ip.i, ip.ve)
}
got := rb.interfaceEncoders
if !cmp.Equal(got, want, cmp.AllowUnexported(interfaceValueEncoder{}, fakeCodec{}), cmp.Comparer(typeComparer)) {
t.Errorf("The registered interfaces are not correct. got %v; want %v", got, want)
}
})
t.Run("type", func(t *testing.T) {
ft1, ft2, ft4 := fakeType1{}, fakeType2{}, fakeType4{}
rb := NewRegistryBuilder().
RegisterTypeEncoder(reflect.TypeOf(ft1), fc1).
RegisterTypeEncoder(reflect.TypeOf(ft2), fc2).
RegisterTypeEncoder(reflect.TypeOf(ft1), fc3).
RegisterTypeEncoder(reflect.TypeOf(ft4), fc4)
want := []struct {
t reflect.Type
c ValueEncoder
}{
{reflect.TypeOf(ft1), fc3},
{reflect.TypeOf(ft2), fc2},
{reflect.TypeOf(ft4), fc4},
}
got := rb.typeEncoders
for _, s := range want {
wantT, wantC := s.t, s.c
gotC, exists := got[wantT]
if !exists {
t.Errorf("Did not find type in the type registry: %v", wantT)
}
if !cmp.Equal(gotC, wantC, cmp.AllowUnexported(fakeCodec{})) {
t.Errorf("Codecs did not match. got %#v; want %#v", gotC, wantC)
}
}
})
t.Run("kind", func(t *testing.T) {
k1, k2, k4 := reflect.Struct, reflect.Slice, reflect.Map
rb := NewRegistryBuilder().
RegisterDefaultEncoder(k1, fc1).
RegisterDefaultEncoder(k2, fc2).
RegisterDefaultEncoder(k1, fc3).
RegisterDefaultEncoder(k4, fc4)
want := []struct {
k reflect.Kind
c ValueEncoder
}{
{k1, fc3},
{k2, fc2},
{k4, fc4},
}
got := rb.kindEncoders
for _, s := range want {
wantK, wantC := s.k, s.c
gotC, exists := got[wantK]
if !exists {
t.Errorf("Did not find kind in the kind registry: %v", wantK)
}
if !cmp.Equal(gotC, wantC, cmp.AllowUnexported(fakeCodec{})) {
t.Errorf("Codecs did not match. got %#v; want %#v", gotC, wantC)
}
}
})
t.Run("RegisterDefault", func(t *testing.T) {
t.Run("MapCodec", func(t *testing.T) {
codec := fakeCodec{num: 1}
codec2 := fakeCodec{num: 2}
rb := NewRegistryBuilder()
rb.RegisterDefaultEncoder(reflect.Map, codec)
if rb.kindEncoders[reflect.Map] != codec {
t.Errorf("Did not properly set the map codec. got %v; want %v", rb.kindEncoders[reflect.Map], codec)
}
rb.RegisterDefaultEncoder(reflect.Map, codec2)
if rb.kindEncoders[reflect.Map] != codec2 {
t.Errorf("Did not properly set the map codec. got %v; want %v", rb.kindEncoders[reflect.Map], codec2)
}
})
t.Run("StructCodec", func(t *testing.T) {
codec := fakeCodec{num: 1}
codec2 := fakeCodec{num: 2}
rb := NewRegistryBuilder()
rb.RegisterDefaultEncoder(reflect.Struct, codec)
if rb.kindEncoders[reflect.Struct] != codec {
t.Errorf("Did not properly set the struct codec. got %v; want %v", rb.kindEncoders[reflect.Struct], codec)
}
rb.RegisterDefaultEncoder(reflect.Struct, codec2)
if rb.kindEncoders[reflect.Struct] != codec2 {
t.Errorf("Did not properly set the struct codec. got %v; want %v", rb.kindEncoders[reflect.Struct], codec2)
}
})
t.Run("SliceCodec", func(t *testing.T) {
codec := fakeCodec{num: 1}
codec2 := fakeCodec{num: 2}
rb := NewRegistryBuilder()
rb.RegisterDefaultEncoder(reflect.Slice, codec)
if rb.kindEncoders[reflect.Slice] != codec {
t.Errorf("Did not properly set the slice codec. got %v; want %v", rb.kindEncoders[reflect.Slice], codec)
}
rb.RegisterDefaultEncoder(reflect.Slice, codec2)
if rb.kindEncoders[reflect.Slice] != codec2 {
t.Errorf("Did not properly set the slice codec. got %v; want %v", rb.kindEncoders[reflect.Slice], codec2)
}
})
t.Run("ArrayCodec", func(t *testing.T) {
codec := fakeCodec{num: 1}
codec2 := fakeCodec{num: 2}
rb := NewRegistryBuilder()
rb.RegisterDefaultEncoder(reflect.Array, codec)
if rb.kindEncoders[reflect.Array] != codec {
t.Errorf("Did not properly set the slice codec. got %v; want %v", rb.kindEncoders[reflect.Array], codec)
}
rb.RegisterDefaultEncoder(reflect.Array, codec2)
if rb.kindEncoders[reflect.Array] != codec2 {
t.Errorf("Did not properly set the slice codec. got %v; want %v", rb.kindEncoders[reflect.Array], codec2)
}
})
})
t.Run("Lookup", func(t *testing.T) {
type Codec interface {
ValueEncoder
ValueDecoder
}
var arrinstance [12]int
arr := reflect.TypeOf(arrinstance)
slc := reflect.TypeOf(make([]int, 12))
m := reflect.TypeOf(make(map[string]int))
strct := reflect.TypeOf(struct{ Foo string }{})
ft1 := reflect.PtrTo(reflect.TypeOf(fakeType1{}))
ft2 := reflect.TypeOf(fakeType2{})
ft3 := reflect.TypeOf(fakeType5(func(string, string) string { return "fakeType5" }))
ti1 := reflect.TypeOf((*testInterface1)(nil)).Elem()
ti2 := reflect.TypeOf((*testInterface2)(nil)).Elem()
ti1Impl := reflect.TypeOf(testInterface1Impl{})
ti2Impl := reflect.TypeOf(testInterface2Impl{})
ti3 := reflect.TypeOf((*testInterface3)(nil)).Elem()
ti3Impl := reflect.TypeOf(testInterface3Impl{})
ti3ImplPtr := reflect.TypeOf((*testInterface3Impl)(nil))
fc1, fc2 := fakeCodec{num: 1}, fakeCodec{num: 2}
fsc, fslcc, fmc := new(fakeStructCodec), new(fakeSliceCodec), new(fakeMapCodec)
pc := NewPointerCodec()
reg := NewRegistryBuilder().
RegisterTypeEncoder(ft1, fc1).
RegisterTypeEncoder(ft2, fc2).
RegisterTypeEncoder(ti1, fc1).
RegisterDefaultEncoder(reflect.Struct, fsc).
RegisterDefaultEncoder(reflect.Slice, fslcc).
RegisterDefaultEncoder(reflect.Array, fslcc).
RegisterDefaultEncoder(reflect.Map, fmc).
RegisterDefaultEncoder(reflect.Ptr, pc).
RegisterTypeDecoder(ft1, fc1).
RegisterTypeDecoder(ft2, fc2).
RegisterTypeDecoder(ti1, fc1). // values whose exact type is testInterface1 will use fc1 encoder
RegisterDefaultDecoder(reflect.Struct, fsc).
RegisterDefaultDecoder(reflect.Slice, fslcc).
RegisterDefaultDecoder(reflect.Array, fslcc).
RegisterDefaultDecoder(reflect.Map, fmc).
RegisterDefaultDecoder(reflect.Ptr, pc).
RegisterHookEncoder(ti2, fc2).
RegisterHookDecoder(ti2, fc2).
RegisterHookEncoder(ti3, fc3).
RegisterHookDecoder(ti3, fc3).
Build()
testCases := []struct {
name string
t reflect.Type
wantcodec Codec
wanterr error
testcache bool
}{
{
"type registry (pointer)",
ft1,
fc1,
nil,
false,
},
{
"type registry (non-pointer)",
ft2,
fc2,
nil,
false,
},
{
// lookup an interface type and expect that the registered encoder is returned
"interface with type encoder",
ti1,
fc1,
nil,
true,
},
{
// lookup a type that implements an interface and expect that the default struct codec is returned
"interface implementation with type encoder",
ti1Impl,
fsc,
nil,
false,
},
{
// lookup an interface type and expect that the registered hook is returned
"interface with hook",
ti2,
fc2,
nil,
false,
},
{
// lookup a type that implements an interface and expect that the registered hook is returned
"interface implementation with hook",
ti2Impl,
fc2,
nil,
false,
},
{
// lookup a pointer to a type where the pointer implements an interface and expect that the
// registered hook is returned
"interface pointer to implementation with hook (pointer)",
ti3ImplPtr,
fc3,
nil,
false,
},
{
"default struct codec (pointer)",
reflect.PtrTo(strct),
pc,
nil,
false,
},
{
"default struct codec (non-pointer)",
strct,
fsc,
nil,
false,
},
{
"default array codec",
arr,
fslcc,
nil,
false,
},
{
"default slice codec",
slc,
fslcc,
nil,
false,
},
{
"default map",
m,
fmc,
nil,
false,
},
{
"map non-string key",
reflect.TypeOf(map[int]int{}),
fmc,
nil,
false,
},
{
"No Codec Registered",
ft3,
nil,
ErrNoEncoder{Type: ft3},
false,
},
}
allowunexported := cmp.AllowUnexported(fakeCodec{}, fakeStructCodec{}, fakeSliceCodec{}, fakeMapCodec{})
comparepc := func(pc1, pc2 *PointerCodec) bool { return true }
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Run("Encoder", func(t *testing.T) {
gotcodec, goterr := reg.LookupEncoder(tc.t)
if !cmp.Equal(goterr, tc.wanterr, cmp.Comparer(compareErrors)) {
t.Errorf("Errors did not match. got %v; want %v", goterr, tc.wanterr)
}
if !cmp.Equal(gotcodec, tc.wantcodec, allowunexported, cmp.Comparer(comparepc)) {
t.Errorf("Codecs did not match. got %v; want %v", gotcodec, tc.wantcodec)
}
})
t.Run("Decoder", func(t *testing.T) {
var wanterr error
if ene, ok := tc.wanterr.(ErrNoEncoder); ok {
wanterr = ErrNoDecoder(ene)
} else {
wanterr = tc.wanterr
}
gotcodec, goterr := reg.LookupDecoder(tc.t)
if !cmp.Equal(goterr, wanterr, cmp.Comparer(compareErrors)) {
t.Errorf("Errors did not match. got %v; want %v", goterr, wanterr)
}
if !cmp.Equal(gotcodec, tc.wantcodec, allowunexported, cmp.Comparer(comparepc)) {
t.Errorf("Codecs did not match. got %v; want %v", gotcodec, tc.wantcodec)
t.Errorf("Codecs did not match. got %T; want %T", gotcodec, tc.wantcodec)
}
})
})
}
// lookup a type whose pointer implements an interface and expect that the registered hook is
// returned
t.Run("interface implementation with hook (pointer)", func(t *testing.T) {
t.Run("Encoder", func(t *testing.T) {
gotEnc, err := reg.LookupEncoder(ti3Impl)
assert.Nil(t, err, "LookupEncoder error: %v", err)
cae, ok := gotEnc.(*condAddrEncoder)
assert.True(t, ok, "Expected CondAddrEncoder, got %T", gotEnc)
if !cmp.Equal(cae.canAddrEnc, fc3, allowunexported, cmp.Comparer(comparepc)) {
t.Errorf("expected canAddrEnc %v, got %v", cae.canAddrEnc, fc3)
}
if !cmp.Equal(cae.elseEnc, fsc, allowunexported, cmp.Comparer(comparepc)) {
t.Errorf("expected elseEnc %v, got %v", cae.elseEnc, fsc)
}
})
t.Run("Decoder", func(t *testing.T) {
gotDec, err := reg.LookupDecoder(ti3Impl)
assert.Nil(t, err, "LookupDecoder error: %v", err)
cad, ok := gotDec.(*condAddrDecoder)
assert.True(t, ok, "Expected CondAddrDecoder, got %T", gotDec)
if !cmp.Equal(cad.canAddrDec, fc3, allowunexported, cmp.Comparer(comparepc)) {
t.Errorf("expected canAddrDec %v, got %v", cad.canAddrDec, fc3)
}
if !cmp.Equal(cad.elseDec, fsc, allowunexported, cmp.Comparer(comparepc)) {
t.Errorf("expected elseDec %v, got %v", cad.elseDec, fsc)
}
})
})
})
})
t.Run("Type Map", func(t *testing.T) {
reg := NewRegistryBuilder().
RegisterTypeMapEntry(bsontype.String, reflect.TypeOf("")).
RegisterTypeMapEntry(bsontype.Int32, reflect.TypeOf(int(0))).
Build()
var got, want reflect.Type
want = reflect.TypeOf("")
got, err := reg.LookupTypeMapEntry(bsontype.String)
noerr(t, err)
if got != want {
t.Errorf("Did not get expected type. got %v; want %v", got, want)
}
want = reflect.TypeOf(int(0))
got, err = reg.LookupTypeMapEntry(bsontype.Int32)
noerr(t, err)
if got != want {
t.Errorf("Did not get expected type. got %v; want %v", got, want)
}
want = nil
wanterr := ErrNoTypeMapEntry{Type: bsontype.ObjectID}
got, err = reg.LookupTypeMapEntry(bsontype.ObjectID)
if err != wanterr {
t.Errorf("Did not get expected error. got %v; want %v", err, wanterr)
}
if got != want {
t.Errorf("Did not get expected type. got %v; want %v", got, want)
}
})
}
type fakeType1 struct{}
type fakeType2 struct{}
type fakeType4 struct{}
type fakeType5 func(string, string) string
type fakeStructCodec struct{ fakeCodec }
type fakeSliceCodec struct{ fakeCodec }
type fakeMapCodec struct{ fakeCodec }
type fakeCodec struct{ num int }
func (fc fakeCodec) EncodeValue(EncodeContext, bsonrw.ValueWriter, reflect.Value) error {
return nil
}
func (fc fakeCodec) DecodeValue(DecodeContext, bsonrw.ValueReader, reflect.Value) error {
return nil
}
type testInterface1 interface{ test1() }
type testInterface2 interface{ test2() }
type testInterface3 interface{ test3() }
type testInterface4 interface{ test4() }
type testInterface1Impl struct{}
var _ testInterface1 = testInterface1Impl{}
func (testInterface1Impl) test1() {}
type testInterface2Impl struct{}
var _ testInterface2 = testInterface2Impl{}
func (testInterface2Impl) test2() {}
type testInterface3Impl struct{}
var _ testInterface3 = (*testInterface3Impl)(nil)
func (*testInterface3Impl) test3() {}
func typeComparer(i1, i2 reflect.Type) bool { return i1 == i2 }