Source code for qrules.topology

# pylint: disable=too-many-lines
"""Functionality for `Topology` and `Transition` instances.

.. rubric:: Main interfaces

- `Topology` and its builder functions :func:`create_isobar_topologies` and
  :func:`create_n_body_topology`.
- `Transition` and its two implementations `MutableTransition` and `FrozenTransition`.

.. autolink-preface::

    from qrules.topology import (
        create_isobar_topologies,
        create_n_body_topology,
    )
"""

import copy
import itertools
import logging
import sys
from abc import ABC, abstractmethod
from collections import abc
from functools import total_ordering
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    Dict,
    FrozenSet,
    Generic,
    ItemsView,
    Iterable,
    Iterator,
    KeysView,
    List,
    Mapping,
    Optional,
    Sequence,
    Set,
    Tuple,
    TypeVar,
    ValuesView,
    overload,
)

import attrs
from attrs import define, field, frozen
from attrs.validators import deep_iterable, deep_mapping, instance_of

from qrules._implementers import implement_pretty_repr

if sys.version_info >= (3, 8):
    from typing import Protocol
else:
    from typing_extensions import Protocol

if TYPE_CHECKING:
    from IPython.lib.pretty import PrettyPrinter

_LOGGER = logging.getLogger(__name__)


class _Comparable(Protocol):
    @abstractmethod
    def __lt__(self, other: Any) -> bool:
        ...


KT = TypeVar("KT", bound=_Comparable)
VT = TypeVar("VT")


