# pylint: disable=too-many-return-statements
"""Serialization module for the `qrules`.
The `.io` module provides tools to export or import objects from `qrules` to
and from disk, so that they can be used by external packages, or just to store
(cache) the state of the system.
"""
import json
from collections import abc
from pathlib import Path
from typing import Optional
import attr
import yaml
from qrules.particle import Particle, ParticleCollection
from qrules.topology import StateTransitionGraph, Topology
from qrules.transition import (
ProblemSet,
ReactionInfo,
State,
StateTransition,
StateTransitionCollection,
)
from . import _dict, _dot
[docs]def asdict(instance: object) -> dict:
# pylint: disable=protected-access
if isinstance(instance, Particle):
return _dict.from_particle(instance)
if isinstance(instance, ParticleCollection):
return _dict.from_particle_collection(instance)
if isinstance(
instance,
(ReactionInfo, State, StateTransition, StateTransitionCollection),
):
return attr.asdict(
instance,
recurse=True,
filter=lambda attr, _: attr.init,
value_serializer=_dict._value_serializer,
)
if isinstance(instance, StateTransitionGraph):
return _dict.from_stg(instance)
if isinstance(instance, Topology):
return _dict.from_topology(instance)
raise NotImplementedError(
"No conversion for dict available for class"
f" {instance.__class__.__name__}"
)
[docs]def fromdict(definition: dict) -> object:
keys = set(definition.keys())
if __REQUIRED_PARTICLE_FIELDS <= keys:
return _dict.build_particle(definition)
if keys == {"particles"}:
return _dict.build_particle_collection(definition)
if keys == {"transition_groups", "formalism"}:
return _dict.build_reaction_info(definition)
if keys == {"topology", "states", "interactions"}:
return _dict.build_state_transition(definition)
if keys == {"transitions"}:
return _dict.build_stc(definition)
if keys == {"topology", "edge_props", "node_props"}:
return _dict.build_stg(definition)
if keys == __REQUIRED_TOPOLOGY_FIELDS:
return _dict.build_topology(definition)
raise NotImplementedError(f"Could not determine type from keys {keys}")
__REQUIRED_PARTICLE_FIELDS = {
field.name
for field in attr.fields(Particle)
if field.default == attr.NOTHING
}
__REQUIRED_TOPOLOGY_FIELDS = {
field.name for field in attr.fields(Topology) if field.init
}
[docs]def asdot(
instance: object,
*,
render_node: bool = False,
render_final_state_id: bool = True,
render_resonance_id: bool = False,
render_initial_state_id: bool = False,
strip_spin: bool = False,
collapse_graphs: bool = False,
) -> str:
"""Convert a `object` to a DOT language `str`.
Only works for objects that can be represented as a graph, particularly a
`.StateTransitionGraph` or a `list` of `.StateTransitionGraph` instances.
Args:
instance: the input `object` that is to be rendered as DOT (graphviz)
language.
strip_spin: Normally, each `.StateTransitionGraph` has a `.Particle`
with a spin projection on its edges. This option hides the
projections, leaving only `.Particle` names on edges.
collapse_graphs: Group all transitions by equivalent kinematic topology
and combine all allowed particles on each edge.
render_node: Whether or not to render node ID (in the case of a
`.Topology`) and/or node properties (in the case of a
`.StateTransitionGraph`). Meaning of the labels:
- :math:`P`: parity prefactor
- :math:`s`: tuple of **coupled spin** magnitude and its
projection
- :math:`l`: tuple of **angular momentum** and its projection
See `.InteractionProperties` for more info.
render_final_state_id: Add edge IDs for the final state edges.
render_resonance_id: Add edge IDs for the intermediate state edges.
render_initial_state_id: Add edge IDs for the initial state edges.
.. seealso:: :doc:`/usage/visualize`
"""
if isinstance(instance, StateTransition):
instance = instance.to_graph()
if isinstance(instance, (ProblemSet, StateTransitionGraph, Topology)):
return _dot.graph_to_dot(
instance,
render_node=render_node,
render_final_state_id=render_final_state_id,
render_resonance_id=render_resonance_id,
render_initial_state_id=render_initial_state_id,
)
if isinstance(instance, (ReactionInfo, StateTransitionCollection)):
instance = instance.to_graphs()
if isinstance(instance, abc.Iterable):
return _dot.graph_list_to_dot(
instance,
render_node=render_node,
render_final_state_id=render_final_state_id,
render_resonance_id=render_resonance_id,
render_initial_state_id=render_initial_state_id,
strip_spin=strip_spin,
collapse_graphs=collapse_graphs,
)
raise NotImplementedError(
f"Cannot convert a {instance.__class__.__name__} to DOT language"
)
[docs]def load(filename: str) -> object:
with open(filename) as stream:
file_extension = _get_file_extension(filename)
if file_extension == "json":
definition = json.load(stream)
return fromdict(definition)
if file_extension in ["yaml", "yml"]:
definition = yaml.load(stream, Loader=yaml.SafeLoader)
return fromdict(definition)
raise NotImplementedError(
f'No loader defined for file type "{file_extension}"'
)
class _IncreasedIndent(yaml.Dumper):
# pylint: disable=too-many-ancestors
def increase_indent(
self, flow: bool = False, indentless: bool = False
) -> None:
return super().increase_indent(flow, False)
def write_line_break(self, data: Optional[str] = None) -> None:
"""See https://stackoverflow.com/a/44284819."""
super().write_line_break(data)
if len(self.indents) == 1:
super().write_line_break()
[docs]def write(instance: object, filename: str) -> None:
with open(filename, "w") as stream:
file_extension = _get_file_extension(filename)
if file_extension == "json":
json.dump(asdict(instance), stream, indent=2)
return
if file_extension in ["yaml", "yml"]:
yaml.dump(
asdict(instance),
stream,
sort_keys=False,
Dumper=_IncreasedIndent,
default_flow_style=False,
)
return
if file_extension == "gv":
if isinstance(instance, str): # direct output of asdot
output_str = instance
else:
output_str = asdot(instance)
with open(filename, "w") as stream:
stream.write(output_str)
return
raise NotImplementedError(
f'No writer defined for file type "{file_extension}"'
)
def _get_file_extension(filename: str) -> str:
path = Path(filename)
extension = path.suffix.lower()
if not extension:
raise ValueError(f'No file extension in file name "{filename}"')
extension = extension[1:]
return extension