Skip to content

Mixins

Mixins to support the plotting and slicing of the distribution parameters. They are only required for creating new distributions or for understanding the internals.

ContinuousPlotDistMixin

Bases: PlotDistMixin

Functionality for plot_pdf method of continuous distributions.

Source code in conjugate/plot.py
class ContinuousPlotDistMixin(PlotDistMixin):
    """Functionality for plot_pdf method of continuous distributions."""

    def plot_pdf(self, ax: Optional[plt.Axes] = None, **kwargs) -> plt.Axes:
        """Plot the pdf of distribution

        Args:
            ax: matplotlib Axes, optional
            **kwargs: Additonal kwargs to pass to matplotlib

        Returns:
            new or modified Axes

        Raises:
            ValueError: If the max_value is not set.

        """
        x = self._create_x_values()
        x = self._reshape_x_values(x)

        ax = self._settle_axis(ax=ax)

        return self._create_plot_on_axis(x, ax, **kwargs)

    def _create_x_values(self) -> np.ndarray:
        return np.linspace(self.min_value, self.max_value, 100)

    def _setup_labels(self, ax) -> None:
        if isinstance(ax, plt.PolarAxes):
            return

        ax.set_xlabel("Domain")
        ax.set_ylabel("Density $f(x)$")

    def _create_plot_on_axis(self, x, ax, **kwargs) -> plt.Axes:
        yy = self.dist.pdf(x)
        if "label" in kwargs:
            label = kwargs.pop("label")
            label = resolve_label(label, yy)
        else:
            label = None

        ax.plot(x, yy, label=label, **kwargs)
        self._setup_labels(ax=ax)
        ax.set_ylim(0, None)
        return ax

plot_pdf(ax=None, **kwargs)

Plot the pdf of distribution

Parameters:

Name Type Description Default
ax Optional[Axes]

matplotlib Axes, optional

None
**kwargs

Additonal kwargs to pass to matplotlib

{}

Returns:

Type Description
Axes

new or modified Axes

Raises:

Type Description
ValueError

If the max_value is not set.

Source code in conjugate/plot.py
def plot_pdf(self, ax: Optional[plt.Axes] = None, **kwargs) -> plt.Axes:
    """Plot the pdf of distribution

    Args:
        ax: matplotlib Axes, optional
        **kwargs: Additonal kwargs to pass to matplotlib

    Returns:
        new or modified Axes

    Raises:
        ValueError: If the max_value is not set.

    """
    x = self._create_x_values()
    x = self._reshape_x_values(x)

    ax = self._settle_axis(ax=ax)

    return self._create_plot_on_axis(x, ax, **kwargs)

DirichletPlotDistMixin

Bases: ContinuousPlotDistMixin

Plot the pdf using samples from the dirichlet distribution.

Source code in conjugate/plot.py
class DirichletPlotDistMixin(ContinuousPlotDistMixin):
    """Plot the pdf using samples from the dirichlet distribution."""

    def plot_pdf(
        self,
        ax: Optional[plt.Axes] = None,
        samples: int = 1_000,
        random_state=None,
        **kwargs,
    ) -> plt.Axes:
        """Plots the pdf by sampling from the distribution.

        Args:
            ax: matplotlib Axes, optional
            samples: number of samples to take from the distribution
            random_state: random state to use for sampling
            **kwargs: Additonal kwargs to pass to matplotlib

        Returns:
            new or modified Axes

        """
        distribution_samples = self.dist.rvs(size=samples, random_state=random_state)

        ax = self._settle_axis(ax=ax)
        xx = self._create_x_values()

        labels = label_to_iterable(
            kwargs.pop("label", None), distribution_samples.shape[1]
        )

        for x, label in zip_longest(distribution_samples.T, labels):
            kde = gaussian_kde(x)

            yy = kde(xx)
            ax.plot(xx, yy, label=label, **kwargs)

        self._setup_labels(ax=ax)
        return ax

plot_pdf(ax=None, samples=1000, random_state=None, **kwargs)

Plots the pdf by sampling from the distribution.

Parameters:

Name Type Description Default
ax Optional[Axes]

matplotlib Axes, optional

None
samples int

number of samples to take from the distribution

1000
random_state

random state to use for sampling

None
**kwargs

Additonal kwargs to pass to matplotlib

{}

Returns:

Type Description
Axes

new or modified Axes

Source code in conjugate/plot.py
def plot_pdf(
    self,
    ax: Optional[plt.Axes] = None,
    samples: int = 1_000,
    random_state=None,
    **kwargs,
) -> plt.Axes:
    """Plots the pdf by sampling from the distribution.

    Args:
        ax: matplotlib Axes, optional
        samples: number of samples to take from the distribution
        random_state: random state to use for sampling
        **kwargs: Additonal kwargs to pass to matplotlib

    Returns:
        new or modified Axes

    """
    distribution_samples = self.dist.rvs(size=samples, random_state=random_state)

    ax = self._settle_axis(ax=ax)
    xx = self._create_x_values()

    labels = label_to_iterable(
        kwargs.pop("label", None), distribution_samples.shape[1]
    )

    for x, label in zip_longest(distribution_samples.T, labels):
        kde = gaussian_kde(x)

        yy = kde(xx)
        ax.plot(xx, yy, label=label, **kwargs)

    self._setup_labels(ax=ax)
    return ax

DiscretePlotMixin

Bases: PlotDistMixin

Adding the plot_pmf method to class.

