diff --git a/plugins/providers/oss/provider.go b/plugins/providers/oss/provider.go index 80f49c0e..ecfc75bd 100644 --- a/plugins/providers/oss/provider.go +++ b/plugins/providers/oss/provider.go @@ -258,7 +258,7 @@ func revokePermissionsFromPolicy(policyString string, g domain.Grant) (string, e return "", err } - principalAccountID, err := getPrincipalFromAccountID(g.AccountID) + principalAccountID, err := getPrincipalFromAccountID(g.AccountID, g.AccountType) if err != nil { return "", err } @@ -310,7 +310,7 @@ func updatePolicyToGrantPermissions(policy string, g domain.Grant) (string, erro return "", err } - principalAccountID, err := getPrincipalFromAccountID(g.AccountID) + principalAccountID, err := getPrincipalFromAccountID(g.AccountID, g.AccountType) if err != nil { return "", err } @@ -462,18 +462,42 @@ func getAccountIDFromResource(resource *domain.Resource) (string, error) { return urnParts[2], nil } -func getPrincipalFromAccountID(accountID string) (string, error) { - accountIDParts := strings.Split(accountID, "$") - if len(accountIDParts) < 2 { - return "", fmt.Errorf("invalid accountID format") - } +func getPrincipalFromAccountID(accountID, accountType string) (string, error) { + // AccountTypeRAMUser = RAM$: + // AccountTypeRAMRole = acs:ram:::role/ + if accountType == AccountTypeRAMUser { + accountIDParts := strings.Split(accountID, "$") + if len(accountIDParts) < 2 { + return "", fmt.Errorf("invalid accountID format: %q", accountID) + } + + subParts := strings.Split(accountIDParts[1], ":") + if len(subParts) < 2 { + return "", fmt.Errorf("invalid accountID format: %q", accountID) + } + + return subParts[1], nil + } else if accountType == AccountTypeRAMRole { - subParts := strings.Split(accountIDParts[1], ":") - if len(subParts) < 2 { - return "", fmt.Errorf("invalid accountID format") + accountIDParts := strings.Split(accountID, ":") + if len(accountIDParts) < 5 { + return "", fmt.Errorf("invalid accountID format: %q", accountID) + } + + mainAccountID := accountIDParts[3] + roleNameParts := strings.Split(accountIDParts[4], "/") + if len(roleNameParts) < 2 { + return "", fmt.Errorf("invalid accountID format: %q", accountID) + } + + roleName := roleNameParts[1] + + // STS ARN - arn:sts:::assumed-role//* + return fmt.Sprintf("arn:sts::%s:assumed-role/%s/*", mainAccountID, roleName), nil } - return subParts[1], nil + return "", fmt.Errorf("invalid account type: %q", accountType) + } func unmarshalPolicy(policy string) (Policy, error) {