diff --git a/packages/dispatcher/scripts/generate.cjs b/packages/dispatcher/scripts/generate.cjs index 5f6acdc5..e2620e17 100644 --- a/packages/dispatcher/scripts/generate.cjs +++ b/packages/dispatcher/scripts/generate.cjs @@ -1,5 +1,4 @@ -/* eslint-disable no-restricted-globals */ -const { types, toSentence, replaceSections, formatFile } = require('../../client/scripts/generate-updates') +const { types, toSentence, replaceSections, formatFile } = require('../../client/scripts/generate-updates.cjs') function generateHandler() { const lines = [] @@ -41,7 +40,7 @@ function generateDispatcher() { * @param group Handler group index */ on${type.handlerTypeName}(handler: ${type.handlerTypeName}Handler${ - type.state ? `<${type.context}, State extends never ? never : UpdateState>` : '' + type.state ? `<${type.context}, State extends never ? never : UpdateState>` : '' }['callback'], group?: number): void ${ @@ -58,7 +57,7 @@ ${ filter: UpdateFilter<${type.context}, Mod, State>, handler: ${type.handlerTypeName}Handler, State extends never ? never : UpdateState>['callback'], +}, Mod>, State extends never ? never : UpdateState>['callback'], group?: number ): void ` : @@ -75,7 +74,7 @@ ${ on${type.handlerTypeName}( filter: UpdateFilter<${type.context}, Mod>, handler: ${type.handlerTypeName}Handler${ - type.state ? ', State extends never ? never : UpdateState' : '' + type.state ? ', State extends never ? never : UpdateState' : '' }>['callback'], group?: number ): void diff --git a/packages/dispatcher/src/dispatcher.ts b/packages/dispatcher/src/dispatcher.ts index 3dffde34..0d35b515 100644 --- a/packages/dispatcher/src/dispatcher.ts +++ b/packages/dispatcher/src/dispatcher.ts @@ -60,22 +60,49 @@ import { } from './handler.js' // end-codegen-imports import { PropagationAction } from './propagation.js' -import { defaultStateKeyDelegate, IStateStorage, StateKeyDelegate, UpdateState } from './state/index.js' +import { + defaultStateKeyDelegate, + isCompatibleStorage, + IStateStorage, + StateKeyDelegate, + UpdateState, +} from './state/index.js' + +export interface DispatcherParams { + /** + * If this dispatcher can be used as a scene, its unique name. + * + * Should not be set manually, use {@link Dispatcher#scene} instead + */ + sceneName?: string + + /** + * Custom storage for the dispatcher. + * + * @default Client's storage + */ + storage?: IStateStorage + + /** + * Custom key delegate for the dispatcher. + */ + key?: StateKeyDelegate +} /** * Updates dispatcher */ -export class Dispatcher { +export class Dispatcher { private _groups: Record> = {} private _groupsOrder: number[] = [] private _client?: TelegramClient private _parent?: Dispatcher - private _children: Dispatcher[] = [] + private _children: Dispatcher[] = [] - private _scenes?: Record> - private _scene?: SceneName + private _scenes?: Record> + private _scene?: string private _sceneScoped?: boolean private _storage?: State extends never ? undefined : IStateStorage @@ -87,67 +114,96 @@ export class Dispatcher { private _errorHandler?: ( err: Error, update: ParsedUpdate & T, - state?: UpdateState, + state?: UpdateState, ) => MaybeAsync private _preUpdateHandler?: ( update: ParsedUpdate & T, - state?: UpdateState, + state?: UpdateState, ) => MaybeAsync private _postUpdateHandler?: ( handled: boolean, update: ParsedUpdate & T, - state?: UpdateState, + state?: UpdateState, ) => MaybeAsync - /** - * Create a new dispatcher, that will be used as a child, - * optionally providing a custom key delegate - */ - constructor(key?: StateKeyDelegate) - /** - * Create a new dispatcher, that will be used as a child, optionally - * providing custom storage and key delegate - */ - constructor(storage: IStateStorage, key?: StateKeyDelegate) - /** - * Create a new dispatcher and bind it to client and optionally - * FSM storage - */ - constructor( - client: TelegramClient, - ...args: (() => State) extends () => never ? [] : [IStateStorage, StateKeyDelegate?] - ) - constructor( - client?: TelegramClient | IStateStorage | StateKeyDelegate, - storage?: IStateStorage | StateKeyDelegate, - key?: StateKeyDelegate, - ) { + protected constructor(client?: TelegramClient, params?: DispatcherParams) { this.dispatchRawUpdate = this.dispatchRawUpdate.bind(this) this.dispatchUpdate = this.dispatchUpdate.bind(this) + // eslint-disable-next-line prefer-const + let { storage, key, sceneName } = params ?? {} + if (client) { - if (client instanceof TelegramClient) { - this.bindToClient(client) + this.bindToClient(client) - if (storage) { - this._storage = storage as any - this._stateKeyDelegate = (key ?? defaultStateKeyDelegate) as any - } - } else if (typeof client === 'function') { - // is StateKeyDelegate - this._customStateKeyDelegate = client as any - } else { - this._customStorage = client as any + if (!storage) { + const _storage = client.storage - if (storage) { - this._customStateKeyDelegate = client as any + if (!isCompatibleStorage(_storage)) { + throw new MtArgumentError( + 'Storage used by the client is not compatible with the dispatcher. Please provide a compatible storage manually', + ) } + + storage = _storage + } + + if (storage) { + this._storage = storage as any + this._stateKeyDelegate = (key ?? defaultStateKeyDelegate) as any + } + } else { + // child dispatcher without client + + if (storage) { + this._customStorage = storage as any + } + + if (key) { + this._customStateKeyDelegate = key as any + } + + if (sceneName) { + if (sceneName[0] === '$') { + throw new MtArgumentError('Scene name cannot start with $') + } + + this._scene = sceneName } } } + /** + * Create a new dispatcher and bind it to the client. + */ + static for(client: TelegramClient, params?: DispatcherParams): Dispatcher { + return new Dispatcher(client, params) + } + + /** + * Create a new child dispatcher. + */ + static child(params?: DispatcherParams): Dispatcher { + return new Dispatcher(undefined, params) + } + + /** + * Create a new scene dispatcher + */ + static scene>( + name: string, + params?: Omit, + ): Dispatcher { + return new Dispatcher(undefined, { sceneName: name, ...params }) + } + + /** For scene dispatchers, name of the scene */ + get sceneName(): string | undefined { + return this._scene + } + /** * Bind the dispatcher to the client. * Called by the constructor automatically if @@ -289,7 +345,7 @@ export class Dispatcher { private async _dispatchUpdateNowImpl( update: ParsedUpdate, // this is getting a bit crazy lol - parsedState?: UpdateState | null, + parsedState?: UpdateState | null, parsedScene?: string | null, forceScene?: true, parsedContext?: UpdateContextType, @@ -525,9 +581,7 @@ export class Dispatcher { * @param handler Error handler */ onError( - handler: - | ((err: Error, update: ParsedUpdate & T, state?: UpdateState) => MaybeAsync) - | null, + handler: ((err: Error, update: ParsedUpdate & T, state?: UpdateState) => MaybeAsync) | null, ): void { if (handler) this._errorHandler = handler else this._errorHandler = undefined @@ -547,10 +601,7 @@ export class Dispatcher { */ onPreUpdate( handler: - | (( - update: ParsedUpdate & T, - state?: UpdateState, - ) => MaybeAsync) + | ((update: ParsedUpdate & T, state?: UpdateState) => MaybeAsync) | null, ): void { if (handler) this._preUpdateHandler = handler @@ -570,9 +621,7 @@ export class Dispatcher { * @param handler Pre-update middleware */ onPostUpdate( - handler: - | ((handled: boolean, update: ParsedUpdate & T, state?: UpdateState) => MaybeAsync) - | null, + handler: ((handled: boolean, update: ParsedUpdate & T, state?: UpdateState) => MaybeAsync) | null, ): void { if (handler) this._postUpdateHandler = handler else this._postUpdateHandler = undefined @@ -582,17 +631,13 @@ export class Dispatcher { * Set error handler that will propagate * the error to the parent dispatcher */ - propagateErrorToParent( - err: Error, - update: ParsedUpdate, - state?: UpdateState, - ): MaybeAsync { + propagateErrorToParent(err: Error, update: ParsedUpdate, state?: UpdateState): MaybeAsync { if (!this.parent) { throw new MtArgumentError('This dispatcher is not a child') } if (this.parent._errorHandler) { - return this.parent._errorHandler(err, update, state) + return this.parent._errorHandler(err, update, state as any) } throw err } @@ -607,7 +652,7 @@ export class Dispatcher { return this._parent ?? null } - private _prepareChild(child: Dispatcher): void { + private _prepareChild(child: Dispatcher): void { if (child._client) { throw new MtArgumentError( 'Provided dispatcher is ' + @@ -638,7 +683,7 @@ export class Dispatcher { * * @param child Other dispatcher */ - addChild(child: Dispatcher): void { + addChild(child: Dispatcher): void { if (this._children.includes(child)) return this._prepareChild(child) @@ -658,7 +703,7 @@ export class Dispatcher { * @param scene Dispatcher representing the scene * @param scoped Whether to use scoped FSM storage for the scene */ - addScene(uid: SceneName, scene: Dispatcher, scoped: false): void + addScene(scene: Dispatcher, scoped: false): void /** * Add a dispatcher as a scene with a scoped state * @@ -672,26 +717,23 @@ export class Dispatcher { * @param scene Dispatcher representing the scene * @param scoped Whether to use scoped FSM storage for the scene (defaults to `true`) */ - addScene(uid: SceneName, scene: Dispatcher, scoped?: true): void - addScene(uid: SceneName, scene: Dispatcher, scoped = true): void { + addScene(scene: Dispatcher, scoped?: true): void + addScene(scene: Dispatcher, scoped = true): void { if (!this._scenes) this._scenes = {} - if (uid in this._scenes) { - throw new MtArgumentError(`Scene with UID ${uid} is already registered!`) + if (!scene._scene) { + throw new MtArgumentError( + 'Non-scene dispatcher passed to addScene. Use `Dispatcher.scene()` to create one.', + ) } - if (uid[0] === '$') { - throw new MtArgumentError('Scene UID cannot start with $') - } - - if (scene._scene) { - throw new MtArgumentError(`This dispatcher is already registered as scene ${scene._scene}`) + if (scene._scene in this._scenes) { + throw new MtArgumentError(`Scene with name ${scene._scene} is already registered!`) } this._prepareChild(scene) - scene._scene = uid scene._sceneScoped = scoped - this._scenes[uid] = scene + this._scenes[scene._scene] = scene } /** @@ -705,7 +747,7 @@ export class Dispatcher { * * @param child Other dispatcher */ - removeChild(child: Dispatcher): void { + removeChild(child: Dispatcher): void { const idx = this._children.indexOf(child) if (idx > -1) { @@ -732,7 +774,7 @@ export class Dispatcher { * * @param other Other dispatcher */ - extend(other: Dispatcher): void { + extend(other: Dispatcher): void { if (other._customStorage || other._customStateKeyDelegate) { throw new MtArgumentError('Provided dispatcher has custom storage and cannot be extended from.') } @@ -772,7 +814,7 @@ export class Dispatcher { delete myScenes[key] } - this.addScene(key as any, myScenes[key] as any, myScenes[key]._sceneScoped as any) + this.addScene(myScenes[key] as any, myScenes[key]._sceneScoped as any) }) } @@ -791,8 +833,8 @@ export class Dispatcher { * * @param children Whether to also clone children and scenes */ - clone(children = false): Dispatcher { - const dp = new Dispatcher() + clone(children = false): Dispatcher { + const dp = new Dispatcher() // copy handlers. Object.keys(this._groups).forEach((key) => { @@ -819,12 +861,7 @@ export class Dispatcher { if (this._scenes) { Object.keys(this._scenes).forEach((key) => { const scene = this._scenes![key].clone(true) - dp.addScene( - key as any, - scene as any, - - this._scenes![key]._sceneScoped as any, - ) + dp.addScene(scene as any, this._scenes![key]._sceneScoped as any) }) } } @@ -841,7 +878,7 @@ export class Dispatcher { * @param key State storage key * @template S State type, defaults to dispatcher's state type. Only checked at compile-time */ - getState(key: string): UpdateState + getState(key: string): UpdateState /** * Get update state object for the given object. @@ -853,8 +890,8 @@ export class Dispatcher { * @param object Object for which the state should be fetched * @template S State type, defaults to dispatcher's state type. Only checked at compile-time */ - getState(object: Parameters[0]): Promise> - getState(object: string | Parameters[0]): MaybeAsync> { + getState(object: Parameters[0]): Promise> + getState(object: string | Parameters[0]): MaybeAsync> { if (!this._storage) { throw new MtArgumentError('Cannot use getUpdateState() filter without state storage') } @@ -895,7 +932,7 @@ export class Dispatcher { * This will load the state for the given object * ignoring local custom storage, key delegate and scene scope. */ - getGlobalState(object: Parameters[0]): Promise> { + getGlobalState(object: Parameters[0]): Promise> { if (!this._parent) { throw new MtArgumentError('This dispatcher does not have a parent') } @@ -963,10 +1000,7 @@ export class Dispatcher { * @param group Handler group index */ onNewMessage( - handler: NewMessageHandler< - MessageContext, - State extends never ? never : UpdateState - >['callback'], + handler: NewMessageHandler>['callback'], group?: number, ): void @@ -981,7 +1015,7 @@ export class Dispatcher { filter: UpdateFilter, handler: NewMessageHandler< filters.Modify, - State extends never ? never : UpdateState + State extends never ? never : UpdateState >['callback'], group?: number, ): void @@ -997,7 +1031,7 @@ export class Dispatcher { filter: UpdateFilter, handler: NewMessageHandler< filters.Modify, - State extends never ? never : UpdateState + State extends never ? never : UpdateState >['callback'], group?: number, ): void @@ -1014,10 +1048,7 @@ export class Dispatcher { * @param group Handler group index */ onEditMessage( - handler: EditMessageHandler< - MessageContext, - State extends never ? never : UpdateState - >['callback'], + handler: EditMessageHandler>['callback'], group?: number, ): void @@ -1032,7 +1063,7 @@ export class Dispatcher { filter: UpdateFilter, handler: EditMessageHandler< filters.Modify, - State extends never ? never : UpdateState + State extends never ? never : UpdateState >['callback'], group?: number, ): void @@ -1048,7 +1079,7 @@ export class Dispatcher { filter: UpdateFilter, handler: EditMessageHandler< filters.Modify, - State extends never ? never : UpdateState + State extends never ? never : UpdateState >['callback'], group?: number, ): void @@ -1065,10 +1096,7 @@ export class Dispatcher { * @param group Handler group index */ onMessageGroup( - handler: MessageGroupHandler< - MessageContext, - State extends never ? never : UpdateState - >['callback'], + handler: MessageGroupHandler>['callback'], group?: number, ): void @@ -1083,7 +1111,7 @@ export class Dispatcher { filter: UpdateFilter, handler: MessageGroupHandler< filters.Modify, - State extends never ? never : UpdateState + State extends never ? never : UpdateState >['callback'], group?: number, ): void @@ -1099,7 +1127,7 @@ export class Dispatcher { filter: UpdateFilter, handler: MessageGroupHandler< filters.Modify, - State extends never ? never : UpdateState + State extends never ? never : UpdateState >['callback'], group?: number, ): void @@ -1222,7 +1250,7 @@ export class Dispatcher { onCallbackQuery( handler: CallbackQueryHandler< CallbackQueryContext, - State extends never ? never : UpdateState + State extends never ? never : UpdateState >['callback'], group?: number, ): void @@ -1238,7 +1266,7 @@ export class Dispatcher { filter: UpdateFilter, handler: CallbackQueryHandler< filters.Modify, - State extends never ? never : UpdateState + State extends never ? never : UpdateState >['callback'], group?: number, ): void @@ -1254,7 +1282,7 @@ export class Dispatcher { filter: UpdateFilter, handler: CallbackQueryHandler< filters.Modify, - State extends never ? never : UpdateState + State extends never ? never : UpdateState >['callback'], group?: number, ): void diff --git a/packages/dispatcher/src/filters/group.ts b/packages/dispatcher/src/filters/group.ts index b6df1306..6ad76ce2 100644 --- a/packages/dispatcher/src/filters/group.ts +++ b/packages/dispatcher/src/filters/group.ts @@ -12,7 +12,7 @@ import { Modify, UpdateFilter } from './types.js' * @param filter * @returns */ -export function every( +export function every( filter: UpdateFilter, ): UpdateFilter< MessageContext, @@ -57,8 +57,11 @@ export function every( * @param filter * @returns */ -// eslint-disable-next-line -export function some(filter: UpdateFilter): UpdateFilter { +export function some( + // eslint-disable-next-line + filter: UpdateFilter, + // eslint-disable-next-line +): UpdateFilter { return (ctx, state) => { let i = 0 const upds = ctx.messages diff --git a/packages/dispatcher/src/filters/logic.ts b/packages/dispatcher/src/filters/logic.ts index 982091ec..dc8c63af 100644 --- a/packages/dispatcher/src/filters/logic.ts +++ b/packages/dispatcher/src/filters/logic.ts @@ -22,7 +22,7 @@ export const any: UpdateFilter = () => true * * @param fn Filter to negate */ -export function not( +export function not( fn: UpdateFilter, ): UpdateFilter, State> { return (upd, state) => { @@ -37,16 +37,39 @@ export function not( // i couldn't come up with proper types for these 😭 // if you know how to do this better - PRs are welcome! -export function and( +export function and( fn1: UpdateFilter, fn2: UpdateFilter, ): UpdateFilter -export function and( +export function and< + Base1, + Mod1, + State1 extends object, + Base2, + Mod2, + State2 extends object, + Base3, + Mod3, + State3 extends object, +>( fn1: UpdateFilter, fn2: UpdateFilter, fn3: UpdateFilter, ): UpdateFilter -export function and( +export function and< + Base1, + Mod1, + State1 extends object, + Base2, + Mod2, + State2 extends object, + Base3, + Mod3, + State3 extends object, + Base4, + Mod4, + State4 extends object, +>( fn1: UpdateFilter, fn2: UpdateFilter, fn3: UpdateFilter, @@ -55,19 +78,19 @@ export function and( fn1: UpdateFilter, fn2: UpdateFilter, @@ -82,22 +105,22 @@ export function and< export function and< Base1, Mod1, - State1, + State1 extends object, Base2, Mod2, - State2, + State2 extends object, Base3, Mod3, - State3, + State3 extends object, Base4, Mod4, - State4, + State4 extends object, Base5, Mod5, - State5, + State5 extends object, Base6, Mod6, - State6, + State6 extends object, >( fn1: UpdateFilter, fn2: UpdateFilter, @@ -159,18 +182,41 @@ export function and(...fns: UpdateFilter[]): UpdateFilter( +export function or( fn1: UpdateFilter, fn2: UpdateFilter, ): UpdateFilter -export function or( +export function or< + Base1, + Mod1, + State1 extends object, + Base2, + Mod2, + State2 extends object, + Base3, + Mod3, + State3 extends object, +>( fn1: UpdateFilter, fn2: UpdateFilter, fn3: UpdateFilter, ): UpdateFilter -export function or( +export function or< + Base1, + Mod1, + State1 extends object, + Base2, + Mod2, + State2 extends object, + Base3, + Mod3, + State3 extends object, + Base4, + Mod4, + State4 extends object, +>( fn1: UpdateFilter, fn2: UpdateFilter, fn3: UpdateFilter, @@ -180,19 +226,19 @@ export function or( fn1: UpdateFilter, fn2: UpdateFilter, @@ -208,22 +254,22 @@ export function or< export function or< Base1, Mod1, - State1, + State1 extends object, Base2, Mod2, - State2, + State2 extends object, Base3, Mod3, - State3, + State3 extends object, Base4, Mod4, - State4, + State4 extends object, Base5, Mod5, - State5, + State5 extends object, Base6, Mod6, - State6, + State6 extends object, >( fn1: UpdateFilter, fn2: UpdateFilter, diff --git a/packages/dispatcher/src/filters/message.ts b/packages/dispatcher/src/filters/message.ts index 289974bb..3cb56889 100644 --- a/packages/dispatcher/src/filters/message.ts +++ b/packages/dispatcher/src/filters/message.ts @@ -15,6 +15,7 @@ import { Video, } from '@mtcute/client' +import { MessageContext } from '../index.js' import { Modify, UpdateFilter } from './types.js' /** @@ -208,3 +209,25 @@ export const sender = ): UpdateFilter }> => (msg) => msg.sender.type === type + +/** + * Filter that matches messages that are replies to some other message. + * + * Optionally, you can pass a filter that will be applied to the replied message. + */ +export const replyTo = + ( + filter?: UpdateFilter, + ): UpdateFilter Promise }, State> => + async (msg, state) => { + if (!msg.replyToMessageId) return false + + const reply = await msg.getReplyTo() + if (!reply) return false + + msg.getReplyTo = () => Promise.resolve(reply) + + if (!filter) return true + + return filter(reply, state) + } diff --git a/packages/dispatcher/src/filters/state.ts b/packages/dispatcher/src/filters/state.ts index 2d0a0bdf..8892eea5 100644 --- a/packages/dispatcher/src/filters/state.ts +++ b/packages/dispatcher/src/filters/state.ts @@ -20,7 +20,7 @@ export const stateEmpty: UpdateFilter = async (upd, state) => { * * @param predicate State predicate */ -export const state = ( +export const state = ( predicate: (state: T) => MaybeAsync, // eslint-disable-next-line @typescript-eslint/ban-types ): UpdateFilter => { diff --git a/packages/dispatcher/src/filters/types.ts b/packages/dispatcher/src/filters/types.ts index 1083862d..7832bdf7 100644 --- a/packages/dispatcher/src/filters/types.ts +++ b/packages/dispatcher/src/filters/types.ts @@ -75,7 +75,7 @@ import { UpdateState } from '../state/update-state.js' */ // we need the second parameter because it carries meta information // eslint-disable-next-line @typescript-eslint/no-unused-vars -export type UpdateFilter = ( +export type UpdateFilter = ( update: Base, state?: UpdateState, ) => MaybeAsync diff --git a/packages/dispatcher/src/state/key.ts b/packages/dispatcher/src/state/key.ts index 3d888015..fab9f302 100644 --- a/packages/dispatcher/src/state/key.ts +++ b/packages/dispatcher/src/state/key.ts @@ -1,4 +1,4 @@ -import { assertNever, MaybeAsync } from '@mtcute/client' +import { assertNever, Chat, MaybeAsync, User } from '@mtcute/client' import { CallbackQueryContext, MessageContext } from '../context/index.js' @@ -10,7 +10,7 @@ import { CallbackQueryContext, MessageContext } from '../context/index.js' * @param msg Message or callback from which to derive the key * @param scene Current scene UID, or `null` if none */ -export type StateKeyDelegate = (upd: MessageContext | CallbackQueryContext) => MaybeAsync +export type StateKeyDelegate = (upd: MessageContext | CallbackQueryContext | User | Chat) => MaybeAsync /** * Default state key delegate. @@ -24,6 +24,11 @@ export type StateKeyDelegate = (upd: MessageContext | CallbackQueryContext) => M * - If in group/channel/supergroup (i.e. `upd.chatType !== 'user'`), `upd.chatId + '_' + upd.user.id` */ export const defaultStateKeyDelegate: StateKeyDelegate = (upd): string | null => { + if ('type' in upd) { + // User | Chat + return String(upd.id) + } + if (upd._name === 'new_message') { switch (upd.chat.chatType) { case 'private': diff --git a/packages/dispatcher/src/state/storage.ts b/packages/dispatcher/src/state/storage.ts index 4fac20aa..017b611b 100644 --- a/packages/dispatcher/src/state/storage.ts +++ b/packages/dispatcher/src/state/storage.ts @@ -80,3 +80,18 @@ export interface IStateStorage { */ resetRateLimit(key: string): MaybeAsync } + +export function isCompatibleStorage(storage: unknown): storage is IStateStorage { + return ( + typeof storage === 'object' && + storage !== null && + 'getState' in storage && + 'setState' in storage && + 'deleteState' in storage && + 'getCurrentScene' in storage && + 'setCurrentScene' in storage && + 'deleteCurrentScene' in storage && + 'getRateLimit' in storage && + 'resetRateLimit' in storage + ) +} diff --git a/packages/dispatcher/src/state/update-state.ts b/packages/dispatcher/src/state/update-state.ts index 0fc560db..582dd23d 100644 --- a/packages/dispatcher/src/state/update-state.ts +++ b/packages/dispatcher/src/state/update-state.ts @@ -1,6 +1,9 @@ +/* eslint-disable dot-notation */ +/* eslint-disable @typescript-eslint/no-explicit-any */ import { MtArgumentError, MtcuteError } from '@mtcute/client' import { sleep } from '@mtcute/client/utils.js' +import type { Dispatcher } from '../dispatcher.js' import { IStateStorage } from './storage.js' /** @@ -18,13 +21,13 @@ export class RateLimitError extends MtcuteError { * @template State Type that represents the state * @template SceneName Possible scene names */ -export class UpdateState { +export class UpdateState { private _key: string private _localKey!: string private _storage: IStateStorage - private _scene: SceneName | null + private _scene: string | null private _scoped?: boolean private _cached?: State | null @@ -34,7 +37,7 @@ export class UpdateState { constructor( storage: IStateStorage, key: string, - scene: SceneName | null, + scene: string | null, scoped?: boolean, customStorage?: IStateStorage, customKey?: string, @@ -50,7 +53,8 @@ export class UpdateState { this._updateLocalKey() } - get scene(): SceneName | null { + /** Name of the current scene */ + get scene(): string | null { return this._scene } @@ -68,7 +72,7 @@ export class UpdateState { * @param fallback Default state value * @param force Whether to ignore cached state (def. `false`) */ - async get(fallback: State, force?: boolean): Promise + async get(fallback: State | (() => State), force?: boolean): Promise /** * Retrieve the state from the storage, falling back to default @@ -77,27 +81,32 @@ export class UpdateState { * @param fallback Default state value * @param force Whether to ignore cached state (def. `false`) */ - async get(fallback?: State, force?: boolean): Promise + async get(fallback?: State | (() => State), force?: boolean): Promise /** * Retrieve the state from the storage * * @param force Whether to ignore cached state (def. `false`) */ async get(force?: boolean): Promise - async get(fallback?: State | boolean, force?: boolean): Promise { + async get(fallback?: State | (() => State) | boolean, force?: boolean): Promise { if (typeof fallback === 'boolean') { force = fallback fallback = undefined } if (!force && this._cached !== undefined) { - if (!this._cached && fallback) return fallback + if (!this._cached && fallback) { + return typeof fallback === 'function' ? fallback() : fallback + } return this._cached } let res = (await this._localStorage.getState(this._localKey)) as State | null - if (!res && fallback) res = fallback + + if (!res && fallback) { + res = typeof fallback === 'function' ? fallback() : fallback + } this._cached = res return res @@ -128,7 +137,16 @@ export class UpdateState { * @param ttl TTL for the new state (in seconds) * @param forceLoad Whether to force load the old state from storage */ - async merge(state: Partial, fallback?: State, ttl?: number, forceLoad = false): Promise { + async merge( + state: Partial, + params: { + fallback?: State | (() => State) + ttl?: number + forceLoad?: boolean + } = {}, + ): Promise { + const { fallback, ttl, forceLoad } = params + const old = await this.get(forceLoad) if (!old) { @@ -136,7 +154,9 @@ export class UpdateState { throw new MtArgumentError('Cannot use merge on empty state without fallback.') } - await this.set({ ...fallback, ...state }, ttl) + const fallback_ = typeof fallback === 'function' ? fallback() : fallback + + await this.set({ ...fallback_, ...state }, ttl) } else { await this.set({ ...old, ...state }, ttl) } @@ -156,14 +176,43 @@ export class UpdateState { /** * Enter some scene - * - * @param scene Scene name - * @param ttl TTL for the scene (in seconds) */ - async enter(scene: SceneName, ttl?: number): Promise { - this._scene = scene + async enter>( + scene: Scene, + params?: { + /** + * Initial state for the scene + * + * Note that this will only work if the scene uses the same key delegate as this state. + */ + with?: SceneState + + /** TTL for the scene (in seconds) */ + ttl?: number + }, + ): Promise { + const { with: with_, ttl } = params ?? {} + + if (!scene['_scene']) { + throw new MtArgumentError('Cannot enter a non-scene Dispatcher') + } + + if (!scene['_parent']) { + throw new MtArgumentError('This scene has not been registered') + } + + this._scene = scene['_scene'] this._updateLocalKey() - await this._storage.setCurrentScene(this._key, scene, ttl) + + await this._storage.setCurrentScene(this._key, this._scene, ttl) + + if (with_) { + if (scene['_customStateKeyDelegate']) { + throw new MtArgumentError('Cannot use `with` parameter when the scene uses a custom state key delegate') + } + + await scene.getState(this._key).set(with_, ttl) + } } /** diff --git a/packages/dispatcher/src/wizard.ts b/packages/dispatcher/src/wizard.ts index 6957b62b..bf7e6240 100644 --- a/packages/dispatcher/src/wizard.ts +++ b/packages/dispatcher/src/wizard.ts @@ -32,10 +32,7 @@ interface WizardInternalState { * that can be used to simplify implementing * step-by-step scenes. */ -export class WizardScene extends Dispatcher< - State & WizardInternalState, - SceneName -> { +export class WizardScene extends Dispatcher { private _steps = 0 private _defaultState: State & WizardInternalState = {} as State & WizardInternalState @@ -54,7 +51,7 @@ export class WizardScene extends Dispa /** * Go to the Nth step */ - async goToStep(state: UpdateState, step: number) { + async goToStep(state: UpdateState, step: number) { if (step >= this._steps) { await state.exit() } else { @@ -65,7 +62,7 @@ export class WizardScene extends Dispa /** * Skip N steps */ - async skip(state: UpdateState, count = 1) { + async skip(state: UpdateState, count = 1) { const { $step } = (await state.get()) || {} if ($step === undefined) throw new Error('Wizard state is not initialized') @@ -96,7 +93,7 @@ export class WizardScene extends Dispa addStep( handler: ( msg: MessageContext, - state: UpdateState, + state: UpdateState, ) => MaybeAsync, ): void { const step = this._steps++