package cryptext

import (
	"bytes"
	"crypto/aes"
	"crypto/cipher"
	"crypto/rand"
	"crypto/sha256"
	"encoding/base32"
	"encoding/json"
	"errors"
	"golang.org/x/crypto/scrypt"
	"io"
)

// https://stackoverflow.com/a/18819040/1761622

type aesPayload struct {
	Salt    []byte `json:"s"`
	IV      []byte `json:"i"`
	Data    []byte `json:"d"`
	Rounds  int    `json:"r"`
	Version uint   `json:"v"`
}

func EncryptAESSimple(password []byte, data []byte, rounds int) (string, error) {

	salt := make([]byte, 8)
	_, err := io.ReadFull(rand.Reader, salt)
	if err != nil {
		return "", err
	}

	key, err := scrypt.Key(password, salt, rounds, 8, 1, 32)
	if err != nil {
		return "", err
	}

	block, err := aes.NewCipher(key)
	if err != nil {
		return "", err
	}

	h := sha256.New()
	h.Write(data)
	checksum := h.Sum(nil)
	if len(checksum) != 32 {
		return "", errors.New("wrong cs size")
	}

	ciphertext := make([]byte, 32+len(data))

	iv := make([]byte, aes.BlockSize)
	_, err = io.ReadFull(rand.Reader, iv)
	if err != nil {
		return "", err
	}

	combinedData := make([]byte, 0, 32+len(data))
	combinedData = append(combinedData, checksum...)
	combinedData = append(combinedData, data...)

	cfb := cipher.NewCFBEncrypter(block, iv)
	cfb.XORKeyStream(ciphertext, combinedData)

	pl := aesPayload{
		Salt:    salt,
		IV:      iv,
		Data:    ciphertext,
		Version: 1,
		Rounds:  rounds,
	}

	jbin, err := json.Marshal(pl)
	if err != nil {
		return "", err
	}

	res := base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(jbin)

	return res, nil
}

func DecryptAESSimple(password []byte, encText string) ([]byte, error) {

	jbin, err := base32.StdEncoding.WithPadding(base32.NoPadding).DecodeString(encText)
	if err != nil {
		return nil, err
	}

	var pl aesPayload
	err = json.Unmarshal(jbin, &pl)
	if err != nil {
		return nil, err
	}

	if pl.Version != 1 {
		return nil, errors.New("unsupported version")
	}

	key, err := scrypt.Key(password, pl.Salt, pl.Rounds, 8, 1, 32) // this is not 100% correct, rounds too low and salt is missing
	if err != nil {
		return nil, err
	}

	block, err := aes.NewCipher(key)
	if err != nil {
		return nil, err
	}

	dest := make([]byte, len(pl.Data))

	cfb := cipher.NewCFBDecrypter(block, pl.IV)
	cfb.XORKeyStream(dest, pl.Data)

	if len(dest) < 32 {
		return nil, errors.New("payload too small")
	}

	chck := dest[:32]
	data := dest[32:]

	h := sha256.New()
	h.Write(data)
	chck2 := h.Sum(nil)

	if !bytes.Equal(chck, chck2) {
		return nil, errors.New("checksum mismatch")
	}

	return data, nil
}