Merge pull request #9 from mtcute/networking-rewrite

Networking rewrite
This commit is contained in:
Alina Tumanova 2023-08-23 23:54:13 +03:00 committed by GitHub
commit e7171e32c7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
64 changed files with 4163 additions and 2117 deletions

View file

@ -181,6 +181,10 @@ module.exports = {
], ],
globals: { Atomics: 'readonly', SharedArrayBuffer: 'readonly' }, globals: { Atomics: 'readonly', SharedArrayBuffer: 'readonly' },
parser: '@typescript-eslint/parser', parser: '@typescript-eslint/parser',
parserOptions: {
project: true,
tsconfigRootDir: __dirname,
},
plugins: ['@typescript-eslint'], plugins: ['@typescript-eslint'],
rules: { rules: {
// https://github.com/typescript-eslint/typescript-eslint/tree/master/packages/eslint-plugin#supported-rules // https://github.com/typescript-eslint/typescript-eslint/tree/master/packages/eslint-plugin#supported-rules

View file

@ -27,15 +27,15 @@
"@types/node": "18.16.0", "@types/node": "18.16.0",
"@types/node-forge": "1.3.2", "@types/node-forge": "1.3.2",
"@types/ws": "8.5.4", "@types/ws": "8.5.4",
"@typescript-eslint/eslint-plugin": "5.59.8", "@typescript-eslint/eslint-plugin": "6.4.0",
"@typescript-eslint/parser": "5.59.8", "@typescript-eslint/parser": "6.4.0",
"chai": "4.3.7", "chai": "4.3.7",
"dotenv-flow": "3.2.0", "dotenv-flow": "3.2.0",
"eslint": "8.42.0", "eslint": "8.47.0",
"eslint-config-prettier": "8.8.0", "eslint-config-prettier": "8.8.0",
"eslint-import-resolver-typescript": "3.5.5", "eslint-import-resolver-typescript": "3.6.0",
"eslint-plugin-ascii": "1.0.0", "eslint-plugin-ascii": "1.0.0",
"eslint-plugin-import": "2.27.5", "eslint-plugin-import": "2.28.0",
"eslint-plugin-simple-import-sort": "10.0.0", "eslint-plugin-simple-import-sort": "10.0.0",
"glob": "10.2.6", "glob": "10.2.6",
"husky": "^8.0.3", "husky": "^8.0.3",

View file

@ -64,8 +64,11 @@ async function addSingleMethod(state, fileName) {
if ( if (
!stmt.importClause.namedBindings || !stmt.importClause.namedBindings ||
stmt.importClause.namedBindings.kind !== ts.SyntaxKind.NamedImports stmt.importClause.namedBindings.kind !==
) { throwError(stmt, fileName, 'Only named imports are supported!') } ts.SyntaxKind.NamedImports
) {
throwError(stmt, fileName, 'Only named imports are supported!')
}
let module = stmt.moduleSpecifier.text let module = stmt.moduleSpecifier.text
@ -131,11 +134,7 @@ async function addSingleMethod(state, fileName) {
})() })()
if (!isExported && !isPrivate) { if (!isExported && !isPrivate) {
throwError( throwError(stmt, fileName, 'Public methods MUST be exported.')
stmt,
fileName,
'Public methods MUST be exported.',
)
} }
if (isExported && !checkForFlag(stmt, '@internal')) { if (isExported && !checkForFlag(stmt, '@internal')) {
@ -182,16 +181,20 @@ async function addSingleMethod(state, fileName) {
) )
} }
const returnsExported = (stmt.body ? const returnsExported = (
ts.getLeadingCommentRanges(fileFullText, stmt.body.pos + 2) || stmt.body ?
(stmt.statements && ts.getLeadingCommentRanges(
stmt.statements.length && fileFullText,
ts.getLeadingCommentRanges( stmt.body.pos + 2,
fileFullText, ) ||
stmt.statements[0].pos, (stmt.statements &&
)) || stmt.statements.length &&
[] : ts.getLeadingCommentRanges(
[] fileFullText,
stmt.statements[0].pos,
)) ||
[] :
[]
) )
.map((range) => fileFullText.substring(range.pos, range.end)) .map((range) => fileFullText.substring(range.pos, range.end))
.join('\n') .join('\n')
@ -275,7 +278,9 @@ async function addSingleMethod(state, fileName) {
} }
async function main() { async function main() {
const output = fs.createWriteStream(path.join(__dirname, '../src/client.ts')) const output = fs.createWriteStream(
path.join(__dirname, '../src/client.ts'),
)
const state = { const state = {
imports: {}, imports: {},
fields: [], fields: [],
@ -295,7 +300,8 @@ async function main() {
} }
output.write( output.write(
'/* THIS FILE WAS AUTO-GENERATED */\n' + '/* eslint-disable @typescript-eslint/no-unsafe-declaration-merging, @typescript-eslint/unified-signatures */\n' +
'/* THIS FILE WAS AUTO-GENERATED */\n' +
"import { BaseTelegramClient, BaseTelegramClientOptions } from '@mtcute/core'\n" + "import { BaseTelegramClient, BaseTelegramClientOptions } from '@mtcute/core'\n" +
"import { tl } from '@mtcute/tl'\n", "import { tl } from '@mtcute/tl'\n",
) )
@ -336,7 +342,9 @@ async function main() {
* @param name Event name * @param name Event name
* @param handler ${updates.toSentence(type, 'full')} * @param handler ${updates.toSentence(type, 'full')}
*/ */
on(name: '${type.typeName}', handler: ((upd: ${type.updateType}) => void)): this\n`) on(name: '${type.typeName}', handler: ((upd: ${
type.updateType
}) => void)): this\n`)
}) })
const printer = ts.createPrinter() const printer = ts.createPrinter()
@ -406,7 +414,9 @@ on(name: '${type.typeName}', handler: ((upd: ${type.updateType}) => void)): this
it.initializer = undefined it.initializer = undefined
const deleteParents = (obj) => { const deleteParents = (obj) => {
if (Array.isArray(obj)) { return obj.forEach((it) => deleteParents(it)) } if (Array.isArray(obj)) {
return obj.forEach((it) => deleteParents(it))
}
if (obj.parent) delete obj.parent if (obj.parent) delete obj.parent
@ -455,7 +465,7 @@ on(name: '${type.typeName}', handler: ((upd: ${type.updateType}) => void)): this
for (const name of [origName, ...aliases]) { for (const name of [origName, ...aliases]) {
if (!hasOverloads) { if (!hasOverloads) {
if (!comment.match(/\/\*\*?\s*\*\//)) { if (!comment.match(/\/\*\*?\s*\*\//)) {
// empty comment, no need to write it // empty comment, no need to write it
output.write(comment + '\n') output.write(comment + '\n')
} }
@ -465,18 +475,14 @@ on(name: '${type.typeName}', handler: ((upd: ${type.updateType}) => void)): this
} }
if (!overload) { if (!overload) {
classContents.push( classContents.push(`${name} = ${origName}`)
`${name} = ${origName}`,
)
} }
} }
}, },
) )
output.write('}\n') output.write('}\n')
output.write( output.write('\nexport class TelegramClient extends BaseTelegramClient {\n')
'\nexport class TelegramClient extends BaseTelegramClient {\n',
)
state.fields.forEach(({ code }) => output.write(`protected ${code}\n`)) state.fields.forEach(({ code }) => output.write(`protected ${code}\n`))
@ -501,10 +507,9 @@ on(name: '${type.typeName}', handler: ((upd: ${type.updateType}) => void)): this
await fs.promises.writeFile(targetFile, fullSource) await fs.promises.writeFile(targetFile, fullSource)
// fix using eslint // fix using eslint
require('child_process').execSync( require('child_process').execSync(`pnpm exec eslint --fix ${targetFile}`, {
`pnpm exec eslint --fix ${targetFile}`, stdio: 'inherit',
{ stdio: 'inherit' }, })
)
} }
main().catch(console.error) main().catch(console.error)

View file

