refactor(dispatcher): improved surface api

This commit is contained in:
alina 🌸 2023-10-26 22:22:40 +03:00
parent c061581abb
commit 022481966b
Signed by: teidesu
SSH key fingerprint: SHA256:uNeCpw6aTSU4aIObXLvHfLkDa82HWH9EiOj9AXOIRpI
11 changed files with 342 additions and 177 deletions

View file

@ -1,5 +1,4 @@
/* eslint-disable no-restricted-globals */
const { types, toSentence, replaceSections, formatFile } = require('../../client/scripts/generate-updates')
const { types, toSentence, replaceSections, formatFile } = require('../../client/scripts/generate-updates.cjs')
function generateHandler() {
const lines = []
@ -41,7 +40,7 @@ function generateDispatcher() {
* @param group Handler group index
*/
on${type.handlerTypeName}(handler: ${type.handlerTypeName}Handler${
type.state ? `<${type.context}, State extends never ? never : UpdateState<State, SceneName>>` : ''
type.state ? `<${type.context}, State extends never ? never : UpdateState<State>>` : ''
}['callback'], group?: number): void
${
@ -58,7 +57,7 @@ ${
filter: UpdateFilter<${type.context}, Mod, State>,
handler: ${type.handlerTypeName}Handler<filters.Modify<${
type.context
}, Mod>, State extends never ? never : UpdateState<State, SceneName>>['callback'],
}, Mod>, State extends never ? never : UpdateState<State>>['callback'],
group?: number
): void
` :
@ -75,7 +74,7 @@ ${
on${type.handlerTypeName}<Mod>(
filter: UpdateFilter<${type.context}, Mod>,
handler: ${type.handlerTypeName}Handler<filters.Modify<${type.context}, Mod>${
type.state ? ', State extends never ? never : UpdateState<State, SceneName>' : ''
type.state ? ', State extends never ? never : UpdateState<State>' : ''
}>['callback'],
group?: number
): void

View file

@ -60,22 +60,49 @@ import {
} from './handler.js'
// end-codegen-imports
import { PropagationAction } from './propagation.js'
import { defaultStateKeyDelegate, IStateStorage, StateKeyDelegate, UpdateState } from './state/index.js'
import {
defaultStateKeyDelegate,
isCompatibleStorage,
IStateStorage,
StateKeyDelegate,
UpdateState,
} from './state/index.js'
export interface DispatcherParams {
/**
* If this dispatcher can be used as a scene, its unique name.
*
* Should not be set manually, use {@link Dispatcher#scene} instead
*/
sceneName?: string
/**
* Custom storage for the dispatcher.
*
* @default Client's storage
*/
storage?: IStateStorage
/**
* Custom key delegate for the dispatcher.
*/
key?: StateKeyDelegate
}
/**
* Updates dispatcher
*/
export class Dispatcher<State = never, SceneName extends string = string> {
export class Dispatcher<State extends object = never> {
private _groups: Record<number, Record<UpdateHandler['name'], UpdateHandler[]>> = {}
private _groupsOrder: number[] = []
private _client?: TelegramClient
private _parent?: Dispatcher<any>
private _children: Dispatcher<any, any>[] = []
private _children: Dispatcher<any>[] = []
private _scenes?: Record<string, Dispatcher<any, SceneName>>
private _scene?: SceneName
private _scenes?: Record<string, Dispatcher<any>>
private _scene?: string
private _sceneScoped?: boolean
private _storage?: State extends never ? undefined : IStateStorage
@ -87,65 +114,94 @@ export class Dispatcher<State = never, SceneName extends string = string> {
private _errorHandler?: <T = {}>(
err: Error,
update: ParsedUpdate & T,
state?: UpdateState<State, SceneName>,
state?: UpdateState<State>,
) => MaybeAsync<boolean>
private _preUpdateHandler?: <T = {}>(
update: ParsedUpdate & T,
state?: UpdateState<State, SceneName>,
state?: UpdateState<State>,
) => MaybeAsync<PropagationAction | void>
private _postUpdateHandler?: <T = {}>(
handled: boolean,
update: ParsedUpdate & T,
state?: UpdateState<State, SceneName>,
state?: UpdateState<State>,
) => MaybeAsync<void>
/**
* Create a new dispatcher, that will be used as a child,
* optionally providing a custom key delegate
*/
constructor(key?: StateKeyDelegate)
/**
* Create a new dispatcher, that will be used as a child, optionally
* providing custom storage and key delegate
*/
constructor(storage: IStateStorage, key?: StateKeyDelegate)
/**
* Create a new dispatcher and bind it to client and optionally
* FSM storage
*/
constructor(
client: TelegramClient,
...args: (() => State) extends () => never ? [] : [IStateStorage, StateKeyDelegate?]
)
constructor(
client?: TelegramClient | IStateStorage | StateKeyDelegate,
storage?: IStateStorage | StateKeyDelegate,
key?: StateKeyDelegate,
) {
protected constructor(client?: TelegramClient, params?: DispatcherParams) {
this.dispatchRawUpdate = this.dispatchRawUpdate.bind(this)
this.dispatchUpdate = this.dispatchUpdate.bind(this)
// eslint-disable-next-line prefer-const
let { storage, key, sceneName } = params ?? {}
if (client) {
if (client instanceof TelegramClient) {
this.bindToClient(client)
if (!storage) {
const _storage = client.storage
if (!isCompatibleStorage(_storage)) {
throw new MtArgumentError(
'Storage used by the client is not compatible with the dispatcher. Please provide a compatible storage manually',
)
}
storage = _storage
}
if (storage) {
this._storage = storage as any
this._stateKeyDelegate = (key ?? defaultStateKeyDelegate) as any
}
} else if (typeof client === 'function') {
// is StateKeyDelegate
this._customStateKeyDelegate = client as any
} else {
this._customStorage = client as any
// child dispatcher without client
if (storage) {
this._customStateKeyDelegate = client as any
this._customStorage = storage as any
}
if (key) {
this._customStateKeyDelegate = key as any
}
if (sceneName) {
if (sceneName[0] === '$') {
throw new MtArgumentError('Scene name cannot start with $')
}
this._scene = sceneName
}
}
}
/**
* Create a new dispatcher and bind it to the client.
*/
static for<State extends object = never>(client: TelegramClient, params?: DispatcherParams): Dispatcher<State> {
return new Dispatcher<State>(client, params)
}
/**
* Create a new child dispatcher.
*/
static child<State extends object = never>(params?: DispatcherParams): Dispatcher<State> {
return new Dispatcher<State>(undefined, params)
}
/**
* Create a new scene dispatcher
*/
static scene<State extends object = Record<never, never>>(
name: string,
params?: Omit<DispatcherParams, 'sceneName'>,
): Dispatcher<State> {
return new Dispatcher<State>(undefined, { sceneName: name, ...params })
}
/** For scene dispatchers, name of the scene */
get sceneName(): string | undefined {
return this._scene
}
/**
@ -289,7 +345,7 @@ export class Dispatcher<State = never, SceneName extends string = string> {
private async _dispatchUpdateNowImpl(
update: ParsedUpdate,
// this is getting a bit crazy lol
parsedState?: UpdateState<State, SceneName> | null,
parsedState?: UpdateState<State> | null,
parsedScene?: string | null,
forceScene?: true,
parsedContext?: UpdateContextType,
@ -525,9 +581,7 @@ export class Dispatcher<State = never, SceneName extends string = string> {
* @param handler Error handler
*/
onError<T = {}>(
handler:
| ((err: Error, update: ParsedUpdate & T, state?: UpdateState<State, SceneName>) => MaybeAsync<boolean>)
| null,
handler: ((err: Error, update: ParsedUpdate & T, state?: UpdateState<State>) => MaybeAsync<boolean>) | null,
): void {
if (handler) this._errorHandler = handler
else this._errorHandler = undefined
@ -547,10 +601,7 @@ export class Dispatcher<State = never, SceneName extends string = string> {
*/
onPreUpdate<T = {}>(
handler:
| ((
update: ParsedUpdate & T,
state?: UpdateState<State, SceneName>,
) => MaybeAsync<PropagationAction | void>)
| ((update: ParsedUpdate & T, state?: UpdateState<State>) => MaybeAsync<PropagationAction | void>)
| null,
): void {
if (handler) this._preUpdateHandler = handler
@ -570,9 +621,7 @@ export class Dispatcher<State = never, SceneName extends string = string> {
* @param handler Pre-update middleware
*/
onPostUpdate<T = {}>(
handler:
| ((handled: boolean, update: ParsedUpdate & T, state?: UpdateState<State, SceneName>) => MaybeAsync<void>)
| null,
handler: ((handled: boolean, update: ParsedUpdate & T, state?: UpdateState<State>) => MaybeAsync<void>) | null,
): void {
if (handler) this._postUpdateHandler = handler
else this._postUpdateHandler = undefined
@ -582,17 +631,13 @@ export class Dispatcher<State = never, SceneName extends string = string> {
* Set error handler that will propagate
* the error to the parent dispatcher
*/
propagateErrorToParent(
err: Error,
update: ParsedUpdate,
state?: UpdateState<State, SceneName>,
): MaybeAsync<boolean> {
propagateErrorToParent(err: Error, update: ParsedUpdate, state?: UpdateState<State>): MaybeAsync<boolean> {
if (!this.parent) {
throw new MtArgumentError('This dispatcher is not a child')
}
if (this.parent._errorHandler) {
return this.parent._errorHandler(err, update, state)
return this.parent._errorHandler(err, update, state as any)
}
throw err
}
@ -607,7 +652,7 @@ export class Dispatcher<State = never, SceneName extends string = string> {
return this._parent ?? null
}
private _prepareChild(child: Dispatcher<any, any>): void {
private _prepareChild(child: Dispatcher<any>): void {
if (child._client) {
throw new MtArgumentError(
'Provided dispatcher is ' +
@ -638,7 +683,7 @@ export class Dispatcher<State = never, SceneName extends string = string> {
*
* @param child Other dispatcher
*/
addChild(child: Dispatcher<State, SceneName>): void {
addChild(child: Dispatcher<State>): void {
if (this._children.includes(child)) return
this._prepareChild(child)
@ -658,7 +703,7 @@ export class Dispatcher<State = never, SceneName extends string = string> {
* @param scene Dispatcher representing the scene
* @param scoped Whether to use scoped FSM storage for the scene
*/
addScene(uid: SceneName, scene: Dispatcher<State, SceneName>, scoped: false): void
addScene(scene: Dispatcher<State>, scoped: false): void
/**
* Add a dispatcher as a scene with a scoped state
*
@ -672,26 +717,23 @@ export class Dispatcher<State = never, SceneName extends string = string> {
* @param scene Dispatcher representing the scene
* @param scoped Whether to use scoped FSM storage for the scene (defaults to `true`)
*/
addScene(uid: SceneName, scene: Dispatcher<any, SceneName>, scoped?: true): void
addScene(uid: SceneName, scene: Dispatcher<any, SceneName>, scoped = true): void {
addScene(scene: Dispatcher<any>, scoped?: true): void
addScene(scene: Dispatcher<any>, scoped = true): void {
if (!this._scenes) this._scenes = {}
if (uid in this._scenes) {
throw new MtArgumentError(`Scene with UID ${uid} is already registered!`)
if (!scene._scene) {
throw new MtArgumentError(
'Non-scene dispatcher passed to addScene. Use `Dispatcher.scene()` to create one.',
)
}
if (uid[0] === '$') {
throw new MtArgumentError('Scene UID cannot start with $')
}
if (scene._scene) {
throw new MtArgumentError(`This dispatcher is already registered as scene ${scene._scene}`)
if (scene._scene in this._scenes) {
throw new MtArgumentError(`Scene with name ${scene._scene} is already registered!`)
}
this._prepareChild(scene)
scene._scene = uid
scene._sceneScoped = scoped
this._scenes[uid] = scene
this._scenes[scene._scene] = scene
}
/**
@ -705,7 +747,7 @@ export class Dispatcher<State = never, SceneName extends string = string> {
*
* @param child Other dispatcher
*/
removeChild(child: Dispatcher<any, any>): void {
removeChild(child: Dispatcher<any>): void {
const idx = this._children.indexOf(child)
if (idx > -1) {
@ -732,7 +774,7 @@ export class Dispatcher<State = never, SceneName extends string = string> {
*
* @param other Other dispatcher
*/
extend(other: Dispatcher<State, SceneName>): void {
extend(other: Dispatcher<State>): void {
if (other._customStorage || other._customStateKeyDelegate) {
throw new MtArgumentError('Provided dispatcher has custom storage and cannot be extended from.')
}
@ -772,7 +814,7 @@ export class Dispatcher<State = never, SceneName extends string = string> {
delete myScenes[key]
}
this.addScene(key as any, myScenes[key] as any, myScenes[key]._sceneScoped as any)
this.addScene(myScenes[key] as any, myScenes[key]._sceneScoped as any)
})
}
@ -791,8 +833,8 @@ export class Dispatcher<State = never, SceneName extends string = string> {
*
* @param children Whether to also clone children and scenes
*/
clone(children = false): Dispatcher<State, SceneName> {
const dp = new Dispatcher<State, SceneName>()
clone(children = false): Dispatcher<State> {
const dp = new Dispatcher<State>()
// copy handlers.
Object.keys(this._groups).forEach((key) => {
@ -819,12 +861,7 @@ export class Dispatcher<State = never, SceneName extends string = string> {
if (this._scenes) {
Object.keys(this._scenes).forEach((key) => {
const scene = this._scenes![key].clone(true)
dp.addScene(
key as any,
scene as any,
this._scenes![key]._sceneScoped as any,
)
dp.addScene(scene as any, this._scenes![key]._sceneScoped as any)
})
}
}
@ -841,7 +878,7 @@ export class Dispatcher<State = never, SceneName extends string = string> {
* @param key State storage key
* @template S State type, defaults to dispatcher's state type. Only checked at compile-time
*/
getState<S = State>(key: string): UpdateState<S, SceneName>
getState<S extends object = State>(key: string): UpdateState<S>
/**
* Get update state object for the given object.
@ -853,8 +890,8 @@ export class Dispatcher<State = never, SceneName extends string = string> {
* @param object Object for which the state should be fetched
* @template S State type, defaults to dispatcher's state type. Only checked at compile-time
*/
getState<S = State>(object: Parameters<StateKeyDelegate>[0]): Promise<UpdateState<S, SceneName>>
getState<S = State>(object: string | Parameters<StateKeyDelegate>[0]): MaybeAsync<UpdateState<S, SceneName>> {
getState<S extends object = State>(object: Parameters<StateKeyDelegate>[0]): Promise<UpdateState<S>>
getState<S extends object = State>(object: string | Parameters<StateKeyDelegate>[0]): MaybeAsync<UpdateState<S>> {
if (!this._storage) {
throw new MtArgumentError('Cannot use getUpdateState() filter without state storage')
}
@ -895,7 +932,7 @@ export class Dispatcher<State = never, SceneName extends string = string> {
* This will load the state for the given object
* ignoring local custom storage, key delegate and scene scope.
*/
getGlobalState<T>(object: Parameters<StateKeyDelegate>[0]): Promise<UpdateState<T, SceneName>> {
getGlobalState<T extends object>(object: Parameters<StateKeyDelegate>[0]): Promise<UpdateState<T>> {
if (!this._parent) {
throw new MtArgumentError('This dispatcher does not have a parent')
}
@ -963,10 +1000,7 @@ export class Dispatcher<State = never, SceneName extends string = string> {
* @param group Handler group index
*/
onNewMessage(
handler: NewMessageHandler<
MessageContext,
State extends never ? never : UpdateState<State, SceneName>
>['callback'],
handler: NewMessageHandler<MessageContext, State extends never ? never : UpdateState<State>>['callback'],
group?: number,
): void
@ -981,7 +1015,7 @@ export class Dispatcher<State = never, SceneName extends string = string> {
filter: UpdateFilter<MessageContext, Mod, State>,
handler: NewMessageHandler<
filters.Modify<MessageContext, Mod>,
State extends never ? never : UpdateState<State, SceneName>
State extends never ? never : UpdateState<State>
>['callback'],
group?: number,
): void
@ -997,7 +1031,7 @@ export class Dispatcher<State = never, SceneName extends string = string> {
filter: UpdateFilter<MessageContext, Mod>,
handler: NewMessageHandler<
filters.Modify<MessageContext, Mod>,
State extends never ? never : UpdateState<State, SceneName>
State extends never ? never : UpdateState<State>
>['callback'],
group?: number,
): void
@ -1014,10 +1048,7 @@ export class Dispatcher<State = never, SceneName extends string = string> {
* @param group Handler group index
*/
onEditMessage(
handler: EditMessageHandler<
MessageContext,
State extends never ? never : UpdateState<State, SceneName>
>['callback'],
handler: EditMessageHandler<MessageContext, State extends never ? never : UpdateState<State>>['callback'],
group?: number,
): void
@ -1032,7 +1063,7 @@ export class Dispatcher<State = never, SceneName extends string = string> {
filter: UpdateFilter<MessageContext, Mod, State>,
handler: EditMessageHandler<
filters.Modify<MessageContext, Mod>,
State extends never ? never : UpdateState<State, SceneName>
State extends never ? never : UpdateState<State>
>['callback'],
group?: number,
): void
@ -1048,7 +1079,7 @@ export class Dispatcher<State = never, SceneName extends string = string> {
filter: UpdateFilter<MessageContext, Mod>,
handler: EditMessageHandler<
filters.Modify<MessageContext, Mod>,
State extends never ? never : UpdateState<State, SceneName>
State extends never ? never : UpdateState<State>
>['callback'],
group?: number,
): void
@ -1065,10 +1096,7 @@ export class Dispatcher<State = never, SceneName extends string = string> {
* @param group Handler group index
*/
onMessageGroup(
handler: MessageGroupHandler<
MessageContext,
State extends never ? never : UpdateState<State, SceneName>
>['callback'],
handler: MessageGroupHandler<MessageContext, State extends never ? never : UpdateState<State>>['callback'],
group?: number,
): void
@ -1083,7 +1111,7 @@ export class Dispatcher<State = never, SceneName extends string = string> {
filter: UpdateFilter<MessageContext, Mod, State>,
handler: MessageGroupHandler<
filters.Modify<MessageContext, Mod>,
State extends never ? never : UpdateState<State, SceneName>
State extends never ? never : UpdateState<State>
>['callback'],
group?: number,
): void
@ -1099,7 +1127,7 @@ export class Dispatcher<State = never, SceneName extends string = string> {
filter: UpdateFilter<MessageContext, Mod>,
handler: MessageGroupHandler<
filters.Modify<MessageContext, Mod>,
State extends never ? never : UpdateState<State, SceneName>
State extends never ? never : UpdateState<State>
>['callback'],
group?: number,
): void
@ -1222,7 +1250,7 @@ export class Dispatcher<State = never, SceneName extends string = string> {
onCallbackQuery(
handler: CallbackQueryHandler<
CallbackQueryContext,
State extends never ? never : UpdateState<State, SceneName>
State extends never ? never : UpdateState<State>
>['callback'],
group?: number,
): void
@ -1238,7 +1266,7 @@ export class Dispatcher<State = never, SceneName extends string = string> {
filter: UpdateFilter<CallbackQueryContext, Mod, State>,
handler: CallbackQueryHandler<
filters.Modify<CallbackQueryContext, Mod>,
State extends never ? never : UpdateState<State, SceneName>
State extends never ? never : UpdateState<State>
>['callback'],
group?: number,
): void
@ -1254,7 +1282,7 @@ export class Dispatcher<State = never, SceneName extends string = string> {
filter: UpdateFilter<CallbackQueryContext, Mod>,
handler: CallbackQueryHandler<
filters.Modify<CallbackQueryContext, Mod>,
State extends never ? never : UpdateState<State, SceneName>
State extends never ? never : UpdateState<State>
>['callback'],
group?: number,
): void

View file

@ -12,7 +12,7 @@ import { Modify, UpdateFilter } from './types.js'
* @param filter
* @returns
*/
export function every<Mod, State>(
export function every<Mod, State extends object>(
filter: UpdateFilter<Message, Mod, State>,
): UpdateFilter<
MessageContext,
@ -57,8 +57,11 @@ export function every<Mod, State>(
* @param filter
* @returns
*/
export function some<State extends object>(
// eslint-disable-next-line
export function some<State>(filter: UpdateFilter<Message, any, State>): UpdateFilter<MessageContext, {}, State> {
filter: UpdateFilter<Message, any, State>,
// eslint-disable-next-line
): UpdateFilter<MessageContext, {}, State> {
return (ctx, state) => {
let i = 0
const upds = ctx.messages

View file

@ -22,7 +22,7 @@ export const any: UpdateFilter<any> = () => true
*
* @param fn Filter to negate
*/
export function not<Base, Mod, State>(
export function not<Base, Mod, State extends object>(
fn: UpdateFilter<Base, Mod, State>,
): UpdateFilter<Base, Invert<Base, Mod>, State> {
return (upd, state) => {
@ -37,16 +37,39 @@ export function not<Base, Mod, State>(
// i couldn't come up with proper types for these 😭
// if you know how to do this better - PRs are welcome!
export function and<Base1, Mod1, State1, Base2, Mod2, State2>(
export function and<Base1, Mod1, State1 extends object, Base2, Mod2, State2 extends object>(
fn1: UpdateFilter<Base1, Mod1, State1>,
fn2: UpdateFilter<Base2, Mod2, State2>,
): UpdateFilter<Base1 & Base2, Mod1 & Mod2, State1 | State2>
export function and<Base1, Mod1, State1, Base2, Mod2, State2, Base3, Mod3, State3>(
export function and<
Base1,
Mod1,
State1 extends object,
Base2,
Mod2,
State2 extends object,
Base3,
Mod3,
State3 extends object,
>(
fn1: UpdateFilter<Base1, Mod1, State1>,
fn2: UpdateFilter<Base2, Mod2, State2>,
fn3: UpdateFilter<Base3, Mod3, State3>,
): UpdateFilter<Base1 & Base2 & Base3, Mod1 & Mod2 & Mod3, State1 | State2 | State3>
export function and<Base1, Mod1, State1, Base2, Mod2, State2, Base3, Mod3, State3, Base4, Mod4, State4>(
export function and<
Base1,
Mod1,
State1 extends object,
Base2,
Mod2,
State2 extends object,
Base3,
Mod3,
State3 extends object,
Base4,
Mod4,
State4 extends object,
>(
fn1: UpdateFilter<Base1, Mod1, State1>,
fn2: UpdateFilter<Base2, Mod2, State2>,
fn3: UpdateFilter<Base3, Mod3, State3>,
@ -55,19 +78,19 @@ export function and<Base1, Mod1, State1, Base2, Mod2, State2, Base3, Mod3, State
export function and<
Base1,
Mod1,
State1,
State1 extends object,
Base2,
Mod2,
State2,
State2 extends object,
Base3,
Mod3,
State3,
State3 extends object,
Base4,
Mod4,
State4,
State4 extends object,
Base5,
Mod5,
State5,
State5 extends object,
>(
fn1: UpdateFilter<Base1, Mod1, State1>,
fn2: UpdateFilter<Base2, Mod2, State2>,
@ -82,22 +105,22 @@ export function and<
export function and<
Base1,
Mod1,
State1,
State1 extends object,
Base2,
Mod2,
State2,
State2 extends object,
Base3,
Mod3,
State3,
State3 extends object,
Base4,
Mod4,
State4,
State4 extends object,
Base5,
Mod5,
State5,
State5 extends object,
Base6,
Mod6,
State6,
State6 extends object,
>(
fn1: UpdateFilter<Base1, Mod1, State1>,
fn2: UpdateFilter<Base2, Mod2, State2>,
@ -159,18 +182,41 @@ export function and(...fns: UpdateFilter<any, any, any>[]): UpdateFilter<any, an
}
}
export function or<Base1, Mod1, State1, Base2, Mod2, State2>(
export function or<Base1, Mod1, State1 extends object, Base2, Mod2, State2 extends object>(
fn1: UpdateFilter<Base1, Mod1, State1>,
fn2: UpdateFilter<Base2, Mod2, State2>,
): UpdateFilter<Base1 & Base2, Mod1 | Mod2, State1 | State2>
export function or<Base1, Mod1, State1, Base2, Mod2, State2, Base3, Mod3, State3>(
export function or<
Base1,
Mod1,
State1 extends object,
Base2,
Mod2,
State2 extends object,
Base3,
Mod3,
State3 extends object,
>(
fn1: UpdateFilter<Base1, Mod1, State1>,
fn2: UpdateFilter<Base2, Mod2, State2>,
fn3: UpdateFilter<Base3, Mod3, State3>,
): UpdateFilter<Base1 & Base2 & Base3, Mod1 | Mod2 | Mod3, State1 | State2 | State3>
export function or<Base1, Mod1, State1, Base2, Mod2, State2, Base3, Mod3, State3, Base4, Mod4, State4>(
export function or<
Base1,
Mod1,
State1 extends object,
Base2,
Mod2,
State2 extends object,
Base3,
Mod3,
State3 extends object,
Base4,
Mod4,
State4 extends object,
>(
fn1: UpdateFilter<Base1, Mod1, State1>,
fn2: UpdateFilter<Base2, Mod2, State2>,
fn3: UpdateFilter<Base3, Mod3, State3>,
@ -180,19 +226,19 @@ export function or<Base1, Mod1, State1, Base2, Mod2, State2, Base3, Mod3, State3
export function or<
Base1,
Mod1,
State1,
State1 extends object,
Base2,
Mod2,
State2,
State2 extends object,
Base3,
Mod3,
State3,
State3 extends object,
Base4,
Mod4,
State4,
State4 extends object,
Base5,
Mod5,
State5,
State5 extends object,
>(
fn1: UpdateFilter<Base1, Mod1, State1>,
fn2: UpdateFilter<Base2, Mod2, State2>,
@ -208,22 +254,22 @@ export function or<
export function or<
Base1,
Mod1,
State1,
State1 extends object,
Base2,
Mod2,
State2,
State2 extends object,
Base3,
Mod3,
State3,
State3 extends object,
Base4,
Mod4,
State4,
State4 extends object,
Base5,
Mod5,
State5,
State5 extends object,
Base6,
Mod6,
State6,
State6 extends object,
>(
fn1: UpdateFilter<Base1, Mod1, State1>,
fn2: UpdateFilter<Base2, Mod2, State2>,

View file

@ -15,6 +15,7 @@ import {
Video,
} from '@mtcute/client'
import { MessageContext } from '../index.js'
import { Modify, UpdateFilter } from './types.js'
/**
@ -208,3 +209,25 @@ export const sender =
): UpdateFilter<Message, { sender: Extract<Message['sender'], { type: T }> }> =>
(msg) =>
msg.sender.type === type
/**
* Filter that matches messages that are replies to some other message.
*
* Optionally, you can pass a filter that will be applied to the replied message.
*/
export const replyTo =
<Mod, State extends object>(
filter?: UpdateFilter<Message, Mod, State>,
): UpdateFilter<MessageContext, { getReplyTo: () => Promise<Message & Mod> }, State> =>
async (msg, state) => {
if (!msg.replyToMessageId) return false
const reply = await msg.getReplyTo()
if (!reply) return false
msg.getReplyTo = () => Promise.resolve(reply)
if (!filter) return true
return filter(reply, state)
}

View file

@ -20,7 +20,7 @@ export const stateEmpty: UpdateFilter<any> = async (upd, state) => {
*
* @param predicate State predicate
*/
export const state = <T>(
export const state = <T extends object>(
predicate: (state: T) => MaybeAsync<boolean>,
// eslint-disable-next-line @typescript-eslint/ban-types
): UpdateFilter<any, {}, T> => {

View file

@ -75,7 +75,7 @@ import { UpdateState } from '../state/update-state.js'
*/
// we need the second parameter because it carries meta information
// eslint-disable-next-line @typescript-eslint/no-unused-vars
export type UpdateFilter<Base, Mod = {}, State = never> = (
export type UpdateFilter<Base, Mod = {}, State extends object = never> = (
update: Base,
state?: UpdateState<State>,
) => MaybeAsync<boolean>

View file

@ -1,4 +1,4 @@
import { assertNever, MaybeAsync } from '@mtcute/client'
import { assertNever, Chat, MaybeAsync, User } from '@mtcute/client'
import { CallbackQueryContext, MessageContext } from '../context/index.js'
@ -10,7 +10,7 @@ import { CallbackQueryContext, MessageContext } from '../context/index.js'
* @param msg Message or callback from which to derive the key
* @param scene Current scene UID, or `null` if none
*/
export type StateKeyDelegate = (upd: MessageContext | CallbackQueryContext) => MaybeAsync<string | null>
export type StateKeyDelegate = (upd: MessageContext | CallbackQueryContext | User | Chat) => MaybeAsync<string | null>
/**
* Default state key delegate.
@ -24,6 +24,11 @@ export type StateKeyDelegate = (upd: MessageContext | CallbackQueryContext) => M
* - If in group/channel/supergroup (i.e. `upd.chatType !== 'user'`), `upd.chatId + '_' + upd.user.id`
*/
export const defaultStateKeyDelegate: StateKeyDelegate = (upd): string | null => {
if ('type' in upd) {
// User | Chat
return String(upd.id)
}
if (upd._name === 'new_message') {
switch (upd.chat.chatType) {
case 'private':

View file

@ -80,3 +80,18 @@ export interface IStateStorage {
*/
resetRateLimit(key: string): MaybeAsync<void>
}
export function isCompatibleStorage(storage: unknown): storage is IStateStorage {
return (
typeof storage === 'object' &&
storage !== null &&
'getState' in storage &&
'setState' in storage &&
'deleteState' in storage &&
'getCurrentScene' in storage &&
'setCurrentScene' in storage &&
'deleteCurrentScene' in storage &&
'getRateLimit' in storage &&
'resetRateLimit' in storage
)
}

View file

@ -1,6 +1,9 @@
/* eslint-disable dot-notation */
/* eslint-disable @typescript-eslint/no-explicit-any */
import { MtArgumentError, MtcuteError } from '@mtcute/client'
import { sleep } from '@mtcute/client/utils.js'
import type { Dispatcher } from '../dispatcher.js'
import { IStateStorage } from './storage.js'
/**
@ -18,13 +21,13 @@ export class RateLimitError extends MtcuteError {
* @template State Type that represents the state
* @template SceneName Possible scene names
*/
export class UpdateState<State, SceneName extends string = string> {
export class UpdateState<State extends object> {
private _key: string
private _localKey!: string
private _storage: IStateStorage
private _scene: SceneName | null
private _scene: string | null
private _scoped?: boolean
private _cached?: State | null
@ -34,7 +37,7 @@ export class UpdateState<State, SceneName extends string = string> {
constructor(
storage: IStateStorage,
key: string,
scene: SceneName | null,
scene: string | null,
scoped?: boolean,
customStorage?: IStateStorage,
customKey?: string,
@ -50,7 +53,8 @@ export class UpdateState<State, SceneName extends string = string> {
this._updateLocalKey()
}
get scene(): SceneName | null {
/** Name of the current scene */
get scene(): string | null {
return this._scene
}
@ -68,7 +72,7 @@ export class UpdateState<State, SceneName extends string = string> {
* @param fallback Default state value
* @param force Whether to ignore cached state (def. `false`)
*/
async get(fallback: State, force?: boolean): Promise<State>
async get(fallback: State | (() => State), force?: boolean): Promise<State>
/**
* Retrieve the state from the storage, falling back to default
@ -77,27 +81,32 @@ export class UpdateState<State, SceneName extends string = string> {
* @param fallback Default state value
* @param force Whether to ignore cached state (def. `false`)
*/
async get(fallback?: State, force?: boolean): Promise<State | null>
async get(fallback?: State | (() => State), force?: boolean): Promise<State | null>
/**
* Retrieve the state from the storage
*
* @param force Whether to ignore cached state (def. `false`)
*/
async get(force?: boolean): Promise<State | null>
async get(fallback?: State | boolean, force?: boolean): Promise<State | null> {
async get(fallback?: State | (() => State) | boolean, force?: boolean): Promise<State | null> {
if (typeof fallback === 'boolean') {
force = fallback
fallback = undefined
}
if (!force && this._cached !== undefined) {
if (!this._cached && fallback) return fallback
if (!this._cached && fallback) {
return typeof fallback === 'function' ? fallback() : fallback
}
return this._cached
}
let res = (await this._localStorage.getState(this._localKey)) as State | null
if (!res && fallback) res = fallback
if (!res && fallback) {
res = typeof fallback === 'function' ? fallback() : fallback
}
this._cached = res
return res
@ -128,7 +137,16 @@ export class UpdateState<State, SceneName extends string = string> {
* @param ttl TTL for the new state (in seconds)
* @param forceLoad Whether to force load the old state from storage
*/
async merge(state: Partial<State>, fallback?: State, ttl?: number, forceLoad = false): Promise<State> {
async merge(
state: Partial<State>,
params: {
fallback?: State | (() => State)
ttl?: number
forceLoad?: boolean
} = {},
): Promise<State> {
const { fallback, ttl, forceLoad } = params
const old = await this.get(forceLoad)
if (!old) {
@ -136,7 +154,9 @@ export class UpdateState<State, SceneName extends string = string> {
throw new MtArgumentError('Cannot use merge on empty state without fallback.')
}
await this.set({ ...fallback, ...state }, ttl)
const fallback_ = typeof fallback === 'function' ? fallback() : fallback
await this.set({ ...fallback_, ...state }, ttl)
} else {
await this.set({ ...old, ...state }, ttl)
}
@ -156,14 +176,43 @@ export class UpdateState<State, SceneName extends string = string> {
/**
* Enter some scene
*
* @param scene Scene name
* @param ttl TTL for the scene (in seconds)
*/
async enter(scene: SceneName, ttl?: number): Promise<void> {
this._scene = scene
async enter<SceneState extends object, Scene extends Dispatcher<SceneState>>(
scene: Scene,
params?: {
/**
* Initial state for the scene
*
* Note that this will only work if the scene uses the same key delegate as this state.
*/
with?: SceneState
/** TTL for the scene (in seconds) */
ttl?: number
},
): Promise<void> {
const { with: with_, ttl } = params ?? {}
if (!scene['_scene']) {
throw new MtArgumentError('Cannot enter a non-scene Dispatcher')
}
if (!scene['_parent']) {
throw new MtArgumentError('This scene has not been registered')
}
this._scene = scene['_scene']
this._updateLocalKey()
await this._storage.setCurrentScene(this._key, scene, ttl)
await this._storage.setCurrentScene(this._key, this._scene, ttl)
if (with_) {
if (scene['_customStateKeyDelegate']) {
throw new MtArgumentError('Cannot use `with` parameter when the scene uses a custom state key delegate')
}
await scene.getState(this._key).set(with_, ttl)
}
}
/**

View file

@ -32,10 +32,7 @@ interface WizardInternalState {
* that can be used to simplify implementing
* step-by-step scenes.
*/
export class WizardScene<State, SceneName extends string = string> extends Dispatcher<
State & WizardInternalState,
SceneName
> {
export class WizardScene<State extends object> extends Dispatcher<State & WizardInternalState> {
private _steps = 0
private _defaultState: State & WizardInternalState = {} as State & WizardInternalState
@ -54,7 +51,7 @@ export class WizardScene<State, SceneName extends string = string> extends Dispa
/**
* Go to the Nth step
*/
async goToStep(state: UpdateState<WizardInternalState, SceneName>, step: number) {
async goToStep(state: UpdateState<WizardInternalState>, step: number) {
if (step >= this._steps) {
await state.exit()
} else {
@ -65,7 +62,7 @@ export class WizardScene<State, SceneName extends string = string> extends Dispa
/**
* Skip N steps
*/
async skip(state: UpdateState<WizardInternalState, SceneName>, count = 1) {
async skip(state: UpdateState<WizardInternalState>, count = 1) {
const { $step } = (await state.get()) || {}
if ($step === undefined) throw new Error('Wizard state is not initialized')
@ -96,7 +93,7 @@ export class WizardScene<State, SceneName extends string = string> extends Dispa
addStep(
handler: (
msg: MessageContext,
state: UpdateState<State & WizardInternalState, SceneName>,
state: UpdateState<State & WizardInternalState>,
) => MaybeAsync<WizardSceneAction | number>,
): void {
const step = this._steps++