diff --git a/wmo/reflection.go b/wmo/reflection.go index 6646332..f1cd280 100644 --- a/wmo/reflection.go +++ b/wmo/reflection.go @@ -25,7 +25,7 @@ func (c *Coll[TData]) EnsureInitializedReflection(v TData) { m := make(map[string]fullTypeRef) - c.initFields("", rval.Type(), m, make([]int, 0)) + c.initFields("", rval.Type(), m, make([]int, 0), make([]reflect.Type, 0)) c.implDataTypeMap[rval.Type()] = m } @@ -50,14 +50,13 @@ func (c *Coll[TData]) init() { c.implDataTypeMap = make(map[reflect.Type]map[string]fullTypeRef) v := reflect.ValueOf(example) - c.initFields("", v.Type(), c.dataTypeMap, make([]int, 0)) + c.initFields("", v.Type(), c.dataTypeMap, make([]int, 0), make([]reflect.Type, 0)) } } -func (c *Coll[TData]) initFields(prefix string, rtyp reflect.Type, m map[string]fullTypeRef, idxarr []int) { - +func (c *Coll[TData]) initFields(prefix string, rtyp reflect.Type, m map[string]fullTypeRef, idxarr []int, typesInPath []reflect.Type) { for i := 0; i < rtyp.NumField(); i++ { rsfield := rtyp.Field(i) @@ -91,7 +90,7 @@ func (c *Coll[TData]) initFields(prefix string, rtyp reflect.Type, m map[string] if langext.InArray("inline", bsontags) && rsfield.Type.Kind() == reflect.Struct { // pass-through field - c.initFields(prefix, rsfield.Type, m, newIdxArr) + c.initFields(prefix, rsfield.Type, m, newIdxArr, typesInPath) } else { @@ -122,11 +121,25 @@ func (c *Coll[TData]) initFields(prefix string, rtyp reflect.Type, m map[string] } if rsfield.Type.Kind() == reflect.Struct { - c.initFields(fullKey+".", rsfield.Type, m, newIdxArr) + c.initFields(fullKey+".", rsfield.Type, m, newIdxArr, typesInPath) } if rsfield.Type.Kind() == reflect.Pointer && rsfield.Type.Elem().Kind() == reflect.Struct { - c.initFields(fullKey+".", rsfield.Type.Elem(), m, newIdxArr) + innerType := rsfield.Type.Elem() + + // check if there is recursion + recursion := false + for _, typ := range typesInPath { + recursion = recursion || (typ == innerType) + } + + if !recursion { + // Store all seen types before that deref a pointer to prevent endless recursion + newTypesInPath := make([]reflect.Type, len(typesInPath)) + copy(newTypesInPath, typesInPath) + newTypesInPath = append(newTypesInPath, rtyp) + c.initFields(fullKey+".", innerType, m, newIdxArr, newTypesInPath) + } } } diff --git a/wmo/reflection_test.go b/wmo/reflection_test.go index 0a1de71..e5764c2 100644 --- a/wmo/reflection_test.go +++ b/wmo/reflection_test.go @@ -113,19 +113,25 @@ func TestReflectionGetTokenValueAsMongoType(t *testing.T) { type IDType string + type RecurseiveType struct { + Other int `bson:"other"` + Inner *RecurseiveType `bson:"inner"` + } + type TestData struct { ID IDType `bson:"_id"` CDate time.Time `bson:"cdate"` Sub struct { A string `bson:"a"` } `bson:"sub"` - SubPtr struct { + SubPtr *struct { A string `bson:"a"` } `bson:"subPtr"` Str string `bson:"str"` Ptr *int `bson:"ptr"` Num int `bson:"num"` MDate rfctime.RFC3339NanoTime `bson:"mdate"` + Rec RecurseiveType `bson:"rec"` } coll := W[TestData](&mongo.Collection{}) @@ -149,6 +155,7 @@ func TestReflectionGetTokenValueAsMongoType(t *testing.T) { tst.AssertEqual(t, gtvasmt("hello", "str").(string), "hello") tst.AssertEqual(t, gtvasmt("hello", "sub.a").(string), "hello") tst.AssertEqual(t, gtvasmt("hello", "subPtr.a").(string), "hello") + tst.AssertEqual(t, gtvasmt("4", "rec.other").(int), 4) tst.AssertEqual(t, gtvasmt("4", "num").(int), 4) tst.AssertEqual(t, gtvasmt("asdf", "_id").(IDType), "asdf") tst.AssertEqual(t, gtvasmt("", "ptr").(*int), nil)