diff --git a/packages/dispatcher/src/filters.ts b/packages/dispatcher/src/filters.ts index 178c4c4f..581b0355 100644 --- a/packages/dispatcher/src/filters.ts +++ b/packages/dispatcher/src/filters.ts @@ -21,7 +21,8 @@ import { Invoice, Game, WebPage, - MessageAction, RawLocation, + MessageAction, + RawLocation, } from '@mtcute/client' import { MaybeArray } from '@mtcute/core' import { ChatMemberUpdate } from './updates' @@ -31,6 +32,22 @@ import { UserStatusUpdate } from './updates/user-status-update' import { PollVoteUpdate } from './updates/poll-vote' import { UserTypingUpdate } from './updates/user-typing-update' +function extractText( + obj: Message | InlineQuery | ChosenInlineResult | CallbackQuery +): string | null { + if (obj.constructor === Message) { + return obj.text + } else if (obj.constructor === InlineQuery) { + return obj.query + } else if (obj.constructor === ChosenInlineResult) { + return obj.id + } else if (obj.constructor === CallbackQuery) { + if (obj.raw.data) return obj.dataStr + } + + return null +} + /** * Type describing a primitive filter, which is a function taking some `Base` * and a {@link TelegramClient}, checking it against some condition @@ -318,6 +335,11 @@ export namespace filters { } } + /** + * Filter that matches any update + */ + export const any: UpdateFilter = () => true + /** * Filter messages generated by yourself (including Saved Messages) */ @@ -718,16 +740,16 @@ export namespace filters { /** * Filter messages containing any location (live or static). */ - export const anyLocation: UpdateFilter = (msg) => - msg.media instanceof RawLocation + export const anyLocation: UpdateFilter = ( + msg + ) => msg.media instanceof RawLocation /** * Filter messages containing a static (non-live) location. */ - export const location: UpdateFilter< - Message, - { media: LiveLocation } - > = (msg) => msg.media?.type === 'location' + export const location: UpdateFilter = ( + msg + ) => msg.media?.type === 'location' /** * Filter messages containing a live location. @@ -772,7 +794,7 @@ export namespace filters { * - for `Message`, `Message.text` is used * - for `InlineQuery`, `InlineQuery.query` is used * - for {@link ChosenInlineResult}, {@link ChosenInlineResult.id} is used - * - for `CallbackQuery`, `CallbackQuery.dataStr` + * - for `CallbackQuery`, `CallbackQuery.dataStr` is used * * When a regex matches, the match array is stored in a * type-safe extension field `.match` of the object @@ -785,16 +807,10 @@ export namespace filters { Message | InlineQuery | ChosenInlineResult | CallbackQuery, { match: RegExpMatchArray } > => (obj) => { - let m: RegExpMatchArray | null = null - if (obj.constructor === Message) { - m = obj.text.match(regex) - } else if (obj.constructor === InlineQuery) { - m = obj.query.match(regex) - } else if (obj.constructor === ChosenInlineResult) { - m = obj.id.match(regex) - } else if (obj.constructor === CallbackQuery) { - if (obj.raw.data) m = obj.dataStr!.match(regex) - } + const txt = extractText(obj) + if (!txt) return false + + const m = txt.match(regex) if (m) { ;(obj as any).match = m @@ -803,6 +819,122 @@ export namespace filters { return false } + /** + * Filter objects which contain the exact text given + * - for `Message`, `Message.text` is used + * - for `InlineQuery`, `InlineQuery.query` is used + * - for {@link ChosenInlineResult}, {@link ChosenInlineResult.id} is used + * - for `CallbackQuery`, `CallbackQuery.dataStr` is used + * + * @param str String to be matched + * @param ignoreCase Whether string case should be ignored + */ + export const equals = ( + str: string, + ignoreCase = false + ): UpdateFilter< + Message | InlineQuery | ChosenInlineResult | CallbackQuery + > => { + if (ignoreCase) { + str = str.toLowerCase() + return (obj) => extractText(obj)?.toLowerCase() === str + } + + return (obj) => extractText(obj) === str + } + + /** + * Filter objects which contain the text given (as a substring) + * - for `Message`, `Message.text` is used + * - for `InlineQuery`, `InlineQuery.query` is used + * - for {@link ChosenInlineResult}, {@link ChosenInlineResult.id} is used + * - for `CallbackQuery`, `CallbackQuery.dataStr` is used + * + * @param str Substring to be matched + * @param ignoreCase Whether string case should be ignored + */ + export const contains = ( + str: string, + ignoreCase = false + ): UpdateFilter< + Message | InlineQuery | ChosenInlineResult | CallbackQuery + > => { + if (ignoreCase) { + str = str.toLowerCase() + return (obj) => { + const txt = extractText(obj) + return txt != null && txt.toLowerCase().indexOf(str) > -1 + } + } + + return (obj) => { + const txt = extractText(obj) + return txt != null && txt.indexOf(str) > -1 + } + } + + /** + * Filter objects which contain the text starting with a given string + * - for `Message`, `Message.text` is used + * - for `InlineQuery`, `InlineQuery.query` is used + * - for {@link ChosenInlineResult}, {@link ChosenInlineResult.id} is used + * - for `CallbackQuery`, `CallbackQuery.dataStr` is used + * + * @param str Substring to be matched + * @param ignoreCase Whether string case should be ignored + */ + export const startsWith = ( + str: string, + ignoreCase = false + ): UpdateFilter< + Message | InlineQuery | ChosenInlineResult | CallbackQuery + > => { + if (ignoreCase) { + str = str.toLowerCase() + + return (obj) => { + const txt = extractText(obj) + return txt != null && txt.toLowerCase().substring(0, str.length) === str + } + } + + return (obj) => { + const txt = extractText(obj) + return txt != null && txt.substring(0, str.length) === str + } + } + + /** + * Filter objects which contain the text ending with a given string + * - for `Message`, `Message.text` is used + * - for `InlineQuery`, `InlineQuery.query` is used + * - for {@link ChosenInlineResult}, {@link ChosenInlineResult.id} is used + * - for `CallbackQuery`, `CallbackQuery.dataStr` is used + * + * @param str Substring to be matched + * @param ignoreCase Whether string case should be ignored + */ + export const endsWith = ( + str: string, + ignoreCase = false + ): UpdateFilter< + Message | InlineQuery | ChosenInlineResult | CallbackQuery + > => { + if (ignoreCase) { + str = str.toLowerCase() + + return (obj) => { + const txt = extractText(obj) + return txt != null && txt.toLowerCase().substring(0, str.length) === str + } + } + + return (obj) => { + const txt = extractText(obj) + return txt != null && txt.substring(0, str.length) === str + } + } + /** * Filter messages that call the given command(s).. * @@ -859,7 +991,8 @@ export namespace filters { const lastGroup = m[m.length - 1] if (lastGroup && msg.client['_isBot']) { // check bot username - if (lastGroup !== msg.client['_selfUsername']) return false + if (lastGroup !== msg.client['_selfUsername']) + return false } const match = m.slice(1, -1) @@ -889,10 +1022,7 @@ export namespace filters { * Shorthand filter that matches /start commands sent to bot's * private messages. */ - export const start = and( - chat('private'), - command('start') - ) + export const start = and(chat('private'), command('start')) /** * Filter for deep links (i.e. `/start `). @@ -900,44 +1030,40 @@ export namespace filters { * If the parameter is a regex, groups are added to `msg.command`, * meaning that the first group is available in `msg.command[2]`. */ - export const deeplink = (params: MaybeArray): UpdateFilter => { + export const deeplink = ( + params: MaybeArray + ): UpdateFilter => { if (!Array.isArray(params)) { - return and( - start, - (msg: Message & { command: string[] }) => { - if (msg.command.length !== 2) return false - - const p = msg.command[1] - if (typeof params === 'string' && p === params) return true - - const m = p.match(params) - if (!m) return false - - msg.command.push(...m.slice(1)) - return true - } - ) - } - - return and( - start, - (msg: Message & { command: string[] }) => { + return and(start, (msg: Message & { command: string[] }) => { if (msg.command.length !== 2) return false const p = msg.command[1] - for (const param of params) { - if (typeof param === 'string' && p === param) return true + if (typeof params === 'string' && p === params) return true - const m = p.match(param) - if (!m) continue + const m = p.match(params) + if (!m) return false - msg.command.push(...m.slice(1)) - return true - } + msg.command.push(...m.slice(1)) + return true + }) + } - return false + return and(start, (msg: Message & { command: string[] }) => { + if (msg.command.length !== 2) return false + + const p = msg.command[1] + for (const param of params) { + if (typeof param === 'string' && p === param) return true + + const m = p.match(param) + if (!m) continue + + msg.command.push(...m.slice(1)) + return true } - ) + + return false + }) } /**