Source code for sara.viz.viz

"""Plotting function for SARA-VIZ."""

from io import BytesIO, StringIO
from pathlib import Path
from typing import Final, BinaryIO

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pandera.pandas as pa
from matplotlib.patches import Rectangle
from pandera.typing import DataFrame

from sara.oar import RTGSchema

MAX_LABELS: Final[int] = 2


def cleaner_for_label(x: tuple[str, str]) -> str:
    """Transform column to clean label for plots.

    Example:
        >>> cleaner_for_label(('rew0', 'money'))
        'rew0: money'
    """
    return f"{x[0]}: {x[1]}"


def df2tables(
    df: pd.DataFrame,
    col_labels: list,
) -> tuple[pd.DataFrame, pd.DataFrame]:
    """Group and Pivot environment DataFrame for plotting function."""
    rtg_sth = next(k for k in df.columns if k[0] == "rtg0")
    rtg_flat = cleaner_for_label(rtg_sth)
    labels_flat = [cleaner_for_label(label) for label in col_labels]

    # Group by labels, take mean and count
    gb = df.reset_index().groupby(col_labels, observed=True, dropna=False)

    df_rtg = (
        gb.apply(lambda df: df[rtg_sth].mean(), include_groups=False)
        .to_frame(rtg_flat)
        .reset_index()
    )
    df_rtg.columns = labels_flat + [rtg_flat]

    df_count = (
        gb.apply(lambda df: df[rtg_sth].count(), include_groups=False)
        .to_frame("count")
        .reset_index()
    )
    df_count.columns = labels_flat + ["count"]

    # prepare pivot_table index and columns
    if len(labels_flat) == MAX_LABELS:
        index = [labels_flat[1]]
        columns = [labels_flat[0]]
    elif len(labels_flat) == 1:
        index = [labels_flat[0]]
        columns = None
    else:
        raise NotImplementedError

    # pivot_table
    table_rtg = df_rtg.pivot_table(
        index=index,
        columns=columns,
        values=rtg_flat,
        observed=True,
        dropna=False,
    )
    table_rtg.name = rtg_flat
    table_count = df_count.pivot_table(
        index=index,
        columns=columns,
        values="count",
        observed=True,
        dropna=False,
    )

    # pivoting can introduce nans where (x, y) pairs
    # do not exist in dataset,
    # their count is set to 0 while their rtg is left nan
    table_count = table_count.fillna(0)
    return table_rtg, table_count


def _finalize_plot(fig, save_path: str | Path | None = None) -> BytesIO:
    img_stream = BytesIO()
    fig.savefig(img_stream, format="png", bbox_inches="tight")
    img_stream.seek(0)

    if save_path:
        fig.savefig(save_path, format="png", bbox_inches="tight")
    else:
        plt.show()

    plt.close(fig)
    return img_stream


