goext/mongo/bson/marshal_test.go

383 lines
10 KiB
Go
Raw Normal View History

2023-06-18 15:50:55 +02:00
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"bytes"
"errors"
"fmt"
"reflect"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/require"
"go.mongodb.org/mongo-driver/bson/bsoncodec"
"go.mongodb.org/mongo-driver/bson/bsonrw"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/internal/testutil/assert"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
)
var tInt32 = reflect.TypeOf(int32(0))
func TestMarshalAppendWithRegistry(t *testing.T) {
for _, tc := range marshalingTestCases {
t.Run(tc.name, func(t *testing.T) {
dst := make([]byte, 0, 1024)
var reg *bsoncodec.Registry
if tc.reg != nil {
reg = tc.reg
} else {
reg = DefaultRegistry
}
got, err := MarshalAppendWithRegistry(reg, dst, tc.val)
noerr(t, err)
if !bytes.Equal(got, tc.want) {
t.Errorf("Bytes are not equal. got %v; want %v", got, tc.want)
t.Errorf("Bytes:\n%v\n%v", got, tc.want)
}
})
}
}
func TestMarshalAppendWithContext(t *testing.T) {
for _, tc := range marshalingTestCases {
t.Run(tc.name, func(t *testing.T) {
dst := make([]byte, 0, 1024)
var reg *bsoncodec.Registry
if tc.reg != nil {
reg = tc.reg
} else {
reg = DefaultRegistry
}
ec := bsoncodec.EncodeContext{Registry: reg}
got, err := MarshalAppendWithContext(ec, dst, tc.val)
noerr(t, err)
if !bytes.Equal(got, tc.want) {
t.Errorf("Bytes are not equal. got %v; want %v", got, tc.want)
t.Errorf("Bytes:\n%v\n%v", got, tc.want)
}
})
}
}
func TestMarshalWithRegistry(t *testing.T) {
for _, tc := range marshalingTestCases {
t.Run(tc.name, func(t *testing.T) {
var reg *bsoncodec.Registry
if tc.reg != nil {
reg = tc.reg
} else {
reg = DefaultRegistry
}
got, err := MarshalWithRegistry(reg, tc.val)
noerr(t, err)
if !bytes.Equal(got, tc.want) {
t.Errorf("Bytes are not equal. got %v; want %v", got, tc.want)
t.Errorf("Bytes:\n%v\n%v", got, tc.want)
}
})
}
}
func TestMarshalWithContext(t *testing.T) {
for _, tc := range marshalingTestCases {
t.Run(tc.name, func(t *testing.T) {
var reg *bsoncodec.Registry
if tc.reg != nil {
reg = tc.reg
} else {
reg = DefaultRegistry
}
ec := bsoncodec.EncodeContext{Registry: reg}
got, err := MarshalWithContext(ec, tc.val)
noerr(t, err)
if !bytes.Equal(got, tc.want) {
t.Errorf("Bytes are not equal. got %v; want %v", got, tc.want)
t.Errorf("Bytes:\n%v\n%v", got, tc.want)
}
})
}
}
func TestMarshalAppend(t *testing.T) {
for _, tc := range marshalingTestCases {
t.Run(tc.name, func(t *testing.T) {
if tc.reg != nil {
t.Skip() // test requires custom registry
}
dst := make([]byte, 0, 1024)
got, err := MarshalAppend(dst, tc.val)
noerr(t, err)
if !bytes.Equal(got, tc.want) {
t.Errorf("Bytes are not equal. got %v; want %v", got, tc.want)
t.Errorf("Bytes:\n%v\n%v", got, tc.want)
}
})
}
}
func TestMarshalExtJSONAppendWithContext(t *testing.T) {
t.Run("MarshalExtJSONAppendWithContext", func(t *testing.T) {
dst := make([]byte, 0, 1024)
type teststruct struct{ Foo int }
val := teststruct{1}
ec := bsoncodec.EncodeContext{Registry: DefaultRegistry}
got, err := MarshalExtJSONAppendWithContext(ec, dst, val, true, false)
noerr(t, err)
want := []byte(`{"foo":{"$numberInt":"1"}}`)
if !bytes.Equal(got, want) {
t.Errorf("Bytes are not equal. got %v; want %v", got, want)
t.Errorf("Bytes:\n%s\n%s", got, want)
}
})
}
func TestMarshalExtJSONWithContext(t *testing.T) {
t.Run("MarshalExtJSONWithContext", func(t *testing.T) {
type teststruct struct{ Foo int }
val := teststruct{1}
ec := bsoncodec.EncodeContext{Registry: DefaultRegistry}
got, err := MarshalExtJSONWithContext(ec, val, true, false)
noerr(t, err)
want := []byte(`{"foo":{"$numberInt":"1"}}`)
if !bytes.Equal(got, want) {
t.Errorf("Bytes are not equal. got %v; want %v", got, want)
t.Errorf("Bytes:\n%s\n%s", got, want)
}
})
}
func TestMarshal_roundtripFromBytes(t *testing.T) {
before := []byte{
// length
0x1c, 0x0, 0x0, 0x0,
// --- begin array ---
// type - document
0x3,
// key - "foo"
0x66, 0x6f, 0x6f, 0x0,
// length
0x12, 0x0, 0x0, 0x0,
// type - string
0x2,
// key - "bar"
0x62, 0x61, 0x72, 0x0,
// value - string length
0x4, 0x0, 0x0, 0x0,
// value - "baz"
0x62, 0x61, 0x7a, 0x0,
// null terminator
0x0,
// --- end array ---
// null terminator
0x0,
}
var doc D
require.NoError(t, Unmarshal(before, &doc))
after, err := Marshal(doc)
require.NoError(t, err)
require.True(t, bytes.Equal(before, after))
}
func TestMarshal_roundtripFromDoc(t *testing.T) {
before := D{
{"foo", "bar"},
{"baz", int64(-27)},
{"bing", A{nil, primitive.Regex{Pattern: "word", Options: "i"}}},
}
b, err := Marshal(before)
require.NoError(t, err)
var after D
require.NoError(t, Unmarshal(b, &after))
if !cmp.Equal(after, before) {
t.Errorf("Documents to not match. got %v; want %v", after, before)
}
}
func TestCachingEncodersNotSharedAcrossRegistries(t *testing.T) {
// Encoders that have caches for recursive encoder lookup should not be shared across Registry instances. Otherwise,
// the first EncodeValue call would cache an encoder and a subsequent call would see that encoder even if a
// different Registry is used.
// Create a custom Registry that negates int32 values when encoding.
var encodeInt32 bsoncodec.ValueEncoderFunc = func(_ bsoncodec.EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if val.Kind() != reflect.Int32 {
return fmt.Errorf("expected kind to be int32, got %v", val.Kind())
}
return vw.WriteInt32(int32(val.Int()) * -1)
}
customReg := NewRegistryBuilder().
RegisterTypeEncoder(tInt32, encodeInt32).
Build()
// Helper function to run the test and make assertions. The provided original value should result in the document
// {"x": {$numberInt: 1}} when marshalled with the default registry.
verifyResults := func(t *testing.T, original interface{}) {
// Marshal using the default and custom registries. Assert that the result is {x: 1} and {x: -1}, respectively.
first, err := Marshal(original)
assert.Nil(t, err, "Marshal error: %v", err)
expectedFirst := Raw(bsoncore.BuildDocumentFromElements(
nil,
bsoncore.AppendInt32Element(nil, "x", 1),
))
assert.Equal(t, expectedFirst, Raw(first), "expected document %v, got %v", expectedFirst, Raw(first))
second, err := MarshalWithRegistry(customReg, original)
assert.Nil(t, err, "Marshal error: %v", err)
expectedSecond := Raw(bsoncore.BuildDocumentFromElements(
nil,
bsoncore.AppendInt32Element(nil, "x", -1),
))
assert.Equal(t, expectedSecond, Raw(second), "expected document %v, got %v", expectedSecond, Raw(second))
}
t.Run("struct", func(t *testing.T) {
type Struct struct {
X int32
}
verifyResults(t, Struct{
X: 1,
})
})
t.Run("pointer", func(t *testing.T) {
i32 := int32(1)
verifyResults(t, M{
"x": &i32,
})
})
}
func TestNullBytes(t *testing.T) {
t.Run("element keys", func(t *testing.T) {
doc := D{{"a\x00", "foobar"}}
res, err := Marshal(doc)
want := errors.New("BSON element key cannot contain null bytes")
assert.Equal(t, want, err, "expected Marshal error %v, got error %v with result %q", want, err, Raw(res))
})
t.Run("regex values", func(t *testing.T) {
wantErr := errors.New("BSON regex values cannot contain null bytes")
testCases := []struct {
name string
pattern string
options string
}{
{"null bytes in pattern", "a\x00", "i"},
{"null bytes in options", "pattern", "i\x00"},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
regex := primitive.Regex{
Pattern: tc.pattern,
Options: tc.options,
}
res, err := Marshal(D{{"foo", regex}})
assert.Equal(t, wantErr, err, "expected Marshal error %v, got error %v with result %q", wantErr, err, Raw(res))
})
}
})
t.Run("sub document field name", func(t *testing.T) {
doc := D{{"foo", D{{"foobar", D{{"a\x00", "foobar"}}}}}}
res, err := Marshal(doc)
wantErr := errors.New("BSON element key cannot contain null bytes")
assert.Equal(t, wantErr, err, "expected Marshal error %v, got error %v with result %q", wantErr, err, Raw(res))
})
}
func TestMarshalExtJSONIndent(t *testing.T) {
type indentTestCase struct {
name string
val interface{}
expectedExtJSON string
}
// expectedExtJSON must be written as below because single-quoted
// literal strings capture undesired code formatting tabs
testCases := []indentTestCase{
{
"empty val",
struct{}{},
`{}`,
},
{
"embedded struct",
struct {
Embedded interface{} `json:"embedded"`
Foo string `json:"foo"`
}{
Embedded: struct {
Name string `json:"name"`
Word string `json:"word"`
}{
Name: "test",
Word: "word",
},
Foo: "bar",
},
"{\n\t\"embedded\": {\n\t\t\"name\": \"test\",\n\t\t\"word\": \"word\"\n\t},\n\t\"foo\": \"bar\"\n}",
},
{
"date struct",
struct {
Foo string `json:"foo"`
Date time.Time `json:"date"`
}{
Foo: "bar",
Date: time.Date(2000, time.January, 1, 12, 0, 0, 0, time.UTC),
},
"{\n\t\"foo\": \"bar\",\n\t\"date\": {\n\t\t\"$date\": {\n\t\t\t\"$numberLong\": \"946728000000\"\n\t\t}\n\t}\n}",
},
{
"float struct",
struct {
Foo string `json:"foo"`
Float float32 `json:"float"`
}{
Foo: "bar",
Float: 3.14,
},
"{\n\t\"foo\": \"bar\",\n\t\"float\": {\n\t\t\"$numberDouble\": \"3.140000104904175\"\n\t}\n}",
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
extJSONBytes, err := MarshalExtJSONIndent(tc.val, true, false, "", "\t")
assert.Nil(t, err, "Marshal indent error: %v", err)
expectedExtJSONBytes := []byte(tc.expectedExtJSON)
assert.Equal(t, expectedExtJSONBytes, extJSONBytes, "expected:\n%s\ngot:\n%s", expectedExtJSONBytes, extJSONBytes)
})
}
}