Source code for sara.oar.utils
"""Minor tools for OAR schema and OAR dataframe manipulation."""
import pandas as pd
import pandera.pandas as pa
from pandera.typing.pandas import DataFrame
from sara.oar import RTGSchema
from sara.oar.schemas import OARSchema
[docs]
def get_keys_dataframe(schema: pa.DataFrameSchema) -> pd.DataFrame:
"""Transform columns of a schema into a nice dataframe ready for printing.
Args:
schema: an OARSchema
Returns:
pd.DataFrame: a dataframe with key as index
and its signal, description as columns
Examples:
>>> import numpy as np
>>> import pandas as pd
>>> import pandera.pandas as pa
>>> schema = pa.DataFrameSchema(columns={
... ('act0', 'choice'): pa.Column(int, metadata={
... 'description': 'You take the blue pill the story ends, '
... 'you wake up in your bed and believe whatever you '
... 'want to believe. You take the red pill you stay '
... 'in Wonderland, and I show you how deep the rabbit '
... 'hole goes.'}),
... ('rew1', 'reality'): pa.Column(float, metadata={
... 'description': 'What is real? How do you define real? If you re '
... 'talking about what you can feel, what you can '
... 'smell, what you can taste and see, then real is '
... 'simply electrical signals interpreted by your '
... 'brain.'})},
... index={'date': pa.Index(np.datetime64)})
>>> from sara.oar import get_keys_dataframe
>>> with pd.option_context( # doctest: +NORMALIZE_WHITESPACE
... "display.max_colwidth", None):
... print(get_keys_dataframe(schema))
signal description
key
choice act0 You take the blue pill the story ends, you wake up in your
bed and believe whatever you want to believe. You take the red
pill you stay in Wonderland, and I show you how deep the rabbit
hole goes.
"""
cols = schema.columns
sig_key = [c for c in cols if c[0] != "rew1"]
return pd.DataFrame(
{
"key": [c[1] for c in sig_key],
"signal": [c[0] for c in sig_key],
"description": [cols[c].metadata["description"] for c in sig_key],
},
).set_index("key", verify_integrity=True)
[docs]
def filter_with_query(
df: DataFrame[RTGSchema],
query: str,
) -> DataFrame[RTGSchema]:
"""Filter RTG dataframe with DataFrame.query.
use only keys instead of signal, key pairs in the query
for conveniance.
Args:
df: the input dataframe is a rtg dataframe because filtering is applied
after rtg calculations
query: the query cf :func:`pd.DataFrame.query`
Returns:
DataFrame[RTGSchema]: filtered dataframe
"""
old_columns = df.columns
df.columns = df.columns.get_level_values("key")
df_out = df.query(query)
df_out.columns = old_columns
return RTGSchema.validate(df_out)
[docs]
def bin_with_quantiles(
df: DataFrame[RTGSchema],
num_quantiles: dict[tuple[str, str], int],
) -> DataFrame[RTGSchema]:
"""Gather values of RTG dataframe into quantiles with pd.qcut.
Args:
df: the input dataframe is a rtg dataframe because binning is applied
after rtg calculations
num_quantiles: number of quantiles for each column
Returns:
DataFrame[RTGSchema]: a new dataframe with bin names instead of values
"""
df_out = df.copy()
for k, v in num_quantiles.items():
if v is not None:
df_out[k] = pd.qcut(df[k], v, duplicates="drop", precision=0)
return RTGSchema.validate(df_out)
[docs]
def discount_from_horizon(
horizon: float,
tolerance: float = 0.05,
) -> float:
r"""Select the discount in a more intuitive way.
Which is the discount such that its geometric series becomes negligible beyond
the horizon.
Specifically,
.. math::
\sum_i d^{H+i} = \alpha\sum_i d^i
where :math:`H` is the horizon and :math:`\alpha` is the tolerance.
Args:
horizon: a positive scalar setting the time window of interest:
- a small time window will miss longterm influence of the action,
- a large time window will dilute influence of the action.
tolerance: a scalar between 0 and 1 setting negligibility.
Returns:
float: the discount
Examples:
>>> from sara.oar import discount_from_horizon
>>> discount_from_horizon(30)
0.9049661471446958
"""
return tolerance ** (1.0 / horizon)