package cryptext

import (
	"crypto/rand"
	"crypto/sha256"
	"encoding/base64"
	"encoding/hex"
	"errors"
	"fmt"
	"gogs.mikescher.com/BlackForestBytes/goext/langext"
	"gogs.mikescher.com/BlackForestBytes/goext/totpext"
	"golang.org/x/crypto/bcrypt"
	"strconv"
	"strings"
)

const LatestPassHashVersion = 4

// PassHash
// - [v0]: plaintext password ( `0|...` )
// - [v1]: sha256(plaintext)
// - [v2]: seed | sha256<seed>(plaintext)
// - [v3]: seed | sha256<seed>(plaintext) | [hex(totp)]
// - [v4]: bcrypt(plaintext) | [hex(totp)]
type PassHash string

func (ph PassHash) Valid() bool {
	_, _, _, _, _, valid := ph.Data()
	return valid
}

func (ph PassHash) HasTOTP() bool {
	_, _, _, otp, _, _ := ph.Data()
	return otp
}

func (ph PassHash) Data() (_version int, _seed []byte, _payload []byte, _totp bool, _totpsecret []byte, _valid bool) {

	split := strings.Split(string(ph), "|")
	if len(split) == 0 {
		return -1, nil, nil, false, nil, false
	}

	version, err := strconv.ParseInt(split[0], 10, 32)
	if err != nil {
		return -1, nil, nil, false, nil, false
	}

	if version == 0 {
		if len(split) != 2 {
			return -1, nil, nil, false, nil, false
		}
		return int(version), nil, []byte(split[1]), false, nil, true
	}

	if version == 1 {
		if len(split) != 2 {
			return -1, nil, nil, false, nil, false
		}
		payload, err := base64.RawStdEncoding.DecodeString(split[1])
		if err != nil {
			return -1, nil, nil, false, nil, false
		}
		return int(version), nil, payload, false, nil, true
	}

	//
	if version == 2 {
		if len(split) != 3 {
			return -1, nil, nil, false, nil, false
		}
		seed, err := base64.RawStdEncoding.DecodeString(split[1])
		if err != nil {
			return -1, nil, nil, false, nil, false
		}
		payload, err := base64.RawStdEncoding.DecodeString(split[2])
		if err != nil {
			return -1, nil, nil, false, nil, false
		}
		return int(version), seed, payload, false, nil, true
	}

	if version == 3 {
		if len(split) != 4 {
			return -1, nil, nil, false, nil, false
		}
		seed, err := base64.RawStdEncoding.DecodeString(split[1])
		if err != nil {
			return -1, nil, nil, false, nil, false
		}
		payload, err := base64.RawStdEncoding.DecodeString(split[2])
		if err != nil {
			return -1, nil, nil, false, nil, false
		}
		totp := false
		totpsecret := make([]byte, 0)
		if split[3] != "0" {
			totpsecret, err = hex.DecodeString(split[3])
			totp = true
		}
		return int(version), seed, payload, totp, totpsecret, true
	}

	if version == 4 {
		if len(split) != 3 {
			return -1, nil, nil, false, nil, false
		}
		payload := []byte(split[1])
		totp := false
		totpsecret := make([]byte, 0)
		if split[2] != "0" {
			totpsecret, err = hex.DecodeString(split[3])
			totp = true
		}
		return int(version), nil, payload, totp, totpsecret, true
	}

	return -1, nil, nil, false, nil, false
}

func (ph PassHash) Verify(plainpass string, totp *string) bool {
	version, seed, payload, hastotp, totpsecret, valid := ph.Data()
	if !valid {
		return false
	}

	if hastotp && totp == nil {
		return false
	}

	if version == 0 {
		return langext.ArrEqualsExact([]byte(plainpass), payload)
	}

	if version == 1 {
		return langext.ArrEqualsExact(hash256(plainpass), payload)
	}

	if version == 2 {
		return langext.ArrEqualsExact(hash256Seeded(plainpass, seed), payload)
	}

	if version == 3 {
		if !hastotp {
			return langext.ArrEqualsExact(hash256Seeded(plainpass, seed), payload)
		} else {
			return langext.ArrEqualsExact(hash256Seeded(plainpass, seed), payload) && totpext.Validate(totpsecret, *totp)
		}
	}

	if version == 4 {
		if !hastotp {
			return bcrypt.CompareHashAndPassword(payload, []byte(plainpass)) == nil
		} else {
			return bcrypt.CompareHashAndPassword(payload, []byte(plainpass)) == nil && totpext.Validate(totpsecret, *totp)
		}
	}

	return false
}

func (ph PassHash) NeedsPasswordUpgrade() bool {
	version, _, _, _, _, valid := ph.Data()
	return valid && version < LatestPassHashVersion
}

func (ph PassHash) Upgrade(plainpass string) (PassHash, error) {
	version, _, _, hastotp, totpsecret, valid := ph.Data()
	if !valid {
		return "", errors.New("invalid password")
	}
	if version == LatestPassHashVersion {
		return ph, nil
	}
	if hastotp {
		return HashPassword(plainpass, totpsecret)
	} else {
		return HashPassword(plainpass, nil)
	}
}

