Source code for momba.gym.abstract

# -*- coding:utf-8 -*-
#
# Copyright (C) 2019-2021, Saarland University
# Copyright (C) 2019-2021, Maximilian Köhl <koehl@cs.uni-saarland.de>

from __future__ import annotations

import dataclasses as d
import typing as t

import abc


StateVector = t.Sequence[float]
AvailableVector = t.Sequence[bool]


@d.dataclass(frozen=True)
class Destination:
    state: StateVector
    reward: float
    probability: float


@d.dataclass(frozen=True)
class Transition:
    action: int
    destinations: t.Sequence[Destination]


[docs]class Explorer(abc.ABC): """State space explorer for training decision agents.""" @property @abc.abstractmethod def num_actions(self) -> int: """The number of *actions*.""" raise NotImplementedError() @property @abc.abstractmethod def num_features(self) -> int: """The number of features of the state vector.""" raise NotImplementedError() @property @abc.abstractmethod def has_terminated(self) -> bool: """Indicates whether the explorer is in a terminal state.""" raise NotImplementedError() @property @abc.abstractmethod def state_vector(self) -> StateVector: """The state vector of the current explorer state.""" raise NotImplementedError() @property @abc.abstractmethod def available_actions(self) -> AvailableVector: """A boolean vector indicating which actions are available.""" raise NotImplementedError() @property @abc.abstractmethod def available_transitions(self) -> t.Sequence[Transition]: """A sequence of available transitions.""" raise NotImplementedError()
[docs] @abc.abstractmethod def step(self, action: int) -> float: """ Takes a step with the given action and returns the reward. Precondition: The explorer must not be in a terminal state. """ raise NotImplementedError()
[docs] @abc.abstractmethod def reset(self) -> None: """Resets the explorer to the initial state.""" raise NotImplementedError()
[docs] @abc.abstractmethod def fork(self) -> Explorer: """Forks the explorer with the current state.""" raise NotImplementedError()
[docs]class Oracle(t.Protocol): """An *oracle* selects an action based on the state and the available actions.""" def __call__(self, state: StateVector, available: AvailableVector) -> int: pass