#!/usr/bin/env python3
import argparse
import difflib
from pathlib import Path
import re
import sys


def _strip_size_attr(xml_text, section_name):
    name = re.escape(section_name)
    # Case: name appears before size attribute.
    pattern_before = re.compile(
        r'(<ProgramSection\b[^>]*\bname="' + name + r'"[^>]*?)\s+size="[^"]+"'
    )
    xml_text = pattern_before.sub(r"\1", xml_text)

    # Case: size attribute appears before name.
    pattern_after = re.compile(
        r'(<ProgramSection\b[^>]*?)\s+size="[^"]+"([^>]*\bname="' + name + r'"[^>]*>)'
    )
    xml_text = pattern_after.sub(r"\1\2", xml_text)
    return xml_text


def _iter_flash_placements(paths):
    for path in paths:
        if path.is_file():
            if path.name == "flash_placement.xml":
                yield path
            continue
        if not path.is_dir():
            continue
        for xml_path in path.rglob("flash_placement.xml"):
            if "ses" not in xml_path.parts:
                continue
            yield xml_path


def main(argv):
    parser = argparse.ArgumentParser(
        description="Remove size limits from .text/.rodata in SES flash_placement.xml files."
    )
    parser.add_argument(
        "paths",
        nargs="*",
        help="Files or directories to scan (default: current directory).",
    )
    parser.add_argument(
        "--dry-run",
        action="store_true",
        help="Print files that would be modified without writing.",
    )
    args = parser.parse_args(argv)

    roots = [Path(p) for p in args.paths] if args.paths else [Path.cwd()]

    modified = []
    for xml_path in _iter_flash_placements(roots):
        try:
            xml_text = xml_path.read_text(encoding="utf-8")
        except OSError:
            continue

        updated = _strip_size_attr(xml_text, ".text")
        updated = _strip_size_attr(updated, ".rodata")

        if updated != xml_text:
            modified.append(xml_path)
            if not args.dry_run:
                xml_path.write_text(updated, encoding="utf-8")

    if args.dry_run:
        if modified:
            for path in modified:
                original = path.read_text(encoding="utf-8")
                updated = _strip_size_attr(original, ".text")
                updated = _strip_size_attr(updated, ".rodata")
                diff = difflib.unified_diff(
                    original.splitlines(keepends=True),
                    updated.splitlines(keepends=True),
                    fromfile=str(path),
                    tofile=str(path),
                )
                sys.stdout.writelines(diff)
        else:
            print("no changes needed")
    else:
        if modified:
            print("updated {} file(s)".format(len(modified)))
        else:
            print("no changes needed")

    return 0


if __name__ == "__main__":
    raise SystemExit(main(sys.argv[1:]))
