diff --git a/server/api/handler/api.go b/server/api/handler/api.go index 226056b..636cca0 100644 --- a/server/api/handler/api.go +++ b/server/api/handler/api.go @@ -775,7 +775,11 @@ func (h APIHandler) ListChannelMessages(g *gin.Context) ginresp.HTTPResponse { return ginresp.APIError(g, 500, apierr.PAGETOKEN_ERROR, "Failed to decode next_page_token", err) } - messages, npt, err := h.database.ListChannelMessages(ctx, channel.ChannelID, pageSize, tok) + filter := models.MessageFilter{ + ChannelID: langext.Ptr([]models.ChannelID{channel.ChannelID}), + } + + messages, npt, err := h.database.ListMessages(ctx, filter, pageSize, tok) if err != nil { return ginresp.APIError(g, 500, apierr.DATABASE_ERROR, "Failed to query messages", err) } @@ -1176,7 +1180,11 @@ func (h APIHandler) ListMessages(g *gin.Context) ginresp.HTTPResponse { return ginresp.APIError(g, 500, apierr.DATABASE_ERROR, "Failed to update last-read", err) } - messages, npt, err := h.database.ListMessages(ctx, userid, pageSize, tok) + filter := models.MessageFilter{ + ConfirmedSubscriptionBy: langext.Ptr(userid), + } + + messages, npt, err := h.database.ListMessages(ctx, filter, pageSize, tok) if err != nil { return ginresp.APIError(g, 500, apierr.DATABASE_ERROR, "Failed to query messages", err) } diff --git a/server/db/cursortoken/token.go b/server/db/cursortoken/token.go index 0ed5d4a..113654a 100644 --- a/server/db/cursortoken/token.go +++ b/server/db/cursortoken/token.go @@ -17,42 +17,47 @@ const ( ) type CursorToken struct { - Mode Mode - Timestamp int64 - Id int64 - Direction string + Mode Mode + Timestamp int64 + Id int64 + Direction string + FilterHash string } type cursorTokenSerialize struct { - Timestamp *int64 `json:"ts,omitempty"` - Id *int64 `json:"id,omitempty"` - Direction *string `json:"dir,omitempty"` + Timestamp *int64 `json:"ts,omitempty"` + Id *int64 `json:"id,omitempty"` + Direction *string `json:"dir,omitempty"` + FilterHash *string `json:"f,omitempty"` } func Start() CursorToken { return CursorToken{ - Mode: CTMStart, - Timestamp: 0, - Id: 0, - Direction: "", + Mode: CTMStart, + Timestamp: 0, + Id: 0, + Direction: "", + FilterHash: "", } } func End() CursorToken { return CursorToken{ - Mode: CTMEnd, - Timestamp: 0, - Id: 0, - Direction: "", + Mode: CTMEnd, + Timestamp: 0, + Id: 0, + Direction: "", + FilterHash: "", } } -func Normal(ts time.Time, id int64, dir string) CursorToken { +func Normal(ts time.Time, id int64, dir string, filter string) CursorToken { return CursorToken{ - Mode: CTMNormal, - Timestamp: ts.UnixMilli(), - Id: id, - Direction: dir, + Mode: CTMNormal, + Timestamp: ts.UnixMilli(), + Id: id, + Direction: dir, + FilterHash: filter, } } @@ -83,6 +88,10 @@ func (c *CursorToken) Token() string { sertok.Direction = &c.Direction } + if c.FilterHash != "" { + sertok.FilterHash = &c.FilterHash + } + body, err := json.Marshal(sertok) if err != nil { panic(err) @@ -128,6 +137,9 @@ func Decode(tok string) (CursorToken, error) { if tokenDeserialize.Direction != nil { token.Direction = *tokenDeserialize.Direction } + if tokenDeserialize.FilterHash != nil { + token.FilterHash = *tokenDeserialize.FilterHash + } return token, nil } diff --git a/server/db/messages.go b/server/db/messages.go index ba4185b..4c6240e 100644 --- a/server/db/messages.go +++ b/server/db/messages.go @@ -4,7 +4,6 @@ import ( "blackforestbytes.com/simplecloudnotifier/db/cursortoken" "blackforestbytes.com/simplecloudnotifier/models" "database/sql" - "fmt" "gogs.mikescher.com/BlackForestBytes/goext/sq" "time" ) @@ -112,7 +111,7 @@ func (db *Database) DeleteMessage(ctx TxContext, scnMessageID models.SCNMessageI return nil } -func (db *Database) ListMessages(ctx TxContext, userid models.UserID, pageSize int, inTok cursortoken.CursorToken) ([]models.Message, cursortoken.CursorToken, error) { +func (db *Database) ListMessages(ctx TxContext, filter models.MessageFilter, pageSize int, inTok cursortoken.CursorToken) ([]models.Message, cursortoken.CursorToken, error) { tx, err := ctx.GetOrCreateTransaction(db) if err != nil { return nil, cursortoken.CursorToken{}, err @@ -124,13 +123,20 @@ func (db *Database) ListMessages(ctx TxContext, userid models.UserID, pageSize i pageCond := "" if inTok.Mode == cursortoken.CTMNormal { - pageCond = fmt.Sprintf("AND ( timestamp_real < %d OR (timestamp_real = %d AND scn_message_id < %d ) )", inTok.Timestamp, inTok.Timestamp, inTok.Id) + pageCond = "timestamp_real < :tokts OR (timestamp_real = :tokts AND scn_message_id < :tokid )" } - rows, err := tx.Query(ctx, "SELECT messages.* FROM messages LEFT JOIN subscriptions subs on messages.channel_id = subs.channel_id WHERE subs.subscriber_user_id = :uid AND subs.confirmed = 1 "+pageCond+" ORDER BY timestamp_real DESC LIMIT :lim", sq.PP{ - "uid": userid, - "lim": pageSize + 1, - }) + filterCond, filterJoin, prepParams, err := filter.SQL() + + orderClause := "ORDER BY COALESCE(timestamp_client, timestamp_real) DESC LIMIT :lim" + + sqlQuery := "SELECT " + "messages.*" + " FROM messages " + filterJoin + " WHERE ( " + filterCond + " ) AND ( " + pageCond + " ) " + orderClause + + prepParams["lim"] = pageSize + 1 + prepParams["tokts"] = inTok.Timestamp + prepParams["tokid"] = inTok.Id + + rows, err := tx.Query(ctx, sqlQuery, prepParams) if err != nil { return nil, cursortoken.CursorToken{}, err } @@ -143,45 +149,7 @@ func (db *Database) ListMessages(ctx TxContext, userid models.UserID, pageSize i if len(data) <= pageSize { return data, cursortoken.End(), nil } else { - outToken := cursortoken.Normal(data[pageSize-1].TimestampReal, data[pageSize-1].SCNMessageID.IntID(), "DESC") - return data[0:pageSize], outToken, nil - } -} - -func (db *Database) ListChannelMessages(ctx TxContext, channelid models.ChannelID, pageSize int, inTok cursortoken.CursorToken) ([]models.Message, cursortoken.CursorToken, error) { - tx, err := ctx.GetOrCreateTransaction(db) - if err != nil { - return nil, cursortoken.CursorToken{}, err - } - - if inTok.Mode == cursortoken.CTMEnd { - return make([]models.Message, 0), cursortoken.End(), nil - } - - pageCond := "" - if inTok.Mode == cursortoken.CTMNormal { - pageCond = "AND ( timestamp_real < :tokts OR (timestamp_real = :tokts AND scn_message_id < :tokid ) )" - } - - rows, err := tx.Query(ctx, "SELECT * FROM messages WHERE channel_id = :cid "+pageCond+" ORDER BY timestamp_real DESC LIMIT :lim", sq.PP{ - "cid": channelid, - "lim": pageSize + 1, - "tokts": inTok.Timestamp, - "tokid": inTok.Timestamp, - }) - if err != nil { - return nil, cursortoken.CursorToken{}, err - } - - data, err := models.DecodeMessages(rows) - if err != nil { - return nil, cursortoken.CursorToken{}, err - } - - if len(data) <= pageSize { - return data, cursortoken.End(), nil - } else { - outToken := cursortoken.Normal(data[pageSize-1].TimestampReal, data[pageSize-1].SCNMessageID.IntID(), "DESC") + outToken := cursortoken.Normal(data[pageSize-1].Timestamp(), data[pageSize-1].SCNMessageID.IntID(), "DESC", filter.Hash()) return data[0:pageSize], outToken, nil } } diff --git a/server/db/schema/schema_3.ddl b/server/db/schema/schema_3.ddl index 0cc2360..8f7af68 100644 --- a/server/db/schema/schema_3.ddl +++ b/server/db/schema/schema_3.ddl @@ -92,10 +92,16 @@ CREATE TABLE messages priority INTEGER CHECK(priority IN (0, 1, 2)) NOT NULL, usr_message_id TEXT NULL ) STRICT; -CREATE INDEX "idx_messages_channel" ON messages (owner_user_id, channel_name); -CREATE UNIQUE INDEX "idx_messages_idempotency" ON messages (owner_user_id, usr_message_id); -CREATE INDEX "idx_messages_senderip" ON messages (sender_ip); -CREATE INDEX "idx_messages_sendername" ON messages (sender_name); +CREATE INDEX "idx_messages_owner_channel" ON messages (owner_user_id, channel_name COLLATE BINARY); +CREATE INDEX "idx_messages_owner_channel_nc" ON messages (owner_user_id, channel_name COLLATE NOCASE); +CREATE INDEX "idx_messages_channel" ON messages (channel_name COLLATE BINARY); +CREATE INDEX "idx_messages_channel_nc" ON messages (channel_name COLLATE NOCASE); +CREATE UNIQUE INDEX "idx_messages_idempotency" ON messages (owner_user_id, usr_message_id COLLATE BINARY); +CREATE INDEX "idx_messages_senderip" ON messages (sender_ip COLLATE BINARY); +CREATE INDEX "idx_messages_sendername" ON messages (sender_name COLLATE BINARY); +CREATE INDEX "idx_messages_sendername_nc" ON messages (sender_name COLLATE NOCASE); +CREATE INDEX "idx_messages_title" ON messages (title COLLATE BINARY); +CREATE INDEX "idx_messages_title_nc" ON messages (title COLLATE NOCASE); CREATE VIRTUAL TABLE messages_fts USING fts5 diff --git a/server/go.mod b/server/go.mod index e580ed6..11201e1 100644 --- a/server/go.mod +++ b/server/go.mod @@ -8,7 +8,7 @@ require ( github.com/mattn/go-sqlite3 v1.14.16 github.com/rs/zerolog v1.28.0 github.com/swaggo/swag v1.8.7 - gogs.mikescher.com/BlackForestBytes/goext v0.0.31 + gogs.mikescher.com/BlackForestBytes/goext v0.0.32 ) require ( diff --git a/server/go.sum b/server/go.sum index 3ae06d7..0ff9e69 100644 --- a/server/go.sum +++ b/server/go.sum @@ -100,6 +100,8 @@ github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0 github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY= gogs.mikescher.com/BlackForestBytes/goext v0.0.31 h1:DC2RZe7/tSDDbPRbjDcYa+BLRlY0SgLTAkI2DPw5WJQ= gogs.mikescher.com/BlackForestBytes/goext v0.0.31/go.mod h1:/u9JtMwCP68ix4R9BJ/MT0Lm+QScmqIoyYZFKBGzv9g= +gogs.mikescher.com/BlackForestBytes/goext v0.0.32 h1:DJoRBNhq4rrOBXA/nD6WEm7L3vylLkMifU9/sWEiF7M= +gogs.mikescher.com/BlackForestBytes/goext v0.0.32/go.mod h1:/u9JtMwCP68ix4R9BJ/MT0Lm+QScmqIoyYZFKBGzv9g= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 h1:/UOmuWzQfxxo9UtlXMwuQU8CMgg1eZXqTRwkSQJWKOI= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 h1:6zppjxzCulZykYSLyVDYbneBfbaBIQPYMevg0bEwv2s= diff --git a/server/models/messagefilter.go b/server/models/messagefilter.go new file mode 100644 index 0000000..eb1a50a --- /dev/null +++ b/server/models/messagefilter.go @@ -0,0 +1,227 @@ +package models + +import ( + "crypto/sha512" + "encoding/hex" + "fmt" + "gogs.mikescher.com/BlackForestBytes/goext/dataext" + "gogs.mikescher.com/BlackForestBytes/goext/langext" + "gogs.mikescher.com/BlackForestBytes/goext/mathext" + "gogs.mikescher.com/BlackForestBytes/goext/sq" + "strconv" + "strings" + "time" +) + +type MessageFilter struct { + ConfirmedSubscriptionBy *UserID + SearchString *[]string + Sender *[]UserID + Owner *[]UserID + ChannelNameCS *[]string // case-sensitive + ChannelNameCI *[]string // case-insensitive + ChannelID *[]ChannelID + SenderNameCS *[]string // case-sensitive + SenderNameCI *[]string // case-insensitive + SenderIP *[]string + TimestampCoalesce *time.Time + TimestampCoalesceAfter *time.Time + TimestampCoalesceBefore *time.Time + TimestampReal *time.Time + TimestampRealAfter *time.Time + TimestampRealBefore *time.Time + TimestampClient *time.Time + TimestampClientAfter *time.Time + TimestampClientBefore *time.Time + TitleCS *string // case-sensitive + TitleCI *string // case-insensitive + Priority *[]int + UserMessageID *[]string +} + +func (f MessageFilter) SQL() (string, string, sq.PP, error) { + + joinClause := "" + if f.ConfirmedSubscriptionBy != nil { + joinClause += " LEFT JOIN subscriptions subs on messages.channel_id = subs.channel_id " + } + if f.SearchString != nil { + joinClause += " JOIN messages_fts mfts on (mfts.rowid = a.scn_message_id) " + } + + sqlClauses := make([]string, 0) + + params := sq.PP{} + + if f.ConfirmedSubscriptionBy != nil { + sqlClauses = append(sqlClauses, "(subs.subscriber_user_id = :sub_uid AND subs.confirmed = 1)") + params["sub_uid"] = *f.ConfirmedSubscriptionBy + } + + if f.SearchString != nil { + filter := make([]string, 0) + for i, v := range *f.SearchString { + filter = append(filter, fmt.Sprintf("(messages_fts match :searchstring_%d)", i)) + params[fmt.Sprintf("searchstring_%d", i)] = v + } + sqlClauses = append(sqlClauses, "("+strings.Join(filter, " OR ")+")") + } + + if f.Sender != nil { + filter := make([]string, 0) + for i, v := range *f.Sender { + filter = append(filter, fmt.Sprintf("(sender_user_id = :sender_%d)", i)) + params[fmt.Sprintf("sender_%d", i)] = v + } + sqlClauses = append(sqlClauses, "("+strings.Join(filter, " OR ")+")") + } + + if f.Owner != nil { + filter := make([]string, 0) + for i, v := range *f.Sender { + filter = append(filter, fmt.Sprintf("(owner_user_id = :owner_%d)", i)) + params[fmt.Sprintf("owner_%d", i)] = v + } + sqlClauses = append(sqlClauses, "("+strings.Join(filter, " OR ")+")") + } + + if f.ChannelNameCI != nil { + filter := make([]string, 0) + for i, v := range *f.ChannelNameCI { + filter = append(filter, fmt.Sprintf("(channel_name = :channelnameci_%d COLLATE NOCASE)", i)) + params[fmt.Sprintf("channelnameci_%d", i)] = v + } + sqlClauses = append(sqlClauses, "("+strings.Join(filter, " OR ")+")") + } + + if f.ChannelNameCS != nil { + filter := make([]string, 0) + for i, v := range *f.ChannelNameCS { + filter = append(filter, fmt.Sprintf("(channel_name = :channelnamecs_%d COLLATE BINARY)", i)) + params[fmt.Sprintf("channelnamecs_%d", i)] = v + } + sqlClauses = append(sqlClauses, "("+strings.Join(filter, " OR ")+")") + } + + if f.ChannelID != nil { + filter := make([]string, 0) + for i, v := range *f.ChannelID { + filter = append(filter, fmt.Sprintf("(channel_id = :channelid_%d)", i)) + params[fmt.Sprintf("channelid_%d", i)] = v + } + sqlClauses = append(sqlClauses, "("+strings.Join(filter, " OR ")+")") + } + + if f.SenderNameCI != nil { + filter := make([]string, 0) + for i, v := range *f.ChannelNameCI { + filter = append(filter, fmt.Sprintf("(sender_name = :sendernameci_%d COLLATE NOCASE)", i)) + params[fmt.Sprintf("sendernameci_%d", i)] = v + } + sqlClauses = append(sqlClauses, "(sender_name IS NOT NULL AND ("+strings.Join(filter, " OR ")+"))") + } + + if f.SenderNameCS != nil { + filter := make([]string, 0) + for i, v := range *f.ChannelNameCS { + filter = append(filter, fmt.Sprintf("(sender_name = :sendernamecs_%d COLLATE BINARY)", i)) + params[fmt.Sprintf("sendernamecs_%d", i)] = v + } + sqlClauses = append(sqlClauses, "(sender_name IS NOT NULL AND ("+strings.Join(filter, " OR ")+"))") + } + + if f.SenderIP != nil { + filter := make([]string, 0) + for i, v := range *f.SenderIP { + filter = append(filter, fmt.Sprintf("(sender_ip = :senderip_%d)", i)) + params[fmt.Sprintf("senderip_%d", i)] = v + } + sqlClauses = append(sqlClauses, "("+strings.Join(filter, " OR ")+")") + } + + if f.TimestampCoalesce != nil { + sqlClauses = append(sqlClauses, "(COALESCE(timestamp_client, timestamp_real) = :ts_equals)") + params["ts_equals"] = (*f.TimestampCoalesce).UnixMilli() + } + + if f.TimestampCoalesceAfter != nil { + sqlClauses = append(sqlClauses, "(COALESCE(timestamp_client, timestamp_real) > :ts_after)") + params["ts_after"] = (*f.TimestampCoalesceAfter).UnixMilli() + } + + if f.TimestampCoalesceBefore != nil { + sqlClauses = append(sqlClauses, "(COALESCE(timestamp_client, timestamp_real) < :ts_before)") + params["ts_before"] = (*f.TimestampCoalesceBefore).UnixMilli() + } + + if f.TimestampReal != nil { + sqlClauses = append(sqlClauses, "(timestamp_real = :ts_real_equals)") + params["ts_real_equals"] = (*f.TimestampRealAfter).UnixMilli() + } + + if f.TimestampRealAfter != nil { + sqlClauses = append(sqlClauses, "(timestamp_real > :ts_real_after)") + params["ts_real_after"] = (*f.TimestampRealAfter).UnixMilli() + } + + if f.TimestampRealBefore != nil { + sqlClauses = append(sqlClauses, "(timestamp_real < :ts_real_before)") + params["ts_real_before"] = (*f.TimestampRealAfter).UnixMilli() + } + + if f.TimestampClient != nil { + sqlClauses = append(sqlClauses, "(timestamp_client IS NOT NULL AND timestamp_client = :ts_client_equals)") + params["ts_client_equals"] = (*f.TimestampClient).UnixMilli() + } + + if f.TimestampClientAfter != nil { + sqlClauses = append(sqlClauses, "(timestamp_client IS NOT NULL AND timestamp_client > :ts_client_after)") + params["ts_client_after"] = (*f.TimestampClientAfter).UnixMilli() + } + + if f.TimestampClientBefore != nil { + sqlClauses = append(sqlClauses, "(timestamp_client IS NOT NULL AND timestamp_client < :ts_client_before)") + params["ts_client_before"] = (*f.TimestampClientBefore).UnixMilli() + } + + if f.TitleCI != nil { + sqlClauses = append(sqlClauses, "(title = :titleci COLLATE NOCASE)") + params["titleci"] = *f.TitleCI + } + + if f.TitleCS != nil { + sqlClauses = append(sqlClauses, "(title = :titleci COLLATE BINARY)") + params["titleci"] = *f.TitleCI + } + + if f.Priority != nil { + prioList := "(" + strings.Join(langext.ArrMap(*f.Priority, func(p int) string { return strconv.Itoa(p) }), ", ") + ")" + sqlClauses = append(sqlClauses, "(priority IN "+prioList+")") + } + + if f.UserMessageID != nil { + filter := make([]string, 0) + for i, v := range *f.UserMessageID { + filter = append(filter, fmt.Sprintf("(usr_message_id = :usermessageid_%d)", i)) + params[fmt.Sprintf("usermessageid_%d", i)] = v + } + sqlClauses = append(sqlClauses, "(usr_message_id IS NOT NULL AND ("+strings.Join(filter, " OR ")+"))") + } + + sqlClause := "" + if len(sqlClauses) > 0 { + sqlClause = strings.Join(sqlClauses, " AND ") + } + + return sqlClause, joinClause, params, nil +} + +func (f MessageFilter) Hash() string { + bh, err := dataext.StructHash(f, dataext.StructHashOptions{HashAlgo: sha512.New()}) + if err != nil { + return "00000000" + } + + str := hex.EncodeToString(bh) + return str[0:mathext.Min(8, len(bh))] +} diff --git a/server/models/utils.go b/server/models/utils.go index e0f06bd..af6dbd8 100644 --- a/server/models/utils.go +++ b/server/models/utils.go @@ -63,3 +63,7 @@ func scanAll[TData any](rows *sqlx.Rows) ([]TData, error) { } //TODO move scanAll+scanSingle into sq package (?) + +//TODO als add convenient methods: +// - QueryScanSingle[T any](..) +// - QueryScanMulti[T any](..)