[docs]@total_ordering class FrozenDict( # pylint: disable=too-many-ancestors abc.Hashable, abc.Mapping, Generic[KT, VT] ): """An **immutable** and **hashable** version of a `dict`. `FrozenDict` makes it possible to make classes hashable if they are decorated with :func:`attr.frozen` and contain `~typing.Mapping`-like attributes. If these attributes were to be implemented with a normal `dict`, the instance is strictly speaking still mutable (even if those attributes are a `property`) and the class is therefore not safely hashable. .. warning:: The keys have to be comparable, that is, they need to have a :meth:`~object.__lt__` method. """ def __init__(self, mapping: Optional[Mapping] = None): self.__mapping: Dict[KT, VT] = {} if mapping is not None: self.__mapping = dict(mapping) self.__hash = hash(None) if len(self.__mapping) != 0: self.__hash = 0 for key_value_pair in self.items(): self.__hash ^= hash(key_value_pair) def __repr__(self) -> str: return f"{type(self).__name__}({self.__mapping})" def _repr_pretty_(self, p: "PrettyPrinter", cycle: bool) -> None: class_name = type(self).__name__ if cycle: p.text(f"{class_name}(...)") else: with p.group(indent=2, open=f"{class_name}({{"): for key, value in self.items(): p.breakable() p.text(f"{key}: ") p.pretty(value) # type: ignore[attr-defined] p.text(",") p.breakable() p.text("})") def __iter__(self) -> Iterator[KT]: return iter(self.__mapping) def __len__(self) -> int: return len(self.__mapping) def __getitem__(self, key: KT) -> VT: return self.__mapping[key] def __gt__(self, other: Any) -> bool: if isinstance(other, abc.Mapping): sorted_self = _convert_mapping_to_sorted_tuple(self) sorted_other = _convert_mapping_to_sorted_tuple(other) return sorted_self > sorted_other raise NotImplementedError( f"Can only compare {type(self).__name__} with a mapping," f" not with {type(other).__name__}" ) def __hash__(self) -> int: return self.__hash def keys(self) -> KeysView[KT]: return self.__mapping.keys() def items(self) -> ItemsView[KT, VT]: return self.__mapping.items() def values(self) -> ValuesView[VT]: return self.__mapping.values()
def _convert_mapping_to_sorted_tuple( mapping: Mapping[KT, VT], ) -> Tuple[Tuple[KT, VT], ...]: return tuple((key, mapping[key]) for key in sorted(mapping.keys())) def _to_optional_int(optional_int: Optional[int]) -> Optional[int]: if optional_int is None: return None return int(optional_int)
[docs]@frozen(order=True) class Edge: """Struct-like definition of an edge, used in `Topology.edges`.""" originating_node_id: Optional[int] = field(default=None, converter=_to_optional_int) """Node ID where the `Edge` **starts**. An `Edge` is **incoming to** a `Topology` if its `originating_node_id` is `None` (see `~Topology.incoming_edge_ids`). """ ending_node_id: Optional[int] = field(default=None, converter=_to_optional_int) """Node ID where the `Edge` **ends**. An `Edge` is **outgoing from** a `Topology` if its `ending_node_id` is `None` (see `~Topology.outgoing_edge_ids`). """
[docs] def get_connected_nodes(self) -> Set[int]: """Get all node IDs to which the `Edge` is connected.""" connected_nodes = {self.ending_node_id, self.originating_node_id} connected_nodes.discard(None) return connected_nodes # type: ignore[return-value]
def _to_topology_nodes(inst: Iterable[int]) -> FrozenSet[int]: return frozenset(inst) def _to_topology_edges(inst: Mapping[int, Edge]) -> FrozenDict[int, Edge]: return FrozenDict(inst)
[docs]@implement_pretty_repr @frozen(order=True) class Topology: # noqa: D416 """Directed Feynman-like graph without edge or node properties. A `Topology` is **directed** in the sense that its edges are ingoing and outgoing to specific nodes. This is to mimic Feynman graphs, which have a time axis. Note that a `Topology` is not strictly speaking a graph from graph theory, because it allows open edges, like a Feynman-diagram. The edges and nodes can be provided with properties with a `Transition`, which contains a `~Transition.topology`. As opposed to a `MutableTopology`, a `Topology` is frozen, hashable, and ordered, so that it can be used as a kind of fingerprint for a `Transition`. In addition, the IDs of `edges` are guaranteed to be sequential integers and follow a specific pattern: - `incoming_edge_ids` (`~Transition.initial_states`) are always negative. - `outgoing_edge_ids` (`~Transition.final_states`) lie in the range :code:`0...n-1` with :code:`n` the number of final states. - `intermediate_edge_ids` continue counting from :code:`n`. See also :meth:`MutableTopology.organize_edge_ids`. Example ------- **Isobar decay** topologies can best be created as follows: >>> topologies = create_isobar_topologies(number_of_final_states=3) >>> len(topologies) 1 >>> topologies[0] Topology(nodes=..., edges=...) """ nodes: FrozenSet[int] = field( converter=_to_topology_nodes, validator=deep_iterable(member_validator=instance_of(int)), ) """A node is a point where different `edges` connect.""" edges: FrozenDict[int, Edge] = field( converter=_to_topology_edges, validator=deep_mapping( key_validator=instance_of(int), value_validator=instance_of(Edge) ), ) """Mapping of edge IDs to their corresponding `Edge` definition.""" incoming_edge_ids: FrozenSet[int] = field(init=False, repr=False) """Edge IDs of edges that have no `~Edge.originating_node_id`. `Transition.initial_states` provide properties for these edges. """ outgoing_edge_ids: FrozenSet[int] = field(init=False, repr=False) """Edge IDs of edges that have no `~Edge.ending_node_id`. `Transition.final_states` provide properties for these edges. """ intermediate_edge_ids: FrozenSet[int] = field(init=False, repr=False) """Edge IDs of edges that connect two `nodes`.""" def __attrs_post_init__(self) -> None: self.__verify() incoming = sorted( edge_id for edge_id, edge in self.edges.items() if edge.originating_node_id is None ) outgoing = sorted( edge_id for edge_id, edge in self.edges.items() if edge.ending_node_id is None ) inter = sorted(set(self.edges) - set(incoming) - set(outgoing)) expected = list(range(-len(incoming), 0)) if sorted(incoming) != expected: raise ValueError(f"Incoming edge IDs should be {expected}, not {incoming}.") n_out = len(outgoing) expected = list(range(0, n_out)) if sorted(outgoing) != expected: raise ValueError(f"Outgoing edge IDs should be {expected}, not {outgoing}.") expected = list(range(n_out, n_out + len(inter))) if sorted(inter) != expected: raise ValueError(f"Intermediate edge IDs should be {expected}.") object.__setattr__(self, "incoming_edge_ids", frozenset(incoming)) object.__setattr__(self, "outgoing_edge_ids", frozenset(outgoing)) object.__setattr__(self, "intermediate_edge_ids", frozenset(inter)) def __verify(self) -> None: """Verify if there are no dangling edges or nodes.""" for edge_id, edge in self.edges.items(): connected_nodes = edge.get_connected_nodes() if not connected_nodes: raise ValueError( f"Edge nr. {edge_id} is not connected to any other node ({edge})" ) if not connected_nodes <= self.nodes: raise ValueError( f"{edge} (ID: {edge_id}) has non-existing node IDs.\n" f"Available node IDs: {self.nodes}" ) self.__check_isolated_nodes() def __check_isolated_nodes(self) -> None: if len(self.nodes) < 2: return for node_id in self.nodes: surrounding_nodes = self.__get_surrounding_nodes(node_id) if not surrounding_nodes: raise ValueError(f"Node {node_id} is not connected to any other node") def __get_surrounding_nodes(self, node_id: int) -> Set[int]: surrounding_nodes = set() for edge in self.edges.values(): connected_nodes = edge.get_connected_nodes() if node_id in connected_nodes: surrounding_nodes |= connected_nodes surrounding_nodes.discard(node_id) return surrounding_nodes
[docs] def is_isomorphic(self, other: "Topology") -> bool: """Check if two graphs are isomorphic. Returns `True` if the two graphs have a one-to-one mapping of the node IDs and edge IDs. .. warning:: Not yet implemented. """ raise NotImplementedError
[docs] def get_edge_ids_ingoing_to_node(self, node_id: int) -> Set[int]: return { edge_id for edge_id, edge in self.edges.items() if edge.ending_node_id == node_id }
[docs] def get_edge_ids_outgoing_from_node(self, node_id: int) -> Set[int]: return { edge_id for edge_id, edge in self.edges.items() if edge.originating_node_id == node_id }
[docs] def get_originating_final_state_edge_ids(self, node_id: int) -> Set[int]: fs_edges = self.outgoing_edge_ids edge_ids = set() temp_edge_list = self.get_edge_ids_outgoing_from_node(node_id) while temp_edge_list: new_temp_edge_list = set() for edge_id in temp_edge_list: if edge_id in fs_edges: edge_ids.add(edge_id) else: new_node_id = self.edges[edge_id].ending_node_id if new_node_id is not None: new_temp_edge_list.update( self.get_edge_ids_outgoing_from_node(new_node_id) ) temp_edge_list = new_temp_edge_list return edge_ids
[docs] def get_originating_initial_state_edge_ids(self, node_id: int) -> Set[int]: is_edges = self.incoming_edge_ids edge_ids: Set[int] = set() temp_edge_list = self.get_edge_ids_ingoing_to_node(node_id) while temp_edge_list: new_temp_edge_list = set() for edge_id in temp_edge_list: if edge_id in is_edges: edge_ids.add(edge_id) else: new_node_id = self.edges[edge_id].originating_node_id if new_node_id is not None: new_temp_edge_list.update( self.get_edge_ids_ingoing_to_node(new_node_id) ) temp_edge_list = new_temp_edge_list return edge_ids
[docs] def relabel_edges(self, old_to_new: Mapping[int, int]) -> "Topology": """Create a new `Topology` with new edge IDs. This method is particularly useful when creating permutations of a `Topology`, e.g.: >>> topologies = create_isobar_topologies(3) >>> len(topologies) 1 >>> topology = topologies[0] >>> final_state_ids = topology.outgoing_edge_ids >>> permuted_topologies = { ... topology.relabel_edges(dict(zip(final_state_ids, permutation))) ... for permutation in itertools.permutations(final_state_ids) ... } >>> len(permuted_topologies) 3 """ new_to_old = {j: i for i, j in old_to_new.items()} new_edges = { old_to_new.get(i, new_to_old.get(i, i)): edge for i, edge in self.edges.items() } return attrs.evolve(self, edges=new_edges)
[docs] def swap_edges(self, edge_id1: int, edge_id2: int) -> "Topology": return self.relabel_edges({edge_id1: edge_id2, edge_id2: edge_id1})
[docs]def get_originating_node_list(topology: Topology, edge_ids: Iterable[int]) -> List[int]: """Get list of node ids from which the supplied edges originate from. Args: topology: The `Topology` on which to perform the search. edge_ids ([int]): A list of edge ids for which the origin node is searched for. """ def __get_originating_node(edge_id: int) -> Optional[int]: return topology.edges[edge_id].originating_node_id return [node_id for node_id in map(__get_originating_node, edge_ids) if node_id]
def _to_mutable_topology_nodes(inst: Iterable[int]) -> Set[int]: return set(inst) def _to_mutable_topology_edges(inst: Mapping[int, Edge]) -> Dict[int, Edge]: return dict(inst)
[docs]@define class MutableTopology: """Mutable version of a `Topology`. A `MutableTopology` can be used to conveniently build up a `Topology` (see e.g. `SimpleStateTransitionTopologyBuilder`). It does not have restrictions on the numbering of edge and node IDs. """ nodes: Set[int] = field( converter=_to_mutable_topology_nodes, factory=set, on_setattr=deep_iterable(member_validator=instance_of(int)), ) """See `Topology.nodes`.""" edges: Dict[int, Edge] = field( converter=_to_mutable_topology_edges, factory=dict, on_setattr=deep_mapping( key_validator=instance_of(int), value_validator=instance_of(Edge) ), ) """See `Topology.edges`."""
[docs] def add_node(self, node_id: int) -> None: """Adds a node with number :code:`node_id`. Raises: ValueError: if :code:`node_id` already exists in `nodes`. """ if node_id in self.nodes: raise ValueError(f"Node nr. {node_id} already exists") self.nodes.add(node_id)
[docs] def add_edges(self, edge_ids: Iterable[int]) -> None: """Add edges with the ids in the :code:`edge_ids` list. Raises: ValueError: if :code:`edge_ids` already exist in `edges`. """ for edge_id in edge_ids: if edge_id in self.edges: raise ValueError(f"Edge nr. {edge_id} already exists") self.edges[edge_id] = Edge()
[docs] def attach_edges_to_node_ingoing( self, ingoing_edge_ids: Iterable[int], node_id: int ) -> None: """Attach existing edges to nodes. So that the are ingoing to these nodes. Args: ingoing_edge_ids ([int]): list of edge ids, that will be attached node_id (int): id of the node to which the edges will be attached Raises: ValueError: if an edge not doesn't exist. ValueError: if an edge ID is already an ingoing node. """ # first check if the ingoing edges are all available for edge_id in ingoing_edge_ids: if edge_id not in self.edges: raise ValueError(f"Edge nr. {edge_id} does not exist") if self.edges[edge_id].ending_node_id is not None: raise ValueError( f"Edge nr. {edge_id} is already ingoing to" f" node {self.edges[edge_id].ending_node_id}" ) # update the newly connected edges for edge_id in ingoing_edge_ids: edge = self.edges[edge_id] self.edges[edge_id] = Edge( ending_node_id=node_id, originating_node_id=edge.originating_node_id, )
[docs] def attach_edges_to_node_outgoing( self, outgoing_edge_ids: Iterable[int], node_id: int ) -> None: # first check if the ingoing edges are all available for edge_id in outgoing_edge_ids: if edge_id not in self.edges: raise ValueError(f"Edge nr. {edge_id} does not exist") if self.edges[edge_id].originating_node_id is not None: raise ValueError( f"Edge nr. {edge_id} is already outgoing from" f" node {self.edges[edge_id].originating_node_id}" ) # update the edges for edge_id in outgoing_edge_ids: edge = self.edges[edge_id] self.edges[edge_id] = Edge( ending_node_id=edge.ending_node_id, originating_node_id=node_id, )
[docs] def organize_edge_ids(self) -> "MutableTopology": """Organize edge IDS so that they lie in range :code:`[-m, n+i]`. Here, :code:`m` is the number of `.incoming_edge_ids`, :code:`n` is the number of `.outgoing_edge_ids`, and :code:`i` is the number of `.intermediate_edge_ids`. In other words, relabel the edges so that: - incoming edge IDs lie in the range :code:`[-1, -2, ...]`, - outgoing edge IDs lie in the range :code:`[0, 1, ..., n]`, - intermediate edge IDs lie in the range :code:`[n+1, n+2, ...]`. """ incoming = { i for i, edge in self.edges.items() if edge.originating_node_id is None } outgoing = { edge_id for edge_id, edge in self.edges.items() if edge.ending_node_id is None } intermediate = set(self.edges) - incoming - outgoing new_to_old_id = enumerate( list(incoming) + list(outgoing) + list(intermediate), start=-len(incoming), ) old_to_new_id = {j: i for i, j in new_to_old_id} new_edges = {old_to_new_id.get(i, i): edge for i, edge in self.edges.items()} return attrs.evolve(self, edges=new_edges)
[docs] def freeze(self) -> Topology: """Create an immutable `Topology` from this `MutableTopology`. You may need to call :meth:`organize_edge_ids` first. """ return Topology(self.nodes, self.edges)
[docs]@define class InteractionNode: """Helper class for the `.SimpleStateTransitionTopologyBuilder`.""" number_of_ingoing_edges: int = field(validator=instance_of(int)) number_of_outgoing_edges: int = field(validator=instance_of(int)) def __attrs_post_init__(self) -> None: if self.number_of_ingoing_edges < 1: raise ValueError("Number of incoming edges has to be larger than 0") if self.number_of_outgoing_edges < 1: raise ValueError("Number of outgoing edges has to be larger than 0")
[docs]class SimpleStateTransitionTopologyBuilder: """Simple topology builder. Recursively tries to add the interaction nodes to available open end edges/lines in all combinations until the number of open end lines matches the final state lines. """ def __init__(self, interaction_node_set: Iterable[InteractionNode]) -> None: if not isinstance(interaction_node_set, list): raise TypeError("interaction_node_set must be a list") self.interaction_node_set: List[InteractionNode] = list(interaction_node_set)
[docs] def build( self, number_of_initial_edges: int, number_of_final_edges: int ) -> Tuple[Topology, ...]: number_of_initial_edges = int(number_of_initial_edges) number_of_final_edges = int(number_of_final_edges) if number_of_initial_edges < 1: raise ValueError("number_of_initial_edges has to be larger than 0") if number_of_final_edges < 1: raise ValueError("number_of_final_edges has to be larger than 0") _LOGGER.info("building topology graphs...") # result list graph_tuple_list: List[Tuple[MutableTopology, List[int]]] = [] # create seed graph seed_graph = MutableTopology() current_open_end_edges = list(range(number_of_initial_edges)) seed_graph.add_edges(current_open_end_edges) extendable_graph_list = [(seed_graph, current_open_end_edges)] while extendable_graph_list: active_graph_list = extendable_graph_list extendable_graph_list = [] for active_graph in active_graph_list: # check if finished if ( len(active_graph[1]) == number_of_final_edges and len(active_graph[0].nodes) > 0 ): graph_tuple_list.append(active_graph) continue extendable_graph_list.extend(self._extend_graph(active_graph)) _LOGGER.info("finished building topology graphs...") # strip the current open end edges list from the result graph tuples topologies = [] for graph_tuple in graph_tuple_list: topology = graph_tuple[0] topology = topology.organize_edge_ids() topologies.append(topology.freeze()) return tuple(topologies)
def _extend_graph( self, pair: Tuple[MutableTopology, Sequence[int]] ) -> List[Tuple[MutableTopology, List[int]]]: extended_graph_list: List[Tuple[MutableTopology, List[int]]] = [] topology, current_open_end_edges = pair # Try to extend the graph with interaction nodes # that have equal or less ingoing lines than active lines for interaction_node in self.interaction_node_set: if interaction_node.number_of_ingoing_edges <= len(current_open_end_edges): # make all combinations combis = list( itertools.combinations( current_open_end_edges, interaction_node.number_of_ingoing_edges, ) ) # remove all combinations that originate from the same nodes for comb1, comb2 in itertools.combinations(combis, 2): if get_originating_node_list( topology, comb1 # type: ignore[arg-type] ) == get_originating_node_list( topology, comb2 # type: ignore[arg-type] ): # type: ignore[arg-type] combis.remove(comb2) for combi in combis: new_graph = _attach_node_to_edges(pair, interaction_node, combi) extended_graph_list.append(new_graph) return extended_graph_list
[docs]def create_isobar_topologies( number_of_final_states: int, ) -> Tuple[Topology, ...]: """Builder function to create a set of unique isobar decay topologies. Args: number_of_final_states: The number of `~Topology.outgoing_edge_ids` (`~.Transition.final_states`). Returns: A sorted `tuple` of non-isomorphic `Topology` instances, all with the same number of final states. Example: >>> topologies = create_isobar_topologies(number_of_final_states=4) >>> len(topologies) 2 >>> len(topologies[0].outgoing_edge_ids) 4 >>> len(set(topologies)) # hashable 2 >>> list(topologies) == sorted(topologies) # ordered True """ if number_of_final_states < 2: raise ValueError("At least two final states required for an isobar decay") builder = SimpleStateTransitionTopologyBuilder([InteractionNode(1, 2)]) topologies = builder.build( number_of_initial_edges=1, number_of_final_edges=number_of_final_states, ) return tuple(sorted(topologies))
[docs]def create_n_body_topology( number_of_initial_states: int, number_of_final_states: int ) -> Topology: """Create a `Topology` that connects all edges through a single node. These types of ":math:`n`-body topologies" are particularly important for :func:`.check_reaction_violations` and :mod:`.conservation_rules`. Args: number_of_initial_states: The number of `~Topology.incoming_edge_ids` (`~.Transition.initial_states`). number_of_final_states: The number of `~Topology.outgoing_edge_ids` (`~.Transition.final_states`). Example: >>> topology = create_n_body_topology( ... number_of_initial_states=2, ... number_of_final_states=5, ... ) >>> topology Topology(nodes=..., edges...) >>> len(topology.nodes) 1 >>> len(topology.incoming_edge_ids) 2 >>> len(topology.outgoing_edge_ids) 5 """ n_in = number_of_initial_states n_out = number_of_final_states builder = SimpleStateTransitionTopologyBuilder( [ InteractionNode( number_of_ingoing_edges=n_in, number_of_outgoing_edges=n_out, ) ] ) topologies = builder.build( number_of_initial_edges=n_in, number_of_final_edges=n_out, ) decay_name = f"{n_in} to {n_out}" if len(topologies) == 0: raise ValueError(f"Could not create n-body decay for {decay_name}") if len(topologies) > 1: raise RuntimeError(f"Several n-body decays for {decay_name}") topology = next(iter(topologies)) return topology
def _attach_node_to_edges( graph: Tuple[MutableTopology, Sequence[int]], interaction_node: InteractionNode, ingoing_edge_ids: Iterable[int], ) -> Tuple[MutableTopology, List[int]]: temp_graph = copy.deepcopy(graph[0]) new_open_end_lines = list(copy.deepcopy(graph[1])) # add node new_node_id = len(temp_graph.nodes) temp_graph.add_node(new_node_id) # attach the edges to the node temp_graph.attach_edges_to_node_ingoing(ingoing_edge_ids, new_node_id) # update the newly connected edges for edge_id in ingoing_edge_ids: new_open_end_lines.remove(edge_id) # make new edges for the outgoing lines new_edge_start_id = len(temp_graph.edges) new_edge_ids = list( range( new_edge_start_id, new_edge_start_id + interaction_node.number_of_outgoing_edges, ) ) temp_graph.add_edges(new_edge_ids) temp_graph.attach_edges_to_node_outgoing(new_edge_ids, new_node_id) for edge_id in new_edge_ids: new_open_end_lines.append(edge_id) return (temp_graph, new_open_end_lines) # pylint: disable=invalid-name EdgeType = TypeVar("EdgeType") NodeType = TypeVar("NodeType") NewEdgeType = TypeVar("NewEdgeType") NewNodeType = TypeVar("NewNodeType")
[docs]class Transition(ABC, Generic[EdgeType, NodeType]): """Mapping of edge and node properties over a `.Topology`. This **interface** class describes a transition from an initial state to a final state by providing a mapping of properties over the `~Topology.edges` and `~Topology.nodes` of its `topology`. Since a `Topology` behaves like a Feynman graph, **edges** are considered as "`states`" and **nodes** are considered as `interactions` between those states. There are two implementation classes: - `FrozenTransition`: a complete, hashable and ordered mapping of properties over the `~Topology.edges` and `~Topology.nodes` in its `~FrozenTransition.topology`. - `MutableTransition`: comparable to `MutableTopology` in that it is used internally when finding solutions through the `.StateTransitionManager` etc. These classes are also provided with **mixin** attributes `initial_states`, `final_states`, `intermediate_states`, and :meth:`filter_states`. """ @property @abstractmethod def topology(self) -> Topology: """`Topology` over which `states` and `interactions` are defined.""" @property @abstractmethod def states(self) -> Mapping[int, EdgeType]: """Mapping of properties over its `topology` `~Topology.edges`.""" @property @abstractmethod def interactions(self) -> Mapping[int, NodeType]: """Mapping of properties over its `topology` `~Topology.nodes`.""" @property def initial_states(self) -> Dict[int, EdgeType]: """Properties for the `~Topology.incoming_edge_ids`.""" return self.filter_states(self.topology.incoming_edge_ids) @property def final_states(self) -> Dict[int, EdgeType]: """Properties for the `~Topology.outgoing_edge_ids`.""" return self.filter_states(self.topology.outgoing_edge_ids) @property def intermediate_states(self) -> Dict[int, EdgeType]: """Properties for the intermediate edges (connecting two nodes).""" return self.filter_states(self.topology.intermediate_edge_ids)
[docs] def filter_states(self, edge_ids: Iterable[int]) -> Dict[int, EdgeType]: """Filter `states` by a selection of :code:`edge_ids`.""" return {i: self.states[i] for i in edge_ids}
[docs]@implement_pretty_repr @frozen(order=True) class FrozenTransition(Transition, Generic[EdgeType, NodeType]): """Defines a frozen mapping of edge and node properties on a `Topology`.""" topology: Topology = field(validator=instance_of(Topology)) states: FrozenDict[int, EdgeType] = field(converter=FrozenDict) interactions: FrozenDict[int, NodeType] = field(converter=FrozenDict) def __attrs_post_init__(self) -> None: _assert_all_defined(self.topology.nodes, self.interactions) _assert_all_defined(self.topology.edges, self.states)
[docs] def unfreeze(self) -> "MutableTransition[EdgeType, NodeType]": """Convert into a `MutableTransition`.""" return MutableTransition(self.topology, self.states, self.interactions)
@overload def convert(self) -> "FrozenTransition[EdgeType, NodeType]": ... @overload def convert( self, state_converter: Callable[[EdgeType], NewEdgeType] ) -> "FrozenTransition[NewEdgeType, NodeType]": ... @overload def convert( self, *, interaction_converter: Callable[[NodeType], NewNodeType] ) -> "FrozenTransition[EdgeType, NewNodeType]": ... @overload def convert( self, state_converter: Callable[[EdgeType], NewEdgeType], interaction_converter: Callable[[NodeType], NewNodeType], ) -> "FrozenTransition[NewEdgeType, NewNodeType]": ...
[docs] def convert(self, state_converter=None, interaction_converter=None): # type: ignore[no-untyped-def] """Cast the edge and/or node properties to another type.""" # pylint: disable=unnecessary-lambda if state_converter is None: state_converter = _identity_function if interaction_converter is None: interaction_converter = _identity_function return FrozenTransition( self.topology, states={i: state_converter(state) for i, state in self.states.items()}, interactions={ i: interaction_converter(interaction) for i, interaction in self.interactions.items() }, )
def _identity_function(obj: Any) -> Any: return obj def _cast_states(obj: Mapping[int, EdgeType]) -> Dict[int, EdgeType]: return dict(obj) def _cast_interactions(obj: Mapping[int, NodeType]) -> Dict[int, NodeType]: return dict(obj)
[docs]@implement_pretty_repr @define class MutableTransition(Transition, Generic[EdgeType, NodeType]): """Mutable implementation of a `Transition`. Mainly used internally by the `.StateTransitionManager` to build solutions. """ topology: Topology = field(validator=instance_of(Topology)) states: Dict[int, EdgeType] = field(converter=_cast_states, factory=dict) interactions: Dict[int, NodeType] = field( converter=_cast_interactions, factory=dict )
[docs] def compare( self, other: "MutableTransition", state_comparator: Optional[Callable[[EdgeType, EdgeType], bool]] = None, interaction_comparator: Optional[Callable[[NodeType, NodeType], bool]] = None, ) -> bool: if self.topology != other.topology: return False if state_comparator is not None: for i in self.topology.edges: if not state_comparator(self.states[i], other.states[i]): return False if interaction_comparator is not None: for i in self.topology.nodes: if not interaction_comparator( self.interactions[i], other.interactions[i] ): return False return True
[docs] def swap_edges(self, edge_id1: int, edge_id2: int) -> None: self.topology = self.topology.swap_edges(edge_id1, edge_id2) value1: Optional[EdgeType] = None value2: Optional[EdgeType] = None if edge_id1 in self.states: value1 = self.states.pop(edge_id1) if edge_id2 in self.states: value2 = self.states.pop(edge_id2) if value1 is not None: self.states[edge_id2] = value1 if value2 is not None: self.states[edge_id1] = value2
[docs] def freeze(self) -> "FrozenTransition[EdgeType, NodeType]": """Convert into a `FrozenTransition`.""" return FrozenTransition(self.topology, self.states, self.interactions)
def _assert_all_defined(items: Iterable, properties: Iterable) -> None: existing = set(items) defined = set(properties) if existing & defined != existing: raise ValueError( "Some items have no property assigned to them." f" Available items: {existing}, items with property: {defined}" ) # pyright: reportUnusedFunction=false def _assert_not_overdefined(items: Iterable, properties: Iterable) -> None: existing = set(items) defined = set(properties) over_defined = defined - existing if over_defined: raise ValueError( "Properties have been defined for items that don't exist." f" Available items: {existing}, over-defined: {over_defined}" )