Fix sql-preprocessor leading to deadlocks in parallel requests
This commit is contained in:
parent
0112d681ac
commit
984470b47d
@ -17,6 +17,7 @@ import (
|
|||||||
|
|
||||||
type Database struct {
|
type Database struct {
|
||||||
db sq.DB
|
db sq.DB
|
||||||
|
pp *dbtools.DBPreprocessor
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDatabase(conf server.Config) (*Database, error) {
|
func NewDatabase(conf server.Config) (*Database, error) {
|
||||||
@ -38,11 +39,16 @@ func NewDatabase(conf server.Config) (*Database, error) {
|
|||||||
|
|
||||||
qqdb := sq.NewDB(xdb)
|
qqdb := sq.NewDB(xdb)
|
||||||
|
|
||||||
scndb := &Database{qqdb}
|
|
||||||
|
|
||||||
qqdb.AddListener(dbtools.DBLogger{})
|
qqdb.AddListener(dbtools.DBLogger{})
|
||||||
|
|
||||||
qqdb.AddListener(dbtools.NewDBPreprocessor(scndb.db))
|
pp, err := dbtools.NewDBPreprocessor(qqdb)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
qqdb.AddListener(pp)
|
||||||
|
|
||||||
|
scndb := &Database{db: qqdb, pp: pp}
|
||||||
|
|
||||||
return scndb, nil
|
return scndb, nil
|
||||||
}
|
}
|
||||||
@ -64,6 +70,11 @@ func (db *Database) Migrate(ctx context.Context) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = db.pp.Init(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|
||||||
} else if currschema == 1 {
|
} else if currschema == 1 {
|
||||||
|
@ -10,6 +10,7 @@ import (
|
|||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
//
|
//
|
||||||
@ -31,19 +32,72 @@ type DBPreprocessor struct {
|
|||||||
db sq.DB
|
db sq.DB
|
||||||
|
|
||||||
lock sync.Mutex
|
lock sync.Mutex
|
||||||
cacheColumns map[string][]string
|
dbTables []string
|
||||||
|
dbColumns map[string][]string
|
||||||
cacheQuery map[string]string
|
cacheQuery map[string]string
|
||||||
}
|
}
|
||||||
|
|
||||||
var regexAlias = regexp.MustCompile("([A-Za-z_\\-0-9]+)\\s+AS\\s+([A-Za-z_\\-0-9]+)")
|
var regexAlias = regexp.MustCompile("([A-Za-z_\\-0-9]+)\\s+AS\\s+([A-Za-z_\\-0-9]+)")
|
||||||
|
|
||||||
func NewDBPreprocessor(db sq.DB) *DBPreprocessor {
|
func NewDBPreprocessor(db sq.DB) (*DBPreprocessor, error) {
|
||||||
return &DBPreprocessor{
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
obj := &DBPreprocessor{
|
||||||
db: db,
|
db: db,
|
||||||
lock: sync.Mutex{},
|
lock: sync.Mutex{},
|
||||||
cacheColumns: make(map[string][]string),
|
|
||||||
cacheQuery: make(map[string]string),
|
cacheQuery: make(map[string]string),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err := obj.Init(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return obj, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pp *DBPreprocessor) Init(ctx context.Context) error {
|
||||||
|
|
||||||
|
dbTables := make([]string, 0)
|
||||||
|
dbColumns := make(map[string][]string, 0)
|
||||||
|
|
||||||
|
type tabInfo struct {
|
||||||
|
Name string `db:"name"`
|
||||||
|
}
|
||||||
|
type colInfo struct {
|
||||||
|
Name string `db:"name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
rows1, err := pp.db.Query(ctx, "PRAGMA table_list;", sq.PP{})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
resrows1, err := sq.ScanAll[tabInfo](rows1, sq.SModeFast, sq.Unsafe, true)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, tab := range resrows1 {
|
||||||
|
|
||||||
|
rows2, err := pp.db.Query(ctx, fmt.Sprintf("PRAGMA table_info(\"%s\");", tab.Name), sq.PP{})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
resrows2, err := sq.ScanAll[colInfo](rows2, sq.SModeFast, sq.Unsafe, true)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
columns := langext.ArrMap(resrows2, func(v colInfo) string { return v.Name })
|
||||||
|
|
||||||
|
dbTables = append(dbTables, tab.Name)
|
||||||
|
dbColumns[tab.Name] = columns
|
||||||
|
}
|
||||||
|
|
||||||
|
pp.dbTables = dbTables
|
||||||
|
pp.dbColumns = dbColumns
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pp *DBPreprocessor) PrePing(ctx context.Context) error {
|
func (pp *DBPreprocessor) PrePing(ctx context.Context) error {
|
||||||
@ -102,9 +156,9 @@ func (pp *DBPreprocessor) PreQuery(ctx context.Context, txID *uint16, sql *strin
|
|||||||
|
|
||||||
if expr == "*" {
|
if expr == "*" {
|
||||||
|
|
||||||
columns, err := pp.getTableColumns(ctx, fromTableName)
|
columns, ok := pp.dbColumns[fromTableName]
|
||||||
if err != nil {
|
if !ok {
|
||||||
return err
|
return errors.New(fmt.Sprintf("[preprocessor]: table '%s' not found", fromTableName))
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, colname := range columns {
|
for _, colname := range columns {
|
||||||
@ -117,9 +171,9 @@ func (pp *DBPreprocessor) PreQuery(ctx context.Context, txID *uint16, sql *strin
|
|||||||
|
|
||||||
if tableRealName, ok := aliasMap[tableName]; ok {
|
if tableRealName, ok := aliasMap[tableName]; ok {
|
||||||
|
|
||||||
columns, err := pp.getTableColumns(ctx, tableRealName)
|
columns, ok := pp.dbColumns[tableRealName]
|
||||||
if err != nil {
|
if !ok {
|
||||||
return err
|
return errors.New(fmt.Sprintf("[sql-preprocessor]: table '%s' not found", tableRealName))
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, colname := range columns {
|
for _, colname := range columns {
|
||||||
@ -128,9 +182,9 @@ func (pp *DBPreprocessor) PreQuery(ctx context.Context, txID *uint16, sql *strin
|
|||||||
|
|
||||||
} else if tableName == fromTableName {
|
} else if tableName == fromTableName {
|
||||||
|
|
||||||
columns, err := pp.getTableColumns(ctx, tableName)
|
columns, ok := pp.dbColumns[tableName]
|
||||||
if err != nil {
|
if !ok {
|
||||||
return err
|
return errors.New(fmt.Sprintf("[sql-preprocessor]: table '%s' not found", tableName))
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, colname := range columns {
|
for _, colname := range columns {
|
||||||
@ -139,9 +193,9 @@ func (pp *DBPreprocessor) PreQuery(ctx context.Context, txID *uint16, sql *strin
|
|||||||
|
|
||||||
} else {
|
} else {
|
||||||
|
|
||||||
columns, err := pp.getTableColumns(ctx, tableName)
|
columns, ok := pp.dbColumns[tableName]
|
||||||
if err != nil {
|
if !ok {
|
||||||
return err
|
return errors.New(fmt.Sprintf("[sql-preprocessor]: table '%s' not found", tableName))
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, colname := range columns {
|
for _, colname := range columns {
|
||||||
@ -195,39 +249,3 @@ func (pp *DBPreprocessor) PostQuery(txID *uint16, sqlOriginal string, sqlReal st
|
|||||||
func (pp *DBPreprocessor) PostExec(txID *uint16, sqlOriginal string, sqlReal string, params sq.PP) {
|
func (pp *DBPreprocessor) PostExec(txID *uint16, sqlOriginal string, sqlReal string, params sq.PP) {
|
||||||
//
|
//
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pp *DBPreprocessor) getTableColumns(ctx context.Context, tablename string) ([]string, error) {
|
|
||||||
pp.lock.Lock()
|
|
||||||
v, ok := pp.cacheColumns[tablename]
|
|
||||||
pp.lock.Unlock()
|
|
||||||
|
|
||||||
if ok {
|
|
||||||
return v, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type res struct {
|
|
||||||
Name string `db:"name"`
|
|
||||||
}
|
|
||||||
|
|
||||||
rows, err := pp.db.Query(ctx, "PRAGMA table_info('"+tablename+"');", sq.PP{})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
resrows, err := sq.ScanAll[res](rows, sq.SModeFast, sq.Unsafe, true)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
columns := langext.ArrMap(resrows, func(v res) string { return v.Name })
|
|
||||||
|
|
||||||
if len(columns) == 0 {
|
|
||||||
return nil, errors.New("no columns in table '" + tablename + "' (table does not exist?)")
|
|
||||||
}
|
|
||||||
|
|
||||||
pp.lock.Lock()
|
|
||||||
pp.cacheColumns[tablename] = columns
|
|
||||||
pp.lock.Unlock()
|
|
||||||
|
|
||||||
return columns, nil
|
|
||||||
}
|
|
||||||
|
@ -1440,8 +1440,10 @@ func TestSendParallel(t *testing.T) {
|
|||||||
uid := int(r0["user_id"].(float64))
|
uid := int(r0["user_id"].(float64))
|
||||||
sendtok := r0["send_key"].(string)
|
sendtok := r0["send_key"].(string)
|
||||||
|
|
||||||
sem := make(chan tt.Void, 900) // semaphore pattern
|
count := 128
|
||||||
for i := 0; i < 900; i++ {
|
|
||||||
|
sem := make(chan tt.Void, count) // semaphore pattern
|
||||||
|
for i := 0; i < count; i++ {
|
||||||
go func() {
|
go func() {
|
||||||
defer func() {
|
defer func() {
|
||||||
sem <- tt.Void{}
|
sem <- tt.Void{}
|
||||||
@ -1454,7 +1456,7 @@ func TestSendParallel(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
// wait for goroutines to finish
|
// wait for goroutines to finish
|
||||||
for i := 0; i < 900; i++ {
|
for i := 0; i < count; i++ {
|
||||||
<-sem
|
<-sem
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user