from __future__ import annotations from contextlib import contextmanager from dataclasses import dataclass from pathlib import Path import os import sys REPO_ROOT = Path(__file__).resolve().parents[2] DEFAULT_INSTALL_DIR = Path( os.environ.get("GHIDRA_INSTALL_DIR", r"I:\Apps\ghidra_12.0.4_PUBLIC") ) DEFAULT_PROJECT_DIR = REPO_ROOT DEFAULT_PROJECT_NAME = "Crusader" DEFAULT_PROGRAM_NAME = "CRUSADER-RAW.EXE" DEFAULT_FOLDER_PATH = "/" @dataclass(frozen=True) class ProjectConfig: install_dir: Path = DEFAULT_INSTALL_DIR project_dir: Path = DEFAULT_PROJECT_DIR project_name: str = DEFAULT_PROJECT_NAME program_name: str = DEFAULT_PROGRAM_NAME folder_path: str = DEFAULT_FOLDER_PATH restore_project: bool = False def ensure_pyghidra_started(install_dir: Path | None = None): import pyghidra resolved_dir = Path(install_dir or DEFAULT_INSTALL_DIR) if not pyghidra.started(): with suppress_process_output(): pyghidra.start(install_dir=resolved_dir) return pyghidra @contextmanager def suppress_process_output(): with open(os.devnull, "w", encoding="utf-8") as devnull: original_stdout = os.dup(1) original_stderr = os.dup(2) try: sys.stdout.flush() sys.stderr.flush() os.dup2(devnull.fileno(), 1) os.dup2(devnull.fileno(), 2) yield finally: os.dup2(original_stdout, 1) os.dup2(original_stderr, 2) os.close(original_stdout) os.close(original_stderr) def parse_address_text(address_text: str) -> int: text = address_text.strip() if ":" in text: segment_text, offset_text = text.split(":", 1) return (int(segment_text, 16) << 16) + int(offset_text, 16) return int(text, 0) def to_address(program, address_text: str): address_space = program.getAddressFactory().getDefaultAddressSpace() return address_space.getAddress(parse_address_text(address_text)) def format_address(address) -> str: return str(address) def iter_java_items(items): if hasattr(items, "hasNext") and hasattr(items, "next"): while items.hasNext(): yield items.next() return for item in items: yield item def format_project_error(config: ProjectConfig, exc: Exception) -> RuntimeError: lock_path = config.project_dir / f"{config.project_name}.lock" details = [ f"unable to open project '{config.project_name}' in '{config.project_dir}'", str(exc), ] if lock_path.exists(): details.append( f"project lock present at '{lock_path}'; close Ghidra or work on a project copy for write operations" ) return RuntimeError("; ".join(details)) def open_project(config: ProjectConfig): ensure_pyghidra_started(config.install_dir) from ghidra.base.project import GhidraProject try: return GhidraProject.openProject( str(config.project_dir), config.project_name, config.restore_project, ) except Exception as exc: # pragma: no cover - depends on local Ghidra state raise format_project_error(config, exc) from exc def _candidate_folder_paths(folder_path: str) -> list[str]: candidates = [folder_path] for fallback in ("/", "\\", ""): if fallback not in candidates: candidates.append(fallback) return candidates @contextmanager def open_program(config: ProjectConfig, read_only: bool): project = open_project(config) program = None last_error = None try: for folder_path in _candidate_folder_paths(config.folder_path): try: program = project.openProgram(folder_path, config.program_name, read_only) break except Exception as exc: # pragma: no cover - depends on local Ghidra state last_error = exc if program is None: raise RuntimeError( f"unable to open program '{config.program_name}' from project '{config.project_name}': {last_error}" ) yield project, program finally: if project is not None: if program is not None: project.close(program) project.close() @contextmanager def transaction(program, description: str): transaction_id = program.startTransaction(description) commit = False try: yield commit = True finally: program.endTransaction(transaction_id, commit) def list_root_files(project) -> list[str]: return [domain_file.getName() for domain_file in project.getRootFolder().getFiles()] def get_function(program, entry_text: str): return program.getFunctionManager().getFunctionAt(to_address(program, entry_text)) def get_function_containing(program, address_text: str): return program.getFunctionManager().getFunctionContaining(to_address(program, address_text)) def read_region_bytes(program, start_text: str, end_text: str) -> bytes: memory = program.getMemory() start = to_address(program, start_text) end = to_address(program, end_text) size = end.subtract(start) + 1 if size < 0: raise ValueError(f"invalid address range: {start_text}..{end_text}") data = bytearray() current = start for _ in range(size): data.append(int(memory.getByte(current)) & 0xFF) current = current.next() return bytes(data) def iter_functions(program): return program.getFunctionManager().getFunctions(True) def function_signature(function) -> str: return function.getPrototypeString(True, True) def function_body_range(function) -> tuple[str, str]: body = function.getBody() return format_address(body.getMinAddress()), format_address(body.getMaxAddress()) def format_function_summary(function) -> str: body_start, body_end = function_body_range(function) return ( f"Function: {function.getName()} at {format_address(function.getEntryPoint())}\n" f"Signature: {function_signature(function)}\n" f"Entry: {format_address(function.getEntryPoint())}\n" f"Body: {body_start} - {body_end}" ) def list_segments(program, offset: int = 0, limit: int = 100): memory = program.getMemory() matches = [] skipped = 0 for block in memory.getBlocks(): if skipped < offset: skipped += 1 continue matches.append( { "name": block.getName(), "start": format_address(block.getStart()), "end": format_address(block.getEnd()), "length": int(block.getSize()), "initialized": bool(block.isInitialized()), "read": bool(block.isRead()), "write": bool(block.isWrite()), "execute": bool(block.isExecute()), } ) if len(matches) >= limit: break return matches def list_data_items(program, offset: int = 0, limit: int = 100): listing = program.getListing() matches = [] skipped = 0 for data in iter_java_items(listing.getDefinedData(True)): if skipped < offset: skipped += 1 continue value = data.getValue() matches.append( { "address": format_address(data.getAddress()), "length": int(data.getLength()), "mnemonic": data.getMnemonicString(), "value": None if value is None else str(value), } ) if len(matches) >= limit: break return matches def list_classes(program, offset: int = 0, limit: int = 100): from ghidra.program.model.symbol import SymbolType symbol_table = program.getSymbolTable() matches = [] skipped = 0 for symbol in iter_java_items(symbol_table.getDefinedSymbols()): if symbol.getSymbolType() != SymbolType.CLASS: continue namespace = symbol.getObject() parent = namespace.getParentNamespace() if namespace is not None else None matches.append( { "name": symbol.getName(), "parent": None if parent is None or parent.isGlobal() else parent.getName(), } ) matches.sort(key=lambda entry: (entry["parent"] or "", entry["name"])) return matches[offset: offset + limit] def search_functions_by_name(program, query: str, offset: int = 0, limit: int = 100): lowered = query.lower() matches = [] skipped = 0 for function in iter_java_items(iter_functions(program)): if lowered not in function.getName().lower(): continue if skipped < offset: skipped += 1 continue matches.append(function) if len(matches) >= limit: break return matches def get_functions_by_exact_name(program, name: str): matches = [] for function in iter_java_items(iter_functions(program)): if function.getName() == name: matches.append(function) return matches def create_function(program, entry_text: str, name: str, body_start: str | None, body_end: str | None): from ghidra.program.model.address import AddressSet from ghidra.program.model.symbol import SourceType entry_address = to_address(program, entry_text) body_start_address = to_address(program, body_start or entry_text) body_end_address = to_address(program, body_end or entry_text) body = AddressSet(body_start_address, body_end_address) return program.getFunctionManager().createFunction( name, entry_address, body, SourceType.USER_DEFINED, ) def remove_function(program, entry_text: str) -> bool: return bool(program.getFunctionManager().removeFunction(to_address(program, entry_text))) def rename_function(program, entry_text: str, new_name: str): from ghidra.program.model.symbol import SourceType function = get_function(program, entry_text) if function is None: raise ValueError(f"no function found at {entry_text}") function.setName(new_name, SourceType.USER_DEFINED) return function def decompile_function(program, function, timeout_seconds: int = 30) -> str: from ghidra.app.decompiler import DecompInterface from ghidra.util.task import ConsoleTaskMonitor interface = DecompInterface() interface.openProgram(program) try: result = interface.decompileFunction(function, timeout_seconds, ConsoleTaskMonitor()) if not result.decompileCompleted(): error_message = result.getErrorMessage() or "decompilation did not complete" raise RuntimeError(error_message) decompiled = result.getDecompiledFunction() if decompiled is None: raise RuntimeError("decompiler returned no function text") return decompiled.getC() finally: interface.dispose() def disassemble_function(program, function) -> list[str]: from ghidra.program.model.listing import CodeUnit listing = program.getListing() lines = [] for instruction in iter_java_items(listing.getInstructions(function.getBody(), True)): line = f"{format_address(instruction.getAddress())}: {instruction.toString()}" if instruction.getFlowType().isCall(): references = instruction.getReferencesFrom() if references: target = references[0].getToAddress() target_function = program.getFunctionManager().getFunctionAt(target) if target_function is not None: line += f" -> {target_function.getName()} @ {format_address(target)}" else: line += f" -> {format_address(target)}" comment = instruction.getComment(CodeUnit.EOL_COMMENT) if comment: line += f" ; {comment}" lines.append(line) return lines def _reference_dict(reference) -> dict[str, str | int]: return { "from": format_address(reference.getFromAddress()), "to": format_address(reference.getToAddress()), "type": str(reference.getReferenceType()), "operand_index": int(reference.getOperandIndex()), } def get_xrefs_to(program, address_text: str, offset: int = 0, limit: int = 100) -> list[dict[str, str | int]]: reference_manager = program.getReferenceManager() target_address = to_address(program, address_text) results = [] skipped = 0 for reference in iter_java_items(reference_manager.getReferencesTo(target_address)): if skipped < offset: skipped += 1 continue results.append(_reference_dict(reference)) if len(results) >= limit: break return results def get_xrefs_from(program, address_text: str, offset: int = 0, limit: int = 100) -> list[dict[str, str | int]]: reference_manager = program.getReferenceManager() source_address = to_address(program, address_text) results = [] skipped = 0 for reference in iter_java_items(reference_manager.getReferencesFrom(source_address)): if skipped < offset: skipped += 1 continue results.append(_reference_dict(reference)) if len(results) >= limit: break return results def list_strings(program, offset: int = 0, limit: int = 2000, filter_text: str | None = None): listing = program.getListing() matches = [] skipped = 0 lowered_filter = filter_text.lower() if filter_text else None for data in iter_java_items(listing.getDefinedData(True)): if not data.hasStringValue(): continue text = str(data.getValue()) if lowered_filter and lowered_filter not in text.lower(): continue if skipped < offset: skipped += 1 continue matches.append( { "address": format_address(data.getAddress()), "length": int(data.getLength()), "text": text, } ) if len(matches) >= limit: break return matches def list_imports(program, offset: int = 0, limit: int = 100): external_manager = program.getExternalManager() matches = [] skipped = 0 for library_name in external_manager.getExternalLibraryNames(): for location in iter_java_items(external_manager.getExternalLocations(library_name)): if skipped < offset: skipped += 1 continue label = location.getLabel() address = location.getAddress() matches.append( { "library": str(library_name), "label": str(label) if label is not None else None, "address": format_address(address) if address is not None else None, } ) if len(matches) >= limit: return matches return matches def list_exports(program, offset: int = 0, limit: int = 100): symbol_table = program.getSymbolTable() function_manager = program.getFunctionManager() matches = [] skipped = 0 for address in iter_java_items(symbol_table.getExternalEntryPointIterator()): if skipped < offset: skipped += 1 continue function = function_manager.getFunctionAt(address) primary_symbol = symbol_table.getPrimarySymbol(address) matches.append( { "address": format_address(address), "name": function.getName() if function is not None else (primary_symbol.getName() if primary_symbol is not None else None), "kind": "function" if function is not None else (str(primary_symbol.getSymbolType()) if primary_symbol is not None else "unknown"), } ) if len(matches) >= limit: break return matches def list_namespaces(program, offset: int = 0, limit: int = 100): from ghidra.program.model.symbol import SymbolType symbol_table = program.getSymbolTable() matches = [] skipped = 0 for symbol in iter_java_items(symbol_table.getDefinedSymbols()): symbol_type = symbol.getSymbolType() if symbol_type not in (SymbolType.NAMESPACE, SymbolType.CLASS, SymbolType.LIBRARY): continue namespace = symbol.getObject() parent = namespace.getParentNamespace() if namespace is not None else None if parent is not None and parent.isGlobal(): parent_name = None else: parent_name = parent.getName() if parent is not None else None if skipped < offset: skipped += 1 continue matches.append( { "name": symbol.getName(), "type": str(symbol_type), "parent": parent_name, } ) if len(matches) >= limit: break return matches def run_script_file(script_path: Path, globals_dict: dict): script_globals = dict(globals_dict) script_globals.setdefault("__name__", "__main__") script_globals.setdefault("__file__", str(script_path)) code = compile(script_path.read_text(encoding="utf-8"), str(script_path), "exec") exec(code, script_globals, script_globals) return script_globals def set_comment(program, address_text: str, comment: str, comment_type: str): from ghidra.program.model.listing import CodeUnit comment_types = { "pre": CodeUnit.PRE_COMMENT, "plate": CodeUnit.PLATE_COMMENT, "eol": CodeUnit.EOL_COMMENT, "repeatable": CodeUnit.REPEATABLE_COMMENT, "post": CodeUnit.POST_COMMENT, } if comment_type not in comment_types: raise ValueError(f"unsupported comment type: {comment_type}") listing = program.getListing() target_address = to_address(program, address_text) code_unit = listing.getCodeUnitAt(target_address) if code_unit is None: function = program.getFunctionManager().getFunctionAt(target_address) if function is not None: function.setComment(comment) return raise ValueError(f"no code unit or function found at {address_text}") code_unit.setComment(comment_types[comment_type], comment) def save_program(project, program): project.save(program)