diff --git a/packages/client/src/methods/files/download-file.ts b/packages/client/src/methods/files/download-file.ts index 552100ed..2ae72b46 100644 --- a/packages/client/src/methods/files/download-file.ts +++ b/packages/client/src/methods/files/download-file.ts @@ -36,6 +36,15 @@ export function downloadToFile(this: TelegramClient, filename: string, params: F const output = fs.createWriteStream(filename) const stream = this.downloadAsStream(params) + if (params.abortSignal) { + params.abortSignal.addEventListener('abort', () => { + this.log.debug('aborting file download %s - cleaning up', filename) + output.destroy() + stream.destroy() + fs!.rmSync(filename) + }) + } + return new Promise((resolve, reject) => { stream.on('error', reject).pipe(output).on('finish', resolve).on('error', reject) }) diff --git a/packages/client/src/methods/files/download-iterable.ts b/packages/client/src/methods/files/download-iterable.ts index cf5632ef..94a55bbb 100644 --- a/packages/client/src/methods/files/download-iterable.ts +++ b/packages/client/src/methods/files/download-iterable.ts @@ -113,7 +113,7 @@ export async function* downloadAsIterable( offset: chunkSize * chunk, limit: chunkSize, }, - { dcId, kind: connectionKind }, + { dcId, kind: connectionKind, abortSignal: params.abortSignal }, ) } catch (e: unknown) { if (!tl.RpcError.is(e)) throw e diff --git a/packages/client/src/types/files/utils.ts b/packages/client/src/types/files/utils.ts index 3d4697de..cf5008aa 100644 --- a/packages/client/src/types/files/utils.ts +++ b/packages/client/src/types/files/utils.ts @@ -105,4 +105,9 @@ export interface FileDownloadParameters { * @param total Total file size (`Infinity` if not available) */ progressCallback?: (downloaded: number, total: number) => void + + /** + * Abort signal that can be used to cancel the download. + */ + abortSignal?: AbortSignal } diff --git a/packages/core/src/network/mtproto-session.ts b/packages/core/src/network/mtproto-session.ts index 063b6f4f..00624661 100644 --- a/packages/core/src/network/mtproto-session.ts +++ b/packages/core/src/network/mtproto-session.ts @@ -25,6 +25,7 @@ export interface PendingRpc { gzipOverhead?: number sent?: boolean + done?: boolean msgId?: Long seqNo?: number containerId?: Long diff --git a/packages/core/src/network/multi-session-connection.ts b/packages/core/src/network/multi-session-connection.ts index 30909bed..dd6abff0 100644 --- a/packages/core/src/network/multi-session-connection.ts +++ b/packages/core/src/network/multi-session-connection.ts @@ -188,7 +188,12 @@ export class MultiSessionConnection extends EventEmitter { private _nextConnection = 0 - sendRpc(request: T, stack?: string, timeout?: number): Promise { + sendRpc( + request: T, + stack?: string, + timeout?: number, + abortSignal?: AbortSignal, + ): Promise { // if (this.params.isMainConnection) { // find the least loaded connection let min = Infinity @@ -204,7 +209,7 @@ export class MultiSessionConnection extends EventEmitter { } } - return this._connections[minIdx].sendRpc(request, stack, timeout) + return this._connections[minIdx].sendRpc(request, stack, timeout, abortSignal) // } // round-robin connections diff --git a/packages/core/src/network/network-manager.ts b/packages/core/src/network/network-manager.ts index 9a17212e..81417f85 100644 --- a/packages/core/src/network/network-manager.ts +++ b/packages/core/src/network/network-manager.ts @@ -139,6 +139,11 @@ export interface RpcCallOptions { * Overrides `dcId` if set. */ manager?: DcConnectionManager + + /** + * Abort signal for the call. + */ + abortSignal?: AbortSignal } export class DcConnectionManager { @@ -662,7 +667,7 @@ export class NetworkManager { for (let i = 0; i < maxRetryCount; i++) { try { - const res = await multi.sendRpc(message, stack, params?.timeout) + const res = await multi.sendRpc(message, stack, params?.timeout, params?.abortSignal) if (kind === 'main') { this._lastUpdateTime = Date.now() diff --git a/packages/core/src/network/session-connection.ts b/packages/core/src/network/session-connection.ts index 6e007956..9cb69e0a 100644 --- a/packages/core/src/network/session-connection.ts +++ b/packages/core/src/network/session-connection.ts @@ -9,7 +9,6 @@ import { gzipDeflate, gzipInflate } from '@mtcute/tl-runtime/src/platform/gzip' import { MtArgumentError, MtcuteError, MtTimeoutError } from '../types' import { ControllablePromise, - createCancellablePromise, createControllablePromise, EarlyTimer, longFromBuffer, @@ -691,6 +690,23 @@ export class SessionConnection extends PersistentConnection { return } + if (msg._ === 'cancel') { + let result + + try { + result = message.object() as mtp.TlObject + } catch (err) { + this.log.debug('failed to parse rpc_result for cancel request %l, ignoring', reqMsgId) + + return + } + + this.log.debug('received %s for cancelled request %l: %j', result._, reqMsgId) + this._onMessageAcked(reqMsgId) + + return + } + this.log.error('received rpc_result for %s request %l', msg._, reqMsgId) return @@ -707,6 +723,8 @@ export class SessionConnection extends PersistentConnection { this._session.initConnectionCalled = true } + rpc.done = true + this.log.verbose('<<< (%s) %j', rpc.method, result) if (result._ === 'mt_rpc_error') { @@ -1165,7 +1183,12 @@ export class SessionConnection extends PersistentConnection { } } - sendRpc(request: T, stack?: string, timeout?: number): Promise { + sendRpc( + request: T, + stack?: string, + timeout?: number, + abortSignal?: AbortSignal, + ): Promise { if (this._usable && this.params.inactivityTimeout) { this._rescheduleInactivity() } @@ -1244,8 +1267,7 @@ export class SessionConnection extends PersistentConnection { const pending: PendingRpc = { method, - // eslint-disable-next-line @typescript-eslint/no-unsafe-assignment - promise: undefined as any, // because we need the object to make a promise + promise: createControllablePromise(), data: content, stack, // we will need to know size of gzip_packed overhead in _flush() @@ -1254,6 +1276,7 @@ export class SessionConnection extends PersistentConnection { // setting them as well so jit can optimize stuff sent: undefined, + done: undefined, getState: undefined, msgId: undefined, seqNo: undefined, @@ -1263,19 +1286,28 @@ export class SessionConnection extends PersistentConnection { timeout: undefined, } - const promise = createCancellablePromise(this._cancelRpc.bind(this, pending)) - pending.promise = promise + if (abortSignal?.aborted) { + pending.promise.reject(abortSignal.reason) + + return pending.promise + } if (timeout) { pending.timeout = setTimeout(this._cancelRpc, timeout, pending, true) } + if (abortSignal) { + abortSignal.addEventListener('abort', () => this._cancelRpc(pending, false, abortSignal)) + } + this._enqueueRpc(pending, true) - return promise + return pending.promise } - private _cancelRpc(rpc: PendingRpc, onTimeout = false): void { + private _cancelRpc(rpc: PendingRpc, onTimeout = false, abortSignal?: AbortSignal): void { + if (rpc.done) return + if (rpc.cancelled && !onTimeout) { throw new MtcuteError('RPC was already cancelled') } @@ -1286,13 +1318,15 @@ export class SessionConnection extends PersistentConnection { if (onTimeout) { // todo: replace with MtTimeoutError - const error = new tl.RpcError(-503, 'Timeout') + const error = new tl.RpcError(400, 'Client timeout') if (this.params.niceStacks !== false) { makeNiceStack(error, rpc.stack!, rpc.method) } rpc.promise.reject(error) + } else if (abortSignal) { + rpc.promise.reject(abortSignal.reason) } rpc.cancelled = true @@ -1721,7 +1755,7 @@ export class SessionConnection extends PersistentConnection { const rootMsgId = new Long(result.readInt32LE(), result.readInt32LE(4)) this.log.debug( - 'sending %d messages: size = %db, acks = %d (msg_id = %s), ping = %s (msg_id = %s), state_req = %s (msg_id = %s), resend = %s (msg_id = %s), rpc = %s, container = %s, root msg_id = %l', + 'sending %d messages: size = %db, acks = %d (msg_id = %s), ping = %s (msg_id = %s), state_req = %s (msg_id = %s), resend = %s (msg_id = %s), cancels = %s (msg_id = %s), rpc = %s, container = %s, root msg_id = %l', messageCount, packetSize, ackMsgIds?.length || 'false', @@ -1731,6 +1765,8 @@ export class SessionConnection extends PersistentConnection { getStateMsgIds?.map((it) => it.toString()) || 'false', getStateMsgId, resendMsgIds?.map((it) => it.toString()) || 'false', + cancelRpcs, + cancelRpcs?.map((it) => it.toString()) || 'false', resendMsgId, rpcToSend.map((it) => it.method), useContainer, diff --git a/packages/core/src/utils/controllable-promise.ts b/packages/core/src/utils/controllable-promise.ts index f4ee68b9..a688c4e4 100644 --- a/packages/core/src/utils/controllable-promise.ts +++ b/packages/core/src/utils/controllable-promise.ts @@ -6,13 +6,6 @@ export type ControllablePromise = Promise & { reject(err?: unknown): void } -/** - * A promise that can be cancelled. - */ -export type CancellablePromise = Promise & { - cancel(): void -} - /** * The promise was cancelled */ @@ -35,23 +28,3 @@ export function createControllablePromise(): ControllablePromise return promise as ControllablePromise } - -/** - * Creates a promise that can be cancelled. - * - * @param onCancel Callback to call when cancellation is requested - */ -export function createCancellablePromise( - onCancel: () => void, -): ControllablePromise & CancellablePromise { - // todo rethink this in MTQ-20 - - const promise = createControllablePromise() - - ;(promise as unknown as CancellablePromise).cancel = () => { - promise.reject(new PromiseCancelledError()) - onCancel() - } - - return promise as ControllablePromise & CancellablePromise -}