From 773530ecd142988d70f7b77528a26b63c23dc829 Mon Sep 17 00:00:00 2001 From: Jordan Rose Date: Mon, 8 Jul 2024 16:35:57 -0700 Subject: [PATCH] Add type annotations to all our Python scripts The main benefit of this is not our *own* type-checking; it's that mypy will error out if you try to use a too-new Python API. And in fact, we were already relying on Python 3.9 and didn't realize. check_code_size.py works with JSON, so it still uses Any a fair bit. --- .github/workflows/lints.yml | 3 +-- bin/update_versions.py | 8 +++--- java/check_code_size.py | 12 +++++---- node/build_node_bridge.py | 6 +++-- rust/bridge/jni/bin/gen_java_decl.py | 32 +++++++++++----------- rust/bridge/node/bin/gen_ts_decl.py | 40 +++++++++++++++------------- 6 files changed, 54 insertions(+), 47 deletions(-) diff --git a/.github/workflows/lints.yml b/.github/workflows/lints.yml index f9bec445..40ed8bd2 100644 --- a/.github/workflows/lints.yml +++ b/.github/workflows/lints.yml @@ -22,5 +22,4 @@ jobs: - run: pip3 install flake8 mypy - run: shellcheck **/*.sh bin/verify_duplicate_crates bin/adb-run-test - run: python3 -m flake8 . - # Only include typed Python scripts here. - - run: python3 -m mypy bin/fetch_archive.py --python-version 3.8 --strict + - run: python3 -m mypy . --python-version 3.9 --strict diff --git a/bin/update_versions.py b/bin/update_versions.py index 66b99c9c..17fe1f62 100755 --- a/bin/update_versions.py +++ b/bin/update_versions.py @@ -14,7 +14,7 @@ import re import os -def read_version(file, pattern): +def read_version(file: str, pattern: re.Pattern[str]) -> str: with open(file) as f: for line in f: match = pattern.match(line) @@ -23,7 +23,7 @@ def read_version(file, pattern): raise Exception(f"Could not determine version from {file}") -def update_version(file, pattern, new_version): +def update_version(file: str, pattern: re.Pattern[str], new_version: str) -> None: with fileinput.input(files=(file,), inplace=True) as f: for line in f: print(pattern.sub(f"\\g<1>{new_version}\\g<3>", line, count=1), end='') @@ -36,7 +36,7 @@ CARGO_PATTERN = re.compile(r'^(version = ")(.*)(")') RUST_PATTERN = re.compile(r'^(pub const VERSION: &str = ")(.*)(")') -def bridge_path(*bridge): +def bridge_path(*bridge: str) -> str: return os.path.join('rust', 'bridge', *bridge, 'Cargo.toml') @@ -52,7 +52,7 @@ VERSION_FILES = [ ] -def main(): +def main() -> int: os.chdir(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) if len(sys.argv) > 1: diff --git a/java/check_code_size.py b/java/check_code_size.py index ade8fc09..d99a6f5f 100755 --- a/java/check_code_size.py +++ b/java/check_code_size.py @@ -10,15 +10,17 @@ import json import subprocess import sys +from typing import Any, List, Mapping, Optional -def warn(message): + +def warn(message: str) -> None: if 'GITHUB_ACTIONS' in os.environ: print("::warning ::" + message) else: print("warning: " + message, file=sys.stderr) -def print_size_diff(lib_size, old_entry): +def print_size_diff(lib_size: int, old_entry: Mapping[str, Any]) -> None: delta = lib_size - old_entry['size'] delta_fraction = (float(delta) / old_entry['size']) message = "current build is {0}% larger than {1} (current: {2} bytes, {1}: {3} bytes)".format( @@ -33,7 +35,7 @@ def print_size_diff(lib_size, old_entry): print(message) -def current_origin_main_entry(): +def current_origin_main_entry() -> Optional[Mapping[str, Any]]: try: most_recent_main = subprocess.run(["git", "merge-base", "HEAD", "origin/main"], capture_output=True, check=True).stdout.decode().strip() @@ -85,10 +87,10 @@ if origin_main_entry is not None: print_size_diff(lib_size, origin_main_entry) -def print_plot(sizes): +def print_plot(sizes: List[Mapping[str, Any]]) -> None: highest_size = max(recent_sizes, key=lambda x: x['size'])['size'] - scale = 1 * 1024 * 1024 + scale = 1.0 * 1024 * 1024 while scale < highest_size: scale *= 2 scale /= 20 diff --git a/node/build_node_bridge.py b/node/build_node_bridge.py index 5a13596a..d84a6087 100755 --- a/node/build_node_bridge.py +++ b/node/build_node_bridge.py @@ -14,8 +14,10 @@ import subprocess import sys import tarfile +from typing import List, Optional -def maybe_archive_debug_info(*, src_path, src_checksum_path, dst_path, dst_checksum_path): + +def maybe_archive_debug_info(*, src_path: str, src_checksum_path: str, dst_path: str, dst_checksum_path: str) -> None: with open(src_checksum_path, 'rb') as f: digest = hashlib.sha256() # Use read1 to use the file object's buffering. @@ -39,7 +41,7 @@ def maybe_archive_debug_info(*, src_path, src_checksum_path, dst_path, dst_check archive.add(debug_realpath, arcname=os.path.basename(src_path)) -def main(args=None): +def main(args: Optional[List[str]] = None) -> int: if args is None: args = sys.argv diff --git a/rust/bridge/jni/bin/gen_java_decl.py b/rust/bridge/jni/bin/gen_java_decl.py index 51e59a89..b1717556 100755 --- a/rust/bridge/jni/bin/gen_java_decl.py +++ b/rust/bridge/jni/bin/gen_java_decl.py @@ -12,11 +12,13 @@ import subprocess import re import sys +from typing import Iterable, Iterator, Tuple + Args = collections.namedtuple('Args', 'verify') -def parse_args(): - def print_usage_and_exit(): +def parse_args() -> Args: + def print_usage_and_exit() -> None: print(f'usage: {sys.argv[0]} [--verify]', file=sys.stderr) sys.exit(2) @@ -46,13 +48,13 @@ IGNORE_THIS_WARNING = re.compile( ")") -def run_cbindgen(cwd): +def run_cbindgen(cwd: str) -> str: cbindgen = subprocess.Popen(['cbindgen'], cwd=cwd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) - (stdout, stderr) = cbindgen.communicate() + (raw_stdout, raw_stderr) = cbindgen.communicate() - stdout = str(stdout.decode('utf8')) - stderr = str(stderr.decode('utf8')) + stdout = str(raw_stdout.decode('utf8')) + stderr = str(raw_stderr.decode('utf8')) unknown_warning = False @@ -72,7 +74,7 @@ def run_cbindgen(cwd): return stdout -def box_primitive_if_needed(typ): +def box_primitive_if_needed(typ: str) -> str: type_map = { "void": "Void", "boolean": "Boolean", @@ -88,7 +90,7 @@ def box_primitive_if_needed(typ): return type_map.get(typ, typ) -def translate_to_java(typ): +def translate_to_java(typ: str) -> Tuple[str, bool]: type_map = { "void": "void", "JString": "String", @@ -137,7 +139,7 @@ JAVA_DECL = re.compile(r""" """, re.VERBOSE) -def parse_decls(cbindgen_output): +def parse_decls(cbindgen_output: str) -> Iterator[str]: cur_type = None for line in cbindgen_output.split('\n'): @@ -172,13 +174,13 @@ def parse_decls(cbindgen_output): " throws Exception" if is_throwing else "")) -def expand_template(template_file, decls): - with open(template_file, "r") as template_file: - contents = template_file.read().replace('\n // INSERT DECLS HERE', "\n".join(decls)) +def expand_template(template_file: str, decls: Iterable[str]) -> str: + with open(template_file, "r") as f: + contents = f.read().replace('\n // INSERT DECLS HERE', "\n".join(decls)) return contents -def verify_contents(expected_output_file, expected_contents): +def verify_contents(expected_output_file: str, expected_contents: str) -> None: with open(expected_output_file) as fh: current_contents = fh.readlines() diff = difflib.unified_diff(current_contents, expected_contents.splitlines(keepends=True)) @@ -189,7 +191,7 @@ def verify_contents(expected_output_file, expected_contents): sys.exit("error: Native.java not up to date; re-run %s!" % sys.argv[0]) -def convert_to_java(rust_crate_dir, java_in_path, java_out_path, verify): +def convert_to_java(rust_crate_dir: str, java_in_path: str, java_out_path: str, verify: bool) -> None: stdout = run_cbindgen(rust_crate_dir) decls = list(parse_decls(stdout)) @@ -206,7 +208,7 @@ def convert_to_java(rust_crate_dir, java_in_path, java_out_path, verify): verify_contents(java_out_path, contents) -def main(): +def main() -> None: args = parse_args() our_abs_dir = os.path.dirname(os.path.realpath(__file__)) diff --git a/rust/bridge/node/bin/gen_ts_decl.py b/rust/bridge/node/bin/gen_ts_decl.py index 352c752d..3b820537 100755 --- a/rust/bridge/node/bin/gen_ts_decl.py +++ b/rust/bridge/node/bin/gen_ts_decl.py @@ -13,11 +13,13 @@ import subprocess import re import sys +from typing import Iterable, Iterator, Tuple + Args = collections.namedtuple('Args', ['verify']) -def parse_args(): - def print_usage_and_exit(): +def parse_args() -> Args: + def print_usage_and_exit() -> None: print('usage: %s [--verify]' % sys.argv[0], file=sys.stderr) sys.exit(2) @@ -33,7 +35,7 @@ def parse_args(): return Args(verify=mode is not None) -def split_rust_args(args): +def split_rust_args(args: str) -> Iterator[Tuple[str, str]]: """ Split Rust `arg: Type` pairs separated by commas. @@ -56,7 +58,7 @@ def split_rust_args(args): yield (name.strip(), args.strip()) -def translate_to_ts(typ): +def translate_to_ts(typ: str) -> str: typ = typ.replace(' ', '') type_map = { @@ -153,7 +155,7 @@ DIAGNOSTICS_TO_IGNORE = [ SHOULD_IGNORE_PATTERN = re.compile("(" + ")|(".join(DIAGNOSTICS_TO_IGNORE) + ")") -def camelcase(arg): +def camelcase(arg: str) -> str: return re.sub( # Preserve double-underscores and leading underscores, # but remove single underscores and capitalize the following letter. @@ -162,7 +164,7 @@ def camelcase(arg): arg) -def collect_decls(crate_dir, features=()): +def collect_decls(crate_dir: str, features: Iterable[str] = ()) -> Iterator[str]: args = [ 'cargo', 'rustc', @@ -175,10 +177,10 @@ def collect_decls(crate_dir, features=()): '-Zunpretty=expanded'] rustc = subprocess.Popen(args, cwd=crate_dir, stdout=subprocess.PIPE, stderr=subprocess.PIPE) - (stdout, stderr) = rustc.communicate() + (raw_stdout, raw_stderr) = rustc.communicate() - stdout = str(stdout.decode('utf8')) - stderr = str(stderr.decode('utf8')) + stdout = str(raw_stdout.decode('utf8')) + stderr = str(raw_stderr.decode('utf8')) had_error = False for l in stderr.split('\n'): @@ -215,23 +217,23 @@ def collect_decls(crate_dir, features=()): yield decl continue - (prefix, args, ret_type) = function_match.groups() + (prefix, fn_args, ret_type) = function_match.groups() ts_ret_type = translate_to_ts(ret_type) ts_args = [] - if '::' in args: - raise Exception(f'Paths are not supported. Use alias for the type of \'{args}\'') + if '::' in fn_args: + raise Exception(f'Paths are not supported. Use alias for the type of \'{fn_args}\'') - for (arg_name, arg_type) in split_rust_args(args): + for (arg_name, arg_type) in split_rust_args(fn_args): ts_arg_type = translate_to_ts(arg_type) ts_args.append('%s: %s' % (camelcase(arg_name.strip()), ts_arg_type)) yield '%s(%s): %s;' % (prefix, ', '.join(ts_args), ts_ret_type) -def expand_template(template_file, decls): - with open(template_file, "r") as template_file: - contents = template_file.read() +def expand_template(template_file: str, decls: Iterable[str]) -> str: + with open(template_file, "r") as f: + contents = f.read() contents += "\n" contents += "\n".join(sorted(decls)) contents += "\n" @@ -239,7 +241,7 @@ def expand_template(template_file, decls): return contents -def verify_contents(expected_output_file, expected_contents): +def verify_contents(expected_output_file: str, expected_contents: str) -> None: with open(expected_output_file) as fh: current_contents = fh.readlines() diff = difflib.unified_diff(current_contents, expected_contents.splitlines(keepends=True)) @@ -253,7 +255,7 @@ def verify_contents(expected_output_file, expected_contents): Crate = collections.namedtuple('Crate', ["path", "features"], defaults=[()]) -def convert_to_typescript(rust_crates, ts_in_path, ts_out_path, verify): +def convert_to_typescript(rust_crates: Iterable[Crate], ts_in_path: str, ts_out_path: str, verify: bool) -> None: decls = itertools.chain.from_iterable(collect_decls(crate.path, crate.features) for crate in rust_crates) contents = expand_template(ts_in_path, decls) @@ -267,7 +269,7 @@ def convert_to_typescript(rust_crates, ts_in_path, ts_out_path, verify): verify_contents(ts_out_path, contents) -def main(): +def main() -> None: args = parse_args() our_abs_dir = os.path.dirname(os.path.realpath(__file__)) output_file_name = 'Native.d.ts'