Crusader_Decomp/_tmp_psx_gpu_search.py

314 lines
12 KiB
Python
Raw Permalink Normal View History

2026-03-30 00:19:01 +02:00
from __future__ import annotations
import bisect
import json
import struct
import sys
from pathlib import Path
ROOT = Path(r"k:/ghidra/Crusader_Decomp")
BUNDLE_DIR = ROOT / "out/psx_wdl/L0/sprite_bundles/bundle_000A1B04"
FRAME_PATH = BUNDLE_DIR / "frame_000.bin"
BUNDLE_JSON = BUNDLE_DIR / "bundle.json"
GPU_PATH = ROOT / "binary/Crusader - No Remorse (USA) GPU RAM.bin"
L0_WDL_PATH = Path(r"e:/emu/psx/Crusader - No Remorse/LSET1/L0.WDL")
ROW_BYTES = 2048
GPU_ROWS = 512
TOP_N = 10
FRAMEBUFFER_WIDTH = 320
FRAMEBUFFER_HEIGHT = 240
MATCH_TOP_N = 12
sys.path.insert(0, str(ROOT / "tools"))
from psx_extract_wdl import colorize_indexed_pixels, psx_555_to_rgba, write_overview_grid, write_psx_16bpp_png
def find_all(haystack: bytes, needle: bytes):
start = 0
while True:
index = haystack.find(needle, start)
if index < 0:
return
yield index
start = index + 1
def count_row_mismatches(left: bytes, right: bytes) -> int:
return sum(a != b for a, b in zip(left, right))
def is_exact_at(rows: list[bytes], candidate_rows: list[bytes], x: int, y: int, width: int) -> bool:
for dy, src in enumerate(candidate_rows):
if rows[y + dy][x : x + width] != src:
return False
return True
def near_score(
rows: list[bytes],
candidate_rows: list[bytes],
x: int,
y: int,
width: int,
cutoff: int | None,
) -> tuple[int, list[int], bool]:
total = 0
row_mismatches: list[int] = []
for dy, src in enumerate(candidate_rows):
seg = rows[y + dy][x : x + width]
mismatch = 0 if seg == src else count_row_mismatches(seg, src)
total += mismatch
row_mismatches.append(mismatch)
if cutoff is not None and total > cutoff:
return total, row_mismatches, True
return total, row_mismatches, False
def rgba_from_words(words: tuple[int, ...]) -> list[tuple[int, int, int]]:
return [psx_555_to_rgba(word)[:3] for word in words]
def candidate_match_score(
framebuffer_rgb: list[tuple[int, int, int]],
framebuffer_width: int,
framebuffer_height: int,
rgba: bytes,
width: int,
height: int,
guess_x: int,
guess_y: int,
radius: int = 12,
step: int = 2,
) -> tuple[int, int, int]:
best_score: int | None = None
best_x = -1
best_y = -1
x_min = max(0, guess_x - radius)
x_max = min(framebuffer_width - width, guess_x + radius)
y_min = max(0, guess_y - radius)
y_max = min(framebuffer_height - height, guess_y + radius)
for y in range(y_min, y_max + 1):
for x in range(x_min, x_max + 1):
score = 0
samples = 0
for sy in range(0, height, step):
screen_row = (y + sy) * framebuffer_width
sprite_row = sy * width * 4
for sx in range(0, width, step):
src = sprite_row + sx * 4
if rgba[src + 3] == 0:
continue
screen_r, screen_g, screen_b = framebuffer_rgb[screen_row + x + sx]
red = rgba[src]
green = rgba[src + 1]
blue = rgba[src + 2]
score += abs(screen_r - red) + abs(screen_g - green) + abs(screen_b - blue)
samples += 1
if samples == 0:
continue
normalized = score // samples
if best_score is None or normalized < best_score:
best_score = normalized
best_x = x
best_y = y
if best_score is None:
return 1 << 30, -1, -1
return best_score, best_x, best_y
def main() -> None:
bundle = json.loads(BUNDLE_JSON.read_text(encoding="ascii"))
frame_meta = next(frame for frame in bundle["exported_frames"] if frame["index"] == 0)
width = frame_meta["width"]
height = frame_meta["height"]
mode = bundle["mode"]
frame = FRAME_PATH.read_bytes()
expected = width * height
if len(frame) != expected:
raise SystemExit(f"frame byte size mismatch: got {len(frame)}, expected {expected}")
if mode != 1:
raise SystemExit(f"unexpected mode {mode}, expected 1 for 8bpp")
gpu = GPU_PATH.read_bytes()
if len(gpu) != ROW_BYTES * GPU_ROWS:
raise SystemExit(f"unexpected GPU dump size {len(gpu)}")
l0_data = L0_WDL_PATH.read_bytes()
palette_offset = int.from_bytes(l0_data[8:12], "little")
palette_size = int.from_bytes(l0_data[12:16], "little")
if palette_size != 0x1000:
raise SystemExit(f"unexpected palette size 0x{palette_size:X}")
palette_blob = l0_data[palette_offset : palette_offset + palette_size]
palettes_256 = [palette_blob[offset : offset + 0x200] for offset in range(0, len(palette_blob), 0x200)]
rows = [gpu[y * ROW_BYTES : (y + 1) * ROW_BYTES] for y in range(GPU_ROWS)]
frame_rows = [frame[i * width : (i + 1) * width] for i in range(height)]
flip_rows = [row[::-1] for row in frame_rows]
normal_hits: list[tuple[int, int]] = []
flipped_hits: list[tuple[int, int]] = []
for y in range(GPU_ROWS - height + 1):
row = rows[y]
normal_hits.extend((x, y) for x in find_all(row, frame_rows[0]))
flipped_hits.extend((x, y) for x in find_all(row, flip_rows[0]))
exact_normal = [(x, y) for x, y in normal_hits if is_exact_at(rows, frame_rows, x, y, width)]
exact_flipped = [(x, y) for x, y in flipped_hits if is_exact_at(rows, flip_rows, x, y, width)]
print(f"bundle_offset=0x{bundle['offset']:X} mode={mode} frame_count={bundle['frame_count']}")
print(
"frame0 "
f"width={width} height={height} origin=({frame_meta['origin_x']},{frame_meta['origin_y']}) "
f"data_start={frame_meta['data_start']} consumed={frame_meta['consumed']}"
)
print(f"frame_bytes={len(frame)} gpu_dump_bytes={len(gpu)}")
print(f"row0_hits normal={len(normal_hits)} flipped={len(flipped_hits)}")
print(f"exact_full_matches_normal={len(exact_normal)}")
for x, y in exact_normal[:TOP_N]:
print(f" normal x={x} y={y} page=({x // 256},{y // 256}) in_page=({x % 256},{y % 256})")
print(f"exact_full_matches_flipped={len(exact_flipped)}")
for x, y in exact_flipped[:TOP_N]:
print(f" flipped x={x} y={y} page=({x // 256},{y // 256}) in_page=({x % 256},{y % 256})")
live_palette_entries: list[dict[str, object]] = []
live_palette_labels: list[str] = []
for clut_row in range(8):
y = 0xF0 + clut_row
row_words = struct.unpack("<1024H", rows[y])
for column in range(16):
x = column * 16
palette = list(row_words[x : x + 256])
rgba = colorize_indexed_pixels(frame, width, height, mode, palette)
live_palette_entries.append(
{
"width": width,
"height": height,
"rgba": rgba,
}
)
live_palette_labels.append(f"index={clut_row * 16 + column} x={x} y={y}")
atlas_path = BUNDLE_DIR / "live_vram_clut_atlas.png"
labels_path = BUNDLE_DIR / "live_vram_clut_atlas.txt"
write_overview_grid(atlas_path, live_palette_entries, columns=16)
labels_path.write_text("\n".join(live_palette_labels) + "\n", encoding="ascii")
print(f"live_vram_clut_atlas={atlas_path}")
print(f"live_vram_clut_labels={labels_path}")
framebuffer_path = ROOT / "binary/psx_framebuffer_left.png"
framebuffer_crop_path = ROOT / "binary/psx_framebuffer_console_crop.png"
print(f"raw_palette_blocks_256={len(palettes_256)}")
for palette_index, palette in enumerate(palettes_256):
palette_hits: list[tuple[int, int]] = []
for y in range(240, 256):
row = rows[y]
start = 0
while True:
x = row.find(palette, start)
if x < 0:
break
palette_hits.append((x, y))
start = x + 1
print(f" palette_{palette_index}_hits={len(palette_hits)}")
for x, y in palette_hits[:TOP_N]:
print(f" palette_{palette_index} x={x} y={y} row_band={y - 240}")
framebuffer_bytes = bytearray(FRAMEBUFFER_WIDTH * FRAMEBUFFER_HEIGHT * 2)
for y in range(FRAMEBUFFER_HEIGHT):
src_row = rows[y]
start = y * FRAMEBUFFER_WIDTH * 2
framebuffer_bytes[start : start + FRAMEBUFFER_WIDTH * 2] = src_row[: FRAMEBUFFER_WIDTH * 2]
write_psx_16bpp_png(framebuffer_path, bytes(framebuffer_bytes), FRAMEBUFFER_WIDTH, FRAMEBUFFER_HEIGHT)
framebuffer_words = struct.unpack(f"<{FRAMEBUFFER_WIDTH * FRAMEBUFFER_HEIGHT}H", bytes(framebuffer_bytes))
framebuffer_rgb = rgba_from_words(framebuffer_words)
crop_x = 70
crop_y = 0
crop_width = 210
crop_height = 110
crop_bytes = bytearray(crop_width * crop_height * 2)
for y in range(crop_height):
src = rows[crop_y + y]
src_start = crop_x * 2
src_end = src_start + crop_width * 2
dst_start = y * crop_width * 2
crop_bytes[dst_start : dst_start + crop_width * 2] = src[src_start:src_end]
write_psx_16bpp_png(framebuffer_crop_path, bytes(crop_bytes), crop_width, crop_height)
print(f"framebuffer_left={framebuffer_path}")
print(f"framebuffer_console_crop={framebuffer_crop_path}")
palette_rankings: list[tuple[int, int, int, int]] = []
for palette_index, entry in enumerate(live_palette_entries):
score, best_x, best_y = candidate_match_score(
framebuffer_rgb,
FRAMEBUFFER_WIDTH,
FRAMEBUFFER_HEIGHT,
entry["rgba"],
width,
height,
guess_x=107,
guess_y=12,
)
palette_rankings.append((score, palette_index, best_x, best_y))
palette_rankings.sort()
ranking_path = BUNDLE_DIR / "live_vram_clut_rank.txt"
top_atlas_path = BUNDLE_DIR / "live_vram_clut_top_matches.png"
best_candidate_path = BUNDLE_DIR / "live_vram_clut_best.png"
ranking_lines = []
print(f"best_live_vram_clut_matches_top_{MATCH_TOP_N}={min(MATCH_TOP_N, len(palette_rankings))}")
top_entries: list[dict[str, object]] = []
for score, palette_index, best_x, best_y in palette_rankings[:MATCH_TOP_N]:
label = live_palette_labels[palette_index]
line = f"score={score} {label} screen=({best_x},{best_y})"
ranking_lines.append(line)
print(f" {line}")
top_entries.append(live_palette_entries[palette_index])
ranking_path.write_text("\n".join(ranking_lines) + "\n", encoding="ascii")
print(f"live_vram_clut_rank={ranking_path}")
write_overview_grid(top_atlas_path, top_entries, columns=4)
print(f"live_vram_clut_top_matches={top_atlas_path}")
if palette_rankings:
best_palette_index = palette_rankings[0][1]
best_entry = live_palette_entries[best_palette_index]
write_overview_grid(best_candidate_path, [best_entry], columns=1)
print(f"live_vram_clut_best={best_candidate_path}")
if exact_normal or exact_flipped:
return
ranked: list[tuple[int, int, int, str, list[int]]] = []
cutoff: int | None = None
for orientation, hits, candidate_rows in (
("normal", normal_hits, frame_rows),
("flipped", flipped_hits, flip_rows),
):
for x, y in hits:
total, row_mismatches, pruned = near_score(rows, candidate_rows, x, y, width, cutoff)
if pruned and len(ranked) >= TOP_N and total > ranked[-1][0]:
continue
entry = (total, y, x, orientation, row_mismatches)
insert_at = bisect.bisect_left(ranked, entry)
ranked.insert(insert_at, entry)
if len(ranked) > TOP_N:
ranked.pop()
if len(ranked) == TOP_N:
cutoff = ranked[-1][0]
print(f"best_near_matches_top_{TOP_N}={len(ranked)}")
for total, y, x, orientation, row_mismatches in ranked:
nonzero_rows = [(index, mismatch) for index, mismatch in enumerate(row_mismatches) if mismatch]
sample = ", ".join(f"r{index}={mismatch}" for index, mismatch in nonzero_rows[:8])
if len(nonzero_rows) > 8:
sample += ", ..."
if not sample:
sample = "all rows exact"
print(
f" {orientation} x={x} y={y} page=({x // 256},{y // 256}) in_page=({x % 256},{y % 256}) "
f"mismatches={total} details=[{sample}]"
)
if __name__ == "__main__":
main()