Source code for sara.oar.utils

"""Minor tools for OAR schema and OAR dataframe manipulation."""

import pandas as pd
import pandera.pandas as pa
from pandera.typing.pandas import DataFrame

from sara.oar import RTGSchema
from sara.oar.schemas import OARSchema


[docs] def get_keys_dataframe(schema: pa.DataFrameSchema) -> pd.DataFrame: """Transform columns of a schema into a nice dataframe ready for printing. Args: schema: an OARSchema Returns: pd.DataFrame: a dataframe with key as index and its signal, description as columns Examples: >>> import numpy as np >>> import pandas as pd >>> import pandera.pandas as pa >>> schema = pa.DataFrameSchema(columns={ ... ('act0', 'choice'): pa.Column(int, metadata={ ... 'description': 'You take the blue pill the story ends, ' ... 'you wake up in your bed and believe whatever you ' ... 'want to believe. You take the red pill you stay ' ... 'in Wonderland, and I show you how deep the rabbit ' ... 'hole goes.'}), ... ('rew1', 'reality'): pa.Column(float, metadata={ ... 'description': 'What is real? How do you define real? If you re ' ... 'talking about what you can feel, what you can ' ... 'smell, what you can taste and see, then real is ' ... 'simply electrical signals interpreted by your ' ... 'brain.'})}, ... index={'date': pa.Index(np.datetime64)}) >>> from sara.oar import get_keys_dataframe >>> with pd.option_context( # doctest: +NORMALIZE_WHITESPACE ... "display.max_colwidth", None): ... print(get_keys_dataframe(schema)) signal description key choice act0 You take the blue pill the story ends, you wake up in your bed and believe whatever you want to believe. You take the red pill you stay in Wonderland, and I show you how deep the rabbit hole goes. """ cols = schema.columns sig_key = [c for c in cols if c[0] != "rew1"] return pd.DataFrame( { "key": [c[1] for c in sig_key], "signal": [c[0] for c in sig_key], "description": [cols[c].metadata["description"] for c in sig_key], }, ).set_index("key", verify_integrity=True)
[docs] def filter_with_query( df: DataFrame[RTGSchema], query: str, ) -> DataFrame[RTGSchema]: """Filter RTG dataframe with DataFrame.query. use only keys instead of signal, key pairs in the query for conveniance. Args: df: the input dataframe is a rtg dataframe because filtering is applied after rtg calculations query: the query cf :func:`pd.DataFrame.query` Returns: DataFrame[RTGSchema]: filtered dataframe """ old_columns = df.columns df.columns = df.columns.get_level_values("key") df_out = df.query(query) df_out.columns = old_columns return RTGSchema.validate(df_out)
[docs] def bin_with_quantiles( df: DataFrame[RTGSchema], num_quantiles: dict[tuple[str, str], int], ) -> DataFrame[RTGSchema]: """Gather values of RTG dataframe into quantiles with pd.qcut. Args: df: the input dataframe is a rtg dataframe because binning is applied after rtg calculations num_quantiles: number of quantiles for each column Returns: DataFrame[RTGSchema]: a new dataframe with bin names instead of values """ df_out = df.copy() for k, v in num_quantiles.items(): if v is not None: df_out[k] = pd.qcut(df[k], v, duplicates="drop", precision=0) return RTGSchema.validate(df_out)
[docs] def discount_from_horizon( horizon: float, tolerance: float = 0.05, ) -> float: r"""Select the discount in a more intuitive way. Which is the discount such that its geometric series becomes negligible beyond the horizon. Specifically, .. math:: \sum_i d^{H+i} = \alpha\sum_i d^i where :math:`H` is the horizon and :math:`\alpha` is the tolerance. Args: horizon: a positive scalar setting the time window of interest: - a small time window will miss longterm influence of the action, - a large time window will dilute influence of the action. tolerance: a scalar between 0 and 1 setting negligibility. Returns: float: the discount Examples: >>> from sara.oar import discount_from_horizon >>> discount_from_horizon(30) 0.9049661471446958 """ return tolerance ** (1.0 / horizon)