diff --git a/packages/core/src/network/client.ts b/packages/core/src/network/client.ts index d58eba62..5163a02f 100644 --- a/packages/core/src/network/client.ts +++ b/packages/core/src/network/client.ts @@ -351,7 +351,7 @@ export class MtClient extends EventEmitter { */ async close(): Promise { this._config.destroy() - this.network.destroy() + await this.network.destroy() await this.storage.save() await this.storage.destroy?.() diff --git a/packages/core/src/network/multi-session-connection.ts b/packages/core/src/network/multi-session-connection.ts index e301294f..3ff7021a 100644 --- a/packages/core/src/network/multi-session-connection.ts +++ b/packages/core/src/network/multi-session-connection.ts @@ -116,7 +116,9 @@ export class MultiSessionConnection extends EventEmitter { // destroy extra connections for (let i = this._connections.length - 1; i >= this._count; i--) { this._connections[i].removeAllListeners() - this._connections[i].destroy() + this._connections[i].destroy().catch((err) => { + this._log.warn('error destroying connection: %s', err) + }) } this._connections.splice(this._count) @@ -199,8 +201,8 @@ export class MultiSessionConnection extends EventEmitter { } _destroyed = false - destroy(): void { - this._connections.forEach((conn) => conn.destroy()) + async destroy(): Promise { + await Promise.all(this._connections.map((conn) => conn.destroy())) this._sessions.forEach((sess) => sess.reset()) this.removeAllListeners() diff --git a/packages/core/src/network/network-manager.ts b/packages/core/src/network/network-manager.ts index da746da2..df4e1d7d 100644 --- a/packages/core/src/network/network-manager.ts +++ b/packages/core/src/network/network-manager.ts @@ -413,11 +413,11 @@ export class DcConnectionManager { return true } - destroy() { - this.main.destroy() - this.upload.destroy() - this.download.destroy() - this.downloadSmall.destroy() + async destroy() { + await this.main.destroy() + await this.upload.destroy() + await this.download.destroy() + await this.downloadSmall.destroy() this._salts.destroy() } } @@ -861,9 +861,9 @@ export class NetworkManager { return this._primaryDc.dcId } - destroy(): void { + async destroy(): Promise { for (const dc of this._dcConnections.values()) { - dc.destroy() + await dc.destroy() } this.config.offReload(this._onConfigChanged) this._resetOnNetworkChange?.() diff --git a/packages/core/src/network/persistent-connection.ts b/packages/core/src/network/persistent-connection.ts index 4c9b2cac..6c799bae 100644 --- a/packages/core/src/network/persistent-connection.ts +++ b/packages/core/src/network/persistent-connection.ts @@ -69,7 +69,9 @@ export abstract class PersistentConnection extends EventEmitter { changeTransport(factory: TransportFactory): void { if (this._transport) { - this._transport.close() + Promise.resolve(this._transport.close()).catch((err) => { + this.log.warn('error closing previous transport: %s', err) + }) } this._transport = factory() @@ -149,7 +151,9 @@ export abstract class PersistentConnection extends EventEmitter { ) if (wait === false) { - this.destroy() + this.destroy().catch((err) => { + this.log.warn('error destroying connection: %s', err) + }) return } @@ -192,7 +196,9 @@ export abstract class PersistentConnection extends EventEmitter { // if we are already connected if (this.isConnected) { this._shouldReconnectImmediately = true - this._transport.close() + Promise.resolve(this._transport.close()).catch((err) => { + this.log.error('error closing transport: %s', err) + }) return } @@ -201,12 +207,12 @@ export abstract class PersistentConnection extends EventEmitter { this.connect() } - disconnectManual(): void { + async disconnectManual(): Promise { this._disconnectedManually = true - this._transport.close() + await this._transport.close() } - destroy(): void { + async destroy(): Promise { if (this._reconnectionTimeout != null) { clearTimeout(this._reconnectionTimeout) } @@ -214,7 +220,7 @@ export abstract class PersistentConnection extends EventEmitter { clearTimeout(this._inactivityTimeout) } - this._transport.close() + await this._transport.close() this._transport.removeAllListeners() this._destroyed = true } @@ -229,7 +235,9 @@ export abstract class PersistentConnection extends EventEmitter { this.log.info('disconnected because of inactivity for %d', this.params.inactivityTimeout) this._inactive = true this._inactivityTimeout = null - this._transport.close() + Promise.resolve(this._transport.close()).catch((err) => { + this.log.warn('error closing transport: %s', err) + }) } setInactivityTimeout(timeout?: number): void { diff --git a/packages/core/src/network/session-connection.ts b/packages/core/src/network/session-connection.ts index 401162ba..0dc08df2 100644 --- a/packages/core/src/network/session-connection.ts +++ b/packages/core/src/network/session-connection.ts @@ -155,8 +155,8 @@ export class SessionConnection extends PersistentConnection { this.reset() } - destroy(): void { - super.destroy() + async destroy(): Promise { + await super.destroy() this.reset(true) } @@ -1462,7 +1462,9 @@ export class SessionConnection extends PersistentConnection { if (online) { this.reconnect() } else { - this.disconnectManual() + this.disconnectManual().catch((err) => { + this.log.warn('error while disconnecting: %s', err) + }) } } diff --git a/packages/core/src/network/transports/abstract.ts b/packages/core/src/network/transports/abstract.ts index eaf70218..94f2be42 100644 --- a/packages/core/src/network/transports/abstract.ts +++ b/packages/core/src/network/transports/abstract.ts @@ -48,7 +48,7 @@ export interface ITelegramTransport extends EventEmitter { */ connect(dc: BasicDcOption, testMode: boolean): void /** call to close existing connection to some DC */ - close(): void + close(): MaybePromise /** send a message */ send(data: Uint8Array): Promise diff --git a/packages/web/src/websocket.test.ts b/packages/web/src/websocket.test.ts index 91630352..13c7e51a 100644 --- a/packages/web/src/websocket.test.ts +++ b/packages/web/src/websocket.test.ts @@ -11,20 +11,24 @@ const p = getPlatform() describe('WebSocketTransport', () => { const create = async () => { + let closeListener: () => void = () => {} const fakeWs = vi.fn().mockImplementation(() => ({ addEventListener: vi.fn().mockImplementation((event: string, cb: () => void) => { if (event === 'open') { cb() } + if (event === 'close') { + closeListener = cb + } }), removeEventListener: vi.fn(), - close: vi.fn(), + close: vi.fn().mockImplementation(() => closeListener()), send: vi.fn(), })) const transport = new WebSocketTransport({ ws: fakeWs }) const logger = new LogManager() - logger.level = 0 + logger.level = 10 transport.setup(await defaultTestCryptoProvider(), logger) return [transport, fakeWs] as const @@ -89,9 +93,9 @@ describe('WebSocketTransport', () => { const socket = getLastSocket(ws) await vi.waitFor(() => expect(t.state()).toEqual(TransportState.Ready)) - t.close() + await t.close() + console.log('kek') - expect(socket.removeEventListener).toHaveBeenCalled() expect(socket.close).toHaveBeenCalled() }) diff --git a/packages/web/src/websocket.ts b/packages/web/src/websocket.ts index a612960e..a957cc8e 100644 --- a/packages/web/src/websocket.ts +++ b/packages/web/src/websocket.ts @@ -9,7 +9,13 @@ import { ObfuscatedPacketCodec, TransportState, } from '@mtcute/core' -import { BasicDcOption, ICryptoProvider, Logger } from '@mtcute/core/utils.js' +import { + BasicDcOption, + ControllablePromise, + createControllablePromise, + ICryptoProvider, + Logger, +} from '@mtcute/core/utils.js' export type WebSocketConstructor = { new (address: string, protocol?: string): WebSocket @@ -69,8 +75,6 @@ export abstract class BaseWebSocketTransport extends EventEmitter implements ITe this._baseDomain = baseDomain this._subdomains = subdomains this._WebSocket = ws - - this.close = this.close.bind(this) } private _updateLogPrefix() { @@ -123,20 +127,33 @@ export abstract class BaseWebSocketTransport extends EventEmitter implements ITe ) this._socket.addEventListener('open', this.handleConnect.bind(this)) this._socket.addEventListener('error', this.handleError.bind(this)) - this._socket.addEventListener('close', this.close) + this._socket.addEventListener('close', this.handleClosed.bind(this)) } - close(): void { + private _closeWaiters: ControllablePromise[] = [] + async close(): Promise { if (this._state === TransportState.Idle) return - this.log.info('connection closed') - this._state = TransportState.Idle - this._socket!.removeEventListener('close', this.close) + const promise = createControllablePromise() + this._closeWaiters.push(promise) + this._socket!.close() + + return promise + } + + handleClosed(): void { + this.log.info('connection closed') + this._state = TransportState.Idle this._socket = null this._currentDc = null this._packetCodec.reset() this.emit('close') + + for (const waiter of this._closeWaiters) { + waiter.resolve() + } + this._closeWaiters = [] } handleError(event: Event | { error: Error }): void {