Crusader_Decomp/tools/pyghidra_crusader/common.py
MaddoScientisto a56851f994 Enhance CLI functionality and improve common utilities
- Added new commands to the CLI for dumping regions, renaming functions by address, and setting various types of comments.
- Implemented JSON output formatting for CLI commands.
- Introduced functions for decompiling and disassembling functions, as well as retrieving cross-references.
- Enhanced common utilities with functions for reading memory regions, iterating Java items, and managing function metadata.
- Added suppress_output context manager to hide process output during Ghidra startup.
- Updated existing functions to improve error handling and output formatting.
2026-03-21 09:44:35 +01:00

547 lines
No EOL
18 KiB
Python

from __future__ import annotations
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
import os
import sys
REPO_ROOT = Path(__file__).resolve().parents[2]
DEFAULT_INSTALL_DIR = Path(
os.environ.get("GHIDRA_INSTALL_DIR", r"I:\Apps\ghidra_11.3.2_PUBLIC")
)
DEFAULT_PROJECT_DIR = REPO_ROOT
DEFAULT_PROJECT_NAME = "Crusader"
DEFAULT_PROGRAM_NAME = "CRUSADER-RAW.EXE"
DEFAULT_FOLDER_PATH = "/"
@dataclass(frozen=True)
class ProjectConfig:
install_dir: Path = DEFAULT_INSTALL_DIR
project_dir: Path = DEFAULT_PROJECT_DIR
project_name: str = DEFAULT_PROJECT_NAME
program_name: str = DEFAULT_PROGRAM_NAME
folder_path: str = DEFAULT_FOLDER_PATH
restore_project: bool = False
def ensure_pyghidra_started(install_dir: Path | None = None):
import pyghidra
resolved_dir = Path(install_dir or DEFAULT_INSTALL_DIR)
if not pyghidra.started():
with suppress_process_output():
pyghidra.start(install_dir=resolved_dir)
return pyghidra
@contextmanager
def suppress_process_output():
with open(os.devnull, "w", encoding="utf-8") as devnull:
original_stdout = os.dup(1)
original_stderr = os.dup(2)
try:
sys.stdout.flush()
sys.stderr.flush()
os.dup2(devnull.fileno(), 1)
os.dup2(devnull.fileno(), 2)
yield
finally:
os.dup2(original_stdout, 1)
os.dup2(original_stderr, 2)
os.close(original_stdout)
os.close(original_stderr)
def parse_address_text(address_text: str) -> int:
text = address_text.strip()
if ":" in text:
segment_text, offset_text = text.split(":", 1)
return (int(segment_text, 16) << 16) + int(offset_text, 16)
return int(text, 0)
def to_address(program, address_text: str):
address_space = program.getAddressFactory().getDefaultAddressSpace()
return address_space.getAddress(parse_address_text(address_text))
def format_address(address) -> str:
return str(address)
def iter_java_items(items):
if hasattr(items, "hasNext") and hasattr(items, "next"):
while items.hasNext():
yield items.next()
return
for item in items:
yield item
def format_project_error(config: ProjectConfig, exc: Exception) -> RuntimeError:
lock_path = config.project_dir / f"{config.project_name}.lock"
details = [
f"unable to open project '{config.project_name}' in '{config.project_dir}'",
str(exc),
]
if lock_path.exists():
details.append(
f"project lock present at '{lock_path}'; close Ghidra or work on a project copy for write operations"
)
return RuntimeError("; ".join(details))
def open_project(config: ProjectConfig):
ensure_pyghidra_started(config.install_dir)
from ghidra.base.project import GhidraProject
try:
return GhidraProject.openProject(
str(config.project_dir),
config.project_name,
config.restore_project,
)
except Exception as exc: # pragma: no cover - depends on local Ghidra state
raise format_project_error(config, exc) from exc
def _candidate_folder_paths(folder_path: str) -> list[str]:
candidates = [folder_path]
for fallback in ("/", "\\", ""):
if fallback not in candidates:
candidates.append(fallback)
return candidates
@contextmanager
def open_program(config: ProjectConfig, read_only: bool):
project = open_project(config)
program = None
last_error = None
try:
for folder_path in _candidate_folder_paths(config.folder_path):
try:
program = project.openProgram(folder_path, config.program_name, read_only)
break
except Exception as exc: # pragma: no cover - depends on local Ghidra state
last_error = exc
if program is None:
raise RuntimeError(
f"unable to open program '{config.program_name}' from project '{config.project_name}': {last_error}"
)
yield project, program
finally:
if project is not None:
if program is not None:
project.close(program)
project.close()
@contextmanager
def transaction(program, description: str):
transaction_id = program.startTransaction(description)
commit = False
try:
yield
commit = True
finally:
program.endTransaction(transaction_id, commit)
def list_root_files(project) -> list[str]:
return [domain_file.getName() for domain_file in project.getRootFolder().getFiles()]
def get_function(program, entry_text: str):
return program.getFunctionManager().getFunctionAt(to_address(program, entry_text))
def get_function_containing(program, address_text: str):
return program.getFunctionManager().getFunctionContaining(to_address(program, address_text))
def read_region_bytes(program, start_text: str, end_text: str) -> bytes:
memory = program.getMemory()
start = to_address(program, start_text)
end = to_address(program, end_text)
size = end.subtract(start) + 1
if size < 0:
raise ValueError(f"invalid address range: {start_text}..{end_text}")
data = bytearray()
current = start
for _ in range(size):
data.append(int(memory.getByte(current)) & 0xFF)
current = current.next()
return bytes(data)
def iter_functions(program):
return program.getFunctionManager().getFunctions(True)
def function_signature(function) -> str:
return function.getPrototypeString(True, True)
def function_body_range(function) -> tuple[str, str]:
body = function.getBody()
return format_address(body.getMinAddress()), format_address(body.getMaxAddress())
def format_function_summary(function) -> str:
body_start, body_end = function_body_range(function)
return (
f"Function: {function.getName()} at {format_address(function.getEntryPoint())}\n"
f"Signature: {function_signature(function)}\n"
f"Entry: {format_address(function.getEntryPoint())}\n"
f"Body: {body_start} - {body_end}"
)
def list_segments(program, offset: int = 0, limit: int = 100):
memory = program.getMemory()
matches = []
skipped = 0
for block in memory.getBlocks():
if skipped < offset:
skipped += 1
continue
matches.append(
{
"name": block.getName(),
"start": format_address(block.getStart()),
"end": format_address(block.getEnd()),
"length": int(block.getSize()),
"initialized": bool(block.isInitialized()),
"read": bool(block.isRead()),
"write": bool(block.isWrite()),
"execute": bool(block.isExecute()),
}
)
if len(matches) >= limit:
break
return matches
def list_data_items(program, offset: int = 0, limit: int = 100):
listing = program.getListing()
matches = []
skipped = 0
for data in iter_java_items(listing.getDefinedData(True)):
if skipped < offset:
skipped += 1
continue
value = data.getValue()
matches.append(
{
"address": format_address(data.getAddress()),
"length": int(data.getLength()),
"mnemonic": data.getMnemonicString(),
"value": None if value is None else str(value),
}
)
if len(matches) >= limit:
break
return matches
def list_classes(program, offset: int = 0, limit: int = 100):
from ghidra.program.model.symbol import SymbolType
symbol_table = program.getSymbolTable()
matches = []
skipped = 0
for symbol in iter_java_items(symbol_table.getDefinedSymbols()):
if symbol.getSymbolType() != SymbolType.CLASS:
continue
namespace = symbol.getObject()
parent = namespace.getParentNamespace() if namespace is not None else None
matches.append(
{
"name": symbol.getName(),
"parent": None if parent is None or parent.isGlobal() else parent.getName(),
}
)
matches.sort(key=lambda entry: (entry["parent"] or "", entry["name"]))
return matches[offset: offset + limit]
def search_functions_by_name(program, query: str, offset: int = 0, limit: int = 100):
lowered = query.lower()
matches = []
skipped = 0
for function in iter_java_items(iter_functions(program)):
if lowered not in function.getName().lower():
continue
if skipped < offset:
skipped += 1
continue
matches.append(function)
if len(matches) >= limit:
break
return matches
def get_functions_by_exact_name(program, name: str):
matches = []
for function in iter_java_items(iter_functions(program)):
if function.getName() == name:
matches.append(function)
return matches
def create_function(program, entry_text: str, name: str, body_start: str | None, body_end: str | None):
from ghidra.program.model.address import AddressSet
from ghidra.program.model.symbol import SourceType
entry_address = to_address(program, entry_text)
body_start_address = to_address(program, body_start or entry_text)
body_end_address = to_address(program, body_end or entry_text)
body = AddressSet(body_start_address, body_end_address)
return program.getFunctionManager().createFunction(
name,
entry_address,
body,
SourceType.USER_DEFINED,
)
def remove_function(program, entry_text: str) -> bool:
return bool(program.getFunctionManager().removeFunction(to_address(program, entry_text)))
def rename_function(program, entry_text: str, new_name: str):
from ghidra.program.model.symbol import SourceType
function = get_function(program, entry_text)
if function is None:
raise ValueError(f"no function found at {entry_text}")
function.setName(new_name, SourceType.USER_DEFINED)
return function
def decompile_function(program, function, timeout_seconds: int = 30) -> str:
from ghidra.app.decompiler import DecompInterface
from ghidra.util.task import ConsoleTaskMonitor
interface = DecompInterface()
interface.openProgram(program)
try:
result = interface.decompileFunction(function, timeout_seconds, ConsoleTaskMonitor())
if not result.decompileCompleted():
error_message = result.getErrorMessage() or "decompilation did not complete"
raise RuntimeError(error_message)
decompiled = result.getDecompiledFunction()
if decompiled is None:
raise RuntimeError("decompiler returned no function text")
return decompiled.getC()
finally:
interface.dispose()
def disassemble_function(program, function) -> list[str]:
from ghidra.program.model.listing import CodeUnit
listing = program.getListing()
lines = []
for instruction in iter_java_items(listing.getInstructions(function.getBody(), True)):
line = f"{format_address(instruction.getAddress())}: {instruction.toString()}"
if instruction.getFlowType().isCall():
references = instruction.getReferencesFrom()
if references:
target = references[0].getToAddress()
target_function = program.getFunctionManager().getFunctionAt(target)
if target_function is not None:
line += f" -> {target_function.getName()} @ {format_address(target)}"
else:
line += f" -> {format_address(target)}"
comment = instruction.getComment(CodeUnit.EOL_COMMENT)
if comment:
line += f" ; {comment}"
lines.append(line)
return lines
def _reference_dict(reference) -> dict[str, str | int]:
return {
"from": format_address(reference.getFromAddress()),
"to": format_address(reference.getToAddress()),
"type": str(reference.getReferenceType()),
"operand_index": int(reference.getOperandIndex()),
}
def get_xrefs_to(program, address_text: str, offset: int = 0, limit: int = 100) -> list[dict[str, str | int]]:
reference_manager = program.getReferenceManager()
target_address = to_address(program, address_text)
results = []
skipped = 0
for reference in iter_java_items(reference_manager.getReferencesTo(target_address)):
if skipped < offset:
skipped += 1
continue
results.append(_reference_dict(reference))
if len(results) >= limit:
break
return results
def get_xrefs_from(program, address_text: str, offset: int = 0, limit: int = 100) -> list[dict[str, str | int]]:
reference_manager = program.getReferenceManager()
source_address = to_address(program, address_text)
results = []
skipped = 0
for reference in iter_java_items(reference_manager.getReferencesFrom(source_address)):
if skipped < offset:
skipped += 1
continue
results.append(_reference_dict(reference))
if len(results) >= limit:
break
return results
def list_strings(program, offset: int = 0, limit: int = 2000, filter_text: str | None = None):
listing = program.getListing()
matches = []
skipped = 0
lowered_filter = filter_text.lower() if filter_text else None
for data in iter_java_items(listing.getDefinedData(True)):
if not data.hasStringValue():
continue
text = str(data.getValue())
if lowered_filter and lowered_filter not in text.lower():
continue
if skipped < offset:
skipped += 1
continue
matches.append(
{
"address": format_address(data.getAddress()),
"length": int(data.getLength()),
"text": text,
}
)
if len(matches) >= limit:
break
return matches
def list_imports(program, offset: int = 0, limit: int = 100):
external_manager = program.getExternalManager()
matches = []
skipped = 0
for library_name in external_manager.getExternalLibraryNames():
for location in iter_java_items(external_manager.getExternalLocations(library_name)):
if skipped < offset:
skipped += 1
continue
label = location.getLabel()
address = location.getAddress()
matches.append(
{
"library": str(library_name),
"label": str(label) if label is not None else None,
"address": format_address(address) if address is not None else None,
}
)
if len(matches) >= limit:
return matches
return matches
def list_exports(program, offset: int = 0, limit: int = 100):
symbol_table = program.getSymbolTable()
function_manager = program.getFunctionManager()
matches = []
skipped = 0
for address in iter_java_items(symbol_table.getExternalEntryPointIterator()):
if skipped < offset:
skipped += 1
continue
function = function_manager.getFunctionAt(address)
primary_symbol = symbol_table.getPrimarySymbol(address)
matches.append(
{
"address": format_address(address),
"name": function.getName() if function is not None else (primary_symbol.getName() if primary_symbol is not None else None),
"kind": "function" if function is not None else (str(primary_symbol.getSymbolType()) if primary_symbol is not None else "unknown"),
}
)
if len(matches) >= limit:
break
return matches
def list_namespaces(program, offset: int = 0, limit: int = 100):
from ghidra.program.model.symbol import SymbolType
symbol_table = program.getSymbolTable()
matches = []
skipped = 0
for symbol in iter_java_items(symbol_table.getDefinedSymbols()):
symbol_type = symbol.getSymbolType()
if symbol_type not in (SymbolType.NAMESPACE, SymbolType.CLASS, SymbolType.LIBRARY):
continue
namespace = symbol.getObject()
parent = namespace.getParentNamespace() if namespace is not None else None
if parent is not None and parent.isGlobal():
parent_name = None
else:
parent_name = parent.getName() if parent is not None else None
if skipped < offset:
skipped += 1
continue
matches.append(
{
"name": symbol.getName(),
"type": str(symbol_type),
"parent": parent_name,
}
)
if len(matches) >= limit:
break
return matches
def run_script_file(script_path: Path, globals_dict: dict):
script_globals = dict(globals_dict)
script_globals.setdefault("__name__", "__main__")
script_globals.setdefault("__file__", str(script_path))
code = compile(script_path.read_text(encoding="utf-8"), str(script_path), "exec")
exec(code, script_globals, script_globals)
return script_globals
def set_comment(program, address_text: str, comment: str, comment_type: str):
from ghidra.program.model.listing import CodeUnit
comment_types = {
"pre": CodeUnit.PRE_COMMENT,
"plate": CodeUnit.PLATE_COMMENT,
"eol": CodeUnit.EOL_COMMENT,
"repeatable": CodeUnit.REPEATABLE_COMMENT,
"post": CodeUnit.POST_COMMENT,
}
if comment_type not in comment_types:
raise ValueError(f"unsupported comment type: {comment_type}")
listing = program.getListing()
target_address = to_address(program, address_text)
code_unit = listing.getCodeUnitAt(target_address)
if code_unit is None:
function = program.getFunctionManager().getFunctionAt(target_address)
if function is not None:
function.setComment(comment)
return
raise ValueError(f"no code unit or function found at {address_text}")
code_unit.setComment(comment_types[comment_type], comment)
def save_program(project, program):
project.save(program)