package reflectext

import (
	"reflect"
)

var reflectBasicTypes = map[reflect.Kind]reflect.Type{
	reflect.Bool:       reflect.TypeOf(false),
	reflect.Int:        reflect.TypeOf(int(0)),
	reflect.Int8:       reflect.TypeOf(int8(0)),
	reflect.Int16:      reflect.TypeOf(int16(0)),
	reflect.Int32:      reflect.TypeOf(int32(0)),
	reflect.Int64:      reflect.TypeOf(int64(0)),
	reflect.Uint:       reflect.TypeOf(uint(0)),
	reflect.Uint8:      reflect.TypeOf(uint8(0)),
	reflect.Uint16:     reflect.TypeOf(uint16(0)),
	reflect.Uint32:     reflect.TypeOf(uint32(0)),
	reflect.Uint64:     reflect.TypeOf(uint64(0)),
	reflect.Uintptr:    reflect.TypeOf(uintptr(0)),
	reflect.Float32:    reflect.TypeOf(float32(0)),
	reflect.Float64:    reflect.TypeOf(float64(0)),
	reflect.Complex64:  reflect.TypeOf(complex64(0)),
	reflect.Complex128: reflect.TypeOf(complex128(0)),
	reflect.String:     reflect.TypeOf(""),
}

// Underlying returns the underlying type of t (without type alias)
//
// https://github.com/golang/go/issues/39574#issuecomment-655664772
func Underlying(t reflect.Type) (ret reflect.Type) {
	if t.Name() == "" {
		// t is an unnamed type. the underlying type is t itself
		return t
	}
	kind := t.Kind()
	if ret = reflectBasicTypes[kind]; ret != nil {
		return ret
	}
	switch kind {
	case reflect.Array:
		ret = reflect.ArrayOf(t.Len(), t.Elem())
	case reflect.Chan:
		ret = reflect.ChanOf(t.ChanDir(), t.Elem())
	case reflect.Map:
		ret = reflect.MapOf(t.Key(), t.Elem())
	case reflect.Func:
		nIn := t.NumIn()
		nOut := t.NumOut()
		in := make([]reflect.Type, nIn)
		out := make([]reflect.Type, nOut)
		for i := 0; i < nIn; i++ {
			in[i] = t.In(i)
		}
		for i := 0; i < nOut; i++ {
			out[i] = t.Out(i)
		}
		ret = reflect.FuncOf(in, out, t.IsVariadic())
	case reflect.Interface:
		// not supported
	case reflect.Ptr:
		ret = reflect.PtrTo(t.Elem())
	case reflect.Slice:
		ret = reflect.SliceOf(t.Elem())
	case reflect.Struct:
		// only partially supported: embedded fields
		// and unexported fields may cause panic in reflect.StructOf()
		defer func() {
			// if a panic happens, return t unmodified
			if recover() != nil && ret == nil {
				ret = t
			}
		}()
		n := t.NumField()
		fields := make([]reflect.StructField, n)
		for i := 0; i < n; i++ {
			fields[i] = t.Field(i)
		}
		ret = reflect.StructOf(fields)
	}
	return ret
}

// TryCast works similar to `v2, ok := v.(T)`
// Except it works through type alias'
func TryCast[T any](v any) (T, bool) {

	underlying := Underlying(reflect.TypeOf(v))

	def := *new(T)

	if underlying != Underlying(reflect.TypeOf(def)) {
		return def, false
	}

	r1 := reflect.ValueOf(v)

	if !r1.CanConvert(underlying) {
		return def, false
	}

	r2 := r1.Convert(underlying)

	r3 := r2.Interface()

	r4, ok := r3.(T)
	if !ok {
		return def, false
	}

	return r4, true
}

func TryCastType(v any, dest reflect.Type) (any, bool) {

	underlying := Underlying(reflect.TypeOf(v))

	if underlying != Underlying(dest) {
		return nil, false
	}

	r1 := reflect.ValueOf(v)

	if !r1.CanConvert(underlying) {
		return nil, false
	}

	r2 := r1.Convert(underlying)

	if !r2.CanConvert(dest) {
		return nil, false
	}

	r4 := r2.Convert(dest)

	return r4.Interface(), true
}