diff --git a/packages/core/src/network/mtproto-session.ts b/packages/core/src/network/mtproto-session.ts index ea4777d0..eae8a137 100644 --- a/packages/core/src/network/mtproto-session.ts +++ b/packages/core/src/network/mtproto-session.ts @@ -17,6 +17,7 @@ import { SortedArray, } from '../utils/index.js' import { AuthKey } from './auth-key.js' +import { ServerSaltManager } from './server-salt.js' export interface PendingRpc { method: string @@ -98,8 +99,6 @@ export class MtprotoSession { _lastMessageId = Long.ZERO _seqNo = 0 - serverSalt = Long.ZERO - /// state /// // recent msg ids recentOutgoingMsgIds = new LruSet(1000, true) @@ -137,6 +136,7 @@ export class MtprotoSession { readonly log: Logger, readonly _readerMap: TlReaderMap, readonly _writerMap: TlWriterMap, + readonly _salts: ServerSaltManager, ) { this.log.prefix = `[SESSION ${this._sessionId.toString(16)}] ` } @@ -254,7 +254,7 @@ export class MtprotoSession { encryptMessage(message: Uint8Array): Uint8Array { const key = this._authKeyTemp.ready ? this._authKeyTemp : this._authKey - return key.encryptMessage(message, this.serverSalt, this._sessionId) + return key.encryptMessage(message, this._salts.currentSalt, this._sessionId) } /** Decrypt a single MTProto message using session's keys */ diff --git a/packages/core/src/network/multi-session-connection.ts b/packages/core/src/network/multi-session-connection.ts index 6df93cf8..f0e4a6f9 100644 --- a/packages/core/src/network/multi-session-connection.ts +++ b/packages/core/src/network/multi-session-connection.ts @@ -88,6 +88,7 @@ export class MultiSessionConnection extends EventEmitter { this._log.create('session'), this.params.readerMap, this.params.writerMap, + this.params.salts, ) // brvh diff --git a/packages/core/src/network/network-manager.ts b/packages/core/src/network/network-manager.ts index a3aaaa65..07b53292 100644 --- a/packages/core/src/network/network-manager.ts +++ b/packages/core/src/network/network-manager.ts @@ -9,6 +9,7 @@ import { ConfigManager } from './config-manager.js' import { MultiSessionConnection } from './multi-session-connection.js' import { PersistentConnectionParams } from './persistent-connection.js' import { defaultReconnectionStrategy, ReconnectionStrategy } from './reconnection.js' +import { ServerSaltManager } from './server-salt.js' import { SessionConnection, SessionConnectionParams } from './session-connection.js' import { defaultTransportFactory, TransportFactory } from './transports/index.js' @@ -170,6 +171,7 @@ export interface RpcCallOptions { * Wrapper over all connection pools for a single DC. */ export class DcConnectionManager { + private _salts = new ServerSaltManager() private __baseConnectionParams = (): SessionConnectionParams => ({ crypto: this.manager.params.crypto, initConnection: this.manager._initConnectionParams, @@ -186,6 +188,7 @@ export class DcConnectionManager { isMainDcConnection: this.isPrimary, inactivityTimeout: this.manager.params.inactivityTimeout ?? 60_000, enableErrorReporting: this.manager.params.enableErrorReporting, + salts: this._salts, }) private _log = this.manager._log.create('dc-manager') @@ -379,6 +382,14 @@ export class DcConnectionManager { return true } + + destroy() { + this.main.destroy() + this.upload.destroy() + this.download.destroy() + this.downloadSmall.destroy() + this._salts.destroy() + } } /** @@ -812,10 +823,7 @@ export class NetworkManager { destroy(): void { for (const dc of this._dcConnections.values()) { - dc.main.destroy() - dc.upload.destroy() - dc.download.destroy() - dc.downloadSmall.destroy() + dc.destroy() } if (this._keepAliveInterval) clearInterval(this._keepAliveInterval) this.config.offConfigUpdate(this._onConfigChanged) diff --git a/packages/core/src/network/server-salt.ts b/packages/core/src/network/server-salt.ts new file mode 100644 index 00000000..9d8e1788 --- /dev/null +++ b/packages/core/src/network/server-salt.ts @@ -0,0 +1,48 @@ +import EventEmitter from 'events' +import Long from 'long' + +import { mtp } from '@mtcute/tl' + +export class ServerSaltManager extends EventEmitter { + private _futureSalts: mtp.RawMt_future_salt[] = [] + + currentSalt = Long.ZERO + + isFetching = false + + shouldFetchSalts(): boolean { + return !this.isFetching && !this.currentSalt.isZero() && this._futureSalts.length < 2 + } + + setFutureSalts(salts: mtp.RawMt_future_salt[]): void { + this._futureSalts = salts + + if (Date.now() > salts[0].validSince * 1000) { + this.currentSalt = salts[0].salt + this._futureSalts.shift() + } + + this._scheduleNext() + } + + private _timer?: NodeJS.Timeout + + private _scheduleNext(): void { + if (this._timer) clearTimeout(this._timer) + if (this._futureSalts.length === 0) return + + const next = this._futureSalts.shift()! + + this._timer = setTimeout( + () => { + this.currentSalt = next.salt + this._scheduleNext() + }, + next.validSince * 1000 - Date.now(), + ) + } + + destroy(): void { + clearTimeout(this._timer) + } +} diff --git a/packages/core/src/network/session-connection.ts b/packages/core/src/network/session-connection.ts index d8aa2b6a..af22771c 100644 --- a/packages/core/src/network/session-connection.ts +++ b/packages/core/src/network/session-connection.ts @@ -21,6 +21,7 @@ import { reportUnknownError } from '../utils/platform/error-reporting.js' import { doAuthorization } from './authorization.js' import { MtprotoSession, PendingMessage, PendingRpc } from './mtproto-session.js' import { PersistentConnection, PersistentConnectionParams } from './persistent-connection.js' +import { ServerSaltManager } from './server-salt.js' import { TransportError } from './transports/abstract.js' export interface SessionConnectionParams extends PersistentConnectionParams { @@ -35,6 +36,8 @@ export interface SessionConnectionParams extends PersistentConnectionParams { isMainDcConnection: boolean usePfs?: boolean + salts: ServerSaltManager + readerMap: TlReaderMap writerMap: TlWriterMap } @@ -84,6 +87,7 @@ export class SessionConnection extends PersistentConnection { private _readerMap: TlReaderMap private _writerMap: TlWriterMap private _crypto: ICryptoProvider + private _salts: ServerSaltManager constructor( params: SessionConnectionParams, @@ -95,6 +99,7 @@ export class SessionConnection extends PersistentConnection { this._readerMap = params.readerMap this._writerMap = params.writerMap this._crypto = params.crypto + this._salts = params.salts this._handleRawMessage = this._handleRawMessage.bind(this) } @@ -143,6 +148,7 @@ export class SessionConnection extends PersistentConnection { reset(forever = false): void { this._session.initConnectionCalled = false this._flushTimer.reset() + this._salts.isFetching = false if (forever) { this.removeAllListeners() @@ -273,7 +279,7 @@ export class SessionConnection extends PersistentConnection { doAuthorization(this, this._crypto) .then(([authKey, serverSalt, timeOffset]) => { this._session._authKey.setup(authKey) - this._session.serverSalt = serverSalt + this._salts.currentSalt = serverSalt this._session._timeOffset = timeOffset this._session.authorizationPending = false @@ -430,7 +436,7 @@ export class SessionConnection extends PersistentConnection { this._session._authKeyTempSecondary = this._session._authKeyTemp this._session._authKeyTemp = tempKey - this._session.serverSalt = tempServerSalt + this._salts.currentSalt = tempServerSalt this.log.debug('temp key has been bound, exp = %d', inner.expiresAt) @@ -608,7 +614,7 @@ export class SessionConnection extends PersistentConnection { this._onMsgsStateInfo(message) break case 'mt_future_salts': - // todo + this._onFutureSalts(message) break case 'mt_msgs_state_req': case 'mt_msg_resend_req': @@ -1022,6 +1028,12 @@ export class SessionConnection extends PersistentConnection { case 'bind': this.log.debug('temp key binding request %l failed because of %s, retrying', msgId, reason) msgInfo.promise.reject(Error(reason)) + break + case 'future_salts': + this.log.debug('future_salts request %l failed because of %s, will retry', msgId, reason) + this._salts.isFetching = false + // this is enough to make it retry on the next flush. + break } this._session.pendingMessages.delete(msgId) @@ -1064,7 +1076,7 @@ export class SessionConnection extends PersistentConnection { } private _onBadServerSalt(msg: mtp.RawMt_bad_server_salt): void { - this._session.serverSalt = msg.newServerSalt + this._salts.currentSalt = msg.newServerSalt this._onMessageFailed(msg.badMsgId, 'bad_server_salt') } @@ -1114,7 +1126,7 @@ export class SessionConnection extends PersistentConnection { this.emit('update', { _: 'updatesTooLong' }) } - this._session.serverSalt = serverSalt + this._salts.currentSalt = serverSalt this.log.debug('received new_session_created, uid = %l, first msg_id = %l', uniqueId, firstMsgId) @@ -1230,6 +1242,25 @@ export class SessionConnection extends PersistentConnection { this._onMessagesInfo(info.msgIds, msg.info) } + private _onFutureSalts(msg: mtp.RawMt_future_salts): void { + const info = this._session.pendingMessages.get(msg.reqMsgId) + + if (!info) { + this.log.warn('received future_salts to unknown request %l', msg.reqMsgId) + + return + } + + if (info._ !== 'future_salts') { + this.log.warn('received future_salts to %s query %l', info._, msg.reqMsgId) + + return + } + + this._salts.isFetching = false + this._salts.setFutureSalts(msg.salts) + } + private _onDestroySessionResult(msg: mtp.TypeDestroySessionRes): void { const reqMsgId = this._session.destroySessionIdToMsgId.get(msg.sessionId) @@ -1504,6 +1535,9 @@ export class SessionConnection extends PersistentConnection { let ackRequest: Uint8Array | null = null let ackMsgIds: Long[] | null = null + let getFutureSaltsRequest: Uint8Array | null = null + let getFutureSaltsMsgId: Long | null = null + let pingRequest: Uint8Array | null = null let pingId: Long | null = null let pingMsgId: Long | null = null @@ -1543,8 +1577,6 @@ export class SessionConnection extends PersistentConnection { messageCount += 1 } - const getStateTime = now + GET_STATE_INTERVAL - if (now - this._session.lastPingTime > PING_INTERVAL) { if (!this._session.lastPingMsgId.isZero()) { this.log.warn("didn't receive pong for previous ping (msg_id = %l)", this._session.lastPingMsgId) @@ -1632,6 +1664,18 @@ export class SessionConnection extends PersistentConnection { this._queuedDestroySession = [] } + if (this._salts.shouldFetchSalts()) { + const obj: mtp.RawMt_get_future_salts = { + _: 'mt_get_future_salts', + num: 64, + } + + getFutureSaltsRequest = TlBinaryWriter.serializeObject(this._writerMap, obj) + containerSize += getFutureSaltsRequest.length + 16 + containerMessageCount += 1 + this._salts.isFetching = true + } + let forceContainer = false const rpcToSend: PendingRpc[] = [] @@ -1765,6 +1809,18 @@ export class SessionConnection extends PersistentConnection { }) } + if (getFutureSaltsRequest) { + getFutureSaltsMsgId = this._registerOutgoingMsgId(this._session.writeMessage(writer, getFutureSaltsRequest)) + const pending: PendingMessage = { + _: 'future_salts', + containerId: getFutureSaltsMsgId, + } + this._session.pendingMessages.set(getFutureSaltsMsgId, pending) + otherPendings.push(pending) + } + + const getStateTime = now + GET_STATE_INTERVAL + for (let i = 0; i < rpcToSend.length; i++) { const msg = rpcToSend[i] // not using writeMessage here because we also need seqNo, and @@ -1867,7 +1923,7 @@ export class SessionConnection extends PersistentConnection { const result = writer.result() this.log.debug( - 'sending %d messages: size = %db, acks = %L, ping = %b (msg_id = %l), state_req = %L (msg_id = %l), resend = %L (msg_id = %l), cancels = %L (msg_id = %l), rpc = %s, container = %b, root msg_id = %l', + 'sending %d messages: size = %db, acks = %L, ping = %b (msg_id = %l), state_req = %L (msg_id = %l), resend = %L (msg_id = %l), cancels = %L (msg_id = %l), salts_req = %b (msg_id = %l), rpc = %s, container = %b, root msg_id = %l', messageCount, packetSize, ackMsgIds, @@ -1879,6 +1935,8 @@ export class SessionConnection extends PersistentConnection { cancelRpcs, cancelRpcs, resendMsgId, + getFutureSaltsRequest, + getFutureSaltsMsgId, rpcToSend.map((it) => it.method), useContainer, rootMsgId,