"""
All components in atomiq derive from the :class:`Component` class. In addition they can inherit from one or more
primitive classes, depending on the properties of the respective component. If, for example, the component represents
a device that can measure something, it will inherit from :class:`Measurable`. If on the other hand, it can be
switched, it will inherit from :class:`Switchable`.
"""
from __future__ import annotations
from artiq.experiment import kernel, delay
from artiq.language.environment import HasEnvironment
from artiq.language.types import TList, TStr, TFloat, TBool
from artiq.language.units import ms
from atomiq.hooks import hooks
import inspect
import logging
import copy
logging.basicConfig()
logger = logging.getLogger(__name__)
[docs]
def add_or_append(obj, member, value):
    if not hasattr(obj, member):
        setattr(obj, member, value)
    else:
        setattr(obj, member, getattr(obj, member) + value) 
[docs]
class Component(HasEnvironment):
    """An atomiq Component
    Every component in atomiq inherits from this class. It provides basic functionality for automatic and recursive
    building and initialization of components (prepare, build, prerun). It also takes care for joining kernel
    invariants along the inheritance tree.
    .. Note::
       The arguments `parent` and `identifier` are automatically passed to the component object by the atomiq object
       builder.
    Args:
        parent: The parent context of the component. Usually this is the experiment that uses the component
        identifier: A unique name to identify the component.
        debug_output: Set whether the component should show debug output. Using this switch rather than the debug
                      kernel logger can allow the compiler to not include the debug commands in the kernel code if the
                      output is not needed.
    """
    kernel_invariants = {"identifier", "core", "experiment", "debug_output"}
    def __init__(self, parent, identifier: TStr, debug_output: TBool = False, *args, **kwargs):
        self.experiment = parent.experiment
        super().__init__(self.experiment)
        self.identifier = identifier
        self.debug_output = debug_output
        self.core = self.experiment.core
        # combine kernel_invariants from parent class with ourselfs
        self._kernel_invariants = copy.copy(self.kernel_invariants)
        for cls in inspect.getmro(type(self))[1:]:
            if hasattr(cls, "kernel_invariants"):
                self._kernel_invariants.update(cls.kernel_invariants)
        logger.debug(f"{self.identifier}: Joined kernel_invariants: {self._kernel_invariants}")
        for hook, _ in hooks:
            setattr(self, f"_{hook}_done", False)
        self._prepare_done = False
        self._build_done = False
        self._hooks_done = []
    def _recursive_prepare(self):
        for arg in self.__dict__:
            if isinstance(getattr(self, arg), Component):
                getattr(self, arg)._recursive_prepare()
        # use prepare state to set unified kernel invariants
        self.kernel_invariants = self._kernel_invariants
        if not self._prepare_done:
            if self.__class__._prepare != Component._prepare:
                logging.info(f"Doing prepare for {self.identifier}")
                self._prepare()
            self._prepare_done = True
    def _prepare(self):
        """
        Specify here what should be done for this component in the prepare phase
        """
        pass
    def _recursive_build(self):
        for arg in self.__dict__:
            if isinstance(getattr(self, arg), Component):
                getattr(self, arg)._recursive_build()
        if not self._build_done:
            if self.__class__._build != Component._build:
                logging.info(f"Doing build for {self.identifier}")
                self._build()
            self._build_done = True
    def _build(self):
        """
        Specify here what should be done for this component in the build phase
        """
        pass
    @kernel
    def _do_prerun(self):
        self.experiment.core.break_realtime()
        delay(0.5*ms)
        self.experiment.log.info("Doing prerun for {0}", [self.identifier])
        self._prerun()
[docs]
    def required_components(self, ancestors=[]):
        req_components = []
        for arg in self.__dict__:
            if isinstance(getattr(self, arg), Component):
                req_components += getattr(self, arg).required_components(ancestors + [arg])
        if isinstance(self, Component):
            return req_components + [(self, ancestors)]
        else:
            return [] 
    @kernel
    def _prerun(self):
        """
        Specify here what should be done for this component before the run starts. In contrast to the _build() method,
        the _prerun() routine is executed on the core device before the actual experiment starts.
        """
        pass 
[docs]
class Measurable():
    kernel_invariants = {"channels"}
    def __init__(self, channels: TList(TStr)):
        """
        A Measurable has one or more channels at which data can be measured
        """
        add_or_append(self, "channels", channels)
[docs]
    @kernel
    def measure(self, channel: TStr = "") -> TFloat:
        raise NotImplementedError("Implement measure method") 
[docs]
    def measurement_channels(self):
        return self.channels 
 
[docs]
class Triggerable():
    kernel_invariants = {"channels"}
    def __init__(self, channels: TList(TStr)):
        """
        A Triggerable has one or more channel(s) that can be triggered
        """
        add_or_append(self, "channels", channels)
[docs]
    @kernel
    def fire(self, channel: TStr = ""):
        raise NotImplementedError("Implement fire() method") 
 
[docs]
class Switchable():
    kernel_invariants = {"channels"}
    def __init__(self, channels: TList(TStr)):
        """
        A Switchable has one or more channel(s) that can be switched on or off
        """
        add_or_append(self, "channels", channels)
[docs]
    @kernel
    def on(self, channel: TStr = None):
        raise NotImplementedError("Implement on() method") 
[docs]
    @kernel
    def off(self, channel: TStr = None):
        raise NotImplementedError("Implement off() method") 
[docs]
    @kernel
    def is_on(self, channel: TStr = None):
        raise NotImplementedError("Implement is_on() method") 
[docs]
    @kernel
    def toggle(self, channel: TStr = None):
        if self.is_on(channel):
            self.off(channel)
        else:
            self.on(channel) 
[docs]
    @kernel
    def pulse(self, pulsetime: TFloat, channel: TStr = ""):
        self.on()
        delay(pulsetime)
        self.off() 
 
[docs]
class Parametrizable():
    kernel_invariants = {"channels"}
    def __init__(self, channels: TList(TStr)):
        """
        A Parametrizable is an entity that can be controlled by one or more continuous parameter(s)
        """
        add_or_append(self, "channels", channels)
[docs]
    @kernel
    def set_parameter(self, value: TFloat, channel: TStr = None):
        if channel is not None and hasattr(self, "set_"+channel):
            getattr(self, "set_"+channel)(value)
        else:
            raise NotImplementedError("Implement set_" + channel + "() method in class " + self.__class__.__name__ +
                                      " or an appropriate ancestor") 
 
[docs]
class Remote():
    def __init__(self, remote_reference: str, sync: bool = False):
        """
        An abstract class to represent a remote device, i.e. a device that ist not directly attached to the realtime
        control system (ARTIQ) but is rather controlled through a non-realtime link (i.e. HEROS, pyon, pyro, etc.)
        Args:
            remote_reference: A reference to the remote site that can handle the component. E.g. URL, IP, port, UID..
            sync: Is synchronous operation required, i.e. do we need to wait for the response of the remote site?
        """
        self.remote_reference = remote_reference
        self.sync = sync