from __future__ import annotations import sys from dataclasses import dataclass, field from typing import Callable from .formats import FLAG_FLIPPED, MapItem, ShapeArchive, ShapeFrame, ShapeInfo @dataclass(frozen=True) class InvalidRenderItem: shape: int frame: int x: int y: int z: int source: str reason: str @dataclass class SortNode: item: MapItem info: ShapeInfo frame: ShapeFrame pixels: list[int] left: int top: int right: int bottom: int x: int x_left: int y: int y_far: int z: int z_top: int sx_left: int sx_right: int sx_top: int sy_top: int sx_bot: int sy_bot: int fbigsq: bool flat: bool occl: bool solid: bool draw: bool roof: bool noisy: bool anim: bool trans: bool fixed: bool land: bool sprite: bool invitem: bool occluded: bool = False order: int = -1 depends: list["SortNode"] = field(default_factory=list) def list_less_than(self, other: "SortNode") -> bool: if self.sprite != other.sprite: return self.sprite < other.sprite if self.z != other.z: return self.z < other.z return self.flat > other.flat def overlap(self, other: "SortNode") -> bool: if not rect_intersects(self, other): return False point_top_diff = (self.sx_top - other.sx_bot, self.sy_top - other.sy_bot) point_bot_diff = (self.sx_bot - other.sx_top, self.sy_bot - other.sy_top) dot_top_left = point_top_diff[0] + point_top_diff[1] * 2 dot_top_right = -point_top_diff[0] + point_top_diff[1] * 2 dot_bot_left = point_bot_diff[0] - point_bot_diff[1] * 2 dot_bot_right = -point_bot_diff[0] - point_bot_diff[1] * 2 right_clear = self.sx_right <= other.sx_left left_clear = self.sx_left >= other.sx_right top_left_clear = dot_top_left >= 0 top_right_clear = dot_top_right >= 0 bot_left_clear = dot_bot_left >= 0 bot_right_clear = dot_bot_right >= 0 clear = right_clear or left_clear or (bot_right_clear or bot_left_clear) or (top_right_clear or top_left_clear) return not clear def occludes(self, other: "SortNode") -> bool: if not rect_contains(self, other): return False point_top_diff = (self.sx_top - other.sx_top, self.sy_top - other.sy_top) point_bot_diff = (self.sx_bot - other.sx_bot, self.sy_bot - other.sy_bot) dot_top_left = point_top_diff[0] + point_top_diff[1] * 2 dot_top_right = -point_top_diff[0] + point_top_diff[1] * 2 dot_bot_left = point_bot_diff[0] - point_bot_diff[1] * 2 dot_bot_right = -point_bot_diff[0] - point_bot_diff[1] * 2 right_res = self.sx_right >= other.sx_right left_res = self.sx_left <= other.sx_left top_left_res = dot_top_left <= 0 top_right_res = dot_top_right <= 0 bot_left_res = dot_bot_left <= 0 bot_right_res = dot_bot_right <= 0 return right_res and left_res and bot_right_res and bot_left_res and top_right_res and top_left_res def below(self, other: "SortNode") -> bool: if self.sprite != other.sprite: return self.sprite < other.sprite if self.flat and other.flat: if self.z != other.z: return self.z < other.z elif self.invitem == other.invitem: if self.z_top <= other.z: return True if self.z >= other.z_top: return False y_flat_self = self.y_far == self.y y_flat_other = other.y_far == other.y if y_flat_self and y_flat_other: if self.y // 32 != other.y // 32: return self.y < other.y else: if self.y <= other.y_far: return True if self.y_far >= other.y: return False x_flat_self = self.x_left == self.x x_flat_other = other.x_left == other.x if x_flat_self and x_flat_other: if self.x // 32 != other.x // 32: return self.x < other.x else: if self.x <= other.x_left: return True if self.x_left >= other.x: return False if self.z_top - 8 <= other.z and self.z < other.z_top - 8: return True if self.z >= other.z_top - 8 and self.z_top - 8 > other.z: return False if y_flat_self != y_flat_other: if self.y // 32 <= other.y_far // 32: return True if self.y_far // 32 >= other.y // 32: return False y_center_self = (self.y_far // 32 + self.y // 32) // 2 y_center_other = (other.y_far // 32 + other.y // 32) // 2 if y_center_self != y_center_other: return y_center_self < y_center_other if x_flat_self != x_flat_other: if self.x // 32 <= other.x_left // 32: return True if self.x_left // 32 >= other.x // 32: return False x_center_self = (self.x_left // 32 + self.x // 32) // 2 x_center_other = (other.x_left // 32 + other.x // 32) // 2 if x_center_self != x_center_other: return x_center_self < x_center_other if self.flat or other.flat: if self.z != other.z: return self.z < other.z if self.invitem != other.invitem: return self.invitem < other.invitem if self.flat != other.flat: return self.flat > other.flat if self.trans != other.trans: return self.trans < other.trans if self.anim != other.anim: return self.anim < other.anim if self.draw != other.draw: return self.draw > other.draw if self.solid != other.solid: return self.solid > other.solid if self.occl != other.occl: return self.occl > other.occl if self.fbigsq != other.fbigsq: return self.fbigsq > other.fbigsq if self.x == other.x and self.y == other.y and self.trans != other.trans: return self.trans < other.trans if self.land and other.land and self.roof != other.roof: return self.roof < other.roof if self.roof != other.roof: return self.roof > other.roof if self.z != other.z: return self.z < other.z if x_flat_self or x_flat_other or y_flat_self or y_flat_other: if self.sx_left != other.sx_left: return self.sx_left > other.sx_left if self.sy_bot != other.sy_bot: return self.sy_bot < other.sy_bot if self.x + self.y != other.x + other.y: return self.x + self.y < other.x + other.y if self.x_left + self.y_far != other.x_left + other.y_far: return self.x_left + self.y_far < other.x_left + other.y_far if self.y != other.y: return self.y < other.y if self.x != other.x: return self.x < other.x if self.item.shape != other.item.shape: return self.item.shape < other.item.shape return self.item.frame < other.item.frame def rect_intersects(left: SortNode, right: SortNode) -> bool: return left.left < right.right and left.right > right.left and left.top < right.bottom and left.bottom > right.top def rect_contains(outer: SortNode, inner: SortNode) -> bool: return outer.left <= inner.left and outer.top <= inner.top and outer.right >= inner.right and outer.bottom >= inner.bottom def build_sort_node(item: MapItem, info: ShapeInfo, frame: ShapeFrame, pixels: list[int]) -> SortNode: flipped = bool(item.flags & FLAG_FLIPPED) xdim = info.y * 32 if flipped else info.x * 32 ydim = info.x * 32 if flipped else info.y * 32 zdim = info.z * 8 x = item.x y = item.y z = item.z x_left = x - xdim y_far = y - ydim z_top = z + zdim sx_left = x_left // 4 - y // 4 sx_right = x // 4 - y_far // 4 sx_top = x_left // 4 - y_far // 4 sy_top = x_left // 8 + y_far // 8 - z_top sx_bot = x // 4 - y // 4 sy_bot = x // 8 + y // 8 - z left = sx_bot + frame.xoff - frame.width if flipped else sx_bot - frame.xoff top = sy_bot - frame.yoff right = left + frame.width bottom = top + frame.height return SortNode( item=item, info=info, frame=frame, pixels=pixels, left=left, top=top, right=right, bottom=bottom, x=x, x_left=x_left, y=y, y_far=y_far, z=z, z_top=z_top, sx_left=sx_left, sx_right=sx_right, sx_top=sx_top, sy_top=sy_top, sx_bot=sx_bot, sy_bot=sy_bot, fbigsq=xdim == ydim and xdim >= 128, flat=zdim == 0, occl=info.is_occl and not info.is_translucent, solid=info.is_solid, draw=info.is_draw, roof=info.is_roof, noisy=info.is_noisy, anim=info.anim_type != 0, trans=info.is_translucent, fixed=info.is_fixed, land=info.is_land, sprite=False, invitem=info.is_invitem, ) def insert_dependency_sorted(depends: list[SortNode], node: SortNode) -> bool: for index, current in enumerate(depends): if current is node: return False if node.list_less_than(current): depends.insert(index, node) return True depends.append(node) return True def resolve_paint_order( ordered: list[SortNode], progress: Callable[[str], None] | None = None, checkpoint_every: int = 0, ) -> list[SortNode]: painted: list[SortNode] = [] def visit(node: SortNode) -> None: if node.occluded or node.order >= 0: return node.order = -2 for dependency in node.depends: if dependency.order == -2: break if dependency.order == -1: visit(dependency) node.order = painted[-1].order + 1 if painted else 0 painted.append(node) if progress is not None and checkpoint_every > 0 and len(painted) % checkpoint_every == 0: progress(f"paint resolved={len(painted)} of {len(ordered)}") for node in ordered: if node.order == -1: visit(node) if progress is not None: progress(f"paint complete resolved={len(painted)} of {len(ordered)}") return painted def prepare_sorted_items( items: list[MapItem], archive: ShapeArchive, shape_infos: list[ShapeInfo], progress: Callable[[str], None] | None = None, checkpoint_every: int = 0, max_invalid_details: int = 20, ) -> tuple[int, int, int, int, list[SortNode], int, int, list[InvalidRenderItem]]: ordered: list[SortNode] = [] min_left = sys.maxsize min_top = sys.maxsize max_right = -sys.maxsize max_bottom = -sys.maxsize occluded_count = 0 invalid_item_count = 0 invalid_items: list[InvalidRenderItem] = [] dependency_count = 0 for item_index, item in enumerate(items, start=1): try: frame, pixels = archive.decode_frame(item.shape, item.frame) except (IndexError, ValueError) as error: invalid_item_count += 1 if len(invalid_items) < max_invalid_details: invalid_items.append( InvalidRenderItem( shape=item.shape, frame=item.frame, x=item.x, y=item.y, z=item.z, source=item.source, reason=str(error), ) ) continue node = build_sort_node(item, shape_infos[item.shape], frame, pixels) min_left = min(min_left, node.left) min_top = min(min_top, node.top) max_right = max(max_right, node.right) max_bottom = max(max_bottom, node.bottom) insert_at = len(ordered) for index, other in enumerate(ordered): if insert_at == len(ordered) and node.list_less_than(other): insert_at = index if other.occluded: continue if not node.overlap(other): continue if node.below(other): if other.occl and other.occludes(node): node.occluded = True occluded_count += 1 break if insert_dependency_sorted(other.depends, node): dependency_count += 1 else: if node.occl and node.occludes(other): if not other.occluded: other.occluded = True occluded_count += 1 else: if insert_dependency_sorted(node.depends, other): dependency_count += 1 ordered.insert(insert_at, node) if progress is not None and checkpoint_every > 0 and item_index % checkpoint_every == 0: progress( "sort " f"processed={item_index} valid={len(ordered)} occluded={occluded_count} invalid={invalid_item_count} " f"dependencies={dependency_count}" ) if progress is not None: progress( "sort complete " f"processed={len(items)} valid={len(ordered)} occluded={occluded_count} invalid={invalid_item_count} " f"dependencies={dependency_count}" ) return ( min_left, min_top, max_right, max_bottom, resolve_paint_order(ordered, progress=progress, checkpoint_every=checkpoint_every), occluded_count, invalid_item_count, invalid_items, )