diff --git a/scnserver/api/handler/apiKeyToken.go b/scnserver/api/handler/apiKeyToken.go index 22d5d51..0446123 100644 --- a/scnserver/api/handler/apiKeyToken.go +++ b/scnserver/api/handler/apiKeyToken.go @@ -221,9 +221,9 @@ func (h APIHandler) CreateUserKey(g *gin.Context) ginresp.HTTPResponse { } type body struct { Name string `json:"name" binding:"required"` - AllChannels *bool `json:"all_channels" binding:"required"` - Channels *[]models.ChannelID `json:"channels" binding:"required"` - Permissions *string `json:"permissions" binding:"required"` + Permissions string `json:"permissions" binding:"required"` + AllChannels *bool `json:"all_channels"` + Channels *[]models.ChannelID `json:"channels"` } var u uri @@ -234,7 +234,18 @@ func (h APIHandler) CreateUserKey(g *gin.Context) ginresp.HTTPResponse { } defer ctx.Cancel() - for _, c := range *b.Channels { + channels := langext.Coalesce(b.Channels, make([]models.ChannelID, 0)) + + var allChan bool + if b.AllChannels == nil && b.Channels != nil { + allChan = false + } else if b.AllChannels == nil && b.Channels == nil { + allChan = true + } else { + allChan = *b.AllChannels + } + + for _, c := range channels { if err := c.Valid(); err != nil { return ginresp.APIError(g, 400, apierr.INVALID_BODY_PARAM, "Invalid ChannelID", err) } @@ -246,9 +257,9 @@ func (h APIHandler) CreateUserKey(g *gin.Context) ginresp.HTTPResponse { token := h.app.GenerateRandomAuthKey() - perms := models.ParseTokenPermissionList(*b.Permissions) + perms := models.ParseTokenPermissionList(b.Permissions) - keytok, err := h.database.CreateKeyToken(ctx, b.Name, *ctx.GetPermissionUserID(), *b.AllChannels, *b.Channels, perms, token) + keytok, err := h.database.CreateKeyToken(ctx, b.Name, *ctx.GetPermissionUserID(), allChan, channels, perms, token) if err != nil { return ginresp.APIError(g, 500, apierr.DATABASE_ERROR, "Failed to create keytoken in db", err) } diff --git a/scnserver/test/keytoken_test.go b/scnserver/test/keytoken_test.go index b03a280..48ae2ca 100644 --- a/scnserver/test/keytoken_test.go +++ b/scnserver/test/keytoken_test.go @@ -383,6 +383,11 @@ func TestTokenKeysDowngradeSelf(t *testing.T) { data := tt.InitSingleData(t, ws) + chan0 := tt.RequestAuthPost[gin.H](t, data.AdminKey, baseUrl, fmt.Sprintf("/api/v2/users/%s/channels", data.UID), gin.H{ + "name": "testchan1", + }) + chanid := fmt.Sprintf("%v", chan0["channel_id"]) + type keyobj struct { AllChannels bool `json:"all_channels"` Channels []string `json:"channels"` @@ -415,7 +420,7 @@ func TestTokenKeysDowngradeSelf(t *testing.T) { }, 400, apierr.CANNOT_SELFUPDATE_KEY) tt.RequestAuthPatchShouldFail(t, data.AdminKey, baseUrl, fmt.Sprintf("/api/v2/users/%s/keys/%s", data.UID, ak), gin.H{ - "channels": []string{"main"}, + "channels": []string{chanid}, }, 400, apierr.CANNOT_SELFUPDATE_KEY) tt.RequestAuthPatch[tt.Void](t, data.AdminKey, baseUrl, fmt.Sprintf("/api/v2/users/%s/keys/%s", data.UID, ak), gin.H{ @@ -606,3 +611,64 @@ func TestTokenKeysMessageCounter(t *testing.T) { assertCounter(4, 4, 0) } + +func TestTokenKeysCreateDefaultParam(t *testing.T) { + ws, baseUrl, stop := tt.StartSimpleWebserver(t) + defer stop() + + data := tt.InitSingleData(t, ws) + + type keyobj struct { + AllChannels bool `json:"all_channels"` + Channels []string `json:"channels"` + KeytokenId string `json:"keytoken_id"` + MessagesSent int `json:"messages_sent"` + Name string `json:"name"` + OwnerUserId string `json:"owner_user_id"` + Permissions string `json:"permissions"` + Token string `json:"token"` // only in create + } + + chan0 := tt.RequestAuthPost[gin.H](t, data.AdminKey, baseUrl, fmt.Sprintf("/api/v2/users/%s/channels", data.UID), gin.H{ + "name": "testchan1", + }) + chanid := fmt.Sprintf("%v", chan0["channel_id"]) + + { + key2 := tt.RequestAuthPost[keyobj](t, data.AdminKey, baseUrl, fmt.Sprintf("/api/v2/users/%s/keys", data.UID), gin.H{ + "name": "K2", + "permissions": "CS", + }) + + tt.AssertEqual(t, "Name", "K2", key2.Name) + tt.AssertEqual(t, "Permissions", "CS", key2.Permissions) + tt.AssertEqual(t, "AllChannels", true, key2.AllChannels) + tt.AssertEqual(t, "Channels.Len", 0, len(key2.Channels)) + } + + { + key2 := tt.RequestAuthPost[keyobj](t, data.AdminKey, baseUrl, fmt.Sprintf("/api/v2/users/%s/keys", data.UID), gin.H{ + "name": "K3", + "permissions": "CS", + "channels": []string{chanid}, + }) + + tt.AssertEqual(t, "Name", "K3", key2.Name) + tt.AssertEqual(t, "Permissions", "CS", key2.Permissions) + tt.AssertEqual(t, "AllChannels", false, key2.AllChannels) + tt.AssertEqual(t, "Channels.Len", 1, len(key2.Channels)) + } + + { + key2 := tt.RequestAuthPost[keyobj](t, data.AdminKey, baseUrl, fmt.Sprintf("/api/v2/users/%s/keys", data.UID), gin.H{ + "name": "K4", + "permissions": "CS", + "all_channels": false, + }) + + tt.AssertEqual(t, "Name", "K4", key2.Name) + tt.AssertEqual(t, "Permissions", "CS", key2.Permissions) + tt.AssertEqual(t, "AllChannels", false, key2.AllChannels) + tt.AssertEqual(t, "Channels.Len", 0, len(key2.Channels)) + } +}