"""Structure-aware masking pipeline excerpt.

This example shows how MegaCpp POC masking keeps token-aligned metadata valid
after a fill-in-the-middle transform. The problem it solves is silent metadata
drift: if token order changes but chunk boundaries or structure labels do not,
the model trains on corrupted supervision.
"""

from __future__ import annotations

from dataclasses import dataclass


@dataclass(frozen=True)
class FimResult:
    tokens: list[int]
    was_transformed: bool
    split_start: int = 0
    split_end: int = 0
    is_spm: bool = False


def permute_metadata_for_fim(
    meta_array: list[int],
    fim_result: FimResult,
    sentinel_value: int = 0,
) -> list[int]:
    """Remap a token-aligned metadata array through a FIM permutation.

    The layout matches the real MegaCpp POC masking contract:
      PSM: [FIM_PREFIX] prefix [FIM_SUFFIX] suffix [FIM_MIDDLE] middle [EOT]
      SPM: [FIM_PREFIX] [FIM_SUFFIX] suffix [FIM_MIDDLE] prefix middle [EOT]
    """
    if not fim_result.was_transformed:
        return meta_array

    split_start, split_end = fim_result.split_start, fim_result.split_end
    prefix_meta = meta_array[:split_start]
    middle_meta = meta_array[split_start:split_end]
    suffix_meta = meta_array[split_end:]
    sentinel = sentinel_value

    if fim_result.is_spm:
        return (
            [sentinel, sentinel]
            + suffix_meta
            + [sentinel]
            + prefix_meta
            + middle_meta
            + [sentinel]
        )
    return (
        [sentinel]
        + prefix_meta
        + [sentinel]
        + suffix_meta
        + [sentinel]
        + middle_meta
        + [sentinel]
    )


def remap_chunk_boundaries_for_fim(
    chunk_boundaries: list[dict],
    fim_result: FimResult,
    original_token_count: int,
) -> tuple[list[dict], list[int], list[int]]:
    """Move chunk offsets into the new FIM token order.

    Chunks that cross the FIM split are dropped. That keeps call/type graph
    metadata conservative instead of pretending partially moved chunks are still
    meaningful.
    """
    if not fim_result.was_transformed or not chunk_boundaries:
        ordered = sorted(chunk_boundaries, key=lambda chunk: chunk["token_offset"])
        starts = [chunk["token_offset"] for chunk in ordered]
        ends = [
            ordered[index + 1]["token_offset"] if index + 1 < len(ordered) else original_token_count
            for index in range(len(ordered))
        ]
        return ordered, starts, ends

    split_start, split_end = fim_result.split_start, fim_result.split_end
    token_count = original_token_count
    if fim_result.is_spm:
        suffix_offset = 2
        prefix_offset = 2 + (token_count - split_end) + 1
        middle_offset = prefix_offset + split_start
    else:
        prefix_offset = 1
        suffix_offset = 1 + split_start + 1
        middle_offset = 1 + split_start + 1 + (token_count - split_end) + 1

    ordered = sorted(chunk_boundaries, key=lambda chunk: int(chunk.get("token_offset", 0)))
    ends_orig = [
        ordered[index + 1]["token_offset"] if index + 1 < len(ordered) else token_count
        for index in range(len(ordered))
    ]

    remapped: list[dict] = []
    starts: list[int] = []
    ends: list[int] = []
    for chunk, chunk_end in zip(ordered, ends_orig):
        chunk_start = int(chunk.get("token_offset", 0))
        if chunk_end <= split_start:
            new_start = prefix_offset + chunk_start
            new_end = prefix_offset + chunk_end
        elif chunk_start >= split_end:
            new_start = suffix_offset + (chunk_start - split_end)
            new_end = suffix_offset + (chunk_end - split_end)
        elif chunk_start >= split_start and chunk_end <= split_end:
            new_start = middle_offset + (chunk_start - split_start)
            new_end = middle_offset + (chunk_end - split_start)
        else:
            continue

        new_chunk = dict(chunk)
        new_chunk["token_offset"] = new_start
        remapped.append(new_chunk)
        starts.append(new_start)
        ends.append(new_end)

    return remapped, starts, ends
