sara.oar¶
The OAR format structures business data for quantitative decision making with RL.
This module contains functions to create, validate and transform OAR pandas dataframes.
- class sara.oar.OARSchema(*args, **kwargs)[source]¶
The base OAR format for pandas DataFrame.
- Raises:
SchemaError – if the input dataframe do not validate the schema
Examples
Initialize minimal OAR dataframe:
>>> import pandas as pd >>> from sara.oar import OARSchema >>> df = pd.DataFrame({ ... ("act0", "choice"): [1, 2, 3], ... ("rew1", "money"): [1., 2., 3.]}) >>> df.columns.names = ("signal", "key") >>> df = df.set_index( ... pd.MultiIndex.from_product( ... [[0], pd.date_range("2000-01-01", "2000-01-03")], ... names=["episode", "date"]))
validate it against OARSchema:
>>> OARSchema.validate(df) signal act0 rew1 key choice money episode date 0 2000-01-01 1 1.0 2000-01-02 2 2.0 2000-01-03 3 3.0
Note
episodes ending with
term1 == Falseare considered truncatedif no
term1column is present, episodes are considered truncateddate should be an integer or a datetime
Todo
maybe it is more general to include truncation signal, it allows to slice episode how we want and eventually batch them
- sara.oar.bin_with_quantiles(df: DataFrame[RTGSchema], num_quantiles: dict[tuple[str, str], int]) DataFrame[RTGSchema][source]¶
Gather values of RTG dataframe into quantiles with pd.qcut.
- Parameters:
df – the input dataframe is a rtg dataframe because binning is applied after rtg calculations
num_quantiles – number of quantiles for each column
- Returns:
a new dataframe with bin names instead of values
- Return type:
DataFrame[RTGSchema]
- sara.oar.discount_from_horizon(horizon: float, tolerance: float = 0.05) float[source]¶
Select the discount in a more intuitive way.
Which is the discount such that its geometric series becomes negligible beyond the horizon.
Specifically,
\[\sum_i d^{H+i} = \alpha\sum_i d^i\]where \(H\) is the horizon and \(\alpha\) is the tolerance.
- Parameters:
horizon – a positive scalar setting the time window of interest: - a small time window will miss longterm influence of the action, - a large time window will dilute influence of the action.
tolerance – a scalar between 0 and 1 setting negligibility.
- Returns:
the discount
- Return type:
float
Examples
>>> from sara.oar import discount_from_horizon >>> discount_from_horizon(30) 0.9049661471446958
- sara.oar.enrich_rtg(df: DataFrame[OARSchema], discount: float, value_estimation_mode: Literal['zero', 'mean', 'scale'] = 'scale', rtg_key: str | None = None, mix_guardrail: bool = False) DataFrame[RTGSchema][source]¶
Enrich OAR dataframe with return-to-go.
Return-to-go from date \(t\) is:
\[R_t = r_{t+1} + \rho r_{t+2} + \cdots + \rho^{T-t-1}r_T\]where \(\rho\) is the discount, \(r_{t+i}\) is the reward at date \(t+i\), date \(T\) is the last date of the terminated episode.
When the episode is truncated at date \(\tilde T\), the return-to-go is estimated with some bootstrap value \(v\) as:
\[\tilde R_t = r_{t+1} + \rho r_{t+2} + \cdots + \rho^{\tilde T-t-1}r_{\tilde T} + \rho^{\tilde T-t}v\]- Parameters:
df –
valid OAR dataframe all episodes should be truncated or all episodes should be terminated
when episodes are truncated, episodes are considered infinite and are bootstrapped.
when episodes are terminated, bootstrap is ignored.
discount – discount factor
value_estimation_mode –
mode for bootstrap value estimation for truncated episodes,
if “zero”: the bootstrap value is zero
if “mean”: the bootstrap value is estimated with the reward mean over episode \(r\) as \(\frac{r}{1-\rho}\)
if “scale” (default): the return-to-go at \(t\) with no bootstrap value is scaled with \(\frac{1}{1-\rho^{\tilde T-t}}\). This is equivalent to bootstrap with \(\frac{r}{1-\rho}\) where \(r\) being the discounted mean from \(t+1\) to \(\tilde T\)
rtg_key –
the key for the new rtg0 column,
if None, the key is “cumulative “ + key of the rew1 column
mix_guardrail – rtg can’t be estimated yet with mix of terminated and truncated episodes. If True, assert episodes are not mixed which is the correct alternative. Default to False while we seek a solution.
- Returns:
the original dataframe with an additional “rtg0” column
- Return type:
DataFrame[RTGSchema]
Examples
Initialize OAR dataframe
>>> import pandas as pd >>> df = pd.DataFrame({ ... ("act0", "choice"): [1, 2, 3], ... ("rew1", "money"): [1., 2., 3.], ... ("term1", "fundraising end"): [False, False, True]}) >>> df.columns.names = ("signal", "key") >>> df = df.set_index( ... pd.MultiIndex.from_product([[0], pd.date_range("2000-01-01", "2000-01-03")], ... names=["episode", "date"]))
validate initial dataframe
>>> from sara.oar import OARSchema, RTGSchema, enrich_rtg >>> OARSchema.validate(df) signal act0 rew1 term1 key choice money fundraising end episode date 0 2000-01-01 1 1.0 False 2000-01-02 2 2.0 False 2000-01-03 3 3.0 True
enrich dataframe with return-to-go, df then validates RTGSchema
>>> df = enrich_rtg(df, 0.9) >>> RTGSchema.validate(df) signal act0 rew1 term1 rtg0 key choice money fundraising end cumulative money episode date 0 2000-01-01 1 1.0 False 5.23 2000-01-02 2 2.0 False 4.70 2000-01-03 3 3.0 True 3.00
check correct return-to-go
>>> import numpy as np >>> rtg0 = df[("rtg0", "cumulative money")] >>> rtg0_tgt = np.array([1. + .9*2. + .9**2*3., 2. + .9*3., 3.]) >>> all(rtg0.to_numpy() == rtg0_tgt) True
- sara.oar.filter_with_query(df: DataFrame[RTGSchema], query: str) DataFrame[RTGSchema][source]¶
Filter RTG dataframe with DataFrame.query.
use only keys instead of signal, key pairs in the query for conveniance.
- Parameters:
df – the input dataframe is a rtg dataframe because filtering is applied after rtg calculations
query – the query cf
pd.DataFrame.query()
- Returns:
filtered dataframe
- Return type:
DataFrame[RTGSchema]
- sara.oar.get_keys_dataframe(schema: DataFrameSchema) DataFrame[source]¶
Transform columns of a schema into a nice dataframe ready for printing.
- Parameters:
schema – an OARSchema
- Returns:
- a dataframe with key as index
and its signal, description as columns
- Return type:
pd.DataFrame
Examples
>>> import numpy as np >>> import pandas as pd >>> import pandera.pandas as pa >>> schema = pa.DataFrameSchema(columns={ ... ('act0', 'choice'): pa.Column(int, metadata={ ... 'description': 'You take the blue pill the story ends, ' ... 'you wake up in your bed and believe whatever you ' ... 'want to believe. You take the red pill you stay ' ... 'in Wonderland, and I show you how deep the rabbit ' ... 'hole goes.'}), ... ('rew1', 'reality'): pa.Column(float, metadata={ ... 'description': 'What is real? How do you define real? If you re ' ... 'talking about what you can feel, what you can ' ... 'smell, what you can taste and see, then real is ' ... 'simply electrical signals interpreted by your ' ... 'brain.'})}, ... index={'date': pa.Index(np.datetime64)}) >>> from sara.oar import get_keys_dataframe >>> with pd.option_context( ... "display.max_colwidth", None): ... print(get_keys_dataframe(schema)) signal description key choice act0 You take the blue pill the story ends, you wake up in your bed and believe whatever you want to believe. You take the red pill you stay in Wonderland, and I show you how deep the rabbit hole goes.