Fix test [TestListMessagesFilterChannel]

This commit is contained in:
Mike Schwörer 2024-09-20 15:36:16 +02:00
parent 3adeadf6fb
commit d8c06e3de2
Signed by: Mikescher
GPG Key ID: D3C7172E0A70F8CF
8 changed files with 91 additions and 125 deletions

View File

@ -418,6 +418,7 @@ func (h APIHandler) ListChannelMessages(pctx ginext.PreContext) ginext.HTTPRespo
Messages []models.Message `json:"messages"`
NextPageToken string `json:"next_page_token"`
PageSize int `json:"page_size"`
TotalCount int64 `json:"total_count"`
}
var u uri
@ -457,16 +458,16 @@ func (h APIHandler) ListChannelMessages(pctx ginext.PreContext) ginext.HTTPRespo
ChannelID: langext.Ptr([]models.ChannelID{channel.ChannelID}),
}
messages, npt, err := h.database.ListMessages(ctx, filter, &pageSize, tok)
messages, npt, totalCount, err := h.database.ListMessages(ctx, filter, &pageSize, tok)
if err != nil {
return ginresp.APIError(g, 500, apierr.DATABASE_ERROR, "Failed to query messages", err)
}
if trimmed {
res := langext.ArrMap(messages, func(v models.Message) models.Message { return v.Trim() })
return finishSuccess(ginext.JSON(http.StatusOK, response{Messages: res, NextPageToken: npt.Token(), PageSize: pageSize}))
return finishSuccess(ginext.JSON(http.StatusOK, response{Messages: res, NextPageToken: npt.Token(), PageSize: pageSize, TotalCount: totalCount}))
} else {
return finishSuccess(ginext.JSON(http.StatusOK, response{Messages: messages, NextPageToken: npt.Token(), PageSize: pageSize}))
return finishSuccess(ginext.JSON(http.StatusOK, response{Messages: messages, NextPageToken: npt.Token(), PageSize: pageSize, TotalCount: totalCount}))
}
})

View File

@ -48,11 +48,13 @@ func (h APIHandler) ListMessages(pctx ginext.PreContext) ginext.HTTPResponse {
TimeAfter *string `json:"after" form:"after"` // RFC3339
Priority []int `json:"priority" form:"priority"`
KeyTokens []string `json:"used_key" form:"used_key"`
HasSender *bool `json:"has_sender" form:"has_sender"`
}
type response struct {
Messages []models.Message `json:"messages"`
NextPageToken string `json:"next_page_token"`
PageSize int `json:"page_size"`
TotalCount int64 `json:"total_count"`
}
var q query
@ -114,6 +116,10 @@ func (h APIHandler) ListMessages(pctx ginext.PreContext) ginext.HTTPResponse {
filter.SenderNameCS = langext.Ptr(q.Senders)
}
if q.HasSender != nil {
filter.HasSenderName = langext.Ptr(*q.HasSender)
}
if q.TimeBefore != nil {
t0, err := time.Parse(time.RFC3339, *q.TimeBefore)
if err != nil {
@ -146,17 +152,17 @@ func (h APIHandler) ListMessages(pctx ginext.PreContext) ginext.HTTPResponse {
filter.UsedKeyID = &tids
}
messages, npt, err := h.database.ListMessages(ctx, filter, &pageSize, tok)
messages, npt, totalCount, err := h.database.ListMessages(ctx, filter, &pageSize, tok)
if err != nil {
return ginresp.APIError(g, 500, apierr.DATABASE_ERROR, "Failed to query messages", err)
}
if trimmed {
res := langext.ArrMap(messages, func(v models.Message) models.Message { return v.PreMarshal().Trim() })
return finishSuccess(ginext.JSON(http.StatusOK, response{Messages: res, NextPageToken: npt.Token(), PageSize: pageSize}))
return finishSuccess(ginext.JSON(http.StatusOK, response{Messages: res, NextPageToken: npt.Token(), PageSize: pageSize, TotalCount: totalCount}))
} else {
res := langext.ArrMap(messages, func(v models.Message) models.Message { return v.PreMarshal() })
return finishSuccess(ginext.JSON(http.StatusOK, response{Messages: res, NextPageToken: npt.Token(), PageSize: pageSize}))
return finishSuccess(ginext.JSON(http.StatusOK, response{Messages: res, NextPageToken: npt.Token(), PageSize: pageSize, TotalCount: totalCount}))
}
})
}

