From 984470b47d89c858db73b01649e88f728bb129f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mike=20Schw=C3=B6rer?= Date: Thu, 22 Dec 2022 16:51:04 +0100 Subject: [PATCH] Fix sql-preprocessor leading to deadlocks in parallel requests --- scnserver/db/database.go | 17 +++- scnserver/db/dbtools/preprocessor.go | 132 +++++++++++++++------------ scnserver/test/send_test.go | 8 +- 3 files changed, 94 insertions(+), 63 deletions(-) diff --git a/scnserver/db/database.go b/scnserver/db/database.go index f6ca5a4..bfe9f2b 100644 --- a/scnserver/db/database.go +++ b/scnserver/db/database.go @@ -17,6 +17,7 @@ import ( type Database struct { db sq.DB + pp *dbtools.DBPreprocessor } func NewDatabase(conf server.Config) (*Database, error) { @@ -38,11 +39,16 @@ func NewDatabase(conf server.Config) (*Database, error) { qqdb := sq.NewDB(xdb) - scndb := &Database{qqdb} - qqdb.AddListener(dbtools.DBLogger{}) - qqdb.AddListener(dbtools.NewDBPreprocessor(scndb.db)) + pp, err := dbtools.NewDBPreprocessor(qqdb) + if err != nil { + return nil, err + } + + qqdb.AddListener(pp) + + scndb := &Database{db: qqdb, pp: pp} return scndb, nil } @@ -64,6 +70,11 @@ func (db *Database) Migrate(ctx context.Context) error { return err } + err = db.pp.Init(ctx) + if err != nil { + return err + } + return nil } else if currschema == 1 { diff --git a/scnserver/db/dbtools/preprocessor.go b/scnserver/db/dbtools/preprocessor.go index 92776ec..5f500ea 100644 --- a/scnserver/db/dbtools/preprocessor.go +++ b/scnserver/db/dbtools/preprocessor.go @@ -10,6 +10,7 @@ import ( "regexp" "strings" "sync" + "time" ) // @@ -30,20 +31,73 @@ import ( type DBPreprocessor struct { db sq.DB - lock sync.Mutex - cacheColumns map[string][]string - cacheQuery map[string]string + lock sync.Mutex + dbTables []string + dbColumns map[string][]string + cacheQuery map[string]string } var regexAlias = regexp.MustCompile("([A-Za-z_\\-0-9]+)\\s+AS\\s+([A-Za-z_\\-0-9]+)") -func NewDBPreprocessor(db sq.DB) *DBPreprocessor { - return &DBPreprocessor{ - db: db, - lock: sync.Mutex{}, - cacheColumns: make(map[string][]string), - cacheQuery: make(map[string]string), +func NewDBPreprocessor(db sq.DB) (*DBPreprocessor, error) { + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + obj := &DBPreprocessor{ + db: db, + lock: sync.Mutex{}, + cacheQuery: make(map[string]string), } + + err := obj.Init(ctx) + if err != nil { + return nil, err + } + + return obj, nil +} + +func (pp *DBPreprocessor) Init(ctx context.Context) error { + + dbTables := make([]string, 0) + dbColumns := make(map[string][]string, 0) + + type tabInfo struct { + Name string `db:"name"` + } + type colInfo struct { + Name string `db:"name"` + } + + rows1, err := pp.db.Query(ctx, "PRAGMA table_list;", sq.PP{}) + if err != nil { + return err + } + resrows1, err := sq.ScanAll[tabInfo](rows1, sq.SModeFast, sq.Unsafe, true) + if err != nil { + return err + } + for _, tab := range resrows1 { + + rows2, err := pp.db.Query(ctx, fmt.Sprintf("PRAGMA table_info(\"%s\");", tab.Name), sq.PP{}) + if err != nil { + return err + } + resrows2, err := sq.ScanAll[colInfo](rows2, sq.SModeFast, sq.Unsafe, true) + if err != nil { + return err + } + columns := langext.ArrMap(resrows2, func(v colInfo) string { return v.Name }) + + dbTables = append(dbTables, tab.Name) + dbColumns[tab.Name] = columns + } + + pp.dbTables = dbTables + pp.dbColumns = dbColumns + + return nil } func (pp *DBPreprocessor) PrePing(ctx context.Context) error { @@ -102,9 +156,9 @@ func (pp *DBPreprocessor) PreQuery(ctx context.Context, txID *uint16, sql *strin if expr == "*" { - columns, err := pp.getTableColumns(ctx, fromTableName) - if err != nil { - return err + columns, ok := pp.dbColumns[fromTableName] + if !ok { + return errors.New(fmt.Sprintf("[preprocessor]: table '%s' not found", fromTableName)) } for _, colname := range columns { @@ -117,9 +171,9 @@ func (pp *DBPreprocessor) PreQuery(ctx context.Context, txID *uint16, sql *strin if tableRealName, ok := aliasMap[tableName]; ok { - columns, err := pp.getTableColumns(ctx, tableRealName) - if err != nil { - return err + columns, ok := pp.dbColumns[tableRealName] + if !ok { + return errors.New(fmt.Sprintf("[sql-preprocessor]: table '%s' not found", tableRealName)) } for _, colname := range columns { @@ -128,9 +182,9 @@ func (pp *DBPreprocessor) PreQuery(ctx context.Context, txID *uint16, sql *strin } else if tableName == fromTableName { - columns, err := pp.getTableColumns(ctx, tableName) - if err != nil { - return err + columns, ok := pp.dbColumns[tableName] + if !ok { + return errors.New(fmt.Sprintf("[sql-preprocessor]: table '%s' not found", tableName)) } for _, colname := range columns { @@ -139,9 +193,9 @@ func (pp *DBPreprocessor) PreQuery(ctx context.Context, txID *uint16, sql *strin } else { - columns, err := pp.getTableColumns(ctx, tableName) - if err != nil { - return err + columns, ok := pp.dbColumns[tableName] + if !ok { + return errors.New(fmt.Sprintf("[sql-preprocessor]: table '%s' not found", tableName)) } for _, colname := range columns { @@ -195,39 +249,3 @@ func (pp *DBPreprocessor) PostQuery(txID *uint16, sqlOriginal string, sqlReal st func (pp *DBPreprocessor) PostExec(txID *uint16, sqlOriginal string, sqlReal string, params sq.PP) { // } - -func (pp *DBPreprocessor) getTableColumns(ctx context.Context, tablename string) ([]string, error) { - pp.lock.Lock() - v, ok := pp.cacheColumns[tablename] - pp.lock.Unlock() - - if ok { - return v, nil - } - - type res struct { - Name string `db:"name"` - } - - rows, err := pp.db.Query(ctx, "PRAGMA table_info('"+tablename+"');", sq.PP{}) - if err != nil { - return nil, err - } - - resrows, err := sq.ScanAll[res](rows, sq.SModeFast, sq.Unsafe, true) - if err != nil { - return nil, err - } - - columns := langext.ArrMap(resrows, func(v res) string { return v.Name }) - - if len(columns) == 0 { - return nil, errors.New("no columns in table '" + tablename + "' (table does not exist?)") - } - - pp.lock.Lock() - pp.cacheColumns[tablename] = columns - pp.lock.Unlock() - - return columns, nil -} diff --git a/scnserver/test/send_test.go b/scnserver/test/send_test.go index 22a3861..861fbeb 100644 --- a/scnserver/test/send_test.go +++ b/scnserver/test/send_test.go @@ -1440,8 +1440,10 @@ func TestSendParallel(t *testing.T) { uid := int(r0["user_id"].(float64)) sendtok := r0["send_key"].(string) - sem := make(chan tt.Void, 900) // semaphore pattern - for i := 0; i < 900; i++ { + count := 128 + + sem := make(chan tt.Void, count) // semaphore pattern + for i := 0; i < count; i++ { go func() { defer func() { sem <- tt.Void{} @@ -1454,7 +1456,7 @@ func TestSendParallel(t *testing.T) { }() } // wait for goroutines to finish - for i := 0; i < 900; i++ { + for i := 0; i < count; i++ { <-sem } }