Skip to content
This repository has been archived by the owner on Jan 31, 2024. It is now read-only.

Commit

Permalink
add callbacks to store and restore app data along a session state
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann committed Mar 18, 2023
1 parent 73f8bcb commit 97fbf25
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 26 deletions.
27 changes: 19 additions & 8 deletions common.go
Original file line number Diff line number Diff line change
Expand Up @@ -793,19 +793,30 @@ type ExtraConfig struct {
//
// It has no meaning to the server.
Enable0RTT bool

// Is called when the client saves a session ticket to the session ticket.
// This gives the application the opportunity to save some data along with the ticket,
// which can be restored when the session ticket is used.
GetAppDataForSessionState func() []byte

// Is called when the client uses a session ticket.
// Restores the application data that was saved earlier on GetAppDataForSessionTicket.
SetAppDataFromSessionState func([]byte)
}

// Clone clones.
func (c *ExtraConfig) Clone() *ExtraConfig {
return &ExtraConfig{
GetExtensions: c.GetExtensions,
ReceivedExtensions: c.ReceivedExtensions,
AlternativeRecordLayer: c.AlternativeRecordLayer,
EnforceNextProtoSelection: c.EnforceNextProtoSelection,
MaxEarlyData: c.MaxEarlyData,
Enable0RTT: c.Enable0RTT,
Accept0RTT: c.Accept0RTT,
Rejected0RTT: c.Rejected0RTT,
GetExtensions: c.GetExtensions,
ReceivedExtensions: c.ReceivedExtensions,
AlternativeRecordLayer: c.AlternativeRecordLayer,
EnforceNextProtoSelection: c.EnforceNextProtoSelection,
MaxEarlyData: c.MaxEarlyData,
Enable0RTT: c.Enable0RTT,
Accept0RTT: c.Accept0RTT,
Rejected0RTT: c.Rejected0RTT,
GetAppDataForSessionState: c.GetAppDataForSessionState,
SetAppDataFromSessionState: c.SetAppDataFromSessionState,
}
}

Expand Down
54 changes: 45 additions & 9 deletions handshake_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,19 @@ import (
"crypto/rsa"
"crypto/subtle"
"crypto/x509"
"encoding/binary"
"errors"
"fmt"
"hash"
"io"
"net"
"strings"
"time"

"golang.org/x/crypto/cryptobyte"
)

const clientSessionStateVersion = 1

type clientHandshakeState struct {
c *Conn
ctx context.Context
Expand Down Expand Up @@ -290,6 +293,33 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) {
return nil
}

// extract the app data saved in the session.nonce,
// and set the session.nonce to the actual nonce value
func (c *Conn) decodeSessionState(session *clientSessionState) (uint32 /* max early data */, []byte /* app data */, bool /* ok */) {
s := cryptobyte.String(session.nonce)
var version uint16
if !s.ReadUint16(&version) {
return 0, nil, false
}
if version != clientSessionStateVersion {
return 0, nil, false
}
var maxEarlyData uint32
if !s.ReadUint32(&maxEarlyData) {
return 0, nil, false
}
var appData []byte
if !readUint16LengthPrefixed(&s, &appData) {
return 0, nil, false
}
var nonce []byte
if !readUint16LengthPrefixed(&s, &nonce) {
return 0, nil, false
}
session.nonce = nonce
return maxEarlyData, appData, true
}

