From 0290bb429a8483b32ba8bb9b73146e1d523b0c9f Mon Sep 17 00:00:00 2001 From: alina sireneva Date: Mon, 1 Jul 2024 12:05:57 +0300 Subject: [PATCH] fix(core): abort signal with workers --- packages/core/src/highlevel/worker/invoker.ts | 16 ++++++++- packages/core/src/highlevel/worker/port.ts | 21 +++++++++-- .../core/src/highlevel/worker/protocol.ts | 22 +++++++----- packages/core/src/highlevel/worker/worker.ts | 35 ++++++++++++++++++- 4 files changed, 82 insertions(+), 12 deletions(-) diff --git a/packages/core/src/highlevel/worker/invoker.ts b/packages/core/src/highlevel/worker/invoker.ts index 424a5d89..9a5ecffb 100644 --- a/packages/core/src/highlevel/worker/invoker.ts +++ b/packages/core/src/highlevel/worker/invoker.ts @@ -10,7 +10,7 @@ export class WorkerInvoker { private _nextId = 0 private _pending = new Map() - private _invoke(target: InvokeTarget, method: string, args: unknown[], isVoid: boolean) { + private _invoke(target: InvokeTarget, method: string, args: unknown[], isVoid: boolean, abortSignal?: AbortSignal) { const id = this._nextId++ this.send({ @@ -20,6 +20,14 @@ export class WorkerInvoker { method, args, void: isVoid, + withAbort: Boolean(abortSignal), + }) + + abortSignal?.addEventListener('abort', () => { + this.send({ + type: 'abort', + id, + }) }) if (!isVoid) { @@ -40,6 +48,12 @@ export class WorkerInvoker { this._invoke(target, method, args, true) } + invokeWithAbort(target: InvokeTarget, method: string, args: unknown[], abortSignal: AbortSignal): Promise { + if (abortSignal.aborted) return Promise.reject(abortSignal.reason) + + return this._invoke(target, method, args, false, abortSignal) as Promise + } + handleResult(msg: Extract) { const promise = this._pending.get(msg.id) if (!promise) return diff --git a/packages/core/src/highlevel/worker/port.ts b/packages/core/src/highlevel/worker/port.ts index 11142b9f..116aff98 100644 --- a/packages/core/src/highlevel/worker/port.ts +++ b/packages/core/src/highlevel/worker/port.ts @@ -1,3 +1,7 @@ +import { tl } from '@mtcute/tl' + +import { RpcCallOptions } from '../../network/network-manager.js' +import { MustEqual } from '../../types/utils.js' import { LogManager } from '../../utils/logger.js' import { ConnectionState, ITelegramClient, ServerUpdateHandler } from '../client.types.js' import { PeersIndex } from '../types/peers/peers-index.js' @@ -28,7 +32,6 @@ export abstract class TelegramWorkerPort imp readonly notifyLoggedOut readonly notifyChannelOpened readonly notifyChannelClosed - readonly call readonly importSession readonly exportSession readonly handleClientUpdate @@ -63,7 +66,6 @@ export abstract class TelegramWorkerPort imp this.notifyLoggedOut = bind('notifyLoggedOut') this.notifyChannelOpened = bind('notifyChannelOpened') this.notifyChannelClosed = bind('notifyChannelClosed') - this.call = bind('call') this.importSession = bind('importSession') this.exportSession = bind('exportSession') this.handleClientUpdate = bind('handleClientUpdate', true) @@ -77,6 +79,21 @@ export abstract class TelegramWorkerPort imp this.stopUpdatesLoop = bind('stopUpdatesLoop') } + call( + message: MustEqual, + params?: RpcCallOptions, + ): Promise { + if (params?.abortSignal) { + const { abortSignal, ...rest } = params + + return this._invoker.invokeWithAbort('client', 'call', [message, rest], abortSignal) as Promise< + tl.RpcCallReturn[T['_']] + > + } + + return this._invoker.invoke('client', 'call', [message, params]) as Promise + } + abstract connectToWorker(worker: SomeWorker, handler: ClientMessageHandler): [SendFn, () => void] private _serverUpdatesHandler: ServerUpdateHandler = () => {} diff --git a/packages/core/src/highlevel/worker/protocol.ts b/packages/core/src/highlevel/worker/protocol.ts index 144c371e..5837959a 100644 --- a/packages/core/src/highlevel/worker/protocol.ts +++ b/packages/core/src/highlevel/worker/protocol.ts @@ -5,14 +5,20 @@ import { tl } from '@mtcute/tl' import { ConnectionState } from '../client.types.js' import { SerializedError } from './errors.js' -export type WorkerInboundMessage = { - type: 'invoke' - id: number - target: 'custom' | 'client' | 'storage' | 'storage-self' | 'storage-peers' | 'app-config' - method: string - args: unknown[] - void: boolean -} +export type WorkerInboundMessage = + | { + type: 'invoke' + id: number + target: 'custom' | 'client' | 'storage' | 'storage-self' | 'storage-peers' | 'app-config' + method: string + args: unknown[] + void: boolean + withAbort: boolean + } + | { + type: 'abort' + id: number + } export type WorkerOutboundMessage = | { type: 'server_update'; update: tl.TypeUpdates } diff --git a/packages/core/src/highlevel/worker/worker.ts b/packages/core/src/highlevel/worker/worker.ts index b4bea663..404786d4 100644 --- a/packages/core/src/highlevel/worker/worker.ts +++ b/packages/core/src/highlevel/worker/worker.ts @@ -13,12 +13,22 @@ export abstract class TelegramWorker { abstract registerWorker(handler: WorkerMessageHandler): RespondFn + readonly pendingAborts = new Map() + constructor(readonly params: TelegramWorkerOptions) { this.broadcast = this.registerWorker((message, respond) => { switch (message.type) { case 'invoke': this.onInvoke(message, respond) break + case 'abort': { + const abort = this.pendingAborts.get(message.id) + + if (abort) { + abort.abort() + this.pendingAborts.delete(message.id) + } + } } }) @@ -116,9 +126,28 @@ export abstract class TelegramWorker { return } + let args = msg.args + + if (msg.target === 'client' && msg.method === 'call' && msg.withAbort) { + const abort = new AbortController() + this.pendingAborts.set(msg.id, abort) + + args = [ + args[0], + { + ...(args[1] as object), + abortSignal: abort.signal, + }, + ] + } + // eslint-disable-next-line @typescript-eslint/no-unsafe-call - Promise.resolve(method.apply(target, msg.args)) + Promise.resolve(method.apply(target, args)) .then((res) => { + if (msg.withAbort) { + this.pendingAborts.delete(msg.id) + } + if (msg.void) return respond({ @@ -128,6 +157,10 @@ export abstract class TelegramWorker { }) }) .catch((err) => { + if (msg.withAbort) { + this.pendingAborts.delete(msg.id) + } + respond({ type: 'result', id: msg.id,