View File

@ -538,7 +538,7 @@ func (h CompatHandler) Requery(pctx ginext.PreContext) ginext.HTTPResponse {
CompatAcknowledged: langext.Ptr(false),
}
msgs, _, err := h.database.ListMessages(ctx, filter, langext.Ptr(16), ct.Start())
msgs, _, _, err := h.database.ListMessages(ctx, filter, langext.Ptr(16), ct.Start())
if err != nil {
return ginresp.CompatAPIError(0, "Failed to query user")
}

View File

@ -80,14 +80,10 @@ func (db *Database) DeleteMessage(ctx db.TxContext, messageID models.MessageID)
return nil
}
func (db *Database) ListMessages(ctx db.TxContext, filter models.MessageFilter, pageSize *int, inTok ct.CursorToken) ([]models.Message, ct.CursorToken, error) {
func (db *Database) ListMessages(ctx db.TxContext, filter models.MessageFilter, pageSize *int, inTok ct.CursorToken) ([]models.Message, ct.CursorToken, int64, error) {
tx, err := ctx.GetOrCreateTransaction(db)
if err != nil {
return nil, ct.CursorToken{}, err
}
if inTok.Mode == ct.CTMEnd {
return make([]models.Message, 0), ct.End(), nil
return nil, ct.CursorToken{}, 0, err
}
pageCond := "1=1"
@ -105,21 +101,39 @@ func (db *Database) ListMessages(ctx db.TxContext, filter models.MessageFilter,
orderClause = "ORDER BY COALESCE(timestamp_client, timestamp_real) DESC, message_id DESC"
}
sqlQuery := "SELECT " + "messages.*" + " FROM messages " + filterJoin + " WHERE ( " + pageCond + " ) AND ( " + filterCond + " ) " + orderClause
sqlQueryList := "SELECT " + "messages.*" + " FROM messages " + filterJoin + " WHERE ( " + pageCond + " ) AND ( " + filterCond + " ) " + orderClause
sqlQueryCount := "SELECT " + " COUNT(*) AS count FROM messages " + filterJoin + " WHERE ( " + filterCond + " ) "
prepParams["tokts"] = inTok.Timestamp
prepParams["tokid"] = inTok.Id
data, err := sq.QueryAll[models.Message](ctx, tx, sqlQuery, prepParams, sq.SModeExtended, sq.Safe)
if inTok.Mode == ct.CTMEnd {
dataCount, err := sq.QuerySingle[CountResponse](ctx, tx, sqlQueryCount, prepParams, sq.SModeFast, sq.Safe)
if err != nil {
return nil, ct.CursorToken{}, err
return nil, ct.CursorToken{}, 0, err
}
if pageSize == nil || len(data) <= *pageSize {
return data, ct.End(), nil
return make([]models.Message, 0), ct.End(), dataCount.Count, nil
}
dataList, err := sq.QueryAll[models.Message](ctx, tx, sqlQueryList, prepParams, sq.SModeExtended, sq.Safe)
if err != nil {
return nil, ct.CursorToken{}, 0, err
}
if pageSize == nil || len(dataList) <= *pageSize {
return dataList, ct.End(), int64(len(dataList)), nil
} else {
outToken := ct.Normal(data[*pageSize-1].Timestamp(), data[*pageSize-1].MessageID.String(), "DESC", filter.Hash())
return data[0:*pageSize], outToken, nil
dataCount, err := sq.QuerySingle[CountResponse](ctx, tx, sqlQueryCount, prepParams, sq.SModeFast, sq.Safe)
if err != nil {
return nil, ct.CursorToken{}, 0, err
}
outToken := ct.Normal(dataList[*pageSize-1].Timestamp(), dataList[*pageSize-1].MessageID.String(), "DESC", filter.Hash())
return dataList[0:*pageSize], outToken, dataCount.Count, nil
}
}

View File

@ -23,3 +23,7 @@ func time2DBOpt(t *time.Time) *int64 {
}
return langext.Ptr(t.UnixMilli())
}
type CountResponse struct {
Count int64 `db:"count"`
}

View File

@ -22,6 +22,7 @@ type MessageFilter struct {
ChannelID *[]ChannelID
SenderNameCS *[]string // case-sensitive
SenderNameCI *[]string // case-insensitive
HasSenderName *bool
SenderIP *[]string
TimestampCoalesce *time.Time
TimestampCoalesceAfter *time.Time
@ -123,6 +124,14 @@ func (f MessageFilter) SQL() (string, string, sq.PP, error) {
sqlClauses = append(sqlClauses, "(sender_name IS NOT NULL AND ("+strings.Join(filter, " OR ")+"))")
}
if f.HasSenderName != nil {
if *f.HasSenderName {
sqlClauses = append(sqlClauses, "(sender_name IS NOT NULL)")
} else {
sqlClauses = append(sqlClauses, "(sender_name IS NULL)")
}
}
if f.SenderIP != nil {
filter := make([]string, 0)
for i, v := range *f.SenderIP {

View File

@ -722,6 +722,7 @@ func TestListMessagesFilterChannel(t *testing.T) {
}
type mglist struct {
Messages []msg `json:"messages"`
TotalCount int `json:"total_count"`
}
cid1 := ""
@ -752,109 +753,40 @@ func TestListMessagesFilterChannel(t *testing.T) {
}
}
{
msgList := tt.RequestAuthGet[mglist](t, data.User[0].AdminKey, baseUrl, fmt.Sprintf("/api/v2/messages?channel=%s", "Reminders,Promotions"))
tt.AssertEqual(t, "msgList.len", 9, len(msgList.Messages))
filterTests := []struct {
Name string
Count int
Query string
}{
{"all", 22, fmt.Sprintf("/api/v2/messages")},
{"channel=Reminders|Promotions", 9, fmt.Sprintf("/api/v2/messages?channel=%s&channel=%s", "Reminders", "Promotions")},
{"channel=Reminders", 6, fmt.Sprintf("/api/v2/messages?channel=%s", "Reminders")},
{"channel_id=1", 6, fmt.Sprintf("/api/v2/messages?channel_id=%s", cid1)},
{"channel_id=1|2", 9, fmt.Sprintf("/api/v2/messages?channel_id=%s&channel_id=%s", cid1, cid2)},
{"filter=unusual", 1, fmt.Sprintf("/api/v2/messages?filter=%s", "unusual")},
{"filter=your", 6, fmt.Sprintf("/api/v2/messages?filter=%s", "your")},
{"prio=0", 5, fmt.Sprintf("/api/v2/messages?priority=%s", "0")},
{"prio=1", 4 + 7, fmt.Sprintf("/api/v2/messages?priority=%s", "1")},
{"prio=2", 6, fmt.Sprintf("/api/v2/messages?priority=%s", "2")},
{"prio=0|2", 5 + 6, fmt.Sprintf("/api/v2/messages?priority=%s&priority=%s", "0", "2")},
{"key=a", 11, fmt.Sprintf("/api/v2/messages?used_key=%s", akey)},
{"key=s", 11, fmt.Sprintf("/api/v2/messages?used_key=%s", skey)},
{"key=a|s", 11 + 11, fmt.Sprintf("/api/v2/messages?used_key=%s&used_key=%s", akey, skey)},
{"key=a&prio=0", 0, fmt.Sprintf("/api/v2/messages?used_key=%s&priority=%d", akey, 0)},
{"key=s&prio=0", 5, fmt.Sprintf("/api/v2/messages?used_key=%s&priority=%d", skey, 0)},
{"key=s&prio=2", 6, fmt.Sprintf("/api/v2/messages?used_key=%s&priority=%d", skey, 2)},
{"sender=MobileMate", 4, fmt.Sprintf("/api/v2/messages?sender=%s", url.QueryEscape("Mobile Mate"))},
{"sender=PocketPal", 3, fmt.Sprintf("/api/v2/messages?sender=%s", url.QueryEscape("Pocket Pal"))},
{"sender=MobileMate|PocketPal", 3 + 4, fmt.Sprintf("/api/v2/messages?sender=%s&sender=%s", url.QueryEscape("Pocket Pal"), url.QueryEscape("Mobile Mate"))},
{"sender=empty", 12, fmt.Sprintf("/api/v2/messages?has_sender=%s", "false")},
{"sender=any", 10, fmt.Sprintf("/api/v2/messages?has_sender=%s", "true")},
{"before=-1H", 2, fmt.Sprintf("/api/v2/messages?before=%s", url.QueryEscape(time.Now().Add(-time.Hour).Format(time.RFC3339Nano)))},
{"after=-1H", 20, fmt.Sprintf("/api/v2/messages?after=%s", url.QueryEscape(time.Now().Add(-time.Hour).Format(time.RFC3339Nano)))},
{"after=+5min", 3, fmt.Sprintf("/api/v2/messages?after=%s", url.QueryEscape(time.Now().Add(5*time.Minute).Format(time.RFC3339Nano)))},
}
{
msgList := tt.RequestAuthGet[mglist](t, data.User[0].AdminKey, baseUrl, fmt.Sprintf("/api/v2/messages?channel=%s", "Reminders"))
tt.AssertEqual(t, "msgList.len", 6, len(msgList.Messages))
for _, testdata := range filterTests {
msgList := tt.RequestAuthGet[mglist](t, data.User[0].AdminKey, baseUrl, testdata.Query)
tt.AssertEqual(t, "msgList.filter["+testdata.Name+"].len", testdata.Count, msgList.TotalCount)
}
{
msgList := tt.RequestAuthGet[mglist](t, data.User[0].AdminKey, baseUrl, fmt.Sprintf("/api/v2/messages?channel_id=%s", cid1))
tt.AssertEqual(t, "msgList.len", 6, len(msgList.Messages))
}
{
msgList := tt.RequestAuthGet[mglist](t, data.User[0].AdminKey, baseUrl, fmt.Sprintf("/api/v2/messages?channel_id=%s,%s", cid1, cid2))
tt.AssertEqual(t, "msgList.len", 9, len(msgList.Messages))
}
{
msgList := tt.RequestAuthGet[mglist](t, data.User[0].AdminKey, baseUrl, fmt.Sprintf("/api/v2/messages?filter=%s", "unusual"))
tt.AssertEqual(t, "msgList.len", 1, len(msgList.Messages))
}
{
msgList := tt.RequestAuthGet[mglist](t, data.User[0].AdminKey, baseUrl, fmt.Sprintf("/api/v2/messages?filter=%s", "your"))
tt.AssertEqual(t, "msgList.len", 7, len(msgList.Messages))
}
{
msgList := tt.RequestAuthGet[mglist](t, data.User[0].AdminKey, baseUrl, fmt.Sprintf("/api/v2/messages?priority=%s", "1"))
tt.AssertEqual(t, "msgList.len", 4, len(msgList.Messages))
}
{
msgList := tt.RequestAuthGet[mglist](t, data.User[0].AdminKey, baseUrl, fmt.Sprintf("/api/v2/messages?priority=%s", "2"))
tt.AssertEqual(t, "msgList.len", 6, len(msgList.Messages))
}
{
msgList := tt.RequestAuthGet[mglist](t, data.User[0].AdminKey, baseUrl, fmt.Sprintf("/api/v2/messages?priority=%s", "0"))
tt.AssertEqual(t, "msgList.len", 5, len(msgList.Messages))
}
{
msgList := tt.RequestAuthGet[mglist](t, data.User[0].AdminKey, baseUrl, fmt.Sprintf("/api/v2/messages?priority=%s", "0,2"))
tt.AssertEqual(t, "msgList.len", 11, len(msgList.Messages))
}
{
msgList := tt.RequestAuthGet[mglist](t, data.User[0].AdminKey, baseUrl, fmt.Sprintf("/api/v2/messages?used_key_id=%s", akey))
tt.AssertEqual(t, "msgList.len", 11, len(msgList.Messages))
}
{
msgList := tt.RequestAuthGet[mglist](t, data.User[0].AdminKey, baseUrl, fmt.Sprintf("/api/v2/messages?used_key_id=%s", skey))
tt.AssertEqual(t, "msgList.len", 11, len(msgList.Messages))
}
{
msgList := tt.RequestAuthGet[mglist](t, data.User[0].AdminKey, baseUrl, fmt.Sprintf("/api/v2/messages?used_key_id=%s,%s", akey, skey))
tt.AssertEqual(t, "msgList.len", 22, len(msgList.Messages))
}
{
msgList := tt.RequestAuthGet[mglist](t, data.User[0].AdminKey, baseUrl, fmt.Sprintf("/api/v2/messages?used_key_id=%s&priority=%d", akey, 0))
tt.AssertEqual(t, "msgList.len", 5, len(msgList.Messages))
}
{
msgList := tt.RequestAuthGet[mglist](t, data.User[0].AdminKey, baseUrl, fmt.Sprintf("/api/v2/messages?sender=%s", "Mobile Mate"))
tt.AssertEqual(t, "msgList.len", 3, len(msgList.Messages))
}
{
msgList := tt.RequestAuthGet[mglist](t, data.User[0].AdminKey, baseUrl, fmt.Sprintf("/api/v2/messages?sender=%s", "Pocket Pal"))
tt.AssertEqual(t, "msgList.len", 3, len(msgList.Messages))
}
{
msgList := tt.RequestAuthGet[mglist](t, data.User[0].AdminKey, baseUrl, fmt.Sprintf("/api/v2/messages?sender=%s,%s", "Pocket Pal", "Mobile Mate"))
tt.AssertEqual(t, "msgList.len", 6, len(msgList.Messages))
}
{
msgList := tt.RequestAuthGet[mglist](t, data.User[0].AdminKey, baseUrl, fmt.Sprintf("/api/v2/messages?sender=%s", ""))
tt.AssertEqual(t, "msgList.len", 12, len(msgList.Messages))
}
{
msgList := tt.RequestAuthGet[mglist](t, data.User[0].AdminKey, baseUrl, fmt.Sprintf("/api/v2/messages?before=%s", time.Now().Add(-time.Hour)))
tt.AssertEqual(t, "msgList.len", 2, len(msgList.Messages))
}
{
msgList := tt.RequestAuthGet[mglist](t, data.User[0].AdminKey, baseUrl, fmt.Sprintf("/api/v2/messages?after=%s", time.Now().Add(-time.Hour)))
tt.AssertEqual(t, "msgList.len", 20, len(msgList.Messages))
}
{
msgList := tt.RequestAuthGet[mglist](t, data.User[0].AdminKey, baseUrl, fmt.Sprintf("/api/v2/messages?after=%s", time.Now().Add(5*time.Minute)))
tt.AssertEqual(t, "msgList.len", 3, len(msgList.Messages))
}
}

View File

@ -450,7 +450,7 @@ func InitDefaultData(t *testing.T, ws *logic.Application) DefData {
Keys []skey `json:"keys"`
}
r0 := RequestAuthGet[keylist](t, usr.AdminKey, baseUrl, fmt.Sprintf("/api/v2/users/%s/keys", usr.UID))
users[i].Keys = langext.ArrMap(r0.Keys, func(v skey) KeyDat { return KeyDat{KeyID: v.ID, KeyName: v.Name} })
users[i].Keys = langext.ArrMap(r0.Keys, func(v skey) KeyDat { return KeyDat{KeyID: v.ID, Name: v.Name} })
}
// list subscriptions