func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string,
session *clientSessionState, earlySecret, binderKey []byte, err error) {
if c.config.SessionTicketsDisabled || c.config.ClientSessionCache == nil {
Expand Down Expand Up @@ -319,6 +349,17 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string,
}
session = fromClientSessionState(sess)

var appData []byte
var maxEarlyData uint32
if session.vers == VersionTLS13 {
var ok bool
maxEarlyData, appData, ok = c.decodeSessionState(session)
if !ok { // delete it, if parsing failed
c.config.ClientSessionCache.Put(cacheKey, nil)
return cacheKey, nil, nil, nil, nil
}
}

// Check that version used for the previous session is still valid.
versOk := false
for _, v := range hello.supportedVersions {
Expand Down Expand Up @@ -361,14 +402,6 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string,
return
}

// In TLS 1.3, we abuse the nonce field to save the max_early_data_size.
// See Conn.handleNewSessionTicket for an explanation of this hack.
if len(session.nonce) < 4 {
return cacheKey, nil, nil, nil, nil
}
maxEarlyData := binary.BigEndian.Uint32(session.nonce[:4])
session.nonce = session.nonce[4:]

// Check that the session ticket is not expired.
if c.config.time().After(session.useBy) {
c.config.ClientSessionCache.Put(cacheKey, nil)
Expand Down Expand Up @@ -421,6 +454,9 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string,
return "", nil, nil, nil, err
}

if session.vers == VersionTLS13 && c.extraConfig != nil && c.extraConfig.SetAppDataFromSessionState != nil {
c.extraConfig.SetAppDataFromSessionState(appData)
}
return
}

Expand Down
34 changes: 28 additions & 6 deletions handshake_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -874,12 +874,14 @@ func TestClientKeyUpdate(t *testing.T) {
}

func TestResumption(t *testing.T) {
t.Run("TLSv12", func(t *testing.T) { testResumption(t, VersionTLS12) })
t.Run("TLSv13", func(t *testing.T) { testResumption(t, VersionTLS13) })
t.Run("TLSv13 with 0-RTT", testResumption0RTT)
t.Run("TLSv12", func(t *testing.T) { testResumption(t, VersionTLS12, false) })
t.Run("TLSv13", func(t *testing.T) { testResumption(t, VersionTLS13, false) })
t.Run("TLSv13, saving app data", func(t *testing.T) { testResumption(t, VersionTLS13, true) })
t.Run("TLSv13, with 0-RTT", func(t *testing.T) { testResumption0RTT(t, false) })
t.Run("TLSv13, with 0-RTT, saving app data", func(t *testing.T) { testResumption0RTT(t, true) })
}

