From d8c06e3de26d7db480abca89afc303f1b0ac381c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mike=20Schw=C3=B6rer?= Date: Fri, 20 Sep 2024 15:36:16 +0200 Subject: [PATCH] Fix test [TestListMessagesFilterChannel] --- scnserver/api/handler/apiChannel.go | 7 +- scnserver/api/handler/apiMessage.go | 12 ++- scnserver/api/handler/compat.go | 2 +- scnserver/db/impl/primary/messages.go | 42 +++++--- scnserver/db/impl/primary/utils.go | 4 + scnserver/models/messagefilter.go | 9 ++ scnserver/test/message_test.go | 138 +++++++------------------- scnserver/test/util/factory.go | 2 +- 8 files changed, 91 insertions(+), 125 deletions(-) diff --git a/scnserver/api/handler/apiChannel.go b/scnserver/api/handler/apiChannel.go index 93172d3..ca43144 100644 --- a/scnserver/api/handler/apiChannel.go +++ b/scnserver/api/handler/apiChannel.go @@ -418,6 +418,7 @@ func (h APIHandler) ListChannelMessages(pctx ginext.PreContext) ginext.HTTPRespo Messages []models.Message `json:"messages"` NextPageToken string `json:"next_page_token"` PageSize int `json:"page_size"` + TotalCount int64 `json:"total_count"` } var u uri @@ -457,16 +458,16 @@ func (h APIHandler) ListChannelMessages(pctx ginext.PreContext) ginext.HTTPRespo ChannelID: langext.Ptr([]models.ChannelID{channel.ChannelID}), } - messages, npt, err := h.database.ListMessages(ctx, filter, &pageSize, tok) + messages, npt, totalCount, err := h.database.ListMessages(ctx, filter, &pageSize, tok) if err != nil { return ginresp.APIError(g, 500, apierr.DATABASE_ERROR, "Failed to query messages", err) } if trimmed { res := langext.ArrMap(messages, func(v models.Message) models.Message { return v.Trim() }) - return finishSuccess(ginext.JSON(http.StatusOK, response{Messages: res, NextPageToken: npt.Token(), PageSize: pageSize})) + return finishSuccess(ginext.JSON(http.StatusOK, response{Messages: res, NextPageToken: npt.Token(), PageSize: pageSize, TotalCount: totalCount})) } else { - return finishSuccess(ginext.JSON(http.StatusOK, response{Messages: messages, NextPageToken: npt.Token(), PageSize: pageSize})) + return finishSuccess(ginext.JSON(http.StatusOK, response{Messages: messages, NextPageToken: npt.Token(), PageSize: pageSize, TotalCount: totalCount})) } }) diff --git a/scnserver/api/handler/apiMessage.go b/scnserver/api/handler/apiMessage.go index 2631f31..5a69749 100644 --- a/scnserver/api/handler/apiMessage.go +++ b/scnserver/api/handler/apiMessage.go @@ -48,11 +48,13 @@ func (h APIHandler) ListMessages(pctx ginext.PreContext) ginext.HTTPResponse { TimeAfter *string `json:"after" form:"after"` // RFC3339 Priority []int `json:"priority" form:"priority"` KeyTokens []string `json:"used_key" form:"used_key"` + HasSender *bool `json:"has_sender" form:"has_sender"` } type response struct { Messages []models.Message `json:"messages"` NextPageToken string `json:"next_page_token"` PageSize int `json:"page_size"` + TotalCount int64 `json:"total_count"` } var q query @@ -114,6 +116,10 @@ func (h APIHandler) ListMessages(pctx ginext.PreContext) ginext.HTTPResponse { filter.SenderNameCS = langext.Ptr(q.Senders) } + if q.HasSender != nil { + filter.HasSenderName = langext.Ptr(*q.HasSender) + } + if q.TimeBefore != nil { t0, err := time.Parse(time.RFC3339, *q.TimeBefore) if err != nil { @@ -146,17 +152,17 @@ func (h APIHandler) ListMessages(pctx ginext.PreContext) ginext.HTTPResponse { filter.UsedKeyID = &tids } - messages, npt, err := h.database.ListMessages(ctx, filter, &pageSize, tok) + messages, npt, totalCount, err := h.database.ListMessages(ctx, filter, &pageSize, tok) if err != nil { return ginresp.APIError(g, 500, apierr.DATABASE_ERROR, "Failed to query messages", err) } if trimmed { res := langext.ArrMap(messages, func(v models.Message) models.Message { return v.PreMarshal().Trim() }) - return finishSuccess(ginext.JSON(http.StatusOK, response{Messages: res, NextPageToken: npt.Token(), PageSize: pageSize})) + return finishSuccess(ginext.JSON(http.StatusOK, response{Messages: res, NextPageToken: npt.Token(), PageSize: pageSize, TotalCount: totalCount})) } else { res := langext.ArrMap(messages, func(v models.Message) models.Message { return v.PreMarshal() }) - return finishSuccess(ginext.JSON(http.StatusOK, response{Messages: res, NextPageToken: npt.Token(), PageSize: pageSize})) + return finishSuccess(ginext.JSON(http.StatusOK, response{Messages: res, NextPageToken: npt.Token(), PageSize: pageSize, TotalCount: totalCount})) } }) } diff --git a/scnserver/api/handler/compat.go b/scnserver/api/handler/compat.go index 24a7c46..9ba439b 100644 --- a/scnserver/api/handler/compat.go +++ b/scnserver/api/handler/compat.go @@ -538,7 +538,7 @@ func (h CompatHandler) Requery(pctx ginext.PreContext) ginext.HTTPResponse { CompatAcknowledged: langext.Ptr(false), } - msgs, _, err := h.database.ListMessages(ctx, filter, langext.Ptr(16), ct.Start()) + msgs, _, _, err := h.database.ListMessages(ctx, filter, langext.Ptr(16), ct.Start()) if err != nil { return ginresp.CompatAPIError(0, "Failed to query user") } diff --git a/scnserver/db/impl/primary/messages.go b/scnserver/db/impl/primary/messages.go index 6293204..00402ac 100644 --- a/scnserver/db/impl/primary/messages.go +++ b/scnserver/db/impl/primary/messages.go @@ -80,14 +80,10 @@ func (db *Database) DeleteMessage(ctx db.TxContext, messageID models.MessageID) return nil } -func (db *Database) ListMessages(ctx db.TxContext, filter models.MessageFilter, pageSize *int, inTok ct.CursorToken) ([]models.Message, ct.CursorToken, error) { +func (db *Database) ListMessages(ctx db.TxContext, filter models.MessageFilter, pageSize *int, inTok ct.CursorToken) ([]models.Message, ct.CursorToken, int64, error) { tx, err := ctx.GetOrCreateTransaction(db) if err != nil { - return nil, ct.CursorToken{}, err - } - - if inTok.Mode == ct.CTMEnd { - return make([]models.Message, 0), ct.End(), nil + return nil, ct.CursorToken{}, 0, err } pageCond := "1=1" @@ -105,21 +101,39 @@ func (db *Database) ListMessages(ctx db.TxContext, filter models.MessageFilter, orderClause = "ORDER BY COALESCE(timestamp_client, timestamp_real) DESC, message_id DESC" } - sqlQuery := "SELECT " + "messages.*" + " FROM messages " + filterJoin + " WHERE ( " + pageCond + " ) AND ( " + filterCond + " ) " + orderClause + sqlQueryList := "SELECT " + "messages.*" + " FROM messages " + filterJoin + " WHERE ( " + pageCond + " ) AND ( " + filterCond + " ) " + orderClause + sqlQueryCount := "SELECT " + " COUNT(*) AS count FROM messages " + filterJoin + " WHERE ( " + filterCond + " ) " prepParams["tokts"] = inTok.Timestamp prepParams["tokid"] = inTok.Id - data, err := sq.QueryAll[models.Message](ctx, tx, sqlQuery, prepParams, sq.SModeExtended, sq.Safe) - if err != nil { - return nil, ct.CursorToken{}, err + if inTok.Mode == ct.CTMEnd { + + dataCount, err := sq.QuerySingle[CountResponse](ctx, tx, sqlQueryCount, prepParams, sq.SModeFast, sq.Safe) + if err != nil { + return nil, ct.CursorToken{}, 0, err + } + + return make([]models.Message, 0), ct.End(), dataCount.Count, nil } - if pageSize == nil || len(data) <= *pageSize { - return data, ct.End(), nil + dataList, err := sq.QueryAll[models.Message](ctx, tx, sqlQueryList, prepParams, sq.SModeExtended, sq.Safe) + if err != nil { + return nil, ct.CursorToken{}, 0, err + } + + if pageSize == nil || len(dataList) <= *pageSize { + return dataList, ct.End(), int64(len(dataList)), nil } else { - outToken := ct.Normal(data[*pageSize-1].Timestamp(), data[*pageSize-1].MessageID.String(), "DESC", filter.Hash()) - return data[0:*pageSize], outToken, nil + + dataCount, err := sq.QuerySingle[CountResponse](ctx, tx, sqlQueryCount, prepParams, sq.SModeFast, sq.Safe) + if err != nil { + return nil, ct.CursorToken{}, 0, err + } + + outToken := ct.Normal(dataList[*pageSize-1].Timestamp(), dataList[*pageSize-1].MessageID.String(), "DESC", filter.Hash()) + + return dataList[0:*pageSize], outToken, dataCount.Count, nil } } diff --git a/scnserver/db/impl/primary/utils.go b/scnserver/db/impl/primary/utils.go index 9c693cd..9277f1d 100644 --- a/scnserver/db/impl/primary/utils.go +++ b/scnserver/db/impl/primary/utils.go @@ -23,3 +23,7 @@ func time2DBOpt(t *time.Time) *int64 { } return langext.Ptr(t.UnixMilli()) } + +type CountResponse struct { + Count int64 `db:"count"` +} diff --git a/scnserver/models/messagefilter.go b/scnserver/models/messagefilter.go index a875111..c27894b 100644 --- a/scnserver/models/messagefilter.go +++ b/scnserver/models/messagefilter.go @@ -22,6 +22,7 @@ type MessageFilter struct { ChannelID *[]ChannelID SenderNameCS *[]string // case-sensitive SenderNameCI *[]string // case-insensitive + HasSenderName *bool SenderIP *[]string TimestampCoalesce *time.Time TimestampCoalesceAfter *time.Time @@ -123,6 +124,14 @@ func (f MessageFilter) SQL() (string, string, sq.PP, error) { sqlClauses = append(sqlClauses, "(sender_name IS NOT NULL AND ("+strings.Join(filter, " OR ")+"))") } + if f.HasSenderName != nil { + if *f.HasSenderName { + sqlClauses = append(sqlClauses, "(sender_name IS NOT NULL)") + } else { + sqlClauses = append(sqlClauses, "(sender_name IS NULL)") + } + } + if f.SenderIP != nil { filter := make([]string, 0) for i, v := range *f.SenderIP { diff --git a/scnserver/test/message_test.go b/scnserver/test/message_test.go index c6f6a1d..f550828 100644 --- a/scnserver/test/message_test.go +++ b/scnserver/test/message_test.go @@ -721,7 +721,8 @@ func TestListMessagesFilterChannel(t *testing.T) { UsrMessageId string `json:"usr_message_id"` } type mglist struct { - Messages []msg `json:"messages"` + Messages []msg `json:"messages"` + TotalCount int `json:"total_count"` } cid1 := "" @@ -752,109 +753,40 @@ func TestListMessagesFilterChannel(t *testing.T) { } } - { - msgList := tt.RequestAuthGet[mglist](t, data.User[0].AdminKey, baseUrl, fmt.Sprintf("/api/v2/messages?channel=%s", "Reminders,Promotions")) - tt.AssertEqual(t, "msgList.len", 9, len(msgList.Messages)) + filterTests := []struct { + Name string + Count int + Query string + }{ + {"all", 22, fmt.Sprintf("/api/v2/messages")}, + {"channel=Reminders|Promotions", 9, fmt.Sprintf("/api/v2/messages?channel=%s&channel=%s", "Reminders", "Promotions")}, + {"channel=Reminders", 6, fmt.Sprintf("/api/v2/messages?channel=%s", "Reminders")}, + {"channel_id=1", 6, fmt.Sprintf("/api/v2/messages?channel_id=%s", cid1)}, + {"channel_id=1|2", 9, fmt.Sprintf("/api/v2/messages?channel_id=%s&channel_id=%s", cid1, cid2)}, + {"filter=unusual", 1, fmt.Sprintf("/api/v2/messages?filter=%s", "unusual")}, + {"filter=your", 6, fmt.Sprintf("/api/v2/messages?filter=%s", "your")}, + {"prio=0", 5, fmt.Sprintf("/api/v2/messages?priority=%s", "0")}, + {"prio=1", 4 + 7, fmt.Sprintf("/api/v2/messages?priority=%s", "1")}, + {"prio=2", 6, fmt.Sprintf("/api/v2/messages?priority=%s", "2")}, + {"prio=0|2", 5 + 6, fmt.Sprintf("/api/v2/messages?priority=%s&priority=%s", "0", "2")}, + {"key=a", 11, fmt.Sprintf("/api/v2/messages?used_key=%s", akey)}, + {"key=s", 11, fmt.Sprintf("/api/v2/messages?used_key=%s", skey)}, + {"key=a|s", 11 + 11, fmt.Sprintf("/api/v2/messages?used_key=%s&used_key=%s", akey, skey)}, + {"key=a&prio=0", 0, fmt.Sprintf("/api/v2/messages?used_key=%s&priority=%d", akey, 0)}, + {"key=s&prio=0", 5, fmt.Sprintf("/api/v2/messages?used_key=%s&priority=%d", skey, 0)}, + {"key=s&prio=2", 6, fmt.Sprintf("/api/v2/messages?used_key=%s&priority=%d", skey, 2)}, + {"sender=MobileMate", 4, fmt.Sprintf("/api/v2/messages?sender=%s", url.QueryEscape("Mobile Mate"))}, + {"sender=PocketPal", 3, fmt.Sprintf("/api/v2/messages?sender=%s", url.QueryEscape("Pocket Pal"))}, + {"sender=MobileMate|PocketPal", 3 + 4, fmt.Sprintf("/api/v2/messages?sender=%s&sender=%s", url.QueryEscape("Pocket Pal"), url.QueryEscape("Mobile Mate"))}, + {"sender=empty", 12, fmt.Sprintf("/api/v2/messages?has_sender=%s", "false")}, + {"sender=any", 10, fmt.Sprintf("/api/v2/messages?has_sender=%s", "true")}, + {"before=-1H", 2, fmt.Sprintf("/api/v2/messages?before=%s", url.QueryEscape(time.Now().Add(-time.Hour).Format(time.RFC3339Nano)))}, + {"after=-1H", 20, fmt.Sprintf("/api/v2/messages?after=%s", url.QueryEscape(time.Now().Add(-time.Hour).Format(time.RFC3339Nano)))}, + {"after=+5min", 3, fmt.Sprintf("/api/v2/messages?after=%s", url.QueryEscape(time.Now().Add(5*time.Minute).Format(time.RFC3339Nano)))}, } - { - msgList := tt.RequestAuthGet[mglist](t, data.User[0].AdminKey, baseUrl, fmt.Sprintf("/api/v2/messages?channel=%s", "Reminders")) - tt.AssertEqual(t, "msgList.len", 6, len(msgList.Messages)) + for _, testdata := range filterTests { + msgList := tt.RequestAuthGet[mglist](t, data.User[0].AdminKey, baseUrl, testdata.Query) + tt.AssertEqual(t, "msgList.filter["+testdata.Name+"].len", testdata.Count, msgList.TotalCount) } - - { - msgList := tt.RequestAuthGet[mglist](t, data.User[0].AdminKey, baseUrl, fmt.Sprintf("/api/v2/messages?channel_id=%s", cid1)) - tt.AssertEqual(t, "msgList.len", 6, len(msgList.Messages)) - } - - { - msgList := tt.RequestAuthGet[mglist](t, data.User[0].AdminKey, baseUrl, fmt.Sprintf("/api/v2/messages?channel_id=%s,%s", cid1, cid2)) - tt.AssertEqual(t, "msgList.len", 9, len(msgList.Messages)) - } - - { - msgList := tt.RequestAuthGet[mglist](t, data.User[0].AdminKey, baseUrl, fmt.Sprintf("/api/v2/messages?filter=%s", "unusual")) - tt.AssertEqual(t, "msgList.len", 1, len(msgList.Messages)) - } - - { - msgList := tt.RequestAuthGet[mglist](t, data.User[0].AdminKey, baseUrl, fmt.Sprintf("/api/v2/messages?filter=%s", "your")) - tt.AssertEqual(t, "msgList.len", 7, len(msgList.Messages)) - } - - { - msgList := tt.RequestAuthGet[mglist](t, data.User[0].AdminKey, baseUrl, fmt.Sprintf("/api/v2/messages?priority=%s", "1")) - tt.AssertEqual(t, "msgList.len", 4, len(msgList.Messages)) - } - - { - msgList := tt.RequestAuthGet[mglist](t, data.User[0].AdminKey, baseUrl, fmt.Sprintf("/api/v2/messages?priority=%s", "2")) - tt.AssertEqual(t, "msgList.len", 6, len(msgList.Messages)) - } - - { - msgList := tt.RequestAuthGet[mglist](t, data.User[0].AdminKey, baseUrl, fmt.Sprintf("/api/v2/messages?priority=%s", "0")) - tt.AssertEqual(t, "msgList.len", 5, len(msgList.Messages)) - } - - { - msgList := tt.RequestAuthGet[mglist](t, data.User[0].AdminKey, baseUrl, fmt.Sprintf("/api/v2/messages?priority=%s", "0,2")) - tt.AssertEqual(t, "msgList.len", 11, len(msgList.Messages)) - } - - { - msgList := tt.RequestAuthGet[mglist](t, data.User[0].AdminKey, baseUrl, fmt.Sprintf("/api/v2/messages?used_key_id=%s", akey)) - tt.AssertEqual(t, "msgList.len", 11, len(msgList.Messages)) - } - - { - msgList := tt.RequestAuthGet[mglist](t, data.User[0].AdminKey, baseUrl, fmt.Sprintf("/api/v2/messages?used_key_id=%s", skey)) - tt.AssertEqual(t, "msgList.len", 11, len(msgList.Messages)) - } - - { - msgList := tt.RequestAuthGet[mglist](t, data.User[0].AdminKey, baseUrl, fmt.Sprintf("/api/v2/messages?used_key_id=%s,%s", akey, skey)) - tt.AssertEqual(t, "msgList.len", 22, len(msgList.Messages)) - } - - { - msgList := tt.RequestAuthGet[mglist](t, data.User[0].AdminKey, baseUrl, fmt.Sprintf("/api/v2/messages?used_key_id=%s&priority=%d", akey, 0)) - tt.AssertEqual(t, "msgList.len", 5, len(msgList.Messages)) - } - - { - msgList := tt.RequestAuthGet[mglist](t, data.User[0].AdminKey, baseUrl, fmt.Sprintf("/api/v2/messages?sender=%s", "Mobile Mate")) - tt.AssertEqual(t, "msgList.len", 3, len(msgList.Messages)) - } - - { - msgList := tt.RequestAuthGet[mglist](t, data.User[0].AdminKey, baseUrl, fmt.Sprintf("/api/v2/messages?sender=%s", "Pocket Pal")) - tt.AssertEqual(t, "msgList.len", 3, len(msgList.Messages)) - } - - { - msgList := tt.RequestAuthGet[mglist](t, data.User[0].AdminKey, baseUrl, fmt.Sprintf("/api/v2/messages?sender=%s,%s", "Pocket Pal", "Mobile Mate")) - tt.AssertEqual(t, "msgList.len", 6, len(msgList.Messages)) - } - - { - msgList := tt.RequestAuthGet[mglist](t, data.User[0].AdminKey, baseUrl, fmt.Sprintf("/api/v2/messages?sender=%s", "")) - tt.AssertEqual(t, "msgList.len", 12, len(msgList.Messages)) - } - - { - msgList := tt.RequestAuthGet[mglist](t, data.User[0].AdminKey, baseUrl, fmt.Sprintf("/api/v2/messages?before=%s", time.Now().Add(-time.Hour))) - tt.AssertEqual(t, "msgList.len", 2, len(msgList.Messages)) - } - - { - msgList := tt.RequestAuthGet[mglist](t, data.User[0].AdminKey, baseUrl, fmt.Sprintf("/api/v2/messages?after=%s", time.Now().Add(-time.Hour))) - tt.AssertEqual(t, "msgList.len", 20, len(msgList.Messages)) - } - - { - msgList := tt.RequestAuthGet[mglist](t, data.User[0].AdminKey, baseUrl, fmt.Sprintf("/api/v2/messages?after=%s", time.Now().Add(5*time.Minute))) - tt.AssertEqual(t, "msgList.len", 3, len(msgList.Messages)) - } - } diff --git a/scnserver/test/util/factory.go b/scnserver/test/util/factory.go index 81d257f..c4287e6 100644 --- a/scnserver/test/util/factory.go +++ b/scnserver/test/util/factory.go @@ -450,7 +450,7 @@ func InitDefaultData(t *testing.T, ws *logic.Application) DefData { Keys []skey `json:"keys"` } r0 := RequestAuthGet[keylist](t, usr.AdminKey, baseUrl, fmt.Sprintf("/api/v2/users/%s/keys", usr.UID)) - users[i].Keys = langext.ArrMap(r0.Keys, func(v skey) KeyDat { return KeyDat{KeyID: v.ID, KeyName: v.Name} }) + users[i].Keys = langext.ArrMap(r0.Keys, func(v skey) KeyDat { return KeyDat{KeyID: v.ID, Name: v.Name} }) } // list subscriptions