from collections.abc import Callable, Iterator

from etuples.core import ExpressionTuple
from kanren import run
from unification import var
from unification.variable import Var

from pytensor.graph.basic import Apply, Variable
from pytensor.graph.rewriting.basic import NodeRewriter
from pytensor.graph.rewriting.unify import eval_if_etuple


class KanrenRelationSub(NodeRewriter):
    r"""A rewriter that uses `kanren` to match and replace terms.

    See `kanren <https://github.com/pythological/kanren>`__ for more information
    miniKanren and the API for constructing `kanren` goals.

    Example
    -------

    ..code-block:: python

        from kanren import eq, conso, var

        import pytensor.tensor as pt
        from pytensor.graph.rewriting.kanren import KanrenRelationSub


        def relation(in_lv, out_lv):
            # A `kanren` goal that changes `pt.log` terms to `pt.exp`
            cdr_lv = var()
            return eq(conso(pt.log, cdr_lv, in_lv),
                    conso(pt.exp, cdr_lv, out_lv))


        kanren_sub_opt = KanrenRelationSub(relation)

    """

    reentrant = True

    def __init__(
        self,
        kanren_relation: Callable[[Variable, Var], Callable],
        results_filter: Callable[[Iterator], list[ExpressionTuple | Variable] | None]
        | None = None,
        node_filter: Callable[[Apply], bool] = lambda x: True,
    ):
        r"""Create a `KanrenRelationSub`.

        Parameters
        ----------
        kanren_relation
            A function that takes an input graph and an output logic variable and
            returns a `kanren` goal.
        results_filter
            A function that takes the direct output of ``kanren.run(None, ...)``
            and returns a single result.  The default implementation returns
            the first result.
        node_filter
            A function taking a single node and returns ``True`` when the node
            should be processed.
        """
        if results_filter is None:

            def results_filter(
                x: Iterator,
            ) -> list[ExpressionTuple | Variable] | None:
                return next(x, None)

        self.kanren_relation = kanren_relation
        self.results_filter = results_filter
        self.node_filter = node_filter
        super().__init__()

    def transform(self, fgraph, node, enforce_tracks: bool = True):
        if self.node_filter(node) is False:
            return False

        try:
            input_expr = node.default_output()
        except ValueError:
            input_expr = node.outputs

        q = var()
        kanren_results = run(None, q, self.kanren_relation(input_expr, q))

        chosen_res = self.results_filter(kanren_results)  # type: ignore[arg-type]

        if chosen_res:
            if isinstance(chosen_res, list):
                new_outputs = [eval_if_etuple(v) for v in chosen_res]
            else:
                new_outputs = [eval_if_etuple(chosen_res)]  # type: ignore[unreachable]

            return new_outputs
        else:
            return False
