Fix sql-preprocessor leading to deadlocks in parallel requests

This commit is contained in:
Mike Schwörer 2022-12-22 16:51:04 +01:00
parent 0112d681ac
commit 984470b47d
Signed by: Mikescher
GPG Key ID: D3C7172E0A70F8CF
3 changed files with 94 additions and 63 deletions

View File

@ -17,6 +17,7 @@ import (
type Database struct { type Database struct {
db sq.DB db sq.DB
pp *dbtools.DBPreprocessor
} }
func NewDatabase(conf server.Config) (*Database, error) { func NewDatabase(conf server.Config) (*Database, error) {
@ -38,11 +39,16 @@ func NewDatabase(conf server.Config) (*Database, error) {
qqdb := sq.NewDB(xdb) qqdb := sq.NewDB(xdb)
scndb := &Database{qqdb}
qqdb.AddListener(dbtools.DBLogger{}) qqdb.AddListener(dbtools.DBLogger{})
qqdb.AddListener(dbtools.NewDBPreprocessor(scndb.db)) pp, err := dbtools.NewDBPreprocessor(qqdb)
if err != nil {
return nil, err
}
qqdb.AddListener(pp)
scndb := &Database{db: qqdb, pp: pp}
return scndb, nil return scndb, nil
} }
@ -64,6 +70,11 @@ func (db *Database) Migrate(ctx context.Context) error {
return err return err
} }
err = db.pp.Init(ctx)
if err != nil {
return err
}
return nil return nil
} else if currschema == 1 { } else if currschema == 1 {

View File

@ -10,6 +10,7 @@ import (
"regexp" "regexp"
"strings" "strings"
"sync" "sync"
"time"
) )
// //
@ -31,19 +32,72 @@ type DBPreprocessor struct {
db sq.DB db sq.DB
lock sync.Mutex lock sync.Mutex
cacheColumns map[string][]string dbTables []string
dbColumns map[string][]string
cacheQuery map[string]string cacheQuery map[string]string
} }
var regexAlias = regexp.MustCompile("([A-Za-z_\\-0-9]+)\\s+AS\\s+([A-Za-z_\\-0-9]+)") var regexAlias = regexp.MustCompile("([A-Za-z_\\-0-9]+)\\s+AS\\s+([A-Za-z_\\-0-9]+)")
func NewDBPreprocessor(db sq.DB) *DBPreprocessor { func NewDBPreprocessor(db sq.DB) (*DBPreprocessor, error) {
return &DBPreprocessor{
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
obj := &DBPreprocessor{
db: db, db: db,
lock: sync.Mutex{}, lock: sync.Mutex{},
cacheColumns: make(map[string][]string),
cacheQuery: make(map[string]string), cacheQuery: make(map[string]string),
} }
err := obj.Init(ctx)
if err != nil {
return nil, err
}
return obj, nil
}
func (pp *DBPreprocessor) Init(ctx context.Context) error {
dbTables := make([]string, 0)
dbColumns := make(map[string][]string, 0)
type tabInfo struct {
Name string `db:"name"`
}
type colInfo struct {
Name string `db:"name"`
}
rows1, err := pp.db.Query(ctx, "PRAGMA table_list;", sq.PP{})
if err != nil {
return err
}
resrows1, err := sq.ScanAll[tabInfo](rows1, sq.SModeFast, sq.Unsafe, true)
if err != nil {
return err
}
for _, tab := range resrows1 {
rows2, err := pp.db.Query(ctx, fmt.Sprintf("PRAGMA table_info(\"%s\");", tab.Name), sq.PP{})
if err != nil {
return err
}
resrows2, err := sq.ScanAll[colInfo](rows2, sq.SModeFast, sq.Unsafe, true)
if err != nil {
return err
}
columns := langext.ArrMap(resrows2, func(v colInfo) string { return v.Name })
dbTables = append(dbTables, tab.Name)
dbColumns[tab.Name] = columns
}
pp.dbTables = dbTables
pp.dbColumns = dbColumns
return nil
} }
func (pp *DBPreprocessor) PrePing(ctx context.Context) error { func (pp *DBPreprocessor) PrePing(ctx context.Context) error {
@ -102,9 +156,9 @@ func (pp *DBPreprocessor) PreQuery(ctx context.Context, txID *uint16, sql *strin
if expr == "*" { if expr == "*" {
columns, err := pp.getTableColumns(ctx, fromTableName) columns, ok := pp.dbColumns[fromTableName]
if err != nil { if !ok {
return err return errors.New(fmt.Sprintf("[preprocessor]: table '%s' not found", fromTableName))
} }
for _, colname := range columns { for _, colname := range columns {
@ -117,9 +171,9 @@ func (pp *DBPreprocessor) PreQuery(ctx context.Context, txID *uint16, sql *strin
if tableRealName, ok := aliasMap[tableName]; ok { if tableRealName, ok := aliasMap[tableName]; ok {
columns, err := pp.getTableColumns(ctx, tableRealName) columns, ok := pp.dbColumns[tableRealName]
if err != nil { if !ok {
return err return errors.New(fmt.Sprintf("[sql-preprocessor]: table '%s' not found", tableRealName))
} }
for _, colname := range columns { for _, colname := range columns {
@ -128,9 +182,9 @@ func (pp *DBPreprocessor) PreQuery(ctx context.Context, txID *uint16, sql *strin
} else if tableName == fromTableName { } else if tableName == fromTableName {
columns, err := pp.getTableColumns(ctx, tableName) columns, ok := pp.dbColumns[tableName]
if err != nil { if !ok {
return err return errors.New(fmt.Sprintf("[sql-preprocessor]: table '%s' not found", tableName))
} }
for _, colname := range columns { for _, colname := range columns {
@ -139,9 +193,9 @@ func (pp *DBPreprocessor) PreQuery(ctx context.Context, txID *uint16, sql *strin
} else { } else {
columns, err := pp.getTableColumns(ctx, tableName) columns, ok := pp.dbColumns[tableName]
if err != nil { if !ok {
return err return errors.New(fmt.Sprintf("[sql-preprocessor]: table '%s' not found", tableName))
} }
for _, colname := range columns { for _, colname := range columns {
@ -195,39 +249,3 @@ func (pp *DBPreprocessor) PostQuery(txID *uint16, sqlOriginal string, sqlReal st
func (pp *DBPreprocessor) PostExec(txID *uint16, sqlOriginal string, sqlReal string, params sq.PP) { func (pp *DBPreprocessor) PostExec(txID *uint16, sqlOriginal string, sqlReal string, params sq.PP) {
// //
} }
func (pp *DBPreprocessor) getTableColumns(ctx context.Context, tablename string) ([]string, error) {
pp.lock.Lock()
v, ok := pp.cacheColumns[tablename]
pp.lock.Unlock()
if ok {
return v, nil
}
type res struct {
Name string `db:"name"`
}
rows, err := pp.db.Query(ctx, "PRAGMA table_info('"+tablename+"');", sq.PP{})
if err != nil {
return nil, err
}
resrows, err := sq.ScanAll[res](rows, sq.SModeFast, sq.Unsafe, true)
if err != nil {
return nil, err
}
columns := langext.ArrMap(resrows, func(v res) string { return v.Name })
if len(columns) == 0 {
return nil, errors.New("no columns in table '" + tablename + "' (table does not exist?)")
}
pp.lock.Lock()
pp.cacheColumns[tablename] = columns
pp.lock.Unlock()
return columns, nil
}

View File

@ -1440,8 +1440,10 @@ func TestSendParallel(t *testing.T) {
uid := int(r0["user_id"].(float64)) uid := int(r0["user_id"].(float64))
sendtok := r0["send_key"].(string) sendtok := r0["send_key"].(string)
sem := make(chan tt.Void, 900) // semaphore pattern count := 128
for i := 0; i < 900; i++ {
sem := make(chan tt.Void, count) // semaphore pattern
for i := 0; i < count; i++ {
go func() { go func() {
defer func() { defer func() {
sem <- tt.Void{} sem <- tt.Void{}
@ -1454,7 +1456,7 @@ func TestSendParallel(t *testing.T) {
}() }()
} }
// wait for goroutines to finish // wait for goroutines to finish
for i := 0; i < 900; i++ { for i := 0; i < count; i++ {
<-sem <-sem
} }
} }