Source code for momba.utils.distribution
# -*- coding:utf-8 -*-
#
# Copyright (C) 2020, Maximilian Köhl <koehl@cs.uni-saarland.de>
from __future__ import annotations
import typing as t
import fractions
import random
from mxu.maps import FrozenMap
ElementT = t.TypeVar("ElementT", bound=t.Hashable)
[docs]
class Distribution(t.Generic[ElementT]):
"""A probability distribution."""
_mapping: FrozenMap[ElementT, fractions.Fraction]
@classmethod
def create_dirac(cls, element: ElementT) -> Distribution[ElementT]:
return cls({element: 1})
@classmethod
def create_uniform(cls, *elements: ElementT) -> Distribution[ElementT]:
probability = fractions.Fraction(1, len(elements))
return cls({element: probability for element in elements})
def __init__(
self, mapping: t.Mapping[ElementT, t.Union[int, float, fractions.Fraction]]
) -> None:
self._mapping = FrozenMap.transfer_ownership(
{
element: fractions.Fraction(probability)
for element, probability in mapping.items()
}
)
assert all(probability >= 0 for probability in self._mapping.values())
# assert sum(self._mapping.values()) == 1
def __str__(self) -> str:
return f"Distribution({self._mapping})"
@property
def support(self) -> t.List[ElementT]:
return list(
element for element, probability in self._mapping.items() if probability > 0
)
@property
def is_dirac(self) -> bool:
return len(self.support) == 1
def get_probability(self, element: ElementT) -> fractions.Fraction:
return self._mapping.get(element, fractions.Fraction(0))
[docs]
def pick(self) -> ElementT:
"""Picks an element at random according to the distribution."""
max_denominator = max(
probability.denominator for probability in self._mapping.values()
)
outcome = fractions.Fraction(
random.randint(0, max_denominator), max_denominator
)
total = fractions.Fraction(0)
for element in self.support:
total += self.get_probability(element)
if outcome <= total:
return element
raise RuntimeError("empty distribution is not possible")