fix(core): proper handling of pfs with media dcs
This commit is contained in:
parent
1c5815ecf0
commit
6b430c5a2a
3 changed files with 42 additions and 113 deletions
|
@ -5,12 +5,10 @@ import type { SessionConnectionParams } from './session-connection.js'
|
||||||
|
|
||||||
import type { TelegramTransport } from './transports/index.js'
|
import type { TelegramTransport } from './transports/index.js'
|
||||||
import { Deferred, Emitter, unknownToError } from '@fuman/utils'
|
import { Deferred, Emitter, unknownToError } from '@fuman/utils'
|
||||||
import { MtprotoSession } from './mtproto-session.js'
|
|
||||||
import { SessionConnection } from './session-connection.js'
|
import { SessionConnection } from './session-connection.js'
|
||||||
|
|
||||||
export class MultiSessionConnection {
|
export class MultiSessionConnection {
|
||||||
private _log: Logger
|
private _log: Logger
|
||||||
readonly _sessions: MtprotoSession[]
|
|
||||||
private _enforcePfs = false
|
private _enforcePfs = false
|
||||||
|
|
||||||
// NB: dont forget to update .reset()
|
// NB: dont forget to update .reset()
|
||||||
|
@ -35,11 +33,10 @@ export class MultiSessionConnection {
|
||||||
if (logPrefix) this._log.prefix = `[${logPrefix}] `
|
if (logPrefix) this._log.prefix = `[${logPrefix}] `
|
||||||
this._enforcePfs = _count > 1 && params.isMainConnection
|
this._enforcePfs = _count > 1 && params.isMainConnection
|
||||||
|
|
||||||
this._sessions = []
|
|
||||||
this._updateConnections()
|
this._updateConnections()
|
||||||
}
|
}
|
||||||
|
|
||||||
protected _connections: SessionConnection[] = []
|
readonly _connections: SessionConnection[] = []
|
||||||
|
|
||||||
setCount(count: number, connect: boolean = this.params.isMainConnection): void {
|
setCount(count: number, connect: boolean = this.params.isMainConnection): void {
|
||||||
this._count = count
|
this._count = count
|
||||||
|
@ -47,71 +44,11 @@ export class MultiSessionConnection {
|
||||||
this._updateConnections(connect)
|
this._updateConnections(connect)
|
||||||
}
|
}
|
||||||
|
|
||||||
private _updateSessions(): void {
|
getCount(): number {
|
||||||
// there are two cases
|
return this._count
|
||||||
// 1. this msc is main, in which case every connection should have its own session
|
|
||||||
// 2. this msc is not main, in which case all connections should share the same session
|
|
||||||
// if (!this.params.isMainConnection) {
|
|
||||||
// // case 2
|
|
||||||
// this._log.debug(
|
|
||||||
// 'updating sessions count: %d -> 1',
|
|
||||||
// this._sessions.length,
|
|
||||||
// )
|
|
||||||
//
|
|
||||||
// if (this._sessions.length === 0) {
|
|
||||||
// this._sessions.push(
|
|
||||||
// new MtprotoSession(
|
|
||||||
// this.params.crypto,
|
|
||||||
// this._log.create('session'),
|
|
||||||
// this.params.readerMap,
|
|
||||||
// this.params.writerMap,
|
|
||||||
// ),
|
|
||||||
// )
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// // shouldn't happen, but just in case
|
|
||||||
// while (this._sessions.length > 1) {
|
|
||||||
// this._sessions.pop()!.reset()
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// return
|
|
||||||
// }
|
|
||||||
|
|
||||||
this._log.debug('updating sessions count: %d -> %d', this._sessions.length, this._count)
|
|
||||||
|
|
||||||
// case 1
|
|
||||||
if (this._sessions.length === this._count) return
|
|
||||||
|
|
||||||
if (this._sessions.length > this._count) {
|
|
||||||
// destroy extra sessions
|
|
||||||
for (let i = this._sessions.length - 1; i >= this._count; i--) {
|
|
||||||
this._sessions[i].reset()
|
|
||||||
}
|
|
||||||
|
|
||||||
this._sessions.splice(this._count)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
while (this._sessions.length < this._count) {
|
|
||||||
const idx = this._sessions.length
|
|
||||||
const session = new MtprotoSession(
|
|
||||||
this.params.crypto,
|
|
||||||
this._log.create('session'),
|
|
||||||
this.params.readerMap,
|
|
||||||
this.params.writerMap,
|
|
||||||
this.params.salts,
|
|
||||||
)
|
|
||||||
|
|
||||||
// brvh
|
|
||||||
if (idx !== 0) session._authKey = this._sessions[0]._authKey
|
|
||||||
|
|
||||||
this._sessions.push(session)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private _updateConnections(connect = false): void {
|
private _updateConnections(connect = false): void {
|
||||||
this._updateSessions()
|
|
||||||
if (this._connections.length === this._count) return
|
if (this._connections.length === this._count) return
|
||||||
|
|
||||||
this._log.debug('updating connections count: %d -> %d', this._connections.length, this._count)
|
this._log.debug('updating connections count: %d -> %d', this._connections.length, this._count)
|
||||||
|
@ -157,8 +94,6 @@ export class MultiSessionConnection {
|
||||||
|
|
||||||
// create new connections
|
// create new connections
|
||||||
for (let i = this._connections.length; i < this._count; i++) {
|
for (let i = this._connections.length; i < this._count; i++) {
|
||||||
const session = this._sessions[i] // this.params.isMainConnection ? // :
|
|
||||||
// this._sessions[0]
|
|
||||||
const conn = new SessionConnection(
|
const conn = new SessionConnection(
|
||||||
{
|
{
|
||||||
...this.params,
|
...this.params,
|
||||||
|
@ -167,7 +102,7 @@ export class MultiSessionConnection {
|
||||||
withUpdates:
|
withUpdates:
|
||||||
this.params.isMainConnection && this.params.isMainDcConnection && !this.params.disableUpdates,
|
this.params.isMainConnection && this.params.isMainDcConnection && !this.params.disableUpdates,
|
||||||
},
|
},
|
||||||
session,
|
this._log,
|
||||||
)
|
)
|
||||||
|
|
||||||
if (this.params.isMainConnection && this.params.isMainDcConnection) {
|
if (this.params.isMainConnection && this.params.isMainDcConnection) {
|
||||||
|
@ -216,7 +151,6 @@ export class MultiSessionConnection {
|
||||||
_destroyed = false
|
_destroyed = false
|
||||||
async destroy(): Promise<void> {
|
async destroy(): Promise<void> {
|
||||||
await Promise.all(this._connections.map(conn => conn.destroy()))
|
await Promise.all(this._connections.map(conn => conn.destroy()))
|
||||||
this._sessions.forEach(sess => sess.reset())
|
|
||||||
|
|
||||||
this.onRequestKeys.clear()
|
this.onRequestKeys.clear()
|
||||||
this.onError.clear()
|
this.onError.clear()
|
||||||
|
@ -279,14 +213,14 @@ export class MultiSessionConnection {
|
||||||
}
|
}
|
||||||
|
|
||||||
setAuthKey(authKey: Uint8Array | null, temp = false, idx = 0): void {
|
setAuthKey(authKey: Uint8Array | null, temp = false, idx = 0): void {
|
||||||
const session = this._sessions[idx]
|
const session = this._connections[idx]._session
|
||||||
const key = temp ? session._authKeyTemp : session._authKey
|
const key = temp ? session._authKeyTemp : session._authKey
|
||||||
key.setup(authKey)
|
key.setup(authKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
resetAuthKeys(): void {
|
resetAuthKeys(): void {
|
||||||
for (const session of this._sessions) {
|
for (const conn of this._connections) {
|
||||||
session.reset(true)
|
conn._session.reset(true)
|
||||||
}
|
}
|
||||||
this.notifyKeyChange()
|
this.notifyKeyChange()
|
||||||
}
|
}
|
||||||
|
@ -305,21 +239,23 @@ export class MultiSessionConnection {
|
||||||
|
|
||||||
notifyKeyChange(): void {
|
notifyKeyChange(): void {
|
||||||
// only expected to be called on non-main connections
|
// only expected to be called on non-main connections
|
||||||
const session = this._sessions[0]
|
for (const conn of this._connections) {
|
||||||
|
const session = conn._session
|
||||||
|
|
||||||
if (this.params.usePfs && !session._authKeyTemp.ready) {
|
if (this.params.usePfs && !session._authKeyTemp.ready) {
|
||||||
this._log.debug('temp auth key needed but not ready, ignoring key change')
|
this._log.debug('temp auth key needed but not ready, ignoring key change')
|
||||||
|
|
||||||
return
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if (session.queuedRpc.length) {
|
||||||
|
// there are pending requests, we need to reconnect.
|
||||||
|
this._log.debug('notifying key change on the connection due to queued rpc')
|
||||||
|
this._connections.forEach(conn => conn.onConnected())
|
||||||
|
}
|
||||||
|
|
||||||
|
// connection is idle, we don't need to notify it
|
||||||
}
|
}
|
||||||
|
|
||||||
if (this._sessions[0].queuedRpc.length) {
|
|
||||||
// there are pending requests, we need to reconnect.
|
|
||||||
this._log.debug('notifying key change on the connection due to queued rpc')
|
|
||||||
this._connections.forEach(conn => conn.onConnected())
|
|
||||||
}
|
|
||||||
|
|
||||||
// connection is idle, we don't need to notify it
|
|
||||||
}
|
}
|
||||||
|
|
||||||
notifyNetworkChanged(connected: boolean): void {
|
notifyNetworkChanged(connected: boolean): void {
|
||||||
|
|
|
@ -341,24 +341,13 @@ export class DcConnectionManager {
|
||||||
})
|
})
|
||||||
connection.onTmpKeyChange.add(([idx, key, expires]) => {
|
connection.onTmpKeyChange.add(([idx, key, expires]) => {
|
||||||
if (kind !== 'main') {
|
if (kind !== 'main') {
|
||||||
this.manager._log.warn('got tmp-key-change from non-main connection, ignoring')
|
// tmp keys in media dcs are ephemeral so there's no point in storing them
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
this.manager._log.debug('temp key change for dc %d from connection %d', this.dcId, idx)
|
this.manager._log.debug('temp key change for dc %d from connection %d', this.dcId, idx)
|
||||||
|
|
||||||
// send key to other connections
|
|
||||||
this.upload.setAuthKey(key, true)
|
|
||||||
this.download.setAuthKey(key, true)
|
|
||||||
this.downloadSmall.setAuthKey(key, true)
|
|
||||||
|
|
||||||
Promise.resolve(this.manager._storage.provider.authKeys.setTemp(this.dcId, idx, key, expires * 1000))
|
Promise.resolve(this.manager._storage.provider.authKeys.setTemp(this.dcId, idx, key, expires * 1000))
|
||||||
.then(() => {
|
|
||||||
this.upload.notifyKeyChange()
|
|
||||||
this.download.notifyKeyChange()
|
|
||||||
this.downloadSmall.notifyKeyChange()
|
|
||||||
})
|
|
||||||
.catch((e: Error) => {
|
.catch((e: Error) => {
|
||||||
this.manager._log.warn('failed to save temp auth key %d for dc %d: %e', idx, this.dcId, e)
|
this.manager._log.warn('failed to save temp auth key %d for dc %d: %e', idx, this.dcId, e)
|
||||||
this.manager.params.emitError(e)
|
this.manager.params.emitError(e)
|
||||||
|
@ -442,15 +431,12 @@ export class DcConnectionManager {
|
||||||
if (this.manager.params.usePfs || forcePfs) {
|
if (this.manager.params.usePfs || forcePfs) {
|
||||||
const now = Date.now()
|
const now = Date.now()
|
||||||
await Promise.all(
|
await Promise.all(
|
||||||
this.main._sessions.map(async (_, i) => {
|
Array.from({ length: this.main.getCount() }, async (_, i) => {
|
||||||
const temp = await this.manager._storage.provider.authKeys.getTemp(this.dcId, i, now)
|
const temp = await this.manager._storage.provider.authKeys.getTemp(this.dcId, i, now)
|
||||||
this.main.setAuthKey(temp, true, i)
|
this.main.setAuthKey(temp, true, i)
|
||||||
|
|
||||||
if (i === 0) {
|
// NB: we do not set temp auth keys for media connections,
|
||||||
this.upload.setAuthKey(temp, true)
|
// as they are ephemeral and dc-bound. doing this *will* lead to unwanted -404s
|
||||||
this.download.setAuthKey(temp, true)
|
|
||||||
this.downloadSmall.setAuthKey(temp, true)
|
|
||||||
}
|
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -915,7 +901,7 @@ export class NetworkManager {
|
||||||
}
|
}
|
||||||
|
|
||||||
getMtprotoMessageId(): Long {
|
getMtprotoMessageId(): Long {
|
||||||
return this._primaryDc!.main._sessions[0].getMessageId()
|
return this._primaryDc!.main._connections[0]._session.getMessageId()
|
||||||
}
|
}
|
||||||
|
|
||||||
async recreateDc(dcId: number): Promise<void> {
|
async recreateDc(dcId: number): Promise<void> {
|
||||||
|
|
|
@ -1,17 +1,17 @@
|
||||||
import type { mtp } from '@mtcute/tl'
|
import type { mtp } from '@mtcute/tl'
|
||||||
import type { TlReaderMap, TlWriterMap } from '@mtcute/tl-runtime'
|
import type { TlReaderMap, TlWriterMap } from '@mtcute/tl-runtime'
|
||||||
import type { ICorePlatform } from '../types/platform.js'
|
import type { ICorePlatform } from '../types/platform.js'
|
||||||
import type { ICryptoProvider } from '../utils/index.js'
|
import type { ICryptoProvider, Logger } from '../utils/index.js'
|
||||||
import type { MtprotoSession, PendingMessage, PendingRpc } from './mtproto-session.js'
|
import type { PendingMessage, PendingRpc } from './mtproto-session.js'
|
||||||
import type { PersistentConnectionParams } from './persistent-connection.js'
|
import type { PersistentConnectionParams } from './persistent-connection.js'
|
||||||
|
|
||||||
import type { ServerSaltManager } from './server-salt.js'
|
import type { ServerSaltManager } from './server-salt.js'
|
||||||
|
|
||||||
import { Deferred, Emitter, timers, u8 } from '@fuman/utils'
|
import { Deferred, Emitter, timers, u8 } from '@fuman/utils'
|
||||||
import { tl } from '@mtcute/tl'
|
import { tl } from '@mtcute/tl'
|
||||||
import { TlBinaryReader, TlBinaryWriter, TlSerializationCounter } from '@mtcute/tl-runtime'
|
import { TlBinaryReader, TlBinaryWriter, TlSerializationCounter } from '@mtcute/tl-runtime'
|
||||||
import Long from 'long'
|
import Long from 'long'
|
||||||
|
|
||||||
import { MtArgumentError, MtcuteError, MtTimeoutError } from '../types/index.js'
|
import { MtArgumentError, MtcuteError, MtTimeoutError } from '../types/index.js'
|
||||||
|
|
||||||
import { createAesIgeForMessageOld } from '../utils/crypto/mtproto.js'
|
import { createAesIgeForMessageOld } from '../utils/crypto/mtproto.js'
|
||||||
import {
|
import {
|
||||||
EarlyTimer,
|
EarlyTimer,
|
||||||
|
@ -20,6 +20,7 @@ import {
|
||||||
removeFromLongArray,
|
removeFromLongArray,
|
||||||
} from '../utils/index.js'
|
} from '../utils/index.js'
|
||||||
import { doAuthorization } from './authorization.js'
|
import { doAuthorization } from './authorization.js'
|
||||||
|
import { MtprotoSession } from './mtproto-session.js'
|
||||||
import { PersistentConnection } from './persistent-connection.js'
|
import { PersistentConnection } from './persistent-connection.js'
|
||||||
import { TransportError } from './transports/abstract.js'
|
import { TransportError } from './transports/abstract.js'
|
||||||
|
|
||||||
|
@ -82,6 +83,7 @@ export class SessionConnection extends PersistentConnection {
|
||||||
private _writerMap: TlWriterMap
|
private _writerMap: TlWriterMap
|
||||||
private _crypto: ICryptoProvider
|
private _crypto: ICryptoProvider
|
||||||
private _salts: ServerSaltManager
|
private _salts: ServerSaltManager
|
||||||
|
readonly _session: MtprotoSession
|
||||||
|
|
||||||
// todo: we should probably do adaptive ping interval based on rtt like tdlib:
|
// todo: we should probably do adaptive ping interval based on rtt like tdlib:
|
||||||
// https://github.com/tdlib/td/blob/91aa6c9e4d0774eabf4f8d7f3aa51239032059a6/td/mtproto/SessionConnection.h
|
// https://github.com/tdlib/td/blob/91aa6c9e4d0774eabf4f8d7f3aa51239032059a6/td/mtproto/SessionConnection.h
|
||||||
|
@ -97,11 +99,15 @@ export class SessionConnection extends PersistentConnection {
|
||||||
readonly onUpdate: Emitter<tl.TypeUpdates> = new Emitter()
|
readonly onUpdate: Emitter<tl.TypeUpdates> = new Emitter()
|
||||||
readonly onFutureSalts: Emitter<mtp.RawMt_future_salt[]> = new Emitter()
|
readonly onFutureSalts: Emitter<mtp.RawMt_future_salt[]> = new Emitter()
|
||||||
|
|
||||||
constructor(
|
constructor(params: SessionConnectionParams, log: Logger) {
|
||||||
params: SessionConnectionParams,
|
super(params, log.create('conn'))
|
||||||
readonly _session: MtprotoSession,
|
this._session = new MtprotoSession(
|
||||||
) {
|
params.crypto,
|
||||||
super(params, _session.log.create('conn'))
|
log.create('session'),
|
||||||
|
params.readerMap,
|
||||||
|
params.writerMap,
|
||||||
|
params.salts,
|
||||||
|
)
|
||||||
this._flushTimer.onTimeout(this._flush.bind(this))
|
this._flushTimer.onTimeout(this._flush.bind(this))
|
||||||
|
|
||||||
this._pingInterval = params.pingInterval
|
this._pingInterval = params.pingInterval
|
||||||
|
@ -188,6 +194,7 @@ export class SessionConnection extends PersistentConnection {
|
||||||
this.onError.add((err) => {
|
this.onError.add((err) => {
|
||||||
this.log.warn('caught error after destroying: %s', err)
|
this.log.warn('caught error after destroying: %s', err)
|
||||||
})
|
})
|
||||||
|
this._session.reset()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue