diff --git a/src/matrix/SessionContainer.js b/src/matrix/SessionContainer.js index ebb9e07e45..5aea794680 100644 --- a/src/matrix/SessionContainer.js +++ b/src/matrix/SessionContainer.js @@ -21,6 +21,7 @@ import {Reconnector, ConnectionStatus} from "./net/Reconnector.js"; import {ExponentialRetryDelay} from "./net/ExponentialRetryDelay.js"; import {MediaRepository} from "./net/MediaRepository.js"; import {RequestScheduler} from "./net/RequestScheduler.js"; +import {TokenRefresher} from "./net/TokenRefresher.js"; import {HomeServerError, ConnectionError, AbortError} from "./error.js"; import {Sync, SyncStatus} from "./Sync.js"; import {Session} from "./Session.js"; @@ -105,7 +106,13 @@ export class SessionContainer { accessToken: loginData.access_token, lastUsed: clock.now() }; - await this._platform.sessionInfoStorage.add(sessionInfo); + + if (loginData.refresh_token) { + sessionInfo.accessTokenExpiresAt = clock.now() + loginData.expires_in * 1000; + sessionInfo.refreshToken = loginData.refresh_token; + } + + await this._platform.sessionInfoStorage.add(sessionInfo); } catch (err) { this._error = err; if (err instanceof HomeServerError) { @@ -143,13 +150,31 @@ export class SessionContainer { retryDelay: new ExponentialRetryDelay(clock.createTimeout), createMeasure: clock.createMeasure }); + + let accessToken; + if (sessionInfo.refreshToken) { + this._tokenRefresher = new TokenRefresher({ + accessToken: sessionInfo.accessToken, + accessTokenExpiresAt: sessionInfo.accessTokenExpiresAt, + refreshToken: sessionInfo.refreshToken, + anticipation: 10 * 1000, // Refresh 10 seconds before the expiration + clock, + }); + accessToken = this._tokenRefresher.accessToken; + } else { + accessToken = new ObservableValue(sessionInfo.accessToken); + } + const hsApi = new HomeServerApi({ homeServer: sessionInfo.homeServer, - accessToken: sessionInfo.accessToken, + accessToken, request: this._platform.request, reconnector: this._reconnector, createTimeout: clock.createTimeout }); + if (this._tokenRefresher) { + await this._tokenRefresher.start(hsApi); + } this._sessionId = sessionInfo.id; this._storage = await this._platform.storageFactory.create(sessionInfo.id); // no need to pass access token to session diff --git a/src/matrix/net/HomeServerApi.js b/src/matrix/net/HomeServerApi.js index fa560db62b..3d036aefef 100644 --- a/src/matrix/net/HomeServerApi.js +++ b/src/matrix/net/HomeServerApi.js @@ -101,6 +101,10 @@ export class HomeServerApi { return `${this._homeserver}/_matrix/client/r0${csPath}`; } + _unstableUrl(feature, csPath) { + return `${this._homeserver}/_matrix/client/unstable/${feature}${csPath}`; + } + _baseRequest(method, url, queryParams, body, options, accessToken) { const queryString = encodeQueryParams(queryParams); url = `${url}?${queryString}`; @@ -157,7 +161,7 @@ export class HomeServerApi { } _authedRequest(method, url, queryParams, body, options) { - return this._baseRequest(method, url, queryParams, body, options, this._accessToken); + return this._baseRequest(method, url, queryParams, body, options, this._accessToken.get()); } _post(csPath, queryParams, body, options) { @@ -196,7 +200,9 @@ export class HomeServerApi { } passwordLogin(username, password, initialDeviceDisplayName, options = null) { - return this._unauthedRequest("POST", this._url("/login"), null, { + return this._unauthedRequest("POST", this._url("/login"), { + "org.matrix.msc2918.refresh_token": "true" + }, { "type": "m.login.password", "identifier": { "type": "m.id.user", @@ -207,6 +213,12 @@ export class HomeServerApi { }, options); } + refreshToken(token, options = null) { + return this._unauthedRequest("POST", this._unstableUrl("org.matrix.msc2918.refresh_token", "/refresh"), null, { + "refresh_token": token + }, options); + } + createFilter(userId, filter, options = null) { return this._post(`/user/${encodeURIComponent(userId)}/filter`, null, filter, options); } diff --git a/src/matrix/net/TokenRefresher.js b/src/matrix/net/TokenRefresher.js new file mode 100644 index 0000000000..e265065641 --- /dev/null +++ b/src/matrix/net/TokenRefresher.js @@ -0,0 +1,70 @@ +import { ObservableValue } from "../../observable/ObservableValue.js"; + +export class TokenRefresher { + constructor({ + refreshToken, + accessToken, + accessTokenExpiresAt, + anticipation, + clock, + }) { + this._refreshToken = new ObservableValue(refreshToken); + this._accessToken = new ObservableValue(accessToken); + this._accessTokenExpiresAt = new ObservableValue(accessTokenExpiresAt); + this._anticipation = anticipation; + this._clock = clock; + } + + async start(hsApi) { + this._hsApi = hsApi; + if (this.needsRenewing) { + await this.renew(); + } + + this._renewingLoop(); + } + + get needsRenewing() { + const remaining = this._accessTokenExpiresAt.get() - this._clock.now(); + const anticipated = remaining - this._anticipation; + return anticipated < 0; + } + + async _renewingLoop() { + while (true) { + const remaining = + this._accessTokenExpiresAt.get() - this._clock.now(); + const anticipated = remaining - this._anticipation; + + if (anticipated > 0) { + this._timeout = this._clock.createTimeout(anticipated); + await this._timeout.elapsed(); + } + + await this.renew(); + } + } + + async renew() { + const response = await this._hsApi + .refreshToken(this._refreshToken.get()) + .response(); + + if (response["refresh_token"]) { + this._refreshToken.set(response["refresh_token"]); + } + + this._accessToken.set(response["access_token"]); + this._accessTokenExpiresAt.set( + this._clock.now() + response["expires_in"] * 1000 + ); + } + + get accessToken() { + return this._accessToken; + } + + get refreshToken() { + return this._refreshToken; + } +}