diff --git a/scnserver/api/handler/compat.go b/scnserver/api/handler/compat.go index 562483b..d6bc775 100644 --- a/scnserver/api/handler/compat.go +++ b/scnserver/api/handler/compat.go @@ -257,7 +257,7 @@ func (h CompatHandler) Info(g *gin.Context) ginresp.HTTPResponse { QuotaMax int `json:"quota_max"` IsPro int `json:"is_pro"` FCMSet bool `json:"fcm_token_set"` - UnackCount int `json:"unack_count"` + UnackCount int64 `json:"unack_count"` } var datq query @@ -309,6 +309,16 @@ func (h CompatHandler) Info(g *gin.Context) ginresp.HTTPResponse { return ginresp.CompatAPIError(0, "Failed to query clients") } + filter := models.MessageFilter{ + Owner: langext.Ptr([]models.UserID{user.UserID}), + CompatAcknowledged: langext.Ptr(false), + } + + unackCount, err := h.database.CountMessages(ctx, filter) + if err != nil { + return ginresp.CompatAPIError(0, "Failed to query user") + } + return ctx.FinishSuccess(ginresp.JSON(http.StatusOK, response{ Success: true, Message: "ok", @@ -318,7 +328,7 @@ func (h CompatHandler) Info(g *gin.Context) ginresp.HTTPResponse { QuotaMax: user.QuotaPerDay(), IsPro: langext.Conditional(user.IsPro, 1, 0), FCMSet: len(clients) > 0, - UnackCount: 0, + UnackCount: unackCount, })) } diff --git a/scnserver/db/impl/primary/messages.go b/scnserver/db/impl/primary/messages.go index 96c5902..c234828 100644 --- a/scnserver/db/impl/primary/messages.go +++ b/scnserver/db/impl/primary/messages.go @@ -4,6 +4,7 @@ import ( ct "blackforestbytes.com/simplecloudnotifier/db/cursortoken" "blackforestbytes.com/simplecloudnotifier/models" "database/sql" + "errors" "gogs.mikescher.com/BlackForestBytes/goext/sq" "time" ) @@ -149,3 +150,31 @@ func (db *Database) ListMessages(ctx TxContext, filter models.MessageFilter, pag return data[0:*pageSize], outToken, nil } } + +func (db *Database) CountMessages(ctx TxContext, filter models.MessageFilter) (int64, error) { + tx, err := ctx.GetOrCreateTransaction(db) + if err != nil { + return 0, err + } + + filterCond, filterJoin, prepParams, err := filter.SQL() + + sqlQuery := "SELECT " + "COUNT(*)" + " FROM messages " + filterJoin + " WHERE ( " + filterCond + " ) " + + rows, err := tx.Query(ctx, sqlQuery, prepParams) + if err != nil { + return 0, err + } + + if !rows.Next() { + return 0, errors.New("COUNT query returned no results") + } + + var countRes int64 + err = rows.Scan(&countRes) + if err != nil { + return 0, err + } + + return countRes, nil +} diff --git a/scnserver/test/compat_test.go b/scnserver/test/compat_test.go index ded0e71..db5ba9e 100644 --- a/scnserver/test/compat_test.go +++ b/scnserver/test/compat_test.go @@ -732,3 +732,90 @@ func TestCompatTitlePatch(t *testing.T) { tt.AssertStrRepEqual(t, "msg.ovrTitle", "[TestChan] HelloWorld_001", pusher.Last().CompatTitleOverride) } + +func TestCompatAckCount(t *testing.T) { + _, baseUrl, stop := tt.StartSimpleWebserver(t) + defer stop() + + r0 := tt.RequestGet[gin.H](t, baseUrl, fmt.Sprintf("/api/register.php?fcm_token=%s&pro=%s&pro_token=%s", "DUMMY_FCM", "0", "")) + tt.AssertEqual(t, "success", true, r0["success"]) + + userid := int64(r0["user_id"].(float64)) + userkey := r0["user_key"].(string) + + { + ri1 := tt.RequestGet[gin.H](t, baseUrl, fmt.Sprintf("/api/info.php?user_id=%d&user_key=%s", userid, userkey)) + tt.AssertEqual(t, "unack_count", 0, ri1["unack_count"]) + } + + r1 := tt.RequestPost[gin.H](t, baseUrl, "/send.php", tt.FormData{ + "user_id": fmt.Sprintf("%d", userid), + "user_key": userkey, + "title": "my title 11 & x", + }) + tt.AssertEqual(t, "success", true, r1["success"]) + r1scnid := int64(r1["scn_msg_id"].(float64)) + + { + ri1 := tt.RequestGet[gin.H](t, baseUrl, fmt.Sprintf("/api/info.php?user_id=%d&user_key=%s", userid, userkey)) + tt.AssertEqual(t, "unack_count", 1, ri1["unack_count"]) + } + + r2 := tt.RequestPost[gin.H](t, baseUrl, "/send.php", tt.FormData{ + "user_id": fmt.Sprintf("%d", userid), + "user_key": userkey, + "title": "my title 11 & x", + }) + tt.AssertEqual(t, "success", true, r2["success"]) + r2scnid := int64(r2["scn_msg_id"].(float64)) + + { + ri1 := tt.RequestGet[gin.H](t, baseUrl, fmt.Sprintf("/api/info.php?user_id=%d&user_key=%s", userid, userkey)) + tt.AssertEqual(t, "unack_count", 2, ri1["unack_count"]) + } + + { + ack := tt.RequestGet[gin.H](t, baseUrl, fmt.Sprintf("/api/ack.php?user_id=%d&user_key=%s&scn_msg_id=%d", userid, userkey, r1scnid)) + tt.AssertEqual(t, "success", true, ack["success"]) + tt.AssertEqual(t, "prev_ack", 0, ack["prev_ack"]) + tt.AssertEqual(t, "new_ack", 1, ack["new_ack"]) + tt.AssertEqual(t, "message", "ok", ack["message"]) + } + + { + ri1 := tt.RequestGet[gin.H](t, baseUrl, fmt.Sprintf("/api/info.php?user_id=%d&user_key=%s", userid, userkey)) + tt.AssertEqual(t, "unack_count", 1, ri1["unack_count"]) + } + + { + ack := tt.RequestGet[gin.H](t, baseUrl, fmt.Sprintf("/api/ack.php?user_id=%d&user_key=%s&scn_msg_id=%d", userid, userkey, r1scnid)) + tt.AssertEqual(t, "success", true, ack["success"]) + tt.AssertEqual(t, "prev_ack", 1, ack["prev_ack"]) + tt.AssertEqual(t, "new_ack", 1, ack["new_ack"]) + tt.AssertEqual(t, "message", "ok", ack["message"]) + } + + { + ri1 := tt.RequestGet[gin.H](t, baseUrl, fmt.Sprintf("/api/info.php?user_id=%d&user_key=%s", userid, userkey)) + tt.AssertEqual(t, "unack_count", 1, ri1["unack_count"]) + } + + { + ri1 := tt.RequestGet[gin.H](t, baseUrl, fmt.Sprintf("/api/info.php?user_id=%d&user_key=%s", userid, userkey)) + tt.AssertEqual(t, "unack_count", 1, ri1["unack_count"]) + } + + { + ack := tt.RequestGet[gin.H](t, baseUrl, fmt.Sprintf("/api/ack.php?user_id=%d&user_key=%s&scn_msg_id=%d", userid, userkey, r2scnid)) + tt.AssertEqual(t, "success", true, ack["success"]) + tt.AssertEqual(t, "prev_ack", 0, ack["prev_ack"]) + tt.AssertEqual(t, "new_ack", 1, ack["new_ack"]) + tt.AssertEqual(t, "message", "ok", ack["message"]) + } + + { + ri1 := tt.RequestGet[gin.H](t, baseUrl, fmt.Sprintf("/api/info.php?user_id=%d&user_key=%s", userid, userkey)) + tt.AssertEqual(t, "unack_count", 0, ri1["unack_count"]) + } + +}