diff --git a/.idea/.gitignore b/.idea/.gitignore index 13566b8..a9d7db9 100644 --- a/.idea/.gitignore +++ b/.idea/.gitignore @@ -6,3 +6,5 @@ # Datasource local storage ignored files /dataSources/ /dataSources.local.xml +# GitHub Copilot persisted chat sessions +/copilot/chatSessions diff --git a/goextVersion.go b/goextVersion.go index 4a4d2b0..a3b6178 100644 --- a/goextVersion.go +++ b/goextVersion.go @@ -1,5 +1,5 @@ package goext -const GoextVersion = "0.0.399" +const GoextVersion = "0.0.400" -const GoextVersionTimestamp = "2024-03-09T14:16:35+0100" +const GoextVersionTimestamp = "2024-03-09T14:59:32+0100" diff --git a/sq/builder_test.go b/sq/builder_test.go index 55b0735..2e3bd7c 100644 --- a/sq/builder_test.go +++ b/sq/builder_test.go @@ -52,8 +52,7 @@ func TestCreateUpdateStatement(t *testing.T) { xdb := tst.Must(sqlx.Open("sqlite", url))(t) - db := NewDB(xdb) - db.RegisterDefaultConverter() + db := NewDB(xdb, DBOptions{RegisterDefaultConverter: langext.PTrue}) _, err := db.Exec(ctx, "CREATE TABLE `requests` ( id TEXT NOT NULL, timestamp INTEGER NOT NULL, PRIMARY KEY (id) ) STRICT", PP{}) tst.AssertNoErr(t, err) diff --git a/sq/commentTrimmer.go b/sq/commentTrimmer.go new file mode 100644 index 0000000..b612ff2 --- /dev/null +++ b/sq/commentTrimmer.go @@ -0,0 +1,32 @@ +package sq + +import ( + "context" + "strings" +) + +var CommentTrimmer = NewPreListener(fnTrimComments) + +func fnTrimComments(ctx context.Context, cmdtype string, id *uint16, sql *string, params *PP) error { + + res := make([]string, 0) + + for _, s := range strings.Split(*sql, "\n") { + if strings.HasPrefix(strings.TrimSpace(s), "--") { + continue + } + + idx := strings.Index(s, "--") + if idx != -1 { + s = s[:idx] + } + + s = strings.TrimRight(s, " \t\r\n") + + res = append(res, s) + } + + *sql = strings.Join(res, "\n") + + return nil +} diff --git a/sq/database.go b/sq/database.go index 127a760..3d18ae4 100644 --- a/sq/database.go +++ b/sq/database.go @@ -17,7 +17,11 @@ type DB interface { AddListener(listener Listener) Exit() error RegisterConverter(DBTypeConverter) - RegisterDefaultConverter() +} + +type DBOptions struct { + RegisterDefaultConverter *bool + RegisterCommentTrimmer *bool } type database struct { @@ -28,13 +32,23 @@ type database struct { conv []DBTypeConverter } -func NewDB(db *sqlx.DB) DB { - return &database{ +func NewDB(db *sqlx.DB, opt DBOptions) DB { + sqdb := &database{ db: db, txctr: 0, lock: sync.Mutex{}, lstr: make([]Listener, 0), } + + if langext.Coalesce(opt.RegisterDefaultConverter, true) { + sqdb.registerDefaultConverter() + } + + if langext.Coalesce(opt.RegisterCommentTrimmer, true) { + sqdb.AddListener(CommentTrimmer) + } + + return sqdb } func (db *database) AddListener(listener Listener) { @@ -141,7 +155,7 @@ func (db *database) RegisterConverter(conv DBTypeConverter) { db.conv = append(db.conv, conv) } -func (db *database) RegisterDefaultConverter() { +func (db *database) registerDefaultConverter() { db.RegisterConverter(ConverterBoolToBit) db.RegisterConverter(ConverterTimeToUnixMillis) diff --git a/sq/hasher.go b/sq/hasher.go index 678e79f..e483490 100644 --- a/sq/hasher.go +++ b/sq/hasher.go @@ -31,7 +31,7 @@ func HashMattnSqliteSchema(ctx context.Context, schemaStr string) (string, error return "", err } - db := NewDB(xdb) + db := NewDB(xdb, DBOptions{}) _, err = db.Exec(ctx, schemaStr, PP{}) if err != nil { @@ -59,7 +59,7 @@ func HashGoSqliteSchema(ctx context.Context, schemaStr string) (string, error) { return "", err } - db := NewDB(xdb) + db := NewDB(xdb, DBOptions{}) _, err = db.Exec(ctx, schemaStr, PP{}) if err != nil { diff --git a/sq/listener.go b/sq/listener.go index 6b8158b..deb5d0d 100644 --- a/sq/listener.go +++ b/sq/listener.go @@ -17,3 +17,172 @@ type Listener interface { PostQuery(txID *uint16, sqlOriginal string, sqlReal string, params PP) PostExec(txID *uint16, sqlOriginal string, sqlReal string, params PP) } + +type genListener struct { + prePing func(ctx context.Context) error + preTxBegin func(ctx context.Context, txid uint16) error + preTxCommit func(txid uint16) error + preTxRollback func(txid uint16) error + preQuery func(ctx context.Context, txID *uint16, sql *string, params *PP) error + preExec func(ctx context.Context, txID *uint16, sql *string, params *PP) error + postPing func(result error) + postTxBegin func(txid uint16, result error) + postTxCommit func(txid uint16, result error) + postTxRollback func(txid uint16, result error) + postQuery func(txID *uint16, sqlOriginal string, sqlReal string, params PP) + postExec func(txID *uint16, sqlOriginal string, sqlReal string, params PP) +} + +func (g genListener) PrePing(ctx context.Context) error { + if g.prePing == nil { + return g.prePing(ctx) + } else { + return nil + } +} + +func (g genListener) PreTxBegin(ctx context.Context, txid uint16) error { + if g.preTxBegin == nil { + return g.preTxBegin(ctx, txid) + } else { + return nil + } +} + +func (g genListener) PreTxCommit(txid uint16) error { + if g.preTxCommit == nil { + return g.preTxCommit(txid) + } else { + return nil + } +} + +func (g genListener) PreTxRollback(txid uint16) error { + if g.preTxRollback == nil { + return g.preTxRollback(txid) + } else { + return nil + } +} + +func (g genListener) PreQuery(ctx context.Context, txID *uint16, sql *string, params *PP) error { + if g.preQuery == nil { + return g.preQuery(ctx, txID, sql, params) + } else { + return nil + } +} + +func (g genListener) PreExec(ctx context.Context, txID *uint16, sql *string, params *PP) error { + if g.preExec == nil { + return g.preExec(ctx, txID, sql, params) + } else { + return nil + } +} + +func (g genListener) PostPing(result error) { + if g.postPing != nil { + g.postPing(result) + } +} + +func (g genListener) PostTxBegin(txid uint16, result error) { + if g.postTxBegin != nil { + g.postTxBegin(txid, result) + } +} + +func (g genListener) PostTxCommit(txid uint16, result error) { + if g.postTxCommit != nil { + g.postTxCommit(txid, result) + } +} + +func (g genListener) PostTxRollback(txid uint16, result error) { + if g.postTxRollback != nil { + g.postTxRollback(txid, result) + } +} + +func (g genListener) PostQuery(txID *uint16, sqlOriginal string, sqlReal string, params PP) { + if g.postQuery != nil { + g.postQuery(txID, sqlOriginal, sqlReal, params) + } +} + +func (g genListener) PostExec(txID *uint16, sqlOriginal string, sqlReal string, params PP) { + if g.postExec != nil { + g.postExec(txID, sqlOriginal, sqlReal, params) + } +} + +func NewPrePingListener(f func(ctx context.Context) error) Listener { + return genListener{prePing: f} +} + +func NewPreTxBeginListener(f func(ctx context.Context, txid uint16) error) Listener { + return genListener{preTxBegin: f} +} + +func NewPreTxCommitListener(f func(txid uint16) error) Listener { + return genListener{preTxCommit: f} +} + +func NewPreTxRollbackListener(f func(txid uint16) error) Listener { + return genListener{preTxRollback: f} +} + +func NewPreQueryListener(f func(ctx context.Context, txID *uint16, sql *string, params *PP) error) Listener { + return genListener{preQuery: f} +} + +func NewPreExecListener(f func(ctx context.Context, txID *uint16, sql *string, params *PP) error) Listener { + return genListener{preExec: f} +} + +func NewPreListener(f func(ctx context.Context, cmdtype string, txID *uint16, sql *string, params *PP) error) Listener { + return genListener{ + preExec: func(ctx context.Context, txID *uint16, sql *string, params *PP) error { + return f(ctx, "EXEC", txID, sql, params) + }, + preQuery: func(ctx context.Context, txID *uint16, sql *string, params *PP) error { + return f(ctx, "QUERY", txID, sql, params) + }, + } +} + +func NewPostPingListener(f func(result error)) Listener { + return genListener{postPing: f} +} + +func NewPostTxBeginListener(f func(txid uint16, result error)) Listener { + return genListener{postTxBegin: f} +} + +func NewPostTxCommitListener(f func(txid uint16, result error)) Listener { + return genListener{postTxCommit: f} +} + +func NewPostTxRollbackListener(f func(txid uint16, result error)) Listener { + return genListener{postTxRollback: f} +} + +func NewPostQueryListener(f func(txID *uint16, sqlOriginal string, sqlReal string, params PP)) Listener { + return genListener{postQuery: f} +} + +func NewPostExecListener(f func(txID *uint16, sqlOriginal string, sqlReal string, params PP)) Listener { + return genListener{postExec: f} +} + +func NewPostListener(f func(cmdtype string, txID *uint16, sqlOriginal string, sqlReal string, params PP)) Listener { + return genListener{ + postExec: func(txID *uint16, sqlOriginal string, sqlReal string, params PP) { + f("EXEC", txID, sqlOriginal, sqlReal, params) + }, + postQuery: func(txID *uint16, sqlOriginal string, sqlReal string, params PP) { + f("QUERY", txID, sqlOriginal, sqlReal, params) + }, + } +} diff --git a/sq/scanner_test.go b/sq/scanner_test.go index 7ec7d1c..7b622a1 100644 --- a/sq/scanner_test.go +++ b/sq/scanner_test.go @@ -36,8 +36,7 @@ func TestInsertSingle(t *testing.T) { xdb := tst.Must(sqlx.Open("sqlite", url))(t) - db := NewDB(xdb) - db.RegisterDefaultConverter() + db := NewDB(xdb, DBOptions{RegisterDefaultConverter: langext.PTrue}) _, err := db.Exec(ctx, ` CREATE TABLE requests ( @@ -90,8 +89,7 @@ func TestUpdateSingle(t *testing.T) { xdb := tst.Must(sqlx.Open("sqlite", url))(t) - db := NewDB(xdb) - db.RegisterDefaultConverter() + db := NewDB(xdb, DBOptions{RegisterDefaultConverter: langext.PTrue}) _, err := db.Exec(ctx, ` CREATE TABLE requests ( @@ -176,8 +174,7 @@ func TestInsertMultiple(t *testing.T) { xdb := tst.Must(sqlx.Open("sqlite", url))(t) - db := NewDB(xdb) - db.RegisterDefaultConverter() + db := NewDB(xdb, DBOptions{RegisterDefaultConverter: langext.PTrue}) _, err := db.Exec(ctx, ` CREATE TABLE requests ( diff --git a/sq/sq_test.go b/sq/sq_test.go index bc36bb3..217fb64 100644 --- a/sq/sq_test.go +++ b/sq/sq_test.go @@ -36,8 +36,7 @@ func TestTypeConverter1(t *testing.T) { xdb := tst.Must(sqlx.Open("sqlite", url))(t) - db := NewDB(xdb) - db.RegisterDefaultConverter() + db := NewDB(xdb, DBOptions{RegisterDefaultConverter: langext.PTrue}) _, err := db.Exec(ctx, "CREATE TABLE `requests` ( id TEXT NOT NULL, timestamp INTEGER NOT NULL, PRIMARY KEY (id) ) STRICT", PP{}) tst.AssertNoErr(t, err) @@ -71,8 +70,7 @@ func TestTypeConverter2(t *testing.T) { xdb := tst.Must(sqlx.Open("sqlite", url))(t) - db := NewDB(xdb) - db.RegisterDefaultConverter() + db := NewDB(xdb, DBOptions{RegisterDefaultConverter: langext.PTrue}) _, err := db.Exec(ctx, "CREATE TABLE `requests` ( id TEXT NOT NULL, timestamp INTEGER NOT NULL, PRIMARY KEY (id) ) STRICT", PP{}) tst.AssertNoErr(t, err) @@ -116,8 +114,7 @@ func TestTypeConverter3(t *testing.T) { xdb := tst.Must(sqlx.Open("sqlite", url))(t) - db := NewDB(xdb) - db.RegisterDefaultConverter() + db := NewDB(xdb, DBOptions{RegisterDefaultConverter: langext.PTrue}) _, err := db.Exec(ctx, "CREATE TABLE `requests` ( id TEXT NOT NULL, timestamp INTEGER NULL, PRIMARY KEY (id) ) STRICT", PP{}) tst.AssertNoErr(t, err)