fix(core): proper handling of pfs with media dcs

This commit is contained in:
alina 🌸 2024-12-16 11:21:54 +03:00
parent 1c5815ecf0
commit 6b430c5a2a
Signed by: teidesu
SSH key fingerprint: SHA256:uNeCpw6aTSU4aIObXLvHfLkDa82HWH9EiOj9AXOIRpI
3 changed files with 42 additions and 113 deletions

View file

@ -5,12 +5,10 @@ import type { SessionConnectionParams } from './session-connection.js'
import type { TelegramTransport } from './transports/index.js' import type { TelegramTransport } from './transports/index.js'
import { Deferred, Emitter, unknownToError } from '@fuman/utils' import { Deferred, Emitter, unknownToError } from '@fuman/utils'
import { MtprotoSession } from './mtproto-session.js'
import { SessionConnection } from './session-connection.js' import { SessionConnection } from './session-connection.js'
export class MultiSessionConnection { export class MultiSessionConnection {
private _log: Logger private _log: Logger
readonly _sessions: MtprotoSession[]
private _enforcePfs = false private _enforcePfs = false
// NB: dont forget to update .reset() // NB: dont forget to update .reset()
@ -35,11 +33,10 @@ export class MultiSessionConnection {
if (logPrefix) this._log.prefix = `[${logPrefix}] ` if (logPrefix) this._log.prefix = `[${logPrefix}] `
this._enforcePfs = _count > 1 && params.isMainConnection this._enforcePfs = _count > 1 && params.isMainConnection
this._sessions = []
this._updateConnections() this._updateConnections()
} }
protected _connections: SessionConnection[] = [] readonly _connections: SessionConnection[] = []
setCount(count: number, connect: boolean = this.params.isMainConnection): void { setCount(count: number, connect: boolean = this.params.isMainConnection): void {
this._count = count this._count = count
@ -47,71 +44,11 @@ export class MultiSessionConnection {
this._updateConnections(connect) this._updateConnections(connect)
} }
private _updateSessions(): void { getCount(): number {
// there are two cases return this._count
// 1. this msc is main, in which case every connection should have its own session
// 2. this msc is not main, in which case all connections should share the same session
// if (!this.params.isMainConnection) {
// // case 2
// this._log.debug(
// 'updating sessions count: %d -> 1',
// this._sessions.length,
// )
//
// if (this._sessions.length === 0) {
// this._sessions.push(
// new MtprotoSession(
// this.params.crypto,
// this._log.create('session'),
// this.params.readerMap,
// this.params.writerMap,
// ),
// )
// }
//
// // shouldn't happen, but just in case
// while (this._sessions.length > 1) {
// this._sessions.pop()!.reset()
// }
//
// return
// }
this._log.debug('updating sessions count: %d -> %d', this._sessions.length, this._count)
// case 1
if (this._sessions.length === this._count) return
if (this._sessions.length > this._count) {
// destroy extra sessions
for (let i = this._sessions.length - 1; i >= this._count; i--) {
this._sessions[i].reset()
}
this._sessions.splice(this._count)
return
}
while (this._sessions.length < this._count) {
const idx = this._sessions.length
const session = new MtprotoSession(
this.params.crypto,
this._log.create('session'),
this.params.readerMap,
this.params.writerMap,
this.params.salts,
)
// brvh
if (idx !== 0) session._authKey = this._sessions[0]._authKey
this._sessions.push(session)
}
} }
private _updateConnections(connect = false): void { private _updateConnections(connect = false): void {
this._updateSessions()
if (this._connections.length === this._count) return if (this._connections.length === this._count) return
this._log.debug('updating connections count: %d -> %d', this._connections.length, this._count) this._log.debug('updating connections count: %d -> %d', this._connections.length, this._count)
@ -157,8 +94,6 @@ export class MultiSessionConnection {
// create new connections // create new connections
for (let i = this._connections.length; i < this._count; i++) { for (let i = this._connections.length; i < this._count; i++) {
const session = this._sessions[i] // this.params.isMainConnection ? // :
// this._sessions[0]
const conn = new SessionConnection( const conn = new SessionConnection(
{ {
...this.params, ...this.params,
@ -167,7 +102,7 @@ export class MultiSessionConnection {
withUpdates: withUpdates:
this.params.isMainConnection && this.params.isMainDcConnection && !this.params.disableUpdates, this.params.isMainConnection && this.params.isMainDcConnection && !this.params.disableUpdates,
}, },
session, this._log,
) )
if (this.params.isMainConnection && this.params.isMainDcConnection) { if (this.params.isMainConnection && this.params.isMainDcConnection) {
@ -216,7 +151,6 @@ export class MultiSessionConnection {
_destroyed = false _destroyed = false
async destroy(): Promise<void> { async destroy(): Promise<void> {
await Promise.all(this._connections.map(conn => conn.destroy())) await Promise.all(this._connections.map(conn => conn.destroy()))
this._sessions.forEach(sess => sess.reset())
this.onRequestKeys.clear() this.onRequestKeys.clear()
this.onError.clear() this.onError.clear()
@ -279,14 +213,14 @@ export class MultiSessionConnection {
} }
setAuthKey(authKey: Uint8Array | null, temp = false, idx = 0): void { setAuthKey(authKey: Uint8Array | null, temp = false, idx = 0): void {
const session = this._sessions[idx] const session = this._connections[idx]._session
const key = temp ? session._authKeyTemp : session._authKey const key = temp ? session._authKeyTemp : session._authKey
key.setup(authKey) key.setup(authKey)
} }
resetAuthKeys(): void { resetAuthKeys(): void {
for (const session of this._sessions) { for (const conn of this._connections) {
session.reset(true) conn._session.reset(true)
} }
this.notifyKeyChange() this.notifyKeyChange()
} }
@ -305,21 +239,23 @@ export class MultiSessionConnection {
notifyKeyChange(): void { notifyKeyChange(): void {
// only expected to be called on non-main connections // only expected to be called on non-main connections
const session = this._sessions[0] for (const conn of this._connections) {
const session = conn._session
if (this.params.usePfs && !session._authKeyTemp.ready) { if (this.params.usePfs && !session._authKeyTemp.ready) {
this._log.debug('temp auth key needed but not ready, ignoring key change') this._log.debug('temp auth key needed but not ready, ignoring key change')
return continue
}
if (session.queuedRpc.length) {
// there are pending requests, we need to reconnect.
this._log.debug('notifying key change on the connection due to queued rpc')
this._connections.forEach(conn => conn.onConnected())
}
// connection is idle, we don't need to notify it
} }
if (this._sessions[0].queuedRpc.length) {
// there are pending requests, we need to reconnect.
this._log.debug('notifying key change on the connection due to queued rpc')
this._connections.forEach(conn => conn.onConnected())
}
// connection is idle, we don't need to notify it
} }
notifyNetworkChanged(connected: boolean): void { notifyNetworkChanged(connected: boolean): void {

View file

@ -341,24 +341,13 @@ export class DcConnectionManager {
}) })
connection.onTmpKeyChange.add(([idx, key, expires]) => { connection.onTmpKeyChange.add(([idx, key, expires]) => {
if (kind !== 'main') { if (kind !== 'main') {
this.manager._log.warn('got tmp-key-change from non-main connection, ignoring') // tmp keys in media dcs are ephemeral so there's no point in storing them
return return
} }
this.manager._log.debug('temp key change for dc %d from connection %d', this.dcId, idx) this.manager._log.debug('temp key change for dc %d from connection %d', this.dcId, idx)
// send key to other connections
this.upload.setAuthKey(key, true)
this.download.setAuthKey(key, true)
this.downloadSmall.setAuthKey(key, true)
Promise.resolve(this.manager._storage.provider.authKeys.setTemp(this.dcId, idx, key, expires * 1000)) Promise.resolve(this.manager._storage.provider.authKeys.setTemp(this.dcId, idx, key, expires * 1000))
.then(() => {
this.upload.notifyKeyChange()
this.download.notifyKeyChange()
this.downloadSmall.notifyKeyChange()
})
.catch((e: Error) => { .catch((e: Error) => {
this.manager._log.warn('failed to save temp auth key %d for dc %d: %e', idx, this.dcId, e) this.manager._log.warn('failed to save temp auth key %d for dc %d: %e', idx, this.dcId, e)
this.manager.params.emitError(e) this.manager.params.emitError(e)
@ -442,15 +431,12 @@ export class DcConnectionManager {
if (this.manager.params.usePfs || forcePfs) { if (this.manager.params.usePfs || forcePfs) {
const now = Date.now() const now = Date.now()
await Promise.all( await Promise.all(
this.main._sessions.map(async (_, i) => { Array.from({ length: this.main.getCount() }, async (_, i) => {
const temp = await this.manager._storage.provider.authKeys.getTemp(this.dcId, i, now) const temp = await this.manager._storage.provider.authKeys.getTemp(this.dcId, i, now)
this.main.setAuthKey(temp, true, i) this.main.setAuthKey(temp, true, i)
if (i === 0) { // NB: we do not set temp auth keys for media connections,
this.upload.setAuthKey(temp, true) // as they are ephemeral and dc-bound. doing this *will* lead to unwanted -404s
this.download.setAuthKey(temp, true)
this.downloadSmall.setAuthKey(temp, true)
}
}), }),
) )
} }
@ -915,7 +901,7 @@ export class NetworkManager {
} }
getMtprotoMessageId(): Long { getMtprotoMessageId(): Long {
return this._primaryDc!.main._sessions[0].getMessageId() return this._primaryDc!.main._connections[0]._session.getMessageId()
} }
async recreateDc(dcId: number): Promise<void> { async recreateDc(dcId: number): Promise<void> {

View file

@ -1,17 +1,17 @@
import type { mtp } from '@mtcute/tl' import type { mtp } from '@mtcute/tl'
import type { TlReaderMap, TlWriterMap } from '@mtcute/tl-runtime' import type { TlReaderMap, TlWriterMap } from '@mtcute/tl-runtime'
import type { ICorePlatform } from '../types/platform.js' import type { ICorePlatform } from '../types/platform.js'
import type { ICryptoProvider } from '../utils/index.js' import type { ICryptoProvider, Logger } from '../utils/index.js'
import type { MtprotoSession, PendingMessage, PendingRpc } from './mtproto-session.js' import type { PendingMessage, PendingRpc } from './mtproto-session.js'
import type { PersistentConnectionParams } from './persistent-connection.js' import type { PersistentConnectionParams } from './persistent-connection.js'
import type { ServerSaltManager } from './server-salt.js' import type { ServerSaltManager } from './server-salt.js'
import { Deferred, Emitter, timers, u8 } from '@fuman/utils' import { Deferred, Emitter, timers, u8 } from '@fuman/utils'
import { tl } from '@mtcute/tl' import { tl } from '@mtcute/tl'
import { TlBinaryReader, TlBinaryWriter, TlSerializationCounter } from '@mtcute/tl-runtime' import { TlBinaryReader, TlBinaryWriter, TlSerializationCounter } from '@mtcute/tl-runtime'
import Long from 'long' import Long from 'long'
import { MtArgumentError, MtcuteError, MtTimeoutError } from '../types/index.js' import { MtArgumentError, MtcuteError, MtTimeoutError } from '../types/index.js'
import { createAesIgeForMessageOld } from '../utils/crypto/mtproto.js' import { createAesIgeForMessageOld } from '../utils/crypto/mtproto.js'
import { import {
EarlyTimer, EarlyTimer,
@ -20,6 +20,7 @@ import {
removeFromLongArray, removeFromLongArray,
} from '../utils/index.js' } from '../utils/index.js'
import { doAuthorization } from './authorization.js' import { doAuthorization } from './authorization.js'
import { MtprotoSession } from './mtproto-session.js'
import { PersistentConnection } from './persistent-connection.js' import { PersistentConnection } from './persistent-connection.js'
import { TransportError } from './transports/abstract.js' import { TransportError } from './transports/abstract.js'
@ -82,6 +83,7 @@ export class SessionConnection extends PersistentConnection {
private _writerMap: TlWriterMap private _writerMap: TlWriterMap
private _crypto: ICryptoProvider private _crypto: ICryptoProvider
private _salts: ServerSaltManager private _salts: ServerSaltManager
readonly _session: MtprotoSession
// todo: we should probably do adaptive ping interval based on rtt like tdlib: // todo: we should probably do adaptive ping interval based on rtt like tdlib:
// https://github.com/tdlib/td/blob/91aa6c9e4d0774eabf4f8d7f3aa51239032059a6/td/mtproto/SessionConnection.h // https://github.com/tdlib/td/blob/91aa6c9e4d0774eabf4f8d7f3aa51239032059a6/td/mtproto/SessionConnection.h
@ -97,11 +99,15 @@ export class SessionConnection extends PersistentConnection {
readonly onUpdate: Emitter<tl.TypeUpdates> = new Emitter() readonly onUpdate: Emitter<tl.TypeUpdates> = new Emitter()
readonly onFutureSalts: Emitter<mtp.RawMt_future_salt[]> = new Emitter() readonly onFutureSalts: Emitter<mtp.RawMt_future_salt[]> = new Emitter()
constructor( constructor(params: SessionConnectionParams, log: Logger) {
params: SessionConnectionParams, super(params, log.create('conn'))
readonly _session: MtprotoSession, this._session = new MtprotoSession(
) { params.crypto,
super(params, _session.log.create('conn')) log.create('session'),
params.readerMap,
params.writerMap,
params.salts,
)
this._flushTimer.onTimeout(this._flush.bind(this)) this._flushTimer.onTimeout(this._flush.bind(this))
this._pingInterval = params.pingInterval this._pingInterval = params.pingInterval
@ -188,6 +194,7 @@ export class SessionConnection extends PersistentConnection {
this.onError.add((err) => { this.onError.add((err) => {
this.log.warn('caught error after destroying: %s', err) this.log.warn('caught error after destroying: %s', err)
}) })
this._session.reset()
} }
} }