func (ph PassHash) ClearTOTP() (PassHash, error) {
	version, _, _, _, _, valid := ph.Data()
	if !valid {
		return "", errors.New("invalid PassHash")
	}

	if version == 0 {
		return ph, nil
	}

	if version == 1 {
		return ph, nil
	}

	if version == 2 {
		return ph, nil
	}

	if version == 3 {
		split := strings.Split(string(ph), "|")
		split[3] = "0"
		return PassHash(strings.Join(split, "|")), nil
	}

	if version == 4 {
		split := strings.Split(string(ph), "|")
		split[2] = "0"
		return PassHash(strings.Join(split, "|")), nil
	}

	return "", errors.New("unknown version")
}

func (ph PassHash) WithTOTP(totpSecret []byte) (PassHash, error) {
	version, _, _, _, _, valid := ph.Data()
	if !valid {
		return "", errors.New("invalid PassHash")
	}

	if version == 0 {
		return "", errors.New("version does not support totp, needs upgrade")
	}

	if version == 1 {
		return "", errors.New("version does not support totp, needs upgrade")
	}

	if version == 2 {
		return "", errors.New("version does not support totp, needs upgrade")
	}

	if version == 3 {
		split := strings.Split(string(ph), "|")
		split[3] = hex.EncodeToString(totpSecret)
		return PassHash(strings.Join(split, "|")), nil
	}

	if version == 4 {
		split := strings.Split(string(ph), "|")
		split[2] = hex.EncodeToString(totpSecret)
		return PassHash(strings.Join(split, "|")), nil
	}

	return "", errors.New("unknown version")
}

func (ph PassHash) Change(newPlainPass string) (PassHash, error) {
	version, _, _, hastotp, totpsecret, valid := ph.Data()
	if !valid {
		return "", errors.New("invalid PassHash")
	}

	if version == 0 {
		return HashPasswordV0(newPlainPass)
	}

	if version == 1 {
		return HashPasswordV1(newPlainPass)
	}

	if version == 2 {
		return HashPasswordV2(newPlainPass)
	}

	if version == 3 {
		return HashPasswordV3(newPlainPass, langext.Conditional(hastotp, totpsecret, nil))
	}

	if version == 4 {
		return HashPasswordV4(newPlainPass, langext.Conditional(hastotp, totpsecret, nil))
	}

	return "", errors.New("unknown version")
}

func (ph PassHash) String() string {
	return string(ph)
}

func HashPassword(plainpass string, totpSecret []byte) (PassHash, error) {
	return HashPasswordV4(plainpass, totpSecret)
}

func HashPasswordV4(plainpass string, totpSecret []byte) (PassHash, error) {
	var strtotp string

	if totpSecret == nil {
		strtotp = "0"
	} else {
		strtotp = hex.EncodeToString(totpSecret)
	}

	payload, err := bcrypt.GenerateFromPassword([]byte(plainpass), bcrypt.MinCost)
	if err != nil {
		return "", err
	}

	return PassHash(fmt.Sprintf("4|%s|%s", string(payload), strtotp)), nil
}

func HashPasswordV3(plainpass string, totpSecret []byte) (PassHash, error) {
	var strtotp string

	if totpSecret == nil {
		strtotp = "0"
	} else {
		strtotp = hex.EncodeToString(totpSecret)
	}

	seed, err := newSeed()
	if err != nil {
		return "", err
	}

	checksum := hash256Seeded(plainpass, seed)

	return PassHash(fmt.Sprintf("3|%s|%s|%s",
		base64.RawStdEncoding.EncodeToString(seed),
		base64.RawStdEncoding.EncodeToString(checksum),
		strtotp)), nil
}

func HashPasswordV2(plainpass string) (PassHash, error) {
	seed, err := newSeed()
	if err != nil {
		return "", err
	}

	checksum := hash256Seeded(plainpass, seed)

	return PassHash(fmt.Sprintf("2|%s|%s", base64.RawStdEncoding.EncodeToString(seed), base64.RawStdEncoding.EncodeToString(checksum))), nil
}

func HashPasswordV1(plainpass string) (PassHash, error) {
	return PassHash(fmt.Sprintf("1|%s", base64.RawStdEncoding.EncodeToString(hash256(plainpass)))), nil
}

func HashPasswordV0(plainpass string) (PassHash, error) {
	return PassHash(fmt.Sprintf("0|%s", plainpass)), nil
}

func hash256(s string) []byte {
	h := sha256.New()
	h.Write([]byte(s))
	bs := h.Sum(nil)
	return bs
}

func hash256Seeded(s string, seed []byte) []byte {
	h := sha256.New()
	h.Write(seed)
	h.Write([]byte(s))
	bs := h.Sum(nil)
	return bs
}

func newSeed() ([]byte, error) {
	secret := make([]byte, 32)
	_, err := rand.Read(secret)
	if err != nil {
		return nil, err
	}
	return secret, nil
}