From 08d3afadd8bc4cac60aeaeca29163c5701c6e3b8 Mon Sep 17 00:00:00 2001 From: Alina Sireneva Date: Mon, 13 Nov 2023 02:34:22 +0300 Subject: [PATCH] test(core): more transport tests --- packages/client/scripts/generate-client.cjs | 2 +- .../src/network/transports/obfuscated.test.ts | 197 ++++++++++++++++++ .../core/src/network/transports/obfuscated.ts | 2 +- .../core/src/network/transports/tcp.test.ts | 149 +++++++++++++ .../src/network/transports/websocket.test.ts | 133 ++++++++++++ .../src/utils/crypto/crypto.test-utils.ts | 15 ++ 6 files changed, 496 insertions(+), 2 deletions(-) create mode 100644 packages/core/src/network/transports/obfuscated.test.ts create mode 100644 packages/core/src/network/transports/tcp.test.ts create mode 100644 packages/core/src/network/transports/websocket.test.ts diff --git a/packages/client/scripts/generate-client.cjs b/packages/client/scripts/generate-client.cjs index e3a29dca..0191e001 100644 --- a/packages/client/scripts/generate-client.cjs +++ b/packages/client/scripts/generate-client.cjs @@ -421,7 +421,7 @@ async function main() { } for await (const file of getFiles(path.join(__dirname, '../src/methods'))) { - if (!file.startsWith('.') && file.endsWith('.ts') && !file.endsWith('.web.ts')) { + if (!file.startsWith('.') && file.endsWith('.ts') && !file.endsWith('.web.ts') && !file.endsWith('.test.ts')) { await addSingleMethod(state, file) } } diff --git a/packages/core/src/network/transports/obfuscated.test.ts b/packages/core/src/network/transports/obfuscated.test.ts new file mode 100644 index 00000000..3a8f404c --- /dev/null +++ b/packages/core/src/network/transports/obfuscated.test.ts @@ -0,0 +1,197 @@ +import { describe, expect, it, vi } from 'vitest' + +import { defaultTestCryptoProvider, u8HexDecode } from '../../utils/crypto/crypto.test-utils.js' +import { hexDecodeToBuffer, hexEncode, LogManager } from '../../utils/index.js' +import { IntermediatePacketCodec } from './intermediate.js' +import { MtProxyInfo, ObfuscatedPacketCodec } from './obfuscated.js' + +describe('ObfuscatedPacketCodec', () => { + const create = async (randomSource?: string, proxy?: MtProxyInfo) => { + const codec = new ObfuscatedPacketCodec(new IntermediatePacketCodec(), proxy) + const crypto = await defaultTestCryptoProvider(randomSource) + codec.setup(crypto, new LogManager()) + + return [codec, crypto] as const + } + + describe('tag', () => { + it('should correctly generate random initial payload', async () => { + const random = 'ff'.repeat(64) + const [codec] = await create(random) + + const tag = await codec.tag() + + expect(hexEncode(tag)).toEqual( + 'ff'.repeat(56) + 'fce8ab2203db2bff', // encrypted part + ) + }) + + describe('mtproxy', () => { + it('should correctly generate random initial payload for prod dc', async () => { + const random = 'ff'.repeat(64) + const proxy: MtProxyInfo = { + dcId: 1, + secret: new Uint8Array(16), + test: false, + media: false, + } + const [codec] = await create(random, proxy) + + const tag = await codec.tag() + + expect(hexEncode(tag)).toEqual( + 'ff'.repeat(56) + 'ecec4cbda8bb188b', // encrypted part with dcId = 1 + ) + }) + + it('should correctly generate random initial payload for test dc', async () => { + const random = 'ff'.repeat(64) + const proxy: MtProxyInfo = { + dcId: 1, + secret: new Uint8Array(16), + test: true, + media: false, + } + const [codec] = await create(random, proxy) + + const tag = await codec.tag() + + expect(hexEncode(tag)).toEqual( + 'ff'.repeat(56) + 'ecec4cbdb89c188b', // encrypted part with dcId = 10001 + ) + }) + + it('should correctly generate random initial payload for media dc', async () => { + const random = 'ff'.repeat(64) + const proxy: MtProxyInfo = { + dcId: 1, + secret: new Uint8Array(16), + test: false, + media: true, + } + const [codec] = await create(random, proxy) + + const tag = await codec.tag() + + expect(hexEncode(tag)).toEqual( + 'ff'.repeat(56) + 'ecec4cbd5644188b', // encrypted part with dcId = -1 + ) + }) + }) + + it.each([ + ['ef'], + ['48454144'], + ['504f5354'], + ['47455420'], + ['4f505449'], + ['dddddddd'], + ['eeeeeeee'], + ['16030102'], + ])('should correctly retry for %s prefix', async (prefix) => { + const random = prefix + 'ff'.repeat(64 - prefix.length / 2) + const [codec] = await create(random) + + // generating random payload requires 64 bytes of entropy, so + // if it asks for more, it means it tried to generate it again + await expect(() => codec.tag()).rejects.toThrow('not enough entropy') + }) + }) + + it('should correctly create aes ctr', async () => { + const [codec, crypto] = await create() + + const spyCreateAesCtr = vi.spyOn(crypto, 'createAesCtr') + + await codec.tag() + + expect(spyCreateAesCtr).toHaveBeenCalledTimes(2) + expect(spyCreateAesCtr).toHaveBeenNthCalledWith( + 1, + u8HexDecode('10b6b4ad6d56ef5df9453f88e6ee6adb6e0544ba635dc6a8a990c9b8b980c343'), + u8HexDecode('936b33fa7f97bae025102532233abb26'), + true, + ) + expect(spyCreateAesCtr).toHaveBeenNthCalledWith( + 2, + u8HexDecode('26bb3a2332251025e0ba977ffa336b9343c380b9b8c990a9a8c65d63ba44056e'), + u8HexDecode('db6aeee6883f45f95def566dadb4b610'), + false, + ) + }) + + it('should correctly create aes ctr for mtproxy', async () => { + const proxy: MtProxyInfo = { + dcId: 1, + secret: hexDecodeToBuffer('00112233445566778899aabbccddeeff'), + test: true, + media: false, + } + const [codec, crypto] = await create(undefined, proxy) + + const spyCreateAesCtr = vi.spyOn(crypto, 'createAesCtr') + + await codec.tag() + + expect(spyCreateAesCtr).toHaveBeenCalledTimes(2) + expect(spyCreateAesCtr).toHaveBeenNthCalledWith( + 1, + hexDecodeToBuffer('dd03188944590983e28dad14d97d0952389d118af4ffcbdb28d56a6a612ef7a6'), + u8HexDecode('936b33fa7f97bae025102532233abb26'), + true, + ) + expect(spyCreateAesCtr).toHaveBeenNthCalledWith( + 2, + hexDecodeToBuffer('413b8e08021fbb08a2962b6d7187194fe46565c6b329d3bbdfcffd4870c16119'), + u8HexDecode('db6aeee6883f45f95def566dadb4b610'), + false, + ) + }) + + it('should correctly encrypt the underlying codec', async () => { + const data = hexDecodeToBuffer('6cfeffff') + const msg1 = 'a1020630a410e940' + const msg2 = 'f53ff53f371db495' + + const [codec] = await create() + + await codec.tag() + + expect(hexEncode(await codec.encode(data))).toEqual(msg1) + expect(hexEncode(await codec.encode(data))).toEqual(msg2) + }) + + it('should correctly decrypt the underlying codec', async () => { + const msg1 = 'e8027df708ab3b5c' + const msg2 = '1854be76d2df4949' + + const [codec] = await create() + + await codec.tag() + + const log: string[] = [] + + codec.on('error', (e: Error) => { + log.push(e.toString()) + }) + + codec.feed(hexDecodeToBuffer(msg1)) + codec.feed(hexDecodeToBuffer(msg2)) + + await vi.waitFor(() => expect(log).toEqual(['Error: Transport error: 404', 'Error: Transport error: 404'])) + }) + + it('should correctly reset', async () => { + const inner = new IntermediatePacketCodec() + const spyInnerReset = vi.spyOn(inner, 'reset') + + const codec = new ObfuscatedPacketCodec(inner) + codec.setup(await defaultTestCryptoProvider(), new LogManager()) + + await codec.tag() + + codec.reset() + + expect(spyInnerReset).toHaveBeenCalledTimes(1) + }) +}) diff --git a/packages/core/src/network/transports/obfuscated.ts b/packages/core/src/network/transports/obfuscated.ts index ffd5cb9f..aaa36512 100644 --- a/packages/core/src/network/transports/obfuscated.ts +++ b/packages/core/src/network/transports/obfuscated.ts @@ -30,7 +30,7 @@ export class ObfuscatedPacketCodec extends WrappedCodec implements IPacketCodec if (random[0] === 0xef) continue dv = dataViewFromBuffer(random) - const firstInt = dv.getInt32(0, true) + const firstInt = dv.getUint32(0, true) if ( firstInt === 0x44414548 || // HEAD diff --git a/packages/core/src/network/transports/tcp.test.ts b/packages/core/src/network/transports/tcp.test.ts new file mode 100644 index 00000000..18404139 --- /dev/null +++ b/packages/core/src/network/transports/tcp.test.ts @@ -0,0 +1,149 @@ +import { Socket } from 'net' +import { describe, expect, it, MockedObject, vi } from 'vitest' + +import { defaultTestCryptoProvider, u8HexDecode } from '../../utils/crypto/crypto.test-utils.js' +import { defaultProductionDc, hexDecodeToBuffer, LogManager } from '../../utils/index.js' +import { TransportState } from './abstract.js' +import { TcpTransport } from './tcp.js' + +vi.mock('net', () => ({ + connect: vi.fn().mockImplementation((port: number, ip: string, cb: () => void) => { + cb() + + return { + on: vi.fn(), + write: vi.fn().mockImplementation((data: Uint8Array, cb: () => void) => { + cb() + }), + end: vi.fn(), + removeAllListeners: vi.fn(), + destroy: vi.fn(), + } + }), +})) + +describe('TcpTransport', async () => { + const net = await import('net') + const connect = vi.mocked(net.connect) + + const getLastSocket = () => { + return connect.mock.results[connect.mock.results.length - 1].value as MockedObject + } + + const create = async () => { + const transport = new TcpTransport() + const logger = new LogManager() + logger.level = 0 + transport.setup(await defaultTestCryptoProvider(), logger) + + return transport + } + + it('should initiate a tcp connection to the given dc', async () => { + const t = await create() + + t.connect(defaultProductionDc.main, false) + + expect(connect).toHaveBeenCalledOnce() + expect(connect).toHaveBeenCalledWith( + defaultProductionDc.main.port, + defaultProductionDc.main.ipAddress, + expect.any(Function), + ) + await vi.waitFor(() => expect(t.state()).toEqual(TransportState.Ready)) + }) + + it('should set up event handlers', async () => { + const t = await create() + + t.connect(defaultProductionDc.main, false) + + const socket = getLastSocket() + + expect(socket.on).toHaveBeenCalledTimes(3) + expect(socket.on).toHaveBeenCalledWith('data', expect.any(Function)) + expect(socket.on).toHaveBeenCalledWith('error', expect.any(Function)) + expect(socket.on).toHaveBeenCalledWith('close', expect.any(Function)) + }) + + it('should write packet codec tag once connected', async () => { + const t = await create() + + t.connect(defaultProductionDc.main, false) + + const socket = getLastSocket() + + await vi.waitFor(() => + expect(socket.write).toHaveBeenCalledWith( + u8HexDecode('eeeeeeee'), // intermediate + expect.any(Function), + ), + ) + }) + + it('should write to the underlying socket', async () => { + const t = await create() + + t.connect(defaultProductionDc.main, false) + await vi.waitFor(() => expect(t.state()).toEqual(TransportState.Ready)) + + await t.send(hexDecodeToBuffer('00010203040506070809')) + + const socket = getLastSocket() + + expect(socket.write).toHaveBeenCalledWith(u8HexDecode('0a00000000010203040506070809'), expect.any(Function)) + }) + + it('should correctly close', async () => { + const t = await create() + + t.connect(defaultProductionDc.main, false) + await vi.waitFor(() => expect(t.state()).toEqual(TransportState.Ready)) + + t.close() + + const socket = getLastSocket() + + expect(socket.removeAllListeners).toHaveBeenCalledOnce() + expect(socket.destroy).toHaveBeenCalledOnce() + }) + + it('should feed data to the packet codec', async () => { + const t = await create() + const codec = t._packetCodec + + const spyFeed = vi.spyOn(codec, 'feed') + + t.connect(defaultProductionDc.main, false) + await vi.waitFor(() => expect(t.state()).toEqual(TransportState.Ready)) + + const socket = getLastSocket() + + const onDataCall = socket.on.mock.calls.find((c) => (c as string[])[0] === 'data') as unknown as [ + string, + (data: Uint8Array) => void, + ] + onDataCall[1](u8HexDecode('00010203040506070809')) + + expect(spyFeed).toHaveBeenCalledWith(u8HexDecode('00010203040506070809')) + }) + + it('should propagate errors', async () => { + const t = await create() + + const spyEmit = vi.spyOn(t, 'emit').mockImplementation(() => true) + + t.connect(defaultProductionDc.main, false) + await vi.waitFor(() => expect(t.state()).toEqual(TransportState.Ready)) + + const socket = getLastSocket() + + const onErrorCall = socket.on.mock.calls.find((c) => (c as string[])[0] === 'error') as unknown as [ + string, + (error: Error) => void, + ] + onErrorCall[1](new Error('test error')) + + expect(spyEmit).toHaveBeenCalledWith('error', new Error('test error')) + }) +}) diff --git a/packages/core/src/network/transports/websocket.test.ts b/packages/core/src/network/transports/websocket.test.ts new file mode 100644 index 00000000..05cf14b5 --- /dev/null +++ b/packages/core/src/network/transports/websocket.test.ts @@ -0,0 +1,133 @@ +import { describe, expect, it, Mock, MockedObject, vi } from 'vitest' + +import { defaultTestCryptoProvider, u8HexDecode } from '../../utils/crypto/crypto.test-utils.js' +import { defaultProductionDc, hexDecodeToBuffer, LogManager } from '../../utils/index.js' +import { TransportState } from './abstract.js' +import { WebSocketTransport } from './websocket.js' + +describe('WebSocketTransport', () => { + const create = async () => { + const fakeWs = vi.fn().mockImplementation(() => ({ + addEventListener: vi.fn().mockImplementation((event: string, cb: () => void) => { + if (event === 'open') { + cb() + } + }), + removeEventListener: vi.fn(), + close: vi.fn(), + send: vi.fn(), + })) + + const transport = new WebSocketTransport({ ws: fakeWs }) + const logger = new LogManager() + logger.level = 0 + transport.setup(await defaultTestCryptoProvider(), logger) + + return [transport, fakeWs] as const + } + + const getLastSocket = (ws: Mock) => { + return ws.mock.results[ws.mock.results.length - 1].value as MockedObject + } + + it('should initiate a websocket connection to the given dc', async () => { + const [t, ws] = await create() + + t.connect(defaultProductionDc.main, false) + + expect(ws).toHaveBeenCalledOnce() + expect(ws).toHaveBeenCalledWith('wss://venus.web.telegram.org/apiws', 'binary') + await vi.waitFor(() => expect(t.state()).toEqual(TransportState.Ready)) + }) + + it('should set up event handlers', async () => { + const [t, ws] = await create() + + t.connect(defaultProductionDc.main, false) + const socket = getLastSocket(ws) + + expect(socket.addEventListener).toHaveBeenCalledWith('message', expect.any(Function)) + expect(socket.addEventListener).toHaveBeenCalledWith('error', expect.any(Function)) + expect(socket.addEventListener).toHaveBeenCalledWith('close', expect.any(Function)) + }) + + it('should write packet codec tag to the socket', async () => { + const [t, ws] = await create() + + t.connect(defaultProductionDc.main, false) + const socket = getLastSocket(ws) + + await vi.waitFor(() => + expect(socket.send).toHaveBeenCalledWith( + u8HexDecode( + '29afd26df40fb8ed10b6b4ad6d56ef5df9453f88e6ee6adb6e0544ba635dc6a8a990c9b8b980c343936b33fa7f97bae025102532233abb26d4a1fe6d34f1ba08', + ), + ), + ) + }) + + it('should write to the underlying socket', async () => { + const [t, ws] = await create() + + t.connect(defaultProductionDc.main, false) + const socket = getLastSocket(ws) + await vi.waitFor(() => expect(t.state()).toEqual(TransportState.Ready)) + + await t.send(hexDecodeToBuffer('00010203040506070809')) + + expect(socket.send).toHaveBeenCalledWith(hexDecodeToBuffer('af020630c8ef14bcf53af33853ea')) + }) + + it('should correctly close', async () => { + const [t, ws] = await create() + + t.connect(defaultProductionDc.main, false) + const socket = getLastSocket(ws) + await vi.waitFor(() => expect(t.state()).toEqual(TransportState.Ready)) + + t.close() + + expect(socket.removeEventListener).toHaveBeenCalled() + expect(socket.close).toHaveBeenCalled() + }) + + it('should correctly handle incoming messages', async () => { + const [t, ws] = await create() + + const feedSpy = vi.spyOn(t._packetCodec, 'feed') + + t.connect(defaultProductionDc.main, false) + const socket = getLastSocket(ws) + await vi.waitFor(() => expect(t.state()).toEqual(TransportState.Ready)) + + const data = hexDecodeToBuffer('00010203040506070809') + const message = new MessageEvent('message', { data }) + + const onMessageCall = socket.addEventListener.mock.calls.find(([event]) => event === 'message') as unknown as [ + string, + (evt: MessageEvent) => void, + ] + onMessageCall[1](message) + + expect(feedSpy).toHaveBeenCalledWith(u8HexDecode('00010203040506070809')) + }) + + it('should propagate errors', async () => { + const [t, ws] = await create() + + const spyEmit = vi.spyOn(t, 'emit').mockImplementation(() => true) + + t.connect(defaultProductionDc.main, false) + const socket = getLastSocket(ws) + await vi.waitFor(() => expect(t.state()).toEqual(TransportState.Ready)) + + const error = new Error('test') + const onErrorCall = socket.addEventListener.mock.calls.find(([event]) => event === 'error') as unknown as [ + string, + (evt: { error: Error }) => void, + ] + onErrorCall[1]({ error }) + + expect(spyEmit).toHaveBeenCalledWith('error', error) + }) +}) diff --git a/packages/core/src/utils/crypto/crypto.test-utils.ts b/packages/core/src/utils/crypto/crypto.test-utils.ts index 0b76cca5..cfa47ee5 100644 --- a/packages/core/src/utils/crypto/crypto.test-utils.ts +++ b/packages/core/src/utils/crypto/crypto.test-utils.ts @@ -1,3 +1,4 @@ +/* eslint-disable no-restricted-globals */ import { afterEach, beforeAll, beforeEach, describe, expect, it, vi } from 'vitest' import { gzipSync, inflateSync } from 'zlib' @@ -32,6 +33,10 @@ export function withFakeRandom(provider: ICryptoProvider, source = DEFAULT_ENTRO let offset = 0 function getRandomValues(buf: Uint8Array) { + if (offset + buf.length > sourceBytes.length) { + throw new Error('not enough entropy') + } + buf.set(sourceBytes.subarray(offset, offset + buf.length)) offset += buf.length } @@ -216,3 +221,13 @@ export function testCryptoProvider(c: ICryptoProvider): void { }) }) } + +export function u8HexDecode(hex: string) { + const buf = hexDecodeToBuffer(hex) + + if (Buffer.isBuffer(buf)) { + return new Uint8Array(buf) + } + + return buf +}