From 04d97d8d65ba9b559c2dc0b952c5afc2ce489294 Mon Sep 17 00:00:00 2001 From: Niklas Higi Date: Sat, 2 Jan 2021 18:51:29 +0100 Subject: [PATCH] Introduce `observeAsync` helper to improve async error handling --- src/patch-apk.ts | 55 ++++------ src/patch-app-bundle.ts | 16 ++- src/tasks/disable-certificate-pinning.ts | 125 +++++++++++------------ src/utils/download-file.ts | 42 ++++++++ src/utils/download-tool.ts | 71 +++---------- src/utils/observe-async.ts | 18 ++++ src/utils/observe-process.ts | 42 ++++---- 7 files changed, 185 insertions(+), 184 deletions(-) create mode 100644 src/utils/download-file.ts create mode 100644 src/utils/observe-async.ts diff --git a/src/patch-apk.ts b/src/patch-apk.ts index 95d917c..ae06564 100644 --- a/src/patch-apk.ts +++ b/src/patch-apk.ts @@ -1,6 +1,6 @@ import * as path from 'path' +import { once } from 'events' import * as fs from './utils/fs' -import { Observable } from 'rxjs' import Listr from 'listr' import chalk from 'chalk' @@ -9,6 +9,7 @@ import downloadTools from './tasks/download-tools' import modifyManifest from './tasks/modify-manifest' import createNetworkSecurityConfig from './tasks/create-netsec-config' import disableCertificatePinning from './tasks/disable-certificate-pinning' +import observeAsync from './utils/observe-async' export default function patchApk(taskOptions: TaskOptions) { const { @@ -52,20 +53,16 @@ export default function patchApk(taskOptions: TaskOptions) { { title: 'Waiting for you to make changes', enabled: () => wait, - task: (_) => { - return new Observable(subscriber => { - process.stdin.setEncoding('utf-8') - process.stdin.setRawMode(true) + task: () => observeAsync(async next => { + process.stdin.setEncoding('utf-8') + process.stdin.setRawMode(true) - subscriber.next("Press any key to continue.") + next('Press any key to continue.') + await once(process.stdin, 'data') - process.stdin.once('data', () => { - subscriber.complete() - process.stdin.setRawMode(false) - process.stdin.pause() - }) - }) - }, + process.stdin.setRawMode(false) + process.stdin.pause() + }) }, { title: 'Encoding patched APK file', @@ -73,16 +70,13 @@ export default function patchApk(taskOptions: TaskOptions) { new Listr([ { title: 'Encoding using AAPT2', - task: (_, task) => new Observable(subscriber => { - apktool.encode(decodeDir, tmpApkPath, true).subscribe( - line => subscriber.next(line), - () => { - subscriber.complete() - task.skip('Failed, falling back to AAPT...') - fallBackToAapt = true - }, - () => subscriber.complete(), - ) + task: (_, task) => observeAsync(async next => { + try { + await apktool.encode(decodeDir, tmpApkPath, true).forEach(next) + } catch { + task.skip('Failed, falling back to AAPT...') + fallBackToAapt = true + } }), }, { @@ -94,17 +88,12 @@ export default function patchApk(taskOptions: TaskOptions) { }, { title: 'Signing patched APK file', - task: () => new Observable(subscriber => { - (async () => { - await uberApkSigner - .sign([tmpApkPath], { zipalign: true }) - .forEach(line => subscriber.next(line)) - .catch(error => subscriber.error(error)) - - await fs.copyFile(tmpApkPath, outputPath) + task: () => observeAsync(async next => { + await uberApkSigner + .sign([tmpApkPath], { zipalign: true }) + .forEach(line => next(line)) - subscriber.complete() - })() + await fs.copyFile(tmpApkPath, outputPath) }), }, ]) diff --git a/src/patch-app-bundle.ts b/src/patch-app-bundle.ts index 212f775..ccf4832 100644 --- a/src/patch-app-bundle.ts +++ b/src/patch-app-bundle.ts @@ -1,5 +1,4 @@ import { unzip, zip } from '@tybys/cross-zip' -import { Observable } from 'rxjs' import * as fs from './utils/fs' import * as path from 'path' import globby from 'globby' @@ -7,6 +6,7 @@ import Listr from 'listr' import patchApk from './patch-apk' import { TaskOptions } from './cli' +import observeAsync from './utils/observe-async' export function patchXapkBundle(options: TaskOptions) { return patchAppBundle(options, { isXapk: true }) @@ -48,16 +48,12 @@ function patchAppBundle( }, { title: 'Signing APKs', - task: () => new Observable(subscriber => { - (async () => { - const apkFiles = await globby(path.join(bundleDir, '**/*.apk')) + task: () => observeAsync(async next => { + const apkFiles = await globby(path.join(bundleDir, '**/*.apk')) - await uberApkSigner - .sign(apkFiles, { zipalign: false }) - .forEach(line => subscriber.next(line)) - - subscriber.complete() - })() + await uberApkSigner + .sign(apkFiles, { zipalign: false }) + .forEach(line => next(line)) }), }, { diff --git a/src/tasks/disable-certificate-pinning.ts b/src/tasks/disable-certificate-pinning.ts index 57f7365..66cd523 100644 --- a/src/tasks/disable-certificate-pinning.ts +++ b/src/tasks/disable-certificate-pinning.ts @@ -4,8 +4,8 @@ import * as fs from '../utils/fs' import globby from 'globby' import escapeStringRegexp from 'escape-string-regexp' -import { Observable } from 'rxjs' import { ListrTaskWrapper } from 'listr' +import observeAsync from '../utils/observe-async' const INTERFACE_LINE = '.implements Ljavax/net/ssl/X509TrustManager;' @@ -40,81 +40,78 @@ const RETURN_EMPTY_ARRAY_FIX = [ ] export default async function disableCertificatePinning(directoryPath: string, task: ListrTaskWrapper) { - return new Observable(observer => { - (async () => { - observer.next('Finding smali files...') + return observeAsync(async next => { + next('Finding smali files...') - // Convert Windows path (using backslashes) to POSIX path (using slashes) - const directoryPathPosix = directoryPath.split(path.sep).join(path.posix.sep) - const globPattern = path.posix.join(directoryPathPosix, 'smali*/**/*.smali') + // Convert Windows path (using backslashes) to POSIX path (using slashes) + const directoryPathPosix = directoryPath.split(path.sep).join(path.posix.sep) + const globPattern = path.posix.join(directoryPathPosix, 'smali*/**/*.smali') - const smaliFiles = await globby(globPattern) + const smaliFiles = await globby(globPattern) - let pinningFound = false + let pinningFound = false - for (const filePath of smaliFiles) { - observer.next(`Scanning ${path.basename(filePath)}...`) + for (const filePath of smaliFiles) { + next(`Scanning ${path.basename(filePath)}...`) - let originalContent = await fs.readFile(filePath, 'utf-8') + let originalContent = await fs.readFile(filePath, 'utf-8') - // Don't scan classes that don't implement the interface - if (!originalContent.includes(INTERFACE_LINE)) continue + // Don't scan classes that don't implement the interface + if (!originalContent.includes(INTERFACE_LINE)) continue - if (os.type() === 'Windows_NT') { - // Replace CRLF with LF, so that patches can just use '\n' - originalContent = originalContent.replace(/\r\n/g, '\n') - } - - let patchedContent = originalContent - - for (const pattern of METHOD_PATTERNS) { - patchedContent = patchedContent.replace( - pattern, ( - _, - openingLine: string, - body: string, - closingLine: string, - ) => { - const bodyLines = body - .split('\n') - .map(line => line.replace(/^ /, '')) - - const fixLines = openingLine.includes('getAcceptedIssuers') - ? RETURN_EMPTY_ARRAY_FIX - : RETURN_VOID_FIX - - const patchedBodyLines = [ - '# inserted by apk-mitm to disable certificate pinning', - ...fixLines, - '', - '# commented out by apk-mitm to disable old method body', - '# ', - ...bodyLines.map(line => `# ${line}`) - ] - - return [ - openingLine, - ...patchedBodyLines.map(line => ` ${line}`), - closingLine, - ].map(line => line.trimEnd()).join('\n') - }, - ) - } + if (os.type() === 'Windows_NT') { + // Replace CRLF with LF, so that patches can just use '\n' + originalContent = originalContent.replace(/\r\n/g, '\n') + } - if (originalContent !== patchedContent) { - pinningFound = true + let patchedContent = originalContent + + for (const pattern of METHOD_PATTERNS) { + patchedContent = patchedContent.replace( + pattern, ( + _, + openingLine: string, + body: string, + closingLine: string, + ) => { + const bodyLines = body + .split('\n') + .map(line => line.replace(/^ /, '')) + + const fixLines = openingLine.includes('getAcceptedIssuers') + ? RETURN_EMPTY_ARRAY_FIX + : RETURN_VOID_FIX + + const patchedBodyLines = [ + '# inserted by apk-mitm to disable certificate pinning', + ...fixLines, + '', + '# commented out by apk-mitm to disable old method body', + '# ', + ...bodyLines.map(line => `# ${line}`) + ] + + return [ + openingLine, + ...patchedBodyLines.map(line => ` ${line}`), + closingLine, + ].map(line => line.trimEnd()).join('\n') + }, + ) + } - if (os.type() === 'Windows_NT') { - // Replace LF with CRLF again - patchedContent = patchedContent.replace(/\n/g, '\r\n') - } + if (originalContent !== patchedContent) { + pinningFound = true - await fs.writeFile(filePath, patchedContent) + if (os.type() === 'Windows_NT') { + // Replace LF with CRLF again + patchedContent = patchedContent.replace(/\n/g, '\r\n') } + + await fs.writeFile(filePath, patchedContent) } + } - if (!pinningFound) task.skip('No certificate pinning logic found.') - observer.complete() - })() + if (!pinningFound) task.skip('No certificate pinning logic found.') }) } diff --git a/src/utils/download-file.ts b/src/utils/download-file.ts new file mode 100644 index 0000000..4b28ead --- /dev/null +++ b/src/utils/download-file.ts @@ -0,0 +1,42 @@ +import * as fs from './fs' +import { Observable } from 'rxjs' +import followRedirects = require('follow-redirects') +const { https } = followRedirects + +export default function downloadFile(url: string, path: string) { + return new Observable(subscriber => { + https.get(url, response => { + if (response.statusCode !== 200) { + const error = new Error(`The URL "${url}" returned status code ${response.statusCode}, expected 200.`) + + // Cancel download with error + response.destroy(error) + } + + const fileStream = fs.createWriteStream(path) + + const totalLength = parseInt(response.headers['content-length']) + let currentLength = 0 + + const reportProgress = () => { + const percentage = currentLength / totalLength + subscriber.next(`${(percentage * 100).toFixed(2)}% done (${formatBytes(currentLength)} / ${formatBytes(totalLength)} MB)`) + } + reportProgress() + + response.pipe(fileStream) + + response.on('data', (chunk: Buffer) => { + currentLength += chunk.byteLength + reportProgress() + }) + response.on('error', error => subscriber.error(error)) + + fileStream.on('close', () => subscriber.complete()) + }).on('error', error => subscriber.error(error)) + }) +} + +function formatBytes(bytes: number) { + return (bytes / 1000000).toFixed(2) +} diff --git a/src/utils/download-tool.ts b/src/utils/download-tool.ts index d5a4990..05a8c06 100644 --- a/src/utils/download-tool.ts +++ b/src/utils/download-tool.ts @@ -2,11 +2,11 @@ import * as fs from 'fs' import { promises as fsp } from 'fs' import * as pathUtils from 'path' import envPaths = require('env-paths') -import { Observable } from 'rxjs' import { ListrTaskWrapper } from 'listr' -import followRedirects = require('follow-redirects') + import Tool from '../tools/tool' -const { https } = followRedirects +import observeAsync from './observe-async' +import downloadFile from './download-file' const cachePath = envPaths('apk-mitm', { suffix: '' }).cache @@ -18,71 +18,34 @@ export default function createToolDownloadTask(tool: Tool) { return task.skip('Using custom version') const fileName = `${tool.name}-${tool.version.name}.jar` - return downloadFile(task, tool.version.downloadUrl, fileName) + return downloadCachedFile(task, tool.version.downloadUrl, fileName) }, } } -function downloadFile( +function downloadCachedFile( task: ListrTaskWrapper, url: string, fileName: string, ) { - return new Observable(subscriber => { - (async () => { - const finalFilePath = getCachedPath(fileName) - - if (fs.existsSync(finalFilePath)) { - task.skip('Version already downloaded!') - subscriber.complete() - return - } - - // Ensure cache directory exists - await fsp.mkdir(cachePath, { recursive: true }) - - // Prevent file corruption by using a temporary file name - const downloadFilePath = finalFilePath + '.dl' - - https.get(url, response => { - if (response.statusCode !== 200) { - const error = new Error(`The URL "${url}" returned status code ${response.statusCode}, expected 200.`) - - // Cancel download with error - response.destroy(error) - } + return observeAsync(async next => { + const finalFilePath = getCachedPath(fileName) - const fileStream = fs.createWriteStream(downloadFilePath) + if (fs.existsSync(finalFilePath)) { + task.skip('Version already downloaded!') + return + } - const totalLength = parseInt(response.headers['content-length']) - let currentLength = 0 + // Ensure cache directory exists + await fsp.mkdir(cachePath, { recursive: true }) - const reportProgress = () => { - const percentage = currentLength / totalLength - subscriber.next(`${(percentage * 100).toFixed(2)}% done (${formatBytes(currentLength)} / ${formatBytes(totalLength)} MB)`) - } - reportProgress() - - response.pipe(fileStream) - - response.on('data', (chunk: Buffer) => { - currentLength += chunk.byteLength - reportProgress() - }) - - fileStream.on('close', async () => { - await fsp.rename(downloadFilePath, finalFilePath) - subscriber.complete() - }) - }).on('error', error => subscriber.error(error)) - })() + // Prevent file corruption by using a temporary file name + const downloadFilePath = finalFilePath + '.dl' + await downloadFile(url, downloadFilePath).forEach(next) + await fsp.rename(downloadFilePath, finalFilePath) }) } export function getCachedPath(name: string) { return pathUtils.join(cachePath, name) } - -function formatBytes(bytes: number) { - return (bytes / 1000000).toFixed(2) -} diff --git a/src/utils/observe-async.ts b/src/utils/observe-async.ts new file mode 100644 index 0000000..c97697b --- /dev/null +++ b/src/utils/observe-async.ts @@ -0,0 +1,18 @@ +import { Observable } from 'rxjs' + +/** + * Wraps an async function and produces an `Observable` that reacts to the + * function resolving (`complete` notification), rejecting (`error` + * notification), and calling the `next` callback (`next` notification), making + * it easier to write `async`/`await`-based code that reports its progress + * through an `Observable` *without* forgetting to handle errors. + */ +export default function observeAsync( + fn: (next: (value: T) => void) => Promise, +): Observable { + return new Observable(subscriber => { + fn(value => subscriber.next(value)) + .then(() => subscriber.complete()) + .catch(error => subscriber.error(error)) + }) +} diff --git a/src/utils/observe-process.ts b/src/utils/observe-process.ts index e2cf484..5635653 100644 --- a/src/utils/observe-process.ts +++ b/src/utils/observe-process.ts @@ -2,36 +2,32 @@ import * as fs from '../utils/fs' import * as pathUtils from 'path' import { ExecaChildProcess } from 'execa' import { Observable } from 'rxjs' +import observeAsync from './observe-async' export default function observeProcess( process: ExecaChildProcess, logName: string, ): Observable { - return new Observable(subscriber => { - (async () => { - await fs.mkdir('logs', { recursive: true }) + return observeAsync(async next => { + await fs.mkdir('logs', { recursive: true }) - const fileName = pathUtils.join('logs', `${logName}.log`) - const failedFileName = pathUtils.join('logs', `${logName}.failed.log`) - const stream = fs.createWriteStream(fileName) + const fileName = pathUtils.join('logs', `${logName}.log`) + const failedFileName = pathUtils.join('logs', `${logName}.failed.log`) + const stream = fs.createWriteStream(fileName) - process - .then(() => { - stream.close() - subscriber.complete() - }) - .catch(async error => { - stream.close() - await fs.rename(fileName, failedFileName) + process.stdout.on('data', (data: Buffer) => { + next(data.toString().trim()) + stream.write(data) + }) + process.stderr.on('data', (data: Buffer) => stream.write(data)) - subscriber.error(error) - }) - - process.stdout.on('data', (data: Buffer) => { - subscriber.next(data.toString().trim()) - stream.write(data) - }) - process.stderr.on('data', (data: Buffer) => stream.write(data)) - })() + try { + await process + } catch (error) { + await fs.rename(fileName, failedFileName) + throw error + } finally { + stream.close() + } }) }