package sq

import (
	"database/sql"
	"errors"
	"github.com/jmoiron/sqlx"
)

type StructScanMode string

const (
	SModeFast     StructScanMode = "FAST"
	SModeExtended StructScanMode = "EXTENDED"
)

type StructScanSafety string

const (
	Safe   StructScanSafety = "SAFE"
	Unsafe StructScanSafety = "UNSAFE"
)

func ScanSingle[TData any](rows *sqlx.Rows, mode StructScanMode, sec StructScanSafety, close bool) (TData, error) {
	if rows.Next() {
		var strscan *StructScanner

		if sec == Safe {
			strscan = NewStructScanner(rows, false)
			var data TData
			err := strscan.Start(&data)
			if err != nil {
				return *new(TData), err
			}
		} else if sec == Unsafe {
			strscan = NewStructScanner(rows, true)
			var data TData
			err := strscan.Start(&data)
			if err != nil {
				return *new(TData), err
			}
		} else {
			return *new(TData), errors.New("unknown value for <sec>")
		}

		var data TData

		if mode == SModeFast {
			err := strscan.StructScanBase(&data)
			if err != nil {
				return *new(TData), err
			}
		} else if mode == SModeExtended {
			err := strscan.StructScanExt(&data)
			if err != nil {
				return *new(TData), err
			}
		} else {
			return *new(TData), errors.New("unknown value for <mode>")
		}

		if rows.Next() {
			if close {
				_ = rows.Close()
			}
			return *new(TData), errors.New("sql returned more than one row")
		}

		if close {
			err := rows.Close()
			if err != nil {
				return *new(TData), err
			}
		}

		if err := rows.Err(); err != nil {
			return *new(TData), err
		}

		return data, nil

	} else {
		if close {
			_ = rows.Close()
		}
		return *new(TData), sql.ErrNoRows
	}
}

func ScanAll[TData any](rows *sqlx.Rows, mode StructScanMode, sec StructScanSafety, close bool) ([]TData, error) {
	var strscan *StructScanner

	if sec == Safe {
		strscan = NewStructScanner(rows, false)
		var data TData
		err := strscan.Start(&data)
		if err != nil {
			return nil, err
		}
	} else if sec == Unsafe {
		strscan = NewStructScanner(rows, true)
		var data TData
		err := strscan.Start(&data)
		if err != nil {
			return nil, err
		}
	} else {
		return nil, errors.New("unknown value for <sec>")
	}

	res := make([]TData, 0)
	for rows.Next() {
		if mode == SModeFast {
			var data TData
			err := strscan.StructScanBase(&data)
			if err != nil {
				return nil, err
			}
			res = append(res, data)
		} else if mode == SModeExtended {
			var data TData
			err := strscan.StructScanExt(&data)
			if err != nil {
				return nil, err
			}
			res = append(res, data)
		} else {
			return nil, errors.New("unknown value for <mode>")
		}
	}
	if close {
		err := strscan.rows.Close()
		if err != nil {
			return nil, err
		}
	}
	if err := rows.Err(); err != nil {
		return nil, err
	}
	return res, nil
}