feat: abort signals

This commit is contained in:
alina 🌸 2023-10-05 18:10:15 +03:00
parent 31b41c93fc
commit 85ca3b4603
Signed by: teidesu
SSH key fingerprint: SHA256:uNeCpw6aTSU4aIObXLvHfLkDa82HWH9EiOj9AXOIRpI
8 changed files with 75 additions and 41 deletions

View file

@ -36,6 +36,15 @@ export function downloadToFile(this: TelegramClient, filename: string, params: F
const output = fs.createWriteStream(filename)
const stream = this.downloadAsStream(params)
if (params.abortSignal) {
params.abortSignal.addEventListener('abort', () => {
this.log.debug('aborting file download %s - cleaning up', filename)
output.destroy()
stream.destroy()
fs!.rmSync(filename)
})
}
return new Promise((resolve, reject) => {
stream.on('error', reject).pipe(output).on('finish', resolve).on('error', reject)
})

View file

@ -113,7 +113,7 @@ export async function* downloadAsIterable(
offset: chunkSize * chunk,
limit: chunkSize,
},
{ dcId, kind: connectionKind },
{ dcId, kind: connectionKind, abortSignal: params.abortSignal },
)
} catch (e: unknown) {
if (!tl.RpcError.is(e)) throw e

View file

@ -105,4 +105,9 @@ export interface FileDownloadParameters {
* @param total Total file size (`Infinity` if not available)
*/
progressCallback?: (downloaded: number, total: number) => void
/**
* Abort signal that can be used to cancel the download.
*/
abortSignal?: AbortSignal
}

View file

@ -25,6 +25,7 @@ export interface PendingRpc {
gzipOverhead?: number
sent?: boolean
done?: boolean
msgId?: Long
seqNo?: number
containerId?: Long

View file

@ -188,7 +188,12 @@ export class MultiSessionConnection extends EventEmitter {
private _nextConnection = 0
sendRpc<T extends tl.RpcMethod>(request: T, stack?: string, timeout?: number): Promise<tl.RpcCallReturn[T['_']]> {
sendRpc<T extends tl.RpcMethod>(
request: T,
stack?: string,
timeout?: number,
abortSignal?: AbortSignal,
): Promise<tl.RpcCallReturn[T['_']]> {
// if (this.params.isMainConnection) {
// find the least loaded connection
let min = Infinity
@ -204,7 +209,7 @@ export class MultiSessionConnection extends EventEmitter {
}
}
return this._connections[minIdx].sendRpc(request, stack, timeout)
return this._connections[minIdx].sendRpc(request, stack, timeout, abortSignal)
// }
// round-robin connections

View file

@ -139,6 +139,11 @@ export interface RpcCallOptions {
* Overrides `dcId` if set.
*/
manager?: DcConnectionManager
/**
* Abort signal for the call.
*/
abortSignal?: AbortSignal
}
export class DcConnectionManager {
@ -662,7 +667,7 @@ export class NetworkManager {
for (let i = 0; i < maxRetryCount; i++) {
try {
const res = await multi.sendRpc(message, stack, params?.timeout)
const res = await multi.sendRpc(message, stack, params?.timeout, params?.abortSignal)
if (kind === 'main') {
this._lastUpdateTime = Date.now()

View file

@ -9,7 +9,6 @@ import { gzipDeflate, gzipInflate } from '@mtcute/tl-runtime/src/platform/gzip'
import { MtArgumentError, MtcuteError, MtTimeoutError } from '../types'
import {
ControllablePromise,
createCancellablePromise,
createControllablePromise,
EarlyTimer,
longFromBuffer,
@ -691,6 +690,23 @@ export class SessionConnection extends PersistentConnection {
return
}
if (msg._ === 'cancel') {
let result
try {
result = message.object() as mtp.TlObject
} catch (err) {
this.log.debug('failed to parse rpc_result for cancel request %l, ignoring', reqMsgId)
return
}
this.log.debug('received %s for cancelled request %l: %j', result._, reqMsgId)
this._onMessageAcked(reqMsgId)
return
}
this.log.error('received rpc_result for %s request %l', msg._, reqMsgId)
return
@ -707,6 +723,8 @@ export class SessionConnection extends PersistentConnection {
this._session.initConnectionCalled = true
}
rpc.done = true
this.log.verbose('<<< (%s) %j', rpc.method, result)
if (result._ === 'mt_rpc_error') {
@ -1165,7 +1183,12 @@ export class SessionConnection extends PersistentConnection {
}
}
sendRpc<T extends tl.RpcMethod>(request: T, stack?: string, timeout?: number): Promise<tl.RpcCallReturn[T['_']]> {
sendRpc<T extends tl.RpcMethod>(
request: T,
stack?: string,
timeout?: number,
abortSignal?: AbortSignal,
): Promise<tl.RpcCallReturn[T['_']]> {
if (this._usable && this.params.inactivityTimeout) {
this._rescheduleInactivity()
}
@ -1244,8 +1267,7 @@ export class SessionConnection extends PersistentConnection {
const pending: PendingRpc = {
method,
// eslint-disable-next-line @typescript-eslint/no-unsafe-assignment
promise: undefined as any, // because we need the object to make a promise
promise: createControllablePromise(),
data: content,
stack,
// we will need to know size of gzip_packed overhead in _flush()
@ -1254,6 +1276,7 @@ export class SessionConnection extends PersistentConnection {
// setting them as well so jit can optimize stuff
sent: undefined,
done: undefined,
getState: undefined,
msgId: undefined,
seqNo: undefined,
@ -1263,19 +1286,28 @@ export class SessionConnection extends PersistentConnection {
timeout: undefined,
}
const promise = createCancellablePromise<any>(this._cancelRpc.bind(this, pending))
pending.promise = promise
if (abortSignal?.aborted) {
pending.promise.reject(abortSignal.reason)
return pending.promise
}
if (timeout) {
pending.timeout = setTimeout(this._cancelRpc, timeout, pending, true)
}
if (abortSignal) {
abortSignal.addEventListener('abort', () => this._cancelRpc(pending, false, abortSignal))
}
this._enqueueRpc(pending, true)
return promise
return pending.promise
}
private _cancelRpc(rpc: PendingRpc, onTimeout = false): void {
private _cancelRpc(rpc: PendingRpc, onTimeout = false, abortSignal?: AbortSignal): void {
if (rpc.done) return
if (rpc.cancelled && !onTimeout) {
throw new MtcuteError('RPC was already cancelled')
}
@ -1286,13 +1318,15 @@ export class SessionConnection extends PersistentConnection {
if (onTimeout) {
// todo: replace with MtTimeoutError
const error = new tl.RpcError(-503, 'Timeout')
const error = new tl.RpcError(400, 'Client timeout')
if (this.params.niceStacks !== false) {
makeNiceStack(error, rpc.stack!, rpc.method)
}
rpc.promise.reject(error)
} else if (abortSignal) {
rpc.promise.reject(abortSignal.reason)
}
rpc.cancelled = true
@ -1721,7 +1755,7 @@ export class SessionConnection extends PersistentConnection {
const rootMsgId = new Long(result.readInt32LE(), result.readInt32LE(4))
this.log.debug(
'sending %d messages: size = %db, acks = %d (msg_id = %s), ping = %s (msg_id = %s), state_req = %s (msg_id = %s), resend = %s (msg_id = %s), rpc = %s, container = %s, root msg_id = %l',
'sending %d messages: size = %db, acks = %d (msg_id = %s), ping = %s (msg_id = %s), state_req = %s (msg_id = %s), resend = %s (msg_id = %s), cancels = %s (msg_id = %s), rpc = %s, container = %s, root msg_id = %l',
messageCount,
packetSize,
ackMsgIds?.length || 'false',
@ -1731,6 +1765,8 @@ export class SessionConnection extends PersistentConnection {
getStateMsgIds?.map((it) => it.toString()) || 'false',
getStateMsgId,
resendMsgIds?.map((it) => it.toString()) || 'false',
cancelRpcs,
cancelRpcs?.map((it) => it.toString()) || 'false',
resendMsgId,
rpcToSend.map((it) => it.method),
useContainer,

View file

@ -6,13 +6,6 @@ export type ControllablePromise<T = unknown> = Promise<T> & {
reject(err?: unknown): void
}
/**
* A promise that can be cancelled.
*/
export type CancellablePromise<T = unknown> = Promise<T> & {
cancel(): void
}
/**
* The promise was cancelled
*/
@ -35,23 +28,3 @@ export function createControllablePromise<T = unknown>(): ControllablePromise<T>
return promise as ControllablePromise<T>
}
/**
* Creates a promise that can be cancelled.
*
* @param onCancel Callback to call when cancellation is requested
*/
export function createCancellablePromise<T = unknown>(
onCancel: () => void,
): ControllablePromise<T> & CancellablePromise<T> {
// todo rethink this in MTQ-20
const promise = createControllablePromise()
;(promise as unknown as CancellablePromise<T>).cancel = () => {
promise.reject(new PromiseCancelledError())
onCancel()
}
return promise as ControllablePromise<T> & CancellablePromise<T>
}