package sq

import (
	"context"
	"database/sql"
	"errors"
	"fmt"
	"github.com/jmoiron/sqlx"
	"reflect"
	"strings"
)

type StructScanMode string

const (
	SModeFast     StructScanMode = "FAST"
	SModeExtended StructScanMode = "EXTENDED"
)

type StructScanSafety string

const (
	Safe   StructScanSafety = "SAFE"   // return error for missing fields
	Unsafe StructScanSafety = "UNSAFE" // ignore missing fields
)

func InsertSingle[TData any](ctx context.Context, q Queryable, tableName string, v TData) (sql.Result, error) {

	rval := reflect.ValueOf(v)
	rtyp := rval.Type()

	columns := make([]string, 0)
	params := make([]string, 0)
	pp := PP{}

	for i := 0; i < rtyp.NumField(); i++ {

		rsfield := rtyp.Field(i)
		rvfield := rval.Field(i)

		if !rsfield.IsExported() {
			continue
		}

		columnName := rsfield.Tag.Get("db")
		if columnName == "" || columnName == "-" {
			continue
		}

		paramkey := fmt.Sprintf("_%s", columnName)

		columns = append(columns, "\""+columnName+"\"")
		params = append(params, ":"+paramkey)
		pp[paramkey] = rvfield.Interface()

	}

	sqlstr := fmt.Sprintf("INSERT"+" INTO \"%s\" (%s) VALUES (%s)", tableName, strings.Join(columns, ", "), strings.Join(params, ", "))

	sqlr, err := q.Exec(ctx, sqlstr, pp)
	if err != nil {
		return nil, err
	}

	return sqlr, nil
}

func QuerySingle[TData any](ctx context.Context, q Queryable, sql string, pp PP, mode StructScanMode, sec StructScanSafety) (TData, error) {
	rows, err := q.Query(ctx, sql, pp)
	if err != nil {
		return *new(TData), err
	}

	data, err := ScanSingle[TData](rows, mode, sec, true)
	if err != nil {
		return *new(TData), err
	}

	return data, nil
}

func QueryAll[TData any](ctx context.Context, q Queryable, sql string, pp PP, mode StructScanMode, sec StructScanSafety) ([]TData, error) {
	rows, err := q.Query(ctx, sql, pp)
	if err != nil {
		return nil, err
	}

	data, err := ScanAll[TData](rows, mode, sec, true)
	if err != nil {
		return nil, err
	}

	return data, nil
}

func ScanSingle[TData any](rows *sqlx.Rows, mode StructScanMode, sec StructScanSafety, close bool) (TData, error) {
	if rows.Next() {
		var strscan *StructScanner

		if sec == Safe {
			strscan = NewStructScanner(rows, false)
			var data TData
			err := strscan.Start(&data)
			if err != nil {
				return *new(TData), err
			}
		} else if sec == Unsafe {
			strscan = NewStructScanner(rows, true)
			var data TData
			err := strscan.Start(&data)
			if err != nil {
				return *new(TData), err
			}
		} else {
			return *new(TData), errors.New("unknown value for <sec>")
		}

		var data TData

		if mode == SModeFast {
			err := strscan.StructScanBase(&data)
			if err != nil {
				return *new(TData), err
			}
		} else if mode == SModeExtended {
			err := strscan.StructScanExt(&data)
			if err != nil {
				return *new(TData), err
			}
		} else {
			return *new(TData), errors.New("unknown value for <mode>")
		}

		if rows.Next() {
			if close {
				_ = rows.Close()
			}
			return *new(TData), errors.New("sql returned more than one row")
		}

		if close {
			err := rows.Close()
			if err != nil {
				return *new(TData), err
			}
		}

		if err := rows.Err(); err != nil {
			return *new(TData), err
		}

		return data, nil

	} else {
		if close {
			_ = rows.Close()
		}
		return *new(TData), sql.ErrNoRows
	}
}

func ScanAll[TData any](rows *sqlx.Rows, mode StructScanMode, sec StructScanSafety, close bool) ([]TData, error) {
	var strscan *StructScanner

	if sec == Safe {
		strscan = NewStructScanner(rows, false)
		var data TData
		err := strscan.Start(&data)
		if err != nil {
			return nil, err
		}
	} else if sec == Unsafe {
		strscan = NewStructScanner(rows, true)
		var data TData
		err := strscan.Start(&data)
		if err != nil {
			return nil, err
		}
	} else {
		return nil, errors.New("unknown value for <sec>")
	}

	res := make([]TData, 0)
	for rows.Next() {
		if mode == SModeFast {
			var data TData
			err := strscan.StructScanBase(&data)
			if err != nil {
				return nil, err
			}
			res = append(res, data)
		} else if mode == SModeExtended {
			var data TData
			err := strscan.StructScanExt(&data)
			if err != nil {
				return nil, err
			}
			res = append(res, data)
		} else {
			return nil, errors.New("unknown value for <mode>")
		}
	}
	if close {
		err := strscan.rows.Close()
		if err != nil {
			return nil, err
		}
	}
	if err := rows.Err(); err != nil {
		return nil, err
	}
	return res, nil
}