Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance and bug-fix backports #443

Merged
merged 3 commits into from
Dec 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
- name: Install Go
uses: actions/setup-go@v3
with:
go-version: 1.21
go-version: 1.22
- name: golangci-lint
uses: golangci/golangci-lint-action@v3

Expand All @@ -33,7 +33,7 @@ jobs:
strategy:
fail-fast: false
matrix:
go: ["1.20", "1.21"]
go: ["stable", "1.22"]
steps:
- uses: actions/checkout@v3
- name: Setup go
Expand Down Expand Up @@ -83,7 +83,7 @@ jobs:
strategy:
fail-fast: false
matrix:
go: ["1.21"]
go: ["stable", "1.22"]
steps:
- uses: actions/checkout@v3
with:
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ gomatrixserverlib

[![GoDoc](https://godoc.org/github.com/matrix-org/gomatrixserverlib?status.svg)](https://godoc.org/github.com/matrix-org/gomatrixserverlib)

Go library for common functions needed by matrix servers. This library assumes Go 1.18+.
Go library for common functions needed by matrix servers. This library assumes Go 1.22+.
8 changes: 4 additions & 4 deletions authstate.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ func checkAllowedByAuthEvents(
event PDU, eventsByID map[string]PDU,
missingAuth EventProvider, userIDForSender spec.UserIDForSender,
) error {
authEvents := NewAuthEvents(nil)
authEvents, _ := NewAuthEvents(nil)

for _, ae := range event.AuthEventIDs() {
retryEvent:
Expand Down Expand Up @@ -214,7 +214,7 @@ func checkAllowedByAuthEvents(

// If we made it this far then we've successfully got as many of the auth events as
// as described by AuthEventIDs(). Check if they allow the event.
if err := Allowed(event, &authEvents, userIDForSender); err != nil {
if err := Allowed(event, authEvents, userIDForSender); err != nil {
return fmt.Errorf(
"gomatrixserverlib: event with ID %q is not allowed by its auth_events: %s",
event.EventID(), err.Error(),
Expand Down Expand Up @@ -335,7 +335,7 @@ func CheckSendJoinResponse(
}

eventsByID := map[string]PDU{}
authEventProvider := NewAuthEvents(nil)
authEventProvider, _ := NewAuthEvents(nil)

// Since checkAllowedByAuthEvents needs to be able to look up any of the
// auth events by ID only, we will build a map which contains references
Expand Down Expand Up @@ -369,7 +369,7 @@ func CheckSendJoinResponse(
}

// Now check that the join event is valid against the supplied state.
if err := Allowed(joinEvent, &authEventProvider, userIDForSender); err != nil {
if err := Allowed(joinEvent, authEventProvider, userIDForSender); err != nil {
return nil, fmt.Errorf(
"gomatrixserverlib: event with ID %q is not allowed by the current room state: %w",
joinEvent.EventID(), err,
Expand Down
3 changes: 2 additions & 1 deletion backfill.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ func RequestBackfill(ctx context.Context, origin spec.ServerName, b BackfillRequ
}
}

return result, lastErr
// Since we pulled in results from multiple servers we need to sort again...
return ReverseTopologicalOrdering(result, TopologicalOrderByPrevEvents), lastErr
}

/*
Expand Down
3 changes: 1 addition & 2 deletions eventV1.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,7 @@ func (e *eventV1) SetUnsignedField(path string, value interface{}) error {
eventJSON = CanonicalJSONAssumeValid(eventJSON)

res := gjson.GetBytes(eventJSON, "unsigned")
unsigned := RawJSONFromResult(res, eventJSON)
e.eventFields.Unsigned = unsigned
e.eventFields.Unsigned = []byte(res.Raw)

e.eventJSON = eventJSON

Expand Down
8 changes: 5 additions & 3 deletions eventauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,15 +301,17 @@ func (a *AuthEvents) Clear() {

// NewAuthEvents returns an AuthEventProvider backed by the given events. New events can be added by
// calling AddEvent().
func NewAuthEvents(events []PDU) AuthEvents {
func NewAuthEvents(events []PDU) (*AuthEvents, error) {
a := AuthEvents{
events: make(map[StateKeyTuple]PDU, len(events)),
roomIDs: make(map[string]struct{}),
}
for _, e := range events {
a.AddEvent(e) // nolint: errcheck
if err := a.AddEvent(e); err != nil {
return nil, err
}
}
return a
return &a, nil
}

// A NotAllowed error is returned if an event does not pass the auth checks.
Expand Down
14 changes: 7 additions & 7 deletions eventauth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ func TestStateNeededForMessage(t *testing.T) {
// Message events need the create event, the sender and the power_levels.
testStateNeededForAuth(t, `[{
"type": "m.room.message",
"sender": "@u1:a",
"sender": "@u1:a",
"room_id": "!r1:a"
}]`, &ProtoEvent{
Type: "m.room.message",
Expand Down Expand Up @@ -139,7 +139,7 @@ func TestStateNeededForJoin(t *testing.T) {
"type": "m.room.member",
"state_key": "@u1:a",
"sender": "@u1:a",
"content": {"membership": "join"},
"content": {"membership": "join"},
"room_id": "!r1:a"
}]`, &b, StateNeeded{
Create: true,
Expand All @@ -163,7 +163,7 @@ func TestStateNeededForInvite(t *testing.T) {
"type": "m.room.member",
"state_key": "@u2:b",
"sender": "@u1:a",
"content": {"membership": "invite"},
"content": {"membership": "invite"},
"room_id": "!r1:a"
}]`, &b, StateNeeded{
Create: true,
Expand Down Expand Up @@ -199,7 +199,7 @@ func TestStateNeededForInvite3PID(t *testing.T) {
"token": "my_token"
}
}
},
},
"room_id": "!r1:a"
}]`, &b, StateNeeded{
Create: true,
Expand Down Expand Up @@ -1035,7 +1035,7 @@ func TestAuthEvents(t *testing.T) {
if err != nil {
t.Fatalf("TestAuthEvents: failed to create power_levels event: %s", err)
}
a := NewAuthEvents([]PDU{power})
a, _ := NewAuthEvents([]PDU{power})
var e PDU
if e, err = a.PowerLevels(); err != nil || e != power {
t.Errorf("TestAuthEvents: failed to get same power_levels event")
Expand Down Expand Up @@ -1685,15 +1685,15 @@ func TestMembershipBanned(t *testing.T) {
"state_key": "@u3:a",
"event_id": "$e4:a",
"content": {"membership": "ban"}
},
},
{
"type": "m.room.member",
"sender": "@u2:a",
"room_id": "!r1:a",
"state_key": "@u3:a",
"event_id": "$e4:a",
"content": {"membership": "ban"}
},
},
{
"type": "m.room.member",
"sender": "@u2:a",
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,4 @@ require (
gopkg.in/yaml.v3 v3.0.1 // indirect
)

go 1.18
go 1.22
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚀

8 changes: 5 additions & 3 deletions handleinvite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,19 @@ func (r *TestStateQuerier) GetAuthEvents(ctx context.Context, event PDU) (AuthEv
return nil, fmt.Errorf("failed getting auth provider")
}

eventProvider := AuthEvents{}
eventProvider, _ := NewAuthEvents(nil)
if r.createEvent != nil {
eventProvider = NewAuthEvents([]PDU{r.createEvent})
if err := eventProvider.AddEvent(r.createEvent); err != nil {
return nil, err
}
if r.inviterMemberEvent != nil {
err := eventProvider.AddEvent(r.inviterMemberEvent)
if err != nil {
return nil, err
}
}
}
return &eventProvider, nil
return eventProvider, nil
}

func (r *TestStateQuerier) GetState(ctx context.Context, roomID spec.RoomID, stateWanted []StateKeyTuple) ([]PDU, error) {
Expand Down
7 changes: 5 additions & 2 deletions handlejoin.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,11 @@ func HandleMakeJoin(input HandleMakeJoinInput) (*HandleMakeJoinResponse, error)
return nil, spec.InternalServerError{Err: fmt.Sprintf("expected join event from template builder. got: %s", event.Type())}
}

provider := NewAuthEvents(state)
if err = Allowed(event, &provider, input.UserIDQuerier); err != nil {
provider, err := NewAuthEvents(state)
if err != nil {
return nil, spec.Forbidden(err.Error())
}
if err = Allowed(event, provider, input.UserIDQuerier); err != nil {
return nil, spec.Forbidden(err.Error())
}

Expand Down
1 change: 1 addition & 0 deletions handlejoin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,7 @@ func TestHandleMakeJoinNilContext(t *testing.T) {
})
}

//nolint:unparam
func createMemberEventBuilder(roomVersion RoomVersion, sender string, roomID string, stateKey *string, content spec.RawJSON) *EventBuilder {
return MustGetRoomVersion(roomVersion).NewEventBuilderFromProtoEvent(&ProtoEvent{
SenderID: sender,
Expand Down
7 changes: 5 additions & 2 deletions handleleave.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,11 @@ func HandleMakeLeave(input HandleMakeLeaveInput) (*HandleMakeLeaveResponse, erro
return nil, spec.InternalServerError{Err: fmt.Sprintf("expected leave event from template builder. got: %s", event.Type())}
}

provider := NewAuthEvents(stateEvents)
if err := Allowed(event, &provider, input.UserIDQuerier); err != nil {
provider, err := NewAuthEvents(stateEvents)
if err != nil {
return nil, spec.Forbidden(err.Error())
}
if err = Allowed(event, provider, input.UserIDQuerier); err != nil {
return nil, spec.Forbidden(err.Error())
}

Expand Down
63 changes: 24 additions & 39 deletions json.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package gomatrixserverlib
import (
"encoding/binary"
"errors"
"sort"
"slices"
"strings"
"unicode/utf16"
"unicode/utf8"
Expand Down Expand Up @@ -159,40 +159,33 @@ func CanonicalJSONAssumeValid(input []byte) []byte {
// by codepoint. The input must be valid JSON.
func SortJSON(input, output []byte) []byte {
result := gjson.ParseBytes(input)

RawJSON := RawJSONFromResult(result, input)
return sortJSONValue(result, RawJSON, output)
return sortJSONValue(result, output)
}

// sortJSONValue takes a gjson.Result and sorts it. inputJSON must be the
// raw JSON bytes that gjson.Result points to.
func sortJSONValue(input gjson.Result, inputJSON, output []byte) []byte {
func sortJSONValue(input gjson.Result, output []byte) []byte {
if input.IsArray() {
return sortJSONArray(input, inputJSON, output)
return sortJSONArray(input, output)
}

if input.IsObject() {
return sortJSONObject(input, inputJSON, output)
return sortJSONObject(input, output)
}

// If its neither an object nor an array then there is no sub structure
// to sort, so just append the raw bytes.
return append(output, inputJSON...)
return append(output, input.Raw...)
}

// sortJSONArray takes a gjson.Result and sorts it, assuming its an array.
// inputJSON must be the raw JSON bytes that gjson.Result points to.
func sortJSONArray(input gjson.Result, inputJSON, output []byte) []byte {
func sortJSONArray(input gjson.Result, output []byte) []byte {
sep := byte('[')

// Iterate over each value in the array and sort it.
input.ForEach(func(_, value gjson.Result) bool {
output = append(output, sep)
sep = ','

RawJSON := RawJSONFromResult(value, inputJSON)
output = sortJSONValue(value, RawJSON, output)

output = sortJSONValue(value, output)
return true // keep iterating
})

Expand All @@ -209,29 +202,30 @@ func sortJSONArray(input gjson.Result, inputJSON, output []byte) []byte {

// sortJSONObject takes a gjson.Result and sorts it, assuming its an object.
// inputJSON must be the raw JSON bytes that gjson.Result points to.
func sortJSONObject(input gjson.Result, inputJSON, output []byte) []byte {
func sortJSONObject(input gjson.Result, output []byte) []byte {
type entry struct {
key string // The parsed key string
rawKey []byte // The raw, unparsed key JSON string
value gjson.Result
key string // The parsed key string
value gjson.Result
}

var entries []entry
// Try to stay on the stack here if we can.
var _entries [128]entry
entries := _entries[:0]

// Iterate over each key/value pair and add it to a slice
// that we can sort
input.ForEach(func(key, value gjson.Result) bool {
entries = append(entries, entry{
key: key.String(),
rawKey: RawJSONFromResult(key, inputJSON),
value: value,
key: key.String(),
value: value,
})
return true // keep iterating
})

// Sort the slice based on the *parsed* key
sort.Slice(entries, func(a, b int) bool {
return entries[a].key < entries[b].key
// Using slices.SortFunc here instead of sort.Slice avoids
// heap escapes due to reflection.
slices.SortFunc(entries, func(a, b entry) int {
return strings.Compare(a.key, b.key)
})

sep := byte('{')
Expand All @@ -241,12 +235,10 @@ func sortJSONObject(input gjson.Result, inputJSON, output []byte) []byte {
sep = ','

// Append the raw unparsed JSON key, *not* the parsed key
output = append(output, entry.rawKey...)
output = append(output, ':')

RawJSON := RawJSONFromResult(entry.value, inputJSON)

output = sortJSONValue(entry.value, RawJSON, output)
output = append(output, '"')
output = append(output, entry.key...)
output = append(output, '"', ':')
output = sortJSONValue(entry.value, output)
}
if sep == '{' {
// If sep is still '{' then the object was empty and we never wrote the
Expand Down Expand Up @@ -375,10 +367,3 @@ func readHexDigits(input []byte) rune {
hex |= hex >> 8
return rune(hex & 0xFFFF)
}

// RawJSONFromResult extracts the raw JSON bytes pointed to by result.
// input must be the json bytes that were used to generate result
// TODO: Why do we do this?
func RawJSONFromResult(result gjson.Result, _ []byte) []byte {
return []byte(result.Raw)
}
4 changes: 2 additions & 2 deletions performinvite.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ func PerformInvite(ctx context.Context, input PerformInviteInput, fedClient Fede

input.EventTemplate.Depth = latestEvents.Depth

authEvents := NewAuthEvents(nil)
authEvents, _ := NewAuthEvents(nil)

for _, event := range latestEvents.StateEvents {
err := authEvents.AddEvent(event)
Expand All @@ -132,7 +132,7 @@ func PerformInvite(ctx context.Context, input PerformInviteInput, fedClient Fede
}
}

refs, err := stateNeeded.AuthEventReferences(&authEvents)
refs, err := stateNeeded.AuthEventReferences(authEvents)
if err != nil {
return nil, fmt.Errorf("eventsNeeded.AuthEventReferences: %w", err)
}
Expand Down
3 changes: 1 addition & 2 deletions performjoin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (
"github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib/spec"
"github.com/stretchr/testify/assert"
ed255192 "golang.org/x/crypto/ed25519"
)

type TestMakeJoinResponse struct {
Expand Down Expand Up @@ -388,7 +387,7 @@ func TestPerformJoinPseudoID(t *testing.T) {
return res, nil
}

idCreator := func(ctx context.Context, userID spec.UserID, roomID spec.RoomID, roomVersion string) (spec.SenderID, ed255192.PrivateKey, error) {
idCreator := func(ctx context.Context, userID spec.UserID, roomID spec.RoomID, roomVersion string) (spec.SenderID, ed25519.PrivateKey, error) {
return spec.SenderIDFromPseudoIDKey(userPriv), userPriv, nil
}

Expand Down
Loading
Loading