Skip to content

Commit

Permalink
fix: handled edgecase where stsClients map key could be same for diff…
Browse files Browse the repository at this point in the history
…erent clients(odps/rest/oss)
  • Loading branch information
Ayushi Sharma committed Dec 12, 2024
1 parent aa8de0b commit d094e7f
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 42 deletions.
15 changes: 10 additions & 5 deletions pkg/stsClient/stsClient.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ func NewSTS() *Sts {
}
}

func (s *Sts) IsSTSTokenValid(ramRole string) bool {
client := s.clients[ramRole]
func (s *Sts) IsSTSTokenValid(clientIdentifier string) bool {
client := s.clients[clientIdentifier]
if client == nil {
return false
}
Expand All @@ -53,21 +53,25 @@ func NewSTSClient(userAccessKeyID, userSecretAccessKey, regionID string) (*clien
return stsClient, nil
}

func (s *Sts) GetSTSClient(ramRole, userAccessKeyID, userSecret, regionID string) (*client.Client, error) {
func (s *Sts) GetSTSClient(clientIdentifier, userAccessKeyID, userSecret, regionID string) (*client.Client, error) {
if c, ok := s.clients[clientIdentifier]; ok {
return c.client, nil
}

stsClient, err := NewSTSClient(userAccessKeyID, userSecret, regionID)
if err != nil {
return nil, err
}

s.clients[ramRole] = &StsClient{
s.clients[clientIdentifier] = &StsClient{
client: stsClient,
expiryTimeStamp: time.Now().Add(time.Duration(assumeRoleDurationHours) * time.Hour),
}

return stsClient, nil
}

func AssumeRole(stsClient *client.Client, roleArn, roleSessionName string) (*openapiV2.Config, error) {
func AssumeRole(stsClient *client.Client, roleArn, roleSessionName, regionID string) (*openapiV2.Config, error) {
durationSeconds := assumeRoleDurationHours * int64(time.Hour.Seconds())
request := client.AssumeRoleRequest{
RoleArn: &roleArn,
Expand All @@ -84,6 +88,7 @@ func AssumeRole(stsClient *client.Client, roleArn, roleSessionName string) (*ope
AccessKeyId: res.Body.Credentials.AccessKeyId,
AccessKeySecret: res.Body.Credentials.AccessKeySecret,
SecurityToken: res.Body.Credentials.SecurityToken,
RegionId: &regionID,
}

return config, nil
Expand Down
126 changes: 89 additions & 37 deletions plugins/providers/maxcompute/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,32 @@ import (
sts "github.com/goto/guardian/pkg/stsClient"
"github.com/goto/guardian/utils"
"golang.org/x/net/context"

openapiV2 "github.com/alibabacloud-go/darabonba-openapi/v2/client"
)

//go:generate mockery --name=encryptor --exported --with-expecter
type encryptor interface {
domain.Crypto
}

type ODPSClient struct {
client *odps.Odps
stsClientExist bool
}

type RestClient struct {
client *maxcompute.Client
stsClientExist bool
}

type provider struct {
pv.UnimplementedClient
pv.PermissionManager
typeName string
encryptor encryptor
restClients map[string]*maxcompute.Client
odpsClients map[string]*odps.Odps
restClients map[string]RestClient
odpsClients map[string]ODPSClient
sts *sts.Sts
logger log.Logger
mu sync.Mutex
Expand All @@ -46,8 +58,8 @@ func New(
return &provider{
typeName: typeName,
encryptor: encryptor,
restClients: make(map[string]*maxcompute.Client),
odpsClients: make(map[string]*odps.Odps),
restClients: make(map[string]RestClient),
odpsClients: make(map[string]ODPSClient),
sts: sts.NewSTS(),

logger: logger,
Expand Down Expand Up @@ -352,25 +364,34 @@ func (p *provider) getCreds(pc *domain.ProviderConfig) (*credentials, error) {
}

func (p *provider) getRestClient(pc *domain.ProviderConfig) (*maxcompute.Client, error) {
if client, ok := p.restClients[pc.URN]; ok {
if p.sts.IsSTSTokenValid(pc.URN) {
return client, nil
}
}

creds, err := p.getCreds(pc)
if err != nil {
return nil, err
}

stsClient, err := p.sts.GetSTSClient(pc.URN, creds.AccessKeyID, creds.AccessKeySecret, creds.RegionID)
if err != nil {
return nil, err
ramRole, stsClientID := p.getRamRoleAndStsClientID("rest", creds, "")
if restClient, ok := p.getCachedRestClient(ramRole, stsClientID, pc.URN); ok {
return restClient, nil
}

clientConfig, err := sts.AssumeRole(stsClient, creds.RAMRole, pc.URN)
if err != nil {
return nil, err
var clientConfig *openapiV2.Config
if creds.RAMRole != "" {
stsClient, err := p.sts.GetSTSClient(stsClientID, creds.AccessKeyID, creds.AccessKeySecret, creds.RegionID)
if err != nil {
return nil, err
}

clientConfig, err = sts.AssumeRole(stsClient, creds.RAMRole, pc.URN, creds.RegionID)
if err != nil {
return nil, err
}
} else {
endpoint := fmt.Sprintf("http://service.%s.maxcompute.aliyun.com/api", creds.RegionID)
clientConfig = &openapiV2.Config{
AccessKeyId: &creds.AccessKeyID,
AccessKeySecret: &creds.AccessKeySecret,
Endpoint: &endpoint,
}
}

restClient, err := maxcompute.NewClient(clientConfig)
Expand All @@ -379,7 +400,11 @@ func (p *provider) getRestClient(pc *domain.ProviderConfig) (*maxcompute.Client,
}

p.mu.Lock()
p.restClients[pc.URN] = restClient
if creds.RAMRole != "" {
p.restClients[creds.RAMRole] = RestClient{client: restClient, stsClientExist: true}
} else {
p.restClients[pc.URN] = RestClient{client: restClient}
}
p.mu.Unlock()
return restClient, nil
}
Expand All @@ -391,33 +416,20 @@ func (p *provider) getOdpsClient(pc *domain.ProviderConfig, ramRoleFromAppeal st
}

// getting client from memory cache
var ramRole string
switch {
case ramRoleFromAppeal != "":
ramRole = ramRoleFromAppeal
if c, ok := p.odpsClients[ramRoleFromAppeal]; ok && p.sts.IsSTSTokenValid(ramRoleFromAppeal) {
return c, nil
}
case creds.RAMRole != "":
ramRole = creds.RAMRole
if c, ok := p.odpsClients[pc.URN]; ok && p.sts.IsSTSTokenValid(creds.RAMRole) {
return c, nil
}
default:
if c, ok := p.odpsClients[pc.URN]; ok {
return c, nil
}
ramRole, stsClientID := p.getRamRoleAndStsClientID("odps", creds, ramRoleFromAppeal)
if odpsClient, ok := p.getCachedOdpsClient(ramRole, stsClientID, pc.URN); ok {
return odpsClient, nil
}

// initialize new client
var acc account.Account
if ramRole != "" {
stsClient, err := p.sts.GetSTSClient(ramRoleFromAppeal, creds.AccessKeyID, creds.AccessKeySecret, creds.RegionID)
stsClient, err := p.sts.GetSTSClient(stsClientID, creds.AccessKeyID, creds.AccessKeySecret, creds.RegionID)
if err != nil {
return nil, err
}

clientConfig, err := sts.AssumeRole(stsClient, ramRoleFromAppeal, pc.URN)
clientConfig, err := sts.AssumeRole(stsClient, ramRole, pc.URN, creds.RegionID)
if err != nil {
return nil, err
}
Expand All @@ -430,15 +442,55 @@ func (p *provider) getOdpsClient(pc *domain.ProviderConfig, ramRoleFromAppeal st

p.mu.Lock()
if ramRoleFromAppeal != "" {
p.odpsClients[ramRoleFromAppeal] = client
p.odpsClients[ramRoleFromAppeal] = ODPSClient{client: client, stsClientExist: true}
} else {
p.odpsClients[pc.URN] = client
p.odpsClients[pc.URN] = ODPSClient{client: client}
}
p.mu.Unlock()

return client, nil
}

func (p *provider) getRamRoleAndStsClientID(clientType string, creds *credentials, ramRoleFromAppeal string) (string, string) {
var ramRole string
switch {
case ramRoleFromAppeal != "":
ramRole = ramRoleFromAppeal
case creds.RAMRole != "":
ramRole = creds.RAMRole
}
stsClientID := clientType + "-" + ramRole
return ramRole, stsClientID
}

func (p *provider) getCachedOdpsClient(ramRole, stsClientID, urn string) (*odps.Odps, bool) {
if c, ok := p.odpsClients[ramRole]; ok {
if c.stsClientExist && p.sts.IsSTSTokenValid(stsClientID) {
return c.client, true
}
return c.client, true
}

if c, ok := p.odpsClients[urn]; ok {
return c.client, true
}

return nil, false
}

func (p *provider) getCachedRestClient(ramRole, stsClientID, urn string) (*maxcompute.Client, bool) {
c, ok := p.restClients[ramRole]
if ok && c.stsClientExist && p.sts.IsSTSTokenValid(stsClientID) {
return c.client, true
}

if c, ok := p.restClients[urn]; ok {
return c.client, true
}

return nil, false
}

func getParametersFromGrant[T any](g domain.Grant, key string) (T, bool, error) {
var value T
if g.Appeal == nil {
Expand Down

0 comments on commit d094e7f

Please sign in to comment.