diff --git a/tests/lib/fakestore/store/store.go b/tests/lib/fakestore/store/store.go index 81c8cfe4536..0a5daf4587f 100644 --- a/tests/lib/fakestore/store/store.go +++ b/tests/lib/fakestore/store/store.go @@ -27,6 +27,7 @@ import ( "io/ioutil" "net" "net/http" + "net/url" "path/filepath" "regexp" "strconv" @@ -739,35 +740,88 @@ func (s *Store) retrieveAssertion(bs asserts.Backstore, assertType *asserts.Asse return a, err } -func (s *Store) assertionsEndpoint(w http.ResponseWriter, req *http.Request) { - assertPath := strings.TrimPrefix(req.URL.Path, "/v2/assertions/") +func (s *Store) retrieveLatestSequenceFormingAssertion(bs asserts.Backstore, assertType *asserts.AssertionType, sequenceKey []string) (asserts.Assertion, error) { + a, err := bs.SequenceMemberAfter(assertType, sequenceKey, -1, assertType.MaxSupportedFormat()) + if errors.Is(err, &asserts.NotFoundError{}) && s.assertFallback { + return s.fallback.SeqFormingAssertion(assertType, sequenceKey, -1, nil) + } + return a, err +} - bs, err := s.collectAssertions() - if err != nil { - http.Error(w, fmt.Sprintf("internal error collecting assertions: %v", err), 500) - return +func (s *Store) sequenceFromQueryValues(values url.Values) (int, error) { + if val, ok := values["sequence"]; ok { + // special case value of 'latest', in that case + // we return -1 to indicate we want the newest + if val[0] != "latest" { + seq, err := strconv.Atoi(val[0]) + if err != nil { + return -1, fmt.Errorf("cannot parse sequence %s: %v", val[0], err) + } + + // Only positive integers and 'latest' are valid + if seq <= 0 { + return -1, fmt.Errorf("the requested sequence must be above 0") + } + return seq, nil + } } + return -1, nil +} +func (s *Store) assertTypeAndKey(urlPath string) (*asserts.AssertionType, []string, error) { + // trim the assertions prefix, and handle any query parameters + assertPath := strings.TrimPrefix(urlPath, "/v2/assertions/") comps := strings.Split(assertPath, "/") - if len(comps) == 0 { - http.Error(w, "missing assertion type", 400) - return + return nil, nil, fmt.Errorf("missing assertion type") } typ := asserts.Type(comps[0]) if typ == nil { - http.Error(w, fmt.Sprintf("unknown assertion type: %s", comps[0]), 400) + return nil, nil, fmt.Errorf("unknown assertion type: %s", comps[0]) + } + return typ, comps[1:], nil +} + +func (s *Store) retrieveAssertionWrapper(bs asserts.Backstore, assertType *asserts.AssertionType, keyParts []string, values url.Values) (asserts.Assertion, error) { + pk := keyParts + if assertType.SequenceForming() { + seq, err := s.sequenceFromQueryValues(values) + if err != nil { + return nil, err + } + + // If no sequence value was provided, or when requesting the latest sequence + // point of an assertion, we use a different method of resolving the assertion. + if seq <= 0 { + return s.retrieveLatestSequenceFormingAssertion(bs, assertType, keyParts) + } + + // Otherwise append the sequence to form the primary key and use + // the default retrieval. + pk = append(pk, strconv.Itoa(seq)) + } + + if !assertType.AcceptablePrimaryKey(pk) { + return nil, fmt.Errorf("wrong primary key length: %v", pk) + } + return s.retrieveAssertion(bs, assertType, pk) +} + +func (s *Store) assertionsEndpoint(w http.ResponseWriter, req *http.Request) { + typ, pk, err := s.assertTypeAndKey(req.URL.Path) + if err != nil { + http.Error(w, err.Error(), 400) return } - pk := comps[1:] - if !typ.AcceptablePrimaryKey(pk) { - http.Error(w, fmt.Sprintf("wrong primary key length: %v", comps), 400) + bs, err := s.collectAssertions() + if err != nil { + http.Error(w, fmt.Sprintf("internal error collecting assertions: %v", err), 500) return } - a, err := s.retrieveAssertion(bs, typ, pk) + as, err := s.retrieveAssertionWrapper(bs, typ, pk, req.URL.Query()) if errors.Is(err, &asserts.NotFoundError{}) { w.Header().Set("Content-Type", "application/problem+json") w.WriteHeader(404) @@ -775,13 +829,13 @@ func (s *Store) assertionsEndpoint(w http.ResponseWriter, req *http.Request) { return } if err != nil { - http.Error(w, fmt.Sprintf("cannot retrieve assertion %v: %v", comps, err), 400) + http.Error(w, fmt.Sprintf("cannot retrieve assertion %v: %v", pk, err), 400) return } w.Header().Set("Content-Type", asserts.MediaType) w.WriteHeader(200) - w.Write(asserts.Encode(a)) + w.Write(asserts.Encode(as)) } func addSnapIDs(bs asserts.Backstore, initial map[string]string) (map[string]string, error) { diff --git a/tests/lib/fakestore/store/store_test.go b/tests/lib/fakestore/store/store_test.go index 63c42f36975..e14be5859bb 100644 --- a/tests/lib/fakestore/store/store_test.go +++ b/tests/lib/fakestore/store/store_test.go @@ -443,6 +443,23 @@ timestamp: 2016-08-19T19:19:19Z sign-key-sha3-384: Jv8_JiHiIzJVcO9M55pPdqSDWUvuhfDIBJUS-3VW7F_idjix7Ffn5qMxB21ZQuij AXNpZw=` + exampleValidationSet = `type: validation-set +authority-id: canonical +account-id: canonical +name: base-set +sequence: 2 +revision: 1 +series: 16 +snaps: + - + id: yOqKhntON3vR7kwEbVPsILm7bUViPDzz + name: snap-b + presence: required + revision: 1 +timestamp: 2020-11-06T09:16:26Z +sign-key-sha3-384: 7bbncP0c4RcufwReeiylCe0S7IMCn-tHLNSCgeOVmV3K-7_MzpAHgJDYeOjldefE + +AXNpZw==` ) func (s *storeTestSuite) TestAssertionsEndpointPreloaded(c *C) { @@ -478,6 +495,59 @@ func (s *storeTestSuite) TestAssertionsEndpointFromAssertsDir(c *C) { c.Check(string(body), Equals, exampleSnapRev) } +func (s *storeTestSuite) TestAssertionsEndpointSequenceAssertion(c *C) { + err := ioutil.WriteFile(filepath.Join(s.store.assertDir, "base-set.validation-set"), []byte(exampleValidationSet), 0655) + c.Assert(err, IsNil) + + resp, err := s.StoreGet(`/v2/assertions/validation-set/16/canonical/base-set?sequence=2`) + c.Assert(err, IsNil) + defer resp.Body.Close() + + c.Check(resp.StatusCode, Equals, 200) + body, err := ioutil.ReadAll(resp.Body) + c.Assert(err, IsNil) + c.Check(string(body), Equals, exampleValidationSet) +} + +func (s *storeTestSuite) TestAssertionsEndpointSequenceAssertionLatest(c *C) { + err := ioutil.WriteFile(filepath.Join(s.store.assertDir, "base-set.validation-set"), []byte(exampleValidationSet), 0655) + c.Assert(err, IsNil) + + resp, err := s.StoreGet(`/v2/assertions/validation-set/16/canonical/base-set?sequence=latest`) + c.Assert(err, IsNil) + defer resp.Body.Close() + + c.Check(resp.StatusCode, Equals, 200) + body, err := ioutil.ReadAll(resp.Body) + c.Assert(err, IsNil) + c.Check(string(body), Equals, exampleValidationSet) +} + +func (s *storeTestSuite) TestAssertionsEndpointSequenceAssertionInvalidSequence(c *C) { + err := ioutil.WriteFile(filepath.Join(s.store.assertDir, "base-set.validation-set"), []byte(exampleValidationSet), 0655) + c.Assert(err, IsNil) + + resp, err := s.StoreGet(`/v2/assertions/validation-set/16/canonical/base-set?sequence=0`) + c.Assert(err, IsNil) + defer resp.Body.Close() + + c.Assert(resp.StatusCode, Equals, 400) + body, err := ioutil.ReadAll(resp.Body) + c.Assert(err, IsNil) + c.Check(string(body), Equals, "cannot retrieve assertion [16 canonical base-set]: the requested sequence must be above 0\n") +} + +func (s *storeTestSuite) TestAssertionsEndpointSequenceInvalid(c *C) { + resp, err := s.StoreGet(`/v2/assertions/validation-set/16/canonical/base-set?sequence=foo`) + c.Assert(err, IsNil) + defer resp.Body.Close() + + c.Assert(resp.StatusCode, Equals, 400) + body, err := ioutil.ReadAll(resp.Body) + c.Assert(err, IsNil) + c.Check(string(body), Equals, "cannot retrieve assertion [16 canonical base-set]: cannot parse sequence foo: strconv.Atoi: parsing \"foo\": invalid syntax\n") +} + func (s *storeTestSuite) TestAssertionsEndpointNotFound(c *C) { // something not found resp, err := s.StoreGet(`/v2/assertions/account/not-an-account-id`) diff --git a/tests/lib/gendeveloper1/main.go b/tests/lib/gendeveloper1/main.go index 7ed927a96ba..24f29c43f99 100644 --- a/tests/lib/gendeveloper1/main.go +++ b/tests/lib/gendeveloper1/main.go @@ -71,10 +71,10 @@ func main() { log.Fatalf("failed to decode model headers data: %v", err) } - assertName, _ := headers["type"] + headerType := headers["type"] assertType := asserts.ModelType - if assertName == "system-user" { - assertType = asserts.SystemUserType + if assertTypeStr, ok := headerType.(string); ok { + assertType = asserts.Type(assertTypeStr) } clModel, err := devSigning.Sign(assertType, headers, nil, "")