From c5438a2f2978fee57dcf5987e027cfc2c5a6b527 Mon Sep 17 00:00:00 2001 From: alina sireneva Date: Sun, 7 Jul 2024 03:06:12 +0300 Subject: [PATCH] feat(core): outgoing request middlewares --- packages/core/src/highlevel/base.ts | 10 +- packages/core/src/network/index.ts | 9 +- .../core/src/network/middlewares/bundle.ts | 2 + .../core/src/network/middlewares/index.ts | 3 + .../core/src/network/middlewares/on-error.ts | 49 +++++ .../core/src/network/middlewares/on-method.ts | 54 ++++++ packages/core/src/network/network-manager.ts | 49 ++++- packages/core/src/utils/composer.test.ts | 179 ++++++++++++++++++ packages/core/src/utils/composer.ts | 30 +++ packages/core/src/utils/index.ts | 1 + 10 files changed, 377 insertions(+), 9 deletions(-) create mode 100644 packages/core/src/network/middlewares/bundle.ts create mode 100644 packages/core/src/network/middlewares/index.ts create mode 100644 packages/core/src/network/middlewares/on-error.ts create mode 100644 packages/core/src/network/middlewares/on-method.ts create mode 100644 packages/core/src/utils/composer.test.ts create mode 100644 packages/core/src/utils/composer.ts diff --git a/packages/core/src/highlevel/base.ts b/packages/core/src/highlevel/base.ts index c507195a..f6708d9a 100644 --- a/packages/core/src/highlevel/base.ts +++ b/packages/core/src/highlevel/base.ts @@ -11,7 +11,9 @@ import { asyncResettable, computeNewPasswordHash, computeSrpParams, + ICryptoProvider, isTlRpcError, + Logger, readStringSession, StringSessionData, writeStringSession, @@ -45,10 +47,10 @@ export class BaseTelegramClient implements ITelegramClient { private _serverUpdatesHandler: ServerUpdateHandler = () => {} private _connectionStateHandler: (state: ConnectionState) => void = () => {} - readonly log - readonly mt - readonly crypto - readonly storage + readonly log: Logger + readonly mt: MtClient + readonly crypto: ICryptoProvider + readonly storage: TelegramStorageManager constructor(readonly params: BaseTelegramClientOptions) { this.log = this.params.logger ?? new LogManager('client') diff --git a/packages/core/src/network/index.ts b/packages/core/src/network/index.ts index a1dba3f2..abd57885 100644 --- a/packages/core/src/network/index.ts +++ b/packages/core/src/network/index.ts @@ -1,5 +1,12 @@ export * from './client.js' -export type { ConnectionKind, NetworkManagerExtraParams, RpcCallOptions } from './network-manager.js' +export * from './middlewares/index.js' +export type { + ConnectionKind, + NetworkManagerExtraParams, + RpcCallMiddleware, + RpcCallMiddlewareContext, + RpcCallOptions, +} from './network-manager.js' export * from './reconnection.js' export * from './session-connection.js' export * from './transports/index.js' diff --git a/packages/core/src/network/middlewares/bundle.ts b/packages/core/src/network/middlewares/bundle.ts new file mode 100644 index 00000000..00f48163 --- /dev/null +++ b/packages/core/src/network/middlewares/bundle.ts @@ -0,0 +1,2 @@ +export * from './on-error.js' +export * from './on-method.js' diff --git a/packages/core/src/network/middlewares/index.ts b/packages/core/src/network/middlewares/index.ts new file mode 100644 index 00000000..0954b25f --- /dev/null +++ b/packages/core/src/network/middlewares/index.ts @@ -0,0 +1,3 @@ +import * as networkMiddlewares from './bundle.js' + +export { networkMiddlewares } diff --git a/packages/core/src/network/middlewares/on-error.ts b/packages/core/src/network/middlewares/on-error.ts new file mode 100644 index 00000000..07c5585f --- /dev/null +++ b/packages/core/src/network/middlewares/on-error.ts @@ -0,0 +1,49 @@ +import { mtp } from '@mtcute/tl' + +import { MaybePromise } from '../../types/utils.js' +import { isTlRpcError } from '../../utils/type-assertions.js' +import { RpcCallMiddleware, RpcCallMiddlewareContext } from '../network-manager.js' + +/** + * Middleware that will call `handler` whenever an RPC error happens, + * with the error object itself. + * + * The handler can either return nothing + * (in which case the original error will be thrown), a new error + * (via the `_: 'mt_rpc_error'` object), or any other value, which + * will be returned as the result of the RPC call. + * + * Note that the return value is **not type-checked** + * due to limitations of TypeScript. You'll probably want to use `satisfies` + * keyword to ensure the return value is correct, for example: + * + * ```ts + * networkMiddlewares.onRpcError(async (ctx, error) => { + * if (rpc.request._ === 'help.getNearestDc') { + * return { + * _: 'nearestDc', + * country: 'RU', + * thisDc: 2, + * nearestDc: 2, + * } satisfies tl.RpcCallReturn['help.getNearestDc'] + * } + * }) + * ``` + */ +export function onRpcError( + handler: (ctx: RpcCallMiddlewareContext, error: mtp.RawMt_rpc_error) => MaybePromise, +): RpcCallMiddleware { + return async (ctx, next) => { + let res = await next(ctx) + + if (isTlRpcError(res)) { + const handlerRes = await handler(ctx, res) + + if (handlerRes !== undefined) { + res = handlerRes + } + } + + return res + } +} diff --git a/packages/core/src/network/middlewares/on-method.ts b/packages/core/src/network/middlewares/on-method.ts new file mode 100644 index 00000000..b8430ad4 --- /dev/null +++ b/packages/core/src/network/middlewares/on-method.ts @@ -0,0 +1,54 @@ +import { tl } from '@mtcute/tl' + +import { Middleware } from '../../utils/composer.js' +import { RpcCallMiddleware, RpcCallMiddlewareContext } from '../network-manager.js' + +/** + * Middleware that will call `handler` whenever `method` RPC method is called. + * + * This helper exists due to TypeScript limitations not allowing us to + * properly type the return type without explicit type annotations, + * for a bit more type-safe and clean code: + * + * ```ts + * // before + * async (ctx, next) => { + * if (rpc.request._ === 'help.getNearestDc') { + * return { + * _: 'nearestDc', + * country: 'RU', + * thisDc: 2, + * nearestDc: 2, + * } satisfies tl.RpcCallReturn['help.getNearestDc'] + * } + * + * return next(ctx) + * } + * + * // after + * onMethod('help.getNearestDc', async () => ({ + * _: 'nearestDc' as const, // (otherwise ts will infer this as `string` and will complain) + * country: 'RU', + * thisDc: 2, + * nearestDc: 2, + * }) + * ``` + */ +export function onMethod( + method: T, + middleware: Middleware< + Omit & { + request: Extract + }, + tl.RpcCallReturn[T] + >, +): RpcCallMiddleware { + return async (ctx, next) => { + if (ctx.request._ !== method) { + return next(ctx) + } + + // eslint-disable-next-line + return middleware(ctx as any, next) + } +} diff --git a/packages/core/src/network/network-manager.ts b/packages/core/src/network/network-manager.ts index 59e3640e..33765100 100644 --- a/packages/core/src/network/network-manager.ts +++ b/packages/core/src/network/network-manager.ts @@ -4,6 +4,7 @@ import { TlReaderMap, TlWriterMap } from '@mtcute/tl-runtime' import { getPlatform } from '../platform.js' import { StorageManager } from '../storage/storage.js' import { MtArgumentError, MtcuteError, MtTimeoutError, MtUnsupportedError } from '../types/index.js' +import { ComposedMiddleware, composeMiddlewares, Middleware } from '../utils/composer.js' import { ControllablePromise, createControllablePromise, @@ -113,6 +114,14 @@ export interface NetworkManagerExtraParams { * @default 60000 (60 seconds). */ inactivityTimeout?: number + + /** + * List of middlewares to use for the network manager + * + * > **Note**: these middlewares apply to **outgoing requests only**. + * > If you need to handle incoming updates, use a {@link Dispatcher} instead. + */ + middlewares?: Middleware[] } /** Options that can be customized when making an RPC call */ @@ -198,6 +207,13 @@ export interface RpcCallOptions { chainId?: string | number } +export interface RpcCallMiddlewareContext { + request: tl.RpcMethod + manager: NetworkManager + params?: RpcCallOptions +} +export type RpcCallMiddleware = Middleware + /** * Wrapper over all connection pools for a single DC. */ @@ -501,6 +517,8 @@ export class NetworkManager { this._connectionCount = params.connectionCount ?? defaultConnectionCountDelegate this._updateHandler = params.onUpdate + this.call = this._composeCall(params.middlewares) + this._onConfigChanged = this._onConfigChanged.bind(this) config.onReload(this._onConfigChanged) @@ -752,11 +770,34 @@ export class NetworkManager { await this._switchPrimaryDc(this._dcConnections.get(newDc)!) } - private _floodWaitedRequests = new Map() - async call( + readonly call: ( message: T, params?: RpcCallOptions, - ): Promise { + ) => Promise + + private _composeCall = (middlewares?: Middleware[]) => { + if (!middlewares?.length) { + return this._call + } + + const final: ComposedMiddleware = async (ctx) => { + return this._call(ctx.request, ctx.params) + } + const composed = composeMiddlewares(middlewares, final) + + return async (message: T, params?: RpcCallOptions): Promise => + composed({ + request: message, + manager: this, + params, + }) + } + + private _floodWaitedRequests = new Map() + private _call = async ( + message: T, + params?: RpcCallOptions, + ): Promise => { if (!this._primaryDc) { throw new MtcuteError('Not connected to any DC') } @@ -871,7 +912,7 @@ export class NetworkManager { if (params?.localMigrate) { manager = await this._getOtherDc(newDc) } else { - this._log.info('Migrate error, new dc = %d', newDc) + this._log.info('received %s, migrating to dc %d', err, newDc) await this.changePrimaryDc(newDc) manager = this._primaryDc! diff --git a/packages/core/src/utils/composer.test.ts b/packages/core/src/utils/composer.test.ts new file mode 100644 index 00000000..9fa79ad4 --- /dev/null +++ b/packages/core/src/utils/composer.test.ts @@ -0,0 +1,179 @@ +/* eslint-disable @typescript-eslint/require-await */ +import { describe, expect, it } from 'vitest' + +import { composeMiddlewares, Middleware } from './composer.js' + +describe('composeMiddlewares', () => { + it('should compose middlewares', async () => { + const trace: unknown[] = [] + const middlewares: Middleware[] = [ + async (ctx, next) => { + trace.push(ctx) + trace.push(1) + await next([...ctx, 1]) + trace.push(6) + }, + async (ctx, next) => { + trace.push(ctx) + trace.push(2) + await next([...ctx, 2]) + trace.push(5) + }, + async (ctx, next) => { + trace.push(ctx) + trace.push(3) + await next([...ctx, 3]) + trace.push(4) + }, + ] + + const composed = composeMiddlewares(middlewares, async (res) => { + result = res + }) + + let result: readonly number[] = [] + await composed([]) + + expect(trace).toEqual([[], 1, [1], 2, [1, 2], 3, 4, 5, 6]) + expect(result).toEqual([1, 2, 3]) + }) + + it('should handle multiple calls to final', async () => { + const trace: unknown[] = [] + + const middlewares: Middleware[] = [ + async (ctx, next) => { + trace.push(1) + await next([2]) + trace.push(3) + await next([4]) + trace.push(5) + }, + ] + + const composed = composeMiddlewares(middlewares, async (res) => { + trace.push(res) + }) + + await composed([]) + + expect(trace).toEqual([1, [2], 3, [4], 5]) + }) + + it('should handle multiple calls to next midway', async () => { + const trace: unknown[] = [] + + const middlewares: Middleware[] = [ + async (ctx, next) => { + trace.push(1) + await next([2]) + trace.push(3) + await next([4]) + trace.push(5) + }, + (ctx, next) => next([6, ...ctx]), + ] + + const composed = composeMiddlewares(middlewares, async (res) => { + trace.push(res) + }) + + await composed([]) + + expect(trace).toEqual([1, [6, 2], 3, [6, 4], 5]) + }) + + it('should handle leaf middleware', async () => { + const trace: unknown[] = [] + + const middlewares: Middleware[] = [ + async (ctx, next) => { + trace.push(1) + + return next(ctx) + }, + async () => { + /* do nothing */ + }, + ] + + const composed = composeMiddlewares(middlewares, async (res) => { + trace.push(res) // should not be called + }) + + await composed([]) + + expect(trace).toEqual([1]) + }) + + it('should propagate return value', async () => { + const trace: unknown[] = [] + + const middlewares: Middleware[] = [ + async (ctx, next) => { + trace.push(1) + const res = await next([2]) + trace.push(3) + const res2 = await next([3, 4, 5]) + trace.push(6) + + return res + res2 + }, + async (ctx, next) => { + trace.push(-1) + + return (await next(ctx)) + 1 + }, + ] + + const composed = composeMiddlewares(middlewares, async (res) => { + trace.push(res) + + return res.length + }) + + const result = await composed([]) + + expect(trace).toEqual([1, -1, [2], 3, -1, [3, 4, 5], 6]) + expect(result).toBe(6) + }) + + it('should propagate errors', async () => { + const trace: unknown[] = [] + + const middlewares: Middleware[] = [ + async (ctx, next) => { + trace.push(1) + + try { + await next(2) + } catch (e) { + trace.push('caught error') + } + + trace.push(3) + await next(4) + trace.push(5) + }, + (ctx, next) => next(ctx), // pass-thru + async (ctx, next) => { + if (ctx === 2) { + trace.push('error') + throw new Error('error') + } else { + trace.push('ok') + + return next(ctx) + } + }, + ] + + const composed = composeMiddlewares(middlewares, async (res) => { + trace.push(`final ${res}`) + }) + + await composed(0) + + expect(trace).toEqual([1, 'error', 'caught error', 3, 'ok', 'final 4', 5]) + }) +}) diff --git a/packages/core/src/utils/composer.ts b/packages/core/src/utils/composer.ts new file mode 100644 index 00000000..5c4037b9 --- /dev/null +++ b/packages/core/src/utils/composer.ts @@ -0,0 +1,30 @@ +export type Middleware = ( + ctx: Context, + next: (ctx: Context) => Promise, +) => Promise +export type ComposedMiddleware = (ctx: Context) => Promise + +export function composeMiddlewares( + middlewares: Middleware[], + final: ComposedMiddleware, +): ComposedMiddleware { + middlewares = middlewares.slice() + middlewares.push(final) + + function dispatch(i: number, ctx: Context): Promise { + const fn = middlewares[i] + if (!fn) return final(ctx) + + return fn(ctx, boundDispatches[i + 1]) + } + + const boundDispatches: Array<(ctx: Context) => Promise> = [] + + for (let i = 0; i < middlewares.length; i++) { + boundDispatches.push(dispatch.bind(null, i)) + } + + return function (context: Context): Promise { + return boundDispatches[0](context) + } +} diff --git a/packages/core/src/utils/index.ts b/packages/core/src/utils/index.ts index 2a9c6b56..f3b0c972 100644 --- a/packages/core/src/utils/index.ts +++ b/packages/core/src/utils/index.ts @@ -8,6 +8,7 @@ export * from '../storage/service/default-dcs.js' export * from './async-lock.js' export * from './bigint-utils.js' export * from './buffer-utils.js' +export * from './composer.js' export * from './condition-variable.js' export * from './controllable-promise.js' export * from './crypto/index.js'