func testResumption(t *testing.T, version uint16) {
func testResumption(t *testing.T, version uint16, saveAppData bool) {
if testing.Short() {
t.Skip("skipping in -short mode")
}
Expand All @@ -897,16 +899,22 @@ func testResumption(t *testing.T, version uint16) {
rootCAs := x509.NewCertPool()
rootCAs.AddCert(issuer)

var restoredAppData []byte
clientConfig := &Config{
MaxVersion: version,
CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA},
ClientSessionCache: NewLRUClientSessionCache(32),
RootCAs: rootCAs,
ServerName: "example.golang",
}
clientExtraConfig := &ExtraConfig{}
if saveAppData {
clientExtraConfig.GetAppDataForSessionState = func() []byte { return []byte("foobar") }
clientExtraConfig.SetAppDataFromSessionState = func(data []byte) { restoredAppData = data }
}

testResumeState := func(test string, didResume bool) {
_, hs, err := testHandshake(t, clientConfig, serverConfig)
_, hs, err := testHandshakeWithExtraConfig(t, clientConfig, clientExtraConfig, serverConfig, nil)
if err != nil {
t.Fatalf("%s: handshake failed: %s", test, err)
}
Expand All @@ -919,6 +927,12 @@ func testResumption(t *testing.T, version uint16) {
if got, want := hs.ServerName, clientConfig.ServerName; got != want {
t.Errorf("%s: server name %s, want %s", test, got, want)
}
if didResume && saveAppData {
if !bytes.Equal(restoredAppData, []byte("foobar")) {
t.Fatalf("Expected to restore app data saved with the session state. Got: %#v", restoredAppData)
}
restoredAppData = nil
}
}

getTicket := func() []byte {
Expand Down Expand Up @@ -1057,7 +1071,7 @@ func testResumption(t *testing.T, version uint16) {
testResumeState("WithoutSessionCache", false)
}

func testResumption0RTT(t *testing.T) {
func testResumption0RTT(t *testing.T, saveAppData bool) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()

Expand All @@ -1071,6 +1085,11 @@ func testResumption0RTT(t *testing.T) {
clientConfig := testConfig.Clone()
clientConfig.ClientSessionCache = cache
clientExtraConfig := &ExtraConfig{Enable0RTT: true}
var restoredAppData []byte
if saveAppData {
clientExtraConfig.GetAppDataForSessionState = func() []byte { return []byte("foobar") }
clientExtraConfig.SetAppDataFromSessionState = func(data []byte) { restoredAppData = data }
}

// check that the ticket is deleted when 0-RTT is used
var state *ClientSessionState
Expand Down Expand Up @@ -1122,6 +1141,9 @@ func testResumption0RTT(t *testing.T) {
if hs.Used0RTT {
t.Fatal("should not have used 0-RTT during the second handshake")
}
if saveAppData && !bytes.Equal(restoredAppData, []byte("foobar")) {
t.Fatalf("expected app data to be restored. Got: %#v", restoredAppData)
}
}

func TestLRUClientSessionCache(t *testing.T) {
Expand Down
18 changes: 17 additions & 1 deletion handshake_client_tls13.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import (
"errors"
"hash"
"time"

"golang.org/x/crypto/cryptobyte"
)

type clientHandshakeStateTLS13 struct {
Expand Down Expand Up @@ -716,6 +718,20 @@ func (c *Conn) handleNewSessionTicket(msg *newSessionTicketMsgTLS13) error {
binary.BigEndian.PutUint32(nonceWithEarlyData, msg.maxEarlyData)
copy(nonceWithEarlyData[4:], msg.nonce)

var appData []byte
if c.extraConfig != nil && c.extraConfig.GetAppDataForSessionState != nil {
appData = c.extraConfig.GetAppDataForSessionState()
}
var b cryptobyte.Builder
b.AddUint16(clientSessionStateVersion) // revision
b.AddUint32(msg.maxEarlyData)
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(appData)
})
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(msg.nonce)
})

// Save the resumption_master_secret and nonce instead of deriving the PSK
// to do the least amount of work on NewSessionTicket messages before we
// know if the ticket will be used. Forward secrecy of resumed connections
Expand All @@ -728,7 +744,7 @@ func (c *Conn) handleNewSessionTicket(msg *newSessionTicketMsgTLS13) error {
serverCertificates: c.peerCertificates,
verifiedChains: c.verifiedChains,
receivedAt: c.config.time(),
nonce: nonceWithEarlyData,
nonce: b.BytesOrPanic(),
useBy: c.config.time().Add(lifetime),
ageAdd: msg.ageAdd,
ocspResponse: c.ocspResponse,
Expand Down
13 changes: 11 additions & 2 deletions tls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -855,7 +855,7 @@ func TestCloneNilConfig(t *testing.T) {
}

func TestExtraConfigCloneFuncField(t *testing.T) {
const expectedCount = 4
const expectedCount = 6
called := 0

c1 := ExtraConfig{
Expand All @@ -873,13 +873,22 @@ func TestExtraConfigCloneFuncField(t *testing.T) {
Rejected0RTT: func() {
called |= 1 << 3
},
GetAppDataForSessionState: func() []byte {
called |= 1 << 4
return nil
},
SetAppDataFromSessionState: func([]byte) {
called |= 1 << 5
},
}

c2 := c1.Clone()
c2.GetExtensions(0)
c2.ReceivedExtensions(0, nil)
c2.Accept0RTT(nil)
c2.Rejected0RTT()
c2.GetAppDataForSessionState()
c2.SetAppDataFromSessionState(nil)
if called != (1<<expectedCount)-1 {
t.Fatalf("expected %d calls but saw calls %b", expectedCount, called)
}
Expand All @@ -895,7 +904,7 @@ func TestExtraConfigCloneNonFuncFields(t *testing.T) {
// testing/quick can't handle functions or interfaces and so
// isn't used here.
switch fn := typ.Field(i).Name; fn {
case "GetExtensions", "ReceivedExtensions", "Accept0RTT", "Rejected0RTT":
case "GetExtensions", "ReceivedExtensions", "Accept0RTT", "Rejected0RTT", "GetAppDataForSessionState", "SetAppDataFromSessionState":
// DeepEqual can't compare functions. If you add a
// function field to this list, you must also change
// TestCloneFuncFields to ensure that the func field is
Expand Down

0 comments on commit 97fbf25

Please sign in to comment.