@ -1,3 +1,4 @@
/* eslint-disable @typescript-eslint/no-unsafe-declaration-merging */
/* THIS FILE WAS AUTO-GENERATED */ /* THIS FILE WAS AUTO-GENERATED */
import { Readable } from 'stream' import { Readable } from 'stream'
@ -8,7 +9,6 @@ import {
Deque, Deque,
MaybeArray, MaybeArray,
MaybeAsync, MaybeAsync,
SessionConnection,
SortedLinkedList, SortedLinkedList,
} from '@mtcute/core' } from '@mtcute/core'
import { ConditionVariable } from '@mtcute/core/src/utils/condition-variable' import { ConditionVariable } from '@mtcute/core/src/utils/condition-variable'
@ -131,7 +131,6 @@ import { getMessages } from './methods/messages/get-messages'
import { getMessagesUnsafe } from './methods/messages/get-messages-unsafe' import { getMessagesUnsafe } from './methods/messages/get-messages-unsafe'
import { getReactionUsers } from './methods/messages/get-reaction-users' import { getReactionUsers } from './methods/messages/get-reaction-users'
import { getScheduledMessages } from './methods/messages/get-scheduled-messages' import { getScheduledMessages } from './methods/messages/get-scheduled-messages'
import { _normalizeInline } from './methods/messages/normalize-inline'
import { _parseEntities } from './methods/messages/parse-entities' import { _parseEntities } from './methods/messages/parse-entities'
import { pinMessage } from './methods/messages/pin-message' import { pinMessage } from './methods/messages/pin-message'
import { readHistory } from './methods/messages/read-history' import { readHistory } from './methods/messages/read-history'
@ -1909,14 +1908,15 @@ export interface TelegramClient extends BaseTelegramClient {
/** /**
* Total file size. Automatically inferred for Buffer, File and local files. * Total file size. Automatically inferred for Buffer, File and local files.
*
* When using with streams, if `fileSize` is not passed, the entire file is
* first loaded into memory to determine file size, and used as a Buffer later.
* This might be a major performance bottleneck, so be sure to provide file size
* when using streams and file size is known (which often is the case).
*/ */
fileSize?: number fileSize?: number
/**
* If the file size is unknown, you can provide an estimate,
* which will be used to determine appropriate part size.
*/
estimatedSize?: number
/** /**
* File MIME type. By default is automatically inferred from magic number * File MIME type. By default is automatically inferred from magic number
* If MIME can't be inferred, it defaults to `application/octet-stream` * If MIME can't be inferred, it defaults to `application/octet-stream`
@ -1931,11 +1931,16 @@ export interface TelegramClient extends BaseTelegramClient {
*/ */
partSize?: number partSize?: number
/**
* Number of parts to be sent in parallel per connection.
*/
requestsPerConnection?: number
/** /**
* Function that will be called after some part has been uploaded. * Function that will be called after some part has been uploaded.
* *
* @param uploaded Number of bytes already uploaded * @param uploaded Number of bytes already uploaded
* @param total Total file size * @param total Total file size, if known
*/ */
progressCallback?: (uploaded: number, total: number) => void progressCallback?: (uploaded: number, total: number) => void
}): Promise<UploadedFile> }): Promise<UploadedFile>
@ -2759,10 +2764,6 @@ export interface TelegramClient extends BaseTelegramClient {
messageIds: number[] messageIds: number[]
): Promise<(Message | null)[]> ): Promise<(Message | null)[]>
_normalizeInline(
id: string | tl.TypeInputBotInlineMessageID
): Promise<[tl.TypeInputBotInlineMessageID, SessionConnection]>
_parseEntities( _parseEntities(
text?: string | FormattedString<string>, text?: string | FormattedString<string>,
mode?: string | null, mode?: string | null,
@ -4024,8 +4025,6 @@ export class TelegramClient extends BaseTelegramClient {
protected _selfUsername: string | null protected _selfUsername: string | null
protected _pendingConversations: Record<number, Conversation[]> protected _pendingConversations: Record<number, Conversation[]>
protected _hasConversations: boolean protected _hasConversations: boolean
protected _downloadConnections: Record<number, SessionConnection>
protected _connectionsForInline: Record<number, SessionConnection>
protected _parseModes: Record<string, IMessageEntityParser> protected _parseModes: Record<string, IMessageEntityParser>
protected _defaultParseMode: string | null protected _defaultParseMode: string | null
protected _updatesLoopActive: boolean protected _updatesLoopActive: boolean
@ -4060,8 +4059,6 @@ export class TelegramClient extends BaseTelegramClient {
this.log.prefix = '[USER N/A] ' this.log.prefix = '[USER N/A] '
this._pendingConversations = {} this._pendingConversations = {}
this._hasConversations = false this._hasConversations = false
this._downloadConnections = {}
this._connectionsForInline = {}
this._parseModes = {} this._parseModes = {}
this._defaultParseMode = null this._defaultParseMode = null
this._updatesLoopActive = false this._updatesLoopActive = false
@ -4213,7 +4210,6 @@ export class TelegramClient extends BaseTelegramClient {
getMessages = getMessages getMessages = getMessages
getReactionUsers = getReactionUsers getReactionUsers = getReactionUsers
getScheduledMessages = getScheduledMessages getScheduledMessages = getScheduledMessages
_normalizeInline = _normalizeInline
_parseEntities = _parseEntities _parseEntities = _parseEntities
pinMessage = pinMessage pinMessage = pinMessage
readHistory = readHistory readHistory = readHistory

View file

@ -2,12 +2,7 @@
import { Readable } from 'stream' import { Readable } from 'stream'
// @copy // @copy
import { import { AsyncLock, MaybeArray, MaybeAsync } from '@mtcute/core'
AsyncLock,
MaybeArray,
MaybeAsync,
SessionConnection,
} from '@mtcute/core'
// @copy // @copy
import { Logger } from '@mtcute/core/src/utils/logger' import { Logger } from '@mtcute/core/src/utils/logger'
// @copy // @copy

View file

@ -38,16 +38,18 @@ export async function checkPassword(
'user', 'user',
) )
this.log.prefix = `[USER ${this._userId}] `
this._userId = res.user.id this._userId = res.user.id
this.log.prefix = `[USER ${this._userId}] `
this._isBot = false this._isBot = false
this._selfChanged = true this._selfChanged = true
this._selfUsername = res.user.username ?? null this._selfUsername = res.user.username ?? null
await this.network.notifyLoggedIn(res)
await this._fetchUpdatesState() await this._fetchUpdatesState()
await this._saveStorage() await this._saveStorage()
// telegram ignores invokeWithoutUpdates for auth methods // telegram ignores invokeWithoutUpdates for auth methods
if (this._disableUpdates) this.primaryConnection._resetSession() if (this.network.params.disableUpdates) this.network.resetSessions()
else this.startUpdatesLoop() else this.startUpdatesLoop()
return new User(this, res.user) return new User(this, res.user)

View file

@ -19,7 +19,7 @@ export async function sendCode(
const res = await this.call({ const res = await this.call({
_: 'auth.sendCode', _: 'auth.sendCode',
phoneNumber: phone, phoneNumber: phone,
apiId: this._initConnectionParams.apiId, apiId: this.network._initConnectionParams.apiId,
apiHash: this._apiHash, apiHash: this._apiHash,
settings: { _: 'codeSettings' }, settings: { _: 'codeSettings' },
}) })

View file

@ -17,7 +17,7 @@ export async function signInBot(
const res = await this.call({ const res = await this.call({
_: 'auth.importBotAuthorization', _: 'auth.importBotAuthorization',
flags: 0, flags: 0,
apiId: this._initConnectionParams.apiId, apiId: this.network._initConnectionParams.apiId,
apiHash: this._apiHash, apiHash: this._apiHash,
botAuthToken: token, botAuthToken: token,
}) })
@ -33,16 +33,19 @@ export async function signInBot(
'user', 'user',
) )
this.log.prefix = `[USER ${this._userId}] `
this._userId = res.user.id this._userId = res.user.id
this.log.prefix = `[USER ${this._userId}] `
this._isBot = true this._isBot = true
this._selfUsername = res.user.username! this._selfUsername = res.user.username!
this._selfChanged = true this._selfChanged = true
await this.network.notifyLoggedIn(res)
await this._fetchUpdatesState() await this._fetchUpdatesState()
await this._saveStorage() await this._saveStorage()
// telegram ignores invokeWithoutUpdates for auth methods // telegram ignores invokeWithoutUpdates for auth methods
if (this._disableUpdates) this.primaryConnection._resetSession() if (this.network.params.disableUpdates) this.network.resetSessions()
else this.startUpdatesLoop() else this.startUpdatesLoop()
return new User(this, res.user) return new User(this, res.user)

View file

@ -41,16 +41,18 @@ export async function signIn(
assertTypeIs('signIn (@ auth.signIn -> user)', res.user, 'user') assertTypeIs('signIn (@ auth.signIn -> user)', res.user, 'user')
this.log.prefix = `[USER ${this._userId}] `
this._userId = res.user.id this._userId = res.user.id
this.log.prefix = `[USER ${this._userId}] `
this._isBot = false this._isBot = false
this._selfChanged = true this._selfChanged = true
this._selfUsername = res.user.username ?? null this._selfUsername = res.user.username ?? null
await this.network.notifyLoggedIn(res)
await this._fetchUpdatesState() await this._fetchUpdatesState()
await this._saveStorage() await this._saveStorage()
// telegram ignores invokeWithoutUpdates for auth methods // telegram ignores invokeWithoutUpdates for auth methods
if (this._disableUpdates) this.primaryConnection._resetSession() if (this.network.params.disableUpdates) this.network.resetSessions()
else this.startUpdatesLoop() else this.startUpdatesLoop()
return new User(this, res.user) return new User(this, res.user)

View file

@ -32,15 +32,18 @@ export async function signUp(
assertTypeIs('signUp (@ auth.signUp)', res, 'auth.authorization') assertTypeIs('signUp (@ auth.signUp)', res, 'auth.authorization')
assertTypeIs('signUp (@ auth.signUp -> user)', res.user, 'user') assertTypeIs('signUp (@ auth.signUp -> user)', res.user, 'user')
this.log.prefix = `[USER ${this._userId}] `
this._userId = res.user.id this._userId = res.user.id
this.log.prefix = `[USER ${this._userId}] `
this._isBot = false this._isBot = false
this._selfChanged = true this._selfChanged = true
await this.network.notifyLoggedIn(res)
await this._fetchUpdatesState() await this._fetchUpdatesState()
await this._saveStorage() await this._saveStorage()
// telegram ignores invokeWithoutUpdates for auth methods // telegram ignores invokeWithoutUpdates for auth methods
if (this._disableUpdates) this.primaryConnection._resetSession() if (this.network.params.disableUpdates) this.network.resetSessions()
else this.startUpdatesLoop() else this.startUpdatesLoop()
return new User(this, res.user) return new User(this, res.user)

View file

@ -78,7 +78,7 @@ export async function startTest(
if (!availableDcs.find((dc) => dc.id === id)) { throw new MtArgumentError(`${phone} has invalid DC ID (${id})`) } if (!availableDcs.find((dc) => dc.id === id)) { throw new MtArgumentError(`${phone} has invalid DC ID (${id})`) }
} else { } else {
let dcId = this._primaryDc.id let dcId = this._defaultDc.id
if (params.dcId) { if (params.dcId) {
if (!availableDcs.find((dc) => dc.id === params!.dcId)) { throw new MtArgumentError(`DC ID is invalid (${dcId})`) } if (!availableDcs.find((dc) => dc.id === params!.dcId)) { throw new MtArgumentError(`DC ID is invalid (${dcId})`) }

View file

@ -155,7 +155,9 @@ export async function start(
me.isBot, me.isBot,
) )
if (!this._disableUpdates) { this.network.setIsPremium(me.isPremium)
if (!this.network.params.disableUpdates) {
this._catchUpChannels = Boolean(params.catchUp) this._catchUpChannels = Boolean(params.catchUp)
if (!params.catchUp) { if (!params.catchUp) {
@ -175,14 +177,18 @@ export async function start(
if (!(e instanceof tl.errors.AuthKeyUnregisteredError)) throw e if (!(e instanceof tl.errors.AuthKeyUnregisteredError)) throw e
} }
if (!params.phone && !params.botToken) { throw new MtArgumentError('Neither phone nor bot token were provided') } if (!params.phone && !params.botToken) {
throw new MtArgumentError('Neither phone nor bot token were provided')
}
let phone = params.phone ? await resolveMaybeDynamic(params.phone) : null let phone = params.phone ? await resolveMaybeDynamic(params.phone) : null
if (phone) { if (phone) {
phone = normalizePhoneNumber(phone) phone = normalizePhoneNumber(phone)
if (!params.code) { throw new MtArgumentError('You must pass `code` to use `phone`') } if (!params.code) {
throw new MtArgumentError('You must pass `code` to use `phone`')
}
} else { } else {
const botToken = params.botToken ? const botToken = params.botToken ?
await resolveMaybeDynamic(params.botToken) : await resolveMaybeDynamic(params.botToken) :

View file

@ -1,12 +1,8 @@
import { tl } from '@mtcute/tl' import { tl } from '@mtcute/tl'
import { TelegramClient } from '../../client' import { TelegramClient } from '../../client'
import { import { GameHighScore, InputPeerLike, PeersIndex } from '../../types'
GameHighScore, import { normalizeInlineId } from '../../utils/inline-utils'
InputPeerLike,
MtInvalidPeerTypeError,
PeersIndex,
} from '../../types'
import { normalizeToInputUser } from '../../utils/peer-utils' import { normalizeToInputUser } from '../../utils/peer-utils'
/** /**
@ -57,7 +53,7 @@ export async function getInlineGameHighScores(
messageId: string | tl.TypeInputBotInlineMessageID, messageId: string | tl.TypeInputBotInlineMessageID,
userId?: InputPeerLike, userId?: InputPeerLike,
): Promise<GameHighScore[]> { ): Promise<GameHighScore[]> {
const [id, connection] = await this._normalizeInline(messageId) const id = await normalizeInlineId(messageId)
let user: tl.TypeInputUser let user: tl.TypeInputUser
@ -73,7 +69,7 @@ export async function getInlineGameHighScores(
id, id,
userId: user, userId: user,
}, },
{ connection }, { dcId: id.dcId },
) )
const peers = PeersIndex.from(res) const peers = PeersIndex.from(res)

View file

@ -2,6 +2,7 @@ import { tl } from '@mtcute/tl'
import { TelegramClient } from '../../client' import { TelegramClient } from '../../client'
import { InputPeerLike, Message, MtInvalidPeerTypeError } from '../../types' import { InputPeerLike, Message, MtInvalidPeerTypeError } from '../../types'
import { normalizeInlineId } from '../../utils/inline-utils'
import { normalizeToInputUser } from '../../utils/peer-utils' import { normalizeToInputUser } from '../../utils/peer-utils'
/** /**
@ -86,7 +87,7 @@ export async function setInlineGameScore(
const user = normalizeToInputUser(await this.resolvePeer(userId), userId) const user = normalizeToInputUser(await this.resolvePeer(userId), userId)
const [id, connection] = await this._normalizeInline(messageId) const id = await normalizeInlineId(messageId)
await this.call( await this.call(
{ {
@ -97,6 +98,6 @@ export async function setInlineGameScore(
editMessage: !params.noEdit, editMessage: !params.noEdit,
force: params.force, force: params.force,
}, },
{ connection }, { dcId: id.dcId },
) )
} }

View file

@ -1,13 +0,0 @@
import { SessionConnection } from '@mtcute/core'
import { TelegramClient } from '../../client'
// @extension
interface FilesExtension {
_downloadConnections: Record<number, SessionConnection>
}
// @initialize
function _initializeFiles(this: TelegramClient): void {
this._downloadConnections = {}
}

View file

@ -1,3 +1,4 @@
import { ConditionVariable, ConnectionKind } from '@mtcute/core'
import { import {
fileIdToInputFileLocation, fileIdToInputFileLocation,
fileIdToInputWebFileLocation, fileIdToInputWebFileLocation,
@ -14,6 +15,12 @@ import {
} from '../../types' } from '../../types'
import { determinePartSize } from '../../utils/file-utils' import { determinePartSize } from '../../utils/file-utils'
// small files (less than 128 kb) are downloaded using the "downloadSmall" pool
// furthermore, if the file is small and is located on our main DC, it will be downloaded
// using the current main connection
const SMALL_FILE_MAX_SIZE = 131072
const REQUESTS_PER_CONNECTION = 3 // some arbitrary magic value that seems to work best
/** /**
* Download a file and return it as an iterable, which yields file contents * Download a file and return it as an iterable, which yields file contents
* in chunks of a given size. Order of the chunks is guaranteed to be * in chunks of a given size. Order of the chunks is guaranteed to be
@ -26,17 +33,7 @@ export async function* downloadAsIterable(
this: TelegramClient, this: TelegramClient,
params: FileDownloadParameters, params: FileDownloadParameters,
): AsyncIterableIterator<Buffer> { ): AsyncIterableIterator<Buffer> {
const partSizeKb = const offset = params.offset ?? 0
params.partSize ??
(params.fileSize ? determinePartSize(params.fileSize) : 64)
if (partSizeKb % 4 !== 0) {
throw new MtArgumentError(
`Invalid part size: ${partSizeKb}. Must be divisible by 4.`,
)
}
let offset = params.offset ?? 0
if (offset % 4096 !== 0) { if (offset % 4096 !== 0) {
throw new MtArgumentError( throw new MtArgumentError(
@ -76,26 +73,54 @@ export async function* downloadAsIterable(
const isWeb = tl.isAnyInputWebFileLocation(location) const isWeb = tl.isAnyInputWebFileLocation(location)
// we will receive a FileMigrateError in case this is invalid // we will receive a FileMigrateError in case this is invalid
if (!dcId) dcId = this._primaryDc.id if (!dcId) dcId = this._defaultDc.id
const partSizeKb =
params.partSize ?? (fileSize ? determinePartSize(fileSize) : 64)
if (partSizeKb % 4 !== 0) {
throw new MtArgumentError(
`Invalid part size: ${partSizeKb}. Must be divisible by 4.`,
)
}
const chunkSize = partSizeKb * 1024 const chunkSize = partSizeKb * 1024
let limit = let limitBytes = params.limit ?? fileSize ?? Infinity
params.limit ?? if (limitBytes === 0) return
// derive limit from chunk size, file size and offset
(fileSize ?
~~((fileSize + chunkSize - offset - 1) / chunkSize) :
// we will receive an error when we have reached the end anyway
Infinity)
let connection = this._downloadConnections[dcId] let numChunks =
limitBytes === Infinity ?
Infinity :
~~((limitBytes + chunkSize - offset - 1) / chunkSize)
if (!connection) { let nextChunkIdx = 0
connection = await this.createAdditionalConnection(dcId) let nextWorkerChunkIdx = 0
this._downloadConnections[dcId] = connection const nextChunkCv = new ConditionVariable()
const buffer: Record<number, Buffer> = {}
const isSmall = fileSize && fileSize <= SMALL_FILE_MAX_SIZE
let connectionKind: ConnectionKind
if (isSmall) {
connectionKind =
dcId === this.network.getPrimaryDcId() ? 'main' : 'downloadSmall'
} else {
connectionKind = 'download'
} }
const poolSize = this.network.getPoolSize(connectionKind, dcId)
const requestCurrent = async (): Promise<Buffer> => { this.log.debug(
'Downloading file of size %d from dc %d using %s connection pool (pool size: %d)',
limitBytes,
dcId,
connectionKind,
poolSize,
)
const downloadChunk = async (
chunk = nextWorkerChunkIdx++,
): Promise<void> => {
let result: let result:
| tl.RpcCallReturn['upload.getFile'] | tl.RpcCallReturn['upload.getFile']
| tl.RpcCallReturn['upload.getWebFile'] | tl.RpcCallReturn['upload.getWebFile']
@ -106,22 +131,17 @@ export async function* downloadAsIterable(
_: isWeb ? 'upload.getWebFile' : 'upload.getFile', _: isWeb ? 'upload.getWebFile' : 'upload.getFile',
// eslint-disable-next-line @typescript-eslint/no-explicit-any // eslint-disable-next-line @typescript-eslint/no-explicit-any
location: location as any, location: location as any,
offset, offset: chunkSize * chunk,
limit: chunkSize, limit: chunkSize,
}, },
{ connection }, { dcId, kind: connectionKind },
) )
// eslint-disable-next-line @typescript-eslint/no-explicit-any // eslint-disable-next-line @typescript-eslint/no-explicit-any
} catch (e: any) { } catch (e: any) {
if (e.constructor === tl.errors.FileMigrateXError) { if (e.constructor === tl.errors.FileMigrateXError) {
connection = this._downloadConnections[e.new_dc] dcId = e.new_dc
if (!connection) { return downloadChunk(chunk)
connection = await this.createAdditionalConnection(e.new_dc)
this._downloadConnections[e.new_dc] = connection
}
return requestCurrent()
} else if (e.constructor === tl.errors.FilerefUpgradeNeededError) { } else if (e.constructor === tl.errors.FilerefUpgradeNeededError) {
// todo: implement someday // todo: implement someday
// see: https://github.com/LonamiWebs/Telethon/blob/0e8bd8248cc649637b7c392616887c50986427a0/telethon/client/downloads.py#L99 // see: https://github.com/LonamiWebs/Telethon/blob/0e8bd8248cc649637b7c392616887c50986427a0/telethon/client/downloads.py#L99
@ -141,25 +161,65 @@ export async function* downloadAsIterable(
if ( if (
result._ === 'upload.webFile' && result._ === 'upload.webFile' &&
result.size && result.size &&
limit === Infinity limitBytes === Infinity
) { ) {
limit = result.size limitBytes = result.size
numChunks = ~~((limitBytes + chunkSize - offset - 1) / chunkSize)
} }
return result.bytes buffer[chunk] = result.bytes
if (chunk === nextChunkIdx) {
nextChunkCv.notify()
}
if (
nextWorkerChunkIdx < numChunks &&
result.bytes.length === chunkSize
) {
return downloadChunk()
}
} }
for (let i = 0; i < limit; i++) { let error: unknown = undefined
const buf = await requestCurrent() Promise.all(
Array.from(
{ length: Math.min(poolSize * REQUESTS_PER_CONNECTION, numChunks) },
downloadChunk,
),
)
.catch((e) => {
this.log.debug('download workers errored: %s', e.message)
error = e
nextChunkCv.notify()
})
.then(() => {
this.log.debug('download workers finished')
})
if (buf.length === 0) { let position = offset
// we've reached the end
return while (position < limitBytes) {
await nextChunkCv.wait()
if (error) throw error
while (nextChunkIdx in buffer) {
const buf = buffer[nextChunkIdx]
delete buffer[nextChunkIdx]
position += buf.length
params.progressCallback?.(position, limitBytes)
yield buf
nextChunkIdx++
if (buf.length < chunkSize) {
// we received the last chunk
return
}
} }
yield buf
offset += chunkSize
params.progressCallback?.(offset, limit)
} }
} }

View file

@ -3,7 +3,7 @@ import { fromBuffer as fileTypeFromBuffer } from 'file-type'
import type { ReadStream } from 'fs' import type { ReadStream } from 'fs'
import { Readable } from 'stream' import { Readable } from 'stream'
import { randomLong } from '@mtcute/core' import { AsyncLock, randomLong } from '@mtcute/core'
import { tl } from '@mtcute/tl' import { tl } from '@mtcute/tl'
import { TelegramClient } from '../../client' import { TelegramClient } from '../../client'
@ -13,7 +13,6 @@ import {
bufferToStream, bufferToStream,
convertWebStreamToNodeReadable, convertWebStreamToNodeReadable,
readBytesFromStream, readBytesFromStream,
readStreamUntilEnd,
} from '../../utils/stream-utils' } from '../../utils/stream-utils'
let fs: any = null let fs: any = null
@ -29,6 +28,14 @@ const OVERRIDE_MIME: Record<string, string> = {
'audio/opus': 'audio/ogg', 'audio/opus': 'audio/ogg',
} }
// small files (less than 128 kb) are uploaded using the current connection and not the "upload" pool
const SMALL_FILE_MAX_SIZE = 131072
const BIG_FILE_MIN_SIZE = 10485760 // files >10 MB are considered "big"
const DEFAULT_FILE_NAME = 'unnamed'
const REQUESTS_PER_CONNECTION = 3
const MAX_PART_COUNT = 4000 // 512 kb * 4000 = 2000 MiB
const MAX_PART_COUNT_PREMIUM = 8000 // 512 kb * 8000 = 4000 MiB
/** /**
* Upload a file to Telegram servers, without actually * Upload a file to Telegram servers, without actually
* sending a message anywhere. Useful when an `InputFile` is required. * sending a message anywhere. Useful when an `InputFile` is required.
@ -60,14 +67,15 @@ export async function uploadFile(
/** /**
* Total file size. Automatically inferred for Buffer, File and local files. * Total file size. Automatically inferred for Buffer, File and local files.
*
* When using with streams, if `fileSize` is not passed, the entire file is
* first loaded into memory to determine file size, and used as a Buffer later.
* This might be a major performance bottleneck, so be sure to provide file size
* when using streams and file size is known (which often is the case).
*/ */
fileSize?: number fileSize?: number
/**
* If the file size is unknown, you can provide an estimate,
* which will be used to determine appropriate part size.
*/
estimatedSize?: number
/** /**
* File MIME type. By default is automatically inferred from magic number * File MIME type. By default is automatically inferred from magic number
* If MIME can't be inferred, it defaults to `application/octet-stream` * If MIME can't be inferred, it defaults to `application/octet-stream`
@ -82,11 +90,16 @@ export async function uploadFile(
*/ */
partSize?: number partSize?: number
/**
* Number of parts to be sent in parallel per connection.
*/
requestsPerConnection?: number
/** /**
* Function that will be called after some part has been uploaded. * Function that will be called after some part has been uploaded.
* *
* @param uploaded Number of bytes already uploaded * @param uploaded Number of bytes already uploaded
* @param total Total file size * @param total Total file size, if known
*/ */
progressCallback?: (uploaded: number, total: number) => void progressCallback?: (uploaded: number, total: number) => void
}, },
@ -94,7 +107,7 @@ export async function uploadFile(
// normalize params // normalize params
let file = params.file let file = params.file
let fileSize = -1 // unknown let fileSize = -1 // unknown
let fileName = 'unnamed' let fileName = DEFAULT_FILE_NAME
let fileMime = params.fileMime let fileMime = params.fileMime
if (Buffer.isBuffer(file)) { if (Buffer.isBuffer(file)) {
@ -162,12 +175,12 @@ export async function uploadFile(
} }
} }
if (fileName === 'unnamed') { if (fileName === DEFAULT_FILE_NAME) {
// try to infer from url // try to infer from url
const url = new URL(file.url) const url = new URL(file.url)
const name = url.pathname.split('/').pop() const name = url.pathname.split('/').pop()
if (name && name.indexOf('.') > -1) { if (name && name.includes('.')) {
fileName = name fileName = name
} }
} }
@ -192,42 +205,88 @@ export async function uploadFile(
// set file size if not automatically inferred // set file size if not automatically inferred
if (fileSize === -1 && params.fileSize) fileSize = params.fileSize if (fileSize === -1 && params.fileSize) fileSize = params.fileSize
if (fileSize === -1) { let partSizeKb = params.partSize
// load the entire stream into memory
const buffer = await readStreamUntilEnd(file as Readable) if (!partSizeKb) {
fileSize = buffer.length if (fileSize === -1) {
file = bufferToStream(buffer) partSizeKb = params.estimatedSize ?
determinePartSize(params.estimatedSize) :
64
} else {
partSizeKb = determinePartSize(fileSize)
}
} }
if (!(file instanceof Readable)) { if (!(file instanceof Readable)) {
throw new MtArgumentError('Could not convert input `file` to stream!') throw new MtArgumentError('Could not convert input `file` to stream!')
} }
const partSizeKb = params.partSize ?? determinePartSize(fileSize)
if (partSizeKb > 512) { if (partSizeKb > 512) {
throw new MtArgumentError(`Invalid part size: ${partSizeKb}KB`) throw new MtArgumentError(`Invalid part size: ${partSizeKb}KB`)
} }
const partSize = partSizeKb * 1024 const partSize = partSizeKb * 1024
const isBig = fileSize > 10485760 // 10 MB let partCount =
const hash = this._crypto.createMd5() fileSize === -1 ? -1 : ~~((fileSize + partSize - 1) / partSize)
const maxPartCount = this.network.params.isPremium ?
MAX_PART_COUNT_PREMIUM :
MAX_PART_COUNT
if (partCount > maxPartCount) {
throw new MtArgumentError(
`File is too large (max ${maxPartCount} parts, got ${partCount})`,
)
}
const isBig = fileSize === -1 || fileSize > BIG_FILE_MIN_SIZE
const isSmall = fileSize !== -1 && fileSize < SMALL_FILE_MAX_SIZE
const connectionKind = isSmall ? 'main' : 'upload'
const connectionPoolSize = Math.min(
this.network.getPoolSize(connectionKind),
partCount,
)
const requestsPerConnection =
params.requestsPerConnection ?? REQUESTS_PER_CONNECTION
const partCount = ~~((fileSize + partSize - 1) / partSize)
this.log.debug( this.log.debug(
'uploading %d bytes file in %d chunks, each %d bytes', 'uploading %d bytes file in %d chunks, each %d bytes in %s connection pool of size %d',
fileSize, fileSize,
partCount, partCount,
partSize, partSize,
connectionKind,
connectionPoolSize,
) )
// why is the file id generated by the client? // why is the file id generated by the client?
// isn't the server supposed to generate it and handle collisions? // isn't the server supposed to generate it and handle collisions?
const fileId = randomLong() const fileId = randomLong()
let pos = 0 const stream = file
for (let idx = 0; idx < partCount; idx++) { let pos = 0
const part = await readBytesFromStream(file, partSize) let idx = 0
const lock = new AsyncLock()
const uploadNextPart = async (): Promise<void> => {
const thisIdx = idx++
let part
try {
await lock.acquire()
part = await readBytesFromStream(stream, partSize)
} finally {
lock.release()
}
if (fileSize === -1 && stream.readableEnded) {
fileSize = pos + (part?.length ?? 0)
partCount = ~~((fileSize + partSize - 1) / partSize)
this.log.debug(
'readable ended, file size = %d, part count = %d',
fileSize,
partCount,
)
}
if (!part) { if (!part) {
throw new MtArgumentError( throw new MtArgumentError(
@ -236,15 +295,15 @@ export async function uploadFile(
} }
if (!Buffer.isBuffer(part)) { if (!Buffer.isBuffer(part)) {
throw new MtArgumentError(`Part ${idx} was not a Buffer!`) throw new MtArgumentError(`Part ${thisIdx} was not a Buffer!`)
} }
if (part.length > partSize) { if (part.length > partSize) {
throw new MtArgumentError( throw new MtArgumentError(
`Part ${idx} had invalid size (expected ${partSize}, got ${part.length})`, `Part ${thisIdx} had invalid size (expected ${partSize}, got ${part.length})`,
) )
} }
if (idx === 0 && fileMime === undefined) { if (thisIdx === 0 && fileMime === undefined) {
const fileType = await fileTypeFromBuffer(part) const fileType = await fileTypeFromBuffer(part)
fileMime = fileType?.mime fileMime = fileType?.mime
@ -260,37 +319,43 @@ export async function uploadFile(
} }
} }
if (!isBig) {
// why md5 only small files?
// big files have more chance of corruption, but whatever
// also isn't integrity guaranteed by mtproto?
await hash.update(part)
}
pos += part.length
// why // why
const request = isBig ? const request = isBig ?
({ ({
_: 'upload.saveBigFilePart', _: 'upload.saveBigFilePart',
fileId, fileId,
filePart: idx, filePart: thisIdx,
fileTotalParts: partCount, fileTotalParts: partCount,
bytes: part, bytes: part,
} as tl.upload.RawSaveBigFilePartRequest) : } satisfies tl.upload.RawSaveBigFilePartRequest) :
({ ({
_: 'upload.saveFilePart', _: 'upload.saveFilePart',
fileId, fileId,
filePart: idx, filePart: thisIdx,
bytes: part, bytes: part,
} as tl.upload.RawSaveFilePartRequest) } satisfies tl.upload.RawSaveFilePartRequest)
const result = await this.call(request) const result = await this.call(request, { kind: connectionKind })
if (!result) throw new Error(`Failed to upload part ${idx}`) if (!result) throw new Error(`Failed to upload part ${idx}`)
pos += part.length
params.progressCallback?.(pos, fileSize) params.progressCallback?.(pos, fileSize)
if (idx === partCount) return
return uploadNextPart()
} }
await Promise.all(
Array.from(
{
length: connectionPoolSize * requestsPerConnection,
},
uploadNextPart,
),
)
let inputFile: tl.TypeInputFile let inputFile: tl.TypeInputFile
if (isBig) { if (isBig) {
@ -306,7 +371,7 @@ export async function uploadFile(
id: fileId, id: fileId,
parts: partCount, parts: partCount,
name: fileName, name: fileName,
md5Checksum: (await hash.digest()).toString('hex'), md5Checksum: '', // tdlib doesn't do this, why should we?
} }
} }

View file

@ -7,6 +7,7 @@ import {
InputMediaLike, InputMediaLike,
ReplyMarkup, ReplyMarkup,
} from '../../types' } from '../../types'
import { normalizeInlineId } from '../../utils/inline-utils'
/** /**
* Edit sent inline message text, media and reply markup. * Edit sent inline message text, media and reply markup.
@ -75,7 +76,7 @@ export async function editInlineMessage(
let entities: tl.TypeMessageEntity[] | undefined let entities: tl.TypeMessageEntity[] | undefined
let media: tl.TypeInputMedia | undefined = undefined let media: tl.TypeInputMedia | undefined = undefined
const [id, connection] = await this._normalizeInline(messageId) const id = await normalizeInlineId(messageId)
if (params.media) { if (params.media) {
media = await this._normalizeInputMedia(params.media, params, true) media = await this._normalizeInputMedia(params.media, params, true)
@ -111,7 +112,7 @@ export async function editInlineMessage(
entities, entities,
media, media,
}, },
{ connection }, { dcId: id.dcId },
) )
return return

View file

@ -1,37 +0,0 @@
import { SessionConnection } from '@mtcute/core'
import { tl } from '@mtcute/tl'
import { TelegramClient } from '../../client'
import { parseInlineMessageId } from '../../utils/inline-utils'
// @extension
interface InlineExtension {
_connectionsForInline: Record<number, SessionConnection>
}
// @initialize
function _initializeInline(this: TelegramClient) {
this._connectionsForInline = {}
}
/** @internal */
export async function _normalizeInline(
this: TelegramClient,
id: string | tl.TypeInputBotInlineMessageID,
): Promise<[tl.TypeInputBotInlineMessageID, SessionConnection]> {
if (typeof id === 'string') {
id = parseInlineMessageId(id)
}
let connection = this.primaryConnection
if (id.dcId !== connection.params.dc.id) {
if (!(id.dcId in this._connectionsForInline)) {
this._connectionsForInline[id.dcId] =
await this.createAdditionalConnection(id.dcId)
}
connection = this._connectionsForInline[id.dcId]
}
return [id, connection]
}

View file

@ -558,11 +558,15 @@ async function _fetchPeersForShort(
if ( if (
msg.replyTo._ === 'messageReplyHeader' && msg.replyTo._ === 'messageReplyHeader' &&
!(await fetchPeer(msg.replyTo.replyToPeerId)) !(await fetchPeer(msg.replyTo.replyToPeerId))
) { return null } ) {
return null
}
if ( if (
msg.replyTo._ === 'messageReplyStoryHeader' && msg.replyTo._ === 'messageReplyStoryHeader' &&
!(await fetchPeer(msg.replyTo.userId)) !(await fetchPeer(msg.replyTo.userId))
) { return null } ) {
return null
}
} }
if (msg._ !== 'messageService') { if (msg._ !== 'messageService') {
@ -791,7 +795,7 @@ async function _fetchChannelDifference(
if (!_pts) _pts = fallbackPts if (!_pts) _pts = fallbackPts
if (!_pts) { if (!_pts) {
this._updsLog.warn( this._updsLog.debug(
'fetchChannelDifference failed for channel %d: base pts not available', 'fetchChannelDifference failed for channel %d: base pts not available',
channelId, channelId,
) )
@ -956,19 +960,13 @@ async function _fetchDifference(
this: TelegramClient, this: TelegramClient,
requestedDiff: Record<number, Promise<void>>, requestedDiff: Record<number, Promise<void>>,
): Promise<void> { ): Promise<void> {
let isFirst = true
for (;;) { for (;;) {
const diff = await this.call( const diff = await this.call({
{ _: 'updates.getDifference',
_: 'updates.getDifference', pts: this._pts!,
pts: this._pts!, date: this._date!,
date: this._date!, qts: this._qts!,
qts: this._qts!, })
},
// { flush: !isFirst }
)
isFirst = false
switch (diff._) { switch (diff._) {
case 'updates.differenceEmpty': case 'updates.differenceEmpty':
@ -1210,16 +1208,21 @@ async function _onUpdate(
case 'dummyUpdate': case 'dummyUpdate':
// we just needed to apply new pts values // we just needed to apply new pts values
return return
case 'updateDcOptions': case 'updateDcOptions': {
if (!this._config) { const config = this.network.config.getNow()
this._config = await this.call({ _: 'help.getConfig' })
if (config) {
this.network.config.setConfig({
...config,
dcOptions: upd.dcOptions,
})
} else { } else {
(this._config as tl.Mutable<tl.TypeConfig>).dcOptions = await this.network.config.update(true)
upd.dcOptions
} }
break break
}
case 'updateConfig': case 'updateConfig':
this._config = await this.call({ _: 'help.getConfig' }) await this.network.config.update(true)
break break
case 'updateUserName': case 'updateUserName':
if (upd.userId === this._userId) { if (upd.userId === this._userId) {
@ -1753,10 +1756,12 @@ export async function _updatesLoop(this: TelegramClient): Promise<void> {
log.debug( log.debug(
'waiting for %d pending diffs before processing unordered: %j', 'waiting for %d pending diffs before processing unordered: %j',
pendingDiffs.length, pendingDiffs.length,
Object.keys(requestedDiff), // fixme Object.keys(requestedDiff),
) )
// this.primaryConnection._flushSendQueue() // fixme // is this necessary?
// this.primaryConnection._flushSendQueue()
await Promise.all(pendingDiffs) await Promise.all(pendingDiffs)
// diff results may as well contain new diffs to be requested // diff results may as well contain new diffs to be requested
@ -1764,7 +1769,7 @@ export async function _updatesLoop(this: TelegramClient): Promise<void> {
log.debug( log.debug(
'pending diffs awaited, new diffs requested: %d (%j)', 'pending diffs awaited, new diffs requested: %d (%j)',
pendingDiffs.length, pendingDiffs.length,
Object.keys(requestedDiff), // fixme Object.keys(requestedDiff),
) )
} }
@ -1784,11 +1789,12 @@ export async function _updatesLoop(this: TelegramClient): Promise<void> {
log.debug( log.debug(
'waiting for %d pending diffs after processing unordered: %j', 'waiting for %d pending diffs after processing unordered: %j',
pendingDiffs.length, pendingDiffs.length,
Object.keys(requestedDiff), // fixme Object.keys(requestedDiff),
) )
// fixme // is this necessary?
// this.primaryConnection._flushSendQueue() // this.primaryConnection._flushSendQueue()
await Promise.all(pendingDiffs) await Promise.all(pendingDiffs)
// diff results may as well contain new diffs to be requested // diff results may as well contain new diffs to be requested
@ -1796,7 +1802,7 @@ export async function _updatesLoop(this: TelegramClient): Promise<void> {
log.debug( log.debug(
'pending diffs awaited, new diffs requested: %d (%j)', 'pending diffs awaited, new diffs requested: %d (%j)',
pendingDiffs.length, pendingDiffs.length,
Object.keys(requestedDiff), // fixme Object.keys(requestedDiff),
) )
} }
@ -1815,5 +1821,4 @@ export async function _updatesLoop(this: TelegramClient): Promise<void> {
export function _keepAliveAction(this: TelegramClient): void { export function _keepAliveAction(this: TelegramClient): void {
this._updsLog.debug('no updates for >15 minutes, catching up') this._updsLog.debug('no updates for >15 minutes, catching up')
this._handleUpdate({ _: 'updatesTooLong' }) this._handleUpdate({ _: 'updatesTooLong' })
// this.catchUp().catch((err) => this._emitError(err))
} }

View file

@ -133,8 +133,7 @@ export class Conversation {
const pending = this.client['_pendingConversations'] const pending = this.client['_pendingConversations']
const idx = const idx = pending[this._chatId].indexOf(this)
pending[this._chatId].indexOf(this)
if (idx > -1) { if (idx > -1) {
// just in case // just in case
@ -143,8 +142,7 @@ export class Conversation {
if (!pending[this._chatId].length) { if (!pending[this._chatId].length) {
delete pending[this._chatId] delete pending[this._chatId]
} }
this.client['_hasConversations'] = this.client['_hasConversations'] = Object.keys(pending).length > 0
Object.keys(pending).length > 0
// reset pending status // reset pending status
this._queuedNewMessage.clear() this._queuedNewMessage.clear()
@ -279,6 +277,7 @@ export class Conversation {
if (timeout !== null) { if (timeout !== null) {
timer = setTimeout(() => { timer = setTimeout(() => {
console.log('timed out')
promise.reject(new tl.errors.TimeoutError()) promise.reject(new tl.errors.TimeoutError())
this._queuedNewMessage.removeBy((it) => it.promise === promise) this._queuedNewMessage.removeBy((it) => it.promise === promise)
}, timeout) }, timeout)
@ -537,7 +536,9 @@ export class Conversation {
it.promise.resolve(msg) it.promise.resolve(msg)
delete this._pendingEditMessage[msg.id] delete this._pendingEditMessage[msg.id]
} }
})().catch((e) => this.client['_emitError'](e)) })().catch((e) => {
this.client['_emitError'](e)
})
} }
private _onHistoryRead(upd: HistoryReadUpdate) { private _onHistoryRead(upd: HistoryReadUpdate) {

View file

@ -4,6 +4,7 @@ import { tl } from '@mtcute/tl'
import { TelegramClient } from '../../client' import { TelegramClient } from '../../client'
import { makeInspectable } from '../utils' import { makeInspectable } from '../utils'
import { FileDownloadParameters } from './utils'
/** /**
* Information about file location. * Information about file location.
@ -50,48 +51,61 @@ export class FileLocation {
* in chunks of a given size. Order of the chunks is guaranteed to be * in chunks of a given size. Order of the chunks is guaranteed to be
* consecutive. * consecutive.
* *
* Shorthand for `client.downloadAsIterable({ location: this })` * @param params Download parameters
*
* @link TelegramClient.downloadAsIterable * @link TelegramClient.downloadAsIterable
*/ */
downloadIterable(): AsyncIterableIterator<Buffer> { downloadIterable(
return this.client.downloadAsIterable({ location: this }) params?: Partial<FileDownloadParameters>,
): AsyncIterableIterator<Buffer> {
return this.client.downloadAsIterable({
...params,
location: this,
})
} }
/** /**
* Download a file and return it as a Node readable stream, * Download a file and return it as a Node readable stream,
* streaming file contents. * streaming file contents.
* *
* Shorthand for `client.downloadAsStream({ location: this })`
*
* @link TelegramClient.downloadAsStream * @link TelegramClient.downloadAsStream
*/ */
downloadStream(): Readable { downloadStream(params?: Partial<FileDownloadParameters>): Readable {
return this.client.downloadAsStream({ location: this }) return this.client.downloadAsStream({
...params,
location: this,
})
} }
/** /**
* Download a file and return its contents as a Buffer. * Download a file and return its contents as a Buffer.
* *
* Shorthand for `client.downloadAsBuffer({ location: this })` * @param params File download parameters
*
* @link TelegramClient.downloadAsBuffer * @link TelegramClient.downloadAsBuffer
*/ */
downloadBuffer(): Promise<Buffer> { downloadBuffer(params?: Partial<FileDownloadParameters>): Promise<Buffer> {
return this.client.downloadAsBuffer({ location: this }) return this.client.downloadAsBuffer({
...params,
location: this,
})
} }
/** /**
* Download a remote file to a local file (only for NodeJS). * Download a remote file to a local file (only for NodeJS).
* Promise will resolve once the download is complete. * Promise will resolve once the download is complete.
* *
* Shorthand for `client.downloadToFile(filename, { location: this })`
*
* @param filename Local file name * @param filename Local file name
* @param params File download parameters
* @link TelegramClient.downloadToFile * @link TelegramClient.downloadToFile
*/ */
downloadToFile(filename: string): Promise<void> { downloadToFile(
return this.client.downloadToFile(filename, { location: this }) filename: string,
params?: Partial<FileDownloadParameters>,
): Promise<void> {
return this.client.downloadToFile(filename, {
...params,
location: this,
fileSize: this.fileSize,
})
} }
} }

View file

@ -97,7 +97,7 @@ export interface FileDownloadParameters {
offset?: number offset?: number
/** /**
* Number of chunks (!) of that given size that will be downloaded. * Number of bytes to be downloaded.
* By default, downloads the entire file * By default, downloads the entire file
*/ */
limit?: number limit?: number

View file

@ -1,4 +1,4 @@
import { MustEqual } from '@mtcute/core' import { MustEqual, RpcCallOptions } from '@mtcute/core'
import { tl } from '@mtcute/tl' import { tl } from '@mtcute/tl'
import { TelegramClient } from '../../client' import { TelegramClient } from '../../client'
@ -31,9 +31,7 @@ export class TakeoutSession {
*/ */
async call<T extends tl.RpcMethod>( async call<T extends tl.RpcMethod>(
message: MustEqual<T, tl.RpcMethod>, message: MustEqual<T, tl.RpcMethod>,
params?: { params?: RpcCallOptions,
throwFlood: boolean
},
): Promise<tl.RpcCallReturn[T['_']]> { ): Promise<tl.RpcCallReturn[T['_']]> {
return this.client.call( return this.client.call(
{ {

View file

@ -5,10 +5,9 @@ import { MtArgumentError } from '../types'
* for upload/download operations. * for upload/download operations.
*/ */
export function determinePartSize(fileSize: number): number { export function determinePartSize(fileSize: number): number {
if (fileSize <= 104857600) return 128 // 100 MB if (fileSize <= 262078465) return 128 // 200 MB
if (fileSize <= 786432000) return 256 // 750 MB if (fileSize <= 786432000) return 256 // 750 MB
if (fileSize <= 2097152000) return 512 // 2000 MB if (fileSize <= 2097152000) return 512 // 2000 MB
if (fileSize <= 4194304000) return 1024 // 4000 MB
throw new MtArgumentError('File is too large') throw new MtArgumentError('File is too large')
} }

View file

@ -66,3 +66,11 @@ export function encodeInlineMessageId(
return encodeUrlSafeBase64(writer.result()) return encodeUrlSafeBase64(writer.result())
} }
export function normalizeInlineId(id: string | tl.TypeInputBotInlineMessageID) {
if (typeof id === 'string') {
return parseInlineMessageId(id)
}
return id
}

View file

@ -35,7 +35,9 @@ class NodeReadable extends Readable {
return return
} }
if (this.push(res.value)) { if (this.push(res.value)) {
return doRead() doRead()
return
} }
this._reading = false this._reading = false
this._reader.releaseLock() this._reader.releaseLock()
@ -49,7 +51,9 @@ class NodeReadable extends Readable {
const promise = new Promise<void>((resolve) => { const promise = new Promise<void>((resolve) => {
this._doneReading = resolve this._doneReading = resolve
}) })
promise.then(() => this._handleDestroy(err, callback)) promise.then(() => {
this._handleDestroy(err, callback)
})
} else { } else {
this._handleDestroy(err, callback) this._handleDestroy(err, callback)
} }
@ -71,26 +75,6 @@ export function convertWebStreamToNodeReadable(
return new NodeReadable(webStream, opts) return new NodeReadable(webStream, opts)
} }
export async function readStreamUntilEnd(stream: Readable): Promise<Buffer> {
const chunks = []
let length = 0
while (stream.readable) {
const c = await stream.read()
if (c === null) break
length += c.length
if (length > 2097152000) {
throw new Error('File is too big')
}
chunks.push(c)
}
return Buffer.concat(chunks)
}
export function bufferToStream(buf: Buffer): Readable { export function bufferToStream(buf: Buffer): Readable {
return new Readable({ return new Readable({
read() { read() {
@ -109,15 +93,17 @@ export async function readBytesFromStream(
let res = stream.read(size) let res = stream.read(size)
if (!res) { if (!res) {
return new Promise((resolve) => { return new Promise((resolve, reject) => {
stream.on('readable', function handler() { stream.on('readable', function handler() {
res = stream.read(size) res = stream.read(size)
if (res) { if (res) {
stream.off('readable', handler) stream.off('readable', handler)
stream.off('error', reject)
resolve(res) resolve(res)
} }
}) })
stream.on('error', reject)
}) })
} }

View file

@ -1,21 +0,0 @@
import { expect } from 'chai'
import { describe, it } from 'mocha'
import { Readable } from 'stream'
import { readStreamUntilEnd } from '../src/utils/stream-utils'
describe('readStreamUntilEnd', () => {
it('should read stream until end', async () => {
const stream = new Readable({
read() {
this.push(Buffer.from('aaeeff', 'hex'))
this.push(Buffer.from('ff33ee', 'hex'))
this.push(null)
},
})
expect((await readStreamUntilEnd(stream)).toString('hex')).eq(
'aaeeffff33ee',
)
})
})

View file

@ -8,12 +8,16 @@ import defaultWriterMap from '@mtcute/tl/binary/writer'
import { TlReaderMap, TlWriterMap } from '@mtcute/tl-runtime' import { TlReaderMap, TlWriterMap } from '@mtcute/tl-runtime'
import { import {
defaultReconnectionStrategy,
defaultTransportFactory,
ReconnectionStrategy, ReconnectionStrategy,
SessionConnection, SessionConnection,
TransportFactory, TransportFactory,
} from './network' } from './network'
import { ConfigManager } from './network/config-manager'
import {
NetworkManager,
NetworkManagerExtraParams,
RpcCallOptions,
} from './network/network-manager'
import { PersistentConnectionParams } from './network/persistent-connection' import { PersistentConnectionParams } from './network/persistent-connection'
import { ITelegramStorage, MemoryStorage } from './storage' import { ITelegramStorage, MemoryStorage } from './storage'
import { MustEqual } from './types' import { MustEqual } from './types'
@ -29,11 +33,10 @@ import {
getAllPeersFrom, getAllPeersFrom,
ICryptoProvider, ICryptoProvider,
LogManager, LogManager,
sleep, readStringSession,
toggleChannelIdMark, toggleChannelIdMark,
writeStringSession,
} from './utils' } from './utils'
import { addPublicKey } from './utils/crypto/keys'
import { readStringSession, writeStringSession } from './utils/string-session'
export interface BaseTelegramClientOptions { export interface BaseTelegramClientOptions {
/** /**
@ -74,15 +77,15 @@ export interface BaseTelegramClientOptions {
* When session already contains primary DC, this parameter is ignored. * When session already contains primary DC, this parameter is ignored.
* Defaults to Production DC 2. * Defaults to Production DC 2.
*/ */
primaryDc?: tl.RawDcOption defaultDc?: tl.RawDcOption
/** /**
* Whether to connect to test servers. * Whether to connect to test servers.
* *
* If passed, {@link primaryDc} defaults to Test DC 2. * If passed, {@link defaultDc} defaults to Test DC 2.
* *
* **Must** be passed if using test servers, even if * **Must** be passed if using test servers, even if
* you passed custom {@link primaryDc} * you passed custom {@link defaultDc}
*/ */
testMode?: boolean testMode?: boolean
@ -123,7 +126,7 @@ export interface BaseTelegramClientOptions {
* *
* @default 5 * @default 5
*/ */
rpcRetryCount?: number maxRetryCount?: number
/** /**
* If true, every single API call will be wrapped with `tl.invokeWithoutUpdates`, * If true, every single API call will be wrapped with `tl.invokeWithoutUpdates`,
@ -152,6 +155,11 @@ export interface BaseTelegramClientOptions {
*/ */
niceStacks?: boolean niceStacks?: boolean
/**
* Extra parameters for {@link NetworkManager}
*/
network?: NetworkManagerExtraParams
/** /**
* **EXPERT USE ONLY!** * **EXPERT USE ONLY!**
* *
@ -178,93 +186,52 @@ export interface BaseTelegramClientOptions {
export class BaseTelegramClient extends EventEmitter { export class BaseTelegramClient extends EventEmitter {
/** /**
* `initConnection` params taken from {@link BaseTelegramClient.Options.initConnectionOptions}. * Crypto provider taken from {@link BaseTelegramClientOptions.crypto}
*/
protected readonly _initConnectionParams: tl.RawInitConnectionRequest
/**
* Crypto provider taken from {@link BaseTelegramClient.Options.crypto}
*/ */
protected readonly _crypto: ICryptoProvider protected readonly _crypto: ICryptoProvider
/** /**
* Transport factory taken from {@link BaseTelegramClient.Options.transport} * Telegram storage taken from {@link BaseTelegramClientOptions.storage}
*/
protected readonly _transportFactory: TransportFactory
/**
* Telegram storage taken from {@link BaseTelegramClient.Options.storage}
*/ */
readonly storage: ITelegramStorage readonly storage: ITelegramStorage
/** /**
* API hash taken from {@link BaseTelegramClient.Options.apiHash} * API hash taken from {@link BaseTelegramClientOptions.apiHash}
*/ */
protected readonly _apiHash: string protected readonly _apiHash: string
/** /**
* "Use IPv6" taken from {@link BaseTelegramClient.Options.useIpv6} * "Use IPv6" taken from {@link BaseTelegramClientOptions.useIpv6}
*/ */
protected readonly _useIpv6: boolean protected readonly _useIpv6: boolean
/** /**
* "Test mode" taken from {@link BaseTelegramClient.Options.testMode} * "Test mode" taken from {@link BaseTelegramClientOptions.testMode}
*/ */
protected readonly _testMode: boolean protected readonly _testMode: boolean
/** /**
* Reconnection strategy taken from {@link BaseTelegramClient.Options.reconnectionStrategy} * Primary DC taken from {@link BaseTelegramClientOptions.defaultDc},
*/
protected readonly _reconnectionStrategy: ReconnectionStrategy<PersistentConnectionParams>
/**
* Flood sleep threshold taken from {@link BaseTelegramClient.Options.floodSleepThreshold}
*/
protected readonly _floodSleepThreshold: number
/**
* RPC retry count taken from {@link BaseTelegramClient.Options.rpcRetryCount}
*/
protected readonly _rpcRetryCount: number
/**
* "Disable updates" taken from {@link BaseTelegramClient.Options.disableUpdates}
*/
protected readonly _disableUpdates: boolean
/**
* Primary DC taken from {@link BaseTelegramClient.Options.primaryDc},
* loaded from session or changed by other means (like redirecting). * loaded from session or changed by other means (like redirecting).
*/ */
protected _primaryDc: tl.RawDcOption protected _defaultDc: tl.RawDcOption
private _niceStacks: boolean private _niceStacks: boolean
readonly _layer: number readonly _layer: number
readonly _readerMap: TlReaderMap readonly _readerMap: TlReaderMap
readonly _writerMap: TlWriterMap readonly _writerMap: TlWriterMap
private _keepAliveInterval?: NodeJS.Timeout
protected _lastUpdateTime = 0 protected _lastUpdateTime = 0
private _floodWaitedRequests: Record<string, number> = {}
protected _config?: tl.RawConfig protected _config = new ConfigManager(() =>
protected _cdnConfig?: tl.RawCdnConfig this.call({ _: 'help.getConfig' }),
)
private _additionalConnections: SessionConnection[] = []
// not really connected, but rather "connect() was called" // not really connected, but rather "connect() was called"
private _connected: ControllablePromise<void> | boolean = false private _connected: ControllablePromise<void> | boolean = false
private _onError?: (err: unknown, connection?: SessionConnection) => void private _onError?: (err: unknown, connection?: SessionConnection) => void
/**
* The primary {@link SessionConnection} that is used for
* most of the communication with Telegram.
*
* Methods for downloading/uploading files may create additional connections as needed.
*/
primaryConnection!: SessionConnection
private _importFrom?: string private _importFrom?: string
private _importForce?: boolean private _importForce?: boolean
@ -278,7 +245,8 @@ export class BaseTelegramClient extends EventEmitter {
// eslint-disable-next-line @typescript-eslint/no-unused-vars // eslint-disable-next-line @typescript-eslint/no-unused-vars
protected _handleUpdate(update: tl.TypeUpdates): void {} protected _handleUpdate(update: tl.TypeUpdates): void {}
readonly log = new LogManager() readonly log = new LogManager('client')
readonly network: NetworkManager
constructor(opts: BaseTelegramClientOptions) { constructor(opts: BaseTelegramClientOptions) {
super() super()
@ -290,14 +258,13 @@ export class BaseTelegramClient extends EventEmitter {
throw new Error('apiId must be a number or a numeric string!') throw new Error('apiId must be a number or a numeric string!')
} }
this._transportFactory = opts.transport ?? defaultTransportFactory
this._crypto = (opts.crypto ?? defaultCryptoProviderFactory)() this._crypto = (opts.crypto ?? defaultCryptoProviderFactory)()
this.storage = opts.storage ?? new MemoryStorage() this.storage = opts.storage ?? new MemoryStorage()
this._apiHash = opts.apiHash this._apiHash = opts.apiHash
this._useIpv6 = Boolean(opts.useIpv6) this._useIpv6 = Boolean(opts.useIpv6)
this._testMode = Boolean(opts.testMode) this._testMode = Boolean(opts.testMode)
let dc = opts.primaryDc let dc = opts.defaultDc
if (!dc) { if (!dc) {
if (this._testMode) { if (this._testMode) {
@ -309,42 +276,47 @@ export class BaseTelegramClient extends EventEmitter {
} }
} }
this._primaryDc = dc this._defaultDc = dc
this._reconnectionStrategy =
opts.reconnectionStrategy ?? defaultReconnectionStrategy
this._floodSleepThreshold = opts.floodSleepThreshold ?? 10000
this._rpcRetryCount = opts.rpcRetryCount ?? 5
this._disableUpdates = opts.disableUpdates ?? false
this._niceStacks = opts.niceStacks ?? true this._niceStacks = opts.niceStacks ?? true
this._layer = opts.overrideLayer ?? tl.LAYER this._layer = opts.overrideLayer ?? tl.LAYER
this._readerMap = opts.readerMap ?? defaultReaderMap this._readerMap = opts.readerMap ?? defaultReaderMap
this._writerMap = opts.writerMap ?? defaultWriterMap this._writerMap = opts.writerMap ?? defaultWriterMap
this.network = new NetworkManager(
{
apiId,
crypto: this._crypto,
disableUpdates: opts.disableUpdates ?? false,
initConnectionOptions: opts.initConnectionOptions,
layer: this._layer,
log: this.log,
readerMap: this._readerMap,
writerMap: this._writerMap,
reconnectionStrategy: opts.reconnectionStrategy,
storage: this.storage,
testMode: this._testMode,
transport: opts.transport,
_emitError: this._emitError.bind(this),
floodSleepThreshold: opts.floodSleepThreshold ?? 10000,
maxRetryCount: opts.maxRetryCount ?? 5,
isPremium: false,
useIpv6: Boolean(opts.useIpv6),
keepAliveAction: this._keepAliveAction.bind(this),
...(opts.network ?? {}),
},
this._config,
)
this.storage.setup?.(this.log, this._readerMap, this._writerMap) this.storage.setup?.(this.log, this._readerMap, this._writerMap)
}
let deviceModel = 'mtcute on ' protected _keepAliveAction(): void {
if (typeof process !== 'undefined' && typeof require !== 'undefined') { // core does not have update handling, so we just use getState so the server knows
// eslint-disable-next-line @typescript-eslint/no-var-requires // we still do need updates
const os = require('os') this.call({ _: 'updates.getState' }).catch((e) => {
deviceModel += `${os.type()} ${os.arch()} ${os.release()}` this.log.error('failed to send keep-alive: %s', e)
} else if (typeof navigator !== 'undefined') { })
deviceModel += navigator.userAgent
} else deviceModel += 'unknown'
this._initConnectionParams = {
_: 'initConnection',
deviceModel,
systemVersion: '1.0',
appVersion: '1.0.0',
systemLangCode: 'en',
langPack: '', // "langPacks are for official apps only"
langCode: 'en',
...(opts.initConnectionOptions ?? {}),
apiId,
query: null as any,
}
} }
protected async _loadStorage(): Promise<void> { protected async _loadStorage(): Promise<void> {
@ -356,72 +328,6 @@ export class BaseTelegramClient extends EventEmitter {
await this.storage.save?.() await this.storage.save?.()
} }
protected _keepAliveAction(): void {
if (this._disableUpdates) return
// telegram asks to fetch pending updates
// if there are no updates for 15 minutes.
// core does not have update handling,
// so we just use getState so the server knows
// we still do need updates
this.call({ _: 'updates.getState' }).catch((e) => {
if (!(e instanceof tl.errors.RpcError)) {
this.primaryConnection.reconnect()
}
})
}
private _cleanupPrimaryConnection(forever = false): void {
if (forever && this.primaryConnection) this.primaryConnection.destroy()
if (this._keepAliveInterval) clearInterval(this._keepAliveInterval)
}
private _setupPrimaryConnection(): void {
this._cleanupPrimaryConnection(true)
this.primaryConnection = new SessionConnection(
{
crypto: this._crypto,
initConnection: this._initConnectionParams,
transportFactory: this._transportFactory,
dc: this._primaryDc,
testMode: this._testMode,
reconnectionStrategy: this._reconnectionStrategy,
layer: this._layer,
disableUpdates: this._disableUpdates,
readerMap: this._readerMap,
writerMap: this._writerMap,
},
this.log.create('connection'),
)
this.primaryConnection.on('usable', () => {
this._lastUpdateTime = Date.now()
if (this._keepAliveInterval) clearInterval(this._keepAliveInterval)
this._keepAliveInterval = setInterval(async () => {
if (Date.now() - this._lastUpdateTime > 900_000) {
this._keepAliveAction()
this._lastUpdateTime = Date.now()
}
}, 60_000)
})
this.primaryConnection.on('update', (update) => {
this._lastUpdateTime = Date.now()
this._handleUpdate(update)
})
this.primaryConnection.on('wait', () =>
this._cleanupPrimaryConnection(),
)
this.primaryConnection.on('key-change', async (key) => {
this.storage.setAuthKeyFor(this._primaryDc.id, key)
await this._saveStorage()
})
this.primaryConnection.on('error', (err) =>
this._emitError(err, this.primaryConnection),
)
}
/** /**
* Initialize the connection to the primary DC. * Initialize the connection to the primary DC.
* *
@ -430,61 +336,59 @@ export class BaseTelegramClient extends EventEmitter {
*/ */
async connect(): Promise<void> { async connect(): Promise<void> {
if (this._connected) { if (this._connected) {
// avoid double-connect
await this._connected await this._connected
return return
} }
this._connected = createControllablePromise() // we cant do this in constructor because we need to support subclassing
this.network.setUpdateHandler(this._handleUpdate.bind(this))
const promise = (this._connected = createControllablePromise())
await this._loadStorage() await this._loadStorage()
const primaryDc = await this.storage.getDefaultDc() const primaryDc = await this.storage.getDefaultDc()
if (primaryDc !== null) this._primaryDc = primaryDc if (primaryDc !== null) this._defaultDc = primaryDc
this._setupPrimaryConnection() const defaultDcAuthKey = await this.storage.getAuthKeyFor(
this._defaultDc.id,
await this.primaryConnection.setupKeys(
await this.storage.getAuthKeyFor(this._primaryDc.id),
) )
if ( if ((this._importForce || !defaultDcAuthKey) && this._importFrom) {
(this._importForce || !this.primaryConnection.getAuthKey()) &&
this._importFrom
) {
const data = readStringSession(this._readerMap, this._importFrom) const data = readStringSession(this._readerMap, this._importFrom)
if (data.testMode !== !this._testMode) { if (data.testMode !== this._testMode) {
throw new Error( throw new Error(
'This session string is not for the current backend', 'This session string is not for the current backend. ' +
`Session is ${
data.testMode ? 'test' : 'prod'
}, but the client is ${
this._testMode ? 'test' : 'prod'
}`,
) )
} }
this._primaryDc = this.primaryConnection.params.dc = data.primaryDc this._defaultDc = data.primaryDc
await this.storage.setDefaultDc(data.primaryDc) await this.storage.setDefaultDc(data.primaryDc)
if (data.self) { if (data.self) {
await this.storage.setSelf(data.self) await this.storage.setSelf(data.self)
} }
await this.primaryConnection.setupKeys(data.authKey) // await this.primaryConnection.setupKeys(data.authKey)
await this.storage.setAuthKeyFor(data.primaryDc.id, data.authKey) await this.storage.setAuthKeyFor(data.primaryDc.id, data.authKey)
await this._saveStorage(true) await this._saveStorage(true)
} }
this._connected.resolve() this.network
this._connected = true .connect(this._defaultDc)
.then(() => {
this.primaryConnection.connect() promise.resolve()
} this._connected = true
})
/** .catch((err) => this._emitError(err))
* Wait until this client is usable (i.e. connection is fully ready)
*/
async waitUntilUsable(): Promise<void> {
return new Promise((resolve) => {
this.primaryConnection.once('usable', resolve)
})
} }
/** /**
@ -499,107 +403,13 @@ export class BaseTelegramClient extends EventEmitter {
async close(): Promise<void> { async close(): Promise<void> {
await this._onClose() await this._onClose()
this._cleanupPrimaryConnection(true) this._config.destroy()
// close additional connections this.network.destroy()
this._additionalConnections.forEach((conn) => conn.destroy())
await this._saveStorage() await this._saveStorage()
await this.storage.destroy?.() await this.storage.destroy?.()
} }
/**
* Utility function to find the DC by its ID.
*
* @param id Datacenter ID
* @param preferMedia Whether to prefer media-only DCs
* @param cdn Whether the needed DC is a CDN DC
*/
async getDcById(
id: number,
preferMedia = false,
cdn = false,
): Promise<tl.RawDcOption> {
if (!this._config) {
this._config = await this.call({ _: 'help.getConfig' })
}
if (cdn && !this._cdnConfig) {
this._cdnConfig = await this.call({ _: 'help.getCdnConfig' })
for (const key of this._cdnConfig.publicKeys) {
await addPublicKey(this._crypto, key.publicKey)
}
}
if (this._useIpv6) {
// first try to find ipv6 dc
let found
if (preferMedia) {
found = this._config.dcOptions.find(
(it) =>
it.id === id &&
it.mediaOnly &&
it.cdn === cdn &&
it.ipv6 &&
!it.tcpoOnly,
)
}
if (!found) {
found = this._config.dcOptions.find(
(it) =>
it.id === id &&
it.cdn === cdn &&
it.ipv6 &&
!it.tcpoOnly,
)
}
if (found) return found
}
let found
if (preferMedia) {
found = this._config.dcOptions.find(
(it) =>
it.id === id &&
it.mediaOnly &&
it.cdn === cdn &&
!it.tcpoOnly &&
!it.ipv6,
)
}
if (!found) {
found = this._config.dcOptions.find(
(it) =>
it.id === id && it.cdn === cdn && !it.tcpoOnly && !it.ipv6,
)
}
if (found) return found
throw new Error(`Could not find${cdn ? ' CDN' : ''} DC ${id}`)
}
/**
* Change primary DC and write that fact to the storage.
* Will immediately reconnect to another DC.
*
* @param newDc New DC or its ID
*/
async changeDc(newDc: tl.RawDcOption | number): Promise<void> {
if (typeof newDc === 'number') {
newDc = await this.getDcById(newDc)
}
this._primaryDc = newDc
await this.storage.setDefaultDc(newDc)
await this._saveStorage()
await this.primaryConnection.changeDc(newDc)
}
/** /**
* Make an RPC call to the primary DC. * Make an RPC call to the primary DC.
* This method handles DC migration, flood waits and retries automatically. * This method handles DC migration, flood waits and retries automatically.
@ -615,227 +425,18 @@ export class BaseTelegramClient extends EventEmitter {
*/ */
async call<T extends tl.RpcMethod>( async call<T extends tl.RpcMethod>(
message: MustEqual<T, tl.RpcMethod>, message: MustEqual<T, tl.RpcMethod>,
params?: { params?: RpcCallOptions,
throwFlood?: boolean
connection?: SessionConnection
timeout?: number
},
): Promise<tl.RpcCallReturn[T['_']]> { ): Promise<tl.RpcCallReturn[T['_']]> {
if (this._connected !== true) { if (this._connected !== true) {
await this.connect() await this.connect()
} }
// do not send requests that are in flood wait
if (message._ in this._floodWaitedRequests) {
const delta = this._floodWaitedRequests[message._] - Date.now()
if (delta <= 3000) {
// flood waits below 3 seconds are "ignored"
delete this._floodWaitedRequests[message._]
} else if (delta <= this._floodSleepThreshold) {
await sleep(delta)
delete this._floodWaitedRequests[message._]
} else {
throw new tl.errors.FloodWaitXError(delta / 1000)
}
}
const connection = params?.connection ?? this.primaryConnection
let lastError: Error | null = null
const stack = this._niceStacks ? new Error().stack : undefined const stack = this._niceStacks ? new Error().stack : undefined
for (let i = 0; i < this._rpcRetryCount; i++) { const res = await this.network.call(message, params, stack)
try { await this._cachePeersFrom(res)
const res = await connection.sendRpc(
message,
stack,
params?.timeout,
)
await this._cachePeersFrom(res)
return res return res
} catch (e: any) {
lastError = e
if (e instanceof tl.errors.InternalError) {
this.log.warn('Telegram is having internal issues: %s', e)
if (e.message === 'WORKER_BUSY_TOO_LONG_RETRY') {
// according to tdlib, "it is dangerous to resend query without timeout, so use 1"
await sleep(1000)
}
continue
}
if (
e.constructor === tl.errors.FloodWaitXError ||
e.constructor === tl.errors.SlowmodeWaitXError ||
e.constructor === tl.errors.FloodTestPhoneWaitXError
) {
if (e.constructor !== tl.errors.SlowmodeWaitXError) {
// SLOW_MODE_WAIT is chat-specific, not request-specific
this._floodWaitedRequests[message._] =
Date.now() + e.seconds * 1000
}
// In test servers, FLOOD_WAIT_0 has been observed, and sleeping for
// such a short amount will cause retries very fast leading to issues
if (e.seconds === 0) {
(e as any).seconds = 1
}
if (
params?.throwFlood !== true &&
e.seconds <= this._floodSleepThreshold
) {
this.log.info('Flood wait for %d seconds', e.seconds)
await sleep(e.seconds * 1000)
continue
}
}
if (connection.params.dc.id === this._primaryDc.id) {
if (
e.constructor === tl.errors.PhoneMigrateXError ||
e.constructor === tl.errors.UserMigrateXError ||
e.constructor === tl.errors.NetworkMigrateXError
) {
this.log.info('Migrate error, new dc = %d', e.new_dc)
await this.changeDc(e.new_dc)
continue
}
} else if (
e.constructor === tl.errors.AuthKeyUnregisteredError
) {
// we can try re-exporting auth from the primary connection
this.log.warn('exported auth key error, re-exporting..')
const auth = await this.call({
_: 'auth.exportAuthorization',
dcId: connection.params.dc.id,
})
await connection.sendRpc({
_: 'auth.importAuthorization',
id: auth.id,
bytes: auth.bytes,
})
continue
}
throw e
}
}
throw lastError
}
/**
* Creates an additional connection to a given DC.
* This will use auth key for that DC that was already stored
* in the session, or generate a new auth key by exporting
* authorization from primary DC and importing it to the new DC.
* New connection will use the same crypto provider, `initConnection`,
* transport and reconnection strategy as the primary connection
*
* This method is quite low-level and you shouldn't usually care about this
* when using high-level API provided by `@mtcute/client`.
*
* @param dcId DC id, to which the connection will be created
* @param cdn Whether that DC is a CDN DC
* @param inactivityTimeout
* Inactivity timeout for the connection (in ms), after which the transport will be closed.
* Note that connection can still be used normally, it's just the transport which is closed.
* Defaults to 5 min
*/
async createAdditionalConnection(
dcId: number,
params?: {
// todo proper docs
// default = false
media?: boolean
// default = fa;se
cdn?: boolean
// default = 300_000
inactivityTimeout?: number
// default = false
disableUpdates?: boolean
},
): Promise<SessionConnection> {
const dc = await this.getDcById(dcId, params?.media, params?.cdn)
const connection = new SessionConnection(
{
dc,
testMode: this._testMode,
crypto: this._crypto,
initConnection: this._initConnectionParams,
transportFactory: this._transportFactory,
reconnectionStrategy: this._reconnectionStrategy,
inactivityTimeout: params?.inactivityTimeout ?? 300_000,
layer: this._layer,
disableUpdates: params?.disableUpdates,
readerMap: this._readerMap,
writerMap: this._writerMap,
},
this.log.create('connection'),
)
connection.on('error', (err) => this._emitError(err, connection))
await connection.setupKeys(await this.storage.getAuthKeyFor(dc.id))
connection.connect()
if (!connection.getAuthKey()) {
this.log.info('exporting auth to DC %d', dcId)
const auth = await this.call({
_: 'auth.exportAuthorization',
dcId,
})
await connection.sendRpc({
_: 'auth.importAuthorization',
id: auth.id,
bytes: auth.bytes,
})
// connection.authKey was already generated at this point
this.storage.setAuthKeyFor(dc.id, connection.getAuthKey()!)
await this._saveStorage()
} else {
// in case the auth key is invalid
const dcId = dc.id
connection.on('key-change', async (key) => {
// we don't need to export, it will be done by `.call()`
// in case this error is returned
//
// even worse, exporting here will lead to a race condition,
// and may result in redundant re-exports.
this.storage.setAuthKeyFor(dcId, key)
await this._saveStorage()
})
}
this._additionalConnections.push(connection)
return connection
}
/**
* Destroy a connection that was previously created using
* {@link BaseTelegramClient.createAdditionalConnection}.
* Passing any other connection will not have any effect.
*
* @param connection Connection created with {@link BaseTelegramClient.createAdditionalConnection}
*/
async destroyAdditionalConnection(
connection: SessionConnection,
): Promise<void> {
const idx = this._additionalConnections.indexOf(connection)
if (idx === -1) return
await connection.destroy()
this._additionalConnections.splice(idx, 1)
} }
/** /**
@ -849,11 +450,7 @@ export class BaseTelegramClient extends EventEmitter {
* @param factory New transport factory * @param factory New transport factory
*/ */
changeTransport(factory: TransportFactory): void { changeTransport(factory: TransportFactory): void {
this.primaryConnection.changeTransport(factory) this.network.changeTransport(factory)
this._additionalConnections.forEach((conn) =>
conn.changeTransport(factory),
)
} }
/** /**
@ -865,7 +462,9 @@ export class BaseTelegramClient extends EventEmitter {
* the connection in which the error has occurred, in case * the connection in which the error has occurred, in case
* this was connection-related error. * this was connection-related error.
*/ */
onError(handler: typeof this._onError): void { onError(
handler: (err: unknown, connection?: SessionConnection) => void,
): void {
this._onError = handler this._onError = handler
} }
@ -950,9 +549,8 @@ export class BaseTelegramClient extends EventEmitter {
} }
} }
await this.storage.updatePeers(parsedPeers)
if (count > 0) { if (count > 0) {
await this.storage.updatePeers(parsedPeers)
this.log.debug('cached %d peers', count) this.log.debug('cached %d peers', count)
} }
@ -975,16 +573,18 @@ export class BaseTelegramClient extends EventEmitter {
* > with [@BotFather](//t.me/botfather) * > with [@BotFather](//t.me/botfather)
*/ */
async exportSession(): Promise<string> { async exportSession(): Promise<string> {
if (!this.primaryConnection.getAuthKey()) { const primaryDc = await this.storage.getDefaultDc()
throw new Error('Auth key is not generated yet') if (!primaryDc) throw new Error('No default DC set')
}
const authKey = await this.storage.getAuthKeyFor(primaryDc.id)
if (!authKey) throw new Error('Auth key is not ready yet')
return writeStringSession(this._writerMap, { return writeStringSession(this._writerMap, {
version: 1, version: 1,
self: await this.storage.getSelf(), self: await this.storage.getSelf(),
testMode: this._testMode, testMode: this._testMode,
primaryDc: this._primaryDc, primaryDc,
authKey: this.primaryConnection.getAuthKey()!, authKey,
}) })
} }
@ -996,7 +596,7 @@ export class BaseTelegramClient extends EventEmitter {
* *
* Also note that the session will only be imported in case * Also note that the session will only be imported in case
* the storage is missing authorization (i.e. does not contain * the storage is missing authorization (i.e. does not contain
* auth key for the primary DC), otherwise it will be ignored. * auth key for the primary DC), otherwise it will be ignored (unless `force).
* *
* @param session Session string to import * @param session Session string to import
* @param force Whether to overwrite existing session * @param force Whether to overwrite existing session

View file

@ -0,0 +1,166 @@
import Long from 'long'
import { tl } from '@mtcute/tl'
import { TlBinaryReader, TlReaderMap } from '@mtcute/tl-runtime'
import { buffersEqual, ICryptoProvider, Logger, randomBytes } from '../utils'
import { createAesIgeForMessage } from '../utils/crypto/mtproto'
export class AuthKey {
ready = false
key!: Buffer
id!: Buffer
clientSalt!: Buffer
serverSalt!: Buffer
constructor(
readonly _crypto: ICryptoProvider,
readonly log: Logger,
readonly _readerMap: TlReaderMap,
) {}
match(keyId: Buffer): boolean {
return this.ready && buffersEqual(keyId, this.id)
}
async setup(authKey?: Buffer | null): Promise<void> {
if (!authKey) return this.reset()
this.ready = true
this.key = authKey
this.clientSalt = authKey.slice(88, 120)
this.serverSalt = authKey.slice(96, 128)
this.id = (await this._crypto.sha1(authKey)).slice(-8)
this.log.verbose('auth key set up, id = %h', this.id)
}
async encryptMessage(
message: Buffer,
serverSalt: Long,
sessionId: Long,
): Promise<Buffer> {
if (!this.ready) throw new Error('Keys are not set up!')
let padding =
(16 /* header size */ + message.length + 12) /* min padding */ % 16
padding = 12 + (padding ? 16 - padding : 0)
const buf = Buffer.alloc(16 + message.length + padding)
buf.writeInt32LE(serverSalt.low)
buf.writeInt32LE(serverSalt.high, 4)
buf.writeInt32LE(sessionId.low, 8)
buf.writeInt32LE(sessionId.high, 12)
message.copy(buf, 16)
randomBytes(padding).copy(buf, 16 + message.length)
const messageKey = (
await this._crypto.sha256(Buffer.concat([this.clientSalt, buf]))
).slice(8, 24)
const ige = await createAesIgeForMessage(
this._crypto,
this.key,
messageKey,
true,
)
const encryptedData = await ige.encrypt(buf)
return Buffer.concat([this.id, messageKey, encryptedData])
}
async decryptMessage(
data: Buffer,
sessionId: Long,
callback: (msgId: tl.Long, seqNo: number, data: TlBinaryReader) => void,
): Promise<void> {
const messageKey = data.slice(8, 24)
const encryptedData = data.slice(24)
const ige = await createAesIgeForMessage(
this._crypto,
this.key,
messageKey,
false,
)
const innerData = await ige.decrypt(encryptedData)
const expectedMessageKey = (
await this._crypto.sha256(
Buffer.concat([this.serverSalt, innerData]),
)
).slice(8, 24)
if (!buffersEqual(messageKey, expectedMessageKey)) {
this.log.warn(
'[%h] received message with invalid messageKey = %h (expected %h)',
messageKey,
expectedMessageKey,
)
return
}
const innerReader = new TlBinaryReader(this._readerMap, innerData)
innerReader.seek(8) // skip salt
const sessionId_ = innerReader.long()
const messageId = innerReader.long(true)
if (sessionId_.neq(sessionId)) {
this.log.warn(
'ignoring message with invalid sessionId = %h',
sessionId_,
)
return
}
const seqNo = innerReader.uint()
const length = innerReader.uint()
if (length > innerData.length - 32 /* header size */) {
this.log.warn(
'ignoring message with invalid length: %d > %d',
length,
innerData.length - 32,
)
return
}
if (length % 4 !== 0) {
this.log.warn(
'ignoring message with invalid length: %d is not a multiple of 4',
length,
)
return
}
const paddingSize = innerData.length - length - 32 // header size
if (paddingSize < 12 || paddingSize > 1024) {
this.log.warn(
'ignoring message with invalid padding size: %d',
paddingSize,
)
return
}
callback(messageId, seqNo, innerReader)
}
copyFrom(authKey: AuthKey): void {
this.ready = authKey.ready
this.key = authKey.key
this.id = authKey.id
this.serverSalt = authKey.serverSalt
this.clientSalt = authKey.clientSalt
}
reset(): void {
this.ready = false
}
}

View file

@ -9,7 +9,12 @@ import {
TlSerializationCounter, TlSerializationCounter,
} from '@mtcute/tl-runtime' } from '@mtcute/tl-runtime'
import { bigIntToBuffer, bufferToBigInt, ICryptoProvider } from '../utils' import {
bigIntToBuffer,
bufferToBigInt,
ICryptoProvider,
Logger,
} from '../utils'
import { import {
buffersEqual, buffersEqual,
randomBytes, randomBytes,
@ -17,12 +22,123 @@ import {
xorBufferInPlace, xorBufferInPlace,
} from '../utils/buffer-utils' } from '../utils/buffer-utils'
import { findKeyByFingerprints } from '../utils/crypto/keys' import { findKeyByFingerprints } from '../utils/crypto/keys'
import { millerRabin } from '../utils/crypto/miller-rabin'
import { generateKeyAndIvFromNonce } from '../utils/crypto/mtproto' import { generateKeyAndIvFromNonce } from '../utils/crypto/mtproto'
import { SessionConnection } from './session-connection' import { SessionConnection } from './session-connection'
// Heavily based on code from https://github.com/LonamiWebs/Telethon/blob/master/telethon/network/authenticator.py // Heavily based on code from https://github.com/LonamiWebs/Telethon/blob/master/telethon/network/authenticator.py
// see https://core.telegram.org/mtproto/security_guidelines
const DH_SAFETY_RANGE = bigInt[2].pow(2048 - 64) const DH_SAFETY_RANGE = bigInt[2].pow(2048 - 64)
const KNOWN_DH_PRIME = bigInt(
'C71CAEB9C6B1C9048E6C522F70F13F73980D40238E3E21C14934D037563D930F48198A0AA7C14058229493D22530F4DBFA336F6E0AC925139543AED44CCE7C3720FD51F69458705AC68CD4FE6B6B13ABDC9746512969328454F18FAF8C595F642477FE96BB2A941D5BCD1D4AC8CC49880708FA9B378E3C4F3A9060BEE67CF9A4A4A695811051907E162753B56B0F6B410DBA74D8A84B2A14B3144E0EF1284754FD17ED950D5965B4B9DD46582DB1178D169C6BC465B0D6FF9CA3928FEF5B9AE4E418FC15E83EBEA0F87FA9FF5EED70050DED2849F47BF959D956850CE929851F0D8115F635B105EE2E4E15D04B2454BF6F4FADF034B10403119CD8E3B92FCC5B',
16,
)
const TWO_POW_2047 = bigInt[2].pow(2047)
const TWO_POW_2048 = bigInt[2].pow(2048)
interface CheckedPrime {
prime: bigInt.BigInteger
generators: number[]
}
const checkedPrimesCache: CheckedPrime[] = []
function checkDhPrime(log: Logger, dhPrime: bigInt.BigInteger, g: number) {
if (KNOWN_DH_PRIME.eq(dhPrime)) {
log.debug('server is using known dh prime, skipping validation')
return
}
let checkedPrime = checkedPrimesCache.find((x) => x.prime.eq(dhPrime))
if (!checkedPrime) {
if (
dhPrime.lesserOrEquals(TWO_POW_2047) ||
dhPrime.greaterOrEquals(TWO_POW_2048)
) {
throw new Error('Step 3: dh_prime is not in the 2048-bit range')
}
if (!millerRabin(dhPrime)) {
throw new Error('Step 3: dh_prime is not prime')
}
if (!millerRabin(dhPrime.minus(1).divide(2))) {
throw new Error(
'Step 3: dh_prime is not a safe prime - (dh_prime-1)/2 is not prime',
)
}
log.debug('dh_prime is probably prime')
checkedPrime = {
prime: dhPrime,
generators: [],
}
checkedPrimesCache.push(checkedPrime)
} else {
log.debug('dh_prime is probably prime (cached)')
}
const generatorChecked = checkedPrime.generators.includes(g)
if (generatorChecked) {
log.debug('g = %d is already checked for dh_prime', g)
return
}
switch (g) {
case 2:
if (dhPrime.mod(8).notEquals(7)) {
throw new Error('Step 3: ivalid g - dh_prime mod 8 != 7')
}
break
case 3:
if (dhPrime.mod(3).notEquals(2)) {
throw new Error('Step 3: ivalid g - dh_prime mod 3 != 2')
}
break
case 4:
break
case 5: {
const mod = dhPrime.mod(5)
if (mod.notEquals(1) && mod.notEquals(4)) {
throw new Error(
'Step 3: ivalid g - dh_prime mod 5 != 1 && dh_prime mod 5 != 4',
)
}
break
}
case 6: {
const mod = dhPrime.mod(24)
if (mod.notEquals(19) && mod.notEquals(23)) {
throw new Error(
'Step 3: ivalid g - dh_prime mod 24 != 19 && dh_prime mod 24 != 23',
)
}
break
}
case 7: {
const mod = dhPrime.mod(7)
if (mod.notEquals(3) && mod.notEquals(5) && mod.notEquals(6)) {
throw new Error(
'Step 3: ivalid g - dh_prime mod 7 != 3 && dh_prime mod 7 != 5 && dh_prime mod 7 != 6',
)
}
break
}
default:
throw new Error(`Step 3: ivalid g - unknown g = ${g}`)
}
checkedPrime.generators.push(g)
log.debug('g = %d is safe to use with dh_prime', g)
}
async function rsaPad( async function rsaPad(
data: Buffer, data: Buffer,
@ -102,6 +218,7 @@ async function rsaEncrypt(
export async function doAuthorization( export async function doAuthorization(
connection: SessionConnection, connection: SessionConnection,
crypto: ICryptoProvider, crypto: ICryptoProvider,
expiresIn?: number,
): Promise<[Buffer, Long, number]> { ): Promise<[Buffer, Long, number]> {
// eslint-disable-next-line dot-notation // eslint-disable-next-line dot-notation
const session = connection['_session'] const session = connection['_session']
@ -128,23 +245,26 @@ export async function doAuthorization(
async function readNext(): Promise<mtp.TlObject> { async function readNext(): Promise<mtp.TlObject> {
return TlBinaryReader.deserializeObject( return TlBinaryReader.deserializeObject(
readerMap, readerMap,
await connection.waitForNextMessage(), await connection.waitForUnencryptedMessage(),
20, // skip mtproto header 20, // skip mtproto header
) )
} }
const log = connection.log.create('auth') const log = connection.log.create('auth')
if (expiresIn) log.prefix = '[PFS] '
const nonce = randomBytes(16) const nonce = randomBytes(16)
// Step 1: PQ request // Step 1: PQ request
log.debug('starting PQ handshake, nonce = %h', nonce) log.debug('starting PQ handshake (temp = %b), nonce = %h', expiresIn, nonce)
await sendPlainMessage({ _: 'mt_req_pq_multi', nonce }) await sendPlainMessage({ _: 'mt_req_pq_multi', nonce })
const resPq = await readNext() const resPq = await readNext()
if (resPq._ !== 'mt_resPQ') throw new Error('Step 1: answer was ' + resPq._) if (resPq._ !== 'mt_resPQ') throw new Error('Step 1: answer was ' + resPq._)
if (!buffersEqual(resPq.nonce, nonce)) { throw new Error('Step 1: invalid nonce from server') } if (!buffersEqual(resPq.nonce, nonce)) {
throw new Error('Step 1: invalid nonce from server')
}
const serverKeys = resPq.serverPublicKeyFingerprints.map((it) => const serverKeys = resPq.serverPublicKeyFingerprints.map((it) =>
it.toUnsigned().toString(16), it.toUnsigned().toString(16),
@ -175,8 +295,8 @@ export async function doAuthorization(
if (connection.params.testMode) dcId += 10000 if (connection.params.testMode) dcId += 10000
if (connection.params.dc.mediaOnly) dcId = -dcId if (connection.params.dc.mediaOnly) dcId = -dcId
const _pqInnerData: mtp.RawMt_p_q_inner_data_dc = { const _pqInnerData: mtp.TypeP_Q_inner_data = {
_: 'mt_p_q_inner_data_dc', _: expiresIn ? 'mt_p_q_inner_data_temp_dc' : 'mt_p_q_inner_data_dc',
pq: resPq.pq, pq: resPq.pq,
p, p,
q, q,
@ -184,6 +304,7 @@ export async function doAuthorization(
newNonce, newNonce,
serverNonce: resPq.serverNonce, serverNonce: resPq.serverNonce,
dc: dcId, dc: dcId,
expiresIn: expiresIn!, // whatever
} }
const pqInnerData = TlBinaryWriter.serializeObject(writerMap, _pqInnerData) const pqInnerData = TlBinaryWriter.serializeObject(writerMap, _pqInnerData)
@ -204,12 +325,20 @@ export async function doAuthorization(
}) })
const serverDhParams = await readNext() const serverDhParams = await readNext()
if (!mtp.isAnyServer_DH_Params(serverDhParams)) { throw new Error('Step 2.1: answer was ' + serverDhParams._) } if (!mtp.isAnyServer_DH_Params(serverDhParams)) {
throw new Error('Step 2.1: answer was ' + serverDhParams._)
}
if (serverDhParams._ !== 'mt_server_DH_params_ok') { throw new Error('Step 2.1: answer was ' + serverDhParams._) } if (serverDhParams._ !== 'mt_server_DH_params_ok') {
throw new Error('Step 2.1: answer was ' + serverDhParams._)
}
if (!buffersEqual(serverDhParams.nonce, nonce)) { throw Error('Step 2: invalid nonce from server') } if (!buffersEqual(serverDhParams.nonce, nonce)) {
if (!buffersEqual(serverDhParams.serverNonce, resPq.serverNonce)) { throw Error('Step 2: invalid server nonce from server') } throw Error('Step 2: invalid nonce from server')
}
if (!buffersEqual(serverDhParams.serverNonce, resPq.serverNonce)) {
throw Error('Step 2: invalid server nonce from server')
}
// type was removed from schema in July 2021 // type was removed from schema in July 2021
// if (serverDhParams._ === 'mt_server_DH_params_fail') { // if (serverDhParams._ === 'mt_server_DH_params_fail') {
@ -222,7 +351,9 @@ export async function doAuthorization(
log.debug('server DH ok') log.debug('server DH ok')
if (serverDhParams.encryptedAnswer.length % 16 !== 0) { throw new Error('Step 2: AES block size is invalid') } if (serverDhParams.encryptedAnswer.length % 16 !== 0) {
throw new Error('Step 2: AES block size is invalid')
}
// Step 3: complete DH exchange // Step 3: complete DH exchange
const [key, iv] = await generateKeyAndIvFromNonce( const [key, iv] = await generateKeyAndIvFromNonce(
@ -248,20 +379,28 @@ export async function doAuthorization(
plainTextAnswer.slice(20, serverDhInnerReader.pos), plainTextAnswer.slice(20, serverDhInnerReader.pos),
), ),
) )
) { throw new Error('Step 3: invalid inner data hash') } ) {
throw new Error('Step 3: invalid inner data hash')
}
if (serverDhInner._ !== 'mt_server_DH_inner_data') { throw Error('Step 3: inner data was ' + serverDhInner._) } if (serverDhInner._ !== 'mt_server_DH_inner_data') {
if (!buffersEqual(serverDhInner.nonce, nonce)) { throw Error('Step 3: invalid nonce from server') } throw Error('Step 3: inner data was ' + serverDhInner._)
if (!buffersEqual(serverDhInner.serverNonce, resPq.serverNonce)) { throw Error('Step 3: invalid server nonce from server') } }
if (!buffersEqual(serverDhInner.nonce, nonce)) {
throw Error('Step 3: invalid nonce from server')
}
if (!buffersEqual(serverDhInner.serverNonce, resPq.serverNonce)) {
throw Error('Step 3: invalid server nonce from server')
}
const dhPrime = bufferToBigInt(serverDhInner.dhPrime) const dhPrime = bufferToBigInt(serverDhInner.dhPrime)
const timeOffset = Math.floor(Date.now() / 1000) - serverDhInner.serverTime const timeOffset = Math.floor(Date.now() / 1000) - serverDhInner.serverTime
// dhPrime is not checked because who cares lol :D
const g = bigInt(serverDhInner.g) const g = bigInt(serverDhInner.g)
const gA = bufferToBigInt(serverDhInner.gA) const gA = bufferToBigInt(serverDhInner.gA)
checkDhPrime(log, dhPrime, serverDhInner.g)
let retryId = Long.ZERO let retryId = Long.ZERO
const serverSalt = xorBuffer( const serverSalt = xorBuffer(
newNonce.slice(0, 8), newNonce.slice(0, 8),
@ -276,15 +415,24 @@ export async function doAuthorization(
const authKeyAuxHash = (await crypto.sha1(authKey)).slice(0, 8) const authKeyAuxHash = (await crypto.sha1(authKey)).slice(0, 8)
// validate DH params // validate DH params
if (g.lesserOrEquals(1) || g.greaterOrEquals(dhPrime.minus(bigInt.one))) { throw new Error('g is not within (1, dh_prime - 1)') } if (
g.lesserOrEquals(1) ||
g.greaterOrEquals(dhPrime.minus(bigInt.one))
) {
throw new Error('g is not within (1, dh_prime - 1)')
}
if ( if (
gA.lesserOrEquals(1) || gA.lesserOrEquals(1) ||
gA.greaterOrEquals(dhPrime.minus(bigInt.one)) gA.greaterOrEquals(dhPrime.minus(bigInt.one))
) { throw new Error('g_a is not within (1, dh_prime - 1)') } ) {
throw new Error('g_a is not within (1, dh_prime - 1)')
}
if ( if (
gB.lesserOrEquals(1) || gB.lesserOrEquals(1) ||
gB.greaterOrEquals(dhPrime.minus(bigInt.one)) gB.greaterOrEquals(dhPrime.minus(bigInt.one))
) { throw new Error('g_b is not within (1, dh_prime - 1)') } ) {
throw new Error('g_b is not within (1, dh_prime - 1)')
}
if (gA.lt(DH_SAFETY_RANGE) || gA.gt(dhPrime.minus(DH_SAFETY_RANGE))) { if (gA.lt(DH_SAFETY_RANGE) || gA.gt(dhPrime.minus(DH_SAFETY_RANGE))) {
throw new Error( throw new Error(
@ -334,10 +482,16 @@ export async function doAuthorization(
const dhGen = await readNext() const dhGen = await readNext()
if (!mtp.isAnySet_client_DH_params_answer(dhGen)) { throw new Error('Step 4: answer was ' + dhGen._) } if (!mtp.isAnySet_client_DH_params_answer(dhGen)) {
throw new Error('Step 4: answer was ' + dhGen._)
}
if (!buffersEqual(dhGen.nonce, nonce)) { throw Error('Step 4: invalid nonce from server') } if (!buffersEqual(dhGen.nonce, nonce)) {
if (!buffersEqual(dhGen.serverNonce, resPq.serverNonce)) { throw Error('Step 4: invalid server nonce from server') } throw Error('Step 4: invalid nonce from server')
}
if (!buffersEqual(dhGen.serverNonce, resPq.serverNonce)) {
throw Error('Step 4: invalid server nonce from server')
}
log.debug('DH result: %s', dhGen._) log.debug('DH result: %s', dhGen._)
@ -351,7 +505,9 @@ export async function doAuthorization(
Buffer.concat([newNonce, Buffer.from([2]), authKeyAuxHash]), Buffer.concat([newNonce, Buffer.from([2]), authKeyAuxHash]),
) )
if (!buffersEqual(expectedHash.slice(4, 20), dhGen.newNonceHash2)) { throw Error('Step 4: invalid retry nonce hash from server') } if (!buffersEqual(expectedHash.slice(4, 20), dhGen.newNonceHash2)) {
throw Error('Step 4: invalid retry nonce hash from server')
}
// eslint-disable-next-line @typescript-eslint/no-explicit-any // eslint-disable-next-line @typescript-eslint/no-explicit-any
retryId = Long.fromBytesLE(authKeyAuxHash as any) retryId = Long.fromBytesLE(authKeyAuxHash as any)
continue continue
@ -363,7 +519,9 @@ export async function doAuthorization(
Buffer.concat([newNonce, Buffer.from([1]), authKeyAuxHash]), Buffer.concat([newNonce, Buffer.from([1]), authKeyAuxHash]),
) )
if (!buffersEqual(expectedHash.slice(4, 20), dhGen.newNonceHash1)) { throw Error('Step 4: invalid nonce hash from server') } if (!buffersEqual(expectedHash.slice(4, 20), dhGen.newNonceHash1)) {
throw Error('Step 4: invalid nonce hash from server')
}
log.info('authorization successful') log.info('authorization successful')

View file

@ -0,0 +1,103 @@
import { tl } from '@mtcute/tl'
export class ConfigManager {
constructor(private _update: () => Promise<tl.RawConfig>) {}
private _destroyed = false
private _config?: tl.RawConfig
private _cdnConfig?: tl.RawCdnConfig
private _updateTimeout?: NodeJS.Timeout
private _updatingPromise?: Promise<void>
private _listeners: ((config: tl.RawConfig) => void)[] = []
get isStale(): boolean {
return !this._config || this._config.expires < Date.now() / 1000
}
update(force = false): Promise<void> {
if (!force && !this.isStale) return Promise.resolve()
if (this._updatingPromise) return this._updatingPromise
return (this._updatingPromise = this._update().then((config) => {
if (this._destroyed) return
this.setConfig(config)
}))
}
setConfig(config: tl.RawConfig): void {
this._config = config
if (this._updateTimeout) clearTimeout(this._updateTimeout)
this._updateTimeout = setTimeout(
() => this.update(),
(config.expires - Date.now() / 1000) * 1000,
)
for (const cb of this._listeners) cb(config)
}
onConfigUpdate(cb: (config: tl.RawConfig) => void): void {
this._listeners.push(cb)
}
offConfigUpdate(cb: (config: tl.RawConfig) => void): void {
const idx = this._listeners.indexOf(cb)
if (idx >= 0) this._listeners.splice(idx, 1)
}
getNow(): tl.RawConfig | undefined {
return this._config
}
async get(): Promise<tl.RawConfig> {
if (this.isStale) await this.update()
return this._config!
}
destroy(): void {
if (this._updateTimeout) clearTimeout(this._updateTimeout)
this._listeners.length = 0
this._destroyed = true
}
async findOption(params: {
dcId: number
allowIpv6?: boolean
preferIpv6?: boolean
allowMedia?: boolean
preferMedia?: boolean
cdn?: boolean
}): Promise<tl.RawDcOption | undefined> {
if (this.isStale) await this.update()
const options = this._config!.dcOptions.filter((opt) => {
if (opt.tcpoOnly) return false // unsupported
if (opt.ipv6 && !params.allowIpv6) return false
if (opt.mediaOnly && !params.allowMedia) return false
if (opt.cdn && !params.cdn) return false
return opt.id === params.dcId
})
if (params.preferMedia && params.preferIpv6) {
const r = options.find((opt) => opt.mediaOnly && opt.ipv6)
if (r) return r
}
if (params.preferMedia) {
const r = options.find((opt) => opt.mediaOnly)
if (r) return r
}
if (params.preferIpv6) {
const r = options.find((opt) => opt.ipv6)
if (r) return r
}
return options[0]
}
}

View file

@ -1,3 +1,8 @@
export {
ConnectionKind,
NetworkManagerExtraParams,
RpcCallOptions,
} from './network-manager'
export * from './reconnection' export * from './reconnection'
export * from './session-connection' export * from './session-connection'
export * from './transports' export * from './transports'

View file

@ -2,30 +2,96 @@ import Long from 'long'
import { mtp, tl } from '@mtcute/tl' import { mtp, tl } from '@mtcute/tl'
import { import {
TlBinaryReader,
TlBinaryWriter, TlBinaryWriter,
TlReaderMap, TlReaderMap,
TlSerializationCounter, TlSerializationCounter,
TlWriterMap, TlWriterMap,
} from '@mtcute/tl-runtime' } from '@mtcute/tl-runtime'
import { getRandomInt, ICryptoProvider, Logger, randomLong } from '../utils' import {
import { buffersEqual, randomBytes } from '../utils/buffer-utils' ControllablePromise,
import { createAesIgeForMessage } from '../utils/crypto/mtproto' Deque,
getRandomInt,
ICryptoProvider,
Logger,
LongMap,
LruSet,
randomLong,
SortedArray,
} from '../utils'
import { AuthKey } from './auth-key'
export interface PendingRpc {
method: string
data: Buffer
promise: ControllablePromise
stack?: string
gzipOverhead?: number
sent?: boolean
msgId?: Long
seqNo?: number
containerId?: Long
acked?: boolean
initConn?: boolean
getState?: number
cancelled?: boolean
timeout?: NodeJS.Timeout
}
export type PendingMessage =
| {
_: 'rpc'
rpc: PendingRpc
}
| {
_: 'container'
msgIds: Long[]
}
| {
_: 'state'
msgIds: Long[]
containerId: Long
}
| {
_: 'resend'
msgIds: Long[]
containerId: Long
}
| {
_: 'ping'
pingId: Long
containerId: Long
}
| {
_: 'destroy_session'
sessionId: Long
containerId: Long
}
| {
_: 'cancel'
msgId: Long
containerId: Long
}
| {
_: 'future_salts'
containerId: Long
}
| {
_: 'bind'
promise: ControllablePromise
}
/** /**
* Class encapsulating a single MTProto session. * Class encapsulating a single MTProto session and storing
* Provides means to en-/decrypt messages * all the relevant state
*/ */
export class MtprotoSession { export class MtprotoSession {
readonly _crypto: ICryptoProvider
_sessionId = randomLong() _sessionId = randomLong()
_authKey?: Buffer _authKey = new AuthKey(this._crypto, this.log, this._readerMap)
_authKeyId?: Buffer _authKeyTemp = new AuthKey(this._crypto, this.log, this._readerMap)
_authKeyClientSalt?: Buffer _authKeyTempSecondary = new AuthKey(this._crypto, this.log, this._readerMap)
_authKeyServerSalt?: Buffer
_timeOffset = 0 _timeOffset = 0
_lastMessageId = Long.ZERO _lastMessageId = Long.ZERO
@ -33,190 +99,129 @@ export class MtprotoSession {
serverSalt = Long.ZERO serverSalt = Long.ZERO
/// state ///
// recent msg ids
recentOutgoingMsgIds = new LruSet<Long>(1000, false, true)
recentIncomingMsgIds = new LruSet<Long>(1000, false, true)
// queues
queuedRpc = new Deque<PendingRpc>()
queuedAcks: Long[] = []
queuedStateReq: Long[] = []
queuedResendReq: Long[] = []
queuedCancelReq: Long[] = []
getStateSchedule = new SortedArray<PendingRpc>(
[],
(a, b) => a.getState! - b.getState!,
)
// requests info
pendingMessages = new LongMap<PendingMessage>()
destroySessionIdToMsgId = new LongMap<Long>()
lastPingRtt = NaN
lastPingTime = 0
lastPingMsgId = Long.ZERO
lastSessionCreatedUid = Long.ZERO
initConnectionCalled = false
authorizationPending = false
next429Timeout = 1000
current429Timeout?: NodeJS.Timeout
next429ResetTimeout?: NodeJS.Timeout
constructor( constructor(
crypto: ICryptoProvider, readonly _crypto: ICryptoProvider,
readonly log: Logger, readonly log: Logger,
readonly _readerMap: TlReaderMap, readonly _readerMap: TlReaderMap,
readonly _writerMap: TlWriterMap, readonly _writerMap: TlWriterMap,
) { ) {
this._crypto = crypto this.log.prefix = `[SESSION ${this._sessionId.toString(16)}] `
} }
/** Whether session contains authKey */ get hasPendingMessages(): boolean {
get authorized(): boolean { return Boolean(
return this._authKey !== undefined this.queuedRpc.length ||
this.queuedAcks.length ||
this.queuedStateReq.length ||
this.queuedResendReq.length,
)
} }
/** Setup keys based on authKey */ /**
async setupKeys(authKey?: Buffer | null): Promise<void> { * Reset session by resetting auth key(s) and session state
if (authKey) { */
this._authKey = authKey reset(withAuthKey = false): void {
this._authKeyClientSalt = authKey.slice(88, 120) if (withAuthKey) {
this._authKeyServerSalt = authKey.slice(96, 128) this._authKey.reset()
this._authKeyId = (await this._crypto.sha1(this._authKey)).slice(-8) this._authKeyTemp.reset()
} else { this._authKeyTempSecondary.reset()
this._authKey = undefined
this._authKeyClientSalt = undefined
this._authKeyServerSalt = undefined
this._authKeyId = undefined
} }
clearTimeout(this.current429Timeout)
this.resetState()
this.resetLastPing(true)
} }
/** Reset session by removing authKey and values derived from it */ /**
reset(): void { * Reset session state and generate a new session ID.
*
* By default, also cancels any pending RPC requests.
* If `keepPending` is set to `true`, pending requests will be kept
*/
resetState(keepPending = false): void {
this._lastMessageId = Long.ZERO this._lastMessageId = Long.ZERO
this._seqNo = 0 this._seqNo = 0
this._authKey = undefined
this._authKeyClientSalt = undefined
this._authKeyServerSalt = undefined
this._authKeyId = undefined
this._sessionId = randomLong() this._sessionId = randomLong()
// no need to reset server salt this.log.debug('session reset, new sid = %h', this._sessionId)
this.log.prefix = `[SESSION ${this._sessionId.toString(16)}] `
// reset session state
if (!keepPending) {
for (const info of this.pendingMessages.values()) {
if (info._ === 'rpc') {
info.rpc.promise.reject(new Error('Session is reset'))
}
}
this.pendingMessages.clear()
}
this.recentOutgoingMsgIds.clear()
this.recentIncomingMsgIds.clear()
if (!keepPending) {
while (this.queuedRpc.length) {
const rpc = this.queuedRpc.popFront()!
if (rpc.sent === false) {
rpc.promise.reject(new Error('Session is reset'))
}
}
}
this.queuedAcks.length = 0
this.queuedStateReq.length = 0
this.queuedResendReq.length = 0
this.getStateSchedule.clear()
} }
changeSessionId(): void { enqueueRpc(rpc: PendingRpc, force?: boolean): boolean {
this._sessionId = randomLong() // already queued or cancelled
this._seqNo = 0 if ((!force && !rpc.sent) || rpc.cancelled) return false
}
/** Encrypt a single MTProto message using session's keys */ rpc.sent = false
async encryptMessage(message: Buffer): Promise<Buffer> { rpc.containerId = undefined
if (!this._authKey) throw new Error('Keys are not set up!') this.log.debug(
'enqueued %s for sending (msg_id = %s)',
let padding = rpc.method,
(16 /* header size */ + message.length + 12) /* min padding */ % 16 rpc.msgId || 'n/a',
padding = 12 + (padding ? 16 - padding : 0)
const buf = Buffer.alloc(16 + message.length + padding)
buf.writeInt32LE(this.serverSalt!.low)
buf.writeInt32LE(this.serverSalt!.high, 4)
buf.writeInt32LE(this._sessionId.low, 8)
buf.writeInt32LE(this._sessionId.high, 12)
message.copy(buf, 16)
randomBytes(padding).copy(buf, 16 + message.length)
const messageKey = (
await this._crypto.sha256(
Buffer.concat([this._authKeyClientSalt!, buf]),
)
).slice(8, 24)
const ige = await createAesIgeForMessage(
this._crypto,
this._authKey,
messageKey,
true,
) )
const encryptedData = await ige.encrypt(buf) this.queuedRpc.pushBack(rpc)
return Buffer.concat([this._authKeyId!, messageKey, encryptedData]) return true
}
/** Decrypt a single MTProto message using session's keys */
async decryptMessage(
data: Buffer,
callback: (msgId: tl.Long, seqNo: number, data: TlBinaryReader) => void,
): Promise<void> {
if (!this._authKey) throw new Error('Keys are not set up!')
const authKeyId = data.slice(0, 8)
const messageKey = data.slice(8, 24)
let encryptedData = data.slice(24)
if (!buffersEqual(authKeyId, this._authKeyId!)) {
this.log.warn(
'[%h] warn: received message with unknown authKey = %h (expected %h)',
this._sessionId,
authKeyId,
this._authKeyId,
)
return
}
const padSize = encryptedData.length % 16
if (padSize !== 0) {
// data came from a codec that uses non-16-based padding.
// it is safe to drop those padding bytes
encryptedData = encryptedData.slice(0, -padSize)
}
const ige = await createAesIgeForMessage(
this._crypto,
this._authKey!,
messageKey,
false,
)
const innerData = await ige.decrypt(encryptedData)
const expectedMessageKey = (
await this._crypto.sha256(
Buffer.concat([this._authKeyServerSalt!, innerData]),
)
).slice(8, 24)
if (!buffersEqual(messageKey, expectedMessageKey)) {
this.log.warn(
'[%h] received message with invalid messageKey = %h (expected %h)',
this._sessionId,
messageKey,
expectedMessageKey,
)
return
}
const innerReader = new TlBinaryReader(this._readerMap, innerData)
innerReader.seek(8) // skip salt
const sessionId = innerReader.long()
const messageId = innerReader.long(true)
if (sessionId.neq(this._sessionId)) {
this.log.warn(
'ignoring message with invalid sessionId = %h',
sessionId,
)
return
}
const seqNo = innerReader.uint()
const length = innerReader.uint()
if (length > innerData.length - 32 /* header size */) {
this.log.warn(
'ignoring message with invalid length: %d > %d',
length,
innerData.length - 32,
)
return
}
if (length % 4 !== 0) {
this.log.warn(
'ignoring message with invalid length: %d is not a multiple of 4',
length,
)
return
}
const paddingSize = innerData.length - length - 32 // header size
if (paddingSize < 12 || paddingSize > 1024) {
this.log.warn(
'ignoring message with invalid padding size: %d',
paddingSize,
)
return
}
callback(messageId, seqNo, innerReader)
} }
getMessageId(): Long { getMessageId(): Long {
@ -237,16 +242,55 @@ export class MtprotoSession {
} }
getSeqNo(isContentRelated = true): number { getSeqNo(isContentRelated = true): number {
let seqNo = this._seqNo * 2 let seqNo = this._seqNo
if (isContentRelated) { if (isContentRelated) {
seqNo += 1 seqNo += 1
this._seqNo += 1 this._seqNo += 2
} }
return seqNo return seqNo
} }
/** Encrypt a single MTProto message using session's keys */
async encryptMessage(message: Buffer): Promise<Buffer> {
const key = this._authKeyTemp.ready ? this._authKeyTemp : this._authKey
return key.encryptMessage(message, this.serverSalt, this._sessionId)
}
/** Decrypt a single MTProto message using session's keys */
async decryptMessage(
data: Buffer,
callback: Parameters<AuthKey['decryptMessage']>[2],
): Promise<void> {
if (!this._authKey.ready) throw new Error('Keys are not set up!')
const authKeyId = data.slice(0, 8)
let key: AuthKey
if (this._authKey.match(authKeyId)) {
key = this._authKey
} else if (this._authKeyTemp.match(authKeyId)) {
key = this._authKeyTemp
} else if (this._authKeyTempSecondary.match(authKeyId)) {
key = this._authKeyTempSecondary
} else {
this.log.warn(
'received message with unknown authKey = %h (expected %h or %h or %h)',
authKeyId,
this._authKey.id,
this._authKeyTemp.id,
this._authKeyTempSecondary.id,
)
return
}
return key.decryptMessage(data, this._sessionId, callback)
}
writeMessage( writeMessage(
writer: TlBinaryWriter, writer: TlBinaryWriter,
content: tl.TlObject | mtp.TlObject | Buffer, content: tl.TlObject | mtp.TlObject | Buffer,
@ -270,4 +314,43 @@ export class MtprotoSession {
return messageId return messageId
} }
onTransportFlood(callback: () => void) {
if (this.current429Timeout) return // already waiting
// all active queries must be resent after a timeout
this.resetLastPing(true)
const timeout = this.next429Timeout
this.next429Timeout = Math.min(this.next429Timeout * 2, 32000)
clearTimeout(this.current429Timeout)
clearTimeout(this.next429ResetTimeout)
this.current429Timeout = setTimeout(() => {
this.current429Timeout = undefined
callback()
}, timeout)
this.next429ResetTimeout = setTimeout(() => {
this.next429ResetTimeout = undefined
this.next429Timeout = 1000
}, 60000)
this.log.debug(
'transport flood, waiting for %d ms before proceeding',
timeout,
)
return Date.now() + timeout
}
resetLastPing(withTime = false): void {
if (withTime) this.lastPingTime = 0
if (!this.lastPingMsgId.isZero()) {
this.pendingMessages.delete(this.lastPingMsgId)
}
this.lastPingMsgId = Long.ZERO
}
} }

View file

@ -0,0 +1,332 @@
import EventEmitter from 'events'
import { tl } from '@mtcute/tl'
import { Logger } from '../utils'
import { MtprotoSession } from './mtproto-session'
import {
SessionConnection,
SessionConnectionParams,
} from './session-connection'
import { TransportFactory } from './transports'
export class MultiSessionConnection extends EventEmitter {
private _log: Logger
readonly _sessions: MtprotoSession[]
private _enforcePfs = false
constructor(
readonly params: SessionConnectionParams,
private _count: number,
log: Logger,
logPrefix = '',
) {
super()
this._log = log.create('multi')
if (logPrefix) this._log.prefix = `[${logPrefix}] `
this._enforcePfs = _count > 1 && params.isMainConnection
this._sessions = []
this._updateConnections()
}
protected _connections: SessionConnection[] = []
setCount(count: number, connect = this.params.isMainConnection): void {
this._count = count
this._updateConnections(connect)
}
private _updateSessions(): void {
// there are two cases
// 1. this msc is main, in which case every connection should have its own session
// 2. this msc is not main, in which case all connections should share the same session
// if (!this.params.isMainConnection) {
// // case 2
// this._log.debug(
// 'updating sessions count: %d -> 1',
// this._sessions.length,
// )
//
// if (this._sessions.length === 0) {
// this._sessions.push(
// new MtprotoSession(
// this.params.crypto,
// this._log.create('session'),
// this.params.readerMap,
// this.params.writerMap,
// ),
// )
// }
//
// // shouldn't happen, but just in case
// while (this._sessions.length > 1) {
// this._sessions.pop()!.reset()
// }
//
// return
// }
this._log.debug(
'updating sessions count: %d -> %d',
this._sessions.length,
this._count,
)
// case 1
if (this._sessions.length === this._count) return
if (this._sessions.length > this._count) {
// destroy extra sessions
for (let i = this._sessions.length - 1; i >= this._count; i--) {
this._sessions[i].reset()
}
this._sessions.splice(this._count)
return
}
while (this._sessions.length < this._count) {
const idx = this._sessions.length
const session = new MtprotoSession(
this.params.crypto,
this._log.create('session'),
this.params.readerMap,
this.params.writerMap,
)
// brvh
if (idx !== 0) session._authKey = this._sessions[0]._authKey
this._sessions.push(session)
}
}
private _updateConnections(connect = false): void {
this._updateSessions()
if (this._connections.length === this._count) return
this._log.debug(
'updating connections count: %d -> %d',
this._connections.length,
this._count,
)
const newEnforcePfs = this._count > 1 && this.params.isMainConnection
const enforcePfsChanged = newEnforcePfs !== this._enforcePfs
if (enforcePfsChanged) {
this._log.debug(
'enforcePfs changed: %s -> %s',
this._enforcePfs,
newEnforcePfs,
)
this._enforcePfs = newEnforcePfs
}
if (this._connections.length > this._count) {
// destroy extra connections
for (let i = this._connections.length - 1; i >= this._count; i--) {
this._connections[i].removeAllListeners()
this._connections[i].destroy()
}
this._connections.splice(this._count)
return
}
if (enforcePfsChanged) {
this._connections.forEach((conn) => {
conn.setUsePfs(this.params.usePfs || this._enforcePfs)
})
}
// create new connections
for (let i = this._connections.length; i < this._count; i++) {
const session = this._sessions[i] // this.params.isMainConnection ? // :
// this._sessions[0]
const conn = new SessionConnection(
{
...this.params,
usePfs: this.params.usePfs || this._enforcePfs,
isMainConnection: this.params.isMainConnection && i === 0,
withUpdates:
this.params.isMainConnection &&
!this.params.disableUpdates,
},
session,
)
if (this.params.isMainConnection) {
conn.on('update', (update) => this.emit('update', update))
}
conn.on('error', (err) => this.emit('error', err, conn))
conn.on('key-change', (key) => {
this.emit('key-change', i, key)
// notify other connections
for (const conn_ of this._connections) {
if (conn_ === conn) continue
conn_.onConnected()
}
})
conn.on('tmp-key-change', (key, expires) =>
this.emit('tmp-key-change', i, key, expires),
)
conn.on('auth-begin', () => {
this._log.debug('received auth-begin from connection %d', i)
this.emit('auth-begin', i)
// we need to reset temp auth keys if there are any left
this._connections.forEach((conn_) => {
conn_._session._authKeyTemp.reset()
if (conn_ !== conn) conn_.reconnect()
})
})
conn.on('usable', () => this.emit('usable', i))
conn.on('request-auth', () => this.emit('request-auth', i))
conn.on('flood-done', () => {
this._log.debug('received flood-done from connection %d', i)
this._connections.forEach((it) => it.flushWhenIdle())
})
this._connections.push(conn)
if (connect) conn.connect()
}
}
_destroyed = false
destroy(): void {
this._connections.forEach((conn) => conn.destroy())
this._sessions.forEach((sess) => sess.reset())
this.removeAllListeners()
this._destroyed = true
}
private _nextConnection = 0
sendRpc<T extends tl.RpcMethod>(
request: T,
stack?: string,
timeout?: number,
): Promise<tl.RpcCallReturn[T['_']]> {
// if (this.params.isMainConnection) {
// find the least loaded connection
let min = Infinity
let minIdx = 0
for (let i = 0; i < this._connections.length; i++) {
const conn = this._connections[i]
const total =
conn._session.queuedRpc.length +
conn._session.pendingMessages.size()
if (total < min) {
min = total
minIdx = i
}
}
return this._connections[minIdx].sendRpc(request, stack, timeout)
// }
// round-robin connections
// since they all share the same session, it doesn't matter which one we use
// the connection chosen here will only affect the first attempt at sending
// return this._connections[
// this._nextConnection++ % this._connections.length
// ].sendRpc(request, stack, timeout)
}
connect(): void {
for (const conn of this._connections) {
conn.connect()
}
}
ensureConnected(): void {
if (this._connections[0].isConnected) return
this.connect()
}
async setAuthKey(
authKey: Buffer | null,
temp = false,
idx = 0,
): Promise<void> {
const session = this._sessions[idx]
const key = temp ? session._authKeyTemp : session._authKey
await key.setup(authKey)
}
resetAuthKeys(): void {
for (const session of this._sessions) {
session.reset(true)
}
this.notifyKeyChange()
}
setInactivityTimeout(timeout?: number): void {
this._log.debug('setting inactivity timeout to %s', timeout)
// for future connections (if any)
this.params.inactivityTimeout = timeout
// for current connections
for (const conn of this._connections) {
conn.setInactivityTimeout(timeout)
}
}
notifyKeyChange(): void {
// only expected to be called on non-main connections
const session = this._sessions[0]
if (this.params.usePfs && !session._authKeyTemp.ready) {
this._log.debug(
'temp auth key needed but not ready, ignoring key change',
)
return
}
if (this._sessions[0].queuedRpc.length) {
// there are pending requests, we need to reconnect.
this._log.debug(
'notifying key change on the connection due to queued rpc',
)
this._connections[0].onConnected()
}
// connection is idle, we don't need to notify it
}
requestAuth(): void {
this._connections[0]._authorize()
}
resetSessions(): void {
if (this.params.isMainConnection) {
for (const conn of this._connections) {
conn._resetSession()
}
} else {
this._connections[0]._resetSession()
}
}
changeTransport(factory: TransportFactory): void {
this._connections.forEach((conn) => conn.changeTransport(factory))
}
getPoolSize(): number {
return this._connections.length
}
}

View file

@ -0,0 +1,866 @@
import { tl } from '@mtcute/tl'
import { TlReaderMap, TlWriterMap } from '@mtcute/tl-runtime'
import { ITelegramStorage } from '../storage'
import {
createControllablePromise,
ICryptoProvider,
Logger,
sleep,
} from '../utils'
import { ConfigManager } from './config-manager'
import { MultiSessionConnection } from './multi-session-connection'
import { PersistentConnectionParams } from './persistent-connection'
import {
defaultReconnectionStrategy,
ReconnectionStrategy,
} from './reconnection'
import {
SessionConnection,
SessionConnectionParams,
} from './session-connection'
import { defaultTransportFactory, TransportFactory } from './transports'
export type ConnectionKind = 'main' | 'upload' | 'download' | 'downloadSmall'
const CLIENT_ERRORS = {
'303': 1,
'400': 1,
'401': 1,
'403': 1,
'404': 1,
'406': 1,
'420': 1,
}
/**
* Params passed into {@link NetworkManager} by {@link TelegramClient}.
* This type is intended for internal usage only.
*/
export interface NetworkManagerParams {
storage: ITelegramStorage
crypto: ICryptoProvider
log: Logger
apiId: number
initConnectionOptions?: Partial<
Omit<tl.RawInitConnectionRequest, 'apiId' | 'query'>
>
transport?: TransportFactory
reconnectionStrategy?: ReconnectionStrategy<PersistentConnectionParams>
floodSleepThreshold: number
maxRetryCount: number
disableUpdates?: boolean
testMode: boolean
layer: number
useIpv6: boolean
readerMap: TlReaderMap
writerMap: TlWriterMap
isPremium: boolean
_emitError: (err: Error, connection?: SessionConnection) => void
keepAliveAction: () => void
}
export type ConnectionCountDelegate = (
kind: ConnectionKind,
dcId: number,
isPremium: boolean
) => number
const defaultConnectionCountDelegate: ConnectionCountDelegate = (
kind,
dcId,
isPremium,
) => {
switch (kind) {
case 'main':
return 1
case 'upload':
return isPremium || (dcId !== 2 && dcId !== 4) ? 8 : 4
case 'download':
return isPremium ? 8 : 2
case 'downloadSmall':
return 2
}
}
/**
* Additional params passed into {@link NetworkManager} by the user
* that customize the behavior of the manager
*/
export interface NetworkManagerExtraParams {
/**
* Whether to use PFS (Perfect Forward Secrecy) for all connections.
* This is disabled by default
*/
usePfs?: boolean
/**
* Connection count for each connection kind.
* The function should be pure to avoid unexpected behavior.
*
* Defaults to TDLib logic:
* - main: handled internally, **cannot be changed here**
* - upload: if premium or dc id is other than 2 or 4, then 8, otherwise 4
* - download: if premium then 8, otherwise 2
* - downloadSmall: 2
*/
connectionCount?: ConnectionCountDelegate
/**
* Idle timeout for non-main connections, in ms
* Defaults to 60 seconds.
*/
inactivityTimeout?: number
}
export interface RpcCallOptions {
/**
* If the call results in a `FLOOD_WAIT_X` error,
* the maximum amount of time to wait before retrying.
*
* If set to `0`, the call will not be retried.
*
* @default {@link BaseTelegramClientOptions.floodSleepThreshold}
*/
floodSleepThreshold?: number
/**
* If the call results in an internal server error or a flood wait,
* the maximum amount of times to retry the call.
*
* @default {@link BaseTelegramClientOptions.maxRetryCount}
*/
maxRetryCount?: number
/**
* Timeout for the call, in milliseconds.
*
* @default Infinity
*/
timeout?: number
/**
* Kind of connection to use for this call.
*
* @default 'main'
*/
kind?: ConnectionKind
/**
* ID of the DC to use for this call
*/
dcId?: number
/**
* DC connection manager to use for this call.
* Overrides `dcId` if set.
*/
manager?: DcConnectionManager
}
export class DcConnectionManager {
private __baseConnectionParams = (): SessionConnectionParams => ({
crypto: this.manager.params.crypto,
initConnection: this.manager._initConnectionParams,
transportFactory: this.manager._transportFactory,
dc: this._dc,
testMode: this.manager.params.testMode,
reconnectionStrategy: this.manager._reconnectionStrategy,
layer: this.manager.params.layer,
disableUpdates: this.manager.params.disableUpdates,
readerMap: this.manager.params.readerMap,
writerMap: this.manager.params.writerMap,
usePfs: this.manager.params.usePfs,
isMainConnection: false,
inactivityTimeout: this.manager.params.inactivityTimeout ?? 60_000,
})
private _log = this.manager._log.create('dc-manager')
main: MultiSessionConnection
upload = new MultiSessionConnection(
this.__baseConnectionParams(),
this.manager._connectionCount(
'upload',
this._dc.id,
this.manager.params.isPremium,
),
this._log,
'UPLOAD',
)
download = new MultiSessionConnection(
this.__baseConnectionParams(),
this.manager._connectionCount(
'download',
this._dc.id,
this.manager.params.isPremium,
),
this._log,
'DOWNLOAD',
)
downloadSmall = new MultiSessionConnection(
this.__baseConnectionParams(),
this.manager._connectionCount(
'downloadSmall',
this._dc.id,
this.manager.params.isPremium,
),
this._log,
'DOWNLOAD_SMALL',
)
private get _mainConnectionCount() {
if (!this.isPrimary) return 1
return this.manager.config.getNow()?.tmpSessions ?? 1
}
constructor(
readonly manager: NetworkManager,
readonly dcId: number,
readonly _dc: tl.RawDcOption,
public isPrimary = false,
) {
this._log.prefix = `[DC ${dcId}] `
const mainParams = this.__baseConnectionParams()
mainParams.isMainConnection = true
if (isPrimary) {
mainParams.inactivityTimeout = undefined
}
this.main = new MultiSessionConnection(
mainParams,
this._mainConnectionCount,
this._log,
'MAIN',
)
this._setupMulti('main')
this._setupMulti('upload')
this._setupMulti('download')
this._setupMulti('downloadSmall')
}
private _setupMulti(kind: ConnectionKind): void {
const connection = this[kind]
connection.on('key-change', (idx, key) => {
if (kind !== 'main') {
// main connection is responsible for authorization,
// and keys are then sent to other connections
this.manager._log.warn(
'got key-change from non-main connection, ignoring',
)
return
}
this.manager._log.debug(
'key change for dc %d from connection %d',
this.dcId,
idx,
)
this.manager._storage.setAuthKeyFor(this.dcId, key)
// send key to other connections
Promise.all([
this.upload.setAuthKey(key),
this.download.setAuthKey(key),
this.downloadSmall.setAuthKey(key),
]).then(() => {
this.upload.notifyKeyChange()
this.download.notifyKeyChange()
this.downloadSmall.notifyKeyChange()
})
})
connection.on('tmp-key-change', (idx, key, expires) => {
if (kind !== 'main') {
this.manager._log.warn(
'got tmp-key-change from non-main connection, ignoring',
)
return
}
this.manager._log.debug(
'temp key change for dc %d from connection %d',
this.dcId,
idx,
)
this.manager._storage.setTempAuthKeyFor(
this.dcId,
idx,
key,
expires * 1000,
)
// send key to other connections
Promise.all([
this.upload.setAuthKey(key, true),
this.download.setAuthKey(key, true),
this.downloadSmall.setAuthKey(key, true),
]).then(() => {
this.upload.notifyKeyChange()
this.download.notifyKeyChange()
this.downloadSmall.notifyKeyChange()
})
})
connection.on('auth-begin', () => {
// we need to propagate auth-begin to all connections
// to avoid them sending requests before auth is complete
if (kind !== 'main') {
this.manager._log.warn(
'got auth-begin from non-main connection, ignoring',
)
return
}
// reset key on non-main connections
// this event was already propagated to additional main connections
this.upload.resetAuthKeys()
this.download.resetAuthKeys()
this.downloadSmall.resetAuthKeys()
})
connection.on('request-auth', () => {
this.main.requestAuth()
})
connection.on('error', (err, conn) => {
this.manager.params._emitError(err, conn)
})
}
setIsPrimary(isPrimary: boolean): void {
if (this.isPrimary === isPrimary) return
this.isPrimary = isPrimary
if (isPrimary) {
this.main.setInactivityTimeout(undefined)
} else {
this.main.setInactivityTimeout(
this.manager.params.inactivityTimeout ?? 60_000,
)
}
}
setIsPremium(isPremium: boolean): void {
this.upload.setCount(
this.manager._connectionCount('upload', this._dc.id, isPremium),
)
this.download.setCount(
this.manager._connectionCount('download', this._dc.id, isPremium),
)
this.downloadSmall.setCount(
this.manager._connectionCount(
'downloadSmall',
this._dc.id,
isPremium,
),
)
}
async loadKeys(): Promise<boolean> {
const permanent = await this.manager._storage.getAuthKeyFor(this.dcId)
await Promise.all([
this.main.setAuthKey(permanent),
this.upload.setAuthKey(permanent),
this.download.setAuthKey(permanent),
this.downloadSmall.setAuthKey(permanent),
])
if (!permanent) {
return false
}
if (this.manager.params.usePfs) {
await Promise.all(
this.main._sessions.map(async (_, i) => {
const temp = await this.manager._storage.getAuthKeyFor(
this.dcId,
i,
)
await this.main.setAuthKey(temp, true, i)
if (i === 0) {
await Promise.all([
this.upload.setAuthKey(temp, true),
this.download.setAuthKey(temp, true),
this.downloadSmall.setAuthKey(temp, true),
])
}
}),
)
}
return true
}
}
export class NetworkManager {
readonly _log = this.params.log.create('network')
readonly _storage = this.params.storage
readonly _initConnectionParams: tl.RawInitConnectionRequest
readonly _transportFactory: TransportFactory
readonly _reconnectionStrategy: ReconnectionStrategy<PersistentConnectionParams>
readonly _connectionCount: ConnectionCountDelegate
protected readonly _dcConnections: Record<number, DcConnectionManager> = {}
protected _primaryDc?: DcConnectionManager
private _keepAliveInterval?: NodeJS.Timeout
private _lastUpdateTime = 0
private _updateHandler: (upd: tl.TypeUpdates) => void = () => {}
constructor(
readonly params: NetworkManagerParams & NetworkManagerExtraParams,
readonly config: ConfigManager,
) {
let deviceModel = 'mtcute on '
let appVersion = 'unknown'
if (typeof process !== 'undefined' && typeof require !== 'undefined') {
// eslint-disable-next-line @typescript-eslint/no-var-requires
const os = require('os')
deviceModel += `${os.type()} ${os.arch()} ${os.release()}`
try {
// for production builds
// eslint-disable-next-line @typescript-eslint/no-var-requires
appVersion = require('../package.json').version
} catch (e) {
try {
// for development builds (additional /src/ in path)
// eslint-disable-next-line @typescript-eslint/no-var-requires
appVersion = require('../../package.json').version
} catch (e) {}
}
} else if (typeof navigator !== 'undefined') {
deviceModel += navigator.userAgent
} else deviceModel += 'unknown'
this._initConnectionParams = {
_: 'initConnection',
deviceModel,
systemVersion: '1.0',
appVersion,
systemLangCode: 'en',
langPack: '', // "langPacks are for official apps only"
langCode: 'en',
...(params.initConnectionOptions ?? {}),
apiId: params.apiId,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
query: null as any,
}
this._transportFactory = params.transport ?? defaultTransportFactory
this._reconnectionStrategy =
params.reconnectionStrategy ?? defaultReconnectionStrategy
this._connectionCount =
params.connectionCount ?? defaultConnectionCountDelegate
this._onConfigChanged = this._onConfigChanged.bind(this)
config.onConfigUpdate(this._onConfigChanged)
}
private _switchPrimaryDc(dc: DcConnectionManager) {
if (this._primaryDc && this._primaryDc !== dc) {
this._primaryDc.setIsPrimary(false)
}
this._primaryDc = dc
dc.setIsPrimary(true)
dc.main.on('usable', () => {
this._lastUpdateTime = Date.now()
if (this._keepAliveInterval) clearInterval(this._keepAliveInterval)
this._keepAliveInterval = setInterval(async () => {
if (Date.now() - this._lastUpdateTime > 900_000) {
// telegram asks to fetch pending updates if there are no updates for 15 minutes.
// it is up to the user to decide whether to do it or not
this.params.keepAliveAction()
this._lastUpdateTime = Date.now()
}
}, 60_000)
Promise.resolve(this._storage.getSelf()).then((self) => {
if (self?.isBot) {
// bots may receive tmpSessions, which we should respect
this.config.update(true).catch((e) => {
this.params._emitError(e)
})
}
})
})
dc.main.on('update', (update) => {
this._lastUpdateTime = Date.now()
this._updateHandler(update)
})
dc.loadKeys()
.catch((e) => {
this.params._emitError(e)
})
.then(() => {
dc.main.ensureConnected()
})
}
private _dcCreationPromise: Record<number, Promise<void>> = {}
async _getOtherDc(dcId: number): Promise<DcConnectionManager> {
if (!this._dcConnections[dcId]) {
if (dcId in this._dcCreationPromise) {
this._log.debug('waiting for DC %d to be created', dcId)
await this._dcCreationPromise[dcId]
return this._dcConnections[dcId]
}
const promise = createControllablePromise<void>()
this._dcCreationPromise[dcId] = promise
this._log.debug('creating new DC %d', dcId)
try {
const dcOption = await this.config.findOption({
dcId,
allowIpv6: this.params.useIpv6,
preferIpv6: this.params.useIpv6,
allowMedia: true,
preferMedia: true,
cdn: false,
})
if (!dcOption) {
throw new Error(`Could not find DC ${dcId}`)
}
const dc = new DcConnectionManager(this, dcId, dcOption)
if (!(await dc.loadKeys())) {
dc.main.requestAuth()
}
this._dcConnections[dcId] = dc
promise.resolve()
} catch (e) {
promise.reject(e)
}
}
return this._dcConnections[dcId]
}
/**
* Perform initial connection to the default DC
*
* @param defaultDc Default DC to connect to
*/
async connect(defaultDc: tl.RawDcOption): Promise<void> {
if (this._dcConnections[defaultDc.id]) {
// shouldn't happen
throw new Error('DC manager already exists')
}
const dc = new DcConnectionManager(this, defaultDc.id, defaultDc)
this._dcConnections[defaultDc.id] = dc
this._switchPrimaryDc(dc)
}
private async _exportAuthTo(manager: DcConnectionManager): Promise<void> {
const auth = await this.call({
_: 'auth.exportAuthorization',
dcId: manager.dcId,
})
const res = await this.call(
{
_: 'auth.importAuthorization',
id: auth.id,
bytes: auth.bytes,
},
{ manager },
)
if (res._ !== 'auth.authorization') {
throw new Error(
`Unexpected response from auth.importAuthorization: ${res._}`,
)
}
}
async exportAuth(): Promise<void> {
const dcs: Record<number, number> = {}
const config = await this.config.get()
for (const dc of config.dcOptions) {
if (dc.cdn) continue
dcs[dc.id] = dc.id
}
for (const dc of Object.values(dcs)) {
if (dc === this._primaryDc!.dcId) continue
this._log.debug('exporting auth for dc %d', dc)
const manager = await this._getOtherDc(dc)
await this._exportAuthTo(manager)
}
}
setIsPremium(isPremium: boolean): void {
this._log.debug('setting isPremium to %s', isPremium)
this.params.isPremium = isPremium
Object.values(this._dcConnections).forEach((dc) => {
dc.setIsPremium(isPremium)
})
}
async notifyLoggedIn(auth: tl.auth.TypeAuthorization): Promise<void> {
if (
auth._ === 'auth.authorizationSignUpRequired' ||
auth.user._ === 'userEmpty'
) {
return
}
if (auth.tmpSessions) {
this._primaryDc?.main.setCount(auth.tmpSessions)
}
this.setIsPremium(auth.user.premium!)
await this.exportAuth()
}
resetSessions(): void {
const dc = this._primaryDc
if (!dc) return
dc.main.resetSessions()
dc.upload.resetSessions()
dc.download.resetSessions()
dc.downloadSmall.resetSessions()
}
private _onConfigChanged(config: tl.RawConfig): void {
if (config.tmpSessions) {
this._primaryDc?.main.setCount(config.tmpSessions)
}
}
async changePrimaryDc(newDc: number): Promise<void> {
if (newDc === this._primaryDc?.dcId) return
const option = await this.config.findOption({
dcId: newDc,
allowIpv6: this.params.useIpv6,
preferIpv6: this.params.useIpv6,
cdn: false,
allowMedia: false,
})
if (!option) {
throw new Error(`DC ${newDc} not found`)
}
if (!this._dcConnections[newDc]) {
this._dcConnections[newDc] = new DcConnectionManager(
this,
newDc,
option,
)
}
this._storage.setDefaultDc(option)
this._switchPrimaryDc(this._dcConnections[newDc])
}
private _floodWaitedRequests: Record<string, number> = {}
async call<T extends tl.RpcMethod>(
message: T,
params?: RpcCallOptions,
stack?: string,
): Promise<tl.RpcCallReturn[T['_']]> {
if (!this._primaryDc) {
throw new Error('Not connected to any DC')
}
const floodSleepThreshold =
params?.floodSleepThreshold ?? this.params.floodSleepThreshold
const maxRetryCount = params?.maxRetryCount ?? this.params.maxRetryCount
// do not send requests that are in flood wait
if (message._ in this._floodWaitedRequests) {
const delta = this._floodWaitedRequests[message._] - Date.now()
if (delta <= 3000) {
// flood waits below 3 seconds are "ignored"
delete this._floodWaitedRequests[message._]
} else if (delta <= this.params.floodSleepThreshold) {
await sleep(delta)
delete this._floodWaitedRequests[message._]
} else {
throw new tl.errors.FloodWaitXError(delta / 1000)
}
}
let lastError: Error | null = null
const kind = params?.kind ?? 'main'
let manager: DcConnectionManager
if (params?.manager) {
manager = params.manager
} else if (params?.dcId && params.dcId !== this._primaryDc.dcId) {
manager = await this._getOtherDc(params.dcId)
} else {
manager = this._primaryDc
}
let multi = manager[kind]
for (let i = 0; i < maxRetryCount; i++) {
try {
const res = await multi.sendRpc(message, stack, params?.timeout)
if (kind === 'main') {
this._lastUpdateTime = Date.now()
}
return res
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} catch (e: any) {
lastError = e
if (e.code && !(e.code in CLIENT_ERRORS)) {
this._log.warn(
'Telegram is having internal issues: %d %s, retrying',
e.code,
e.message,
)
if (e.message === 'WORKER_BUSY_TOO_LONG_RETRY') {
// according to tdlib, "it is dangerous to resend query without timeout, so use 1"
await sleep(1000)
}
continue
}
if (
e.constructor === tl.errors.FloodWaitXError ||
e.constructor === tl.errors.SlowmodeWaitXError ||
e.constructor === tl.errors.FloodTestPhoneWaitXError
) {
if (e.constructor !== tl.errors.SlowmodeWaitXError) {
// SLOW_MODE_WAIT is chat-specific, not request-specific
this._floodWaitedRequests[message._] =
Date.now() + e.seconds * 1000
}
// In test servers, FLOOD_WAIT_0 has been observed, and sleeping for
// such a short amount will cause retries very fast leading to issues
if (e.seconds === 0) {
(e as tl.Mutable<typeof e>).seconds = 1
}
if (e.seconds <= floodSleepThreshold) {
this._log.info('Flood wait for %d seconds', e.seconds)
await sleep(e.seconds * 1000)
continue
}
}
if (manager === this._primaryDc) {
if (
e.constructor === tl.errors.PhoneMigrateXError ||
e.constructor === tl.errors.UserMigrateXError ||
e.constructor === tl.errors.NetworkMigrateXError
) {
this._log.info('Migrate error, new dc = %d', e.new_dc)
await this.changePrimaryDc(e.new_dc)
manager = this._primaryDc!
multi = manager[kind]
continue
}
} else if (
e.constructor === tl.errors.AuthKeyUnregisteredError
) {
// we can try re-exporting auth from the primary connection
this._log.warn(
'exported auth key error, trying re-exporting..',
)
await this._exportAuthTo(manager)
continue
}
throw e
}
}
throw lastError
}
setUpdateHandler(handler: NetworkManager['_updateHandler']): void {
this._updateHandler = handler
}
changeTransport(factory: TransportFactory): void {
Object.values(this._dcConnections).forEach((dc) => {
dc.main.changeTransport(factory)
dc.upload.changeTransport(factory)
dc.download.changeTransport(factory)
dc.downloadSmall.changeTransport(factory)
})
}
getPoolSize(kind: ConnectionKind, dcId?: number) {
const dc = dcId ? this._dcConnections[dcId] : this._primaryDc
if (!dc) {
if (!this._primaryDc) {
throw new Error('Not connected to any DC')
}
// guess based on the provided delegate. it is most likely correct,
// but we should give actual values if possible
return this._connectionCount(
kind,
dcId ?? this._primaryDc.dcId,
this.params.isPremium,
)
}
return dc[kind].getPoolSize()
}
getPrimaryDcId() {
if (!this._primaryDc) throw new Error('Not connected to any DC')
return this._primaryDc.dcId
}
destroy(): void {
for (const dc of Object.values(this._dcConnections)) {
dc.main.destroy()
dc.upload.destroy()
dc.download.destroy()
dc.downloadSmall.destroy()
}
if (this._keepAliveInterval) clearInterval(this._keepAliveInterval)
this.config.offConfigUpdate(this._onConfigChanged)
}
}

View file

@ -3,10 +3,6 @@ import EventEmitter from 'events'
import { tl } from '@mtcute/tl' import { tl } from '@mtcute/tl'
import { ICryptoProvider, Logger } from '../utils' import { ICryptoProvider, Logger } from '../utils'
import {
ControllablePromise,
createControllablePromise,
} from '../utils/controllable-promise'
import { ReconnectionStrategy } from './reconnection' import { ReconnectionStrategy } from './reconnection'
import { import {
ITelegramTransport, ITelegramTransport,
@ -23,13 +19,18 @@ export interface PersistentConnectionParams {
inactivityTimeout?: number inactivityTimeout?: number
} }
let nextConnectionUid = 0
/** /**
* Base class for persistent connections. * Base class for persistent connections.
* Only used for {@link PersistentConnection} and used as a mean of code splitting. * Only used for {@link PersistentConnection} and used as a mean of code splitting.
* This class doesn't know anything about MTProto, it just manages the transport.
*/ */
export abstract class PersistentConnection extends EventEmitter { export abstract class PersistentConnection extends EventEmitter {
private _uid = nextConnectionUid++
readonly params: PersistentConnectionParams readonly params: PersistentConnectionParams
private _transport!: ITelegramTransport protected _transport!: ITelegramTransport
private _sendOnceConnected: Buffer[] = [] private _sendOnceConnected: Buffer[] = []
@ -41,10 +42,7 @@ export abstract class PersistentConnection extends EventEmitter {
// inactivity timeout // inactivity timeout
private _inactivityTimeout: NodeJS.Timeout | null = null private _inactivityTimeout: NodeJS.Timeout | null = null
private _inactive = false private _inactive = true
// waitForMessage
private _pendingWaitForMessages: ControllablePromise<Buffer>[] = []
_destroyed = false _destroyed = false
_usable = false _usable = false
@ -62,6 +60,14 @@ export abstract class PersistentConnection extends EventEmitter {
super() super()
this.params = params this.params = params
this.changeTransport(params.transportFactory) this.changeTransport(params.transportFactory)
this.log.prefix = `[UID ${this._uid}] `
this._onInactivityTimeout = this._onInactivityTimeout.bind(this)
}
get isConnected(): boolean {
return this._transport.state() !== TransportState.Idle
} }
changeTransport(factory: TransportFactory): void { changeTransport(factory: TransportFactory): void {
@ -73,18 +79,36 @@ export abstract class PersistentConnection extends EventEmitter {
this._transport.setup?.(this.params.crypto, this.log) this._transport.setup?.(this.params.crypto, this.log)
this._transport.on('ready', this.onTransportReady.bind(this)) this._transport.on('ready', this.onTransportReady.bind(this))
this._transport.on('message', this.onTransportMessage.bind(this)) this._transport.on('message', this.onMessage.bind(this))
this._transport.on('error', this.onTransportError.bind(this)) this._transport.on('error', this.onTransportError.bind(this))
this._transport.on('close', this.onTransportClose.bind(this)) this._transport.on('close', this.onTransportClose.bind(this))
} }
onTransportReady(): void { onTransportReady(): void {
// transport ready does not mean actual mtproto is ready // transport ready does not mean actual mtproto is ready
if (this._sendOnceConnected.length) { if (this._sendOnceConnected.length) {
this._transport.send(Buffer.concat(this._sendOnceConnected)) const sendNext = () => {
if (!this._sendOnceConnected.length) {
this.onConnected()
return
}
const data = this._sendOnceConnected.shift()!
this._transport
.send(data)
.then(sendNext)
.catch((err) => {
this.log.error('error sending queued data: %s', err)
this._sendOnceConnected.unshift(data)
})
}
sendNext()
return
} }
this._sendOnceConnected = []
this.onConnected() this.onConnected()
} }
@ -101,32 +125,12 @@ export abstract class PersistentConnection extends EventEmitter {
} }
onTransportError(err: Error): void { onTransportError(err: Error): void {
if (this._pendingWaitForMessages.length) {
this._pendingWaitForMessages.shift()!.reject(err)
return
}
this._lastError = err this._lastError = err
this.onError(err) this.onError(err)
// transport is expected to emit `close` after `error` // transport is expected to emit `close` after `error`
} }
onTransportMessage(data: Buffer): void {
if (this._pendingWaitForMessages.length) {
this._pendingWaitForMessages.shift()!.resolve(data)
return
}
this.onMessage(data)
}
onTransportClose(): void { onTransportClose(): void {
Object.values(this._pendingWaitForMessages).forEach((prom) =>
prom.reject(new Error('Connection closed')),
)
// transport closed because of inactivity // transport closed because of inactivity
// obviously we dont want to reconnect then // obviously we dont want to reconnect then
if (this._inactive) return if (this._inactive) return
@ -139,13 +143,20 @@ export abstract class PersistentConnection extends EventEmitter {
this._consequentFails, this._consequentFails,
this._previousWait, this._previousWait,
) )
if (wait === false) return this.destroy()
if (wait === false) {
this.destroy()
return
}
this.emit('wait', wait) this.emit('wait', wait)
this._previousWait = wait this._previousWait = wait
if (this._reconnectionTimeout != null) { clearTimeout(this._reconnectionTimeout) } if (this._reconnectionTimeout != null) {
clearTimeout(this._reconnectionTimeout)
}
this._reconnectionTimeout = setTimeout(() => { this._reconnectionTimeout = setTimeout(() => {
if (this._destroyed) return if (this._destroyed) return
this._reconnectionTimeout = null this._reconnectionTimeout = null
@ -154,10 +165,14 @@ export abstract class PersistentConnection extends EventEmitter {
} }
connect(): void { connect(): void {
if (this._transport.state() !== TransportState.Idle) { throw new Error('Connection is already opened!') } if (this.isConnected) {
throw new Error('Connection is already opened!')
}
if (this._destroyed) throw new Error('Connection is already destroyed!') if (this._destroyed) throw new Error('Connection is already destroyed!')
if (this._reconnectionTimeout != null) { clearTimeout(this._reconnectionTimeout) } if (this._reconnectionTimeout != null) {
clearTimeout(this._reconnectionTimeout)
}
this._inactive = false this._inactive = false
this._transport.connect(this.params.dc, this.params.testMode) this._transport.connect(this.params.dc, this.params.testMode)
@ -168,8 +183,12 @@ export abstract class PersistentConnection extends EventEmitter {
} }
destroy(): void { destroy(): void {
if (this._reconnectionTimeout != null) { clearTimeout(this._reconnectionTimeout) } if (this._reconnectionTimeout != null) {
if (this._inactivityTimeout != null) { clearTimeout(this._inactivityTimeout) } clearTimeout(this._reconnectionTimeout)
}
if (this._inactivityTimeout != null) {
clearTimeout(this._inactivityTimeout)
}
this._transport.close() this._transport.close()
this._transport.removeAllListeners() this._transport.removeAllListeners()
@ -179,15 +198,32 @@ export abstract class PersistentConnection extends EventEmitter {
protected _rescheduleInactivity(): void { protected _rescheduleInactivity(): void {
if (!this.params.inactivityTimeout) return if (!this.params.inactivityTimeout) return
if (this._inactivityTimeout) clearTimeout(this._inactivityTimeout) if (this._inactivityTimeout) clearTimeout(this._inactivityTimeout)
this._inactivityTimeout = setTimeout(() => { this._inactivityTimeout = setTimeout(
this.log.info( this._onInactivityTimeout,
'disconnected because of inactivity for %d', this.params.inactivityTimeout,
this.params.inactivityTimeout, )
) }
this._inactive = true
this._inactivityTimeout = null protected _onInactivityTimeout(): void {
this._transport.close() this.log.info(
}, this.params.inactivityTimeout) 'disconnected because of inactivity for %d',
this.params.inactivityTimeout,
)
this._inactive = true
this._inactivityTimeout = null
this._transport.close()
}
setInactivityTimeout(timeout?: number): void {
this.params.inactivityTimeout = timeout
if (this._inactivityTimeout) {
clearTimeout(this._inactivityTimeout)
}
if (timeout) {
this._rescheduleInactivity()
}
} }
async send(data: Buffer): Promise<void> { async send(data: Buffer): Promise<void> {
@ -201,11 +237,4 @@ export abstract class PersistentConnection extends EventEmitter {
this._sendOnceConnected.push(data) this._sendOnceConnected.push(data)
} }
} }
waitForNextMessage(): Promise<Buffer> {
const promise = createControllablePromise<Buffer>()
this._pendingWaitForMessages.push(promise)
return promise
}
} }

File diff suppressed because it is too large Load diff

View file

@ -58,6 +58,8 @@ export interface ITelegramTransport extends EventEmitter {
* This method is called before any other. * This method is called before any other.
*/ */
setup?(crypto: ICryptoProvider, log: Logger): void setup?(crypto: ICryptoProvider, log: Logger): void
getMtproxyInfo?(): tl.RawInputClientProxy
} }
/** Transport factory function */ /** Transport factory function */

View file

@ -48,7 +48,9 @@ export abstract class BaseTcpTransport
// eslint-disable-next-line @typescript-eslint/no-unused-vars // eslint-disable-next-line @typescript-eslint/no-unused-vars
connect(dc: tl.RawDcOption, testMode: boolean): void { connect(dc: tl.RawDcOption, testMode: boolean): void {
if (this._state !== TransportState.Idle) { throw new Error('Transport is not IDLE') } if (this._state !== TransportState.Idle) {
throw new Error('Transport is not IDLE')
}
if (!this.packetCodecInitialized) { if (!this.packetCodecInitialized) {
this._packetCodec.setup?.(this._crypto, this.log) this._packetCodec.setup?.(this._crypto, this.log)
@ -69,7 +71,9 @@ export abstract class BaseTcpTransport
this.handleConnect.bind(this), this.handleConnect.bind(this),
) )
this._socket.on('data', (data) => this._packetCodec.feed(data)) this._socket.on('data', (data) => {
this._packetCodec.feed(data)
})
this._socket.on('error', this.handleError.bind(this)) this._socket.on('error', this.handleError.bind(this))
this._socket.on('close', this.close.bind(this)) this._socket.on('close', this.close.bind(this))
} }
@ -87,7 +91,7 @@ export abstract class BaseTcpTransport
this._packetCodec.reset() this._packetCodec.reset()
} }
async handleError(error: Error): Promise<void> { handleError(error: Error): void {
this.log.error('error: %s', error.stack) this.log.error('error: %s', error.stack)
this.emit('error', error) this.emit('error', error)
} }
@ -99,7 +103,11 @@ export abstract class BaseTcpTransport
if (initialMessage.length) { if (initialMessage.length) {
this._socket!.write(initialMessage, (err) => { this._socket!.write(initialMessage, (err) => {
if (err) { if (err) {
this.emit('error', err) this.log.error(
'failed to write initial message: %s',
err.stack,
)
this.emit('error')
this.close() this.close()
} else { } else {
this._state = TransportState.Ready this._state = TransportState.Ready
@ -113,12 +121,20 @@ export abstract class BaseTcpTransport
} }
async send(bytes: Buffer): Promise<void> { async send(bytes: Buffer): Promise<void> {
if (this._state !== TransportState.Ready) { throw new Error('Transport is not READY') } if (this._state !== TransportState.Ready) {
throw new Error('Transport is not READY')
}
const framed = await this._packetCodec.encode(bytes) const framed = await this._packetCodec.encode(bytes)
return new Promise((res, rej) => { return new Promise((resolve, reject) => {
this._socket!.write(framed, (err) => (err ? rej(err) : res())) this._socket!.write(framed, (error) => {
if (error) {
reject(error)
} else {
resolve()
}
})
}) })
} }
} }

View file

@ -78,12 +78,27 @@ export interface ITelegramStorage {
/** /**
* Get auth_key for a given DC * Get auth_key for a given DC
* (returning null will start authorization) * (returning null will start authorization)
* For temp keys: should also return null if the key has expired
*
* @param dcId DC ID
* @param tempIndex Index of the temporary key (usually 0, used for multi-connections)
*/ */
getAuthKeyFor(dcId: number): MaybeAsync<Buffer | null> getAuthKeyFor(dcId: number, tempIndex?: number): MaybeAsync<Buffer | null>
/** /**
* Set auth_key for a given DC * Set auth_key for a given DC
*/ */
setAuthKeyFor(dcId: number, key: Buffer | null): MaybeAsync<void> setAuthKeyFor(dcId: number, key: Buffer | null): MaybeAsync<void>
/**
* Set temp_auth_key for a given DC
* expiresAt is unix time in ms
*/
setTempAuthKeyFor(dcId: number, index: number, key: Buffer | null, expiresAt: number): MaybeAsync<void>
/**
* Remove all saved auth keys (both temp and perm)
* for the given DC. Used when perm_key becomes invalid,
* meaning all temp_keys also become invalid
*/
dropAuthKeysFor(dcId: number): MaybeAsync<void>
/** /**
* Get information about currently logged in user (if available) * Get information about currently logged in user (if available)

View file

@ -15,6 +15,8 @@ export interface MemorySessionState {
defaultDc: tl.RawDcOption | null defaultDc: tl.RawDcOption | null
authKeys: Record<number, Buffer | null> authKeys: Record<number, Buffer | null>
authKeysTemp: Record<string, Buffer | null>
authKeysTempExpiry: Record<string, number>
// marked peer id -> entity info // marked peer id -> entity info
entities: Record<number, PeerInfoWithUpdated> entities: Record<number, PeerInfoWithUpdated>
@ -110,6 +112,8 @@ export class MemoryStorage implements ITelegramStorage, IStateStorage {
$version: CURRENT_VERSION, $version: CURRENT_VERSION,
defaultDc: null, defaultDc: null,
authKeys: {}, authKeys: {},
authKeysTemp: {},
authKeysTempExpiry: {},
entities: {}, entities: {},
phoneIndex: {}, phoneIndex: {},
usernameIndex: {}, usernameIndex: {},
@ -187,14 +191,43 @@ export class MemoryStorage implements ITelegramStorage, IStateStorage {
this._state.defaultDc = dc this._state.defaultDc = dc
} }
setTempAuthKeyFor(
dcId: number,
index: number,
key: Buffer | null,
expiresAt: number,
): void {
const k = `${dcId}:${index}`
this._state.authKeysTemp[k] = key
this._state.authKeysTempExpiry[k] = expiresAt
}
setAuthKeyFor(dcId: number, key: Buffer | null): void { setAuthKeyFor(dcId: number, key: Buffer | null): void {
this._state.authKeys[dcId] = key this._state.authKeys[dcId] = key
} }
getAuthKeyFor(dcId: number): Buffer | null { getAuthKeyFor(dcId: number, tempIndex?: number): Buffer | null {
if (tempIndex !== undefined) {
const k = `${dcId}:${tempIndex}`
if (Date.now() > (this._state.authKeysTempExpiry[k] ?? 0)) { return null }
return this._state.authKeysTemp[k]
}
return this._state.authKeys[dcId] ?? null return this._state.authKeys[dcId] ?? null
} }
dropAuthKeysFor(dcId: number): void {
this._state.authKeys[dcId] = null
Object.keys(this._state.authKeysTemp).forEach((key) => {
if (key.startsWith(`${dcId}:`)) {
delete this._state.authKeysTemp[key]
delete this._state.authKeysTempExpiry[key]
}
})
}
updatePeers(peers: PeerInfoWithUpdated[]): MaybeAsync<void> { updatePeers(peers: PeerInfoWithUpdated[]): MaybeAsync<void> {
for (const peer of peers) { for (const peer of peers) {
this._cachedFull.set(peer.id, peer.full) this._cachedFull.set(peer.id, peer.full)

View file

@ -16,7 +16,9 @@ export function bigIntToBuffer(
): Buffer { ): Buffer {
const array = value.toArray(256).value const array = value.toArray(256).value
if (length !== 0 && array.length > length) { throw new Error('Value out of bounds') } if (length !== 0 && array.length > length) {
throw new Error('Value out of bounds')
}
if (length !== 0) { if (length !== 0) {
// padding // padding
@ -60,6 +62,23 @@ export function randomBigInt(size: number): BigInteger {
return bufferToBigInt(randomBytes(size)) return bufferToBigInt(randomBytes(size))
} }
/**
* Generate a random big integer of the given size (in bits)
* @param bits
*/
export function randomBigIntBits(bits: number): BigInteger {
let num = randomBigInt(Math.ceil(bits / 8))
const bitLength = num.bitLength()
if (bitLength.gt(bits)) {
const toTrim = bigInt.randBetween(bitLength.minus(bits), 8)
num = num.shiftRight(toTrim)
}
return num
}
/** /**
* Generate a random big integer in the range [min, max) * Generate a random big integer in the range [min, max)
* *
@ -80,3 +99,20 @@ export function randomBigIntInRange(
return min.plus(result) return min.plus(result)
} }
/**
* Compute the multiplicity of 2 in the prime factorization of n
* @param n
*/
export function twoMultiplicity(n: BigInteger): BigInteger {
if (n === bigInt.zero) return bigInt.zero
let m = bigInt.zero
let pow = bigInt.one
while (true) {
if (!n.and(pow).isZero()) return m
m = m.plus(bigInt.one)
pow = pow.shiftLeft(1)
}
}

View file

@ -8,12 +8,6 @@ export interface IEncryptionScheme {
decrypt(data: Buffer): MaybeAsync<Buffer> decrypt(data: Buffer): MaybeAsync<Buffer>
} }
export interface IHashMethod {
update(data: Buffer): MaybeAsync<void>
digest(): MaybeAsync<Buffer>
}
export interface ICryptoProvider { export interface ICryptoProvider {
initialize?(): MaybeAsync<void> initialize?(): MaybeAsync<void>
@ -38,8 +32,6 @@ export interface ICryptoProvider {
createAesEcb(key: Buffer): IEncryptionScheme createAesEcb(key: Buffer): IEncryptionScheme
createMd5(): IHashMethod
factorizePQ(pq: Buffer): MaybeAsync<[Buffer, Buffer]> factorizePQ(pq: Buffer): MaybeAsync<[Buffer, Buffer]>
} }

View file

@ -3,7 +3,6 @@ import {
BaseCryptoProvider, BaseCryptoProvider,
ICryptoProvider, ICryptoProvider,
IEncryptionScheme, IEncryptionScheme,
IHashMethod,
} from './abstract' } from './abstract'
// eslint-disable-next-line @typescript-eslint/no-explicit-any // eslint-disable-next-line @typescript-eslint/no-explicit-any
@ -108,15 +107,6 @@ export class ForgeCryptoProvider
) )
} }
createMd5(): IHashMethod {
const hash = forge.md.md5.create()
return {
update: (data) => hash.update(data.toString('binary')),
digest: () => Buffer.from(hash.digest().data, 'binary'),
}
}
hmacSha256(data: Buffer, key: Buffer): MaybeAsync<Buffer> { hmacSha256(data: Buffer, key: Buffer): MaybeAsync<Buffer> {
const hmac = forge.hmac.create() const hmac = forge.hmac.create()
hmac.start('sha256', key.toString('binary')) hmac.start('sha256', key.toString('binary'))

View file

@ -0,0 +1,43 @@
import bigInt, { BigInteger } from 'big-integer'
import { randomBigIntBits, twoMultiplicity } from '../bigint-utils'
export function millerRabin(n: BigInteger, rounds = 20): boolean {
// small numbers: 0, 1 are not prime, 2, 3 are prime
if (n.lt(bigInt[4])) return n.gt(bigInt[1])
if (n.isEven() || n.isNegative()) return false
const nBits = n.bitLength().toJSNumber()
const nSub = n.minus(1)
const r = twoMultiplicity(nSub)
const d = nSub.shiftRight(r)
for (let i = 0; i < rounds; i++) {
let base
do {
base = randomBigIntBits(nBits)
} while (base.leq(bigInt.one) || base.geq(nSub))
let x = base.modPow(d, n)
if (x.eq(bigInt.one) || x.eq(nSub)) continue
let i = bigInt.zero
let y: BigInteger
while (i.lt(r)) {
y = x.modPow(bigInt[2], n)
if (x.eq(bigInt.one)) return false
if (x.eq(nSub)) break
i = i.plus(bigInt.one)
x = y
}
if (i.eq(r)) return false
}
return true
}

View file

@ -11,7 +11,6 @@ import {
BaseCryptoProvider, BaseCryptoProvider,
ICryptoProvider, ICryptoProvider,
IEncryptionScheme, IEncryptionScheme,
IHashMethod,
} from './abstract' } from './abstract'
export class NodeCryptoProvider export class NodeCryptoProvider
@ -83,10 +82,6 @@ export class NodeCryptoProvider
return createHash('sha256').update(data).digest() return createHash('sha256').update(data).digest()
} }
createMd5(): IHashMethod {
return createHash('md5') as unknown as IHashMethod
}
hmacSha256(data: Buffer, key: Buffer): MaybeAsync<Buffer> { hmacSha256(data: Buffer, key: Buffer): MaybeAsync<Buffer> {
return createHmac('sha256', key).update(data).digest() return createHmac('sha256', key).update(data).digest()
} }

View file

@ -58,8 +58,8 @@ export class EarlyTimer {
* Emit the timer right now * Emit the timer right now
*/ */
emitNow(): void { emitNow(): void {
this._handler()
this.reset() this.reset()
this._handler()
} }
/** /**

View file

@ -73,14 +73,27 @@ export class Logger {
const val = args[idx] const val = args[idx]
args.splice(idx, 1) args.splice(idx, 1)
if (m === '%h') return Buffer.isBuffer(val) ? val.toString('hex') : String(val)
if (m === '%h') {
if (Buffer.isBuffer(val)) return val.toString('hex')
if (typeof val === 'number') return val.toString(16)
return String(val)
}
if (m === '%b') return String(Boolean(val)) if (m === '%b') return String(Boolean(val))
if (m === '%j') { if (m === '%j') {
return JSON.stringify(val, (k, v) => { return JSON.stringify(val, (k, v) => {
if (typeof v === 'object' && v.type === 'Buffer' && Array.isArray(v.data)) { if (
typeof v === 'object' &&
v.type === 'Buffer' &&
Array.isArray(v.data)
) {
let str = Buffer.from(v.data).toString('base64') let str = Buffer.from(v.data).toString('base64')
if (str.length > 300) str = str.slice(0, 300) + '...'
if (str.length > 300) {
str = str.slice(0, 300) + '...'
}
return str return str
} }
@ -137,10 +150,10 @@ export class LogManager extends Logger {
static DEBUG = 4 static DEBUG = 4
static VERBOSE = 5 static VERBOSE = 5
constructor() { constructor(tag = 'base') {
// workaround because we cant pass this to super // workaround because we cant pass this to super
// eslint-disable-next-line @typescript-eslint/no-explicit-any // eslint-disable-next-line @typescript-eslint/no-explicit-any
super(null as any, 'base') super(null as any, tag)
// eslint-disable-next-line @typescript-eslint/no-explicit-any // eslint-disable-next-line @typescript-eslint/no-explicit-any
;(this as any).mgr = this ;(this as any).mgr = this
} }

View file

@ -16,6 +16,21 @@ export function randomLong(unsigned = false): Long {
return new Long(lo, hi, unsigned) return new Long(lo, hi, unsigned)
} }
/**
* Read a Long from a buffer
*
* @param buf Buffer to read from
* @param unsigned Whether the number should be unsigned
* @param le Whether the number is little-endian
*/
export function longFromBuffer(buf: Buffer, unsigned = false, le = true): Long {
if (le) {
return new Long(buf.readInt32LE(0), buf.readInt32LE(4), unsigned)
}
return new Long(buf.readInt32BE(4), buf.readInt32BE(0), unsigned)
}
/** /**
* Remove a Long from an array * Remove a Long from an array
* *

View file

@ -2,7 +2,7 @@ import { isatty } from 'tty'
const isTty = isatty(process.stdout.fd) const isTty = isatty(process.stdout.fd)
const BASE_FORMAT = isTty ? '[%s] [%s] %s%s\x1b[0m - ' : '[%s] [%s] %s - ' const BASE_FORMAT = isTty ? '%s [%s] [%s%s\x1b[0m] ' : '%s [%s] [%s] '
const LEVEL_NAMES = isTty ? const LEVEL_NAMES = isTty ?
[ [
'', // OFF '', // OFF

View file

@ -1,4 +1,4 @@
const BASE_FORMAT = '[%s] [%с%s%с] %c%s%c - ' const BASE_FORMAT = '%s [%с%s%с] [%c%s%c] '
const LEVEL_NAMES = [ const LEVEL_NAMES = [
'', // OFF '', // OFF
'ERR', 'ERR',

View file

@ -164,31 +164,6 @@ export function testCryptoProvider(c: ICryptoProvider): void {
'99706487a1cde613bc6de0b6f24b1c7aa448c8b9c3403e3467a8cad89340f53b', '99706487a1cde613bc6de0b6f24b1c7aa448c8b9c3403e3467a8cad89340f53b',
) )
}) })
it('should calculate md5', async () => {
const test = async (...parts: string[]): Promise<Buffer> => {
const md5 = c.createMd5()
for (const p of parts) await md5.update(Buffer.from(p, 'hex'))
return md5.digest()
}
expect((await test()).toString('hex')).eq(
'd41d8cd98f00b204e9800998ecf8427e',
)
expect((await test('aaeeff')).toString('hex')).eq(
'9c20ec5e212b4fcfa4666a8b165c6d5d',
)
expect((await test('aaeeffffeeaa')).toString('hex')).eq(
'cf216071768a7b610d079e5eb7b68b74',
)
expect((await test('aaeeff', 'ffeeaa')).toString('hex')).eq(
'cf216071768a7b610d079e5eb7b68b74',
)
expect((await test('aa', 'ee', 'ff', 'ffeeaa')).toString('hex')).eq(
'cf216071768a7b610d079e5eb7b68b74',
)
})
} }
describe('NodeCryptoProvider', () => { describe('NodeCryptoProvider', () => {

View file

@ -1,71 +1,71 @@
import { expect } from 'chai' // import { expect } from 'chai'
import { randomBytes } from 'crypto' // import { randomBytes } from 'crypto'
import { describe, it } from 'mocha' // import { describe, it } from 'mocha'
//
import __tlReaderMap from '@mtcute/tl/binary/reader' // import __tlReaderMap from '@mtcute/tl/binary/reader'
import { TlBinaryReader } from '@mtcute/tl-runtime' // import { TlBinaryReader } from '@mtcute/tl-runtime'
//
import { createTestTelegramClient } from './utils' // import { createTestTelegramClient } from './utils'
//
// eslint-disable-next-line @typescript-eslint/no-var-requires // // eslint-disable-next-line @typescript-eslint/no-var-requires
require('dotenv-flow').config() // require('dotenv-flow').config()
//
describe('fuzz : packet', async function () { // describe('fuzz : packet', async function () {
this.timeout(45000) // this.timeout(45000)
//
it('random packet', async () => { // it('random packet', async () => {
const client = createTestTelegramClient() // const client = createTestTelegramClient()
//
await client.connect() // await client.connect()
await client.waitUntilUsable() // await client.waitUntilUsable()
//
let errors = 0 // let errors = 0
//
const conn = client.primaryConnection // const conn = client.primaryConnection
// eslint-disable-next-line dot-notation // // eslint-disable-next-line dot-notation
const mtproto = conn['_session'] // const mtproto = conn['_session']
//
for (let i = 0; i < 100; i++) { // for (let i = 0; i < 100; i++) {
const payload = randomBytes(Math.round(Math.random() * 16) * 16) // const payload = randomBytes(Math.round(Math.random() * 16) * 16)
//
try { // try {
// eslint-disable-next-line dot-notation // // eslint-disable-next-line dot-notation
conn['_handleRawMessage']( // conn['_handleRawMessage'](
mtproto.getMessageId().sub(1), // mtproto.getMessageId().sub(1),
0, // 0,
new TlBinaryReader(__tlReaderMap, payload), // new TlBinaryReader(__tlReaderMap, payload),
) // )
} catch (e) { // } catch (e) {
errors += 1 // errors += 1
} // }
} // }
//
// similar test, but this time only using object ids that do exist // // similar test, but this time only using object ids that do exist
const objectIds = Object.keys(__tlReaderMap) // const objectIds = Object.keys(__tlReaderMap)
//
for (let i = 0; i < 100; i++) { // for (let i = 0; i < 100; i++) {
const payload = randomBytes( // const payload = randomBytes(
(Math.round(Math.random() * 16) + 1) * 16, // (Math.round(Math.random() * 16) + 1) * 16,
) // )
const objectId = parseInt( // const objectId = parseInt(
objectIds[Math.round(Math.random() * objectIds.length)], // objectIds[Math.round(Math.random() * objectIds.length)],
) // )
payload.writeUInt32LE(objectId, 0) // payload.writeUInt32LE(objectId, 0)
//
try { // try {
// eslint-disable-next-line dot-notation // // eslint-disable-next-line dot-notation
conn['_handleRawMessage']( // conn['_handleRawMessage'](
mtproto.getMessageId().sub(1), // mtproto.getMessageId().sub(1),
0, // 0,
new TlBinaryReader(__tlReaderMap, payload), // new TlBinaryReader(__tlReaderMap, payload),
) // )
} catch (e) { // } catch (e) {
errors += 1 // errors += 1
} // }
} // }
//
await client.close() // await client.close()
//
expect(errors).gt(0) // expect(errors).gt(0)
}) // })
}) // })

View file

@ -1,77 +1,77 @@
import { expect } from 'chai' // import { expect } from 'chai'
import { randomBytes } from 'crypto' // import { randomBytes } from 'crypto'
import { describe, it } from 'mocha' // import { describe, it } from 'mocha'
//
import { sleep } from '../../src' // import { sleep } from '../../src'
import { createTestTelegramClient } from './utils' // import { createTestTelegramClient } from './utils'
//
// eslint-disable-next-line @typescript-eslint/no-var-requires // // eslint-disable-next-line @typescript-eslint/no-var-requires
require('dotenv-flow').config() // require('dotenv-flow').config()
//
describe('fuzz : session', async function () { // describe('fuzz : session', async function () {
this.timeout(45000) // this.timeout(45000)
//
it('random auth_key', async () => { // it('random auth_key', async () => {
const client = createTestTelegramClient() // const client = createTestTelegramClient()
//
// random key // // random key
const initKey = randomBytes(256) // const initKey = randomBytes(256)
await client.storage.setAuthKeyFor(2, initKey) // await client.storage.setAuthKeyFor(2, initKey)
//
// client is supposed to handle this and generate a new key // // client is supposed to handle this and generate a new key
//
const errors: unknown[] = [] // const errors: Error[] = []
//
const errorHandler = (err: unknown) => { // const errorHandler = (err: Error) => {
errors.push(err) // errors.push(err)
} // }
//
client.onError(errorHandler) // client.onError(errorHandler)
//
await client.connect() // await client.connect()
//
await sleep(10000) // await sleep(10000)
//
await client.close() // await client.close()
//
expect(errors.length).eq(0) // expect(errors.length).eq(0)
//
expect((await client.storage.getAuthKeyFor(2))?.toString('hex')).not.eq( // expect((await client.storage.getAuthKeyFor(2))?.toString('hex')).not.eq(
initKey.toString('hex'), // initKey.toString('hex'),
) // )
}) // })
//
it('random auth_key for other dc', async () => { // it('random auth_key for other dc', async () => {
const client = createTestTelegramClient() // const client = createTestTelegramClient()
//
// random key for dc1 // // random key for dc1
const initKey = randomBytes(256) // const initKey = randomBytes(256)
await client.storage.setAuthKeyFor(1, initKey) // await client.storage.setAuthKeyFor(1, initKey)
//
// client is supposed to handle this and generate a new key // // client is supposed to handle this and generate a new key
//
const errors: unknown[] = [] // const errors: Error[] = []
//
const errorHandler = (err: unknown) => { // const errorHandler = (err: Error) => {
errors.push(err) // errors.push(err)
} // }
//
client.onError(errorHandler) // client.onError(errorHandler)
//
await client.connect() // await client.connect()
await client.waitUntilUsable() // await client.waitUntilUsable()
//
const conn = await client.createAdditionalConnection(1) // const conn = await client.createAdditionalConnection(1)
await conn.sendRpc({ _: 'help.getConfig' }) // await conn.sendRpc({ _: 'help.getConfig' })
//
await sleep(10000) // await sleep(10000)
//
await client.close() // await client.close()
//
expect(errors.length).eq(0) // expect(errors.length).eq(0)
//
expect((await client.storage.getAuthKeyFor(1))?.toString('hex')).not.eq( // expect((await client.storage.getAuthKeyFor(1))?.toString('hex')).not.eq(
initKey.toString('hex'), // initKey.toString('hex'),
) // )
}) // })
}) // })

View file

@ -1,128 +1,127 @@
/* eslint-disable @typescript-eslint/no-explicit-any */ // import { expect } from 'chai'
import { expect } from 'chai' // import { randomBytes } from 'crypto'
import { randomBytes } from 'crypto' // import { EventEmitter } from 'events'
import { EventEmitter } from 'events' // import { describe, it } from 'mocha'
import { describe, it } from 'mocha' //
// import {
import { // BaseTelegramClient,
BaseTelegramClient, // defaultDcs,
defaultDcs, // ITelegramTransport,
ITelegramTransport, // NodeCryptoProvider,
NodeCryptoProvider, // sleep,
sleep, // tl,
tl, // TransportState,
TransportState, // } from '../../src'
} from '../../src' //
// // eslint-disable-next-line @typescript-eslint/no-var-requires
// eslint-disable-next-line @typescript-eslint/no-var-requires // require('dotenv-flow').config()
require('dotenv-flow').config() //
// class RandomBytesTransport extends EventEmitter implements ITelegramTransport {
class RandomBytesTransport extends EventEmitter implements ITelegramTransport { // dc: tl.RawDcOption
dc!: tl.RawDcOption // interval?: NodeJS.Timeout
interval?: NodeJS.Timeout //
// close(): void {
close(): void { // clearInterval(this.interval)
clearInterval(this.interval) // this.emit('close')
this.emit('close') // this.interval = undefined
this.interval = undefined // }
} //
// connect(dc: tl.RawDcOption): void {
connect(dc: tl.RawDcOption): void { // this.dc = dc
this.dc = dc //
// setTimeout(() => this.emit('ready'), 0)
setTimeout(() => this.emit('ready'), 0) //
// this.interval = setInterval(() => {
this.interval = setInterval(() => { // this.emit('message', randomBytes(64))
this.emit('message', randomBytes(64)) // }, 100)
}, 100) // }
} //
// currentDc(): tl.RawDcOption | null {
currentDc(): tl.RawDcOption | null { // return this.dc
return this.dc // }
} //
// send(_data: Buffer): Promise<void> {
send(_data: Buffer): Promise<void> { // return Promise.resolve()
return Promise.resolve() // }
} //
// state(): TransportState {
state(): TransportState { // return this.interval ? TransportState.Ready : TransportState.Idle
return this.interval ? TransportState.Ready : TransportState.Idle // }
} // }
} //
// describe('fuzz : transport', function () {
describe('fuzz : transport', function () { // this.timeout(30000)
this.timeout(30000) //
// it('RandomBytesTransport (no auth)', async () => {
it('RandomBytesTransport (no auth)', async () => { // const client = new BaseTelegramClient({
const client = new BaseTelegramClient({ // crypto: () => new NodeCryptoProvider(),
crypto: () => new NodeCryptoProvider(), // transport: () => new RandomBytesTransport(),
transport: () => new RandomBytesTransport(), // apiId: 0,
apiId: 0, // apiHash: '',
apiHash: '', // defaultDc: defaultDcs.defaultTestDc,
primaryDc: defaultDcs.defaultTestDc, // })
}) // client.log.level = 0
client.log.level = 0 //
// const errors: Error[] = []
const errors: Error[] = [] //
// client.onError((err) => {
client.onError((err) => { // errors.push(err)
errors.push(err) // })
}) //
// await client.connect()
await client.connect() // await sleep(15000)
await sleep(15000) // await client.close()
await client.close() //
// expect(errors.length).gt(0)
expect(errors.length).gt(0) // errors.forEach((err) => {
errors.forEach((err) => { // expect(err.message).match(/unknown object id/i)
expect(err.message).match(/unknown object id/i) // })
}) // })
}) //
// it('RandomBytesTransport (with auth)', async () => {
it('RandomBytesTransport (with auth)', async () => { // const client = new BaseTelegramClient({
const client = new BaseTelegramClient({ // crypto: () => new NodeCryptoProvider(),
crypto: () => new NodeCryptoProvider(), // transport: () => new RandomBytesTransport(),
transport: () => new RandomBytesTransport(), // apiId: 0,
apiId: 0, // apiHash: '',
apiHash: '', // defaultDc: defaultDcs.defaultTestDc,
primaryDc: defaultDcs.defaultTestDc, // })
}) // client.log.level = 0
client.log.level = 0 //
// // random key just to make it think it already has one
// random key just to make it think it already has one // await client.storage.setAuthKeyFor(2, randomBytes(256))
await client.storage.setAuthKeyFor(2, randomBytes(256)) //
// // in this case, there will be no actual errors, only
// in this case, there will be no actual errors, only // // warnings like 'received message with unknown authKey'
// warnings like 'received message with unknown authKey' // //
// // // to test for that, we hook into `decryptMessage` and make
// to test for that, we hook into `decryptMessage` and make // // sure that it returns `null`
// sure that it returns `null` //
// await client.connect()
await client.connect() //
// let hadNonNull = false
let hadNonNull = false //
// const decryptMessage =
const decryptMessage = // // eslint-disable-next-line dot-notation
// eslint-disable-next-line dot-notation // client.primaryConnection['_session'].decryptMessage
client.primaryConnection['_session'].decryptMessage //
// // ехал any через any
// ехал any через any // // видит any - any, any
// видит any - any, any // // сунул any any в any
// сунул any any в any // // any any any any
// any any any any // // eslint-disable-next-line dot-notation
// eslint-disable-next-line dot-notation // ;(client.primaryConnection['_session'] as any).decryptMessage = (
;(client.primaryConnection['_session'] as any).decryptMessage = ( // buf: any,
buf: any, // cb: any,
cb: any, // ) =>
) => // decryptMessage.call(this, buf, (...args: any[]) => {
decryptMessage.call(this, buf, (...args: any[]) => { // cb(...(args as any))
cb(...(args as any)) // hadNonNull = true
hadNonNull = true // })
}) //
// await sleep(15000)
await sleep(15000) // await client.close()
await client.close() //
// expect(hadNonNull).false
expect(hadNonNull).false // })
}) // })
})

View file

@ -0,0 +1,139 @@
import bigInt from 'big-integer'
import { expect } from 'chai'
import { describe, it } from 'mocha'
import { millerRabin } from '../src/utils/crypto/miller-rabin'
describe('miller-rabin test', function () {
this.timeout(10000) // since miller-rabin factorization relies on RNG, it may take a while (or may not!)
const testMillerRabin = (n: bigInt.BigNumber, isPrime: boolean) => {
expect(millerRabin(bigInt(n as number))).eq(isPrime)
}
it('should correctly label small primes as probable primes', () => {
const smallOddPrimes = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31]
for (const prime of smallOddPrimes) {
testMillerRabin(prime, true)
}
})
it('should correctly label small odd composite numbers as composite', () => {
const smallOddPrimes = [9, 15, 21, 25, 27, 33, 35]
for (const prime of smallOddPrimes) {
testMillerRabin(prime, false)
}
})
// primes are generated using `openssl prime -generate -bits <bits>`
it('should work for 512-bit numbers', () => {
testMillerRabin(
'8411445470921866378538628788380866906358949375899610911537071281076627385046125382763689993349183284546479522400013151510610266158235924343045768103605519',
true,
)
testMillerRabin(
'11167561990563990242158096122232207092938761092751537312016255867850441858086589598418467012717458858604863547175649456433632887622140170743409535470973399',
true,
)
testMillerRabin(
'11006717791910450367418249787526506184731090161438431250022510598653874155081488487035840577645711578911087148186160668569071839053453201592321650008610329',
true,
)
testMillerRabin(
'12224330340162812215033324917156282302617911690617664923428569636370785775561435789211091021550357876767050350997458404009005800772805534351607294516706177',
true,
)
// above numbers but -2 (not prime)
testMillerRabin(
'8411445470921866378538628788380866906358949375899610911537071281076627385046125382763689993349183284546479522400013151510610266158235924343045768103605517',
false,
)
testMillerRabin(
'11167561990563990242158096122232207092938761092751537312016255867850441858086589598418467012717458858604863547175649456433632887622140170743409535470973397',
false,
)
testMillerRabin(
'11006717791910450367418249787526506184731090161438431250022510598653874155081488487035840577645711578911087148186160668569071839053453201592321650008610327',
false,
)
testMillerRabin(
'12224330340162812215033324917156282302617911690617664923428569636370785775561435789211091021550357876767050350997458404009005800772805534351607294516706175',
false,
)
})
it('should work for 1024-bit numbers', () => {
testMillerRabin(
'94163180970530844245052892199633535954736903357996153321496979115367320260897793334681106861766748541439161886270777106456088209508872459550450259737267142959061663564218457086654112219462515165219295402175541003899136060178102898376369981338103600856012709228116661479275753497725541132207243717937379815409',
true,
)
testMillerRabin(
'97324962433497727515811278760066576725849776656602017497363465683978397629803148191267105308901733336070351381654371470561376353774017284623969415330564867697353080030917333974193741719718950105404732792050882127213356260415251087867407489400712288570880407613514781891914135956778687719588061176455381937003',
true,
)
testMillerRabin(
'92511311413226091818378551616231701579277597795073142338527410334932345968554993390789667936819230228388142960299649466238701015865565141753710450319875546944139442823075990348978746055937500467483161699883905850192191164043687791185635729923497381849380102040768674652775240505782671289535260164547714030567',
true,
)
testMillerRabin(
'98801756216479639848708157708947504990501845258427605711852570166662700681215707617225664134994147912417941920327932092748574265476658124536672887141144222716123085451749764522435906007567360583062117498919471220566974634924384147341592903939264267901029640119196259026154529723870788246284629644039137378253',
true,
)
// above numbers but -2 (not prime)
testMillerRabin(
'94163180970530844245052892199633535954736903357996153321496979115367320260897793334681106861766748541439161886270777106456088209508872459550450259737267142959061663564218457086654112219462515165219295402175541003899136060178102898376369981338103600856012709228116661479275753497725541132207243717937379815407',
false,
)
testMillerRabin(
'97324962433497727515811278760066576725849776656602017497363465683978397629803148191267105308901733336070351381654371470561376353774017284623969415330564867697353080030917333974193741719718950105404732792050882127213356260415251087867407489400712288570880407613514781891914135956778687719588061176455381937001',
false,
)
testMillerRabin(
'92511311413226091818378551616231701579277597795073142338527410334932345968554993390789667936819230228388142960299649466238701015865565141753710450319875546944139442823075990348978746055937500467483161699883905850192191164043687791185635729923497381849380102040768674652775240505782671289535260164547714030565',
false,
)
testMillerRabin(
'98801756216479639848708157708947504990501845258427605711852570166662700681215707617225664134994147912417941920327932092748574265476658124536672887141144222716123085451749764522435906007567360583062117498919471220566974634924384147341592903939264267901029640119196259026154529723870788246284629644039137378251',
false,
)
})
it('should work for 2048-bit numbers', () => {
testMillerRabin(
'28608382334358769588283288249494859626901014972463291352091976543138105382282108662849885913053034513852843449409838151123568984617793641641937583673207501643041336002587032201383537626393235736734494131431069043382068545865505150651648610506542819001961332454611129372758714288168807328523359776577571626967649079147416191592855529888846889532625386469236278694936872628305052827422772792103722178298844645210242389265273407924858034431614414896134561928996888883994953322861399988094086562513898527391555490352156627307769278185444897960555995383228897584818577375695810423475039211516849716140051437120083274285367',
true,
)
testMillerRabin(
'30244022694659482453371920976249272809817388822378671144866806600284132009663832003348737406289715119965835410140834733465553787513841966120831322372642881643693711233087233983267648392814127424201572290931937482043046169402667397610783447368703776842799852222745601531140231486417855517072392416789672922529566643118973930252809010605519948446055538976582290902060054788109497630796585770940656002892943575479533099350429655210881833493066716819282707441553612603960556051122162329171373373251909387401572866056121964608595895425640834764028568120995397759283490218181167000161310959711677055741632674632758727382743',
true,
)
testMillerRabin(
'30560953105766401423987964658775999222308579908395527900931049506803845883459894704297458477118152899910620180302473409631442956208933061650967001020981432894530064472547770442696756724169958362395601360296775798187903794894866967342028337982275745956538015473621792510615113531964380246815875830970404687926061637030085629909804357717955251735074071072456074274947993921828878633638119117086342305530526661796817095624933200483138188878398983149622639425550360394901699701985050966685840649129419227936413574227792077082510807968104733387734970009620450108276446659342203263759999068046251645984039420643003580284779',
true,
)
// above numbers but -2 (not prime)
testMillerRabin(
'28608382334358769588283288249494859626901014972463291352091976543138105382282108662849885913053034513852843449409838151123568984617793641641937583673207501643041336002587032201383537626393235736734494131431069043382068545865505150651648610506542819001961332454611129372758714288168807328523359776577571626967649079147416191592855529888846889532625386469236278694936872628305052827422772792103722178298844645210242389265273407924858034431614414896134561928996888883994953322861399988094086562513898527391555490352156627307769278185444897960555995383228897584818577375695810423475039211516849716140051437120083274285365',
false,
)
testMillerRabin(
'30244022694659482453371920976249272809817388822378671144866806600284132009663832003348737406289715119965835410140834733465553787513841966120831322372642881643693711233087233983267648392814127424201572290931937482043046169402667397610783447368703776842799852222745601531140231486417855517072392416789672922529566643118973930252809010605519948446055538976582290902060054788109497630796585770940656002892943575479533099350429655210881833493066716819282707441553612603960556051122162329171373373251909387401572866056121964608595895425640834764028568120995397759283490218181167000161310959711677055741632674632758727382741',
false,
)
testMillerRabin(
'30560953105766401423987964658775999222308579908395527900931049506803845883459894704297458477118152899910620180302473409631442956208933061650967001020981432894530064472547770442696756724169958362395601360296775798187903794894866967342028337982275745956538015473621792510615113531964380246815875830970404687926061637030085629909804357717955251735074071072456074274947993921828878633638119117086342305530526661796817095624933200483138188878398983149622639425550360394901699701985050966685840649129419227936413574227792077082510807968104733387734970009620450108276446659342203263759999068046251645984039420643003580284777',
false,
)
// dh_prime used by telegram, as seen in https://core.telegram.org/mtproto/security_guidelines
const telegramDhPrime =
'C7 1C AE B9 C6 B1 C9 04 8E 6C 52 2F 70 F1 3F 73 98 0D 40 23 8E 3E 21 C1 49 34 D0 37 56 3D 93 0F 48 19 8A 0A A7 C1 40 58 22 94 93 D2 25 30 F4 DB FA 33 6F 6E 0A C9 25 13 95 43 AE D4 4C CE 7C 37 20 FD 51 F6 94 58 70 5A C6 8C D4 FE 6B 6B 13 AB DC 97 46 51 29 69 32 84 54 F1 8F AF 8C 59 5F 64 24 77 FE 96 BB 2A 94 1D 5B CD 1D 4A C8 CC 49 88 07 08 FA 9B 37 8E 3C 4F 3A 90 60 BE E6 7C F9 A4 A4 A6 95 81 10 51 90 7E 16 27 53 B5 6B 0F 6B 41 0D BA 74 D8 A8 4B 2A 14 B3 14 4E 0E F1 28 47 54 FD 17 ED 95 0D 59 65 B4 B9 DD 46 58 2D B1 17 8D 16 9C 6B C4 65 B0 D6 FF 9C A3 92 8F EF 5B 9A E4 E4 18 FC 15 E8 3E BE A0 F8 7F A9 FF 5E ED 70 05 0D ED 28 49 F4 7B F9 59 D9 56 85 0C E9 29 85 1F 0D 81 15 F6 35 B1 05 EE 2E 4E 15 D0 4B 24 54 BF 6F 4F AD F0 34 B1 04 03 11 9C D8 E3 B9 2F CC 5B'
testMillerRabin(bigInt(telegramDhPrime.replace(/ /g, ''), 16), true)
})
})

View file

@ -4,7 +4,8 @@
"outDir": "./dist" "outDir": "./dist"
}, },
"include": [ "include": [
"./src" "./src",
"./tests"
], ],
"typedocOptions": { "typedocOptions": {
"name": "@mtcute/core", "name": "@mtcute/core",

View file

@ -91,10 +91,20 @@ export class MtProxyTcpTransport extends BaseTcpTransport {
} }
} }
getMtproxyInfo(): tl.RawInputClientProxy {
return {
_: 'inputClientProxy',
address: this._proxy.host,
port: this._proxy.port,
}
}
_packetCodec!: IPacketCodec _packetCodec!: IPacketCodec
connect(dc: tl.RawDcOption, testMode: boolean): void { connect(dc: tl.RawDcOption, testMode: boolean): void {
if (this._state !== TransportState.Idle) { throw new Error('Transport is not IDLE') } if (this._state !== TransportState.Idle) {
throw new Error('Transport is not IDLE')
}
if (this._packetCodec && this._currentDc?.id !== dc.id) { if (this._packetCodec && this._currentDc?.id !== dc.id) {
// dc changed, thus the codec's init will change too // dc changed, thus the codec's init will change too

View file

@ -12,6 +12,6 @@
}, },
"dependencies": { "dependencies": {
"@mtcute/core": "workspace:^1.0.0", "@mtcute/core": "workspace:^1.0.0",
"ip6": "0.2.10" "ip6": "0.2.7"
} }
} }

View file

@ -54,61 +54,64 @@ function getInputPeer(
throw new Error(`Invalid peer type: ${row.type}`) throw new Error(`Invalid peer type: ${row.type}`)
} }
const CURRENT_VERSION = 2 const CURRENT_VERSION = 3
// language=SQLite // language=SQLite format=false
const TEMP_AUTH_TABLE = `
create table temp_auth_keys (
dc integer not null,
idx integer not null,
key blob not null,
expires integer not null,
primary key (dc, idx)
);
`
// language=SQLite format=false
const SCHEMA = ` const SCHEMA = `
create table kv create table kv (
( key text primary key,
key text primary key,
value text not null value text not null
); );
create table state create table state (
( key text primary key,
key text primary key, value text not null,
value text not null,
expires number expires number
); );
create table auth_keys create table auth_keys (
( dc integer primary key,
dc integer primary key,
key blob not null key blob not null
); );
create table pts ${TEMP_AUTH_TABLE}
(
create table pts (
channel_id integer primary key, channel_id integer primary key,
pts integer not null pts integer not null
); );
create table entities create table entities (
( id integer primary key,
id integer primary key, hash text not null,
hash text not null, type text not null,
type text not null,
username text, username text,
phone text, phone text,
updated integer not null, updated integer not null,
"full" blob "full" blob
); );
create index idx_entities_username on entities (username); create index idx_entities_username on entities (username);
create index idx_entities_phone on entities (phone); create index idx_entities_phone on entities (phone);
` `
// language=SQLite format=false
const RESET = ` const RESET = `
delete delete from kv where key <> 'ver';
from kv delete from state;
where key <> 'ver'; delete from auth_keys;
delete delete from pts;
from state; delete from entities
delete
from auth_keys;
delete
from pts;
delete
from entities
` `
const USERNAME_TTL = 86400000 // 24 hours const USERNAME_TTL = 86400000 // 24 hours
@ -144,8 +147,14 @@ const STATEMENTS = {
delState: 'delete from state where key = ?', delState: 'delete from state where key = ?',
getAuth: 'select key from auth_keys where dc = ?', getAuth: 'select key from auth_keys where dc = ?',
getAuthTemp:
'select key from temp_auth_keys where dc = ? and idx = ? and expires > ?',
setAuth: 'insert or replace into auth_keys (dc, key) values (?, ?)', setAuth: 'insert or replace into auth_keys (dc, key) values (?, ?)',
setAuthTemp:
'insert or replace into temp_auth_keys (dc, idx, key, expires) values (?, ?, ?, ?)',
delAuth: 'delete from auth_keys where dc = ?', delAuth: 'delete from auth_keys where dc = ?',
delAuthTemp: 'delete from temp_auth_keys where dc = ? and idx = ?',
delAllAuthTemp: 'delete from temp_auth_keys where dc = ?',
getPts: 'select pts from pts where channel_id = ?', getPts: 'select pts from pts where channel_id = ?',
setPts: 'insert or replace into pts (channel_id, pts) values (?, ?)', setPts: 'insert or replace into pts (channel_id, pts) values (?, ?)',
@ -376,12 +385,24 @@ export class SqliteStorage implements ITelegramStorage, IStateStorage {
'Unsupported session version, please migrate manually', 'Unsupported session version, please migrate manually',
) )
} }
if (from === 2) {
// PFS support added
this._db.exec(TEMP_AUTH_TABLE)
from = 3
}
if (from !== CURRENT_VERSION) {
// an assertion just in case i messed up
throw new Error('Migration incomplete')
}
} }
private _initializeStatements(): void { private _initializeStatements(): void {
this._statements = {} as unknown as typeof this._statements this._statements = {} as unknown as typeof this._statements
Object.entries(STATEMENTS).forEach(([name, sql]) => { Object.entries(STATEMENTS).forEach(([name, sql]) => {
this._statements[name as keyof typeof this._statements] = this._db.prepare(sql) this._statements[name as keyof typeof this._statements] =
this._db.prepare(sql)
}) })
} }
@ -397,7 +418,7 @@ export class SqliteStorage implements ITelegramStorage, IStateStorage {
const versionResult = this._db const versionResult = this._db
.prepare("select value from kv where key = 'ver'") .prepare("select value from kv where key = 'ver'")
.get() .get()
const version = (versionResult as { value: number }).value const version = Number((versionResult as { value: number }).value)
this.log.debug('current db version = %d', version) this.log.debug('current db version = %d', version)
@ -426,7 +447,10 @@ export class SqliteStorage implements ITelegramStorage, IStateStorage {
load(): void { load(): void {
this._db = sqlite3(this._filename, { this._db = sqlite3(this._filename, {
verbose: this.log.mgr.level === 5 ? this.log.verbose as Options['verbose'] : undefined, verbose:
this.log.mgr.level === 5 ?
(this.log.verbose as Options['verbose']) :
undefined,
}) })
this._initialize() this._initialize()
@ -481,8 +505,14 @@ export class SqliteStorage implements ITelegramStorage, IStateStorage {
return this._getFromKv('def_dc') return this._getFromKv('def_dc')
} }
getAuthKeyFor(dcId: number): Buffer | null { getAuthKeyFor(dcId: number, tempIndex?: number): Buffer | null {
const row = this._statements.getAuth.get(dcId) let row
if (tempIndex !== undefined) {
row = this._statements.getAuthTemp.get(dcId, tempIndex, Date.now())
} else {
row = this._statements.getAuth.get(dcId)
}
return row ? (row as { key: Buffer }).key : null return row ? (row as { key: Buffer }).key : null
} }
@ -494,6 +524,27 @@ export class SqliteStorage implements ITelegramStorage, IStateStorage {
]) ])
} }
setTempAuthKeyFor(
dcId: number,
index: number,
key: Buffer | null,
expires: number,
): void {
this._pending.push([
key === null ?
this._statements.delAuthTemp :
this._statements.setAuthTemp,
key === null ? [dcId, index] : [dcId, index, key, expires],
])
}
dropAuthKeysFor(dcId: number): void {
this._pending.push(
[this._statements.delAuth, [dcId]],
[this._statements.delAllAuthTemp, [dcId]],
)
}
getSelf(): ITelegramStorage.SelfInfo | null { getSelf(): ITelegramStorage.SelfInfo | null {
return this._getFromKv('self') return this._getFromKv('self')
} }
@ -601,7 +652,9 @@ export class SqliteStorage implements ITelegramStorage, IStateStorage {
const cached = this._cache?.get(peerId) const cached = this._cache?.get(peerId)
if (cached) return cached.peer if (cached) return cached.peer
const row = this._statements.getEntById.get(peerId) as SqliteEntity | null const row = this._statements.getEntById.get(
peerId,
) as SqliteEntity | null
if (row) { if (row) {
const peer = getInputPeer(row) const peer = getInputPeer(row)
@ -617,7 +670,9 @@ export class SqliteStorage implements ITelegramStorage, IStateStorage {
} }
getPeerByPhone(phone: string): tl.TypeInputPeer | null { getPeerByPhone(phone: string): tl.TypeInputPeer | null {
const row = this._statements.getEntByPhone.get(phone) as SqliteEntity | null const row = this._statements.getEntByPhone.get(
phone,
) as SqliteEntity | null
if (row) { if (row) {
const peer = getInputPeer(row) const peer = getInputPeer(row)
@ -633,7 +688,9 @@ export class SqliteStorage implements ITelegramStorage, IStateStorage {
} }
getPeerByUsername(username: string): tl.TypeInputPeer | null { getPeerByUsername(username: string): tl.TypeInputPeer | null {
const row = this._statements.getEntByUser.get(username.toLowerCase()) as SqliteEntity | null const row = this._statements.getEntByUser.get(
username.toLowerCase(),
) as SqliteEntity | null
if (!row || Date.now() - row.updated > USERNAME_TTL) return null if (!row || Date.now() - row.updated > USERNAME_TTL) return null
if (row) { if (row) {

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load diff

View file

@ -9,6 +9,7 @@
"inlineSources": true, "inlineSources": true,
"declaration": true, "declaration": true,
"esModuleInterop": true, "esModuleInterop": true,
"allowSyntheticDefaultImports": true,
"strict": true, "strict": true,
"noImplicitAny": true, "noImplicitAny": true,
"noImplicitThis": true, "noImplicitThis": true,