from __future__ import annotations

import dataclasses as d
import typing as t

import asyncio
import json
import fractions
import pathlib
import struct
import subprocess
import tempfile

from .. import jani, model

from import modest

from . import generic
from .abstract import Oracle
from .dump_nn import dump_nn

    import torch  # type: ignore

HEADER = struct.Struct("!II")  # (num_features, num_actions)
DECISION = struct.Struct("!I")  # (action,)

class ModesOptions:
    max_run_length_as_end: bool = True
    additional_options: t.Sequence[str] = ()

    def apply(self, command: t.List[str]) -> None:
        if self.max_run_length_as_end:

_DEFAULT_OPTIONS = ModesOptions()

def _apply_actions_observations(
    command: t.List[str], actions: generic.Actions, observations: generic.Observations
) -> None:
    if actions is generic.Actions.EDGE_BY_LABEL:
        assert actions is generic.Actions.EDGE_BY_INDEX
    if observations is generic.Observations.LOCAL_AND_GLOBAL:
    elif observations is generic.Observations.OMNISCIENT:
        assert observations is generic.Observations.GLOBAL_ONLY

def _create_vector_decoder(num_features: int) -> struct.Struct:
    return struct.Struct(f"!{num_features}f")

def _create_available_decoder(num_actions: int) -> struct.Struct:
    return struct.Struct(f"!{num_actions}?")

class OracleServer:
    oracle: Oracle

    host: str = ""
    port: t.Optional[int] = None

    server: t.Optional[asyncio.AbstractServer] = None

    async def _handle_connection(
        self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
    ) -> None:
        address = writer.get_extra_info("peername")
        print(f"New connection from {address}.")
        num_features, num_actions = HEADER.unpack(await reader.readexactly(HEADER.size))
        print(f"Features: {num_features}, Actions: {num_actions}")
        state_decoder = _create_vector_decoder(num_features)
        available_decoder = _create_available_decoder(num_actions)
            while True:
                state = state_decoder.unpack(
                    await reader.readexactly(state_decoder.size)
                available = available_decoder.unpack(
                    await reader.readexactly(available_decoder.size)
                decision =, available)
                await writer.drain()
        except asyncio.IncompleteReadError:

    async def start(self) -> None:
        self.server = await asyncio.start_server(
            self._handle_connection,, self.port
        assert self.server.sockets is not None
        self.port = self.server.sockets[0].getsockname()[1]
        await self.server.start_serving()

    async def stop(self) -> None:
        assert self.server is not None
        await self.server.wait_closed()

_NO_PARAMETERS: t.Dict[str, t.Any] = {}

def _load_result(output_file: pathlib.Path) -> t.Mapping[str, fractions.Fraction]:
    result = json.loads(output_file.read_text(encoding="utf-8-sig"))
    return {
        dataset["property"]: fractions.Fraction(dataset["value"])
        for dataset in result["data"]
        if "data" in dataset
        for dataset in dataset["data"]
        if "property" in dataset

async def _check_oracle(
    model_path: pathlib.Path,
    automaton_name: str,
    oracle: Oracle,
    output_file: pathlib.Path,
    parameters: t.Mapping[str, t.Any],
    toolset: modest.Toolset,
    options: ModesOptions,
    actions: generic.Actions,
    observations: generic.Observations,
) -> None:
    server = OracleServer(oracle)
    await server.start()
    address = f"localhost:{server.port}"
    print(f"Address: {address}")
    arguments: t.List[str] = [
    for param_name, param_value in parameters.items():
        arguments.extend(["-E", f"{param_name}={param_value}"])
    _apply_actions_observations(arguments, actions, observations)
    process = await asyncio.subprocess.create_subprocess_exec(
        toolset.executable, *arguments
    await process.wait()

async def check_oracle_async(
    network: model.Network,
    instance: model.Instance,
    oracle: Oracle,
    parameters: t.Mapping[str, t.Any] = _NO_PARAMETERS,
    toolset: modest.Toolset = modest.toolset,
    options: ModesOptions = _DEFAULT_OPTIONS,
    actions: generic.Actions = generic.Actions.EDGE_BY_INDEX,
    observations: generic.Observations = generic.Observations.GLOBAL_ONLY,
) -> t.Mapping[str, fractions.Fraction]:
    """Checks an arbitrary Python function implementing a decsion agent."""
    with tempfile.TemporaryDirectory() as temp_dir:
        temp_path = pathlib.Path(temp_dir)
        model_path = temp_path / "model.jani"
        model_path.write_text(jani.dump_model(network), encoding="utf-8")
        output_file = temp_path / "output.json"
        assert is not None
        await _check_oracle(
        return _load_result(output_file)

[docs] def check_oracle( network: model.Network, instance: model.Instance, oracle: Oracle, *, parameters: t.Mapping[str, t.Any] = _NO_PARAMETERS, toolset: modest.Toolset = modest.toolset, options: ModesOptions = _DEFAULT_OPTIONS, actions: generic.Actions = generic.Actions.EDGE_BY_INDEX, observations: generic.Observations = generic.Observations.GLOBAL_ONLY, ) -> t.Mapping[str, fractions.Fraction]: """ Checks an arbitrary Python function. Arguments: network: A JANI automaton network. instance: An instance of an automaton in the provided network. The decision-making agent trained on the resulting environment is assumed to act by resolving the non-determinism in this automaton. orcace: The Python function to check. parameters: Allows defining values for parameters of the JANI model. toolset: Specifies the location of the Modest Toolset. actions: Specifies the action space for the environment. observations: Specifies the observation space for the environment. """ loop = asyncio.get_event_loop() return loop.run_until_complete( check_oracle_async( network, instance, oracle, parameters=parameters, toolset=toolset, options=options, actions=actions, observations=observations, ) )
[docs] def check_nn( network: model.Network, instance: model.Instance, nn: torch.nn.Module, *, parameters: t.Mapping[str, t.Any] = _NO_PARAMETERS, toolset: modest.Toolset = modest.toolset, options: ModesOptions = _DEFAULT_OPTIONS, actions: generic.Actions = generic.Actions.EDGE_BY_INDEX, observations: generic.Observations = generic.Observations.GLOBAL_ONLY, ) -> t.Mapping[str, fractions.Fraction]: """ Checks a PyTorch neural network. Arguments: network: A JANI automaton network. instance: An instance of an automaton in the provided network. The decision-making agent trained on the resulting environment is assumed to act by resolving the non-determinism in this automaton. nn: The PyTorch neural network. parameters: Allows defining values for parameters of the JANI model. toolset: Specifies the location of the Modest Toolset. actions: Specifies the action space for the environment. observations: Specifies the observation space for the environment. """ with tempfile.TemporaryDirectory() as temp_dir: temp_path = pathlib.Path(temp_dir) model_path = temp_path / "model.jani" model_path.write_text(jani.dump_model(network), encoding="utf-8") nn_path = temp_path / "nn.json" nn_path.write_text(dump_nn(nn)) output_file = temp_path / "output.json" command: t.List[str] = [ str(toolset.executable), "modes", str(model_path), "-O", str(output_file), "json", "-R", "NN", "-NN", str(nn_path), "-A", str(, "--threads", "1", ] for param_name, param_value in parameters.items(): command.extend(["-E", f"{param_name}={param_value}"]) options.apply(command) _apply_actions_observations(command, actions, observations) subprocess.check_call(command) return _load_result(output_file)