"""Plotting function for SARA-VIZ."""
from io import BytesIO, StringIO
from pathlib import Path
from typing import Final, BinaryIO
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pandera.pandas as pa
from matplotlib.patches import Rectangle
from pandera.typing import DataFrame
from sara.oar import RTGSchema
MAX_LABELS: Final[int] = 2
def cleaner_for_label(x: tuple[str, str]) -> str:
"""Transform column to clean label for plots.
Example:
>>> cleaner_for_label(('rew0', 'money'))
'rew0: money'
"""
return f"{x[0]}: {x[1]}"
def df2tables(
df: pd.DataFrame,
col_labels: list,
) -> tuple[pd.DataFrame, pd.DataFrame]:
"""Group and Pivot environment DataFrame for plotting function."""
rtg_sth = next(k for k in df.columns if k[0] == "rtg0")
rtg_flat = cleaner_for_label(rtg_sth)
labels_flat = [cleaner_for_label(label) for label in col_labels]
# Group by labels, take mean and count
gb = df.reset_index().groupby(col_labels, observed=True, dropna=False)
df_rtg = (
gb.apply(lambda df: df[rtg_sth].mean(), include_groups=False)
.to_frame(rtg_flat)
.reset_index()
)
df_rtg.columns = labels_flat + [rtg_flat]
df_count = (
gb.apply(lambda df: df[rtg_sth].count(), include_groups=False)
.to_frame("count")
.reset_index()
)
df_count.columns = labels_flat + ["count"]
# prepare pivot_table index and columns
if len(labels_flat) == MAX_LABELS:
index = [labels_flat[1]]
columns = [labels_flat[0]]
elif len(labels_flat) == 1:
index = [labels_flat[0]]
columns = None
else:
raise NotImplementedError
# pivot_table
table_rtg = df_rtg.pivot_table(
index=index,
columns=columns,
values=rtg_flat,
observed=True,
dropna=False,
)
table_rtg.name = rtg_flat
table_count = df_count.pivot_table(
index=index,
columns=columns,
values="count",
observed=True,
dropna=False,
)
# pivoting can introduce nans where (x, y) pairs
# do not exist in dataset,
# their count is set to 0 while their rtg is left nan
table_count = table_count.fillna(0)
return table_rtg, table_count
def _finalize_plot(fig, save_path: str | Path | None = None) -> BytesIO:
img_stream = BytesIO()
fig.savefig(img_stream, format="png", bbox_inches="tight")
img_stream.seek(0)
if save_path:
fig.savefig(save_path, format="png", bbox_inches="tight")
else:
plt.show()
plt.close(fig)
return img_stream
def plot_heatmap(
table_rtg: pd.DataFrame,
table_count: pd.DataFrame,
filename: str | Path | None = None,
*,
normalize_on_x: bool = False,
argmax_on_x: bool = False,
csv_prefix: str | Path | None = None,
csv_buffers: dict | None = None,
) -> tuple[BytesIO, BytesIO | None, BytesIO | None]:
"""Plot the Q table heatmap in case of 2 labels are given.
Args:
table_rtg: the pivoted table of RTGs from :func:`df2tables`
table_count: the pivoted table of counts from :func:`df2tables`
filename: if provided, plot will be saved in this file
normalize_on_x: whether to normalize RTG values along xaxis to improve
comparability of RTG inside this axis
argmax_on_x: whether to display a red box around tiles that realizes
the max of RTG along xaxis.
csv_prefix: TODO Krzysztof
csv_buffers: TODO Krzysztof
Returns:
TODO Krzysztof
"""
rtg_buf_out = None
count_buf_out = None
if csv_prefix is not None:
table_rtg.to_csv(f"{csv_prefix}_rtg.csv")
table_count.to_csv(f"{csv_prefix}_count.csv")
if csv_buffers is not None:
rtg_target = csv_buffers.get("rtg")
count_target = csv_buffers.get("count")
if rtg_target is not None:
s = StringIO()
table_rtg.to_csv(s)
rtg_target.write(s.getvalue().encode())
rtg_target.seek(0)
rtg_buf_out = rtg_target
if count_target is not None:
s = StringIO()
table_count.to_csv(s)
count_target.write(s.getvalue().encode())
count_target.seek(0)
count_buf_out = count_target
fig, ax = plt.subplots(figsize=(8, 6))
data = table_rtg.to_numpy()
# heatmap
data_to_plot = data
if normalize_on_x:
mx = np.nanmax(data, axis=1, keepdims=True)
mn = np.nanmin(data, axis=1, keepdims=True)
rge = mx - mn
data_to_plot = np.where(rge > 0, (data - mn) / rge, 0.5)
cax = ax.imshow(data_to_plot, cmap="viridis", interpolation="None", aspect="auto")
# count
for ix in range(len(table_count.columns)):
for iy in range(len(table_count.index)):
c = table_count.iloc[iy, ix]
txt = f"{c:.0f}"
ax.text(ix, iy, txt, ha="center", va="center", color="black", fontsize=12)
amx = np.nanargmax(data, axis=1) if argmax_on_x else []
for i, j in enumerate(amx):
# The anchor point for the rectangle is the bottom-left corner of the cell.
# For a cell at (row, col) = (i, j), the corner is at (j-0.5, i-0.5).
rect = Rectangle(
(j - 0.5, i - 0.5), # (x,y) bottom-left corner
1,
1, # width, height
edgecolor="red",
facecolor="none",
lw=2,
)
ax.add_patch(rect)
y_labels = table_rtg.index
y_name = y_labels.name
x_labels = table_rtg.columns
x_name = table_rtg.columns.name
# Set the ticks and labels using the labels from the pivoted table_rtg
ax.set_xticks(np.arange(len(x_labels)))
ax.set_yticks(np.arange(len(y_labels)))
ax.set_xticklabels(x_labels, rotation=-30)
ax.set_yticklabels(y_labels)
# Add a color bar and titles
# rtg_sth = cleaner_for_label(tuple(table_rtg.columns[0][:2]))
fig.colorbar(cax, label=table_rtg.name)
ax.set_xlabel(x_name)
ax.set_ylabel(y_name)
plt.tight_layout()
img_stream = _finalize_plot(fig, filename)
return img_stream, rtg_buf_out, count_buf_out
def plot_histogram(
table_rtg: pd.DataFrame,
table_count: pd.DataFrame,
filename: str | Path | None = None,
*,
normalize_on_x: bool = False,
argmax_on_x: bool = False,
csv_prefix: str | Path | None = None,
) -> BytesIO:
"""Plot the histogram in case 1 label is given."""
if csv_prefix is not None:
table_rtg.to_csv(f"{csv_prefix}_rtg.csv")
table_count.to_csv(f"{csv_prefix}_count.csv")
fig, ax = plt.subplots(figsize=(8, 6))
data = table_rtg.to_numpy().flatten()
amx = np.nanargmax(data)
mx = np.nanmax(data, keepdims=True)
mn = np.nanmin(data, keepdims=True)
rge = mx - mn
data_normalized = np.where(rge > 0, (data - mn) / rge, 0.5)
# histogram
data_to_plot = data_normalized if normalize_on_x else data
colors = None
if argmax_on_x:
colors = len(table_rtg) * ["C0"]
colors[amx] = "r"
ax.bar(range(len(table_rtg)), data_to_plot, color=colors)
ax.set_xticks(range(len(table_rtg)))
ax.set_xticklabels(table_rtg.index)
# count, dirty
for patch, count in zip(ax.patches, table_count["count"], strict=True):
# Get the coordinates for the text
x_pos = patch.get_x() + patch.get_width() / 2
y_pos = patch.get_height() / 2
# Add the text to the plot
ax.text(x_pos, y_pos, f"{count:.0f}", ha="center", va="center", fontsize=10)
x_name = table_rtg.index.name
y_name = table_rtg.columns[0]
ax.set_xlabel(x_name)
ax.set_ylabel(y_name)
plt.tight_layout()
return _finalize_plot(fig, filename)
[docs]
@pa.check_types
def plot_insight(
df: DataFrame[RTGSchema],
col_labels: list, # TODO(anthony): use a cleaner type
filename: str | Path | None = None,
*,
normalize_on_x: bool = False,
argmax_on_x: bool = False,
csv_prefix: str | None = None,
csv_buffers: dict[str, BinaryIO] | None = None,
) -> tuple[BytesIO, BytesIO | None, BytesIO | None]:
"""Plot heatmap or histogram from OAR dataframe with return-to-go.
Heatmap if there is 2 labels, histogram is for 1 label. Heatmap color or
histogram height is the return-to-go averaged against grouping labels. The
number of samples used in the average is displayed as a number in the tiles
for the heatmap and in the bar for the histogram.
Args:
df: the dataframe to plot, it should validates :class:`RTGSchema`,
col_labels: the labels used to group the dataframe before averaging,
- first label is xaxis
- second label is yaxis, only for heatmap
filename: if provided, plot will be saved in this file
else, plot will be displayed
normalize_on_x: whether to normalize return-to-go over x-axis
argmax_on_x: whether to highlight argmax over x-axis
csv_prefix: TODO Krzysztof
csv_buffers: TODO Krzysztof
Return:
TODO Krzysztof
Examples:
Initialize OAR dataframe
>>> import pandas as pd
>>> from itertools import product
>>> from sara.oar import OARSchema, enrich_rtg
>>> oa0s = list(product([1, 2, 3], [1, 2, 3]))
>>> obs0s, act0s = zip(*oa0s, strict=True)
>>> rew1s = (float(o==a) for o, a in oa0s)
>>> df = pd.DataFrame({
... ("obs0", "context"): obs0s,
... ("act0", "choice"): act0s,
... ("rew1", "money"): rew1s})
>>> df.columns.names = ("signal", "key")
>>> df = df.set_index(
... pd.MultiIndex.from_product([[0], pd.date_range("2000-01-01", periods=len(obs0s))],
... names=["episode", "date"]))
>>> df = OARSchema.validate(df)
>>> df = enrich_rtg(df, .5)
Plot histogram visualization
>>> from sara.viz import plot_insight
>>> from pathlib import Path
>>> filename_histogram = Path(
... 'docs/source/viz/histogram.png') # to update plot in doc
>>> filename_heatmap = Path(
... 'docs/source/viz/heatmap.png') # to update plot in doc
>>> _,_,_ = plot_insight(
... df, col_labels=[("act0", "choice")],
... filename=filename_histogram) # histogram
.. image:: histogram.png
>>> _,_,_=plot_insight(
... df, col_labels=[("act0", "choice"), ("obs0", "context")],
... filename=filename_heatmap) # heatmap
.. image:: heatmap.png
"""
# group against labels and pivot dataframe
table_rtg, table_count = df2tables(df, col_labels)
# branch plots
if len(col_labels) == MAX_LABELS:
return plot_heatmap(
table_rtg,
table_count,
filename,
normalize_on_x=normalize_on_x,
argmax_on_x=argmax_on_x,
csv_prefix=csv_prefix,
csv_buffers=csv_buffers,
)
elif len(col_labels) == 1:
img = plot_histogram(
table_rtg,
table_count,
filename,
normalize_on_x=normalize_on_x,
argmax_on_x=argmax_on_x,
csv_prefix=csv_prefix,
)
return img, None, None
else:
raise NotImplementedError