refactor: (initial) extracted connection management to NetworkManager

This commit is contained in:
teidesu 2022-11-05 03:03:21 +03:00 committed by Alina Sireneva
parent 76639d2993
commit 4f834afc6a
Signed by: teidesu
SSH key fingerprint: SHA256:uNeCpw6aTSU4aIObXLvHfLkDa82HWH9EiOj9AXOIRpI
9 changed files with 860 additions and 472 deletions

View file

@ -78,8 +78,7 @@ export async function startTest(
if (!availableDcs.find((dc) => dc.id === id)) { throw new MtArgumentError(`${phone} has invalid DC ID (${id})`) } if (!availableDcs.find((dc) => dc.id === id)) { throw new MtArgumentError(`${phone} has invalid DC ID (${id})`) }
} else { } else {
let dcId = this._primaryDc.id let dcId = this._defaultDc.id
if (params.dcId) { if (params.dcId) {
if (!availableDcs.find((dc) => dc.id === params!.dcId)) { throw new MtArgumentError(`DC ID is invalid (${dcId})`) } if (!availableDcs.find((dc) => dc.id === params!.dcId)) { throw new MtArgumentError(`DC ID is invalid (${dcId})`) }
dcId = params.dcId dcId = params.dcId

View file

@ -76,7 +76,7 @@ export async function* downloadAsIterable(
const isWeb = tl.isAnyInputWebFileLocation(location) const isWeb = tl.isAnyInputWebFileLocation(location)
// we will receive a FileMigrateError in case this is invalid // we will receive a FileMigrateError in case this is invalid
if (!dcId) dcId = this._primaryDc.id if (!dcId) dcId = this._defaultDc.id
const chunkSize = partSizeKb * 1024 const chunkSize = partSizeKb * 1024

View file

@ -7,6 +7,9 @@ import defaultReaderMap from '@mtcute/tl/binary/reader'
import defaultWriterMap from '@mtcute/tl/binary/writer' import defaultWriterMap from '@mtcute/tl/binary/writer'
import { TlReaderMap, TlWriterMap } from '@mtcute/tl-runtime' import { TlReaderMap, TlWriterMap } from '@mtcute/tl-runtime'
import defaultReaderMap from '@mtcute/tl/binary/reader'
import defaultWriterMap from '@mtcute/tl/binary/writer'
import { import {
defaultReconnectionStrategy, defaultReconnectionStrategy,
defaultTransportFactory, defaultTransportFactory,
@ -22,6 +25,10 @@ import {
createControllablePromise, createControllablePromise,
CryptoProviderFactory, CryptoProviderFactory,
defaultCryptoProviderFactory, defaultCryptoProviderFactory,
sleep,
getAllPeersFrom,
LogManager,
toggleChannelIdMark,
defaultProductionDc, defaultProductionDc,
defaultProductionIpv6Dc, defaultProductionIpv6Dc,
defaultTestDc, defaultTestDc,
@ -31,11 +38,26 @@ import {
LogManager, LogManager,
sleep, sleep,
toggleChannelIdMark, toggleChannelIdMark,
ControllablePromise,
createControllablePromise,
readStringSession,
writeStringSession
} from './utils' } from './utils'
import { addPublicKey } from './utils/crypto/keys' import { addPublicKey } from './utils/crypto/keys'
import { readStringSession, writeStringSession } from './utils/string-session' import { readStringSession, writeStringSession } from './utils/string-session'
import {
TransportFactory,
defaultReconnectionStrategy,
ReconnectionStrategy,
defaultTransportFactory,
SessionConnection,
} from './network'
import { PersistentConnectionParams } from './network/persistent-connection'
import { ITelegramStorage, MemoryStorage } from './storage'
import { ConfigManager } from './network/config-manager' import { ConfigManager } from './network/config-manager'
import { NetworkManager } from "./network/network-manager";
export interface BaseTelegramClientOptions { export interface BaseTelegramClientOptions {
/** /**
@ -66,27 +88,27 @@ export interface BaseTelegramClientOptions {
*/ */
useIpv6?: boolean useIpv6?: boolean
/** /**
* Primary DC to use for initial connection. * Primary DC to use for initial connection.
* This does not mean this will be the only DC used, * This does not mean this will be the only DC used,
* nor that this DC will actually be primary, this only * nor that this DC will actually be primary, this only
* determines the first DC the library will try to connect to. * determines the first DC the library will try to connect to.
* Can be used to connect to other networks (like test DCs). * Can be used to connect to other networks (like test DCs).
* *
* When session already contains primary DC, this parameter is ignored. * When session already contains primary DC, this parameter is ignored.
* Defaults to Production DC 2. * Defaults to Production DC 2.
*/ */
primaryDc?: tl.RawDcOption defaultDc?: tl.RawDcOption
/** /**
* Whether to connect to test servers. * Whether to connect to test servers.
* *
* If passed, {@link primaryDc} defaults to Test DC 2. * If passed, {@link defaultDc} defaults to Test DC 2.
* *
* **Must** be passed if using test servers, even if * **Must** be passed if using test servers, even if
* you passed custom {@link primaryDc} * you passed custom {@link defaultDc}
*/ */
testMode?: boolean testMode?: boolean
/** /**
* Additional options for initConnection call. * Additional options for initConnection call.
@ -179,21 +201,11 @@ export interface BaseTelegramClientOptions {
} }
export class BaseTelegramClient extends EventEmitter { export class BaseTelegramClient extends EventEmitter {
/**
* `initConnection` params taken from {@link BaseTelegramClient.Options.initConnectionOptions}.
*/
protected readonly _initConnectionParams: tl.RawInitConnectionRequest
/** /**
* Crypto provider taken from {@link BaseTelegramClient.Options.crypto} * Crypto provider taken from {@link BaseTelegramClient.Options.crypto}
*/ */
protected readonly _crypto: ICryptoProvider protected readonly _crypto: ICryptoProvider
/**
* Transport factory taken from {@link BaseTelegramClient.Options.transport}
*/
protected readonly _transportFactory: TransportFactory
/** /**
* Telegram storage taken from {@link BaseTelegramClient.Options.storage} * Telegram storage taken from {@link BaseTelegramClient.Options.storage}
*/ */
@ -214,11 +226,6 @@ export class BaseTelegramClient extends EventEmitter {
*/ */
protected readonly _testMode: boolean protected readonly _testMode: boolean
/**
* Reconnection strategy taken from {@link BaseTelegramClient.Options.reconnectionStrategy}
*/
protected readonly _reconnectionStrategy: ReconnectionStrategy<PersistentConnectionParams>
/** /**
* Flood sleep threshold taken from {@link BaseTelegramClient.Options.floodSleepThreshold} * Flood sleep threshold taken from {@link BaseTelegramClient.Options.floodSleepThreshold}
*/ */
@ -230,22 +237,16 @@ export class BaseTelegramClient extends EventEmitter {
protected readonly _rpcRetryCount: number protected readonly _rpcRetryCount: number
/** /**
* "Disable updates" taken from {@link BaseTelegramClient.Options.disableUpdates} * Primary DC taken from {@link BaseTelegramClient.Options.defaultDc},
*/
protected readonly _disableUpdates: boolean
/**
* Primary DC taken from {@link BaseTelegramClient.Options.primaryDc},
* loaded from session or changed by other means (like redirecting). * loaded from session or changed by other means (like redirecting).
*/ */
protected _primaryDc: tl.RawDcOption protected _defaultDc: tl.RawDcOption
private _niceStacks: boolean private _niceStacks: boolean
readonly _layer: number readonly _layer: number
readonly _readerMap: TlReaderMap readonly _readerMap: TlReaderMap
readonly _writerMap: TlWriterMap readonly _writerMap: TlWriterMap
private _keepAliveInterval?: NodeJS.Timeout
protected _lastUpdateTime = 0 protected _lastUpdateTime = 0
private _floodWaitedRequests: Record<string, number> = {} private _floodWaitedRequests: Record<string, number> = {}
@ -260,14 +261,6 @@ export class BaseTelegramClient extends EventEmitter {
private _onError?: (err: unknown, connection?: SessionConnection) => void private _onError?: (err: unknown, connection?: SessionConnection) => void
/**
* The primary {@link SessionConnection} that is used for
* most of the communication with Telegram.
*
* Methods for downloading/uploading files may create additional connections as needed.
*/
primaryConnection!: SessionConnection
private _importFrom?: string private _importFrom?: string
private _importForce?: boolean private _importForce?: boolean
@ -282,6 +275,7 @@ export class BaseTelegramClient extends EventEmitter {
protected _handleUpdate(update: tl.TypeUpdates): void {} protected _handleUpdate(update: tl.TypeUpdates): void {}
readonly log = new LogManager() readonly log = new LogManager()
readonly network: NetworkManager
constructor(opts: BaseTelegramClientOptions) { constructor(opts: BaseTelegramClientOptions) {
super() super()
@ -293,14 +287,13 @@ export class BaseTelegramClient extends EventEmitter {
throw new Error('apiId must be a number or a numeric string!') throw new Error('apiId must be a number or a numeric string!')
} }
this._transportFactory = opts.transport ?? defaultTransportFactory
this._crypto = (opts.crypto ?? defaultCryptoProviderFactory)() this._crypto = (opts.crypto ?? defaultCryptoProviderFactory)()
this.storage = opts.storage ?? new MemoryStorage() this.storage = opts.storage ?? new MemoryStorage()
this._apiHash = opts.apiHash this._apiHash = opts.apiHash
this._useIpv6 = Boolean(opts.useIpv6) this._useIpv6 = Boolean(opts.useIpv6)
this._testMode = Boolean(opts.testMode) this._testMode = Boolean(opts.testMode)
let dc = opts.primaryDc let dc = opts.defaultDc
if (!dc) { if (!dc) {
if (this._testMode) { if (this._testMode) {
@ -312,42 +305,33 @@ export class BaseTelegramClient extends EventEmitter {
} }
} }
this._primaryDc = dc this._defaultDc = dc
this._reconnectionStrategy = this._reconnectionStrategy =
opts.reconnectionStrategy ?? defaultReconnectionStrategy opts.reconnectionStrategy ?? defaultReconnectionStrategy
this._floodSleepThreshold = opts.floodSleepThreshold ?? 10000 this._floodSleepThreshold = opts.floodSleepThreshold ?? 10000
this._rpcRetryCount = opts.rpcRetryCount ?? 5 this._rpcRetryCount = opts.rpcRetryCount ?? 5
this._disableUpdates = opts.disableUpdates ?? false
this._niceStacks = opts.niceStacks ?? true this._niceStacks = opts.niceStacks ?? true
this._layer = opts.overrideLayer ?? tl.LAYER this._layer = opts.overrideLayer ?? tl.LAYER
this._readerMap = opts.readerMap ?? defaultReaderMap this._readerMap = opts.readerMap ?? defaultReaderMap
this._writerMap = opts.writerMap ?? defaultWriterMap this._writerMap = opts.writerMap ?? defaultWriterMap
this.storage.setup?.(this.log, this._readerMap, this._writerMap) this.network = new NetworkManager({
let deviceModel = 'mtcute on '
if (typeof process !== 'undefined' && typeof require !== 'undefined') {
// eslint-disable-next-line @typescript-eslint/no-var-requires
const os = require('os')
deviceModel += `${os.type()} ${os.arch()} ${os.release()}`
} else if (typeof navigator !== 'undefined') {
deviceModel += navigator.userAgent
} else deviceModel += 'unknown'
this._initConnectionParams = {
_: 'initConnection',
deviceModel,
systemVersion: '1.0',
appVersion: '1.0.0',
systemLangCode: 'en',
langPack: '', // "langPacks are for official apps only"
langCode: 'en',
...(opts.initConnectionOptions ?? {}),
apiId, apiId,
crypto: this._crypto,
disableUpdates: opts.disableUpdates ?? false,
initConnectionOptions: opts.initConnectionOptions,
layer: this._layer,
log: this.log,
readerMap: this._readerMap,
writerMap: this._writerMap,
reconnectionStrategy: opts.reconnectionStrategy,
storage: this.storage,
testMode: this._testMode,
transport: opts.transport,
}, this._config)
query: null as any, this.storage.setup?.(this.log, this._readerMap, this._writerMap)
}
} }
protected async _loadStorage(): Promise<void> { protected async _loadStorage(): Promise<void> {
@ -433,6 +417,7 @@ export class BaseTelegramClient extends EventEmitter {
*/ */
async connect(): Promise<void> { async connect(): Promise<void> {
if (this._connected) { if (this._connected) {
// avoid double-connect
await this._connected await this._connected
return return
@ -442,16 +427,14 @@ export class BaseTelegramClient extends EventEmitter {
await this._loadStorage() await this._loadStorage()
const primaryDc = await this.storage.getDefaultDc() const primaryDc = await this.storage.getDefaultDc()
if (primaryDc !== null) this._primaryDc = primaryDc if (primaryDc !== null) this._defaultDc = primaryDc
this._setupPrimaryConnection() const defaultDcAuthKey = await this.storage.getAuthKeyFor(this._defaultDc.id)
await this.primaryConnection.setupKeys( // await this.primaryConnection.setupKeys()
await this.storage.getAuthKeyFor(this._primaryDc.id),
)
if ( if (
(this._importForce || !this.primaryConnection.getAuthKey()) && (this._importForce || !defaultDcAuthKey) &&
this._importFrom this._importFrom
) { ) {
const data = readStringSession(this._readerMap, this._importFrom) const data = readStringSession(this._readerMap, this._importFrom)
@ -462,23 +445,23 @@ export class BaseTelegramClient extends EventEmitter {
) )
} }
this._primaryDc = this.primaryConnection.params.dc = data.primaryDc this._defaultDc = data.primaryDc
await this.storage.setDefaultDc(data.primaryDc) await this.storage.setDefaultDc(data.primaryDc)
if (data.self) { if (data.self) {
await this.storage.setSelf(data.self) await this.storage.setSelf(data.self)
} }
await this.primaryConnection.setupKeys(data.authKey) // await this.primaryConnection.setupKeys(data.authKey)
await this.storage.setAuthKeyFor(data.primaryDc.id, data.authKey) await this.storage.setAuthKeyFor(data.primaryDc.id, data.authKey)
await this._saveStorage(true) await this._saveStorage(true)
} }
this.network.connect(this._defaultDc)
this._connected.resolve() this._connected.resolve()
this._connected = true this._connected = true
this.primaryConnection.connect()
} }
/** /**
@ -486,7 +469,8 @@ export class BaseTelegramClient extends EventEmitter {
*/ */
async waitUntilUsable(): Promise<void> { async waitUntilUsable(): Promise<void> {
return new Promise((resolve) => { return new Promise((resolve) => {
this.primaryConnection.once('usable', resolve) // todo
// this.primaryConnection.once('usable', resolve)
}) })
} }
@ -503,8 +487,8 @@ export class BaseTelegramClient extends EventEmitter {
await this._onClose() await this._onClose()
this._config.destroy() this._config.destroy()
this.network.destroy()
this._cleanupPrimaryConnection(true)
// close additional connections // close additional connections
this._additionalConnections.forEach((conn) => conn.destroy()) this._additionalConnections.forEach((conn) => conn.destroy())
@ -528,10 +512,11 @@ export class BaseTelegramClient extends EventEmitter {
newDc = res newDc = res
} }
this._primaryDc = newDc this._defaultDc = newDc
await this.storage.setDefaultDc(newDc) await this.storage.setDefaultDc(newDc)
await this._saveStorage() await this._saveStorage()
await this.primaryConnection.changeDc(newDc) // todo
// await this.primaryConnection.changeDc(newDc)
} }
/** /**
@ -555,6 +540,7 @@ export class BaseTelegramClient extends EventEmitter {
timeout?: number timeout?: number
}, },
): Promise<tl.RpcCallReturn[T['_']]> { ): Promise<tl.RpcCallReturn[T['_']]> {
// todo move to network manager
if (this._connected !== true) { if (this._connected !== true) {
await this.connect() await this.connect()
} }
@ -574,14 +560,12 @@ export class BaseTelegramClient extends EventEmitter {
} }
} }
const connection = params?.connection ?? this.primaryConnection
let lastError: Error | null = null let lastError: Error | null = null
const stack = this._niceStacks ? new Error().stack : undefined const stack = this._niceStacks ? new Error().stack : undefined
for (let i = 0; i < this._rpcRetryCount; i++) { for (let i = 0; i < this._rpcRetryCount; i++) {
try { try {
const res = await connection.sendRpc( const res = await this.network['_primaryDc']!.mainConnection.sendRpc(
message, message,
stack, stack,
params?.timeout, params?.timeout,
@ -629,35 +613,35 @@ export class BaseTelegramClient extends EventEmitter {
} }
} }
if (connection.params.dc.id === this._primaryDc.id) { // if (connection.params.dc.id === this._defaultDc.id) {
if ( // if (
e.constructor === tl.errors.PhoneMigrateXError || // e.constructor === tl.errors.PhoneMigrateXError ||
e.constructor === tl.errors.UserMigrateXError || // e.constructor === tl.errors.UserMigrateXError ||
e.constructor === tl.errors.NetworkMigrateXError // e.constructor === tl.errors.NetworkMigrateXError
) { // ) {
this.log.info('Migrate error, new dc = %d', e.new_dc) // this.log.info('Migrate error, new dc = %d', e.new_dc)
await this.changeDc(e.new_dc) // await this.changeDc(e.new_dc)
continue // continue
} // }
} else if ( // } else {
e.constructor === tl.errors.AuthKeyUnregisteredError // if (e.constructor === tl.errors.AuthKeyUnregisteredError) {
) { // // we can try re-exporting auth from the primary connection
// we can try re-exporting auth from the primary connection // this.log.warn('exported auth key error, re-exporting..')
this.log.warn('exported auth key error, re-exporting..') //
// const auth = await this.call({
const auth = await this.call({ // _: 'auth.exportAuthorization',
_: 'auth.exportAuthorization', // dcId: connection.params.dc.id,
dcId: connection.params.dc.id, // })
}) //
// await connection.sendRpc({
await connection.sendRpc({ // _: 'auth.importAuthorization',
_: 'auth.importAuthorization', // id: auth.id,
id: auth.id, // bytes: auth.bytes,
bytes: auth.bytes, // })
}) //
// continue
continue // }
} // }
throw e throw e
} }
@ -666,100 +650,100 @@ export class BaseTelegramClient extends EventEmitter {
throw lastError throw lastError
} }
/** // /**
* Creates an additional connection to a given DC. // * Creates an additional connection to a given DC.
* This will use auth key for that DC that was already stored // * This will use auth key for that DC that was already stored
* in the session, or generate a new auth key by exporting // * in the session, or generate a new auth key by exporting
* authorization from primary DC and importing it to the new DC. // * authorization from primary DC and importing it to the new DC.
* New connection will use the same crypto provider, `initConnection`, // * New connection will use the same crypto provider, `initConnection`,
* transport and reconnection strategy as the primary connection // * transport and reconnection strategy as the primary connection
* // *
* This method is quite low-level and you shouldn't usually care about this // * This method is quite low-level and you shouldn't usually care about this
* when using high-level API provided by `@mtcute/client`. // * when using high-level API provided by `@mtcute/client`.
* // *
* @param dcId DC id, to which the connection will be created // * @param dcId DC id, to which the connection will be created
* @param cdn Whether that DC is a CDN DC // * @param cdn Whether that DC is a CDN DC
* @param inactivityTimeout // * @param inactivityTimeout
* Inactivity timeout for the connection (in ms), after which the transport will be closed. // * Inactivity timeout for the connection (in ms), after which the transport will be closed.
* Note that connection can still be used normally, it's just the transport which is closed. // * Note that connection can still be used normally, it's just the transport which is closed.
* Defaults to 5 min // * Defaults to 5 min
*/ // */
async createAdditionalConnection( // async createAdditionalConnection(
dcId: number, // dcId: number,
params?: { // params?: {
// todo proper docs // // todo proper docs
// default = false // // default = false
media?: boolean // media?: boolean
// default = fa;se // // default = fa;se
cdn?: boolean // cdn?: boolean
// default = 300_000 // // default = 300_000
inactivityTimeout?: number // inactivityTimeout?: number
// default = false // // default = false
disableUpdates?: boolean // disableUpdates?: boolean
}, // }
): Promise<SessionConnection> { // ): Promise<SessionConnection> {
const dc = await this._config.findOption({ // const dc = await this._config.findOption({
dcId, // dcId,
preferMedia: params?.media, // preferMedia: params?.media,
cdn: params?.cdn, // cdn: params?.cdn,
allowIpv6: this._useIpv6, // allowIpv6: this._useIpv6,
}) // })
if (!dc) throw new Error('DC not found') // if (!dc) throw new Error('DC not found')
const connection = new SessionConnection( // const connection = new SessionConnection(
{ // {
dc, // dc,
testMode: this._testMode, // testMode: this._testMode,
crypto: this._crypto, // crypto: this._crypto,
initConnection: this._initConnectionParams, // initConnection: this._initConnectionParams,
transportFactory: this._transportFactory, // transportFactory: this._transportFactory,
reconnectionStrategy: this._reconnectionStrategy, // reconnectionStrategy: this._reconnectionStrategy,
inactivityTimeout: params?.inactivityTimeout ?? 300_000, // inactivityTimeout: params?.inactivityTimeout ?? 300_000,
layer: this._layer, // layer: this._layer,
disableUpdates: params?.disableUpdates, // disableUpdates: params?.disableUpdates,
readerMap: this._readerMap, // readerMap: this._readerMap,
writerMap: this._writerMap, // writerMap: this._writerMap,
}, // },
this.log.create('connection'), // this.log.create('connection')
) // )
//
connection.on('error', (err) => this._emitError(err, connection)) // connection.on('error', (err) => this._emitError(err, connection))
await connection.setupKeys(await this.storage.getAuthKeyFor(dc.id)) // await connection.setupKeys(await this.storage.getAuthKeyFor(dc.id))
connection.connect() // connection.connect()
//
if (!connection.getAuthKey()) { // if (!connection.getAuthKey()) {
this.log.info('exporting auth to DC %d', dcId) // this.log.info('exporting auth to DC %d', dcId)
const auth = await this.call({ // const auth = await this.call({
_: 'auth.exportAuthorization', // _: 'auth.exportAuthorization',
dcId, // dcId,
}) // })
await connection.sendRpc({ // await connection.sendRpc({
_: 'auth.importAuthorization', // _: 'auth.importAuthorization',
id: auth.id, // id: auth.id,
bytes: auth.bytes, // bytes: auth.bytes,
}) // })
//
// connection.authKey was already generated at this point // // connection.authKey was already generated at this point
this.storage.setAuthKeyFor(dc.id, connection.getAuthKey()!) // this.storage.setAuthKeyFor(dc.id, connection.getAuthKey()!)
await this._saveStorage() // await this._saveStorage()
} else { // } else {
// in case the auth key is invalid // // in case the auth key is invalid
const dcId = dc.id // const dcId = dc.id
connection.on('key-change', async (key) => { // connection.on('key-change', async (key) => {
// we don't need to export, it will be done by `.call()` // // we don't need to export, it will be done by `.call()`
// in case this error is returned // // in case this error is returned
// // //
// even worse, exporting here will lead to a race condition, // // even worse, exporting here will lead to a race condition,
// and may result in redundant re-exports. // // and may result in redundant re-exports.
//
this.storage.setAuthKeyFor(dcId, key) // this.storage.setAuthKeyFor(dcId, key)
await this._saveStorage() // await this._saveStorage()
}) // })
} // }
//
this._additionalConnections.push(connection) // this._additionalConnections.push(connection)
//
return connection // return connection
} // }
/** /**
* Destroy a connection that was previously created using * Destroy a connection that was previously created using
@ -789,7 +773,8 @@ export class BaseTelegramClient extends EventEmitter {
* @param factory New transport factory * @param factory New transport factory
*/ */
changeTransport(factory: TransportFactory): void { changeTransport(factory: TransportFactory): void {
this.primaryConnection.changeTransport(factory) // todo
// this.primaryConnection.changeTransport(factory)
this._additionalConnections.forEach((conn) => this._additionalConnections.forEach((conn) =>
conn.changeTransport(factory), conn.changeTransport(factory),
@ -915,16 +900,16 @@ export class BaseTelegramClient extends EventEmitter {
* > with [@BotFather](//t.me/botfather) * > with [@BotFather](//t.me/botfather)
*/ */
async exportSession(): Promise<string> { async exportSession(): Promise<string> {
if (!this.primaryConnection.getAuthKey()) { // todo
throw new Error('Auth key is not generated yet') // if (!this.primaryConnection.getAuthKey())
} // throw new Error('Auth key is not generated yet')
return writeStringSession(this._writerMap, { return writeStringSession(this._writerMap, {
version: 1, version: 1,
self: await this.storage.getSelf(), self: await this.storage.getSelf(),
testMode: this._testMode, testMode: this._testMode,
primaryDc: this._primaryDc, primaryDc: this._defaultDc,
authKey: this.primaryConnection.getAuthKey()!, authKey: Buffer.from([]) //this.primaryConnection.getAuthKey()!,
}) })
} }

View file

@ -11,8 +11,76 @@ import {
import { getRandomInt, ICryptoProvider, Logger, randomLong } from '../utils' import { getRandomInt, ICryptoProvider, Logger, randomLong } from '../utils'
import { buffersEqual, randomBytes } from '../utils/buffer-utils' import { buffersEqual, randomBytes } from '../utils/buffer-utils'
import {
ICryptoProvider,
Logger,
getRandomInt,
randomLong,
ControllablePromise,
LruSet,
Deque,
SortedArray,
LongMap,
} from '../utils'
import { createAesIgeForMessage } from '../utils/crypto/mtproto' import { createAesIgeForMessage } from '../utils/crypto/mtproto'
export interface PendingRpc {
method: string
data: Buffer
promise: ControllablePromise
stack?: string
gzipOverhead?: number
sent?: boolean
msgId?: Long
seqNo?: number
containerId?: Long
acked?: boolean
initConn?: boolean
getState?: number
cancelled?: boolean
timeout?: number
}
export type PendingMessage =
| {
_: 'rpc'
rpc: PendingRpc
}
| {
_: 'container'
msgIds: Long[]
}
| {
_: 'state'
msgIds: Long[]
containerId: Long
}
| {
_: 'resend'
msgIds: Long[]
containerId: Long
}
| {
_: 'ping'
pingId: Long
containerId: Long
}
| {
_: 'destroy_session'
sessionId: Long
containerId: Long
}
| {
_: 'cancel'
msgId: Long
containerId: Long
}
| {
_: 'future_salts'
containerId: Long
}
/** /**
* Class encapsulating a single MTProto session. * Class encapsulating a single MTProto session.
* Provides means to en-/decrypt messages * Provides means to en-/decrypt messages
@ -33,6 +101,27 @@ export class MtprotoSession {
serverSalt = Long.ZERO serverSalt = Long.ZERO
/// state ///
// recent msg ids
recentOutgoingMsgIds = new LruSet<Long>(1000, false, true)
recentIncomingMsgIds = new LruSet<Long>(1000, false, true)
// queues
queuedRpc = new Deque<PendingRpc>()
queuedAcks: Long[] = []
queuedStateReq: Long[] = []
queuedResendReq: Long[] = []
queuedCancelReq: Long[] = []
getStateSchedule = new SortedArray<PendingRpc>(
[],
(a, b) => a.getState! - b.getState!
)
// requests info
pendingMessages = new LongMap<PendingMessage>()
initConnectionCalled = false
constructor( constructor(
crypto: ICryptoProvider, crypto: ICryptoProvider,
readonly log: Logger, readonly log: Logger,
@ -40,6 +129,7 @@ export class MtprotoSession {
readonly _writerMap: TlWriterMap, readonly _writerMap: TlWriterMap,
) { ) {
this._crypto = crypto this._crypto = crypto
this.log.prefix = `[SESSION ${this._sessionId.toString(16)}] `
} }
/** Whether session contains authKey */ /** Whether session contains authKey */
@ -50,11 +140,13 @@ export class MtprotoSession {
/** Setup keys based on authKey */ /** Setup keys based on authKey */
async setupKeys(authKey?: Buffer | null): Promise<void> { async setupKeys(authKey?: Buffer | null): Promise<void> {
if (authKey) { if (authKey) {
this.log.debug('setting up keys')
this._authKey = authKey this._authKey = authKey
this._authKeyClientSalt = authKey.slice(88, 120) this._authKeyClientSalt = authKey.slice(88, 120)
this._authKeyServerSalt = authKey.slice(96, 128) this._authKeyServerSalt = authKey.slice(96, 128)
this._authKeyId = (await this._crypto.sha1(this._authKey)).slice(-8) this._authKeyId = (await this._crypto.sha1(this._authKey)).slice(-8)
} else { } else {
this.log.debug('resetting keys')
this._authKey = undefined this._authKey = undefined
this._authKeyClientSalt = undefined this._authKeyClientSalt = undefined
this._authKeyServerSalt = undefined this._authKeyServerSalt = undefined
@ -62,22 +154,76 @@ export class MtprotoSession {
} }
} }
/** Reset session by removing authKey and values derived from it */ /**
* Reset session by removing authKey and values derived from it,
* as well as resetting session state
*/
reset(): void { reset(): void {
this._lastMessageId = Long.ZERO
this._seqNo = 0
this._authKey = undefined this._authKey = undefined
this._authKeyClientSalt = undefined this._authKeyClientSalt = undefined
this._authKeyServerSalt = undefined this._authKeyServerSalt = undefined
this._authKeyId = undefined this._authKeyId = undefined
this._sessionId = randomLong()
// no need to reset server salt this.resetState()
} }
changeSessionId(): void { /**
this._sessionId = randomLong() * Reset session state and generate a new session ID.
*
* By default, also cancels any pending RPC requests.
* If `keepPending` is set to `true`, pending requests will be kept
*/
resetState(keepPending = false): void {
this._lastMessageId = Long.ZERO
this._seqNo = 0 this._seqNo = 0
this._sessionId = randomLong()
this.log.debug('session reset, new sid = %l', this._sessionId)
this.log.prefix = `[SESSION ${this._sessionId.toString(16)}] `
// reset session state
if (!keepPending) {
for (const info of this.pendingMessages.values()) {
if (info._ === 'rpc') {
info.rpc.promise.reject(new Error('Session is reset'))
}
}
this.pendingMessages.clear()
}
this.recentOutgoingMsgIds.clear()
this.recentIncomingMsgIds.clear()
if (!keepPending) {
while (this.queuedRpc.length) {
const rpc = this.queuedRpc.popFront()!
if (rpc.sent === false) {
rpc.promise.reject(new Error('Session is reset'))
}
}
}
this.queuedAcks.length = 0
this.queuedStateReq.length = 0
this.queuedResendReq.length = 0
this.getStateSchedule.clear()
}
enqueueRpc(rpc: PendingRpc, force?: boolean): boolean {
// already queued or cancelled
if ((!force && !rpc.sent) || rpc.cancelled) return false
rpc.sent = false
rpc.containerId = undefined
this.log.debug(
'enqueued %s for sending (msg_id = %s)',
rpc.method,
rpc.msgId || 'n/a'
)
this.queuedRpc.pushBack(rpc)
return true
} }
/** Encrypt a single MTProto message using session's keys */ /** Encrypt a single MTProto message using session's keys */

View file

@ -0,0 +1,160 @@
import EventEmitter from 'events'
import { tl } from '@mtcute/tl'
import { Logger } from '../utils'
import {
SessionConnection,
SessionConnectionParams,
} from './session-connection'
import { MtprotoSession } from './mtproto-session'
export class MultiSessionConnection extends EventEmitter {
private _log: Logger
private _sessions: MtprotoSession[]
constructor(
readonly params: SessionConnectionParams,
private _count: number,
log: Logger
) {
super()
this._log = log.create('multi')
this._sessions = []
this._updateConnections()
}
protected _connections: SessionConnection[] = []
setCount(count: number, doUpdate = true): void {
this._count = count
if (doUpdate) this._updateConnections()
}
private _updateSessions(): void {
// there are two cases
// 1. this msc is main, in which case every connection should have its own session
// 2. this msc is not main, in which case all connections should share the same session
if (!this.params.isMainConnection) {
// case 2
if (this._sessions.length === 0) {
this._sessions.push(
new MtprotoSession(
this.params.crypto,
this._log.create('session'),
this.params.readerMap,
this.params.writerMap
)
)
}
// shouldn't happen, but just in case
while (this._sessions.length > 1) {
this._sessions.pop()!.reset()
}
return
}
// case 1
if (this._sessions.length === this._count) return
if (this._sessions.length > this._count) {
// destroy extra sessions
for (let i = this._sessions.length - 1; i >= this._count; i--) {
this._sessions[i].reset()
}
this._sessions.splice(this._count)
return
}
while (this._sessions.length < this._count) {
this._sessions.push(
new MtprotoSession(
this.params.crypto,
this._log.create('session'),
this.params.readerMap,
this.params.writerMap
)
)
}
}
private _updateConnections(): void {
this._updateSessions()
if (this._connections.length === this._count) return
if (this._connections.length > this._count) {
// destroy extra connections
for (let i = this._connections.length - 1; i >= this._count; i--) {
this._connections[i].destroy()
}
this._connections.splice(this._count)
return
}
// create new connections
for (let i = this._connections.length; i < this._count; i++) {
const session = this.params.isMainConnection ? this._sessions[i] : this._sessions[0]
const conn = new SessionConnection(this.params, session)
conn.on('update', (update) => this.emit('update', update))
this._connections.push(conn)
}
}
destroy(): void {
for (const conn of this._connections) {
conn.destroy()
}
for (const session of this._sessions) {
session.reset()
}
}
private _nextConnection = 0
sendRpc<T extends tl.RpcMethod>(
request: T,
stack?: string,
timeout?: number
): Promise<tl.RpcCallReturn[T['_']]> {
if (this.params.isMainConnection) {
// find the least loaded connection
let min = Infinity
let minIdx = 0
for (let i = 0; i < this._connections.length; i++) {
const conn = this._connections[i]
const total = conn._session.queuedRpc.length + conn._session.pendingMessages.size()
if (total < min) {
min = total
minIdx = i
}
}
return this._connections[minIdx].sendRpc(request, stack, timeout)
}
// round-robin connections
// since they all share the same session, it doesn't matter which one we use
return this._connections[
this._nextConnection++ % this._connections.length
].sendRpc(request, stack, timeout)
}
async changeDc(dc: tl.RawDcOption, authKey?: Buffer | null): Promise<void> {
await Promise.all(
this._connections.map((conn) => conn.changeDc(dc, authKey))
)
}
connect(): void {
for (const conn of this._connections) {
conn.connect()
}
}
}

View file

@ -0,0 +1,202 @@
import { TlReaderMap, TlWriterMap } from '@mtcute/tl-runtime'
import { tl } from '@mtcute/tl'
import { createControllablePromise, ICryptoProvider, Logger } from '../utils'
import { defaultTransportFactory, TransportFactory } from './transports'
import {
defaultReconnectionStrategy,
ReconnectionStrategy,
} from './reconnection'
import { PersistentConnectionParams } from './persistent-connection'
import { ConfigManager } from './config-manager'
import { MultiSessionConnection } from './multi-session-connection'
import { SessionConnectionParams } from './session-connection'
import { ITelegramStorage } from '../storage'
export class DcConnectionManager {
constructor(
readonly manager: NetworkManager,
readonly dcId: number,
private _dc: tl.RawDcOption
) {}
private __baseConnectionParams = (): SessionConnectionParams => ({
crypto: this.manager.params.crypto,
initConnection: this.manager._initConnectionParams,
transportFactory: this.manager._transportFactory,
dc: this._dc,
testMode: this.manager.params.testMode,
reconnectionStrategy: this.manager._reconnectionStrategy,
layer: this.manager.params.layer,
disableUpdates: this.manager.params.disableUpdates,
readerMap: this.manager.params.readerMap,
writerMap: this.manager.params.writerMap,
isMainConnection: false,
})
mainConnection = new MultiSessionConnection(
{
...this.__baseConnectionParams(),
isMainConnection: true,
},
1,
this.manager._log
)
}
export interface NetworkManagerParams {
storage: ITelegramStorage
crypto: ICryptoProvider
log: Logger
apiId: number
initConnectionOptions?: Partial<
Omit<tl.RawInitConnectionRequest, 'apiId' | 'query'>
>
transport?: TransportFactory
reconnectionStrategy?: ReconnectionStrategy<PersistentConnectionParams>
disableUpdates?: boolean
testMode: boolean
layer: number
readerMap: TlReaderMap
writerMap: TlWriterMap
}
export class NetworkManager {
readonly _log = this.params.log.create('network')
readonly _initConnectionParams: tl.RawInitConnectionRequest
readonly _transportFactory: TransportFactory
readonly _reconnectionStrategy: ReconnectionStrategy<PersistentConnectionParams>
protected readonly _dcConnections: Record<number, DcConnectionManager> = {}
protected _primaryDc?: DcConnectionManager
private _keepAliveInterval?: NodeJS.Timeout
private _keepAliveAction = this._defaultKeepAliveAction.bind(this)
private _defaultKeepAliveAction(): void {
if (this._keepAliveInterval) return
// todo
// telegram asks to fetch pending updates
// if there are no updates for 15 minutes.
// core does not have update handling,
// so we just use getState so the server knows
// we still do need updates
// this.call({ _: 'updates.getState' }).catch((e) => {
// if (!(e instanceof tl.errors.RpcError)) {
// this.primaryConnection.reconnect()
// }
// })
}
constructor(
readonly params: NetworkManagerParams,
readonly config: ConfigManager
) {
let deviceModel = 'mtcute on '
let appVersion = 'unknown'
if (typeof process !== 'undefined' && typeof require !== 'undefined') {
const os = require('os')
deviceModel += `${os.type()} ${os.arch()} ${os.release()}`
try {
// for production builds
appVersion = require('../package.json').version
} catch (e) {
try {
// for development builds (additional /src/ in path)
appVersion = require('../../package.json').version
} catch (e) {}
}
} else if (typeof navigator !== 'undefined') {
deviceModel += navigator.userAgent
} else deviceModel += 'unknown'
this._initConnectionParams = {
_: 'initConnection',
deviceModel,
systemVersion: '1.0',
appVersion,
systemLangCode: 'en',
langPack: '', // "langPacks are for official apps only"
langCode: 'en',
...(params.initConnectionOptions ?? {}),
apiId: params.apiId,
query: null as any,
}
this._transportFactory = params.transport ?? defaultTransportFactory
this._reconnectionStrategy =
params.reconnectionStrategy ?? defaultReconnectionStrategy
// this._dcConnections[params.defaultDc?.id ?? 2] =
// new DcConnectionManager(this, params.defaultDc?.id ?? 2)
}
private _switchPrimaryDc(dc: DcConnectionManager) {
if (this._primaryDc && this._primaryDc !== dc) {
// todo clean up
return
}
this._primaryDc = dc
// todo add handlers
/*
this.primaryConnection.on('usable', () => {
this._lastUpdateTime = Date.now()
if (this._keepAliveInterval) clearInterval(this._keepAliveInterval)
this._keepAliveInterval = setInterval(async () => {
if (Date.now() - this._lastUpdateTime > 900_000) {
this._keepAliveAction()
this._lastUpdateTime = Date.now()
}
}, 60_000)
})
this.primaryConnection.on('update', (update) => {
this._lastUpdateTime = Date.now()
this._handleUpdate(update)
})
this.primaryConnection.on('wait', () =>
this._cleanupPrimaryConnection()
)
this.primaryConnection.on('key-change', async (key) => {
this.storage.setAuthKeyFor(this._defaultDc.id, key)
await this._saveStorage()
})
this.primaryConnection.on('error', (err) =>
this._emitError(err, this.primaryConnection)
)
*/
dc.mainConnection.connect()
}
/**
* Perform initial connection to the default DC
*
* @param defaultDc Default DC to connect to
*/
connect(defaultDc: tl.RawDcOption): void {
if (this._dcConnections[defaultDc.id]) {
// shouldn't happen
throw new Error('DC manager already exists')
}
this._dcConnections[defaultDc.id] = new DcConnectionManager(
this,
defaultDc.id,
defaultDc
)
this._switchPrimaryDc(this._dcConnections[defaultDc.id])
}
destroy(): void {
for (const dc of Object.values(this._dcConnections)) {
dc.mainConnection.destroy()
}
if (this._keepAliveInterval) clearInterval(this._keepAliveInterval)
}
}

View file

@ -6,7 +6,9 @@ import { ICryptoProvider, Logger } from '../utils'
import { import {
ControllablePromise, ControllablePromise,
createControllablePromise, createControllablePromise,
} from '../utils/controllable-promise' ICryptoProvider,
Logger,
} from '../utils'
import { ReconnectionStrategy } from './reconnection' import { ReconnectionStrategy } from './reconnection'
import { import {
ITelegramTransport, ITelegramTransport,
@ -23,11 +25,15 @@ export interface PersistentConnectionParams {
inactivityTimeout?: number inactivityTimeout?: number
} }
let nextConnectionUid = 0
/** /**
* Base class for persistent connections. * Base class for persistent connections.
* Only used for {@link PersistentConnection} and used as a mean of code splitting. * Only used for {@link PersistentConnection} and used as a mean of code splitting.
*/ */
export abstract class PersistentConnection extends EventEmitter { export abstract class PersistentConnection extends EventEmitter {
private _uid = nextConnectionUid++
readonly params: PersistentConnectionParams readonly params: PersistentConnectionParams
private _transport!: ITelegramTransport private _transport!: ITelegramTransport
@ -64,6 +70,18 @@ export abstract class PersistentConnection extends EventEmitter {
this.changeTransport(params.transportFactory) this.changeTransport(params.transportFactory)
} }
private _updateLogPrefix() {
this.log.prefix = `[UID ${this._uid}, DC ${this.params.dc.id}] `
}
async changeDc(dc: tl.RawDcOption): Promise<void> {
this.log.debug('dc changed to: %j', dc)
this.params.dc = dc
this._updateLogPrefix()
this.reconnect()
}
changeTransport(factory: TransportFactory): void { changeTransport(factory: TransportFactory): void {
if (this._transport) { if (this._transport) {
this._transport.close() this._transport.close()

View file

@ -25,6 +25,7 @@ import {
removeFromLongArray, removeFromLongArray,
SortedArray, SortedArray,
} from '../utils' } from '../utils'
import { MtprotoSession, PendingMessage, PendingRpc } from './mtproto-session'
import { doAuthorization } from './authorization' import { doAuthorization } from './authorization'
import { MtprotoSession } from './mtproto-session' import { MtprotoSession } from './mtproto-session'
import { import {
@ -39,68 +40,12 @@ export interface SessionConnectionParams extends PersistentConnectionParams {
niceStacks?: boolean niceStacks?: boolean
layer: number layer: number
disableUpdates?: boolean disableUpdates?: boolean
isMainConnection: boolean
readerMap: TlReaderMap readerMap: TlReaderMap
writerMap: TlWriterMap writerMap: TlWriterMap
} }
interface PendingRpc {
method: string
data: Buffer
promise: ControllablePromise
stack?: string
gzipOverhead?: number
sent?: boolean
msgId?: Long
seqNo?: number
containerId?: Long
acked?: boolean
initConn?: boolean
getState?: number
cancelled?: boolean
timeout?: NodeJS.Timeout
}
type PendingMessage =
| {
_: 'rpc'
rpc: PendingRpc
}
| {
_: 'container'
msgIds: Long[]
}
| {
_: 'state'
msgIds: Long[]
containerId: Long
}
| {
_: 'resend'
msgIds: Long[]
containerId: Long
}
| {
_: 'ping'
pingId: Long
containerId: Long
}
| {
_: 'destroy_session'
sessionId: Long
containerId: Long
}
| {
_: 'cancel'
msgId: Long
containerId: Long
}
| {
_: 'future_salts'
containerId: Long
}
// destroy_session#e7512126 session_id:long // destroy_session#e7512126 session_id:long
// todo // todo
const DESTROY_SESSION_ID = Buffer.from('262151e7', 'hex') const DESTROY_SESSION_ID = Buffer.from('262151e7', 'hex')
@ -115,61 +60,31 @@ function makeNiceStack(
}\n at ${method}\n${stack.split('\n').slice(2).join('\n')}` }\n at ${method}\n${stack.split('\n').slice(2).join('\n')}`
} }
let nextConnectionUid = 0
/** /**
* A connection to a single DC. * A connection to a single DC.
*/ */
export class SessionConnection extends PersistentConnection { export class SessionConnection extends PersistentConnection {
readonly params!: SessionConnectionParams readonly params!: SessionConnectionParams
private _uid = nextConnectionUid++
private _session: MtprotoSession
private _flushTimer = new EarlyTimer() private _flushTimer = new EarlyTimer()
/// internal state ///
// recent msg ids
private _recentOutgoingMsgIds = new LruSet<Long>(1000, false, true)
private _recentIncomingMsgIds = new LruSet<Long>(1000, false, true)
// queues
private _queuedRpc = new Deque<PendingRpc>()
private _queuedAcks: Long[] = []
private _queuedStateReq: Long[] = []
private _queuedResendReq: Long[] = []
private _queuedCancelReq: Long[] = []
private _queuedDestroySession: Long[] = [] private _queuedDestroySession: Long[] = []
private _getStateSchedule = new SortedArray<PendingRpc>(
[],
(a, b) => a.getState! - b.getState!,
)
// requests info private _next429Timeout = 1000
private _pendingMessages = new LongMap<PendingMessage>() private _current429Timeout?: NodeJS.Timeout
private _initConnectionCalled = false
private _lastPingRtt = NaN private _lastPingRtt = NaN
private _lastPingTime = 0 private _lastPingTime = 0
private _lastPingMsgId = Long.ZERO private _lastPingMsgId = Long.ZERO
private _lastSessionCreatedUid = Long.ZERO private _lastSessionCreatedUid = Long.ZERO
private _next429Timeout = 1000
private _current429Timeout?: NodeJS.Timeout
private _readerMap: TlReaderMap private _readerMap: TlReaderMap
private _writerMap: TlWriterMap private _writerMap: TlWriterMap
constructor(params: SessionConnectionParams, log: Logger) { constructor(
super(params, log.create('conn')) params: SessionConnectionParams,
this._updateLogPrefix() readonly _session: MtprotoSession
this._session = new MtprotoSession( ) {
params.crypto, super(params, _session.log.create('conn'))
log.create('session'),
params.readerMap,
params.writerMap,
)
this._flushTimer.onTimeout(this._flush.bind(this)) this._flushTimer.onTimeout(this._flush.bind(this))
this._readerMap = params.readerMap this._readerMap = params.readerMap
@ -177,18 +92,10 @@ export class SessionConnection extends PersistentConnection {
this._handleRawMessage = this._handleRawMessage.bind(this) this._handleRawMessage = this._handleRawMessage.bind(this)
} }
private _updateLogPrefix() { async changeDc(dc: tl.RawDcOption, authKey?: Buffer | null): Promise<void> {
this.log.prefix = `[UID ${this._uid}, DC ${this.params.dc.id}] `
}
async changeDc(dc: tl.RawDcOption, authKey?: Buffer): Promise<void> {
this.log.debug('dc changed (has_auth_key = %b) to: %j', authKey, dc)
this._updateLogPrefix()
this._session.reset() this._session.reset()
await this._session.setupKeys(authKey) await this._session.setupKeys(authKey)
this.params.dc = dc await super.changeDc(dc)
this.reconnect()
} }
setupKeys(authKey: Buffer | null): Promise<void> { setupKeys(authKey: Buffer | null): Promise<void> {
@ -212,37 +119,12 @@ export class SessionConnection extends PersistentConnection {
} }
reset(forever = false): void { reset(forever = false): void {
this._initConnectionCalled = false this._session.initConnectionCalled = false
this._resetLastPing(true) this._resetLastPing(true)
this._flushTimer.reset() this._flushTimer.reset()
clearTimeout(this._current429Timeout!) clearTimeout(this._current429Timeout!)
if (forever) { if (forever) {
// reset all the queues, cancel all pending messages, etc
this._session.reset()
for (const info of this._pendingMessages.values()) {
if (info._ === 'rpc') {
info.rpc.promise.reject(new Error('Connection destroyed'))
}
}
this._pendingMessages.clear()
this._recentOutgoingMsgIds.clear()
this._recentIncomingMsgIds.clear()
while (this._queuedRpc.length) {
const rpc = this._queuedRpc.popFront()!
if (rpc.sent === false) {
rpc.promise.reject(new Error('Connection destroyed'))
}
}
this._queuedAcks.length = 0
this._queuedStateReq.length = 0
this._queuedResendReq.length = 0
this._getStateSchedule.clear()
this.removeAllListeners() this.removeAllListeners()
} }
} }
@ -284,11 +166,11 @@ export class SessionConnection extends PersistentConnection {
timeout, timeout,
) )
for (const msgId of this._pendingMessages.keys()) { for (const msgId of this._session.pendingMessages.keys()) {
const info = this._pendingMessages.get(msgId)! const info = this._session.pendingMessages.get(msgId)!
if (info._ === 'container') { if (info._ === 'container') {
this._pendingMessages.delete(msgId) this._session.pendingMessages.delete(msgId)
} else { } else {
this._onMessageFailed(msgId, 'transport flood', true) this._onMessageFailed(msgId, 'transport flood', true)
} }
@ -415,15 +297,15 @@ export class SessionConnection extends PersistentConnection {
return return
} }
if (this._recentIncomingMsgIds.has(messageId)) { if (this._session.recentIncomingMsgIds.has(messageId)) {
this.log.warn('warn: ignoring duplicate message %s', messageId) this.log.warn('warn: ignoring duplicate message %s', messageId)
return return
} }
const message = message_ as mtp.TlObject const message = message_ as mtp.TlObject
this.log.verbose('received %s (msg_id: %s)', message._, messageId) this.log.verbose('received %s (msg_id: %l)', message._, messageId)
this._recentIncomingMsgIds.add(messageId) this._session.recentIncomingMsgIds.add(messageId)
switch (message._) { switch (message._) {
case 'mt_msgs_ack': case 'mt_msgs_ack':
@ -490,6 +372,17 @@ export class SessionConnection extends PersistentConnection {
this._rescheduleInactivity() this._rescheduleInactivity()
} }
if (this.params.disableUpdates) {
this.log.warn(
'received updates, but updates are disabled'
)
break
}
if (!this.params.isMainConnection) {
this.log.warn('received updates on non-main connection')
break
}
this.emit('update', message) this.emit('update', message)
return return
@ -522,8 +415,7 @@ export class SessionConnection extends PersistentConnection {
return return
} }
const msg = this._pendingMessages.get(reqMsgId) const msg = this._session.pendingMessages.get(reqMsgId)
if (!msg) { if (!msg) {
let result let result
@ -538,7 +430,7 @@ export class SessionConnection extends PersistentConnection {
) )
// check if the msg is one of the recent ones // check if the msg is one of the recent ones
if (this._recentOutgoingMsgIds.has(reqMsgId)) { if (this._session.recentOutgoingMsgIds.has(reqMsgId)) {
this.log.debug( this.log.debug(
'received rpc_result again for %l (contains %s)', 'received rpc_result again for %l (contains %s)',
reqMsgId, reqMsgId,
@ -573,7 +465,7 @@ export class SessionConnection extends PersistentConnection {
// initConnection call was definitely received and // initConnection call was definitely received and
// processed by the server, so we no longer need to use it // processed by the server, so we no longer need to use it
if (rpc.initConn) this._initConnectionCalled = true if (rpc.initConn) this._session.initConnectionCalled = true
this.log.verbose('<<< (%s) %j', rpc.method, result) this.log.verbose('<<< (%s) %j', rpc.method, result)
@ -610,11 +502,11 @@ export class SessionConnection extends PersistentConnection {
} }
this._onMessageAcked(reqMsgId) this._onMessageAcked(reqMsgId)
this._pendingMessages.delete(reqMsgId) this._session.pendingMessages.delete(reqMsgId)
} }
private _onMessageAcked(msgId: Long, inContainer = false): void { private _onMessageAcked(msgId: Long, inContainer = false): void {
const msg = this._pendingMessages.get(msgId) const msg = this._session.pendingMessages.get(msgId)
if (!msg) { if (!msg) {
this.log.warn('received ack for unknown message %l', msgId) this.log.warn('received ack for unknown message %l', msgId)
@ -633,7 +525,7 @@ export class SessionConnection extends PersistentConnection {
msg.msgIds.forEach((msgId) => this._onMessageAcked(msgId, true)) msg.msgIds.forEach((msgId) => this._onMessageAcked(msgId, true))
// we no longer need info about the container // we no longer need info about the container
this._pendingMessages.delete(msgId) this._session.pendingMessages.delete(msgId)
break break
case 'rpc': { case 'rpc': {
@ -650,19 +542,19 @@ export class SessionConnection extends PersistentConnection {
if ( if (
!inContainer && !inContainer &&
rpc.containerId && rpc.containerId &&
this._pendingMessages.has(rpc.containerId) this._session.pendingMessages.has(rpc.containerId)
) { ) {
// ack all the messages in that container // ack all the messages in that container
this._onMessageAcked(rpc.containerId) this._onMessageAcked(rpc.containerId)
} }
// this message could also already be in some queue, // this message could also already be in some queue,
removeFromLongArray(this._queuedStateReq, msgId) removeFromLongArray(this._session.queuedStateReq, msgId)
removeFromLongArray(this._queuedResendReq, msgId) removeFromLongArray(this._session.queuedResendReq, msgId)
// if resend/state was already requested, it will simply be ignored // if resend/state was already requested, it will simply be ignored
this._getStateSchedule.remove(rpc) this._session.getStateSchedule.remove(rpc)
break break
} }
default: default:
@ -681,8 +573,7 @@ export class SessionConnection extends PersistentConnection {
reason: string, reason: string,
inContainer = false, inContainer = false,
): void { ): void {
const msgInfo = this._pendingMessages.get(msgId) const msgInfo = this._session.pendingMessages.get(msgId)
if (!msgInfo) { if (!msgInfo) {
this.log.debug( this.log.debug(
'unknown message %l failed because of %s', 'unknown message %l failed because of %s',
@ -725,26 +616,26 @@ export class SessionConnection extends PersistentConnection {
) )
// since the query was rejected, we can let it reassign msg_id to avoid containers // since the query was rejected, we can let it reassign msg_id to avoid containers
this._pendingMessages.delete(msgId) this._session.pendingMessages.delete(msgId)
rpc.msgId = undefined rpc.msgId = undefined
this._enqueueRpc(rpc, true) this._enqueueRpc(rpc, true)
if ( if (
!inContainer && !inContainer &&
rpc.containerId && rpc.containerId &&
this._pendingMessages.has(rpc.containerId) this._session.pendingMessages.has(rpc.containerId)
) { ) {
// fail all the messages in that container // fail all the messages in that container
this._onMessageFailed(rpc.containerId, reason) this._onMessageFailed(rpc.containerId, reason)
} }
// this message could also already be in some queue, // this message could also already be in some queue,
removeFromLongArray(this._queuedStateReq, msgId) removeFromLongArray(this._session.queuedStateReq, msgId)
removeFromLongArray(this._queuedResendReq, msgId) removeFromLongArray(this._session.queuedResendReq, msgId)
// if resend/state was already requested, it will simply be ignored // if resend/state was already requested, it will simply be ignored
this._getStateSchedule.remove(rpc) this._session.getStateSchedule.remove(rpc)
break break
} }
@ -755,7 +646,7 @@ export class SessionConnection extends PersistentConnection {
msgInfo.msgIds.length, msgInfo.msgIds.length,
reason, reason,
) )
this._queuedResendReq.splice(0, 0, ...msgInfo.msgIds) this._session.queuedResendReq.splice(0, 0, ...msgInfo.msgIds)
this._flushTimer.emitWhenIdle() this._flushTimer.emitWhenIdle()
break break
case 'state': case 'state':
@ -765,32 +656,31 @@ export class SessionConnection extends PersistentConnection {
msgInfo.msgIds.length, msgInfo.msgIds.length,
reason, reason,
) )
this._queuedStateReq.splice(0, 0, ...msgInfo.msgIds) this._session.queuedStateReq.splice(0, 0, ...msgInfo.msgIds)
this._flushTimer.emitWhenIdle() this._flushTimer.emitWhenIdle()
break break
} }
this._pendingMessages.delete(msgId) this._session.pendingMessages.delete(msgId)
} }
private _resetLastPing(withTime = false): void { private _resetLastPing(withTime = false): void {
if (withTime) this._lastPingTime = 0 if (withTime) this._lastPingTime = 0
if (!this._lastPingMsgId.isZero()) { if (!this._lastPingMsgId.isZero()) {
this._pendingMessages.delete(this._lastPingMsgId) this._session.pendingMessages.delete(this._lastPingMsgId)
} }
this._lastPingMsgId = Long.ZERO this._lastPingMsgId = Long.ZERO
} }
private _registerOutgoingMsgId(msgId: Long): Long { private _registerOutgoingMsgId(msgId: Long): Long {
this._recentOutgoingMsgIds.add(msgId) this._session.recentOutgoingMsgIds.add(msgId)
return msgId return msgId
} }
private _onPong({ msgId, pingId }: mtp.RawMt_pong): void { private _onPong({ msgId, pingId }: mtp.RawMt_pong): void {
const info = this._pendingMessages.get(msgId) const info = this._session.pendingMessages.get(msgId)
if (!info) { if (!info) {
this.log.warn( this.log.warn(
@ -916,14 +806,14 @@ export class SessionConnection extends PersistentConnection {
firstMsgId, firstMsgId,
) )
for (const msgId of this._pendingMessages.keys()) { for (const msgId of this._session.pendingMessages.keys()) {
const val = this._pendingMessages.get(msgId)! const val = this._session.pendingMessages.get(msgId)!
if (val._ === 'container') { if (val._ === 'container') {
if (msgId.lt(firstMsgId)) { if (msgId.lt(firstMsgId)) {
// all messages in this container will be resent // all messages in this container will be resent
// info about this container is no longer needed // info about this container is no longer needed
this._pendingMessages.delete(msgId) this._session.pendingMessages.delete(msgId)
} }
return return
@ -944,8 +834,7 @@ export class SessionConnection extends PersistentConnection {
answerMsgId: Long, answerMsgId: Long,
): void { ): void {
if (!msgId.isZero()) { if (!msgId.isZero()) {
const info = this._pendingMessages.get(msgId) const info = this._session.pendingMessages.get(msgId)
if (!info) { if (!info) {
this.log.info( this.log.info(
'received message info about unknown message %l', 'received message info about unknown message %l',
@ -986,14 +875,14 @@ export class SessionConnection extends PersistentConnection {
if ( if (
!answerMsgId.isZero() && !answerMsgId.isZero() &&
!this._recentIncomingMsgIds.has(answerMsgId) !this._session.recentIncomingMsgIds.has(answerMsgId)
) { ) {
this.log.debug( this.log.debug(
'received message info for %l, but answer (%l) was not received yet', 'received message info for %l, but answer (%l) was not received yet',
msgId, msgId,
answerMsgId, answerMsgId,
) )
this._queuedResendReq.push(answerMsgId) this._session.queuedResendReq.push(answerMsgId)
this._flushTimer.emitWhenIdle() this._flushTimer.emitWhenIdle()
return return
@ -1019,7 +908,7 @@ export class SessionConnection extends PersistentConnection {
} }
private _onMsgsStateInfo(msg: mtp.RawMt_msgs_state_info): void { private _onMsgsStateInfo(msg: mtp.RawMt_msgs_state_info): void {
const info = this._pendingMessages.get(msg.reqMsgId) const info = this._session.pendingMessages.get(msg.reqMsgId)
if (!info) { if (!info) {
this.log.warn( this.log.warn(
@ -1044,43 +933,28 @@ export class SessionConnection extends PersistentConnection {
} }
private _enqueueRpc(rpc: PendingRpc, force?: boolean) { private _enqueueRpc(rpc: PendingRpc, force?: boolean) {
// already queued or cancelled if (this._session.enqueueRpc(rpc, force))
if ((!force && !rpc.sent) || rpc.cancelled) return this._flushTimer.emitWhenIdle()
rpc.sent = false
rpc.containerId = undefined
this.log.debug(
'enqueued %s for sending (msg_id = %s)',
rpc.method,
rpc.msgId || 'n/a',
)
this._queuedRpc.pushBack(rpc)
this._flushTimer.emitWhenIdle()
} }
_resetSession(): void { _resetSession(): void {
this._queuedDestroySession.push(this._session._sessionId) this._queuedDestroySession.push(this._session._sessionId)
this._session.resetState(true)
this.reconnect() this.reconnect()
this._session.changeSessionId()
this.log.debug('session reset, new sid = %l', this._session._sessionId)
// once we receive new_session_created, all pending messages will be resent. // once we receive new_session_created, all pending messages will be resent.
// clear getState/resend queues because they are not needed anymore
this._queuedStateReq.length = 0
this._queuedResendReq.length = 0
this._flushTimer.reset() this._flushTimer.reset()
} }
private _sendAck(msgId: Long): void { private _sendAck(msgId: Long): void {
if (this._queuedAcks.length === 0) { if (this._session.queuedAcks.length === 0) {
this._flushTimer.emitBeforeNext(30000) this._flushTimer.emitBeforeNext(30000)
} }
this._queuedAcks.push(msgId) this._session.queuedAcks.push(msgId)
if (this._queuedAcks.length >= 100) { if (this._session.queuedAcks.length >= 100) {
this._flushTimer.emitNow() this._flushTimer.emitNow()
} }
} }
@ -1110,7 +984,7 @@ export class SessionConnection extends PersistentConnection {
} }
} }
if (!this._initConnectionCalled) { if (!this._session.initConnectionCalled) {
// we will wrap every rpc call with initConnection // we will wrap every rpc call with initConnection
// until some of the requests wrapped with it is // until some of the requests wrapped with it is
// either acked or returns rpc_result // either acked or returns rpc_result
@ -1239,12 +1113,12 @@ export class SessionConnection extends PersistentConnection {
rpc.cancelled = true rpc.cancelled = true
if (rpc.msgId) { if (rpc.msgId) {
this._queuedCancelReq.push(rpc.msgId) this._session.queuedCancelReq.push(rpc.msgId)
this._flushTimer.emitWhenIdle() this._flushTimer.emitWhenIdle()
} else { } else {
// in case rpc wasn't sent yet (or had some error), // in case rpc wasn't sent yet (or had some error),
// we can simply remove it from queue // we can simply remove it from queue
this._queuedRpc.remove(rpc) this._session.queuedRpc.remove(rpc)
} }
} }
@ -1265,10 +1139,10 @@ export class SessionConnection extends PersistentConnection {
// if there are more queued requests, flush immediately // if there are more queued requests, flush immediately
// (they likely didn't fit into one message) // (they likely didn't fit into one message)
if ( if (
this._queuedRpc.length || this._session.queuedRpc.length ||
this._queuedAcks.length || this._session.queuedAcks.length ||
this._queuedStateReq.length || this._session.queuedStateReq.length ||
this._queuedResendReq.length this._session.queuedResendReq.length
) { ) {
this._flush() this._flush()
} else { } else {
@ -1279,7 +1153,7 @@ export class SessionConnection extends PersistentConnection {
private _doFlush(): void { private _doFlush(): void {
this.log.debug( this.log.debug(
'flushing send queue. queued rpc: %d', 'flushing send queue. queued rpc: %d',
this._queuedRpc.length, this._session.queuedRpc.length
) )
// oh bloody hell mate // oh bloody hell mate
@ -1312,14 +1186,13 @@ export class SessionConnection extends PersistentConnection {
const now = Date.now() const now = Date.now()
if (this._queuedAcks.length) { if (this._session.queuedAcks.length) {
let acks = this._queuedAcks let acks = this._session.queuedAcks
if (acks.length > 8192) { if (acks.length > 8192) {
this._queuedAcks = acks.slice(8192) this._session.queuedAcks = acks.slice(8192)
acks = acks.slice(0, 8192) acks = acks.slice(0, 8192)
} else { } else {
this._queuedAcks = [] this._session.queuedAcks = []
} }
const obj: mtp.RawMt_msgs_ack = { const obj: mtp.RawMt_msgs_ack = {
@ -1341,7 +1214,7 @@ export class SessionConnection extends PersistentConnection {
"didn't receive pong for previous ping (msg_id = %l)", "didn't receive pong for previous ping (msg_id = %l)",
this._lastPingMsgId, this._lastPingMsgId,
) )
this._pendingMessages.delete(this._lastPingMsgId) this._session.pendingMessages.delete(this._lastPingMsgId)
} }
pingId = randomLong() pingId = randomLong()
@ -1358,25 +1231,27 @@ export class SessionConnection extends PersistentConnection {
} }
{ {
if (this._queuedStateReq.length) { if (this._session.queuedStateReq.length) {
let ids = this._queuedStateReq let ids = this._session.queuedStateReq
if (ids.length > 8192) { if (ids.length > 8192) {
this._queuedStateReq = ids.slice(8192) this._session.queuedStateReq = ids.slice(8192)
ids = ids.slice(0, 8192) ids = ids.slice(0, 8192)
} else { } else {
this._queuedStateReq = [] this._session.queuedStateReq = []
} }
getStateMsgIds = ids getStateMsgIds = ids
} }
const idx = this._getStateSchedule.index( const idx = this._session.getStateSchedule.index(
{ getState: now } as any, { getState: now } as any,
true, true,
) )
if (idx > 0) { if (idx > 0) {
const toGetState = this._getStateSchedule.raw.splice(0, idx) const toGetState = this._session.getStateSchedule.raw.splice(
0,
idx
)
if (!getStateMsgIds) getStateMsgIds = [] if (!getStateMsgIds) getStateMsgIds = []
toGetState.forEach((it) => getStateMsgIds!.push(it.msgId!)) toGetState.forEach((it) => getStateMsgIds!.push(it.msgId!))
} }
@ -1396,14 +1271,13 @@ export class SessionConnection extends PersistentConnection {
} }
} }
if (this._queuedResendReq.length) { if (this._session.queuedResendReq.length) {
resendMsgIds = this._queuedResendReq resendMsgIds = this._session.queuedResendReq
if (resendMsgIds.length > 8192) { if (resendMsgIds.length > 8192) {
this._queuedResendReq = resendMsgIds.slice(8192) this._session.queuedResendReq = resendMsgIds.slice(8192)
resendMsgIds = resendMsgIds.slice(0, 8192) resendMsgIds = resendMsgIds.slice(0, 8192)
} else { } else {
this._queuedResendReq = [] this._session.queuedResendReq = []
} }
const obj: mtp.RawMt_msg_resend_req = { const obj: mtp.RawMt_msg_resend_req = {
@ -1416,16 +1290,16 @@ export class SessionConnection extends PersistentConnection {
messageCount += 1 messageCount += 1
} }
if (this._queuedCancelReq.length) { if (this._session.queuedCancelReq.length) {
containerMessageCount += this._queuedCancelReq.length containerMessageCount += this._session.queuedCancelReq.length
containerSize += this._queuedCancelReq.length * 28 containerSize += this._session.queuedCancelReq.length * 28
cancelRpcs = this._queuedCancelReq cancelRpcs = this._session.queuedCancelReq
this._queuedCancelReq = [] this._session.queuedCancelReq = []
} }
if (this._queuedDestroySession.length) { if (this._queuedDestroySession.length) {
containerMessageCount += this._queuedCancelReq.length containerMessageCount += this._session.queuedCancelReq.length
containerSize += this._queuedCancelReq.length * 28 containerSize += this._session.queuedCancelReq.length * 28
destroySessions = this._queuedDestroySession destroySessions = this._queuedDestroySession
this._queuedDestroySession = [] this._queuedDestroySession = []
} }
@ -1434,11 +1308,11 @@ export class SessionConnection extends PersistentConnection {
const rpcToSend: PendingRpc[] = [] const rpcToSend: PendingRpc[] = []
while ( while (
this._queuedRpc.length && this._session.queuedRpc.length &&
containerSize < 32768 && // 2^15 containerSize < 32768 && // 2^15
containerMessageCount < 1020 containerMessageCount < 1020
) { ) {
const msg = this._queuedRpc.popFront()! const msg = this._session.queuedRpc.popFront()!
if (msg.cancelled) continue if (msg.cancelled) continue
// note: we don't check for <2^15 here // note: we don't check for <2^15 here
@ -1497,7 +1371,7 @@ export class SessionConnection extends PersistentConnection {
pingId: pingId!, pingId: pingId!,
containerId: pingMsgId, containerId: pingMsgId,
} }
this._pendingMessages.set(pingMsgId, pingPending) this._session.pendingMessages.set(pingMsgId, pingPending)
otherPendings.push(pingPending) otherPendings.push(pingPending)
} }
@ -1510,7 +1384,7 @@ export class SessionConnection extends PersistentConnection {
msgIds: getStateMsgIds!, msgIds: getStateMsgIds!,
containerId: getStateMsgId, containerId: getStateMsgId,
} }
this._pendingMessages.set(getStateMsgId, getStatePending) this._session.pendingMessages.set(getStateMsgId, getStatePending)
otherPendings.push(getStatePending) otherPendings.push(getStatePending)
} }
@ -1523,7 +1397,7 @@ export class SessionConnection extends PersistentConnection {
msgIds: resendMsgIds!, msgIds: resendMsgIds!,
containerId: resendMsgId, containerId: resendMsgId,
} }
this._pendingMessages.set(resendMsgId, resendPending) this._session.pendingMessages.set(resendMsgId, resendPending)
otherPendings.push(resendPending) otherPendings.push(resendPending)
} }
@ -1541,7 +1415,7 @@ export class SessionConnection extends PersistentConnection {
msgId, msgId,
containerId: cancelMsgId, containerId: cancelMsgId,
} }
this._pendingMessages.set(cancelMsgId, pending) this._session.pendingMessages.set(cancelMsgId, pending)
otherPendings.push(pending) otherPendings.push(pending)
}) })
} }
@ -1560,7 +1434,7 @@ export class SessionConnection extends PersistentConnection {
sessionId, sessionId,
containerId: msgId, containerId: msgId,
} }
this._pendingMessages.set(msgId, pending) this._session.pendingMessages.set(msgId, pending)
otherPendings.push(pending) otherPendings.push(pending)
}) })
} }
@ -1584,7 +1458,7 @@ export class SessionConnection extends PersistentConnection {
msg.msgId = msgId msg.msgId = msgId
msg.seqNo = seqNo msg.seqNo = seqNo
this._pendingMessages.set(msgId, { this._session.pendingMessages.set(msgId, {
_: 'rpc', _: 'rpc',
rpc: msg, rpc: msg,
}) })
@ -1599,11 +1473,11 @@ export class SessionConnection extends PersistentConnection {
// (re-)schedule get_state if needed // (re-)schedule get_state if needed
if (msg.getState) { if (msg.getState) {
this._getStateSchedule.remove(msg) this._session.getStateSchedule.remove(msg)
} }
if (!msg.acked) { if (!msg.acked) {
msg.getState = getStateTime msg.getState = getStateTime
this._getStateSchedule.insert(msg) this._session.getStateSchedule.insert(msg)
} }
writer.long(this._registerOutgoingMsgId(msg.msgId)) writer.long(this._registerOutgoingMsgId(msg.msgId))
@ -1650,7 +1524,10 @@ export class SessionConnection extends PersistentConnection {
}) })
} }
this._pendingMessages.set(containerId, { _: 'container', msgIds }) this._session.pendingMessages.set(containerId, {
_: 'container',
msgIds,
})
} }
const result = writer.result() const result = writer.result()
@ -1685,7 +1562,8 @@ export class SessionConnection extends PersistentConnection {
) )
// put acks in the front so they are the first to be sent // put acks in the front so they are the first to be sent
if (ackMsgIds) this._queuedAcks.splice(0, 0, ...ackMsgIds) if (ackMsgIds)
this._session.queuedAcks.splice(0, 0, ...ackMsgIds)
this._onMessageFailed(rootMsgId, 'unknown error') this._onMessageFailed(rootMsgId, 'unknown error')
}) })
} }

View file

@ -59,7 +59,7 @@ describe('fuzz : transport', function () {
transport: () => new RandomBytesTransport(), transport: () => new RandomBytesTransport(),
apiId: 0, apiId: 0,
apiHash: '', apiHash: '',
primaryDc: defaultDcs.defaultTestDc, defaultDc: defaultDcs.defaultTestDc,
}) })
client.log.level = 0 client.log.level = 0
@ -85,7 +85,7 @@ describe('fuzz : transport', function () {
transport: () => new RandomBytesTransport(), transport: () => new RandomBytesTransport(),
apiId: 0, apiId: 0,
apiHash: '', apiHash: '',
primaryDc: defaultDcs.defaultTestDc, defaultDc: defaultDcs.defaultTestDc,
}) })
client.log.level = 0 client.log.level = 0