package langext

import (
	"bytes"
	"errors"
	"math/big"
)

// shamelessly stolen from https://github.com/btcsuite/

type B58Encoding struct {
	bigRadix     [11]*big.Int
	bigRadix10   *big.Int
	alphabet     string
	alphabetIdx0 byte
	b58          [256]byte
}

var Base58DefaultEncoding = newBase58Encoding("123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz")
var Base58FlickrEncoding = newBase58Encoding("123456789abcdefghijkmnopqrstuvwxyzABCDEFGHJKLMNPQRSTUVWXYZ")
var Base58RippleEncoding = newBase58Encoding("rpshnaf39wBUDNEGHJKLM4PQRST7VWXYZ2bcdeCg65jkm8oFqi1tuvAxyz")
var Base58BitcoinEncoding = newBase58Encoding("123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz")

func newBase58Encoding(alphabet string) *B58Encoding {
	bigRadix10 := big.NewInt(58 * 58 * 58 * 58 * 58 * 58 * 58 * 58 * 58 * 58)
	enc := &B58Encoding{
		alphabet:     alphabet,
		alphabetIdx0: '1',
		bigRadix: [...]*big.Int{
			big.NewInt(0),
			big.NewInt(58),
			big.NewInt(58 * 58),
			big.NewInt(58 * 58 * 58),
			big.NewInt(58 * 58 * 58 * 58),
			big.NewInt(58 * 58 * 58 * 58 * 58),
			big.NewInt(58 * 58 * 58 * 58 * 58 * 58),
			big.NewInt(58 * 58 * 58 * 58 * 58 * 58 * 58),
			big.NewInt(58 * 58 * 58 * 58 * 58 * 58 * 58 * 58),
			big.NewInt(58 * 58 * 58 * 58 * 58 * 58 * 58 * 58 * 58),
			bigRadix10,
		},
		bigRadix10: bigRadix10,
	}

	b58 := make([]byte, 0, 256)

	for i := byte(0); i < 32; i++ {
		for j := byte(0); j < 8; j++ {

			b := i*8 + j

			idx := bytes.IndexByte([]byte(alphabet), b)
			if idx == -1 {
				b58 = append(b58, 255)
			} else {
				b58 = append(b58, byte(idx))
			}

		}
	}

	enc.b58 = *((*[256]byte)(b58))

	return enc
}

func (enc *B58Encoding) EncodeString(src string) (string, error) {
	v, err := enc.Encode([]byte(src))
	if err != nil {
		return "", err
	}
	return string(v), nil
}

func (enc *B58Encoding) Encode(src []byte) ([]byte, error) {
	x := new(big.Int)
	x.SetBytes(src)

	// maximum length of output is log58(2^(8*len(b))) == len(b) * 8 / log(58)
	maxlen := int(float64(len(src))*1.365658237309761) + 1
	answer := make([]byte, 0, maxlen)
	mod := new(big.Int)
	for x.Sign() > 0 {
		// Calculating with big.Int is slow for each iteration.
		//    x, mod = x / 58, x % 58
		//
		// Instead we can try to do as much calculations on int64.
		//    x, mod = x / 58^10, x % 58^10
		//
		// Which will give us mod, which is 10 digit base58 number.
		// We'll loop that 10 times to convert to the answer.

		x.DivMod(x, enc.bigRadix10, mod)
		if x.Sign() == 0 {
			// When x = 0, we need to ensure we don't add any extra zeros.
			m := mod.Int64()
			for m > 0 {
				answer = append(answer, enc.alphabet[m%58])
				m /= 58
			}
		} else {
			m := mod.Int64()
			for i := 0; i < 10; i++ {
				answer = append(answer, enc.alphabet[m%58])
				m /= 58
			}
		}
	}

	// leading zero bytes
	for _, i := range src {
		if i != 0 {
			break
		}
		answer = append(answer, enc.alphabetIdx0)
	}

	// reverse
	alen := len(answer)
	for i := 0; i < alen/2; i++ {
		answer[i], answer[alen-1-i] = answer[alen-1-i], answer[i]
	}

	return answer, nil
}

func (enc *B58Encoding) DecodeString(src string) (string, error) {
	v, err := enc.Decode([]byte(src))
	if err != nil {
		return "", err
	}
	return string(v), nil
}

func (enc *B58Encoding) Decode(src []byte) ([]byte, error) {
	answer := big.NewInt(0)
	scratch := new(big.Int)

	for t := src; len(t) > 0; {
		n := len(t)
		if n > 10 {
			n = 10
		}

		total := uint64(0)
		for _, v := range t[:n] {
			if v > 255 {
				return []byte{}, errors.New("invalid char in input")
			}

			tmp := enc.b58[v]
			if tmp == 255 {
				return []byte{}, errors.New("invalid char in input")
			}
			total = total*58 + uint64(tmp)
		}

		answer.Mul(answer, enc.bigRadix[n])
		scratch.SetUint64(total)
		answer.Add(answer, scratch)

		t = t[n:]
	}

	tmpval := answer.Bytes()

	var numZeros int
	for numZeros = 0; numZeros < len(src); numZeros++ {
		if src[numZeros] != enc.alphabetIdx0 {
			break
		}
	}
	flen := numZeros + len(tmpval)
	val := make([]byte, flen)
	copy(val[numZeros:], tmpval)

	return val, nil
}