def plot_heatmap(
    table_rtg: pd.DataFrame,
    table_count: pd.DataFrame,
    filename: str | Path | None = None,
    *,
    normalize_on_x: bool = False,
    argmax_on_x: bool = False,
    csv_prefix: str | Path | None = None,
    csv_buffers: dict | None = None,
) -> tuple[BytesIO, BytesIO | None, BytesIO | None]:
    """Plot the Q table heatmap in case of 2 labels are given.

    Args:
        table_rtg: the pivoted table of RTGs from :func:`df2tables`
        table_count: the pivoted table of counts from :func:`df2tables`
        filename: if provided, plot will be saved in this file
        normalize_on_x: whether to normalize RTG values along xaxis to improve
             comparability of RTG inside this axis
        argmax_on_x: whether to display a red box around tiles that realizes
            the max of RTG along xaxis.
        csv_prefix: TODO Krzysztof
        csv_buffers: TODO Krzysztof
    Returns:
        TODO Krzysztof
    """
    rtg_buf_out = None
    count_buf_out = None

    if csv_prefix is not None:
        table_rtg.to_csv(f"{csv_prefix}_rtg.csv")
        table_count.to_csv(f"{csv_prefix}_count.csv")

    if csv_buffers is not None:
        rtg_target = csv_buffers.get("rtg")
        count_target = csv_buffers.get("count")

        if rtg_target is not None:
            s = StringIO()
            table_rtg.to_csv(s)
            rtg_target.write(s.getvalue().encode())
            rtg_target.seek(0)
            rtg_buf_out = rtg_target

        if count_target is not None:
            s = StringIO()
            table_count.to_csv(s)
            count_target.write(s.getvalue().encode())
            count_target.seek(0)
            count_buf_out = count_target
    fig, ax = plt.subplots(figsize=(8, 6))

    data = table_rtg.to_numpy()

    # heatmap
    data_to_plot = data
    if normalize_on_x:
        mx = np.nanmax(data, axis=1, keepdims=True)
        mn = np.nanmin(data, axis=1, keepdims=True)
        rge = mx - mn
        data_to_plot = np.where(rge > 0, (data - mn) / rge, 0.5)
    cax = ax.imshow(data_to_plot, cmap="viridis", interpolation="None", aspect="auto")

    # count
    for ix in range(len(table_count.columns)):
        for iy in range(len(table_count.index)):
            c = table_count.iloc[iy, ix]
            txt = f"{c:.0f}"
            ax.text(ix, iy, txt, ha="center", va="center", color="black", fontsize=12)

    amx = np.nanargmax(data, axis=1) if argmax_on_x else []
    for i, j in enumerate(amx):
        # The anchor point for the rectangle is the bottom-left corner of the cell.
        # For a cell at (row, col) = (i, j), the corner is at (j-0.5, i-0.5).
        rect = Rectangle(
            (j - 0.5, i - 0.5),  # (x,y) bottom-left corner
            1,
            1,  # width, height
            edgecolor="red",
            facecolor="none",
            lw=2,
        )
        ax.add_patch(rect)

    y_labels = table_rtg.index
    y_name = y_labels.name
    x_labels = table_rtg.columns
    x_name = table_rtg.columns.name

    # Set the ticks and labels using the labels from the pivoted table_rtg
    ax.set_xticks(np.arange(len(x_labels)))
    ax.set_yticks(np.arange(len(y_labels)))
    ax.set_xticklabels(x_labels, rotation=-30)
    ax.set_yticklabels(y_labels)

    # Add a color bar and titles
    # rtg_sth = cleaner_for_label(tuple(table_rtg.columns[0][:2]))
    fig.colorbar(cax, label=table_rtg.name)
    ax.set_xlabel(x_name)
    ax.set_ylabel(y_name)

    plt.tight_layout()
    img_stream = _finalize_plot(fig, filename)
    return img_stream, rtg_buf_out, count_buf_out


def plot_histogram(
    table_rtg: pd.DataFrame,
    table_count: pd.DataFrame,
    filename: str | Path | None = None,
    *,
    normalize_on_x: bool = False,
    argmax_on_x: bool = False,
    csv_prefix: str | Path | None = None,
) -> BytesIO:
    """Plot the histogram in case 1 label is given."""

    if csv_prefix is not None:
        table_rtg.to_csv(f"{csv_prefix}_rtg.csv")
        table_count.to_csv(f"{csv_prefix}_count.csv")

    fig, ax = plt.subplots(figsize=(8, 6))

    data = table_rtg.to_numpy().flatten()
    amx = np.nanargmax(data)
    mx = np.nanmax(data, keepdims=True)
    mn = np.nanmin(data, keepdims=True)
    rge = mx - mn
    data_normalized = np.where(rge > 0, (data - mn) / rge, 0.5)

    # histogram
    data_to_plot = data_normalized if normalize_on_x else data
    colors = None
    if argmax_on_x:
        colors = len(table_rtg) * ["C0"]
        colors[amx] = "r"
    ax.bar(range(len(table_rtg)), data_to_plot, color=colors)
    ax.set_xticks(range(len(table_rtg)))
    ax.set_xticklabels(table_rtg.index)

    # count, dirty
    for patch, count in zip(ax.patches, table_count["count"], strict=True):
        # Get the coordinates for the text
        x_pos = patch.get_x() + patch.get_width() / 2
        y_pos = patch.get_height() / 2
        # Add the text to the plot
        ax.text(x_pos, y_pos, f"{count:.0f}", ha="center", va="center", fontsize=10)

    x_name = table_rtg.index.name
    y_name = table_rtg.columns[0]

    ax.set_xlabel(x_name)
    ax.set_ylabel(y_name)

    plt.tight_layout()
    return _finalize_plot(fig, filename)


