package dataext

import (
	"bytes"
	"crypto/sha256"
	"encoding/binary"
	"errors"
	"fmt"
	"gogs.mikescher.com/BlackForestBytes/goext/langext"
	"hash"
	"io"
	"reflect"
	"sort"
)

type StructHashOptions struct {
	HashAlgo    hash.Hash
	Tag         *string
	SkipChannel bool
	SkipFunc    bool
}

func StructHash(dat any, opt ...StructHashOptions) (r []byte, err error) {
	defer func() {
		if rec := recover(); rec != nil {
			r = nil
			err = errors.New(fmt.Sprintf("recovered panic: %v", rec))
		}
	}()

	shopt := StructHashOptions{}
	if len(opt) > 1 {
		return nil, errors.New("multiple options supplied")
	} else if len(opt) == 1 {
		shopt = opt[0]
	}

	if shopt.HashAlgo == nil {
		shopt.HashAlgo = sha256.New()
	}

	writer := new(bytes.Buffer)

	if langext.IsNil(dat) {
		shopt.HashAlgo.Reset()
		shopt.HashAlgo.Write(writer.Bytes())
		res := shopt.HashAlgo.Sum(nil)
		return res, nil
	}

	err = binarize(writer, reflect.ValueOf(dat), shopt)
	if err != nil {
		return nil, err
	}

	shopt.HashAlgo.Reset()
	shopt.HashAlgo.Write(writer.Bytes())
	res := shopt.HashAlgo.Sum(nil)

	return res, nil
}

func writeBinarized(writer io.Writer, dat any) error {
	tmp := bytes.Buffer{}
	err := binary.Write(&tmp, binary.LittleEndian, dat)
	if err != nil {
		return err
	}
	err = binary.Write(writer, binary.LittleEndian, uint64(tmp.Len()))
	if err != nil {
		return err
	}
	_, err = writer.Write(tmp.Bytes())
	if err != nil {
		return err
	}
	return nil
}

func binarize(writer io.Writer, dat reflect.Value, opt StructHashOptions) error {
	var err error

	err = binary.Write(writer, binary.LittleEndian, uint8(dat.Kind()))
	switch dat.Kind() {
	case reflect.Ptr, reflect.Map, reflect.Array, reflect.Chan, reflect.Slice, reflect.Interface:
		if dat.IsNil() {
			err = binary.Write(writer, binary.LittleEndian, uint64(0))
			if err != nil {
				return err
			}
			return nil
		}
	}

	err = binary.Write(writer, binary.LittleEndian, uint64(len(dat.Type().String())))
	if err != nil {
		return err
	}
	_, err = writer.Write([]byte(dat.Type().String()))
	if err != nil {
		return err
	}

	switch dat.Type().Kind() {
	case reflect.Invalid:
		return errors.New("cannot binarize value of kind <Invalid>")
	case reflect.Bool:
		return writeBinarized(writer, dat.Bool())
	case reflect.Int:
		return writeBinarized(writer, int64(dat.Int()))
	case reflect.Int8:
		fallthrough
	case reflect.Int16:
		fallthrough
	case reflect.Int32:
		fallthrough
	case reflect.Int64:
		return writeBinarized(writer, dat.Interface())
	case reflect.Uint:
		return writeBinarized(writer, uint64(dat.Int()))
	case reflect.Uint8:
		fallthrough
	case reflect.Uint16:
		fallthrough
	case reflect.Uint32:
		fallthrough
	case reflect.Uint64:
		return writeBinarized(writer, dat.Interface())
	case reflect.Uintptr:
		return errors.New("cannot binarize value of kind <Uintptr>")
	case reflect.Float32:
		fallthrough
	case reflect.Float64:
		return writeBinarized(writer, dat.Interface())
	case reflect.Complex64:
		return errors.New("cannot binarize value of kind <Complex64>")
	case reflect.Complex128:
		return errors.New("cannot binarize value of kind <Complex128>")
	case reflect.Slice:
		fallthrough
	case reflect.Array:
		return binarizeArrayOrSlice(writer, dat, opt)
	case reflect.Chan:
		if opt.SkipChannel {
			return nil
		}
		return errors.New("cannot binarize value of kind <Chan>")
	case reflect.Func:
		if opt.SkipFunc {
			return nil
		}
		return errors.New("cannot binarize value of kind <Func>")
	case reflect.Interface:
		return binarize(writer, dat.Elem(), opt)
	case reflect.Map:
		return binarizeMap(writer, dat, opt)
	case reflect.Pointer:
		return binarize(writer, dat.Elem(), opt)
	case reflect.String:
		v := dat.String()
		err = binary.Write(writer, binary.LittleEndian, uint64(len(v)))
		if err != nil {
			return err
		}
		_, err = writer.Write([]byte(v))
		if err != nil {
			return err
		}
		return nil
	case reflect.Struct:
		return binarizeStruct(writer, dat, opt)
	case reflect.UnsafePointer:
		return errors.New("cannot binarize value of kind <UnsafePointer>")
	default:
		return errors.New("cannot binarize value of unknown kind <" + dat.Type().Kind().String() + ">")
	}
}

func binarizeStruct(writer io.Writer, dat reflect.Value, opt StructHashOptions) error {
	err := binary.Write(writer, binary.LittleEndian, uint64(dat.NumField()))
	if err != nil {
		return err
	}

	for i := 0; i < dat.NumField(); i++ {

		if opt.Tag != nil {
			if _, ok := dat.Type().Field(i).Tag.Lookup(*opt.Tag); !ok {
				continue
			}
		}

		err = binary.Write(writer, binary.LittleEndian, uint64(len(dat.Type().Field(i).Name)))
		if err != nil {
			return err
		}
		_, err = writer.Write([]byte(dat.Type().Field(i).Name))
		if err != nil {
			return err
		}

		err = binarize(writer, dat.Field(i), opt)
		if err != nil {
			return err
		}
	}

	return nil
}

func binarizeArrayOrSlice(writer io.Writer, dat reflect.Value, opt StructHashOptions) error {
	err := binary.Write(writer, binary.LittleEndian, uint64(dat.Len()))
	if err != nil {
		return err
	}

	for i := 0; i < dat.Len(); i++ {
		err := binarize(writer, dat.Index(i), opt)
		if err != nil {
			return err
		}
	}

	return nil
}

func binarizeMap(writer io.Writer, dat reflect.Value, opt StructHashOptions) error {
	err := binary.Write(writer, binary.LittleEndian, uint64(dat.Len()))
	if err != nil {
		return err
	}

	sub := make([][]byte, 0, dat.Len())

	for _, k := range dat.MapKeys() {
		tmp := bytes.Buffer{}
		err = binarize(&tmp, dat.MapIndex(k), opt)
		if err != nil {
			return err
		}
		sub = append(sub, tmp.Bytes())
	}

	sort.Slice(sub, func(i1, i2 int) bool { return bytes.Compare(sub[i1], sub[i2]) < 0 })

	for _, v := range sub {
		_, err = writer.Write(v)
		if err != nil {
			return err
		}
	}

	return nil
}