fix(web): wait for the websocket to close

This commit is contained in:
alina 🌸 2024-04-22 19:16:47 +03:00
parent 40cc10cd87
commit 78457fb158
Signed by: teidesu
SSH key fingerprint: SHA256:uNeCpw6aTSU4aIObXLvHfLkDa82HWH9EiOj9AXOIRpI
8 changed files with 68 additions and 35 deletions

View file

@ -351,7 +351,7 @@ export class MtClient extends EventEmitter {
*/ */
async close(): Promise<void> { async close(): Promise<void> {
this._config.destroy() this._config.destroy()
this.network.destroy() await this.network.destroy()
await this.storage.save() await this.storage.save()
await this.storage.destroy?.() await this.storage.destroy?.()

View file

@ -116,7 +116,9 @@ export class MultiSessionConnection extends EventEmitter {
// destroy extra connections // destroy extra connections
for (let i = this._connections.length - 1; i >= this._count; i--) { for (let i = this._connections.length - 1; i >= this._count; i--) {
this._connections[i].removeAllListeners() this._connections[i].removeAllListeners()
this._connections[i].destroy() this._connections[i].destroy().catch((err) => {
this._log.warn('error destroying connection: %s', err)
})
} }
this._connections.splice(this._count) this._connections.splice(this._count)
@ -199,8 +201,8 @@ export class MultiSessionConnection extends EventEmitter {
} }
_destroyed = false _destroyed = false
destroy(): void { async destroy(): Promise<void> {
this._connections.forEach((conn) => conn.destroy()) await Promise.all(this._connections.map((conn) => conn.destroy()))
this._sessions.forEach((sess) => sess.reset()) this._sessions.forEach((sess) => sess.reset())
this.removeAllListeners() this.removeAllListeners()

View file

@ -413,11 +413,11 @@ export class DcConnectionManager {
return true return true
} }
destroy() { async destroy() {
this.main.destroy() await this.main.destroy()
this.upload.destroy() await this.upload.destroy()
this.download.destroy() await this.download.destroy()
this.downloadSmall.destroy() await this.downloadSmall.destroy()
this._salts.destroy() this._salts.destroy()
} }
} }
@ -861,9 +861,9 @@ export class NetworkManager {
return this._primaryDc.dcId return this._primaryDc.dcId
} }
destroy(): void { async destroy(): Promise<void> {
for (const dc of this._dcConnections.values()) { for (const dc of this._dcConnections.values()) {
dc.destroy() await dc.destroy()
} }
this.config.offReload(this._onConfigChanged) this.config.offReload(this._onConfigChanged)
this._resetOnNetworkChange?.() this._resetOnNetworkChange?.()

View file

@ -69,7 +69,9 @@ export abstract class PersistentConnection extends EventEmitter {
changeTransport(factory: TransportFactory): void { changeTransport(factory: TransportFactory): void {
if (this._transport) { if (this._transport) {
this._transport.close() Promise.resolve(this._transport.close()).catch((err) => {
this.log.warn('error closing previous transport: %s', err)
})
} }
this._transport = factory() this._transport = factory()
@ -149,7 +151,9 @@ export abstract class PersistentConnection extends EventEmitter {
) )
if (wait === false) { if (wait === false) {
this.destroy() this.destroy().catch((err) => {
this.log.warn('error destroying connection: %s', err)
})
return return
} }
@ -192,7 +196,9 @@ export abstract class PersistentConnection extends EventEmitter {
// if we are already connected // if we are already connected
if (this.isConnected) { if (this.isConnected) {
this._shouldReconnectImmediately = true this._shouldReconnectImmediately = true
this._transport.close() Promise.resolve(this._transport.close()).catch((err) => {
this.log.error('error closing transport: %s', err)
})
return return
} }
@ -201,12 +207,12 @@ export abstract class PersistentConnection extends EventEmitter {
this.connect() this.connect()
} }
disconnectManual(): void { async disconnectManual(): Promise<void> {
this._disconnectedManually = true this._disconnectedManually = true
this._transport.close() await this._transport.close()
} }
destroy(): void { async destroy(): Promise<void> {
if (this._reconnectionTimeout != null) { if (this._reconnectionTimeout != null) {
clearTimeout(this._reconnectionTimeout) clearTimeout(this._reconnectionTimeout)
} }
@ -214,7 +220,7 @@ export abstract class PersistentConnection extends EventEmitter {
clearTimeout(this._inactivityTimeout) clearTimeout(this._inactivityTimeout)
} }
this._transport.close() await this._transport.close()
this._transport.removeAllListeners() this._transport.removeAllListeners()
this._destroyed = true this._destroyed = true
} }
@ -229,7 +235,9 @@ export abstract class PersistentConnection extends EventEmitter {
this.log.info('disconnected because of inactivity for %d', this.params.inactivityTimeout) this.log.info('disconnected because of inactivity for %d', this.params.inactivityTimeout)
this._inactive = true this._inactive = true
this._inactivityTimeout = null this._inactivityTimeout = null
this._transport.close() Promise.resolve(this._transport.close()).catch((err) => {
this.log.warn('error closing transport: %s', err)
})
} }
setInactivityTimeout(timeout?: number): void { setInactivityTimeout(timeout?: number): void {

View file

@ -155,8 +155,8 @@ export class SessionConnection extends PersistentConnection {
this.reset() this.reset()
} }
destroy(): void { async destroy(): Promise<void> {
super.destroy() await super.destroy()
this.reset(true) this.reset(true)
} }
@ -1462,7 +1462,9 @@ export class SessionConnection extends PersistentConnection {
if (online) { if (online) {
this.reconnect() this.reconnect()
} else { } else {
this.disconnectManual() this.disconnectManual().catch((err) => {
this.log.warn('error while disconnecting: %s', err)
})
} }
} }

View file

@ -48,7 +48,7 @@ export interface ITelegramTransport extends EventEmitter {
*/ */
connect(dc: BasicDcOption, testMode: boolean): void connect(dc: BasicDcOption, testMode: boolean): void
/** call to close existing connection to some DC */ /** call to close existing connection to some DC */
close(): void close(): MaybePromise<void>
/** send a message */ /** send a message */
send(data: Uint8Array): Promise<void> send(data: Uint8Array): Promise<void>

View file

@ -11,20 +11,24 @@ const p = getPlatform()
describe('WebSocketTransport', () => { describe('WebSocketTransport', () => {
const create = async () => { const create = async () => {
let closeListener: () => void = () => {}
const fakeWs = vi.fn().mockImplementation(() => ({ const fakeWs = vi.fn().mockImplementation(() => ({
addEventListener: vi.fn().mockImplementation((event: string, cb: () => void) => { addEventListener: vi.fn().mockImplementation((event: string, cb: () => void) => {
if (event === 'open') { if (event === 'open') {
cb() cb()
} }
if (event === 'close') {
closeListener = cb
}
}), }),
removeEventListener: vi.fn(), removeEventListener: vi.fn(),
close: vi.fn(), close: vi.fn().mockImplementation(() => closeListener()),
send: vi.fn(), send: vi.fn(),
})) }))
const transport = new WebSocketTransport({ ws: fakeWs }) const transport = new WebSocketTransport({ ws: fakeWs })
const logger = new LogManager() const logger = new LogManager()
logger.level = 0 logger.level = 10
transport.setup(await defaultTestCryptoProvider(), logger) transport.setup(await defaultTestCryptoProvider(), logger)
return [transport, fakeWs] as const return [transport, fakeWs] as const
@ -89,9 +93,9 @@ describe('WebSocketTransport', () => {
const socket = getLastSocket(ws) const socket = getLastSocket(ws)
await vi.waitFor(() => expect(t.state()).toEqual(TransportState.Ready)) await vi.waitFor(() => expect(t.state()).toEqual(TransportState.Ready))
t.close() await t.close()
console.log('kek')
expect(socket.removeEventListener).toHaveBeenCalled()
expect(socket.close).toHaveBeenCalled() expect(socket.close).toHaveBeenCalled()
}) })

View file

@ -9,7 +9,13 @@ import {
ObfuscatedPacketCodec, ObfuscatedPacketCodec,
TransportState, TransportState,
} from '@mtcute/core' } from '@mtcute/core'
import { BasicDcOption, ICryptoProvider, Logger } from '@mtcute/core/utils.js' import {
BasicDcOption,
ControllablePromise,
createControllablePromise,
ICryptoProvider,
Logger,
} from '@mtcute/core/utils.js'
export type WebSocketConstructor = { export type WebSocketConstructor = {
new (address: string, protocol?: string): WebSocket new (address: string, protocol?: string): WebSocket
@ -69,8 +75,6 @@ export abstract class BaseWebSocketTransport extends EventEmitter implements ITe
this._baseDomain = baseDomain this._baseDomain = baseDomain
this._subdomains = subdomains this._subdomains = subdomains
this._WebSocket = ws this._WebSocket = ws
this.close = this.close.bind(this)
} }
private _updateLogPrefix() { private _updateLogPrefix() {
@ -123,20 +127,33 @@ export abstract class BaseWebSocketTransport extends EventEmitter implements ITe
) )
this._socket.addEventListener('open', this.handleConnect.bind(this)) this._socket.addEventListener('open', this.handleConnect.bind(this))
this._socket.addEventListener('error', this.handleError.bind(this)) this._socket.addEventListener('error', this.handleError.bind(this))
this._socket.addEventListener('close', this.close) this._socket.addEventListener('close', this.handleClosed.bind(this))
} }
close(): void { private _closeWaiters: ControllablePromise<void>[] = []
async close(): Promise<void> {
if (this._state === TransportState.Idle) return if (this._state === TransportState.Idle) return
this.log.info('connection closed')
this._state = TransportState.Idle const promise = createControllablePromise<void>()
this._socket!.removeEventListener('close', this.close) this._closeWaiters.push(promise)
this._socket!.close() this._socket!.close()
return promise
}
handleClosed(): void {
this.log.info('connection closed')
this._state = TransportState.Idle
this._socket = null this._socket = null
this._currentDc = null this._currentDc = null
this._packetCodec.reset() this._packetCodec.reset()
this.emit('close') this.emit('close')
for (const waiter of this._closeWaiters) {
waiter.resolve()
}
this._closeWaiters = []
} }
handleError(event: Event | { error: Error }): void { handleError(event: Event | { error: Error }): void {