Source code for qrules.system_control

"""Functions that steer operations of `qrules`."""

import logging
from abc import ABC, abstractmethod
from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple, Type

import attrs

from .particle import Particle, ParticleCollection, ParticleWithSpin
from .quantum_numbers import (
    EdgeQuantumNumber,
    EdgeQuantumNumbers,
    InteractionProperties,
    NodeQuantumNumber,
    NodeQuantumNumbers,
    Parity,
)
from .settings import InteractionType
from .solving import GraphEdgePropertyMap, GraphNodePropertyMap, GraphSettings
from .topology import StateTransitionGraph

_LOGGER = logging.getLogger(__name__)

Strength = float

GraphSettingsGroups = Dict[Strength, List[Tuple[StateTransitionGraph, GraphSettings]]]


[docs]def create_edge_properties( particle: Particle, spin_projection: Optional[float] = None, ) -> GraphEdgePropertyMap: edge_qn_mapping: Dict[str, Type[EdgeQuantumNumber]] = { qn_name: qn_type for qn_name, qn_type in EdgeQuantumNumbers.__dict__.items() if not qn_name.startswith("__") } # Note using attrs.fields does not work here because init=False property_map: GraphEdgePropertyMap = {} isospin = None for qn_name, value in attrs.asdict(particle, recurse=False).items(): if isinstance(value, Parity): value = value.value if qn_name in edge_qn_mapping: property_map[edge_qn_mapping[qn_name]] = value else: if "isospin" in qn_name: isospin = value elif "spin" in qn_name: property_map[EdgeQuantumNumbers.spin_magnitude] = value if spin_projection is not None: property_map[EdgeQuantumNumbers.spin_projection] = spin_projection if isospin is not None: property_map[EdgeQuantumNumbers.isospin_magnitude] = isospin.magnitude property_map[EdgeQuantumNumbers.isospin_projection] = isospin.projection return property_map
[docs]def create_node_properties( node_props: InteractionProperties, ) -> GraphNodePropertyMap: node_qn_mapping: Dict[str, Type[NodeQuantumNumber]] = { qn_name: qn_type for qn_name, qn_type in NodeQuantumNumbers.__dict__.items() if not qn_name.startswith("__") } # Note using attrs.fields does not work here because init=False property_map: GraphNodePropertyMap = {} for qn_name, value in attrs.asdict(node_props).items(): if value is None: continue if qn_name in node_qn_mapping: property_map[node_qn_mapping[qn_name]] = value else: msg = ( "Missmatch between InteractionProperties and NodeQuantumNumbers." f" NodeQuantumNumbers does not define {qn_name}" ) raise TypeError(msg) return property_map
[docs]def find_particle( # noqa: D417 state: GraphEdgePropertyMap, particle_db: ParticleCollection ) -> ParticleWithSpin: """Create a Particle with spin projection from a qn dictionary. The implementation assumes the edge properties match the attributes of a particle inside the `.ParticleCollection`. Args: edge_props: The quantum number dictionary. particle_db: A `.ParticleCollection` which is used to retrieve a reference :code:`state` to lower the memory footprint. Raises: KeyError: If the edge properties do not contain the pid information or no particle with the same pid is found in the `.ParticleCollection`. ValueError: If the edge properties do not contain spin projection info. """ particle = particle_db.find(int(state[EdgeQuantumNumbers.pid])) spin_projection = state.get(EdgeQuantumNumbers.spin_projection) if spin_projection is None: msg = f"{GraphEdgePropertyMap.__name__} does not contain a spin projection" raise ValueError(msg) return particle, spin_projection
[docs]def create_interaction_properties( qn_solution: GraphNodePropertyMap, ) -> InteractionProperties: converted_solution = {k.__name__: v for k, v in qn_solution.items()} kw_args = { x.name: converted_solution[x.name] for x in attrs.fields(InteractionProperties) # type: ignore[arg-type] if x.name in converted_solution } return attrs.evolve(InteractionProperties(), **kw_args) # type: ignore[arg-type]
[docs]def filter_interaction_types( valid_determined_interaction_types: List[InteractionType], allowed_interaction_types: List[InteractionType], ) -> List[InteractionType]: int_type_intersection = list( set(allowed_interaction_types) & set(valid_determined_interaction_types) ) if int_type_intersection: return int_type_intersection _LOGGER.warning( ( "The specified list of interaction types %s" " does not intersect with the valid list of interaction types %s" ".\nUsing valid list instead." ), allowed_interaction_types, valid_determined_interaction_types, ) return valid_determined_interaction_types
[docs]class InteractionDeterminator(ABC): """Interface for interaction determination."""
[docs] @abstractmethod def check( self, in_edge_props: List[ParticleWithSpin], out_edge_props: List[ParticleWithSpin], node_props: InteractionProperties, ) -> List[InteractionType]: pass
[docs]class GammaCheck(InteractionDeterminator): """Conservation check for photons."""
[docs] def check( self, in_edge_props: List[ParticleWithSpin], out_edge_props: List[ParticleWithSpin], node_props: InteractionProperties, ) -> List[InteractionType]: int_types = list(InteractionType) for particle, _ in in_edge_props + out_edge_props: if "gamma" in particle.name: int_types = [InteractionType.EM] break return int_types
[docs]class LeptonCheck(InteractionDeterminator): """Conservation check lepton numbers."""
[docs] def check( self, in_edge_props: List[ParticleWithSpin], out_edge_props: List[ParticleWithSpin], node_props: InteractionProperties, ) -> List[InteractionType]: node_interaction_types = list(InteractionType) for particle, _ in in_edge_props + out_edge_props: if particle.is_lepton(): if particle.name.startswith("nu("): node_interaction_types = [InteractionType.WEAK] break node_interaction_types = [ InteractionType.EM, InteractionType.WEAK, ] return node_interaction_types
[docs]def remove_duplicate_solutions( solutions: List[StateTransitionGraph[ParticleWithSpin]], remove_qns_list: Optional[Set[Type[NodeQuantumNumber]]] = None, ignore_qns_list: Optional[Set[Type[NodeQuantumNumber]]] = None, ) -> List[StateTransitionGraph[ParticleWithSpin]]: if remove_qns_list is None: remove_qns_list = set() if ignore_qns_list is None: ignore_qns_list = set() _LOGGER.info("removing duplicate solutions...") _LOGGER.info(f"removing these qns from graphs: {remove_qns_list}") _LOGGER.info(f"ignoring qns in graph comparison: {ignore_qns_list}") filtered_solutions: List[StateTransitionGraph[ParticleWithSpin]] = [] remove_counter = 0 for sol_graph in solutions: sol_graph = _remove_qns_from_graph(sol_graph, remove_qns_list) found_graph = _check_equal_ignoring_qns( sol_graph, filtered_solutions, ignore_qns_list ) if found_graph is None: filtered_solutions.append(sol_graph) else: # check if found solution also has the prefactors # if not overwrite them remove_counter += 1 _LOGGER.info(f"removed {remove_counter} solutions") return filtered_solutions
def _remove_qns_from_graph( graph: StateTransitionGraph[ParticleWithSpin], qn_list: Set[Type[NodeQuantumNumber]], ) -> StateTransitionGraph[ParticleWithSpin]: new_node_props = {} for node_id in graph.topology.nodes: node_props = graph.get_node_props(node_id) new_node_props[node_id] = attrs.evolve( node_props, **{x.__name__: None for x in qn_list} ) return graph.evolve(node_props=new_node_props) def _check_equal_ignoring_qns( ref_graph: StateTransitionGraph, solutions: List[StateTransitionGraph], ignored_qn_list: Set[Type[NodeQuantumNumber]], ) -> Optional[StateTransitionGraph]: """Define equal operator for graphs, ignoring certain quantum numbers.""" if not isinstance(ref_graph, StateTransitionGraph): msg = "Reference graph has to be of type StateTransitionGraph" raise TypeError(msg) found_graph = None node_comparator = NodePropertyComparator(ignored_qn_list) for graph in solutions: if isinstance(graph, StateTransitionGraph) and graph.compare( ref_graph, edge_comparator=lambda e1, e2: e1 == e2, node_comparator=node_comparator, ): found_graph = graph break return found_graph
[docs]class NodePropertyComparator: """Functor for comparing node properties in two graphs.""" def __init__( self, ignored_qn_list: Optional[Set[Type[NodeQuantumNumber]]] = None, ) -> None: self.__ignored_qn_list = ignored_qn_list if ignored_qn_list else set()
[docs] def __call__( self, node_props1: InteractionProperties, node_props2: InteractionProperties, ) -> bool: return attrs.evolve( node_props1, **{x.__name__: None for x in self.__ignored_qn_list}, ) == attrs.evolve( node_props2, **{x.__name__: None for x in self.__ignored_qn_list}, )
[docs]def filter_graphs( graphs: List[StateTransitionGraph], filters: Iterable[Callable[[StateTransitionGraph], bool]], ) -> List[StateTransitionGraph]: r"""Implement filtering of a list of `.StateTransitionGraph` 's. This function can be used to select a subset of `.StateTransitionGraph` 's from a list. Only the graphs passing all supplied filters will be returned. Note: For the more advanced user, lambda functions can be used as filters. Example: Selecting only the solutions, in which the :math:`\rho` decays via p-wave: .. code-block:: python my_filter = require_interaction_property( "rho", InteractionQuantumNumberNames.L, create_spin_domain([1], True), ) filtered_solutions = filter_graphs(solutions, [my_filter]) """ filtered_graphs = graphs for filter_ in filters: if not filtered_graphs: break filtered_graphs = list(filter(filter_, filtered_graphs)) return filtered_graphs
[docs]def require_interaction_property( ingoing_particle_name: str, interaction_qn: Type[NodeQuantumNumber], allowed_values: List, ) -> Callable[[StateTransitionGraph[ParticleWithSpin]], bool]: """Filter function. Closure, which can be used as a filter function in :func:`.filter_graphs`. It selects graphs based on a requirement on the property of specific interaction nodes. Args: ingoing_particle_name: name of particle, used to find nodes which have a particle with this name as "ingoing" interaction_qn: interaction quantum number allowed_values: list of allowed values, that the interaction quantum number may take Return: Callable[Any, bool]: - *True* if the graph has nodes with an ingoing particle of the given name, and the graph fullfills the quantum number requirement - *False* otherwise """ def check(graph: StateTransitionGraph[ParticleWithSpin]) -> bool: node_ids = _find_node_ids_with_ingoing_particle_name( graph, ingoing_particle_name ) if not node_ids: return False for i in node_ids: if ( getattr(graph.get_node_props(i), interaction_qn.__name__) not in allowed_values ): return False return True return check
def _find_node_ids_with_ingoing_particle_name( graph: StateTransitionGraph[ParticleWithSpin], ingoing_particle_name: str ) -> List[int]: topology = graph.topology found_node_ids = [] for node_id in topology.nodes: for edge_id in topology.get_edge_ids_ingoing_to_node(node_id): edge_props = graph.get_edge_props(edge_id) edge_particle_name = edge_props[0].name if str(ingoing_particle_name) in str(edge_particle_name): found_node_ids.append(node_id) break return found_node_ids