[docs] @pa.check_types def plot_insight( df: DataFrame[RTGSchema], col_labels: list, # TODO(anthony): use a cleaner type filename: str | Path | None = None, *, normalize_on_x: bool = False, argmax_on_x: bool = False, csv_prefix: str | None = None, csv_buffers: dict[str, BinaryIO] | None = None, ) -> tuple[BytesIO, BytesIO | None, BytesIO | None]: """Plot heatmap or histogram from OAR dataframe with return-to-go. Heatmap if there is 2 labels, histogram is for 1 label. Heatmap color or histogram height is the return-to-go averaged against grouping labels. The number of samples used in the average is displayed as a number in the tiles for the heatmap and in the bar for the histogram. Args: df: the dataframe to plot, it should validates :class:`RTGSchema`, col_labels: the labels used to group the dataframe before averaging, - first label is xaxis - second label is yaxis, only for heatmap filename: if provided, plot will be saved in this file else, plot will be displayed normalize_on_x: whether to normalize return-to-go over x-axis argmax_on_x: whether to highlight argmax over x-axis csv_prefix: TODO Krzysztof csv_buffers: TODO Krzysztof Return: TODO Krzysztof Examples: Initialize OAR dataframe >>> import pandas as pd >>> from itertools import product >>> from sara.oar import OARSchema, enrich_rtg >>> oa0s = list(product([1, 2, 3], [1, 2, 3])) >>> obs0s, act0s = zip(*oa0s, strict=True) >>> rew1s = (float(o==a) for o, a in oa0s) >>> df = pd.DataFrame({ ... ("obs0", "context"): obs0s, ... ("act0", "choice"): act0s, ... ("rew1", "money"): rew1s}) >>> df.columns.names = ("signal", "key") >>> df = df.set_index( ... pd.MultiIndex.from_product([[0], pd.date_range("2000-01-01", periods=len(obs0s))], ... names=["episode", "date"])) >>> df = OARSchema.validate(df) >>> df = enrich_rtg(df, .5) Plot histogram visualization >>> from sara.viz import plot_insight >>> from pathlib import Path >>> filename_histogram = Path( ... 'docs/source/viz/histogram.png') # to update plot in doc >>> filename_heatmap = Path( ... 'docs/source/viz/heatmap.png') # to update plot in doc >>> _,_,_ = plot_insight( ... df, col_labels=[("act0", "choice")], ... filename=filename_histogram) # histogram .. image:: histogram.png >>> _,_,_=plot_insight( ... df, col_labels=[("act0", "choice"), ("obs0", "context")], ... filename=filename_heatmap) # heatmap .. image:: heatmap.png """ # group against labels and pivot dataframe table_rtg, table_count = df2tables(df, col_labels) # branch plots if len(col_labels) == MAX_LABELS: return plot_heatmap( table_rtg, table_count, filename, normalize_on_x=normalize_on_x, argmax_on_x=argmax_on_x, csv_prefix=csv_prefix, csv_buffers=csv_buffers, ) elif len(col_labels) == 1: img = plot_histogram( table_rtg, table_count, filename, normalize_on_x=normalize_on_x, argmax_on_x=argmax_on_x, csv_prefix=csv_prefix, ) return img, None, None else: raise NotImplementedError