fix(core): proper handling of pfs with media dcs
This commit is contained in:
parent
1c5815ecf0
commit
6b430c5a2a
3 changed files with 42 additions and 113 deletions
|
@ -5,12 +5,10 @@ import type { SessionConnectionParams } from './session-connection.js'
|
|||
|
||||
import type { TelegramTransport } from './transports/index.js'
|
||||
import { Deferred, Emitter, unknownToError } from '@fuman/utils'
|
||||
import { MtprotoSession } from './mtproto-session.js'
|
||||
import { SessionConnection } from './session-connection.js'
|
||||
|
||||
export class MultiSessionConnection {
|
||||
private _log: Logger
|
||||
readonly _sessions: MtprotoSession[]
|
||||
private _enforcePfs = false
|
||||
|
||||
// NB: dont forget to update .reset()
|
||||
|
@ -35,11 +33,10 @@ export class MultiSessionConnection {
|
|||
if (logPrefix) this._log.prefix = `[${logPrefix}] `
|
||||
this._enforcePfs = _count > 1 && params.isMainConnection
|
||||
|
||||
this._sessions = []
|
||||
this._updateConnections()
|
||||
}
|
||||
|
||||
protected _connections: SessionConnection[] = []
|
||||
readonly _connections: SessionConnection[] = []
|
||||
|
||||
setCount(count: number, connect: boolean = this.params.isMainConnection): void {
|
||||
this._count = count
|
||||
|
@ -47,71 +44,11 @@ export class MultiSessionConnection {
|
|||
this._updateConnections(connect)
|
||||
}
|
||||
|
||||
private _updateSessions(): void {
|
||||
// there are two cases
|
||||
// 1. this msc is main, in which case every connection should have its own session
|
||||
// 2. this msc is not main, in which case all connections should share the same session
|
||||
// if (!this.params.isMainConnection) {
|
||||
// // case 2
|
||||
// this._log.debug(
|
||||
// 'updating sessions count: %d -> 1',
|
||||
// this._sessions.length,
|
||||
// )
|
||||
//
|
||||
// if (this._sessions.length === 0) {
|
||||
// this._sessions.push(
|
||||
// new MtprotoSession(
|
||||
// this.params.crypto,
|
||||
// this._log.create('session'),
|
||||
// this.params.readerMap,
|
||||
// this.params.writerMap,
|
||||
// ),
|
||||
// )
|
||||
// }
|
||||
//
|
||||
// // shouldn't happen, but just in case
|
||||
// while (this._sessions.length > 1) {
|
||||
// this._sessions.pop()!.reset()
|
||||
// }
|
||||
//
|
||||
// return
|
||||
// }
|
||||
|
||||
this._log.debug('updating sessions count: %d -> %d', this._sessions.length, this._count)
|
||||
|
||||
// case 1
|
||||
if (this._sessions.length === this._count) return
|
||||
|
||||
if (this._sessions.length > this._count) {
|
||||
// destroy extra sessions
|
||||
for (let i = this._sessions.length - 1; i >= this._count; i--) {
|
||||
this._sessions[i].reset()
|
||||
}
|
||||
|
||||
this._sessions.splice(this._count)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
while (this._sessions.length < this._count) {
|
||||
const idx = this._sessions.length
|
||||
const session = new MtprotoSession(
|
||||
this.params.crypto,
|
||||
this._log.create('session'),
|
||||
this.params.readerMap,
|
||||
this.params.writerMap,
|
||||
this.params.salts,
|
||||
)
|
||||
|
||||
// brvh
|
||||
if (idx !== 0) session._authKey = this._sessions[0]._authKey
|
||||
|
||||
this._sessions.push(session)
|
||||
}
|
||||
getCount(): number {
|
||||
return this._count
|
||||
}
|
||||
|
||||
private _updateConnections(connect = false): void {
|
||||
this._updateSessions()
|
||||
if (this._connections.length === this._count) return
|
||||
|
||||
this._log.debug('updating connections count: %d -> %d', this._connections.length, this._count)
|
||||
|
@ -157,8 +94,6 @@ export class MultiSessionConnection {
|
|||
|
||||
// create new connections
|
||||
for (let i = this._connections.length; i < this._count; i++) {
|
||||
const session = this._sessions[i] // this.params.isMainConnection ? // :
|
||||
// this._sessions[0]
|
||||
const conn = new SessionConnection(
|
||||
{
|
||||
...this.params,
|
||||
|
@ -167,7 +102,7 @@ export class MultiSessionConnection {
|
|||
withUpdates:
|
||||
this.params.isMainConnection && this.params.isMainDcConnection && !this.params.disableUpdates,
|
||||
},
|
||||
session,
|
||||
this._log,
|
||||
)
|
||||
|
||||
if (this.params.isMainConnection && this.params.isMainDcConnection) {
|
||||
|
@ -216,7 +151,6 @@ export class MultiSessionConnection {
|
|||
_destroyed = false
|
||||
async destroy(): Promise<void> {
|
||||
await Promise.all(this._connections.map(conn => conn.destroy()))
|
||||
this._sessions.forEach(sess => sess.reset())
|
||||
|
||||
this.onRequestKeys.clear()
|
||||
this.onError.clear()
|
||||
|
@ -279,14 +213,14 @@ export class MultiSessionConnection {
|
|||
}
|
||||
|
||||
setAuthKey(authKey: Uint8Array | null, temp = false, idx = 0): void {
|
||||
const session = this._sessions[idx]
|
||||
const session = this._connections[idx]._session
|
||||
const key = temp ? session._authKeyTemp : session._authKey
|
||||
key.setup(authKey)
|
||||
}
|
||||
|
||||
resetAuthKeys(): void {
|
||||
for (const session of this._sessions) {
|
||||
session.reset(true)
|
||||
for (const conn of this._connections) {
|
||||
conn._session.reset(true)
|
||||
}
|
||||
this.notifyKeyChange()
|
||||
}
|
||||
|
@ -305,21 +239,23 @@ export class MultiSessionConnection {
|
|||
|
||||
notifyKeyChange(): void {
|
||||
// only expected to be called on non-main connections
|
||||
const session = this._sessions[0]
|
||||
for (const conn of this._connections) {
|
||||
const session = conn._session
|
||||
|
||||
if (this.params.usePfs && !session._authKeyTemp.ready) {
|
||||
this._log.debug('temp auth key needed but not ready, ignoring key change')
|
||||
if (this.params.usePfs && !session._authKeyTemp.ready) {
|
||||
this._log.debug('temp auth key needed but not ready, ignoring key change')
|
||||
|
||||
return
|
||||
continue
|
||||
}
|
||||
|
||||
if (session.queuedRpc.length) {
|
||||
// there are pending requests, we need to reconnect.
|
||||
this._log.debug('notifying key change on the connection due to queued rpc')
|
||||
this._connections.forEach(conn => conn.onConnected())
|
||||
}
|
||||
|
||||
// connection is idle, we don't need to notify it
|
||||
}
|
||||
|
||||
if (this._sessions[0].queuedRpc.length) {
|
||||
// there are pending requests, we need to reconnect.
|
||||
this._log.debug('notifying key change on the connection due to queued rpc')
|
||||
this._connections.forEach(conn => conn.onConnected())
|
||||
}
|
||||
|
||||
// connection is idle, we don't need to notify it
|
||||
}
|
||||
|
||||
notifyNetworkChanged(connected: boolean): void {
|
||||
|
|
|
@ -341,24 +341,13 @@ export class DcConnectionManager {
|
|||
})
|
||||
connection.onTmpKeyChange.add(([idx, key, expires]) => {
|
||||
if (kind !== 'main') {
|
||||
this.manager._log.warn('got tmp-key-change from non-main connection, ignoring')
|
||||
|
||||
// tmp keys in media dcs are ephemeral so there's no point in storing them
|
||||
return
|
||||
}
|
||||
|
||||
this.manager._log.debug('temp key change for dc %d from connection %d', this.dcId, idx)
|
||||
|
||||
// send key to other connections
|
||||
this.upload.setAuthKey(key, true)
|
||||
this.download.setAuthKey(key, true)
|
||||
this.downloadSmall.setAuthKey(key, true)
|
||||
|
||||
Promise.resolve(this.manager._storage.provider.authKeys.setTemp(this.dcId, idx, key, expires * 1000))
|
||||
.then(() => {
|
||||
this.upload.notifyKeyChange()
|
||||
this.download.notifyKeyChange()
|
||||
this.downloadSmall.notifyKeyChange()
|
||||
})
|
||||
.catch((e: Error) => {
|
||||
this.manager._log.warn('failed to save temp auth key %d for dc %d: %e', idx, this.dcId, e)
|
||||
this.manager.params.emitError(e)
|
||||
|
@ -442,15 +431,12 @@ export class DcConnectionManager {
|
|||
if (this.manager.params.usePfs || forcePfs) {
|
||||
const now = Date.now()
|
||||
await Promise.all(
|
||||
this.main._sessions.map(async (_, i) => {
|
||||
Array.from({ length: this.main.getCount() }, async (_, i) => {
|
||||
const temp = await this.manager._storage.provider.authKeys.getTemp(this.dcId, i, now)
|
||||
this.main.setAuthKey(temp, true, i)
|
||||
|
||||
if (i === 0) {
|
||||
this.upload.setAuthKey(temp, true)
|
||||
this.download.setAuthKey(temp, true)
|
||||
this.downloadSmall.setAuthKey(temp, true)
|
||||
}
|
||||
// NB: we do not set temp auth keys for media connections,
|
||||
// as they are ephemeral and dc-bound. doing this *will* lead to unwanted -404s
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
@ -915,7 +901,7 @@ export class NetworkManager {
|
|||
}
|
||||
|
||||
getMtprotoMessageId(): Long {
|
||||
return this._primaryDc!.main._sessions[0].getMessageId()
|
||||
return this._primaryDc!.main._connections[0]._session.getMessageId()
|
||||
}
|
||||
|
||||
async recreateDc(dcId: number): Promise<void> {
|
||||
|
|
|
@ -1,17 +1,17 @@
|
|||
import type { mtp } from '@mtcute/tl'
|
||||
import type { TlReaderMap, TlWriterMap } from '@mtcute/tl-runtime'
|
||||
import type { ICorePlatform } from '../types/platform.js'
|
||||
import type { ICryptoProvider } from '../utils/index.js'
|
||||
import type { MtprotoSession, PendingMessage, PendingRpc } from './mtproto-session.js'
|
||||
import type { ICryptoProvider, Logger } from '../utils/index.js'
|
||||
import type { PendingMessage, PendingRpc } from './mtproto-session.js'
|
||||
import type { PersistentConnectionParams } from './persistent-connection.js'
|
||||
|
||||
import type { ServerSaltManager } from './server-salt.js'
|
||||
|
||||
import { Deferred, Emitter, timers, u8 } from '@fuman/utils'
|
||||
import { tl } from '@mtcute/tl'
|
||||
import { TlBinaryReader, TlBinaryWriter, TlSerializationCounter } from '@mtcute/tl-runtime'
|
||||
import Long from 'long'
|
||||
|
||||
import { MtArgumentError, MtcuteError, MtTimeoutError } from '../types/index.js'
|
||||
|
||||
import { createAesIgeForMessageOld } from '../utils/crypto/mtproto.js'
|
||||
import {
|
||||
EarlyTimer,
|
||||
|
@ -20,6 +20,7 @@ import {
|
|||
removeFromLongArray,
|
||||
} from '../utils/index.js'
|
||||
import { doAuthorization } from './authorization.js'
|
||||
import { MtprotoSession } from './mtproto-session.js'
|
||||
import { PersistentConnection } from './persistent-connection.js'
|
||||
import { TransportError } from './transports/abstract.js'
|
||||
|
||||
|
@ -82,6 +83,7 @@ export class SessionConnection extends PersistentConnection {
|
|||
private _writerMap: TlWriterMap
|
||||
private _crypto: ICryptoProvider
|
||||
private _salts: ServerSaltManager
|
||||
readonly _session: MtprotoSession
|
||||
|
||||
// todo: we should probably do adaptive ping interval based on rtt like tdlib:
|
||||
// https://github.com/tdlib/td/blob/91aa6c9e4d0774eabf4f8d7f3aa51239032059a6/td/mtproto/SessionConnection.h
|
||||
|
@ -97,11 +99,15 @@ export class SessionConnection extends PersistentConnection {
|
|||
readonly onUpdate: Emitter<tl.TypeUpdates> = new Emitter()
|
||||
readonly onFutureSalts: Emitter<mtp.RawMt_future_salt[]> = new Emitter()
|
||||
|
||||
constructor(
|
||||
params: SessionConnectionParams,
|
||||
readonly _session: MtprotoSession,
|
||||
) {
|
||||
super(params, _session.log.create('conn'))
|
||||
constructor(params: SessionConnectionParams, log: Logger) {
|
||||
super(params, log.create('conn'))
|
||||
this._session = new MtprotoSession(
|
||||
params.crypto,
|
||||
log.create('session'),
|
||||
params.readerMap,
|
||||
params.writerMap,
|
||||
params.salts,
|
||||
)
|
||||
this._flushTimer.onTimeout(this._flush.bind(this))
|
||||
|
||||
this._pingInterval = params.pingInterval
|
||||
|
@ -188,6 +194,7 @@ export class SessionConnection extends PersistentConnection {
|
|||
this.onError.add((err) => {
|
||||
this.log.warn('caught error after destroying: %s', err)
|
||||
})
|
||||
this._session.reset()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue