goext/mongo/bson/encoder_test.go

136 lines
3.2 KiB
Go

// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"bytes"
"errors"
"reflect"
"testing"
"go.mongodb.org/mongo-driver/bson/bsoncodec"
"go.mongodb.org/mongo-driver/bson/bsonrw"
"go.mongodb.org/mongo-driver/bson/bsonrw/bsonrwtest"
)
func TestBasicEncode(t *testing.T) {
for _, tc := range marshalingTestCases {
t.Run(tc.name, func(t *testing.T) {
got := make(bsonrw.SliceWriter, 0, 1024)
vw, err := bsonrw.NewBSONValueWriter(&got)
noerr(t, err)
reg := DefaultRegistry
encoder, err := reg.LookupEncoder(reflect.TypeOf(tc.val))
noerr(t, err)
err = encoder.EncodeValue(bsoncodec.EncodeContext{Registry: reg}, vw, reflect.ValueOf(tc.val))
noerr(t, err)
if !bytes.Equal(got, tc.want) {
t.Errorf("Bytes are not equal. got %v; want %v", got, tc.want)
t.Errorf("Bytes:\n%v\n%v", got, tc.want)
}
})
}
}
func TestEncoderEncode(t *testing.T) {
for _, tc := range marshalingTestCases {
t.Run(tc.name, func(t *testing.T) {
got := make(bsonrw.SliceWriter, 0, 1024)
vw, err := bsonrw.NewBSONValueWriter(&got)
noerr(t, err)
enc, err := NewEncoder(vw)
noerr(t, err)
err = enc.Encode(tc.val)
noerr(t, err)
if !bytes.Equal(got, tc.want) {
t.Errorf("Bytes are not equal. got %v; want %v", got, tc.want)
t.Errorf("Bytes:\n%v\n%v", got, tc.want)
}
})
}
t.Run("Marshaler", func(t *testing.T) {
testCases := []struct {
name string
buf []byte
err error
wanterr error
vw bsonrw.ValueWriter
}{
{
"error",
nil,
errors.New("Marshaler error"),
errors.New("Marshaler error"),
&bsonrwtest.ValueReaderWriter{},
},
{
"copy error",
[]byte{0x05, 0x00, 0x00, 0x00, 0x00},
nil,
errors.New("copy error"),
&bsonrwtest.ValueReaderWriter{Err: errors.New("copy error"), ErrAfter: bsonrwtest.WriteDocument},
},
{
"success",
[]byte{0x07, 0x00, 0x00, 0x00, 0x0A, 0x00, 0x00},
nil,
nil,
nil,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
marshaler := testMarshaler{buf: tc.buf, err: tc.err}
var vw bsonrw.ValueWriter
var err error
b := make(bsonrw.SliceWriter, 0, 100)
compareVW := false
if tc.vw != nil {
vw = tc.vw
} else {
compareVW = true
vw, err = bsonrw.NewBSONValueWriter(&b)
noerr(t, err)
}
enc, err := NewEncoder(vw)
noerr(t, err)
got := enc.Encode(marshaler)
want := tc.wanterr
if !compareErrors(got, want) {
t.Errorf("Did not receive expected error. got %v; want %v", got, want)
}
if compareVW {
buf := b
if !bytes.Equal(buf, tc.buf) {
t.Errorf("Copied bytes do not match. got %v; want %v", buf, tc.buf)
}
}
})
}
})
}
type testMarshaler struct {
buf []byte
err error
}
func (tm testMarshaler) MarshalBSON() ([]byte, error) { return tm.buf, tm.err }
func docToBytes(d interface{}) []byte {
b, err := Marshal(d)
if err != nil {
panic(err)
}
return b
}