Source code in conjugate/plot.py
class DiscretePlotMixin(PlotDistMixin):
    """Adding the plot_pmf method to class."""

    def plot_pmf(
        self, ax: Optional[plt.Axes] = None, mark: str = "o-", **kwargs
    ) -> plt.Axes:
        """Plot the pmf of distribution

        Args:
            ax: matplotlib Axes, optional
            mark: matplotlib line style
            **kwargs: Additonal kwargs to pass to matplotlib

        Returns:
            new or modified Axes

        Raises:
            ValueError: If the max_value is not set.

        """
        x = self._create_x_values()
        x = self._reshape_x_values(x)

        ax = self._settle_axis(ax=ax)
        return self._create_plot_on_axis(x, ax, mark, **kwargs)

    def _create_x_values(self) -> np.ndarray:
        return np.arange(self.min_value, self.max_value + 1, 1)

    def _create_plot_on_axis(
        self, x, ax, mark, conditional: bool = False, **kwargs
    ) -> plt.Axes:
        yy = self.dist.pmf(x)
        if conditional:
            yy = yy / np.sum(yy)
            ylabel = f"Conditional Probability $f(x|{self.min_value} \\leq x \\leq {self.max_value})$"
        else:
            ylabel = "Probability $f(x)$"

        if "label" in kwargs:
            label = kwargs.pop("label")
            label = resolve_label(label, yy)
        else:
            label = None

        ax.plot(x, yy, mark, label=label, **kwargs)

        if self.max_value - self.min_value < 15:
            ax.set_xticks(x.ravel())
        else:
            ax.set_xticks(x.ravel(), minor=True)
            ax.set_xticks(x[::5].ravel())

        ax.set_xlabel("Domain")
        ax.set_ylabel(ylabel)
        ax.set_ylim(0, None)
        return ax

plot_pmf(ax=None, mark='o-', **kwargs)

Plot the pmf of distribution

Parameters:

Name Type Description Default
ax Optional[Axes]

matplotlib Axes, optional

None
mark str

matplotlib line style

'o-'
**kwargs

Additonal kwargs to pass to matplotlib

{}

Returns:

Type Description
Axes

new or modified Axes

Raises:

Type Description
ValueError

If the max_value is not set.

Source code in conjugate/plot.py
def plot_pmf(
    self, ax: Optional[plt.Axes] = None, mark: str = "o-", **kwargs
) -> plt.Axes:
    """Plot the pmf of distribution

    Args:
        ax: matplotlib Axes, optional
        mark: matplotlib line style
        **kwargs: Additonal kwargs to pass to matplotlib

    Returns:
        new or modified Axes

    Raises:
        ValueError: If the max_value is not set.

    """
    x = self._create_x_values()
    x = self._reshape_x_values(x)

    ax = self._settle_axis(ax=ax)
    return self._create_plot_on_axis(x, ax, mark, **kwargs)

PlotDistMixin

Base mixin in order to support plotting. Requires the dist attribute of the scipy distribution.

Source code in conjugate/plot.py
class PlotDistMixin:
    """Base mixin in order to support plotting. Requires the dist attribute of the scipy distribution."""

    @property
    def dist(self) -> Distribution:
        raise NotImplementedError("Implement this property in the subclass.")

    @property
    def max_value(self) -> float:
        if not hasattr(self, "_max_value"):
            raise ValueError("Set the max value before plotting.")

        return self._max_value

    @max_value.setter
    def max_value(self, value: float) -> None:
        self._max_value = value

    def set_max_value(self, value: float) -> "PlotDistMixin":
        self.max_value = value

        return self

    @property
    def min_value(self) -> float:
        if not hasattr(self, "_min_value"):
            self._min_value = 0.0

        return self._min_value

    @min_value.setter
    def min_value(self, value: float) -> None:
        self._min_value = value

    def set_min_value(self, value: float) -> "PlotDistMixin":
        """Set the minimum value for plotting."""
        self.min_value = value

        return self

    def set_bounds(self, lower: float, upper: float) -> "PlotDistMixin":
        """Set both the min and max values for plotting."""
        return self.set_min_value(lower).set_max_value(upper)

    def _reshape_x_values(self, x: np.ndarray) -> np.ndarray:
        """Make sure that the values are ready for plotting."""
        for value in asdict(self).values():
            if not isinstance(value, float):
                return x[:, None]

        return x

    def _settle_axis(self, ax: Optional[plt.Axes] = None) -> plt.Axes:
        return ax if ax is not None else plt.gca()

set_bounds(lower, upper)

Set both the min and max values for plotting.

Source code in conjugate/plot.py
def set_bounds(self, lower: float, upper: float) -> "PlotDistMixin":
    """Set both the min and max values for plotting."""
    return self.set_min_value(lower).set_max_value(upper)

set_min_value(value)

Set the minimum value for plotting.

Source code in conjugate/plot.py
def set_min_value(self, value: float) -> "PlotDistMixin":
    """Set the minimum value for plotting."""
    self.min_value = value

    return self

resolve_label(label, yy)

https://stackoverflow.com/questions/73662931/matplotlib-plot-a-numpy-array-as-many-lines-with-a-single-label

Source code in conjugate/plot.py
def resolve_label(label: LABEL_INPUT, yy: np.ndarray):
    """

    https://stackoverflow.com/questions/73662931/matplotlib-plot-a-numpy-array-as-many-lines-with-a-single-label
    """
    if yy.ndim == 1:
        return label

    ncols = yy.shape[1]
    if ncols != 1:
        return label_to_iterable(label, ncols)

    return label

SliceMixin

Mixin in order to slice the parameters

Source code in conjugate/slice.py
class SliceMixin:
    """Mixin in order to slice the parameters"""

    @property
    def params(self):
        return asdict(self)

    def __getitem__(self, key):
        def take_slice(value, key):
            try:
                return value[key]
            except Exception:
                return value

        new_params = {k: take_slice(value=v, key=key) for k, v in self.params.items()}

        return self.__class__(**new_params)

Comments