453 lines
14 KiB
Go
453 lines
14 KiB
Go
|
// Copyright (C) MongoDB, Inc. 2017-present.
|
||
|
//
|
||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||
|
// not use this file except in compliance with the License. You may obtain
|
||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||
|
|
||
|
package bsoncodec
|
||
|
|
||
|
import (
|
||
|
"reflect"
|
||
|
"testing"
|
||
|
|
||
|
"github.com/google/go-cmp/cmp"
|
||
|
"go.mongodb.org/mongo-driver/bson/bsonrw"
|
||
|
"go.mongodb.org/mongo-driver/bson/bsontype"
|
||
|
"go.mongodb.org/mongo-driver/internal/testutil/assert"
|
||
|
)
|
||
|
|
||
|
func TestRegistry(t *testing.T) {
|
||
|
t.Run("Register", func(t *testing.T) {
|
||
|
fc1, fc2, fc3, fc4 := new(fakeCodec), new(fakeCodec), new(fakeCodec), new(fakeCodec)
|
||
|
t.Run("interface", func(t *testing.T) {
|
||
|
var t1f *testInterface1
|
||
|
var t2f *testInterface2
|
||
|
var t4f *testInterface4
|
||
|
ips := []interfaceValueEncoder{
|
||
|
{i: reflect.TypeOf(t1f).Elem(), ve: fc1},
|
||
|
{i: reflect.TypeOf(t2f).Elem(), ve: fc2},
|
||
|
{i: reflect.TypeOf(t1f).Elem(), ve: fc3},
|
||
|
{i: reflect.TypeOf(t4f).Elem(), ve: fc4},
|
||
|
}
|
||
|
want := []interfaceValueEncoder{
|
||
|
{i: reflect.TypeOf(t1f).Elem(), ve: fc3},
|
||
|
{i: reflect.TypeOf(t2f).Elem(), ve: fc2},
|
||
|
{i: reflect.TypeOf(t4f).Elem(), ve: fc4},
|
||
|
}
|
||
|
rb := NewRegistryBuilder()
|
||
|
for _, ip := range ips {
|
||
|
rb.RegisterHookEncoder(ip.i, ip.ve)
|
||
|
}
|
||
|
got := rb.interfaceEncoders
|
||
|
if !cmp.Equal(got, want, cmp.AllowUnexported(interfaceValueEncoder{}, fakeCodec{}), cmp.Comparer(typeComparer)) {
|
||
|
t.Errorf("The registered interfaces are not correct. got %v; want %v", got, want)
|
||
|
}
|
||
|
})
|
||
|
t.Run("type", func(t *testing.T) {
|
||
|
ft1, ft2, ft4 := fakeType1{}, fakeType2{}, fakeType4{}
|
||
|
rb := NewRegistryBuilder().
|
||
|
RegisterTypeEncoder(reflect.TypeOf(ft1), fc1).
|
||
|
RegisterTypeEncoder(reflect.TypeOf(ft2), fc2).
|
||
|
RegisterTypeEncoder(reflect.TypeOf(ft1), fc3).
|
||
|
RegisterTypeEncoder(reflect.TypeOf(ft4), fc4)
|
||
|
want := []struct {
|
||
|
t reflect.Type
|
||
|
c ValueEncoder
|
||
|
}{
|
||
|
{reflect.TypeOf(ft1), fc3},
|
||
|
{reflect.TypeOf(ft2), fc2},
|
||
|
{reflect.TypeOf(ft4), fc4},
|
||
|
}
|
||
|
got := rb.typeEncoders
|
||
|
for _, s := range want {
|
||
|
wantT, wantC := s.t, s.c
|
||
|
gotC, exists := got[wantT]
|
||
|
if !exists {
|
||
|
t.Errorf("Did not find type in the type registry: %v", wantT)
|
||
|
}
|
||
|
if !cmp.Equal(gotC, wantC, cmp.AllowUnexported(fakeCodec{})) {
|
||
|
t.Errorf("Codecs did not match. got %#v; want %#v", gotC, wantC)
|
||
|
}
|
||
|
}
|
||
|
})
|
||
|
t.Run("kind", func(t *testing.T) {
|
||
|
k1, k2, k4 := reflect.Struct, reflect.Slice, reflect.Map
|
||
|
rb := NewRegistryBuilder().
|
||
|
RegisterDefaultEncoder(k1, fc1).
|
||
|
RegisterDefaultEncoder(k2, fc2).
|
||
|
RegisterDefaultEncoder(k1, fc3).
|
||
|
RegisterDefaultEncoder(k4, fc4)
|
||
|
want := []struct {
|
||
|
k reflect.Kind
|
||
|
c ValueEncoder
|
||
|
}{
|
||
|
{k1, fc3},
|
||
|
{k2, fc2},
|
||
|
{k4, fc4},
|
||
|
}
|
||
|
got := rb.kindEncoders
|
||
|
for _, s := range want {
|
||
|
wantK, wantC := s.k, s.c
|
||
|
gotC, exists := got[wantK]
|
||
|
if !exists {
|
||
|
t.Errorf("Did not find kind in the kind registry: %v", wantK)
|
||
|
}
|
||
|
if !cmp.Equal(gotC, wantC, cmp.AllowUnexported(fakeCodec{})) {
|
||
|
t.Errorf("Codecs did not match. got %#v; want %#v", gotC, wantC)
|
||
|
}
|
||
|
}
|
||
|
})
|
||
|
t.Run("RegisterDefault", func(t *testing.T) {
|
||
|
t.Run("MapCodec", func(t *testing.T) {
|
||
|
codec := fakeCodec{num: 1}
|
||
|
codec2 := fakeCodec{num: 2}
|
||
|
rb := NewRegistryBuilder()
|
||
|
rb.RegisterDefaultEncoder(reflect.Map, codec)
|
||
|
if rb.kindEncoders[reflect.Map] != codec {
|
||
|
t.Errorf("Did not properly set the map codec. got %v; want %v", rb.kindEncoders[reflect.Map], codec)
|
||
|
}
|
||
|
rb.RegisterDefaultEncoder(reflect.Map, codec2)
|
||
|
if rb.kindEncoders[reflect.Map] != codec2 {
|
||
|
t.Errorf("Did not properly set the map codec. got %v; want %v", rb.kindEncoders[reflect.Map], codec2)
|
||
|
}
|
||
|
})
|
||
|
t.Run("StructCodec", func(t *testing.T) {
|
||
|
codec := fakeCodec{num: 1}
|
||
|
codec2 := fakeCodec{num: 2}
|
||
|
rb := NewRegistryBuilder()
|
||
|
rb.RegisterDefaultEncoder(reflect.Struct, codec)
|
||
|
if rb.kindEncoders[reflect.Struct] != codec {
|
||
|
t.Errorf("Did not properly set the struct codec. got %v; want %v", rb.kindEncoders[reflect.Struct], codec)
|
||
|
}
|
||
|
rb.RegisterDefaultEncoder(reflect.Struct, codec2)
|
||
|
if rb.kindEncoders[reflect.Struct] != codec2 {
|
||
|
t.Errorf("Did not properly set the struct codec. got %v; want %v", rb.kindEncoders[reflect.Struct], codec2)
|
||
|
}
|
||
|
})
|
||
|
t.Run("SliceCodec", func(t *testing.T) {
|
||
|
codec := fakeCodec{num: 1}
|
||
|
codec2 := fakeCodec{num: 2}
|
||
|
rb := NewRegistryBuilder()
|
||
|
rb.RegisterDefaultEncoder(reflect.Slice, codec)
|
||
|
if rb.kindEncoders[reflect.Slice] != codec {
|
||
|
t.Errorf("Did not properly set the slice codec. got %v; want %v", rb.kindEncoders[reflect.Slice], codec)
|
||
|
}
|
||
|
rb.RegisterDefaultEncoder(reflect.Slice, codec2)
|
||
|
if rb.kindEncoders[reflect.Slice] != codec2 {
|
||
|
t.Errorf("Did not properly set the slice codec. got %v; want %v", rb.kindEncoders[reflect.Slice], codec2)
|
||
|
}
|
||
|
})
|
||
|
t.Run("ArrayCodec", func(t *testing.T) {
|
||
|
codec := fakeCodec{num: 1}
|
||
|
codec2 := fakeCodec{num: 2}
|
||
|
rb := NewRegistryBuilder()
|
||
|
rb.RegisterDefaultEncoder(reflect.Array, codec)
|
||
|
if rb.kindEncoders[reflect.Array] != codec {
|
||
|
t.Errorf("Did not properly set the slice codec. got %v; want %v", rb.kindEncoders[reflect.Array], codec)
|
||
|
}
|
||
|
rb.RegisterDefaultEncoder(reflect.Array, codec2)
|
||
|
if rb.kindEncoders[reflect.Array] != codec2 {
|
||
|
t.Errorf("Did not properly set the slice codec. got %v; want %v", rb.kindEncoders[reflect.Array], codec2)
|
||
|
}
|
||
|
})
|
||
|
})
|
||
|
t.Run("Lookup", func(t *testing.T) {
|
||
|
type Codec interface {
|
||
|
ValueEncoder
|
||
|
ValueDecoder
|
||
|
}
|
||
|
|
||
|
var arrinstance [12]int
|
||
|
arr := reflect.TypeOf(arrinstance)
|
||
|
slc := reflect.TypeOf(make([]int, 12))
|
||
|
m := reflect.TypeOf(make(map[string]int))
|
||
|
strct := reflect.TypeOf(struct{ Foo string }{})
|
||
|
ft1 := reflect.PtrTo(reflect.TypeOf(fakeType1{}))
|
||
|
ft2 := reflect.TypeOf(fakeType2{})
|
||
|
ft3 := reflect.TypeOf(fakeType5(func(string, string) string { return "fakeType5" }))
|
||
|
ti1 := reflect.TypeOf((*testInterface1)(nil)).Elem()
|
||
|
ti2 := reflect.TypeOf((*testInterface2)(nil)).Elem()
|
||
|
ti1Impl := reflect.TypeOf(testInterface1Impl{})
|
||
|
ti2Impl := reflect.TypeOf(testInterface2Impl{})
|
||
|
ti3 := reflect.TypeOf((*testInterface3)(nil)).Elem()
|
||
|
ti3Impl := reflect.TypeOf(testInterface3Impl{})
|
||
|
ti3ImplPtr := reflect.TypeOf((*testInterface3Impl)(nil))
|
||
|
fc1, fc2 := fakeCodec{num: 1}, fakeCodec{num: 2}
|
||
|
fsc, fslcc, fmc := new(fakeStructCodec), new(fakeSliceCodec), new(fakeMapCodec)
|
||
|
pc := NewPointerCodec()
|
||
|
|
||
|
reg := NewRegistryBuilder().
|
||
|
RegisterTypeEncoder(ft1, fc1).
|
||
|
RegisterTypeEncoder(ft2, fc2).
|
||
|
RegisterTypeEncoder(ti1, fc1).
|
||
|
RegisterDefaultEncoder(reflect.Struct, fsc).
|
||
|
RegisterDefaultEncoder(reflect.Slice, fslcc).
|
||
|
RegisterDefaultEncoder(reflect.Array, fslcc).
|
||
|
RegisterDefaultEncoder(reflect.Map, fmc).
|
||
|
RegisterDefaultEncoder(reflect.Ptr, pc).
|
||
|
RegisterTypeDecoder(ft1, fc1).
|
||
|
RegisterTypeDecoder(ft2, fc2).
|
||
|
RegisterTypeDecoder(ti1, fc1). // values whose exact type is testInterface1 will use fc1 encoder
|
||
|
RegisterDefaultDecoder(reflect.Struct, fsc).
|
||
|
RegisterDefaultDecoder(reflect.Slice, fslcc).
|
||
|
RegisterDefaultDecoder(reflect.Array, fslcc).
|
||
|
RegisterDefaultDecoder(reflect.Map, fmc).
|
||
|
RegisterDefaultDecoder(reflect.Ptr, pc).
|
||
|
RegisterHookEncoder(ti2, fc2).
|
||
|
RegisterHookDecoder(ti2, fc2).
|
||
|
RegisterHookEncoder(ti3, fc3).
|
||
|
RegisterHookDecoder(ti3, fc3).
|
||
|
Build()
|
||
|
|
||
|
testCases := []struct {
|
||
|
name string
|
||
|
t reflect.Type
|
||
|
wantcodec Codec
|
||
|
wanterr error
|
||
|
testcache bool
|
||
|
}{
|
||
|
{
|
||
|
"type registry (pointer)",
|
||
|
ft1,
|
||
|
fc1,
|
||
|
nil,
|
||
|
false,
|
||
|
},
|
||
|
{
|
||
|
"type registry (non-pointer)",
|
||
|
ft2,
|
||
|
fc2,
|
||
|
nil,
|
||
|
false,
|
||
|
},
|
||
|
{
|
||
|
// lookup an interface type and expect that the registered encoder is returned
|
||
|
"interface with type encoder",
|
||
|
ti1,
|
||
|
fc1,
|
||
|
nil,
|
||
|
true,
|
||
|
},
|
||
|
{
|
||
|
// lookup a type that implements an interface and expect that the default struct codec is returned
|
||
|
"interface implementation with type encoder",
|
||
|
ti1Impl,
|
||
|
fsc,
|
||
|
nil,
|
||
|
false,
|
||
|
},
|
||
|
{
|
||
|
// lookup an interface type and expect that the registered hook is returned
|
||
|
"interface with hook",
|
||
|
ti2,
|
||
|
fc2,
|
||
|
nil,
|
||
|
false,
|
||
|
},
|
||
|
{
|
||
|
// lookup a type that implements an interface and expect that the registered hook is returned
|
||
|
"interface implementation with hook",
|
||
|
ti2Impl,
|
||
|
fc2,
|
||
|
nil,
|
||
|
false,
|
||
|
},
|
||
|
{
|
||
|
// lookup a pointer to a type where the pointer implements an interface and expect that the
|
||
|
// registered hook is returned
|
||
|
"interface pointer to implementation with hook (pointer)",
|
||
|
ti3ImplPtr,
|
||
|
fc3,
|
||
|
nil,
|
||
|
false,
|
||
|
},
|
||
|
{
|
||
|
"default struct codec (pointer)",
|
||
|
reflect.PtrTo(strct),
|
||
|
pc,
|
||
|
nil,
|
||
|
false,
|
||
|
},
|
||
|
{
|
||
|
"default struct codec (non-pointer)",
|
||
|
strct,
|
||
|
fsc,
|
||
|
nil,
|
||
|
false,
|
||
|
},
|
||
|
{
|
||
|
"default array codec",
|
||
|
arr,
|
||
|
fslcc,
|
||
|
nil,
|
||
|
false,
|
||
|
},
|
||
|
{
|
||
|
"default slice codec",
|
||
|
slc,
|
||
|
fslcc,
|
||
|
nil,
|
||
|
false,
|
||
|
},
|
||
|
{
|
||
|
"default map",
|
||
|
m,
|
||
|
fmc,
|
||
|
nil,
|
||
|
false,
|
||
|
},
|
||
|
{
|
||
|
"map non-string key",
|
||
|
reflect.TypeOf(map[int]int{}),
|
||
|
fmc,
|
||
|
nil,
|
||
|
false,
|
||
|
},
|
||
|
{
|
||
|
"No Codec Registered",
|
||
|
ft3,
|
||
|
nil,
|
||
|
ErrNoEncoder{Type: ft3},
|
||
|
false,
|
||
|
},
|
||
|
}
|
||
|
|
||
|
allowunexported := cmp.AllowUnexported(fakeCodec{}, fakeStructCodec{}, fakeSliceCodec{}, fakeMapCodec{})
|
||
|
comparepc := func(pc1, pc2 *PointerCodec) bool { return true }
|
||
|
for _, tc := range testCases {
|
||
|
t.Run(tc.name, func(t *testing.T) {
|
||
|
t.Run("Encoder", func(t *testing.T) {
|
||
|
gotcodec, goterr := reg.LookupEncoder(tc.t)
|
||
|
if !cmp.Equal(goterr, tc.wanterr, cmp.Comparer(compareErrors)) {
|
||
|
t.Errorf("Errors did not match. got %v; want %v", goterr, tc.wanterr)
|
||
|
}
|
||
|
if !cmp.Equal(gotcodec, tc.wantcodec, allowunexported, cmp.Comparer(comparepc)) {
|
||
|
t.Errorf("Codecs did not match. got %v; want %v", gotcodec, tc.wantcodec)
|
||
|
}
|
||
|
})
|
||
|
t.Run("Decoder", func(t *testing.T) {
|
||
|
var wanterr error
|
||
|
if ene, ok := tc.wanterr.(ErrNoEncoder); ok {
|
||
|
wanterr = ErrNoDecoder(ene)
|
||
|
} else {
|
||
|
wanterr = tc.wanterr
|
||
|
}
|
||
|
gotcodec, goterr := reg.LookupDecoder(tc.t)
|
||
|
if !cmp.Equal(goterr, wanterr, cmp.Comparer(compareErrors)) {
|
||
|
t.Errorf("Errors did not match. got %v; want %v", goterr, wanterr)
|
||
|
}
|
||
|
if !cmp.Equal(gotcodec, tc.wantcodec, allowunexported, cmp.Comparer(comparepc)) {
|
||
|
t.Errorf("Codecs did not match. got %v; want %v", gotcodec, tc.wantcodec)
|
||
|
t.Errorf("Codecs did not match. got %T; want %T", gotcodec, tc.wantcodec)
|
||
|
}
|
||
|
})
|
||
|
})
|
||
|
}
|
||
|
// lookup a type whose pointer implements an interface and expect that the registered hook is
|
||
|
// returned
|
||
|
t.Run("interface implementation with hook (pointer)", func(t *testing.T) {
|
||
|
t.Run("Encoder", func(t *testing.T) {
|
||
|
gotEnc, err := reg.LookupEncoder(ti3Impl)
|
||
|
assert.Nil(t, err, "LookupEncoder error: %v", err)
|
||
|
|
||
|
cae, ok := gotEnc.(*condAddrEncoder)
|
||
|
assert.True(t, ok, "Expected CondAddrEncoder, got %T", gotEnc)
|
||
|
if !cmp.Equal(cae.canAddrEnc, fc3, allowunexported, cmp.Comparer(comparepc)) {
|
||
|
t.Errorf("expected canAddrEnc %v, got %v", cae.canAddrEnc, fc3)
|
||
|
}
|
||
|
if !cmp.Equal(cae.elseEnc, fsc, allowunexported, cmp.Comparer(comparepc)) {
|
||
|
t.Errorf("expected elseEnc %v, got %v", cae.elseEnc, fsc)
|
||
|
}
|
||
|
})
|
||
|
t.Run("Decoder", func(t *testing.T) {
|
||
|
gotDec, err := reg.LookupDecoder(ti3Impl)
|
||
|
assert.Nil(t, err, "LookupDecoder error: %v", err)
|
||
|
|
||
|
cad, ok := gotDec.(*condAddrDecoder)
|
||
|
assert.True(t, ok, "Expected CondAddrDecoder, got %T", gotDec)
|
||
|
if !cmp.Equal(cad.canAddrDec, fc3, allowunexported, cmp.Comparer(comparepc)) {
|
||
|
t.Errorf("expected canAddrDec %v, got %v", cad.canAddrDec, fc3)
|
||
|
}
|
||
|
if !cmp.Equal(cad.elseDec, fsc, allowunexported, cmp.Comparer(comparepc)) {
|
||
|
t.Errorf("expected elseDec %v, got %v", cad.elseDec, fsc)
|
||
|
}
|
||
|
})
|
||
|
})
|
||
|
})
|
||
|
})
|
||
|
t.Run("Type Map", func(t *testing.T) {
|
||
|
reg := NewRegistryBuilder().
|
||
|
RegisterTypeMapEntry(bsontype.String, reflect.TypeOf("")).
|
||
|
RegisterTypeMapEntry(bsontype.Int32, reflect.TypeOf(int(0))).
|
||
|
Build()
|
||
|
|
||
|
var got, want reflect.Type
|
||
|
|
||
|
want = reflect.TypeOf("")
|
||
|
got, err := reg.LookupTypeMapEntry(bsontype.String)
|
||
|
noerr(t, err)
|
||
|
if got != want {
|
||
|
t.Errorf("Did not get expected type. got %v; want %v", got, want)
|
||
|
}
|
||
|
|
||
|
want = reflect.TypeOf(int(0))
|
||
|
got, err = reg.LookupTypeMapEntry(bsontype.Int32)
|
||
|
noerr(t, err)
|
||
|
if got != want {
|
||
|
t.Errorf("Did not get expected type. got %v; want %v", got, want)
|
||
|
}
|
||
|
|
||
|
want = nil
|
||
|
wanterr := ErrNoTypeMapEntry{Type: bsontype.ObjectID}
|
||
|
got, err = reg.LookupTypeMapEntry(bsontype.ObjectID)
|
||
|
if err != wanterr {
|
||
|
t.Errorf("Did not get expected error. got %v; want %v", err, wanterr)
|
||
|
}
|
||
|
if got != want {
|
||
|
t.Errorf("Did not get expected type. got %v; want %v", got, want)
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
|
||
|
type fakeType1 struct{}
|
||
|
type fakeType2 struct{}
|
||
|
type fakeType4 struct{}
|
||
|
type fakeType5 func(string, string) string
|
||
|
type fakeStructCodec struct{ fakeCodec }
|
||
|
type fakeSliceCodec struct{ fakeCodec }
|
||
|
type fakeMapCodec struct{ fakeCodec }
|
||
|
|
||
|
type fakeCodec struct{ num int }
|
||
|
|
||
|
func (fc fakeCodec) EncodeValue(EncodeContext, bsonrw.ValueWriter, reflect.Value) error {
|
||
|
return nil
|
||
|
}
|
||
|
func (fc fakeCodec) DecodeValue(DecodeContext, bsonrw.ValueReader, reflect.Value) error {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
type testInterface1 interface{ test1() }
|
||
|
type testInterface2 interface{ test2() }
|
||
|
type testInterface3 interface{ test3() }
|
||
|
type testInterface4 interface{ test4() }
|
||
|
|
||
|
type testInterface1Impl struct{}
|
||
|
|
||
|
var _ testInterface1 = testInterface1Impl{}
|
||
|
|
||
|
func (testInterface1Impl) test1() {}
|
||
|
|
||
|
type testInterface2Impl struct{}
|
||
|
|
||
|
var _ testInterface2 = testInterface2Impl{}
|
||
|
|
||
|
func (testInterface2Impl) test2() {}
|
||
|
|
||
|
type testInterface3Impl struct{}
|
||
|
|
||
|
var _ testInterface3 = (*testInterface3Impl)(nil)
|
||
|
|
||
|
func (*testInterface3Impl) test3() {}
|
||
|
|
||
|
func typeComparer(i1, i2 reflect.Type) bool { return i1 == i2 }
|