feat: abort signals
This commit is contained in:
parent
31b41c93fc
commit
85ca3b4603
8 changed files with 75 additions and 41 deletions
|
@ -36,6 +36,15 @@ export function downloadToFile(this: TelegramClient, filename: string, params: F
|
|||
const output = fs.createWriteStream(filename)
|
||||
const stream = this.downloadAsStream(params)
|
||||
|
||||
if (params.abortSignal) {
|
||||
params.abortSignal.addEventListener('abort', () => {
|
||||
this.log.debug('aborting file download %s - cleaning up', filename)
|
||||
output.destroy()
|
||||
stream.destroy()
|
||||
fs!.rmSync(filename)
|
||||
})
|
||||
}
|
||||
|
||||
return new Promise((resolve, reject) => {
|
||||
stream.on('error', reject).pipe(output).on('finish', resolve).on('error', reject)
|
||||
})
|
||||
|
|
|
@ -113,7 +113,7 @@ export async function* downloadAsIterable(
|
|||
offset: chunkSize * chunk,
|
||||
limit: chunkSize,
|
||||
},
|
||||
{ dcId, kind: connectionKind },
|
||||
{ dcId, kind: connectionKind, abortSignal: params.abortSignal },
|
||||
)
|
||||
} catch (e: unknown) {
|
||||
if (!tl.RpcError.is(e)) throw e
|
||||
|
|
|
@ -105,4 +105,9 @@ export interface FileDownloadParameters {
|
|||
* @param total Total file size (`Infinity` if not available)
|
||||
*/
|
||||
progressCallback?: (downloaded: number, total: number) => void
|
||||
|
||||
/**
|
||||
* Abort signal that can be used to cancel the download.
|
||||
*/
|
||||
abortSignal?: AbortSignal
|
||||
}
|
||||
|
|
|
@ -25,6 +25,7 @@ export interface PendingRpc {
|
|||
gzipOverhead?: number
|
||||
|
||||
sent?: boolean
|
||||
done?: boolean
|
||||
msgId?: Long
|
||||
seqNo?: number
|
||||
containerId?: Long
|
||||
|
|
|
@ -188,7 +188,12 @@ export class MultiSessionConnection extends EventEmitter {
|
|||
|
||||
private _nextConnection = 0
|
||||
|
||||
sendRpc<T extends tl.RpcMethod>(request: T, stack?: string, timeout?: number): Promise<tl.RpcCallReturn[T['_']]> {
|
||||
sendRpc<T extends tl.RpcMethod>(
|
||||
request: T,
|
||||
stack?: string,
|
||||
timeout?: number,
|
||||
abortSignal?: AbortSignal,
|
||||
): Promise<tl.RpcCallReturn[T['_']]> {
|
||||
// if (this.params.isMainConnection) {
|
||||
// find the least loaded connection
|
||||
let min = Infinity
|
||||
|
@ -204,7 +209,7 @@ export class MultiSessionConnection extends EventEmitter {
|
|||
}
|
||||
}
|
||||
|
||||
return this._connections[minIdx].sendRpc(request, stack, timeout)
|
||||
return this._connections[minIdx].sendRpc(request, stack, timeout, abortSignal)
|
||||
// }
|
||||
|
||||
// round-robin connections
|
||||
|
|
|
@ -139,6 +139,11 @@ export interface RpcCallOptions {
|
|||
* Overrides `dcId` if set.
|
||||
*/
|
||||
manager?: DcConnectionManager
|
||||
|
||||
/**
|
||||
* Abort signal for the call.
|
||||
*/
|
||||
abortSignal?: AbortSignal
|
||||
}
|
||||
|
||||
export class DcConnectionManager {
|
||||
|
@ -662,7 +667,7 @@ export class NetworkManager {
|
|||
|
||||
for (let i = 0; i < maxRetryCount; i++) {
|
||||
try {
|
||||
const res = await multi.sendRpc(message, stack, params?.timeout)
|
||||
const res = await multi.sendRpc(message, stack, params?.timeout, params?.abortSignal)
|
||||
|
||||
if (kind === 'main') {
|
||||
this._lastUpdateTime = Date.now()
|
||||
|
|
|
@ -9,7 +9,6 @@ import { gzipDeflate, gzipInflate } from '@mtcute/tl-runtime/src/platform/gzip'
|
|||
import { MtArgumentError, MtcuteError, MtTimeoutError } from '../types'
|
||||
import {
|
||||
ControllablePromise,
|
||||
createCancellablePromise,
|
||||
createControllablePromise,
|
||||
EarlyTimer,
|
||||
longFromBuffer,
|
||||
|
@ -691,6 +690,23 @@ export class SessionConnection extends PersistentConnection {
|
|||
return
|
||||
}
|
||||
|
||||
if (msg._ === 'cancel') {
|
||||
let result
|
||||
|
||||
try {
|
||||
result = message.object() as mtp.TlObject
|
||||
} catch (err) {
|
||||
this.log.debug('failed to parse rpc_result for cancel request %l, ignoring', reqMsgId)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
this.log.debug('received %s for cancelled request %l: %j', result._, reqMsgId)
|
||||
this._onMessageAcked(reqMsgId)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
this.log.error('received rpc_result for %s request %l', msg._, reqMsgId)
|
||||
|
||||
return
|
||||
|
@ -707,6 +723,8 @@ export class SessionConnection extends PersistentConnection {
|
|||
this._session.initConnectionCalled = true
|
||||
}
|
||||
|
||||
rpc.done = true
|
||||
|
||||
this.log.verbose('<<< (%s) %j', rpc.method, result)
|
||||
|
||||
if (result._ === 'mt_rpc_error') {
|
||||
|
@ -1165,7 +1183,12 @@ export class SessionConnection extends PersistentConnection {
|
|||
}
|
||||
}
|
||||
|
||||
sendRpc<T extends tl.RpcMethod>(request: T, stack?: string, timeout?: number): Promise<tl.RpcCallReturn[T['_']]> {
|
||||
sendRpc<T extends tl.RpcMethod>(
|
||||
request: T,
|
||||
stack?: string,
|
||||
timeout?: number,
|
||||
abortSignal?: AbortSignal,
|
||||
): Promise<tl.RpcCallReturn[T['_']]> {
|
||||
if (this._usable && this.params.inactivityTimeout) {
|
||||
this._rescheduleInactivity()
|
||||
}
|
||||
|
@ -1244,8 +1267,7 @@ export class SessionConnection extends PersistentConnection {
|
|||
|
||||
const pending: PendingRpc = {
|
||||
method,
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-assignment
|
||||
promise: undefined as any, // because we need the object to make a promise
|
||||
promise: createControllablePromise(),
|
||||
data: content,
|
||||
stack,
|
||||
// we will need to know size of gzip_packed overhead in _flush()
|
||||
|
@ -1254,6 +1276,7 @@ export class SessionConnection extends PersistentConnection {
|
|||
|
||||
// setting them as well so jit can optimize stuff
|
||||
sent: undefined,
|
||||
done: undefined,
|
||||
getState: undefined,
|
||||
msgId: undefined,
|
||||
seqNo: undefined,
|
||||
|
@ -1263,19 +1286,28 @@ export class SessionConnection extends PersistentConnection {
|
|||
timeout: undefined,
|
||||
}
|
||||
|
||||
const promise = createCancellablePromise<any>(this._cancelRpc.bind(this, pending))
|
||||
pending.promise = promise
|
||||
if (abortSignal?.aborted) {
|
||||
pending.promise.reject(abortSignal.reason)
|
||||
|
||||
return pending.promise
|
||||
}
|
||||
|
||||
if (timeout) {
|
||||
pending.timeout = setTimeout(this._cancelRpc, timeout, pending, true)
|
||||
}
|
||||
|
||||
if (abortSignal) {
|
||||
abortSignal.addEventListener('abort', () => this._cancelRpc(pending, false, abortSignal))
|
||||
}
|
||||
|
||||
this._enqueueRpc(pending, true)
|
||||
|
||||
return promise
|
||||
return pending.promise
|
||||
}
|
||||
|
||||
private _cancelRpc(rpc: PendingRpc, onTimeout = false): void {
|
||||
private _cancelRpc(rpc: PendingRpc, onTimeout = false, abortSignal?: AbortSignal): void {
|
||||
if (rpc.done) return
|
||||
|
||||
if (rpc.cancelled && !onTimeout) {
|
||||
throw new MtcuteError('RPC was already cancelled')
|
||||
}
|
||||
|
@ -1286,13 +1318,15 @@ export class SessionConnection extends PersistentConnection {
|
|||
|
||||
if (onTimeout) {
|
||||
// todo: replace with MtTimeoutError
|
||||
const error = new tl.RpcError(-503, 'Timeout')
|
||||
const error = new tl.RpcError(400, 'Client timeout')
|
||||
|
||||
if (this.params.niceStacks !== false) {
|
||||
makeNiceStack(error, rpc.stack!, rpc.method)
|
||||
}
|
||||
|
||||
rpc.promise.reject(error)
|
||||
} else if (abortSignal) {
|
||||
rpc.promise.reject(abortSignal.reason)
|
||||
}
|
||||
|
||||
rpc.cancelled = true
|
||||
|
@ -1721,7 +1755,7 @@ export class SessionConnection extends PersistentConnection {
|
|||
const rootMsgId = new Long(result.readInt32LE(), result.readInt32LE(4))
|
||||
|
||||
this.log.debug(
|
||||
'sending %d messages: size = %db, acks = %d (msg_id = %s), ping = %s (msg_id = %s), state_req = %s (msg_id = %s), resend = %s (msg_id = %s), rpc = %s, container = %s, root msg_id = %l',
|
||||
'sending %d messages: size = %db, acks = %d (msg_id = %s), ping = %s (msg_id = %s), state_req = %s (msg_id = %s), resend = %s (msg_id = %s), cancels = %s (msg_id = %s), rpc = %s, container = %s, root msg_id = %l',
|
||||
messageCount,
|
||||
packetSize,
|
||||
ackMsgIds?.length || 'false',
|
||||
|
@ -1731,6 +1765,8 @@ export class SessionConnection extends PersistentConnection {
|
|||
getStateMsgIds?.map((it) => it.toString()) || 'false',
|
||||
getStateMsgId,
|
||||
resendMsgIds?.map((it) => it.toString()) || 'false',
|
||||
cancelRpcs,
|
||||
cancelRpcs?.map((it) => it.toString()) || 'false',
|
||||
resendMsgId,
|
||||
rpcToSend.map((it) => it.method),
|
||||
useContainer,
|
||||
|
|
|
@ -6,13 +6,6 @@ export type ControllablePromise<T = unknown> = Promise<T> & {
|
|||
reject(err?: unknown): void
|
||||
}
|
||||
|
||||
/**
|
||||
* A promise that can be cancelled.
|
||||
*/
|
||||
export type CancellablePromise<T = unknown> = Promise<T> & {
|
||||
cancel(): void
|
||||
}
|
||||
|
||||
/**
|
||||
* The promise was cancelled
|
||||
*/
|
||||
|
@ -35,23 +28,3 @@ export function createControllablePromise<T = unknown>(): ControllablePromise<T>
|
|||
|
||||
return promise as ControllablePromise<T>
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a promise that can be cancelled.
|
||||
*
|
||||
* @param onCancel Callback to call when cancellation is requested
|
||||
*/
|
||||
export function createCancellablePromise<T = unknown>(
|
||||
onCancel: () => void,
|
||||
): ControllablePromise<T> & CancellablePromise<T> {
|
||||
// todo rethink this in MTQ-20
|
||||
|
||||
const promise = createControllablePromise()
|
||||
|
||||
;(promise as unknown as CancellablePromise<T>).cancel = () => {
|
||||
promise.reject(new PromiseCancelledError())
|
||||
onCancel()
|
||||
}
|
||||
|
||||
return promise as ControllablePromise<T> & CancellablePromise<T>
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue