diff --git a/examples/rp/config.go b/examples/rp/config.go index 16c29f3..e29e7d9 100644 --- a/examples/rp/config.go +++ b/examples/rp/config.go @@ -10,15 +10,15 @@ import ( ) type config struct { - EntityID string `yaml:"entity_id"` - TrustAnchors pkg.TrustAnchors `yaml:"trust_anchors"` - AuthorityHints []string `yaml:"authority_hints"` - OrganisationName string `yaml:"organisation_name"` - ServerAddr string `yaml:"server_addr"` - KeyStorage string `yaml:"key_storage"` - OnlyAutomaticOPs bool `yaml:"filter_to_automatic_ops"` - EnableDebugLog bool `yaml:"enable_debug_log"` - TrustMarks []pkg.TrustMarkInfo `yaml:"trust_marks"` + EntityID string `yaml:"entity_id"` + TrustAnchors pkg.TrustAnchors `yaml:"trust_anchors"` + AuthorityHints []string `yaml:"authority_hints"` + OrganisationName string `yaml:"organisation_name"` + ServerAddr string `yaml:"server_addr"` + KeyStorage string `yaml:"key_storage"` + OnlyAutomaticOPs bool `yaml:"filter_to_automatic_ops"` + EnableDebugLog bool `yaml:"enable_debug_log"` + TrustMarks []*pkg.EntityConfigurationTrustMarkConfig `yaml:"trust_marks"` } var conf *config @@ -45,4 +45,9 @@ func mustLoadConfig() { if conf.EnableDebugLog { pkg.EnableDebugLogging() } + for _, c := range conf.TrustMarks { + if err = c.Verify(conf.EntityID); err != nil { + log.Fatal(err) + } + } } diff --git a/examples/ta/config/config.go b/examples/ta/config/config.go index 55cb3e9..fc30811 100644 --- a/examples/ta/config/config.go +++ b/examples/ta/config/config.go @@ -16,21 +16,21 @@ import ( // Config holds configuration for the entity type Config struct { - ServerPort int `yaml:"server_port"` - EntityID string `yaml:"entity_id"` - AuthorityHints []string `yaml:"authority_hints"` - MetadataPolicyFile string `yaml:"metadata_policy_file"` - MetadataPolicy *pkg.MetadataPolicies `yaml:"-"` - SigningKeyFile string `yaml:"signing_key_file"` - ConfigurationLifetime int64 `yaml:"configuration_lifetime"` - OrganizationName string `yaml:"organization_name"` - DataLocation string `yaml:"data_location"` - ReadableStorage bool `yaml:"human_readable_storage"` - Endpoints Endpoints `yaml:"endpoints"` - TrustMarkSpecs []extendedTrustMarkSpec `yaml:"trust_mark_specs"` - TrustMarks []pkg.TrustMarkInfo `yaml:"trust_marks"` - TrustMarkIssuers pkg.AllowedTrustMarkIssuers `yaml:"trust_mark_issuers"` - TrustMarkOwners pkg.TrustMarkOwners `yaml:"trust_mark_owners"` + ServerPort int `yaml:"server_port"` + EntityID string `yaml:"entity_id"` + AuthorityHints []string `yaml:"authority_hints"` + MetadataPolicyFile string `yaml:"metadata_policy_file"` + MetadataPolicy *pkg.MetadataPolicies `yaml:"-"` + SigningKeyFile string `yaml:"signing_key_file"` + ConfigurationLifetime int64 `yaml:"configuration_lifetime"` + OrganizationName string `yaml:"organization_name"` + DataLocation string `yaml:"data_location"` + ReadableStorage bool `yaml:"human_readable_storage"` + Endpoints Endpoints `yaml:"endpoints"` + TrustMarkSpecs []extendedTrustMarkSpec `yaml:"trust_mark_specs"` + TrustMarks []*pkg.EntityConfigurationTrustMarkConfig `yaml:"trust_marks"` + TrustMarkIssuers pkg.AllowedTrustMarkIssuers `yaml:"trust_mark_issuers"` + TrustMarkOwners pkg.TrustMarkOwners `yaml:"trust_mark_owners"` } type extendedTrustMarkSpec struct { @@ -132,4 +132,9 @@ func Load(filename string) { log.Fatal(err) } } + for _, tmc := range c.TrustMarks { + if err = tmc.Verify(c.EntityID); err != nil { + log.Fatal(err) + } + } } diff --git a/go.mod b/go.mod index aa48c54..891eb3b 100644 --- a/go.mod +++ b/go.mod @@ -7,8 +7,10 @@ require ( github.com/adam-hanna/arrayOperations v1.0.1 github.com/dgraph-io/badger/v4 v4.5.0 github.com/fatih/structs v1.1.0 + github.com/go-resty/resty/v2 v2.16.2 github.com/gofiber/fiber/v2 v2.52.5 github.com/google/uuid v1.6.0 + github.com/jarcoal/httpmock v1.3.1 github.com/lestrrat-go/jwx v1.2.30 github.com/luci/go-render v0.0.0-20160219211803-9a04cc21af0f github.com/pkg/errors v0.9.1 diff --git a/go.sum b/go.sum index 15bffac..0ff3158 100644 --- a/go.sum +++ b/go.sum @@ -36,6 +36,8 @@ github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1m github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/fatih/structs v1.1.0 h1:Q7juDM0QtcnhCpeyLGQKyg4TOIghuNXrkL32pHAUMxo= github.com/fatih/structs v1.1.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga6PJ7M= +github.com/go-resty/resty/v2 v2.16.2 h1:CpRqTjIzq/rweXUt9+GxzzQdlkqMdt8Lm/fuK/CAbAg= +github.com/go-resty/resty/v2 v2.16.2/go.mod h1:0fHAoK7JoBy/Ch36N8VFeMsK7xQOHhvWaC3iOktwmIU= github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA= github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/gofiber/fiber/v2 v2.52.5 h1:tWoP1MJQjGEe4GB5TUGOi7P2E0ZMMRx5ZTG4rT+yGMo= @@ -66,6 +68,8 @@ github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeN github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/jarcoal/httpmock v1.3.1 h1:iUx3whfZWVf3jT01hQTO/Eo5sAYtB2/rqaUuOtpInww= +github.com/jarcoal/httpmock v1.3.1/go.mod h1:3yb8rc4BI7TCBhFY8ng0gjuLKJNquuDNiPaZjnENuYg= github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc= github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0= github.com/lestrrat-go/backoff/v2 v2.0.8 h1:oNb5E5isby2kiro9AgdHLv5N5tint1AnDVVf2E2un5A= @@ -90,6 +94,8 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= +github.com/maxatome/go-testdeep v1.12.0 h1:Ql7Go8Tg0C1D/uMMX59LAoYK7LffeJQ6X2T04nTH68g= +github.com/maxatome/go-testdeep v1.12.0/go.mod h1:lPZc/HAcJMP92l7yI6TRz1aZN5URwUBUAfUNvrclaNM= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -152,6 +158,8 @@ golang.org/x/sys v0.27.0 h1:wBqf8DvsY9Y/2P8gAfPDEYNuS30J4lPHJxXSb/nJZ+s= golang.org/x/sys v0.27.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/time v0.6.0 h1:eTDhh4ZXt5Qf0augr54TN6suAUudPcawVZeIAPU7D4U= +golang.org/x/time v0.6.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= diff --git a/go.work.sum b/go.work.sum index 157c62b..68511d1 100644 --- a/go.work.sum +++ b/go.work.sum @@ -669,6 +669,8 @@ github.com/lyft/protoc-gen-star v0.6.1 h1:erE0rdztuaDq3bpGifD95wfoPrSZc95nGA6tbi github.com/lyft/protoc-gen-star v0.6.1/go.mod h1:TGAoBVkt8w7MPG72TrKIu85MIdXwDuzJYeZuUPFPNwA= github.com/lyft/protoc-gen-star/v2 v2.0.1 h1:keaAo8hRuAT0O3DfJ/wM3rufbAjGeJ1lAtWZHDjKGB0= github.com/lyft/protoc-gen-star/v2 v2.0.1/go.mod h1:RcCdONR2ScXaYnQC5tUzxzlpA3WVYF7/opLeUgcQs/o= +github.com/maxatome/go-testdeep v1.12.0 h1:Ql7Go8Tg0C1D/uMMX59LAoYK7LffeJQ6X2T04nTH68g= +github.com/maxatome/go-testdeep v1.12.0/go.mod h1:lPZc/HAcJMP92l7yI6TRz1aZN5URwUBUAfUNvrclaNM= github.com/nxadm/tail v1.4.4 h1:DQuhQpB1tVlglWS2hLQ5OV6B5r8aGxSrPc5Qo6uTN78= github.com/onsi/ginkgo v1.14.1 h1:jMU0WaQrP0a/YAEq8eJmJKjBoMs+pClEr1vDMlM/Do4= github.com/onsi/gomega v1.10.2 h1:aY/nuoWlKJud2J6U0E3NWsjlg+0GtwXxgEqthRdzlcs= @@ -903,6 +905,8 @@ golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxb golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.6.0 h1:eTDhh4ZXt5Qf0augr54TN6suAUudPcawVZeIAPU7D4U= +golang.org/x/time v0.6.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e h1:FDhOuMEY4JVRztM/gsbk+IKUQ8kj74bxZrgw87eMMVc= golang.org/x/tools v0.0.0-20190312151545-0bb0c0a6e846/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190312170243-e65039ee4138/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= diff --git a/internal/constants/constants.go b/internal/constants/constants.go new file mode 100644 index 0000000..9df3297 --- /dev/null +++ b/internal/constants/constants.go @@ -0,0 +1,4 @@ +package constants + +// FederationSuffix is the well-known openid-federation suffix +const FederationSuffix = "/.well-known/openid-federation" diff --git a/internal/http.go b/internal/http.go deleted file mode 100644 index af8710e..0000000 --- a/internal/http.go +++ /dev/null @@ -1,74 +0,0 @@ -package internal - -import ( - "io" - "net/http" - "net/url" - "strings" - - "github.com/pkg/errors" -) - -const federationSuffix = "/.well-known/openid-federation" - -// EntityStatementObtainer is interface for a type obtaining entity configurations and entity statements -type EntityStatementObtainer interface { - GetEntityConfiguration(entityID string) ([]byte, error) - FetchEntityStatement(fetchEndpoint, subID, issID string) ([]byte, error) - ListEntities(listEndpoint, entityType string) ([]byte, error) -} - -type defaultHttpEntityStatementObtainer struct{} - -// DefaultHttpEntityStatementObtainer is the default EntityStatementObtainer to obtain entity statements through http -var DefaultHttpEntityStatementObtainer defaultHttpEntityStatementObtainer - -// GetEntityConfiguration implements the EntityStatementObtainer interface -// It returns the decoded entity configuration for a given entityID -func (defaultHttpEntityStatementObtainer) GetEntityConfiguration(entityID string) ([]byte, error) { - uri := strings.TrimSuffix(entityID, "/") + federationSuffix - Logf("Obtaining entity configuration from %+q", uri) - res, err := http.Get(uri) - if err != nil { - return nil, err - } - if status := res.StatusCode; status >= 300 { - return nil, errors.Errorf("could not obtain entity statement, received status code %d", status) - } - return io.ReadAll(res.Body) -} - -// FetchEntityStatement implements the EntityStatementObtainer interface -// It fetches and returns the decoded entity statement about a given entityID issued by issID -func (defaultHttpEntityStatementObtainer) FetchEntityStatement(fetchEndpoint, subID, issID string) ([]byte, error) { - uri := fetchEndpoint - params := url.Values{} - params.Add("sub", subID) - params.Add("iss", issID) - uri += "?" + params.Encode() - res, err := http.Get(uri) - if err != nil { - return nil, err - } - if status := res.StatusCode; status >= 300 { - return nil, errors.Errorf("could not obtain entity statement, received status code %d", status) - } - return io.ReadAll(res.Body) -} - -// ListEntities implements the EntityStatementObtainer interface -// It fetches and returns the entity list from the passed listendpoint -func (defaultHttpEntityStatementObtainer) ListEntities(listEndpoint, entityType string) ([]byte, error) { - uri := listEndpoint - params := url.Values{} - params.Add("entity_type", entityType) - uri += "?" + params.Encode() - res, err := http.Get(uri) - if err != nil { - return nil, err - } - if status := res.StatusCode; status >= 300 { - return nil, errors.Errorf("could not obtain entity statement, received status code %d", status) - } - return io.ReadAll(res.Body) -} diff --git a/internal/http/http.go b/internal/http/http.go new file mode 100644 index 0000000..8d3c404 --- /dev/null +++ b/internal/http/http.go @@ -0,0 +1,69 @@ +package http + +import ( + "fmt" + "net/url" + "time" + + "github.com/go-resty/resty/v2" + "github.com/pkg/errors" +) + +var client *resty.Client + +func init() { + client = resty.New() + client.SetCookieJar(nil) + // client.SetDisableWarn(true) + client.SetRetryCount(2) + client.SetRedirectPolicy(resty.FlexibleRedirectPolicy(10)) + client.SetTimeout(20 * time.Second) +} + +// HttpError is a type for returning the server's error response including its status code +type HttpError struct { + Error string `json:"error"` + ErrorDescription string `json:"error_description"` + Status int +} + +// Err returns an error including the server's error response +func (e *HttpError) Err() error { + errStr := fmt.Sprintf("http error response: %d: %s", e.Status, e.Error) + if e.ErrorDescription != "" { + errStr += ": " + e.ErrorDescription + } + return errors.New(errStr) + +} + +// Do returns the client, so it can be used to do requests +func Do() *resty.Client { + return client +} + +// Get performs a http GET request and parses the response into the given interface{} +func Get(url string, params url.Values, res interface{}) (*resty.Response, *HttpError, error) { + resp, err := client.R().SetQueryParamsFromValues(params).SetError(&HttpError{}).SetResult(res).Get(url) + if err != nil { + return nil, nil, errors.WithStack(err) + } + if errRes, ok := resp.Error().(*HttpError); ok && errRes != nil && errRes.Error != "" { + errRes.Status = resp.RawResponse.StatusCode + return nil, errRes, nil + } + return resp, nil, nil +} + +// Post performs a http POST request and parses the response into the given interface{} +func Post(url string, req interface{}, res interface{}) (*resty.Response, *HttpError, error) { + resp, err := client.R().SetBody(req).SetError(&HttpError{}).SetResult(res).Post(url) + if err != nil { + return nil, nil, errors.WithStack(err) + } + if errRes, ok := resp.Error().(*HttpError); ok && errRes != nil && errRes.Error != "" { + errRes.Status = resp.RawResponse.StatusCode + return nil, errRes, nil + } + return resp, nil, nil +} diff --git a/internal/jwx/jws.go b/internal/jwx/jws.go index 405e5cd..5fe126d 100644 --- a/internal/jwx/jws.go +++ b/internal/jwx/jws.go @@ -6,11 +6,13 @@ import ( "github.com/lestrrat-go/jwx/jwa" "github.com/lestrrat-go/jwx/jwk" "github.com/lestrrat-go/jwx/jws" + "github.com/lestrrat-go/jwx/jwt" "github.com/pkg/errors" "github.com/vmihailenco/msgpack/v5" "github.com/zachmann/go-oidfed/internal/utils" myjwk "github.com/zachmann/go-oidfed/pkg/jwk" + "github.com/zachmann/go-oidfed/pkg/unixtime" ) // ParsedJWT is a type extending jws.Message by holding the original jwt @@ -113,3 +115,13 @@ func SignPayload(payload []byte, signingAlg jwa.SignatureAlgorithm, key crypto.S } return jws.Sign(payload, signingAlg, key, jws.WithHeaders(headers)) } + +// GetExp returns the expiration of a jwt +func GetExp(bytes []byte) (exp unixtime.Unixtime, err error) { + parsed, err := jwt.Parse(bytes) + if err != nil { + err = errors.WithStack(err) + return + } + return unixtime.Unixtime{Time: parsed.Expiration()}, nil +} diff --git a/pkg/discovery.go b/pkg/discovery.go index bfe3822..9662b84 100644 --- a/pkg/discovery.go +++ b/pkg/discovery.go @@ -1,9 +1,12 @@ package pkg import ( - "encoding/json" + "net/url" + + "github.com/pkg/errors" "github.com/zachmann/go-oidfed/internal" + "github.com/zachmann/go-oidfed/internal/http" "github.com/zachmann/go-oidfed/internal/utils" ) @@ -183,15 +186,20 @@ var OPDiscoveryFilterExplicitRegistration opDiscoveryFilterExplicitRegistration var OPDiscoveryFilterAutomaticRegistration opDiscoveryFilterAutomaticRegistration func fetchList(listEndpoint, entityType string) ([]string, error) { - body, err := entityStatementObtainer.ListEntities(listEndpoint, entityType) + params := url.Values{} + params.Add("entity_type", entityType) + resp, errRes, err := http.Get(listEndpoint, params, &[]string{}) if err != nil { return nil, err } - var entities []string - if err = json.Unmarshal(body, &entities); err != nil { - return nil, err + if errRes != nil { + return nil, errRes.Err() + } + entities, ok := resp.Result().(*[]string) + if !ok || entities == nil { + return nil, errors.New("unexpected response type") } - return entities, nil + return *entities, nil } // OPDiscoveryFilterSupportedGrantTypesIncludes returns an OPDiscoveryFilter that filters to OPs that support the diff --git a/pkg/entitystatement.go b/pkg/entitystatement.go index e436c9c..44b3edd 100644 --- a/pkg/entitystatement.go +++ b/pkg/entitystatement.go @@ -11,6 +11,7 @@ import ( "github.com/zachmann/go-oidfed/internal/jwx" "github.com/zachmann/go-oidfed/internal/utils" "github.com/zachmann/go-oidfed/pkg/jwk" + "github.com/zachmann/go-oidfed/pkg/unixtime" "github.com/fatih/structs" ) @@ -66,8 +67,8 @@ func (e *EntityStatement) UnmarshalMsgpack(data []byte) error { type EntityStatementPayload struct { Issuer string `json:"iss"` Subject string `json:"sub"` - IssuedAt Unixtime `json:"iat"` - ExpiresAt Unixtime `json:"exp"` + IssuedAt unixtime.Unixtime `json:"iat"` + ExpiresAt unixtime.Unixtime `json:"exp"` JWKS jwk.JWKS `json:"jwks"` Audience string `json:"aud,omitempty"` AuthorityHints []string `json:"authority_hints,omitempty"` @@ -86,7 +87,7 @@ type EntityStatementPayload struct { // TimeValid checks if the EntityStatementPayload is already valid and not yet expired. func (e EntityStatementPayload) TimeValid() bool { - return verifyTime(&e.IssuedAt, &e.ExpiresAt) == nil + return unixtime.VerifyTime(&e.IssuedAt, &e.ExpiresAt) == nil } func extraMarshalHelper(explicitFields []byte, extra map[string]interface{}) ([]byte, error) { diff --git a/pkg/entitystatement_test.go b/pkg/entitystatement_test.go index 377bf85..3c76224 100644 --- a/pkg/entitystatement_test.go +++ b/pkg/entitystatement_test.go @@ -15,6 +15,7 @@ import ( "github.com/vmihailenco/msgpack/v5" "github.com/zachmann/go-oidfed/pkg/jwk" + "github.com/zachmann/go-oidfed/pkg/unixtime" ) type marshalData struct { @@ -32,8 +33,8 @@ var entitystatementMarshalData = map[string]marshalData{ Object: EntityStatementPayload{ Issuer: "issuer", Subject: "subject", - IssuedAt: Unixtime{time.Unix(100, 0)}, - ExpiresAt: Unixtime{time.Unix(200, 0)}, + IssuedAt: unixtime.Unixtime{Time: time.Unix(100, 0)}, + ExpiresAt: unixtime.Unixtime{Time: time.Unix(200, 0)}, Audience: "aud", AuthorityHints: []string{ "hint1", @@ -80,8 +81,8 @@ var entitystatementMarshalData = map[string]marshalData{ "extra fields": { Data: []byte(`{"exp":200,"extra-field":"value","foo":["bar"],"iat":100,"iss":"issuer","jwks":null,"sub":"subject"}`), Object: EntityStatementPayload{ - IssuedAt: Unixtime{time.Unix(100, 0)}, - ExpiresAt: Unixtime{time.Unix(200, 0)}, + IssuedAt: unixtime.Unixtime{Time: time.Unix(100, 0)}, + ExpiresAt: unixtime.Unixtime{Time: time.Unix(200, 0)}, Issuer: "issuer", Subject: "subject", Extra: map[string]interface{}{ diff --git a/pkg/fedentities/fedentity.go b/pkg/fedentities/fedentity.go index 0e2d5cc..1f9e89e 100644 --- a/pkg/fedentities/fedentity.go +++ b/pkg/fedentities/fedentity.go @@ -19,6 +19,7 @@ import ( "github.com/zachmann/go-oidfed/pkg/cache" "github.com/zachmann/go-oidfed/pkg/constants" "github.com/zachmann/go-oidfed/pkg/fedentities/storage" + "github.com/zachmann/go-oidfed/pkg/unixtime" ) const entityConfigurationCachePeriod = 5 * time.Second @@ -141,8 +142,8 @@ func (fed FedEntity) CreateSubordinateStatement(subordinate *storage.Subordinate return pkg.EntityStatementPayload{ Issuer: fed.FederationEntity.EntityID, Subject: subordinate.EntityID, - IssuedAt: pkg.Unixtime{Time: now}, - ExpiresAt: pkg.Unixtime{Time: now.Add(time.Duration(fed.SubordinateStatementLifetime) * time.Second)}, + IssuedAt: unixtime.Unixtime{Time: now}, + ExpiresAt: unixtime.Unixtime{Time: now.Add(time.Duration(fed.SubordinateStatementLifetime) * time.Second)}, SourceEndpoint: fed.Metadata.FederationEntity.FederationFetchEndpoint, JWKS: subordinate.JWKS, Metadata: subordinate.Metadata, diff --git a/pkg/fedentities/resolve.go b/pkg/fedentities/resolve.go index 6182a3c..95bae9a 100644 --- a/pkg/fedentities/resolve.go +++ b/pkg/fedentities/resolve.go @@ -7,6 +7,7 @@ import ( "github.com/zachmann/go-oidfed/pkg" "github.com/zachmann/go-oidfed/pkg/constants" + "github.com/zachmann/go-oidfed/pkg/unixtime" ) type resolveRequest struct { @@ -60,7 +61,7 @@ func (fed *FedEntity) AddResolveEndpoint(endpoint EndpointConf) { res := pkg.ResolveResponse{ Issuer: fed.FederationEntity.EntityID, Subject: req.Subject, - IssuedAt: pkg.Unixtime{Time: time.Now()}, + IssuedAt: unixtime.Unixtime{Time: time.Now()}, ExpiresAt: selectedChain.ExpiresAt(), Metadata: metadata, TrustMarks: verifiedTMs, diff --git a/pkg/federation.go b/pkg/federation.go index 509ef29..a0781f1 100644 --- a/pkg/federation.go +++ b/pkg/federation.go @@ -10,6 +10,7 @@ import ( "github.com/zachmann/go-oidfed/internal" "github.com/zachmann/go-oidfed/pkg/cache" "github.com/zachmann/go-oidfed/pkg/jwk" + "github.com/zachmann/go-oidfed/pkg/unixtime" ) // FederationEntity is a type for an entity participating in federations. @@ -22,7 +23,7 @@ type FederationEntity struct { ConfigurationLifetime int64 *EntityStatementSigner jwks jwk.JWKS - TrustMarks []TrustMarkInfo + TrustMarks []*EntityConfigurationTrustMarkConfig TrustMarkIssuers AllowedTrustMarkIssuers TrustMarkOwners TrustMarkOwners } @@ -75,15 +76,29 @@ func NewFederationLeaf( // EntityConfigurationPayload returns an EntityStatementPayload for this FederationEntity func (f FederationEntity) EntityConfigurationPayload() *EntityStatementPayload { now := time.Now() + var tms []TrustMarkInfo + for _, tmc := range f.TrustMarks { + tm, err := tmc.TrustMarkJWT() + if err != nil { + internal.Log(err.Error()) + continue + } + tms = append( + tms, TrustMarkInfo{ + ID: tmc.TrustMarkID, + TrustMarkJWT: tm, + }, + ) + } return &EntityStatementPayload{ Issuer: f.EntityID, Subject: f.EntityID, - IssuedAt: Unixtime{now}, - ExpiresAt: Unixtime{now.Add(time.Second * time.Duration(f.ConfigurationLifetime))}, + IssuedAt: unixtime.Unixtime{Time: now}, + ExpiresAt: unixtime.Unixtime{Time: now.Add(time.Second * time.Duration(f.ConfigurationLifetime))}, JWKS: f.jwks, AuthorityHints: f.AuthorityHints, Metadata: f.Metadata, - TrustMarks: f.TrustMarks, + TrustMarks: tms, TrustMarkIssuers: f.TrustMarkIssuers, TrustMarkOwners: f.TrustMarkOwners, } diff --git a/pkg/mock_authority.go b/pkg/mock_authority.go index 01aa7e1..0db767a 100644 --- a/pkg/mock_authority.go +++ b/pkg/mock_authority.go @@ -11,6 +11,7 @@ import ( "github.com/lestrrat-go/jwx/jwa" "github.com/zachmann/go-oidfed/pkg/jwk" + "github.com/zachmann/go-oidfed/pkg/unixtime" ) type mockAuthority struct { @@ -22,6 +23,22 @@ type mockAuthority struct { subordinates []mockSubordinateInfo } +func (a mockAuthority) EntityConfigurationJWT() ([]byte, error) { + return a.EntityStatementSigner.JWT(a.EntityStatementPayload()) +} + +func (a mockAuthority) FetchResponse(sub string) ([]byte, error) { + pay := a.SubordinateEntityStatementPayload(sub) + return a.EntityStatementSigner.JWT(pay) +} + +func (a mockAuthority) Subordinates(_ string) (subordinates []string, err error) { + for _, sub := range a.subordinates { + subordinates = append(subordinates, sub.entityID) + } + return +} + type mockSubordinateInfo struct { entityID string jwks jwk.JWKS @@ -32,7 +49,7 @@ type mockSubordinate interface { AddAuthority(authorityID string) } -func newMockAuthority(entityID string, data EntityStatementPayload) mockAuthority { +func newMockAuthority(entityID string, data EntityStatementPayload) *mockAuthority { sk, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) if err != nil { panic(err) @@ -40,7 +57,7 @@ func newMockAuthority(entityID string, data EntityStatementPayload) mockAuthorit data.JWKS = jwk.KeyToJWKS(sk.Public(), jwa.ES512) data.Issuer = entityID data.Subject = entityID - a := mockAuthority{ + a := &mockAuthority{ EntityID: entityID, FetchEndpoint: fmt.Sprintf("%s/fetch", entityID), ListEndpoint: fmt.Sprintf("%s/list", entityID), @@ -56,14 +73,19 @@ func newMockAuthority(entityID string, data EntityStatementPayload) mockAuthorit a.data.Metadata.FederationEntity.OrganizationName = fmt.Sprintf("Organization %d", mathrand.Int()%100) a.data.Metadata.FederationEntity.FederationFetchEndpoint = a.FetchEndpoint a.data.Metadata.FederationEntity.FederationListEndpoint = a.ListEndpoint + + mockEntityConfiguration(a.EntityID, a) + mockFetchEndpoint(a.FetchEndpoint, a) + mockListEndpoint(a.ListEndpoint, a) + return a } func (a mockAuthority) EntityStatementPayload() *EntityStatementPayload { now := time.Now() payload := a.data - payload.IssuedAt = Unixtime{now} - payload.ExpiresAt = Unixtime{now.Add(time.Second * time.Duration(mockStmtLifetime))} + payload.IssuedAt = unixtime.Unixtime{Time: now} + payload.ExpiresAt = unixtime.Unixtime{Time: now.Add(time.Second * time.Duration(mockStmtLifetime))} return &payload } @@ -78,8 +100,8 @@ func (a mockAuthority) SubordinateEntityStatementPayload(subID string) EntitySta payload := EntityStatementPayload{ Issuer: a.EntityID, Subject: subID, - IssuedAt: Unixtime{now}, - ExpiresAt: Unixtime{now.Add(time.Second * time.Duration(mockStmtLifetime))}, + IssuedAt: unixtime.Unixtime{Time: now}, + ExpiresAt: unixtime.Unixtime{Time: now.Add(time.Second * time.Duration(mockStmtLifetime))}, JWKS: jwks, MetadataPolicy: a.data.MetadataPolicy, MetadataPolicyCrit: a.data.MetadataPolicyCrit, diff --git a/pkg/mock_http.go b/pkg/mock_http.go index 9283299..a23ddcb 100644 --- a/pkg/mock_http.go +++ b/pkg/mock_http.go @@ -1,169 +1,59 @@ package pkg import ( - "encoding/json" - "time" + "net/http" + "strings" - "github.com/pkg/errors" -) - -var mockupData mockHttp + "github.com/jarcoal/httpmock" -type mockHttp struct { - entityConfigurations map[string]func() []byte - entityListings map[string]mockList - entityStatements map[string]func(iss, sub string) []byte -} -type mockList []struct { - EntityID string - EntityType string -} - -func (d *mockHttp) addEntityConfiguration(entityid string, fnc func() []byte) { - if d.entityConfigurations == nil { - d.entityConfigurations = make(map[string]func() []byte) - } - d.entityConfigurations[entityid] = fnc -} -func (d *mockHttp) addEntityStatement(fetchEndpoint string, fnc func(iss, sub string) []byte) { - if d.entityStatements == nil { - d.entityStatements = make(map[string]func(iss, sub string) []byte) - } - d.entityStatements[fetchEndpoint] = fnc -} -func (d *mockHttp) addToListEndpoint(listEndpoint, entityID, entityType string) { - if d.entityListings == nil { - d.entityListings = make(map[string]mockList) - } - listing := d.entityListings[listEndpoint] - for _, l := range listing { - if l.EntityID == entityID { - return - } - } - listing = append( - listing, struct { - EntityID string - EntityType string - }{ - EntityID: entityID, - EntityType: entityType, - }, - ) - d.entityListings[listEndpoint] = listing -} + "github.com/zachmann/go-oidfed/internal/constants" +) -func (d mockHttp) GetEntityConfiguration(entityID string) ([]byte, error) { - fnc, ok := d.entityConfigurations[entityID] - if !ok { - return nil, errors.New("entity configuration not found") - } - return fnc(), nil +type mockedEntityConfigurationSigner interface { + EntityConfigurationJWT() ([]byte, error) } -func (d mockHttp) FetchEntityStatement(fetchEndpoint, subID, issID string) ([]byte, error) { - fetch, ok := d.entityStatements[fetchEndpoint] - if !ok { - return nil, errors.New("entity statement not found") - } - return fetch(issID, subID), nil +type mockedFetchResponder interface { + FetchResponse(sub string) ([]byte, error) } - -func (d mockHttp) ListEntities(listEndpoint, entityType string) ([]byte, error) { - var entities []string - listing := d.entityListings[listEndpoint] - for _, l := range listing { - if entityType == "" || l.EntityType == "" || entityType == l.EntityType { - entities = append(entities, l.EntityID) - } - } - return json.Marshal(entities) +type mockedSubordinateLister interface { + Subordinates(entityType string) ([]string, error) } -func (d *mockHttp) AddRP(r mockRP) { - d.addEntityConfiguration( - r.EntityID, func() []byte { - data, err := r.JWT(r.EntityStatementPayload()) - if err != nil { - panic(err) - } - return data - }, - ) -} -func (d *mockHttp) AddOP(o mockOP) { - d.addEntityConfiguration( - o.EntityID, func() []byte { - data, err := o.JWT(o.EntityStatementPayload()) - if err != nil { - panic(err) - } - return data - }, - ) -} -func (d *mockHttp) AddProxy(p mockProxy) { - d.addEntityConfiguration( - p.EntityID, func() []byte { - data, err := p.JWT(p.EntityStatementPayload()) +func mockEntityConfiguration(entityID string, signer mockedEntityConfigurationSigner) { + uri := strings.TrimSuffix(entityID, "/") + constants.FederationSuffix + httpmock.RegisterResponder( + "GET", uri, func(_ *http.Request) (*http.Response, error) { + res, err := signer.EntityConfigurationJWT() if err != nil { - panic(err) + return nil, err } - return data + return httpmock.NewBytesResponse(200, res), nil }, ) } -func (d *mockHttp) AddTMI(tmi mockTMI) { - d.addEntityConfiguration( - tmi.EntityID, func() []byte { - now := time.Now() - payload := EntityStatementPayload{ - Issuer: tmi.EntityID, - Subject: tmi.EntityID, - AuthorityHints: tmi.authorities, - IssuedAt: Unixtime{ - Time: now, - }, - ExpiresAt: Unixtime{ - Time: now.Add(defaultEntityConfigurationLifetime), - }, - JWKS: tmi.jwks, - Metadata: &Metadata{ - FederationEntity: &FederationEntityMetadata{ - FederationTrustMarkStatusEndpoint: "TODO", //TODO - OrganizationName: "TMI Organization", - }, - }, - } - data, err := tmi.TrustMarkSigner.JWT(payload) + +func mockFetchEndpoint(fetchEndpoint string, mocker mockedFetchResponder) { + httpmock.RegisterResponder( + "GET", fetchEndpoint, func(request *http.Request) (*http.Response, error) { + sub := request.URL.Query().Get("sub") + res, err := mocker.FetchResponse(sub) if err != nil { - panic(err) + return nil, err } - return data + return httpmock.NewBytesResponse(200, res), nil }, ) } -func (d *mockHttp) AddAuthority(a mockAuthority) { - d.addEntityConfiguration( - a.EntityID, func() []byte { - data, err := a.EntityStatementSigner.JWT(a.EntityStatementPayload()) - if err != nil { - panic(err) - } - return data - }, - ) - d.addEntityStatement( - a.FetchEndpoint, func(iss, sub string) []byte { - pay := a.SubordinateEntityStatementPayload(sub) - data, err := a.EntityStatementSigner.JWT(pay) +func mockListEndpoint(listEndpoint string, mocker mockedSubordinateLister) { + httpmock.RegisterResponder( + "GET", listEndpoint, func(request *http.Request) (*http.Response, error) { + entityType := request.URL.Query().Get("entity_type") + entities, err := mocker.Subordinates(entityType) if err != nil { - panic(err) + return nil, err } - return data + return httpmock.NewJsonResponse(200, entities) }, ) - for _, sub := range a.subordinates { - d.addToListEndpoint(a.ListEndpoint, sub.entityID, "") - } } diff --git a/pkg/mock_op.go b/pkg/mock_op.go index 2f0bbdb..97d0cff 100644 --- a/pkg/mock_op.go +++ b/pkg/mock_op.go @@ -11,6 +11,7 @@ import ( "github.com/lestrrat-go/jwx/jwa" "github.com/zachmann/go-oidfed/pkg/jwk" + "github.com/zachmann/go-oidfed/pkg/unixtime" ) type mockOP struct { @@ -21,18 +22,23 @@ type mockOP struct { metadata *OpenIDProviderMetadata } -func newMockOP(entityID string, metadata *OpenIDProviderMetadata) mockOP { +func (op mockOP) EntityConfigurationJWT() ([]byte, error) { + return op.EntityStatementSigner.JWT(op.EntityStatementPayload()) +} + +func newMockOP(entityID string, metadata *OpenIDProviderMetadata) *mockOP { sk, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) if err != nil { panic(err) } metadata.Issuer = entityID - o := mockOP{ + o := &mockOP{ EntityID: entityID, metadata: metadata, EntityStatementSigner: NewEntityStatementSigner(sk, jwa.ES512), jwks: jwk.KeyToJWKS(sk.Public(), jwa.ES512), } + mockEntityConfiguration(o.EntityID, o) return o } @@ -42,8 +48,8 @@ func (op mockOP) EntityStatementPayload() EntityStatementPayload { payload := EntityStatementPayload{ Issuer: op.EntityID, Subject: op.EntityID, - IssuedAt: Unixtime{now}, - ExpiresAt: Unixtime{now.Add(time.Second * time.Duration(mockStmtLifetime))}, + IssuedAt: unixtime.Unixtime{Time: now}, + ExpiresAt: unixtime.Unixtime{Time: now.Add(time.Second * time.Duration(mockStmtLifetime))}, JWKS: op.jwks, Audience: "", AuthorityHints: op.authorities, diff --git a/pkg/mock_proxy.go b/pkg/mock_proxy.go index 661ddbe..e978449 100644 --- a/pkg/mock_proxy.go +++ b/pkg/mock_proxy.go @@ -11,6 +11,7 @@ import ( "github.com/lestrrat-go/jwx/jwa" "github.com/zachmann/go-oidfed/pkg/jwk" + "github.com/zachmann/go-oidfed/pkg/unixtime" ) type mockProxy struct { @@ -25,22 +26,27 @@ type mockProxy struct { func newMockProxy( entityID string, rp *OpenIDRelyingPartyMetadata, op *OpenIDProviderMetadata, -) mockProxy { +) *mockProxy { sk, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) if err != nil { panic(err) } op.Issuer = entityID - p := mockProxy{ + p := &mockProxy{ EntityID: entityID, rpMetadata: rp, opMetadata: op, EntityStatementSigner: NewEntityStatementSigner(sk, jwa.ES512), jwks: jwk.KeyToJWKS(sk.Public(), jwa.ES512), } + mockEntityConfiguration(p.EntityID, p) return p } +func (proxy mockProxy) EntityConfigurationJWT() ([]byte, error) { + return proxy.EntityStatementSigner.JWT(proxy.EntityStatementPayload()) +} + func (proxy mockProxy) EntityStatementPayload() EntityStatementPayload { now := time.Now() orgID := fmt.Sprintf("%x", md5.Sum([]byte(proxy.EntityID))) @@ -50,8 +56,8 @@ func (proxy mockProxy) EntityStatementPayload() EntityStatementPayload { payload := EntityStatementPayload{ Issuer: proxy.EntityID, Subject: proxy.EntityID, - IssuedAt: Unixtime{now}, - ExpiresAt: Unixtime{now.Add(time.Second * time.Duration(mockStmtLifetime))}, + IssuedAt: unixtime.Unixtime{Time: now}, + ExpiresAt: unixtime.Unixtime{Time: now.Add(time.Second * time.Duration(mockStmtLifetime))}, JWKS: proxy.jwks, Audience: "", AuthorityHints: proxy.authorities, diff --git a/pkg/mock_rp.go b/pkg/mock_rp.go index 62bcd0c..4e4ba6d 100644 --- a/pkg/mock_rp.go +++ b/pkg/mock_rp.go @@ -11,6 +11,7 @@ import ( "github.com/lestrrat-go/jwx/jwa" "github.com/zachmann/go-oidfed/pkg/jwk" + "github.com/zachmann/go-oidfed/pkg/unixtime" ) type mockRP struct { @@ -21,28 +22,33 @@ type mockRP struct { metadata *OpenIDRelyingPartyMetadata } -func newMockRP(entityID string, metadata *OpenIDRelyingPartyMetadata) mockRP { +func newMockRP(entityID string, metadata *OpenIDRelyingPartyMetadata) *mockRP { sk, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) if err != nil { panic(err) } - r := mockRP{ + r := &mockRP{ EntityID: entityID, metadata: metadata, EntityStatementSigner: NewEntityStatementSigner(sk, jwa.ES512), jwks: jwk.KeyToJWKS(sk.Public(), jwa.ES512), } + mockEntityConfiguration(r.EntityID, r) return r } +func (rp mockRP) EntityConfigurationJWT() ([]byte, error) { + return rp.EntityStatementSigner.JWT(rp.EntityStatementPayload()) +} + func (rp mockRP) EntityStatementPayload() EntityStatementPayload { now := time.Now() orgID := fmt.Sprintf("%x", md5.Sum([]byte(rp.EntityID))) payload := EntityStatementPayload{ Issuer: rp.EntityID, Subject: rp.EntityID, - IssuedAt: Unixtime{now}, - ExpiresAt: Unixtime{now.Add(time.Second * time.Duration(mockStmtLifetime))}, + IssuedAt: unixtime.Unixtime{Time: now}, + ExpiresAt: unixtime.Unixtime{Time: now.Add(time.Second * time.Duration(mockStmtLifetime))}, JWKS: rp.jwks, Audience: "", AuthorityHints: rp.authorities, diff --git a/pkg/mock_tm.go b/pkg/mock_tm.go index 6cdb930..f47d7e1 100644 --- a/pkg/mock_tm.go +++ b/pkg/mock_tm.go @@ -3,11 +3,15 @@ package pkg import ( "crypto/ecdsa" "crypto/elliptic" + "crypto/md5" "crypto/rand" + "fmt" + "time" "github.com/lestrrat-go/jwx/jwa" "github.com/zachmann/go-oidfed/pkg/jwk" + "github.com/zachmann/go-oidfed/pkg/unixtime" ) type mockTMI struct { @@ -16,6 +20,30 @@ type mockTMI struct { jwks jwk.JWKS } +func (tmi mockTMI) EntityConfigurationJWT() ([]byte, error) { + return tmi.GeneralJWTSigner.EntityStatementSigner().JWT(tmi.EntityStatementPayload()) +} + +func (tmi mockTMI) EntityStatementPayload() EntityStatementPayload { + now := time.Now() + orgID := fmt.Sprintf("%x", md5.Sum([]byte(tmi.EntityID))) + payload := EntityStatementPayload{ + Issuer: tmi.EntityID, + Subject: tmi.EntityID, + AuthorityHints: tmi.authorities, + IssuedAt: unixtime.Unixtime{Time: now}, + ExpiresAt: unixtime.Unixtime{Time: now.Add(time.Second * time.Duration(mockStmtLifetime))}, + JWKS: tmi.jwks, + Metadata: &Metadata{ + FederationEntity: &FederationEntityMetadata{ + FederationTrustMarkStatusEndpoint: "TODO", //TODO + OrganizationName: fmt.Sprintf("Organization: %s", orgID[:8]), + }, + }, + } + return payload +} + func (tmi *mockTMI) AddAuthority(authorityID string) { tmi.authorities = append(tmi.authorities, authorityID) } @@ -28,16 +56,18 @@ func newMockTrustMarkOwner(entityID string, ownedTrustMarks []OwnedTrustMark) *T return NewTrustMarkOwner(entityID, NewTrustMarkDelegationSigner(sk, jwa.ES512), ownedTrustMarks) } -func newMockTrustMarkIssuer(entityID string, trustMarkSpecs []TrustMarkSpec) mockTMI { +func newMockTrustMarkIssuer(entityID string, trustMarkSpecs []TrustMarkSpec) *mockTMI { sk, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) if err != nil { panic(err) } tmi := NewTrustMarkIssuer(entityID, NewTrustMarkSigner(sk, jwa.ES512), trustMarkSpecs) - return mockTMI{ + mock := &mockTMI{ TrustMarkIssuer: *tmi, jwks: jwk.KeyToJWKS(tmi.key.Public(), tmi.alg), } + mockEntityConfiguration(mock.EntityID, mock) + return mock } func (tmi mockTMI) GetSubordinateInfo() mockSubordinateInfo { diff --git a/pkg/trustchain.go b/pkg/trustchain.go index 3045ce3..40b0145 100644 --- a/pkg/trustchain.go +++ b/pkg/trustchain.go @@ -9,6 +9,7 @@ import ( "github.com/zachmann/go-oidfed/internal" "github.com/zachmann/go-oidfed/internal/utils" "github.com/zachmann/go-oidfed/pkg/cache" + "github.com/zachmann/go-oidfed/pkg/unixtime" ) // TrustChain is a slice of *EntityStatements @@ -24,9 +25,9 @@ func (c TrustChain) hash() ([]byte, error) { } // ExpiresAt returns the expiration time of the TrustChain as a UNIX time stamp -func (c TrustChain) ExpiresAt() Unixtime { +func (c TrustChain) ExpiresAt() unixtime.Unixtime { if len(c) == 0 { - return Unixtime{} + return unixtime.Unixtime{} } exp := c[0].ExpiresAt for i := 1; i < len(c); i++ { @@ -114,6 +115,6 @@ func (c TrustChain) cacheSetMetadata(metadata *Metadata) error { } return cache.Set( cache.Key(cache.KeyTrustChainResolvedMetadata, string(hash)), metadata, - Until(c.ExpiresAt()), + unixtime.Until(c.ExpiresAt()), ) } diff --git a/pkg/trustchain_test.go b/pkg/trustchain_test.go index 72cff30..5f5dd2b 100644 --- a/pkg/trustchain_test.go +++ b/pkg/trustchain_test.go @@ -4,59 +4,85 @@ import ( "reflect" "testing" "time" + + "github.com/zachmann/go-oidfed/pkg/unixtime" ) func TestTrustChains_ExpiresAt(t *testing.T) { tests := []struct { name string chain TrustChain - expiresExpected Unixtime + expiresExpected unixtime.Unixtime }{ { name: "emtpy", chain: TrustChain{}, - expiresExpected: Unixtime{}, + expiresExpected: unixtime.Unixtime{}, }, { name: "single", chain: TrustChain{ - &EntityStatement{EntityStatementPayload: EntityStatementPayload{ExpiresAt: Unixtime{time.Unix(5, 0)}}}, + &EntityStatement{ + EntityStatementPayload: EntityStatementPayload{ + ExpiresAt: unixtime.Unixtime{ + Time: time.Unix( + 5, 0, + ), + }, + }, + }, }, - expiresExpected: Unixtime{time.Unix(5, 0)}, + expiresExpected: unixtime.Unixtime{Time: time.Unix(5, 0)}, }, { name: "first min", chain: TrustChain{ - &EntityStatement{EntityStatementPayload: EntityStatementPayload{ExpiresAt: Unixtime{time.Unix(5, 0)}}}, &EntityStatement{ EntityStatementPayload: EntityStatementPayload{ - ExpiresAt: Unixtime{time.Unix(10, 0)}, + ExpiresAt: unixtime.Unixtime{ + Time: time.Unix( + 5, 0, + ), + }, }, }, &EntityStatement{ EntityStatementPayload: EntityStatementPayload{ - ExpiresAt: Unixtime{time.Unix(100, 0)}, + ExpiresAt: unixtime.Unixtime{Time: time.Unix(10, 0)}, + }, + }, + &EntityStatement{ + EntityStatementPayload: EntityStatementPayload{ + ExpiresAt: unixtime.Unixtime{Time: time.Unix(100, 0)}, }, }, }, - expiresExpected: Unixtime{time.Unix(5, 0)}, + expiresExpected: unixtime.Unixtime{Time: time.Unix(5, 0)}, }, { name: "other min", chain: TrustChain{ &EntityStatement{ EntityStatementPayload: EntityStatementPayload{ - ExpiresAt: Unixtime{time.Unix(10, 0)}, + ExpiresAt: unixtime.Unixtime{Time: time.Unix(10, 0)}, + }, + }, + &EntityStatement{ + EntityStatementPayload: EntityStatementPayload{ + ExpiresAt: unixtime.Unixtime{ + Time: time.Unix( + 5, 0, + ), + }, }, }, - &EntityStatement{EntityStatementPayload: EntityStatementPayload{ExpiresAt: Unixtime{time.Unix(5, 0)}}}, &EntityStatement{ EntityStatementPayload: EntityStatementPayload{ - ExpiresAt: Unixtime{time.Unix(100, 0)}, + ExpiresAt: unixtime.Unixtime{Time: time.Unix(100, 0)}, }, }, }, - expiresExpected: Unixtime{time.Unix(5, 0)}, + expiresExpected: unixtime.Unixtime{Time: time.Unix(5, 0)}, }, } for _, test := range tests { diff --git a/pkg/trustchainfilter_test.go b/pkg/trustchainfilter_test.go index 99aaae4..fba4584 100644 --- a/pkg/trustchainfilter_test.go +++ b/pkg/trustchainfilter_test.go @@ -122,29 +122,20 @@ var ta2WithRemoveCrit = newMockAuthority( ) func init() { - ia1.RegisterSubordinate(&rp1) - ia2.RegisterSubordinate(&rp1) - ia1.RegisterSubordinate(&op1) - ia2.RegisterSubordinate(&op1) - ia1.RegisterSubordinate(&op3) - ia1.RegisterSubordinate(&proxy) - ia2.RegisterSubordinate(&op2) - ia2.RegisterSubordinate(&ia1) - ta1.RegisterSubordinate(&ia1) - ta1.RegisterSubordinate(&ia2) - ta2.RegisterSubordinate(&ia2) - ta2WithRemove.RegisterSubordinate(&ia2) - ta2WithRemoveCrit.RegisterSubordinate(&ia2) + ia1.RegisterSubordinate(rp1) + ia2.RegisterSubordinate(rp1) + ia1.RegisterSubordinate(op1) + ia2.RegisterSubordinate(op1) + ia1.RegisterSubordinate(op3) + ia1.RegisterSubordinate(proxy) + ia2.RegisterSubordinate(op2) + ia2.RegisterSubordinate(ia1) + ta1.RegisterSubordinate(ia1) + ta1.RegisterSubordinate(ia2) + ta2.RegisterSubordinate(ia2) + ta2WithRemove.RegisterSubordinate(ia2) + ta2WithRemoveCrit.RegisterSubordinate(ia2) - mockupData.AddRP(rp1) - mockupData.AddOP(op1) - mockupData.AddOP(op2) - mockupData.AddOP(op3) - mockupData.AddProxy(proxy) - mockupData.AddAuthority(ia1) - mockupData.AddAuthority(ia2) - mockupData.AddAuthority(ta1) - mockupData.AddAuthority(ta2) } var chainRPIA1TA1 = TrustChain{ diff --git a/pkg/trustmark.go b/pkg/trustmark.go index 63b154f..9996d60 100644 --- a/pkg/trustmark.go +++ b/pkg/trustmark.go @@ -10,6 +10,7 @@ import ( "github.com/zachmann/go-oidfed/internal/jwx" "github.com/zachmann/go-oidfed/pkg/jwk" + "github.com/zachmann/go-oidfed/pkg/unixtime" ) // TrustMarkInfo is a type for holding a trust mark as represented in an EntityConfiguration @@ -31,6 +32,19 @@ func (tm TrustMarkInfo) MarshalJSON() ([]byte, error) { return extraMarshalHelper(explicitFields, tm.Extra) } +// ParseTrustMark parses a trust mark jwt into a TrustMark +func ParseTrustMark(data []byte) (*TrustMark, error) { + m, err := jwx.Parse(data) + if err != nil { + return nil, err + } + t := &TrustMark{jwtMsg: m} + if err = json.Unmarshal(m.Payload(), t); err != nil { + return nil, err + } + return t, nil +} + // UnmarshalJSON implements the json.Unmarshaler interface. // It also unmarshalls additional fields into the Extra claim. func (tm *TrustMarkInfo) UnmarshalJSON(data []byte) error { @@ -48,14 +62,10 @@ func (tm *TrustMarkInfo) UnmarshalJSON(data []byte) error { // TrustMark returns the TrustMark for this TrustMarkInfo func (tm *TrustMarkInfo) TrustMark() (*TrustMark, error) { if tm.trustmark == nil || tm.trustmark.jwtMsg == nil { - m, err := jwx.Parse([]byte(tm.TrustMarkJWT)) + t, err := ParseTrustMark([]byte(tm.TrustMarkJWT)) if err != nil { return nil, err } - t := &TrustMark{jwtMsg: m} - if err = json.Unmarshal(m.Payload(), t); err != nil { - return nil, err - } tm.trustmark = t } return tm.trustmark, nil @@ -91,9 +101,9 @@ type TrustMark struct { Issuer string `json:"iss"` Subject string `json:"sub"` ID string `json:"id"` - IssuedAt Unixtime `json:"iat"` + IssuedAt unixtime.Unixtime `json:"iat"` LogoURI string `json:"logo_uri,omitempty"` - ExpiresAt *Unixtime `json:"exp,omitempty"` + ExpiresAt *unixtime.Unixtime `json:"exp,omitempty"` Ref string `json:"ref,omitempty"` DelegationJWT string `json:"delegation,omitempty"` Extra map[string]interface{} `json:"-"` @@ -181,7 +191,7 @@ func (tm *TrustMark) VerifyFederation(ta *EntityStatementPayload) error { // VerifyExternal verifies the TrustMark by using the passed trust mark issuer jwks and optionally the passed // trust mark owner jwks func (tm *TrustMark) VerifyExternal(jwks jwk.JWKS, tmo ...TrustMarkOwnerSpec) error { - if err := verifyTime(&tm.IssuedAt, tm.ExpiresAt); err != nil { + if err := unixtime.VerifyTime(&tm.IssuedAt, tm.ExpiresAt); err != nil { return err } if _, err := jwx.VerifyWithSet(tm.jwtMsg, jwks); err != nil { @@ -216,8 +226,8 @@ type DelegationJWT struct { Issuer string `json:"iss"` Subject string `json:"sub"` ID string `json:"id"` - IssuedAt Unixtime `json:"iat"` - ExpiresAt *Unixtime `json:"exp,omitempty"` + IssuedAt unixtime.Unixtime `json:"iat"` + ExpiresAt *unixtime.Unixtime `json:"exp,omitempty"` Ref string `json:"ref,omitempty"` Extra map[string]interface{} `json:"-"` jwtMsg *jwx.ParsedJWT @@ -250,7 +260,7 @@ func (djwt *DelegationJWT) UnmarshalJSON(data []byte) error { // VerifyFederation verifies the DelegationJWT by using the passed trust anchor func (djwt DelegationJWT) VerifyFederation(ta *EntityStatementPayload) error { - if err := verifyTime(&djwt.IssuedAt, djwt.ExpiresAt); err != nil { + if err := unixtime.VerifyTime(&djwt.IssuedAt, djwt.ExpiresAt); err != nil { return errors.Wrap(err, "verify delegation jwt") } owner, ok := ta.TrustMarkOwners[djwt.ID] @@ -263,7 +273,7 @@ func (djwt DelegationJWT) VerifyFederation(ta *EntityStatementPayload) error { // VerifyExternal verifies the DelegationJWT by using the passed trust mark owner jwks func (djwt DelegationJWT) VerifyExternal(jwks jwk.JWKS) error { - if err := verifyTime(&djwt.IssuedAt, djwt.ExpiresAt); err != nil { + if err := unixtime.VerifyTime(&djwt.IssuedAt, djwt.ExpiresAt); err != nil { return errors.Wrap(err, "verify delegation jwt") } _, err := jwx.VerifyWithSet(djwt.jwtMsg, jwks) @@ -279,13 +289,13 @@ type TrustMarkIssuer struct { // TrustMarkSpec describes a TrustMark for a TrustMarkIssuer type TrustMarkSpec struct { - ID string `json:"trust_mark_id" yaml:"trust_mark_id"` - Lifetime DurationInSeconds `json:"lifetime" yaml:"lifetime"` - Ref string `json:"ref" yaml:"ref"` - LogoURI string `json:"logo_uri" yaml:"logo_uri"` - Extra map[string]any `json:"-" yaml:"-"` - IncludeExtraClaimsInInfo bool `json:"include_extra_claims_in_info" yaml:"include_extra_claims_in_info"` - DelegationJWT string `json:"delegation_jwt" yaml:"delegation_jwt"` + ID string `json:"trust_mark_id" yaml:"trust_mark_id"` + Lifetime unixtime.DurationInSeconds `json:"lifetime" yaml:"lifetime"` + Ref string `json:"ref" yaml:"ref"` + LogoURI string `json:"logo_uri" yaml:"logo_uri"` + Extra map[string]any `json:"-" yaml:"-"` + IncludeExtraClaimsInInfo bool `json:"include_extra_claims_in_info" yaml:"include_extra_claims_in_info"` + DelegationJWT string `json:"delegation_jwt" yaml:"delegation_jwt"` } // MarshalJSON implements the json.Marshaler interface @@ -377,7 +387,7 @@ func (tmi TrustMarkIssuer) IssueTrustMark(trustMarkID, sub string, lifetime ...t Issuer: tmi.EntityID, Subject: sub, ID: spec.ID, - IssuedAt: Unixtime{now}, + IssuedAt: unixtime.Unixtime{Time: now}, LogoURI: spec.LogoURI, Ref: spec.Ref, DelegationJWT: spec.DelegationJWT, @@ -388,7 +398,7 @@ func (tmi TrustMarkIssuer) IssueTrustMark(trustMarkID, sub string, lifetime ...t lf = lifetime[0] } if lf != 0 { - tm.ExpiresAt = &Unixtime{now.Add(lf)} + tm.ExpiresAt = &unixtime.Unixtime{Time: now.Add(lf)} } jwt, err := tmi.TrustMarkSigner.JWT(tm) if err != nil { @@ -453,7 +463,7 @@ func (tmo TrustMarkOwner) DelegationJWT(trustMarkID, sub string, lifetime ...tim Issuer: tmo.EntityID, Subject: sub, ID: spec.ID, - IssuedAt: Unixtime{now}, + IssuedAt: unixtime.Unixtime{Time: now}, Ref: spec.Ref, Extra: spec.Extra, } @@ -462,7 +472,7 @@ func (tmo TrustMarkOwner) DelegationJWT(trustMarkID, sub string, lifetime ...tim lf = lifetime[0] } if spec.DelegationLifetime != 0 { - delegation.ExpiresAt = &Unixtime{now.Add(lf)} + delegation.ExpiresAt = &unixtime.Unixtime{Time: now.Add(lf)} } return tmo.TrustMarkDelegationSigner.JWT(delegation) } diff --git a/pkg/trustmark_refresher.go b/pkg/trustmark_refresher.go new file mode 100644 index 0000000..498a875 --- /dev/null +++ b/pkg/trustmark_refresher.go @@ -0,0 +1,114 @@ +package pkg + +import ( + "net/url" + "time" + + "github.com/lestrrat-go/jwx/jwt" + "github.com/pkg/errors" + + "github.com/zachmann/go-oidfed/internal" + "github.com/zachmann/go-oidfed/internal/http" + "github.com/zachmann/go-oidfed/pkg/unixtime" +) + +// EntityConfigurationTrustMarkConfig is a type for specifying the configuration of a TrustMark that should be +// included in an EntityConfiguration +type EntityConfigurationTrustMarkConfig struct { + TrustMarkID string `yaml:"trust_mark_id"` + TrustMarkIssuer string `yaml:"trust_mark_issuer"` + JWT string `yaml:"trust_mark_jwt"` + Refresh bool `yaml:"refresh"` + MinLifetime unixtime.DurationInSeconds `yaml:"min_lifetime"` + RefreshGracePeriod unixtime.DurationInSeconds `yaml:"refresh_grace_period"` + expiration unixtime.Unixtime + sub string +} + +// Verify verifies that the EntityConfigurationTrustMarkConfig is correct and also extracts trust mark id and issuer +// if a trust mark jwt is given as well as sets default values +func (c *EntityConfigurationTrustMarkConfig) Verify(sub string) error { + c.sub = sub + if c.MinLifetime.Duration == 0 { + c.MinLifetime = unixtime.NewDurationInSeconds(10) + } + if c.RefreshGracePeriod.Duration == 0 { + c.RefreshGracePeriod.Duration = time.Hour + } + + if c.JWT != "" { + parsed, err := jwt.Parse([]byte(c.JWT)) + if err != nil { + return err + } + c.expiration = unixtime.Unixtime{Time: parsed.Expiration()} + c.TrustMarkIssuer = parsed.Issuer() + internal.Logf("Extracted trust mark issuer: %s", c.TrustMarkIssuer) + tmi, set := parsed.Get("id") + if !set { + return errors.New("trustmark id not found in JWT") + } + tmiS, ok := tmi.(string) + if !ok { + return errors.New("trustmark id in JWT not a string") + } + c.TrustMarkID = tmiS + internal.Logf("Extracted trust mark id: %s\n", c.TrustMarkID) + return nil + } + c.Refresh = true + if c.TrustMarkID == "" || c.TrustMarkIssuer == "" { + return errors.New("either trust_mark_jwt or trust_mark_issuer and trust_mark_id must be specified") + } + return nil +} + +// TrustMarkJWT returns a trust mark jwt for the linked trust mark, +// if needed the trust mark is refreshed using the trust mark issuer's trust mark endpoint +func (c *EntityConfigurationTrustMarkConfig) TrustMarkJWT() (string, error) { + if !c.Refresh { + return c.JWT, nil + } + if c.JWT != "" && unixtime.Until(c.expiration) > c.MinLifetime.Duration { + if unixtime.Until(c.expiration) < c.RefreshGracePeriod.Duration { + go c.refresh() + } + return c.JWT, nil + } + err := c.refresh() + return c.JWT, err +} + +// refresh refreshes the trust mark at the trust mark issuer's trust mark endpoint +func (c *EntityConfigurationTrustMarkConfig) refresh() error { + tmi, err := GetEntityConfiguration(c.TrustMarkIssuer) + if err != nil { + return err + } + if tmi.Metadata == nil || tmi.Metadata.FederationEntity == nil || tmi.Metadata. + FederationEntity.FederationTrustMarkEndpoint == "" { + return errors.New("could not obtain trust mark endpoint of trust mark issuer") + } + endpoint := tmi.Metadata.FederationEntity.FederationTrustMarkEndpoint + params := url.Values{} + params.Add("trust_mark_id", c.TrustMarkID) + params.Add("sub", c.sub) + res, errRes, err := http.Get(endpoint, params, nil) + if err != nil { + return err + } + if errRes != nil { + return errRes.Err() + } + tm, err := ParseTrustMark(res.Body()) + if err != nil { + return err + } + c.JWT = string(tm.jwtMsg.RawJWT) + if tm.ExpiresAt != nil { + c.expiration = *tm.ExpiresAt + } else { + c.expiration = unixtime.Unixtime{} + } + return nil +} diff --git a/pkg/trustmark_test.go b/pkg/trustmark_test.go index 9ecc68e..501bd25 100644 --- a/pkg/trustmark_test.go +++ b/pkg/trustmark_test.go @@ -8,31 +8,32 @@ import ( "github.com/lestrrat-go/jwx/jwa" "github.com/zachmann/go-oidfed/pkg/jwk" + "github.com/zachmann/go-oidfed/pkg/unixtime" ) var tmi1 = newMockTrustMarkIssuer( "https://tmi.example.org", []TrustMarkSpec{ { ID: "https://trustmarks.org/tm1", - Lifetime: DurationInSeconds{time.Hour}, + Lifetime: unixtime.DurationInSeconds{Duration: time.Hour}, }, { ID: "https://trustmarks.org/tm2", - Lifetime: DurationInSeconds{time.Hour}, + Lifetime: unixtime.DurationInSeconds{Duration: time.Hour}, Ref: "https://trustmarks.org/tm2/info", LogoURI: "https://trustmarks.org/tm2/logo", }, { ID: "https://trustmarks.org/tm3", - Lifetime: DurationInSeconds{time.Hour}, + Lifetime: unixtime.DurationInSeconds{Duration: time.Hour}, }, { ID: "https://trustmarks.org/tm4", - Lifetime: DurationInSeconds{time.Hour}, + Lifetime: unixtime.DurationInSeconds{Duration: time.Hour}, }, { ID: "https://trustmarks.org/tm-delegated", - Lifetime: DurationInSeconds{time.Hour}, + Lifetime: unixtime.DurationInSeconds{Duration: time.Hour}, }, }, ) @@ -58,7 +59,7 @@ var tmi2 = newMockTrustMarkIssuer( }, { ID: "https://trustmarks.org/tm-delegated", - Lifetime: DurationInSeconds{time.Hour}, + Lifetime: unixtime.DurationInSeconds{Duration: time.Hour}, }, }, ) @@ -126,16 +127,13 @@ func init() { tmi1.AddTrustMark( TrustMarkSpec{ ID: "https://trustmarks.org/test", - Lifetime: DurationInSeconds{time.Hour}, + Lifetime: unixtime.DurationInSeconds{Duration: time.Hour}, DelegationJWT: string(delegation), }, ) - taWithTmo.RegisterSubordinate(&tmi1) - taWithTmo.RegisterSubordinate(&tmi2) - mockupData.AddTMI(tmi1) - mockupData.AddTMI(tmi2) - mockupData.AddAuthority(taWithTmo) + taWithTmo.RegisterSubordinate(tmi1) + taWithTmo.RegisterSubordinate(tmi2) } func TestTrustMarkOwner_DelegationJWT(t *testing.T) { diff --git a/pkg/trustresolver.go b/pkg/trustresolver.go index ae7f2d2..119752c 100644 --- a/pkg/trustresolver.go +++ b/pkg/trustresolver.go @@ -2,15 +2,20 @@ package pkg import ( "encoding/json" + "net/url" + "strings" "time" "github.com/vmihailenco/msgpack/v5" "golang.org/x/crypto/sha3" "github.com/zachmann/go-oidfed/internal" + "github.com/zachmann/go-oidfed/internal/constants" + "github.com/zachmann/go-oidfed/internal/http" "github.com/zachmann/go-oidfed/internal/jwx" "github.com/zachmann/go-oidfed/internal/utils" "github.com/zachmann/go-oidfed/pkg/cache" + "github.com/zachmann/go-oidfed/pkg/unixtime" ) const cacheGracePeriod = time.Hour @@ -19,8 +24,8 @@ const cacheGracePeriod = time.Hour type ResolveResponse struct { Issuer string `json:"iss"` Subject string `json:"sub"` - IssuedAt Unixtime `json:"iat"` - ExpiresAt Unixtime `json:"exp"` + IssuedAt unixtime.Unixtime `json:"iat"` + ExpiresAt unixtime.Unixtime `json:"exp"` Audience string `json:"aud,omitempty"` Metadata *Metadata `json:"metadata,omitempty"` TrustMarks []TrustMarkInfo `json:"trust_marks,omitempty"` @@ -183,7 +188,7 @@ func (r TrustResolver) cacheSetTrustChains(chains TrustChains) error { } return cache.Set( cache.Key(cache.KeyTrustTreeChains, string(hash)), chains, - Until(r.trustTree.expiresAt), + unixtime.Until(r.trustTree.expiresAt), ) } @@ -206,7 +211,7 @@ func (r TrustResolver) cacheSetTrustTree() error { } return cache.Set( cache.Key(cache.KeyTrustTree, string(hash)), r.trustTree, - Until(r.trustTree.expiresAt), + unixtime.Until(r.trustTree.expiresAt), ) } @@ -216,7 +221,7 @@ type trustTree struct { Subordinate *EntityStatement Authorities []trustTree signaturesVerified bool - expiresAt Unixtime + expiresAt unixtime.Unixtime } func (t *trustTree) resolve(anchors TrustAnchors) { @@ -322,12 +327,6 @@ func (t trustTree) chains() (chains []TrustChain) { return } -var entityStatementObtainer internal.EntityStatementObtainer - -func init() { - entityStatementObtainer = internal.DefaultHttpEntityStatementObtainer -} - func entityStmtCacheSet(subID, issID string, stmt *EntityStatement) { if err := cache.Set( cache.EntityStmtCacheKey(subID, issID), stmt, time.Until(stmt.ExpiresAt.Time), @@ -352,15 +351,14 @@ func entityStmtCacheGet(subID, issID string) *EntityStatement { // EntityStatement func GetEntityConfiguration(entityID string) (*EntityStatement, error) { return getEntityStatementOrConfiguration( - entityID, entityID, func() ([]byte, error) { - return entityStatementObtainer.GetEntityConfiguration(entityID) + entityID, entityID, func() (*EntityStatement, error) { + return httpGetEntityConfiguration(entityID) }, ) } func getEntityStatementOrConfiguration( - subID, issID string, - obtainerFnc func() ([]byte, error), + subID, issID string, obtainerFnc func() (*EntityStatement, error), ) (*EntityStatement, error) { if stmt := entityStmtCacheGet(subID, issID); stmt != nil { @@ -383,32 +381,52 @@ func getEntityStatementOrConfiguration( } func obtainAndSetEntityStatementOrConfiguration( - subID, issID string, - obtainerFnc func() ([]byte, error), + subID, issID string, obtainerFnc func() (*EntityStatement, error), ) (*EntityStatement, error) { - body, err := obtainerFnc() + stmt, err := obtainerFnc() if err != nil { internal.Log(err) return nil, err } internal.Log("Obtained entity statement from http") - stmt, err := ParseEntityStatement(body) + entityStmtCacheSet(subID, issID, stmt) + return stmt, nil +} + +func httpGetEntityConfiguration( + entityID string, +) (*EntityStatement, error) { + uri := strings.TrimSuffix(entityID, "/") + constants.FederationSuffix + internal.Logf("Obtaining entity configuration from %+q", uri) + res, errRes, err := http.Get(uri, nil, nil) if err != nil { - internal.Log(err) return nil, err } - entityStmtCacheSet(subID, issID, stmt) - return stmt, nil + if errRes != nil { + return nil, errRes.Err() + } + return ParseEntityStatement(res.Body()) } // FetchEntityStatement fetches an EntityStatement from a fetch endpoint func FetchEntityStatement(fetchEndpoint, subID, issID string) (*EntityStatement, error) { return getEntityStatementOrConfiguration( - subID, issID, func() ([]byte, error) { - return entityStatementObtainer.FetchEntityStatement( - fetchEndpoint, - subID, issID, - ) + subID, issID, func() (*EntityStatement, error) { + return httpFetchEntityStatement(fetchEndpoint, subID, issID) }, ) } +func httpFetchEntityStatement(fetchEndpoint, subID, issID string) (*EntityStatement, error) { + uri := fetchEndpoint + params := url.Values{} + params.Add("sub", subID) + params.Add("iss", issID) + res, errRes, err := http.Get(uri, params, nil) + if err != nil { + return nil, err + } + if errRes != nil { + return nil, errRes.Err() + } + return ParseEntityStatement(res.Body()) +} diff --git a/pkg/trustresolver_test.go b/pkg/trustresolver_test.go index 251d51f..706a79d 100644 --- a/pkg/trustresolver_test.go +++ b/pkg/trustresolver_test.go @@ -8,11 +8,14 @@ import ( "strings" "testing" + "github.com/jarcoal/httpmock" + "github.com/zachmann/go-oidfed/internal" + "github.com/zachmann/go-oidfed/internal/http" ) func setup() { - entityStatementObtainer = mockupData + httpmock.ActivateNonDefault(http.Do().GetClient()) internal.EnableDebugLogging() // cache.UseRedisCache(&redis.Options{Addr: "localhost:6379"}) } diff --git a/pkg/unixtime.go b/pkg/unixtime/unixtime.go similarity index 94% rename from pkg/unixtime.go rename to pkg/unixtime/unixtime.go index 58a9bc0..7c87a19 100644 --- a/pkg/unixtime.go +++ b/pkg/unixtime/unixtime.go @@ -1,4 +1,4 @@ -package pkg +package unixtime import ( "encoding/json" @@ -38,7 +38,8 @@ func Until(u Unixtime) time.Duration { return time.Until(u.Time) } -func verifyTime(iat, exp *Unixtime) error { +// VerifyTime verifies the iat and exp times with regard to the current time +func VerifyTime(iat, exp *Unixtime) error { now := time.Now() if iat != nil && !iat.IsZero() && iat.After(now) { return errors.New("not yet valid")