feat(core): requests chaining

This commit is contained in:
alina 🌸 2023-12-11 00:07:41 +03:00
parent aaa2875fe1
commit d1e4a15f81
Signed by: teidesu
SSH key fingerprint: SHA256:uNeCpw6aTSU4aIObXLvHfLkDa82HWH9EiOj9AXOIRpI
8 changed files with 245 additions and 57 deletions

View file

@ -5,6 +5,7 @@ import { TlBinaryWriter, TlReaderMap, TlSerializationCounter, TlWriterMap } from
import { MtcuteError } from '../types/index.js'
import {
compareLongs,
ControllablePromise,
Deque,
getRandomInt,
@ -24,6 +25,9 @@ export interface PendingRpc {
stack?: string
gzipOverhead?: number
chainId?: string | number
invokeAfter?: Long
sent?: boolean
done?: boolean
msgId?: Long
@ -109,6 +113,9 @@ export class MtprotoSession {
queuedCancelReq: Long[] = []
getStateSchedule = new SortedArray<PendingRpc>([], (a, b) => a.getState! - b.getState!)
chains = new Map<string | number, Long>()
chainsPendingFails = new Map<string | number, SortedArray<PendingRpc>>()
// requests info
pendingMessages = new LongMap<PendingMessage>()
destroySessionIdToMsgId = new LongMap<Long>()
@ -200,6 +207,7 @@ export class MtprotoSession {
this.queuedStateReq.length = 0
this.queuedResendReq.length = 0
this.getStateSchedule.clear()
this.chains.clear()
}
enqueueRpc(rpc: PendingRpc, force?: boolean): boolean {
@ -334,4 +342,42 @@ export class MtprotoSession {
this.lastPingMsgId = Long.ZERO
}
addToChain(chainId: string | number, msgId: Long): Long | undefined {
const prevMsgId = this.chains.get(chainId)
this.chains.set(chainId, msgId)
this.log.debug('added message %l to chain %s (prev: %l)', msgId, chainId, prevMsgId)
return prevMsgId
}
removeFromChain(chainId: string | number, msgId: Long): void {
const lastMsgId = this.chains.get(chainId)
if (!lastMsgId) {
this.log.warn('tried to remove message %l from empty chain %s', msgId, chainId)
return
}
if (lastMsgId.eq(msgId)) {
// last message of the chain, remove it
this.log.debug('chain %s: exhausted, last message %l', msgId, chainId)
this.chains.delete(chainId)
}
// do nothing
}
getPendingChainedFails(chainId: string | number): SortedArray<PendingRpc> {
let arr = this.chainsPendingFails.get(chainId)
if (!arr) {
arr = new SortedArray<PendingRpc>([], (a, b) => compareLongs(a.invokeAfter!, b.invokeAfter!))
this.chainsPendingFails.set(chainId, arr)
}
return arr
}
}

View file

@ -211,6 +211,7 @@ export class MultiSessionConnection extends EventEmitter {
stack?: string,
timeout?: number,
abortSignal?: AbortSignal,
chainId?: string | number,
): Promise<tl.RpcCallReturn[T['_']]> {
// if (this.params.isMainConnection) {
// find the least loaded connection
@ -227,7 +228,7 @@ export class MultiSessionConnection extends EventEmitter {
}
}
return this._connections[minIdx].sendRpc(request, stack, timeout, abortSignal)
return this._connections[minIdx].sendRpc(request, stack, timeout, abortSignal, chainId)
// }
// round-robin connections

View file

@ -155,6 +155,15 @@ export interface RpcCallOptions {
* -503 in case the upstream bot failed to respond.
*/
throw503?: boolean
/**
* Some requests should be processed consecutively, and not in parallel.
* Using the same `chainId` for multiple requests will ensure that they are processed in the order
* of calling `.call()`.
*
* Particularly useful for `messages.sendMessage` and alike.
*/
chainId?: string | number
}
/**
@ -683,7 +692,7 @@ export class NetworkManager {
for (let i = 0; i < maxRetryCount; i++) {
try {
const res = await multi.sendRpc(message, stack, params?.timeout, params?.abortSignal)
const res = await multi.sendRpc(message, stack, params?.timeout, params?.abortSignal, params?.chainId)
if (kind === 'main') {
this._lastUpdateTime = Date.now()
@ -704,7 +713,12 @@ export class NetworkManager {
throw new MtTimeoutError()
}
this._log.warn('Telegram is having internal issues: %d %s, retrying', e.code, e.message)
this._log.warn(
'Telegram is having internal issues: %d:%s (%s), retrying',
e.code,
e.text,
e.message,
)
if (e.text === 'WORKER_BUSY_TOO_LONG_RETRY') {
// according to tdlib, "it is dangerous to resend query without timeout, so use 1"

View file

@ -23,8 +23,6 @@ import { MtprotoSession, PendingMessage, PendingRpc } from './mtproto-session.js
import { PersistentConnection, PersistentConnectionParams } from './persistent-connection.js'
import { TransportError } from './transports/abstract.js'
const TEMP_AUTH_KEY_EXPIRY = 86400
export interface SessionConnectionParams extends PersistentConnectionParams {
initConnection: tl.RawInitConnectionRequest
inactivityTimeout?: number
@ -41,6 +39,9 @@ export interface SessionConnectionParams extends PersistentConnectionParams {
writerMap: TlWriterMap
}
const TEMP_AUTH_KEY_EXPIRY = 86400 // 24 hours
const PING_INTERVAL = 60000 // 1 minute
// destroy_auth_key#d1435160 = DestroyAuthKeyRes;
// const DESTROY_AUTH_KEY = Buffer.from('605134d1', 'hex')
// gzip_packed#3072cfa1 packed_data:string = Object;
@ -49,6 +50,9 @@ const GZIP_PACKED_ID = 0x3072cfa1
const MSG_CONTAINER_ID = 0x73f1f8dc
// rpc_result#f35c6d01 req_msg_id:long result:Object = RpcResult;
const RPC_RESULT_ID = 0xf35c6d01
// invokeAfterMsg#cb9f372d {X:Type} msg_id:long query:!X = X;
const INVOKE_AFTER_MSG_ID = 0xcb9f372d
const INVOKE_AFTER_MSG_SIZE = 12 // 8 (invokeAfterMsg) + 4 (msg_id)
function makeNiceStack(error: tl.RpcError, stack: string, method?: string) {
error.stack = `RpcError (${error.code} ${error.text}): ${error.message}\n at ${method}\n${stack
@ -735,6 +739,7 @@ export class SessionConnection extends PersistentConnection {
// initConnection call was definitely received and
// processed by the server, so we no longer need to use it
// todo: is this the case with failed invokeAfterMsg(s) as well?
if (rpc.initConn) {
this._session.initConnectionCalled = true
}
@ -753,43 +758,74 @@ export class SessionConnection extends PersistentConnection {
rpc.method,
)
if (res.errorMessage === 'AUTH_KEY_PERM_EMPTY') {
// happens when temp auth key is not yet bound
// this shouldn't happen as we block any outbound communications
// until the temp key is derived and bound.
//
// i think it is also possible for the error to be returned
// when the temp key has expired, but this still shouldn't happen
// but this is tg, so something may go wrong, and we will receive this as an error
// (for god's sake why is this not in mtproto and instead hacked into the app layer)
this._authorizePfs()
this._onMessageFailed(reqMsgId, 'AUTH_KEY_PERM_EMPTY', true)
return
}
if (res.errorMessage === 'CONNECTION_NOT_INITED') {
// this seems to sometimes happen when using pfs
// no idea why, but tdlib also seems to handle these, so whatever
this._session.initConnectionCalled = false
this._onMessageFailed(reqMsgId, res.errorMessage, true)
// just setting this flag is not enough because the message
// is already serialized, so we do this awesome hack
this.sendRpc({ _: 'help.getNearestDc' })
.then(() => {
this.log.debug('additional help.getNearestDc for initConnection ok')
})
.catch((err) => {
this.log.debug('additional help.getNearestDc for initConnection error: %s', err)
})
return
}
if (rpc.cancelled) return
switch (res.errorMessage) {
case 'AUTH_KEY_PERM_EMPTY':
// happens when temp auth key is not yet bound
// this shouldn't happen as we block any outbound communications
// until the temp key is derived and bound.
//
// i think it is also possible for the error to be returned
// when the temp key has expired, but this still shouldn't happen
// but this is tg, so something may go wrong, and we will receive this as an error
// (for god's sake why is this not in mtproto and instead hacked into the app layer)
this._authorizePfs()
this._onMessageFailed(reqMsgId, 'AUTH_KEY_PERM_EMPTY', true)
return
case 'CONNECTION_NOT_INITED': {
// this seems to sometimes happen when using pfs
// no idea why, but tdlib also seems to handle these, so whatever
this._session.initConnectionCalled = false
this._onMessageFailed(reqMsgId, res.errorMessage, true)
// just setting this flag is not enough because the message
// is already serialized, so we do this awesome hack
this.sendRpc({ _: 'help.getNearestDc' })
.then(() => {
this.log.debug('additional help.getNearestDc for initConnection ok')
})
.catch((err) => {
this.log.debug('additional help.getNearestDc for initConnection error: %s', err)
})
return
}
case 'MSG_WAIT_TIMEOUT':
case 'MSG_WAIT_FAILED': {
if (!rpc.invokeAfter) {
this.log.warn('received %s for non-chained request %l', res.errorMessage, reqMsgId)
break
}
// in some cases, MSG_WAIT_TIMEOUT is returned instead of MSG_WAIT_FAILED when one of the deps
// failed with MSG_WAIT_TIMEOUT. i have no clue why, this makes zero sense, but what fucking ever
//
// this basically means we can't handle a timeout any different than a general failure,
// because the timeout might not refer to the immediate `.invokeAfter` message, but to
// its arbitrary-depth dependency, so we indeed have to wait for the message ourselves...
if (this._session.pendingMessages.has(rpc.invokeAfter)) {
// the dependency is still pending, postpone the processing
this.log.debug(
'chain %s: waiting for %l before processing %l',
rpc.chainId,
rpc.invokeAfter,
reqMsgId,
)
this._session.getPendingChainedFails(rpc.chainId!).insert(rpc)
} else {
this._session.chains.delete(rpc.chainId!)
this._onMessageFailed(reqMsgId, 'MSG_WAIT_FAILED', true)
}
return
}
}
const error = tl.RpcError.fromTl(res)
if (this.params.niceStacks !== false) {
@ -809,10 +845,40 @@ export class SessionConnection extends PersistentConnection {
rpc.promise.resolve(result)
}
if (rpc.chainId) {
this._processPendingChainedFails(rpc.chainId, reqMsgId)
}
this._onMessageAcked(reqMsgId)
this._session.pendingMessages.delete(reqMsgId)
}
private _processPendingChainedFails(chainId: number | string, sinceMsgId: Long): void {
// sinceMsgId was already definitely received and contained an error.
// we should now re-send all the pending MSG_WAIT_FAILED after it
this._session.removeFromChain(chainId, sinceMsgId)
const oldPending = this._session.chainsPendingFails.get(chainId)
if (!oldPending?.length) {
return
}
const idx = oldPending.index({ invokeAfter: sinceMsgId } as PendingRpc, true)
if (idx === -1) return
const toFail = oldPending.raw.splice(idx)
this.log.debug('chain %s: failing %d dependant messages: %L', chainId, toFail.length, toFail)
// we're failing the rest of the chain, including the last message
this._session.chains.delete(chainId)
for (const rpc of toFail) {
this._onMessageFailed(rpc.msgId!, 'MSG_WAIT_FAILED', true)
}
}
private _onMessageAcked(msgId: Long, inContainer = false): void {
const msg = this._session.pendingMessages.get(msgId)
@ -1204,6 +1270,7 @@ export class SessionConnection extends PersistentConnection {
stack?: string,
timeout?: number,
abortSignal?: AbortSignal,
chainId?: string | number,
): Promise<tl.RpcCallReturn[T['_']]> {
if (this._usable && this.params.inactivityTimeout) {
this._rescheduleInactivity()
@ -1290,8 +1357,10 @@ export class SessionConnection extends PersistentConnection {
// we will need to know size of gzip_packed overhead in _flush()
gzipOverhead: shouldGzip ? 4 + TlSerializationCounter.countBytesOverhead(content.length) : 0,
initConn,
chainId,
// setting them as well so jit can optimize stuff
invokeAfter: undefined,
sent: undefined,
done: undefined,
getState: undefined,
@ -1405,7 +1474,7 @@ export class SessionConnection extends PersistentConnection {
// between multiple connections using the same session
this._flushTimer.emitWhenIdle()
} else {
this._flushTimer.emitBefore(this._session.lastPingTime + 60000)
this._flushTimer.emitBefore(this._session.lastPingTime + PING_INTERVAL)
}
}
@ -1466,16 +1535,17 @@ export class SessionConnection extends PersistentConnection {
const getStateTime = now + 1500
if (now - this._session.lastPingTime > 60000) {
if (now - this._session.lastPingTime > PING_INTERVAL) {
if (!this._session.lastPingMsgId.isZero()) {
this.log.warn("didn't receive pong for previous ping (msg_id = %l)", this._session.lastPingMsgId)
this._session.pendingMessages.delete(this._session.lastPingMsgId)
}
pingId = randomLong()
const obj: mtp.RawMt_ping = {
_: 'mt_ping',
const obj: mtp.RawMt_ping_delay_disconnect = {
_: 'mt_ping_delay_disconnect',
pingId,
disconnectDelay: 75,
}
this._session.lastPingTime = Date.now()
@ -1571,6 +1641,10 @@ export class SessionConnection extends PersistentConnection {
containerSize += msg.data.length + 16
if (msg.gzipOverhead) containerSize += msg.gzipOverhead
if (msg.chainId) {
containerSize += INVOKE_AFTER_MSG_SIZE
}
// if message was already assigned a msg_id,
// we must wrap it in a container with a newer msg_id
if (msg.msgId) forceContainer = true
@ -1699,6 +1773,11 @@ export class SessionConnection extends PersistentConnection {
_: 'rpc',
rpc: msg,
})
if (msg.chainId) {
msg.invokeAfter = this._session.addToChain(msg.chainId, msgId)
this.log.debug('chain %s: invoke %l after %l', msg.chainId, msg.msgId, msg.invokeAfter)
}
} else {
this.log.debug('%s: msg_id already assigned, reusing %l, seqno: %d', msg.method, msg.msgId, msg.seqNo)
}
@ -1716,12 +1795,23 @@ export class SessionConnection extends PersistentConnection {
writer.long(this._registerOutgoingMsgId(msg.msgId))
writer.uint(msg.seqNo!)
const invokeAfterSize = msg.invokeAfter ? INVOKE_AFTER_MSG_SIZE : 0
const writeInvokeAfter = () => {
if (!msg.invokeAfter) return
writer.uint(INVOKE_AFTER_MSG_ID)
writer.long(msg.invokeAfter)
}
if (msg.gzipOverhead) {
writer.uint(msg.data.length + msg.gzipOverhead)
writer.uint(0x3072cfa1) // gzip_packed#3072cfa1
writer.uint(msg.data.length + msg.gzipOverhead + invokeAfterSize)
writeInvokeAfter()
writer.uint(GZIP_PACKED_ID)
writer.bytes(msg.data)
} else {
writer.uint(msg.data.length)
writer.uint(msg.data.length + invokeAfterSize)
writeInvokeAfter()
writer.raw(msg.data)
}
@ -1733,7 +1823,7 @@ export class SessionConnection extends PersistentConnection {
// we couldn't have assigned them earlier because mtproto
// requires them to be >= than the contained messages
// writer.pos is expected to be packetSize
packetSize = writer.pos
const containerId = this._session.getMessageId()
writer.pos = 0
@ -1767,18 +1857,17 @@ export class SessionConnection extends PersistentConnection {
const result = writer.result()
this.log.debug(
'sending %d messages: size = %db, acks = %d (msg_id = %s), ping = %s (msg_id = %s), state_req = %s (msg_id = %s), resend = %s (msg_id = %s), cancels = %s (msg_id = %s), rpc = %s, container = %s, root msg_id = %l',
'sending %d messages: size = %db, acks = %L, ping = %b (msg_id = %l), state_req = %L (msg_id = %l), resend = %L (msg_id = %l), cancels = %L (msg_id = %l), rpc = %s, container = %b, root msg_id = %l',
messageCount,
packetSize,
ackMsgIds?.length || 'false',
ackMsgIds?.map((it) => it.toString()),
Boolean(pingRequest),
ackMsgIds,
pingRequest,
pingMsgId,
getStateMsgIds?.map((it) => it.toString()) || 'false',
getStateMsgIds,
getStateMsgId,
resendMsgIds?.map((it) => it.toString()) || 'false',
resendMsgIds,
cancelRpcs,
cancelRpcs,
cancelRpcs?.map((it) => it.toString()) || 'false',
resendMsgId,
rpcToSend.map((it) => it.method),
useContainer,

View file

@ -171,5 +171,23 @@ describe('logger', () => {
expect(spy).toHaveBeenCalledWith(3, 3, 'base', 'test 123', [])
})
})
describe('%L', () => {
it('should format Long arrays as strings', () => {
const [mgr, spy] = createManager()
mgr.info('test %L', [Long.fromInt(123), Long.fromInt(456)])
expect(spy).toHaveBeenCalledWith(3, 3, 'base', 'test [123, 456]', [])
})
it('should format everything else as n/a', () => {
const [mgr, spy] = createManager()
mgr.info('test %L', 123)
expect(spy).toHaveBeenCalledWith(3, 3, 'base', 'test n/a', [])
})
})
})
})

View file

@ -70,11 +70,12 @@ export class Logger {
fmt.includes('%b') ||
fmt.includes('%j') ||
fmt.includes('%J') ||
fmt.includes('%l')
fmt.includes('%l') ||
fmt.includes('%L')
) {
let idx = 0
fmt = fmt.replace(FORMATTER_RE, (m) => {
if (m === '%h' || m === '%b' || m === '%j' || m === '%J' || m === '%l') {
if (m === '%h' || m === '%b' || m === '%j' || m === '%J' || m === '%l' || m === '%L') {
let val = args[idx]
args.splice(idx, 1)
@ -112,6 +113,12 @@ export class Logger {
})
}
if (m === '%l') return String(val)
if (m === '%L') {
if (!Array.isArray(val)) return 'n/a'
return `[${(val as unknown[]).map(String).join(', ')}]`
}
}
idx++

View file

@ -53,6 +53,18 @@ export function removeFromLongArray(arr: Long[], val: Long): boolean {
return false
}
/**
* Compare two Longs and return -1, 0 or 1,
* to be used as a comparator function.
*/
export function compareLongs(a: Long, b: Long): number {
if (a.eq(b)) return 0
if (a.gt(b)) return 1
return -1
}
/**
* Serialize a Long (int64) to its fast string representation:
* `$high|$low`.

View file

@ -55,6 +55,7 @@ export class SortedArray<T> {
// closest: return the closest value (right-hand side)
// meaning that raw[idx - 1] <= item <= raw[idx]
// in other words, smallest idx such that raw[idx] >= item
index(item: T, closest = false): number {
let lo = 0
let hi = this.raw.length