383 lines
10 KiB
Go
383 lines
10 KiB
Go
|
// Copyright (C) MongoDB, Inc. 2017-present.
|
||
|
//
|
||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||
|
// not use this file except in compliance with the License. You may obtain
|
||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||
|
|
||
|
package bson
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"reflect"
|
||
|
"testing"
|
||
|
"time"
|
||
|
|
||
|
"github.com/google/go-cmp/cmp"
|
||
|
"github.com/stretchr/testify/require"
|
||
|
"go.mongodb.org/mongo-driver/bson/bsoncodec"
|
||
|
"go.mongodb.org/mongo-driver/bson/bsonrw"
|
||
|
"go.mongodb.org/mongo-driver/bson/primitive"
|
||
|
"go.mongodb.org/mongo-driver/internal/testutil/assert"
|
||
|
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
|
||
|
)
|
||
|
|
||
|
var tInt32 = reflect.TypeOf(int32(0))
|
||
|
|
||
|
func TestMarshalAppendWithRegistry(t *testing.T) {
|
||
|
for _, tc := range marshalingTestCases {
|
||
|
t.Run(tc.name, func(t *testing.T) {
|
||
|
dst := make([]byte, 0, 1024)
|
||
|
var reg *bsoncodec.Registry
|
||
|
if tc.reg != nil {
|
||
|
reg = tc.reg
|
||
|
} else {
|
||
|
reg = DefaultRegistry
|
||
|
}
|
||
|
got, err := MarshalAppendWithRegistry(reg, dst, tc.val)
|
||
|
noerr(t, err)
|
||
|
|
||
|
if !bytes.Equal(got, tc.want) {
|
||
|
t.Errorf("Bytes are not equal. got %v; want %v", got, tc.want)
|
||
|
t.Errorf("Bytes:\n%v\n%v", got, tc.want)
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestMarshalAppendWithContext(t *testing.T) {
|
||
|
for _, tc := range marshalingTestCases {
|
||
|
t.Run(tc.name, func(t *testing.T) {
|
||
|
dst := make([]byte, 0, 1024)
|
||
|
var reg *bsoncodec.Registry
|
||
|
if tc.reg != nil {
|
||
|
reg = tc.reg
|
||
|
} else {
|
||
|
reg = DefaultRegistry
|
||
|
}
|
||
|
ec := bsoncodec.EncodeContext{Registry: reg}
|
||
|
got, err := MarshalAppendWithContext(ec, dst, tc.val)
|
||
|
noerr(t, err)
|
||
|
|
||
|
if !bytes.Equal(got, tc.want) {
|
||
|
t.Errorf("Bytes are not equal. got %v; want %v", got, tc.want)
|
||
|
t.Errorf("Bytes:\n%v\n%v", got, tc.want)
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestMarshalWithRegistry(t *testing.T) {
|
||
|
for _, tc := range marshalingTestCases {
|
||
|
t.Run(tc.name, func(t *testing.T) {
|
||
|
var reg *bsoncodec.Registry
|
||
|
if tc.reg != nil {
|
||
|
reg = tc.reg
|
||
|
} else {
|
||
|
reg = DefaultRegistry
|
||
|
}
|
||
|
got, err := MarshalWithRegistry(reg, tc.val)
|
||
|
noerr(t, err)
|
||
|
|
||
|
if !bytes.Equal(got, tc.want) {
|
||
|
t.Errorf("Bytes are not equal. got %v; want %v", got, tc.want)
|
||
|
t.Errorf("Bytes:\n%v\n%v", got, tc.want)
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestMarshalWithContext(t *testing.T) {
|
||
|
for _, tc := range marshalingTestCases {
|
||
|
t.Run(tc.name, func(t *testing.T) {
|
||
|
var reg *bsoncodec.Registry
|
||
|
if tc.reg != nil {
|
||
|
reg = tc.reg
|
||
|
} else {
|
||
|
reg = DefaultRegistry
|
||
|
}
|
||
|
ec := bsoncodec.EncodeContext{Registry: reg}
|
||
|
got, err := MarshalWithContext(ec, tc.val)
|
||
|
noerr(t, err)
|
||
|
|
||
|
if !bytes.Equal(got, tc.want) {
|
||
|
t.Errorf("Bytes are not equal. got %v; want %v", got, tc.want)
|
||
|
t.Errorf("Bytes:\n%v\n%v", got, tc.want)
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestMarshalAppend(t *testing.T) {
|
||
|
for _, tc := range marshalingTestCases {
|
||
|
t.Run(tc.name, func(t *testing.T) {
|
||
|
if tc.reg != nil {
|
||
|
t.Skip() // test requires custom registry
|
||
|
}
|
||
|
dst := make([]byte, 0, 1024)
|
||
|
got, err := MarshalAppend(dst, tc.val)
|
||
|
noerr(t, err)
|
||
|
|
||
|
if !bytes.Equal(got, tc.want) {
|
||
|
t.Errorf("Bytes are not equal. got %v; want %v", got, tc.want)
|
||
|
t.Errorf("Bytes:\n%v\n%v", got, tc.want)
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestMarshalExtJSONAppendWithContext(t *testing.T) {
|
||
|
t.Run("MarshalExtJSONAppendWithContext", func(t *testing.T) {
|
||
|
dst := make([]byte, 0, 1024)
|
||
|
type teststruct struct{ Foo int }
|
||
|
val := teststruct{1}
|
||
|
ec := bsoncodec.EncodeContext{Registry: DefaultRegistry}
|
||
|
got, err := MarshalExtJSONAppendWithContext(ec, dst, val, true, false)
|
||
|
noerr(t, err)
|
||
|
want := []byte(`{"foo":{"$numberInt":"1"}}`)
|
||
|
if !bytes.Equal(got, want) {
|
||
|
t.Errorf("Bytes are not equal. got %v; want %v", got, want)
|
||
|
t.Errorf("Bytes:\n%s\n%s", got, want)
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
|
||
|
func TestMarshalExtJSONWithContext(t *testing.T) {
|
||
|
t.Run("MarshalExtJSONWithContext", func(t *testing.T) {
|
||
|
type teststruct struct{ Foo int }
|
||
|
val := teststruct{1}
|
||
|
ec := bsoncodec.EncodeContext{Registry: DefaultRegistry}
|
||
|
got, err := MarshalExtJSONWithContext(ec, val, true, false)
|
||
|
noerr(t, err)
|
||
|
want := []byte(`{"foo":{"$numberInt":"1"}}`)
|
||
|
if !bytes.Equal(got, want) {
|
||
|
t.Errorf("Bytes are not equal. got %v; want %v", got, want)
|
||
|
t.Errorf("Bytes:\n%s\n%s", got, want)
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
|
||
|
func TestMarshal_roundtripFromBytes(t *testing.T) {
|
||
|
before := []byte{
|
||
|
// length
|
||
|
0x1c, 0x0, 0x0, 0x0,
|
||
|
|
||
|
// --- begin array ---
|
||
|
|
||
|
// type - document
|
||
|
0x3,
|
||
|
// key - "foo"
|
||
|
0x66, 0x6f, 0x6f, 0x0,
|
||
|
|
||
|
// length
|
||
|
0x12, 0x0, 0x0, 0x0,
|
||
|
// type - string
|
||
|
0x2,
|
||
|
// key - "bar"
|
||
|
0x62, 0x61, 0x72, 0x0,
|
||
|
// value - string length
|
||
|
0x4, 0x0, 0x0, 0x0,
|
||
|
// value - "baz"
|
||
|
0x62, 0x61, 0x7a, 0x0,
|
||
|
|
||
|
// null terminator
|
||
|
0x0,
|
||
|
|
||
|
// --- end array ---
|
||
|
|
||
|
// null terminator
|
||
|
0x0,
|
||
|
}
|
||
|
|
||
|
var doc D
|
||
|
require.NoError(t, Unmarshal(before, &doc))
|
||
|
|
||
|
after, err := Marshal(doc)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
require.True(t, bytes.Equal(before, after))
|
||
|
}
|
||
|
|
||
|
func TestMarshal_roundtripFromDoc(t *testing.T) {
|
||
|
before := D{
|
||
|
{"foo", "bar"},
|
||
|
{"baz", int64(-27)},
|
||
|
{"bing", A{nil, primitive.Regex{Pattern: "word", Options: "i"}}},
|
||
|
}
|
||
|
|
||
|
b, err := Marshal(before)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
var after D
|
||
|
require.NoError(t, Unmarshal(b, &after))
|
||
|
|
||
|
if !cmp.Equal(after, before) {
|
||
|
t.Errorf("Documents to not match. got %v; want %v", after, before)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestCachingEncodersNotSharedAcrossRegistries(t *testing.T) {
|
||
|
// Encoders that have caches for recursive encoder lookup should not be shared across Registry instances. Otherwise,
|
||
|
// the first EncodeValue call would cache an encoder and a subsequent call would see that encoder even if a
|
||
|
// different Registry is used.
|
||
|
|
||
|
// Create a custom Registry that negates int32 values when encoding.
|
||
|
var encodeInt32 bsoncodec.ValueEncoderFunc = func(_ bsoncodec.EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
|
||
|
if val.Kind() != reflect.Int32 {
|
||
|
return fmt.Errorf("expected kind to be int32, got %v", val.Kind())
|
||
|
}
|
||
|
|
||
|
return vw.WriteInt32(int32(val.Int()) * -1)
|
||
|
}
|
||
|
customReg := NewRegistryBuilder().
|
||
|
RegisterTypeEncoder(tInt32, encodeInt32).
|
||
|
Build()
|
||
|
|
||
|
// Helper function to run the test and make assertions. The provided original value should result in the document
|
||
|
// {"x": {$numberInt: 1}} when marshalled with the default registry.
|
||
|
verifyResults := func(t *testing.T, original interface{}) {
|
||
|
// Marshal using the default and custom registries. Assert that the result is {x: 1} and {x: -1}, respectively.
|
||
|
|
||
|
first, err := Marshal(original)
|
||
|
assert.Nil(t, err, "Marshal error: %v", err)
|
||
|
expectedFirst := Raw(bsoncore.BuildDocumentFromElements(
|
||
|
nil,
|
||
|
bsoncore.AppendInt32Element(nil, "x", 1),
|
||
|
))
|
||
|
assert.Equal(t, expectedFirst, Raw(first), "expected document %v, got %v", expectedFirst, Raw(first))
|
||
|
|
||
|
second, err := MarshalWithRegistry(customReg, original)
|
||
|
assert.Nil(t, err, "Marshal error: %v", err)
|
||
|
expectedSecond := Raw(bsoncore.BuildDocumentFromElements(
|
||
|
nil,
|
||
|
bsoncore.AppendInt32Element(nil, "x", -1),
|
||
|
))
|
||
|
assert.Equal(t, expectedSecond, Raw(second), "expected document %v, got %v", expectedSecond, Raw(second))
|
||
|
}
|
||
|
|
||
|
t.Run("struct", func(t *testing.T) {
|
||
|
type Struct struct {
|
||
|
X int32
|
||
|
}
|
||
|
verifyResults(t, Struct{
|
||
|
X: 1,
|
||
|
})
|
||
|
})
|
||
|
t.Run("pointer", func(t *testing.T) {
|
||
|
i32 := int32(1)
|
||
|
verifyResults(t, M{
|
||
|
"x": &i32,
|
||
|
})
|
||
|
})
|
||
|
}
|
||
|
|
||
|
func TestNullBytes(t *testing.T) {
|
||
|
t.Run("element keys", func(t *testing.T) {
|
||
|
doc := D{{"a\x00", "foobar"}}
|
||
|
res, err := Marshal(doc)
|
||
|
want := errors.New("BSON element key cannot contain null bytes")
|
||
|
assert.Equal(t, want, err, "expected Marshal error %v, got error %v with result %q", want, err, Raw(res))
|
||
|
})
|
||
|
|
||
|
t.Run("regex values", func(t *testing.T) {
|
||
|
wantErr := errors.New("BSON regex values cannot contain null bytes")
|
||
|
|
||
|
testCases := []struct {
|
||
|
name string
|
||
|
pattern string
|
||
|
options string
|
||
|
}{
|
||
|
{"null bytes in pattern", "a\x00", "i"},
|
||
|
{"null bytes in options", "pattern", "i\x00"},
|
||
|
}
|
||
|
for _, tc := range testCases {
|
||
|
t.Run(tc.name, func(t *testing.T) {
|
||
|
regex := primitive.Regex{
|
||
|
Pattern: tc.pattern,
|
||
|
Options: tc.options,
|
||
|
}
|
||
|
res, err := Marshal(D{{"foo", regex}})
|
||
|
assert.Equal(t, wantErr, err, "expected Marshal error %v, got error %v with result %q", wantErr, err, Raw(res))
|
||
|
})
|
||
|
}
|
||
|
})
|
||
|
|
||
|
t.Run("sub document field name", func(t *testing.T) {
|
||
|
doc := D{{"foo", D{{"foobar", D{{"a\x00", "foobar"}}}}}}
|
||
|
res, err := Marshal(doc)
|
||
|
wantErr := errors.New("BSON element key cannot contain null bytes")
|
||
|
assert.Equal(t, wantErr, err, "expected Marshal error %v, got error %v with result %q", wantErr, err, Raw(res))
|
||
|
})
|
||
|
}
|
||
|
|
||
|
func TestMarshalExtJSONIndent(t *testing.T) {
|
||
|
type indentTestCase struct {
|
||
|
name string
|
||
|
val interface{}
|
||
|
expectedExtJSON string
|
||
|
}
|
||
|
|
||
|
// expectedExtJSON must be written as below because single-quoted
|
||
|
// literal strings capture undesired code formatting tabs
|
||
|
testCases := []indentTestCase{
|
||
|
{
|
||
|
"empty val",
|
||
|
struct{}{},
|
||
|
`{}`,
|
||
|
},
|
||
|
{
|
||
|
"embedded struct",
|
||
|
struct {
|
||
|
Embedded interface{} `json:"embedded"`
|
||
|
Foo string `json:"foo"`
|
||
|
}{
|
||
|
Embedded: struct {
|
||
|
Name string `json:"name"`
|
||
|
Word string `json:"word"`
|
||
|
}{
|
||
|
Name: "test",
|
||
|
Word: "word",
|
||
|
},
|
||
|
Foo: "bar",
|
||
|
},
|
||
|
"{\n\t\"embedded\": {\n\t\t\"name\": \"test\",\n\t\t\"word\": \"word\"\n\t},\n\t\"foo\": \"bar\"\n}",
|
||
|
},
|
||
|
{
|
||
|
"date struct",
|
||
|
struct {
|
||
|
Foo string `json:"foo"`
|
||
|
Date time.Time `json:"date"`
|
||
|
}{
|
||
|
Foo: "bar",
|
||
|
Date: time.Date(2000, time.January, 1, 12, 0, 0, 0, time.UTC),
|
||
|
},
|
||
|
"{\n\t\"foo\": \"bar\",\n\t\"date\": {\n\t\t\"$date\": {\n\t\t\t\"$numberLong\": \"946728000000\"\n\t\t}\n\t}\n}",
|
||
|
},
|
||
|
{
|
||
|
"float struct",
|
||
|
struct {
|
||
|
Foo string `json:"foo"`
|
||
|
Float float32 `json:"float"`
|
||
|
}{
|
||
|
Foo: "bar",
|
||
|
Float: 3.14,
|
||
|
},
|
||
|
"{\n\t\"foo\": \"bar\",\n\t\"float\": {\n\t\t\"$numberDouble\": \"3.140000104904175\"\n\t}\n}",
|
||
|
},
|
||
|
}
|
||
|
|
||
|
for _, tc := range testCases {
|
||
|
tc := tc
|
||
|
t.Run(tc.name, func(t *testing.T) {
|
||
|
t.Parallel()
|
||
|
extJSONBytes, err := MarshalExtJSONIndent(tc.val, true, false, "", "\t")
|
||
|
assert.Nil(t, err, "Marshal indent error: %v", err)
|
||
|
|
||
|
expectedExtJSONBytes := []byte(tc.expectedExtJSON)
|
||
|
|
||
|
assert.Equal(t, expectedExtJSONBytes, extJSONBytes, "expected:\n%s\ngot:\n%s", expectedExtJSONBytes, extJSONBytes)
|
||
|
})
|
||
|
}
|
||
|
}
|