Source code for momba.gym.generic

# -*- coding:utf-8 -*-
#
# Copyright (C) 2019-2021, Saarland University
# Copyright (C) 2019-2021, Maximilian Köhl <koehl@cs.uni-saarland.de>

from __future__ import annotations

import dataclasses as d
import typing as t

import enum
import random
import warnings

from .. import engine, model
from ..engine import explore

from . import abstract


[docs]class Actions(enum.Enum): """Specifies how edges of the controlled automaton are chosen.""" EDGE_BY_INDEX = enum.auto() """The edge is chosen based on its index.""" EDGE_BY_LABEL = enum.auto() """The edge is chosen based on its label. """
[docs]class Observations(enum.Enum): """Specifies what is observable by the agent.""" GLOBAL_ONLY = enum.auto() """Only global variables are observable.""" LOCAL_AND_GLOBAL = enum.auto() """Local and global variables are observable.""" OMNISCIENT = enum.auto() """All (non-transient) variables are observable."""
class _ActionResolver: num_actions: int def available(self, state: engine.State[engine.DiscreteTime]) -> t.Sequence[bool]: raise NotImplementedError() def resolve( self, transitions: t.Sequence[engine.Transition[engine.DiscreteTime]], action: int, ) -> t.Sequence[engine.Transition[engine.DiscreteTime]]: raise NotImplementedError() @d.dataclass(frozen=True) class _EdgeByIndexResolver(_ActionResolver): instance: model.Instance num_actions: int @classmethod def create(cls, instance: model.Instance) -> _EdgeByIndexResolver: return cls(instance, len(instance.automaton.edges)) def available(self, state: engine.State[engine.DiscreteTime]) -> t.Sequence[bool]: available_actions: t.Set[int] = set() for transition in state.transitions: available_actions.add(transition.index_vector.get(self.instance, -1)) available_actions.discard(-1) return [action in available_actions for action in range(self.num_actions)] def resolve( self, transitions: t.Sequence[engine.Transition[engine.DiscreteTime]], action: int, ) -> t.Sequence[engine.Transition[engine.DiscreteTime]]: return [ transition for transition in transitions if transition.index_vector.get(self.instance, -1) == action ] @d.dataclass(frozen=True) class _EdgeByLabelResolver(_ActionResolver): instance: model.Instance num_actions: int action_mapping: t.Mapping[int, model.ActionType] = d.field(default_factory=dict) reverse_action_mapping: t.Mapping[model.ActionType, int] = d.field( default_factory=dict ) @classmethod def create( cls, ctx: model.Context, instance: model.Instance ) -> _EdgeByLabelResolver: action_types: t.Set[model.ActionType] = set() for edge in instance.automaton.edges: if edge.action_pattern is not None: action_types.add(edge.action_pattern.action_type) num_actions = len(action_types) action_mapping: t.Dict[int, model.ActionType] = {} # `ctx.action_types` should be ordered as in the JANI file for action_type in ctx.action_types.values(): if action_type in action_types: action_mapping[len(action_mapping)] = action_type return cls( instance, num_actions, action_mapping, reverse_action_mapping={ typ: number for number, typ in action_mapping.items() }, ) def available(self, state: engine.State[engine.DiscreteTime]) -> t.Sequence[bool]: available_actions: t.Set[int] = set() for transition in state.transitions: instance_action = transition.action_vector.get(self.instance, None) if instance_action is not None: available_actions.add( self.reverse_action_mapping[instance_action.action_type] ) available_actions.add(transition.index_vector.get(self.instance, -1)) available_actions.discard(-1) return [action in available_actions for action in range(self.num_actions)] def resolve( self, transitions: t.Sequence[engine.Transition[engine.DiscreteTime]], action: int, ) -> t.Sequence[engine.Transition[engine.DiscreteTime]]: action_type = self.action_mapping[action] result = [] for transition in transitions: instance_action = transition.action_vector.get(self.instance, None) if ( instance_action is not None and instance_action.action_type == action_type ): result.append(transition) return transitions def _create_action_resolver( actions: Actions, ctx: model.Context, instance: model.Instance ) -> _ActionResolver: if actions is Actions.EDGE_BY_INDEX: return _EdgeByIndexResolver.create(instance) else: assert actions is Actions.EDGE_BY_LABEL return _EdgeByLabelResolver.create(ctx, instance) def _count_features(value: engine.Value) -> int: if value.is_int or value.is_bool or value.is_float: return 1 else: return sum(map(_count_features, value.as_array)) def _extend_state_vector(vector: t.List[float], value: engine.Value) -> None: if value.is_array: for value in value.as_array: _extend_state_vector(vector, value) else: vector.append(float(value._value))
[docs]@d.dataclass(frozen=True) class Rewards: """Specifies the rewards for a reachability objective.""" goal_reached: float = 100 """The reward when a goal state is reached. """ dead_end: float = -100 """The reward when a dead end or bad state is reached.""" step_taken: float = 0 """The reward when a valid decision has been taken.""" invalid_action: float = -100 """The reward when an invalid decision has been taken."""
DEFAULT_REWARD_STRUCTURE = Rewards()
[docs]@d.dataclass(frozen=True) class Objective: r"""A reach-avoid objective. Represents a reach-avoid objective of the form :math:`\lnot\phi\mathbin{\mathbf{U}}\psi`, i.e., :math:`\lnot\phi` has to be true until the goal :math:`\psi` is reached. Used in conjunction with :class:`Rewards` to provide rewards. """ goal_predicate: model.Expression r"""A boolean expression for :math:`\psi`.""" dead_predicate: model.Expression r"""A boolean expression for :math:`\phi`."""
def _extract_property(prop: model.Expression) -> Objective: if isinstance(prop, model.properties.Aggregate): assert prop.function in { model.operators.AggregationFunction.MIN, model.operators.AggregationFunction.MAX, model.operators.AggregationFunction.VALUES, model.operators.AggregationFunction.EXISTS, model.operators.AggregationFunction.FORALL, }, f"Unsupported aggregation function {prop.function}" assert isinstance( prop.predicate, model.properties.StateSelector ), f"Unsupported state predicate {prop.predicate} in aggregation" assert ( prop.predicate.predicate is model.properties.StatePredicate.INITIAL ), "Unsupported state predicate for aggregation." prop = prop.values if isinstance(prop, model.properties.Probability): prop = prop.formula if isinstance(prop, model.properties.UnaryPathFormula): assert ( prop.operator is model.operators.UnaryPathOperator.EVENTUALLY ), "Unsupported unary path formula." return Objective( goal_predicate=prop.formula, dead_predicate=model.ensure_expr(False) ) elif isinstance(prop, model.properties.BinaryPathFormula): assert ( prop.operator is model.operators.BinaryPathOperator.UNTIL ), "Unsupported binary path formula." left = prop.left right = prop.right return Objective( goal_predicate=right, dead_predicate=model.expressions.logic_not(left) ) else: raise Exception("Unsupported property!") @d.dataclass(frozen=True) class _Context: explorer: engine.Explorer[engine.DiscreteTime] controlled_instance: model.Instance initial_state: engine.State[engine.DiscreteTime] objective: Objective dead_predicate: explore.CompiledGlobalExpression[engine.DiscreteTime] goal_predicate: explore.CompiledGlobalExpression[engine.DiscreteTime] rewards: Rewards actions: Actions observations: Observations action_resolver: _ActionResolver global_variables: t.Tuple[str, ...] local_variables: t.Tuple[str, ...] other_variables: t.Mapping[model.Instance, t.Tuple[str, ...]] num_features: int @classmethod def create( cls, explorer: engine.Explorer[engine.DiscreteTime], controlled_instance: model.Instance, property_name: str, *, rewards: t.Optional[Rewards] = None, actions: Actions = Actions.EDGE_BY_INDEX, observations: Observations = Observations.GLOBAL_ONLY, ) -> _Context: network = explorer.network initial_states = explorer.initial_states objective = _extract_property( network.ctx.get_property_definition_by_name(property_name).expression ) goal_predicate = explorer.compile_global_expression(objective.goal_predicate) dead_predicate = explorer.compile_global_expression(objective.dead_predicate) if len(initial_states) != 1: raise Exception("Invalid number of initial states.") (initial_state,) = initial_states num_features = 0 global_variables = [] for declaration in network.ctx.global_scope.variable_declarations: if declaration.is_transient: continue global_variables.append(declaration.identifier) value = initial_state.global_env[declaration.identifier] num_features += _count_features(value) local_variables = [] if observations in {Observations.LOCAL_AND_GLOBAL, Observations.OMNISCIENT}: for ( declaration ) in controlled_instance.automaton.scope.variable_declarations: if declaration.is_transient: continue local_variables.append(declaration.identifier) value = initial_state.get_local_env(controlled_instance)[ declaration.identifier ] num_features += _count_features(value) other_variables: t.Dict[model.Instance, t.Tuple[str, ...]] = {} if observations is Observations.OMNISCIENT: for instance in sorted( network.instances, key=lambda instance: str(instance.automaton.name) ): if instance is controlled_instance: # do not include the local variables twice continue instance_variables: t.List[str] = [] for declaration in instance.automaton.scope.variable_declarations: if declaration.is_transient: continue instance_variables.append(declaration.identifier) value = initial_state.get_local_env(instance)[ declaration.identifier ] num_features += _count_features(value) if instance_variables: other_variables[instance] = tuple(sorted(instance_variables)) return cls( explorer=explorer, controlled_instance=controlled_instance, initial_state=initial_state, objective=objective, goal_predicate=goal_predicate, dead_predicate=dead_predicate, actions=actions, observations=observations, global_variables=tuple(sorted(global_variables)), local_variables=tuple(sorted(local_variables)), other_variables=other_variables, rewards=rewards or Rewards(), action_resolver=_create_action_resolver( actions, network.ctx, controlled_instance ), num_features=num_features, ) class GenericExplorer(abstract.Explorer): _ctx: _Context state: engine.State[engine.DiscreteTime] def __init__( self, *, _ctx: _Context, _state: t.Optional[engine.State[engine.DiscreteTime]] = None, ) -> None: self._ctx = _ctx self.state = _state or _ctx.initial_state @classmethod def create( cls, explorer: engine.Explorer[engine.DiscreteTime], controlled_instance: model.Instance, property_name: str, *, rewards: t.Optional[Rewards] = None, actions: Actions = Actions.EDGE_BY_INDEX, observations: Observations = Observations.GLOBAL_ONLY, ) -> GenericExplorer: ctx = _Context.create( explorer, controlled_instance, property_name, rewards=rewards, actions=actions, observations=observations, ) return cls(_ctx=ctx, _state=ctx.initial_state) def _has_choice(self, state: engine.State[engine.DiscreteTime]) -> bool: return any( self._ctx.controlled_instance in transition.index_vector for transition in state.transitions ) def _is_final(self, state: engine.State[engine.DiscreteTime]) -> bool: return self._has_choice(state) or self._has_terminated(state) def _explore_until_choice(self) -> None: while not self._is_final(self.state): if len(self.state.transitions) > 1: warnings.warn( "Uncontrolled nondeterminism has been resolved uniformly." ) self.state = random.choice(self.state.transitions).destinations.pick().state def _explore_successors( self, state: engine.State[engine.DiscreteTime] ) -> t.Dict[engine.State[engine.DiscreteTime], float]: pending = [(1.0, state)] result: t.Dict[engine.State[engine.DiscreteTime], float] = {} while pending: probability, state = pending.pop() if self._is_final(state): if state not in result: result[state] = 0.0 result[state] += probability else: if len(state.transitions) > 1: warnings.warn( "Uncontrolled nondeterminism has been resolved uniformly." ) transition_probability = probability / len(state.transitions) for transition in state.transitions: for destination in transition.destinations.support: pending.append( ( transition_probability * float( transition.destinations.get_probability(destination) ), destination.state, ) ) return result @property def num_actions(self) -> int: return self._ctx.action_resolver.num_actions @property def num_features(self) -> int: return self._ctx.num_features def _state_vector( self, state: engine.State[engine.DiscreteTime] ) -> t.Sequence[float]: vector: t.List[float] = [] for variable in self._ctx.global_variables: value = state.global_env[variable] _extend_state_vector(vector, value) local_env = state.get_local_env(self._ctx.controlled_instance) for variable in self._ctx.local_variables: _extend_state_vector(vector, local_env[variable]) for instance, variables in self._ctx.other_variables.items(): local_env = state.get_local_env(instance) for variable in variables: _extend_state_vector(vector, local_env[variable]) return tuple(vector) @property def state_vector(self) -> t.Sequence[float]: return self._state_vector(self.state) def _has_terminated(self, state: engine.State[engine.DiscreteTime]) -> bool: is_dead = self._ctx.dead_predicate.evaluate(state).as_bool return not state.transitions or is_dead or self._has_reached_goal(state) def _has_reached_goal(self, state: engine.State[engine.DiscreteTime]) -> bool: return self._ctx.goal_predicate.evaluate(state).as_bool def _get_reward(self, state: engine.State[engine.DiscreteTime]) -> float: if self._has_reached_goal(state): return self._ctx.rewards.goal_reached elif self._has_terminated(state): return self._ctx.rewards.dead_end else: return self._ctx.rewards.step_taken @property def has_terminated(self) -> bool: return self._has_terminated(self.state) @property def has_reached_goal(self) -> bool: return self._has_reached_goal(self.state) @property def available_actions(self) -> t.Sequence[bool]: return tuple(self._ctx.action_resolver.available(self.state)) @property def available_transitions(self) -> t.Sequence[abstract.Transition]: if self.has_terminated: return [] result = [] for action, available in enumerate(self.available_actions): if not available: continue transitions = self._ctx.action_resolver.resolve( self.state.transitions, action ) if len(transitions) > 1: warnings.warn( "Uncontrolled nondeterminism has been resolved uniformly." ) transition_probability = 1.0 / len(transitions) destinations: t.Dict[t.Sequence[float], abstract.Destination] = {} for transition in transitions: for destination in transition.destinations.support: for ( final_successor, probability, ) in self._explore_successors(destination.state).items(): final_successor_vector = self._state_vector(final_successor) reward = self._get_reward(final_successor) probability *= transition_probability probability *= transition.destinations.get_probability( destination ) if final_successor_vector in destinations: probability += destinations[ final_successor_vector ].probability destinations[final_successor_vector] = abstract.Destination( final_successor_vector, reward, probability ) result.append( abstract.Transition(action, destinations=tuple(destinations.values())) ) return result def step(self, action: int) -> float: if self.has_terminated: if self.has_reached_goal: return self._ctx.rewards.goal_reached else: return self._ctx.rewards.dead_end selected_transitions = self._ctx.action_resolver.resolve( self.state.transitions, action ) if not selected_transitions: return self._ctx.rewards.invalid_action else: if len(selected_transitions) > 1: warnings.warn( "Uncontrolled nondeterminism has been resolved uniformly." ) self.state = random.choice(selected_transitions).destinations.pick().state self._explore_until_choice() return self._get_reward(self.state) def reset(self, *, explore_until_choice: bool = True) -> None: self.state = self._ctx.initial_state if explore_until_choice: self._explore_until_choice() def fork(self) -> GenericExplorer: return GenericExplorer(_ctx=self._ctx, _state=self.state)