feat(core): handle future salts

This commit is contained in:
alina 🌸 2023-12-11 06:15:31 +03:00
parent ee9e2e35c4
commit 987de6571a
Signed by: teidesu
SSH key fingerprint: SHA256:uNeCpw6aTSU4aIObXLvHfLkDa82HWH9EiOj9AXOIRpI
5 changed files with 130 additions and 15 deletions

View file

@ -17,6 +17,7 @@ import {
SortedArray, SortedArray,
} from '../utils/index.js' } from '../utils/index.js'
import { AuthKey } from './auth-key.js' import { AuthKey } from './auth-key.js'
import { ServerSaltManager } from './server-salt.js'
export interface PendingRpc { export interface PendingRpc {
method: string method: string
@ -98,8 +99,6 @@ export class MtprotoSession {
_lastMessageId = Long.ZERO _lastMessageId = Long.ZERO
_seqNo = 0 _seqNo = 0
serverSalt = Long.ZERO
/// state /// /// state ///
// recent msg ids // recent msg ids
recentOutgoingMsgIds = new LruSet<Long>(1000, true) recentOutgoingMsgIds = new LruSet<Long>(1000, true)
@ -137,6 +136,7 @@ export class MtprotoSession {
readonly log: Logger, readonly log: Logger,
readonly _readerMap: TlReaderMap, readonly _readerMap: TlReaderMap,
readonly _writerMap: TlWriterMap, readonly _writerMap: TlWriterMap,
readonly _salts: ServerSaltManager,
) { ) {
this.log.prefix = `[SESSION ${this._sessionId.toString(16)}] ` this.log.prefix = `[SESSION ${this._sessionId.toString(16)}] `
} }
@ -254,7 +254,7 @@ export class MtprotoSession {
encryptMessage(message: Uint8Array): Uint8Array { encryptMessage(message: Uint8Array): Uint8Array {
const key = this._authKeyTemp.ready ? this._authKeyTemp : this._authKey const key = this._authKeyTemp.ready ? this._authKeyTemp : this._authKey
return key.encryptMessage(message, this.serverSalt, this._sessionId) return key.encryptMessage(message, this._salts.currentSalt, this._sessionId)
} }
/** Decrypt a single MTProto message using session's keys */ /** Decrypt a single MTProto message using session's keys */

View file

@ -88,6 +88,7 @@ export class MultiSessionConnection extends EventEmitter {
this._log.create('session'), this._log.create('session'),
this.params.readerMap, this.params.readerMap,
this.params.writerMap, this.params.writerMap,
this.params.salts,
) )
// brvh // brvh

View file

@ -9,6 +9,7 @@ import { ConfigManager } from './config-manager.js'
import { MultiSessionConnection } from './multi-session-connection.js' import { MultiSessionConnection } from './multi-session-connection.js'
import { PersistentConnectionParams } from './persistent-connection.js' import { PersistentConnectionParams } from './persistent-connection.js'
import { defaultReconnectionStrategy, ReconnectionStrategy } from './reconnection.js' import { defaultReconnectionStrategy, ReconnectionStrategy } from './reconnection.js'
import { ServerSaltManager } from './server-salt.js'
import { SessionConnection, SessionConnectionParams } from './session-connection.js' import { SessionConnection, SessionConnectionParams } from './session-connection.js'
import { defaultTransportFactory, TransportFactory } from './transports/index.js' import { defaultTransportFactory, TransportFactory } from './transports/index.js'
@ -170,6 +171,7 @@ export interface RpcCallOptions {
* Wrapper over all connection pools for a single DC. * Wrapper over all connection pools for a single DC.
*/ */
export class DcConnectionManager { export class DcConnectionManager {
private _salts = new ServerSaltManager()
private __baseConnectionParams = (): SessionConnectionParams => ({ private __baseConnectionParams = (): SessionConnectionParams => ({
crypto: this.manager.params.crypto, crypto: this.manager.params.crypto,
initConnection: this.manager._initConnectionParams, initConnection: this.manager._initConnectionParams,
@ -186,6 +188,7 @@ export class DcConnectionManager {
isMainDcConnection: this.isPrimary, isMainDcConnection: this.isPrimary,
inactivityTimeout: this.manager.params.inactivityTimeout ?? 60_000, inactivityTimeout: this.manager.params.inactivityTimeout ?? 60_000,
enableErrorReporting: this.manager.params.enableErrorReporting, enableErrorReporting: this.manager.params.enableErrorReporting,
salts: this._salts,
}) })
private _log = this.manager._log.create('dc-manager') private _log = this.manager._log.create('dc-manager')
@ -379,6 +382,14 @@ export class DcConnectionManager {
return true return true
} }
destroy() {
this.main.destroy()
this.upload.destroy()
this.download.destroy()
this.downloadSmall.destroy()
this._salts.destroy()
}
} }
/** /**
@ -812,10 +823,7 @@ export class NetworkManager {
destroy(): void { destroy(): void {
for (const dc of this._dcConnections.values()) { for (const dc of this._dcConnections.values()) {
dc.main.destroy() dc.destroy()
dc.upload.destroy()
dc.download.destroy()
dc.downloadSmall.destroy()
} }
if (this._keepAliveInterval) clearInterval(this._keepAliveInterval) if (this._keepAliveInterval) clearInterval(this._keepAliveInterval)
this.config.offConfigUpdate(this._onConfigChanged) this.config.offConfigUpdate(this._onConfigChanged)

View file

@ -0,0 +1,48 @@
import EventEmitter from 'events'
import Long from 'long'
import { mtp } from '@mtcute/tl'
export class ServerSaltManager extends EventEmitter {
private _futureSalts: mtp.RawMt_future_salt[] = []
currentSalt = Long.ZERO
isFetching = false
shouldFetchSalts(): boolean {
return !this.isFetching && !this.currentSalt.isZero() && this._futureSalts.length < 2
}
setFutureSalts(salts: mtp.RawMt_future_salt[]): void {
this._futureSalts = salts
if (Date.now() > salts[0].validSince * 1000) {
this.currentSalt = salts[0].salt
this._futureSalts.shift()
}
this._scheduleNext()
}
private _timer?: NodeJS.Timeout
private _scheduleNext(): void {
if (this._timer) clearTimeout(this._timer)
if (this._futureSalts.length === 0) return
const next = this._futureSalts.shift()!
this._timer = setTimeout(
() => {
this.currentSalt = next.salt
this._scheduleNext()
},
next.validSince * 1000 - Date.now(),
)
}
destroy(): void {
clearTimeout(this._timer)
}
}

View file

@ -21,6 +21,7 @@ import { reportUnknownError } from '../utils/platform/error-reporting.js'
import { doAuthorization } from './authorization.js' import { doAuthorization } from './authorization.js'
import { MtprotoSession, PendingMessage, PendingRpc } from './mtproto-session.js' import { MtprotoSession, PendingMessage, PendingRpc } from './mtproto-session.js'
import { PersistentConnection, PersistentConnectionParams } from './persistent-connection.js' import { PersistentConnection, PersistentConnectionParams } from './persistent-connection.js'
import { ServerSaltManager } from './server-salt.js'
import { TransportError } from './transports/abstract.js' import { TransportError } from './transports/abstract.js'
export interface SessionConnectionParams extends PersistentConnectionParams { export interface SessionConnectionParams extends PersistentConnectionParams {
@ -35,6 +36,8 @@ export interface SessionConnectionParams extends PersistentConnectionParams {
isMainDcConnection: boolean isMainDcConnection: boolean
usePfs?: boolean usePfs?: boolean
salts: ServerSaltManager
readerMap: TlReaderMap readerMap: TlReaderMap
writerMap: TlWriterMap writerMap: TlWriterMap
} }
@ -84,6 +87,7 @@ export class SessionConnection extends PersistentConnection {
private _readerMap: TlReaderMap private _readerMap: TlReaderMap
private _writerMap: TlWriterMap private _writerMap: TlWriterMap
private _crypto: ICryptoProvider private _crypto: ICryptoProvider
private _salts: ServerSaltManager
constructor( constructor(
params: SessionConnectionParams, params: SessionConnectionParams,
@ -95,6 +99,7 @@ export class SessionConnection extends PersistentConnection {
this._readerMap = params.readerMap this._readerMap = params.readerMap
this._writerMap = params.writerMap this._writerMap = params.writerMap
this._crypto = params.crypto this._crypto = params.crypto
this._salts = params.salts
this._handleRawMessage = this._handleRawMessage.bind(this) this._handleRawMessage = this._handleRawMessage.bind(this)
} }
@ -143,6 +148,7 @@ export class SessionConnection extends PersistentConnection {
reset(forever = false): void { reset(forever = false): void {
this._session.initConnectionCalled = false this._session.initConnectionCalled = false
this._flushTimer.reset() this._flushTimer.reset()
this._salts.isFetching = false
if (forever) { if (forever) {
this.removeAllListeners() this.removeAllListeners()
@ -273,7 +279,7 @@ export class SessionConnection extends PersistentConnection {
doAuthorization(this, this._crypto) doAuthorization(this, this._crypto)
.then(([authKey, serverSalt, timeOffset]) => { .then(([authKey, serverSalt, timeOffset]) => {
this._session._authKey.setup(authKey) this._session._authKey.setup(authKey)
this._session.serverSalt = serverSalt this._salts.currentSalt = serverSalt
this._session._timeOffset = timeOffset this._session._timeOffset = timeOffset
this._session.authorizationPending = false this._session.authorizationPending = false
@ -430,7 +436,7 @@ export class SessionConnection extends PersistentConnection {
this._session._authKeyTempSecondary = this._session._authKeyTemp this._session._authKeyTempSecondary = this._session._authKeyTemp
this._session._authKeyTemp = tempKey this._session._authKeyTemp = tempKey
this._session.serverSalt = tempServerSalt this._salts.currentSalt = tempServerSalt
this.log.debug('temp key has been bound, exp = %d', inner.expiresAt) this.log.debug('temp key has been bound, exp = %d', inner.expiresAt)
@ -608,7 +614,7 @@ export class SessionConnection extends PersistentConnection {
this._onMsgsStateInfo(message) this._onMsgsStateInfo(message)
break break
case 'mt_future_salts': case 'mt_future_salts':
// todo this._onFutureSalts(message)
break break
case 'mt_msgs_state_req': case 'mt_msgs_state_req':
case 'mt_msg_resend_req': case 'mt_msg_resend_req':
@ -1022,6 +1028,12 @@ export class SessionConnection extends PersistentConnection {
case 'bind': case 'bind':
this.log.debug('temp key binding request %l failed because of %s, retrying', msgId, reason) this.log.debug('temp key binding request %l failed because of %s, retrying', msgId, reason)
msgInfo.promise.reject(Error(reason)) msgInfo.promise.reject(Error(reason))
break
case 'future_salts':
this.log.debug('future_salts request %l failed because of %s, will retry', msgId, reason)
this._salts.isFetching = false
// this is enough to make it retry on the next flush.
break
} }
this._session.pendingMessages.delete(msgId) this._session.pendingMessages.delete(msgId)
@ -1064,7 +1076,7 @@ export class SessionConnection extends PersistentConnection {
} }
private _onBadServerSalt(msg: mtp.RawMt_bad_server_salt): void { private _onBadServerSalt(msg: mtp.RawMt_bad_server_salt): void {
this._session.serverSalt = msg.newServerSalt this._salts.currentSalt = msg.newServerSalt
this._onMessageFailed(msg.badMsgId, 'bad_server_salt') this._onMessageFailed(msg.badMsgId, 'bad_server_salt')
} }
@ -1114,7 +1126,7 @@ export class SessionConnection extends PersistentConnection {
this.emit('update', { _: 'updatesTooLong' }) this.emit('update', { _: 'updatesTooLong' })
} }
this._session.serverSalt = serverSalt this._salts.currentSalt = serverSalt
this.log.debug('received new_session_created, uid = %l, first msg_id = %l', uniqueId, firstMsgId) this.log.debug('received new_session_created, uid = %l, first msg_id = %l', uniqueId, firstMsgId)
@ -1230,6 +1242,25 @@ export class SessionConnection extends PersistentConnection {
this._onMessagesInfo(info.msgIds, msg.info) this._onMessagesInfo(info.msgIds, msg.info)
} }
private _onFutureSalts(msg: mtp.RawMt_future_salts): void {
const info = this._session.pendingMessages.get(msg.reqMsgId)
if (!info) {
this.log.warn('received future_salts to unknown request %l', msg.reqMsgId)
return
}
if (info._ !== 'future_salts') {
this.log.warn('received future_salts to %s query %l', info._, msg.reqMsgId)
return
}
this._salts.isFetching = false
this._salts.setFutureSalts(msg.salts)
}
private _onDestroySessionResult(msg: mtp.TypeDestroySessionRes): void { private _onDestroySessionResult(msg: mtp.TypeDestroySessionRes): void {
const reqMsgId = this._session.destroySessionIdToMsgId.get(msg.sessionId) const reqMsgId = this._session.destroySessionIdToMsgId.get(msg.sessionId)
@ -1504,6 +1535,9 @@ export class SessionConnection extends PersistentConnection {
let ackRequest: Uint8Array | null = null let ackRequest: Uint8Array | null = null
let ackMsgIds: Long[] | null = null let ackMsgIds: Long[] | null = null
let getFutureSaltsRequest: Uint8Array | null = null
let getFutureSaltsMsgId: Long | null = null
let pingRequest: Uint8Array | null = null let pingRequest: Uint8Array | null = null
let pingId: Long | null = null let pingId: Long | null = null
let pingMsgId: Long | null = null let pingMsgId: Long | null = null
@ -1543,8 +1577,6 @@ export class SessionConnection extends PersistentConnection {
messageCount += 1 messageCount += 1
} }
const getStateTime = now + GET_STATE_INTERVAL
if (now - this._session.lastPingTime > PING_INTERVAL) { if (now - this._session.lastPingTime > PING_INTERVAL) {
if (!this._session.lastPingMsgId.isZero()) { if (!this._session.lastPingMsgId.isZero()) {
this.log.warn("didn't receive pong for previous ping (msg_id = %l)", this._session.lastPingMsgId) this.log.warn("didn't receive pong for previous ping (msg_id = %l)", this._session.lastPingMsgId)
@ -1632,6 +1664,18 @@ export class SessionConnection extends PersistentConnection {
this._queuedDestroySession = [] this._queuedDestroySession = []
} }
if (this._salts.shouldFetchSalts()) {
const obj: mtp.RawMt_get_future_salts = {
_: 'mt_get_future_salts',
num: 64,
}
getFutureSaltsRequest = TlBinaryWriter.serializeObject(this._writerMap, obj)
containerSize += getFutureSaltsRequest.length + 16
containerMessageCount += 1
this._salts.isFetching = true
}
let forceContainer = false let forceContainer = false
const rpcToSend: PendingRpc[] = [] const rpcToSend: PendingRpc[] = []
@ -1765,6 +1809,18 @@ export class SessionConnection extends PersistentConnection {
}) })
} }
if (getFutureSaltsRequest) {
getFutureSaltsMsgId = this._registerOutgoingMsgId(this._session.writeMessage(writer, getFutureSaltsRequest))
const pending: PendingMessage = {
_: 'future_salts',
containerId: getFutureSaltsMsgId,
}
this._session.pendingMessages.set(getFutureSaltsMsgId, pending)
otherPendings.push(pending)
}
const getStateTime = now + GET_STATE_INTERVAL
for (let i = 0; i < rpcToSend.length; i++) { for (let i = 0; i < rpcToSend.length; i++) {
const msg = rpcToSend[i] const msg = rpcToSend[i]
// not using writeMessage here because we also need seqNo, and // not using writeMessage here because we also need seqNo, and
@ -1867,7 +1923,7 @@ export class SessionConnection extends PersistentConnection {
const result = writer.result() const result = writer.result()
this.log.debug( this.log.debug(
'sending %d messages: size = %db, acks = %L, ping = %b (msg_id = %l), state_req = %L (msg_id = %l), resend = %L (msg_id = %l), cancels = %L (msg_id = %l), rpc = %s, container = %b, root msg_id = %l', 'sending %d messages: size = %db, acks = %L, ping = %b (msg_id = %l), state_req = %L (msg_id = %l), resend = %L (msg_id = %l), cancels = %L (msg_id = %l), salts_req = %b (msg_id = %l), rpc = %s, container = %b, root msg_id = %l',
messageCount, messageCount,
packetSize, packetSize,
ackMsgIds, ackMsgIds,
@ -1879,6 +1935,8 @@ export class SessionConnection extends PersistentConnection {
cancelRpcs, cancelRpcs,
cancelRpcs, cancelRpcs,
resendMsgId, resendMsgId,
getFutureSaltsRequest,
getFutureSaltsMsgId,
rpcToSend.map((it) => it.method), rpcToSend.map((it) => it.method),
useContainer, useContainer,
rootMsgId, rootMsgId,