sara.oar

The OAR format structures business data for quantitative decision making with RL.

This module contains functions to create, validate and transform OAR pandas dataframes.

class sara.oar.OARSchema(*args, **kwargs)[source]

The base OAR format for pandas DataFrame.

Raises:

SchemaError – if the input dataframe do not validate the schema

Examples

Initialize minimal OAR dataframe:

>>> import pandas as pd
>>> from sara.oar import OARSchema
>>> df = pd.DataFrame({
...      ("act0", "choice"): [1, 2, 3],
...      ("rew1", "money"): [1., 2., 3.]})
>>> df.columns.names = ("signal", "key")
>>> df = df.set_index(
...     pd.MultiIndex.from_product(
...         [[0], pd.date_range("2000-01-01", "2000-01-03")],
...         names=["episode", "date"]))

validate it against OARSchema:

>>> OARSchema.validate(df)
signal               act0  rew1
key                choice money
episode date
0       2000-01-01      1   1.0
        2000-01-02      2   2.0
        2000-01-03      3   3.0

Note

  • episodes ending with term1 == False are considered truncated

  • if no term1 column is present, episodes are considered truncated

  • date should be an integer or a datetime

Todo

  • maybe it is more general to include truncation signal, it allows to slice episode how we want and eventually batch them

class sara.oar.RTGSchema(*args, **kwargs)[source]

An OARSchema with a non-optional “rtg0” signal.

sara.oar.bin_with_quantiles(df: DataFrame[RTGSchema], num_quantiles: dict[tuple[str, str], int]) DataFrame[RTGSchema][source]

Gather values of RTG dataframe into quantiles with pd.qcut.

Parameters:
  • df – the input dataframe is a rtg dataframe because binning is applied after rtg calculations

  • num_quantiles – number of quantiles for each column

Returns:

a new dataframe with bin names instead of values

Return type:

DataFrame[RTGSchema]

sara.oar.discount_from_horizon(horizon: float, tolerance: float = 0.05) float[source]

Select the discount in a more intuitive way.

Which is the discount such that its geometric series becomes negligible beyond the horizon.

Specifically,

\[\sum_i d^{H+i} = \alpha\sum_i d^i\]

where \(H\) is the horizon and \(\alpha\) is the tolerance.

Parameters:
  • horizon – a positive scalar setting the time window of interest: - a small time window will miss longterm influence of the action, - a large time window will dilute influence of the action.

  • tolerance – a scalar between 0 and 1 setting negligibility.

Returns:

the discount

Return type:

float

Examples

>>> from sara.oar import discount_from_horizon
>>> discount_from_horizon(30)
0.9049661471446958
sara.oar.enrich_rtg(df: DataFrame[OARSchema], discount: float, value_estimation_mode: Literal['zero', 'mean', 'scale'] = 'scale', rtg_key: str | None = None, mix_guardrail: bool = False) DataFrame[RTGSchema][source]

Enrich OAR dataframe with return-to-go.

Return-to-go from date \(t\) is:

\[R_t = r_{t+1} + \rho r_{t+2} + \cdots + \rho^{T-t-1}r_T\]

where \(\rho\) is the discount, \(r_{t+i}\) is the reward at date \(t+i\), date \(T\) is the last date of the terminated episode.

When the episode is truncated at date \(\tilde T\), the return-to-go is estimated with some bootstrap value \(v\) as:

\[\tilde R_t = r_{t+1} + \rho r_{t+2} + \cdots + \rho^{\tilde T-t-1}r_{\tilde T} + \rho^{\tilde T-t}v\]
Parameters:
  • df

    valid OAR dataframe all episodes should be truncated or all episodes should be terminated

    • when episodes are truncated, episodes are considered infinite and are bootstrapped.

    • when episodes are terminated, bootstrap is ignored.

  • discount – discount factor

  • value_estimation_mode

    mode for bootstrap value estimation for truncated episodes,

    • if “zero”: the bootstrap value is zero

    • if “mean”: the bootstrap value is estimated with the reward mean over episode \(r\) as \(\frac{r}{1-\rho}\)

    • if “scale” (default): the return-to-go at \(t\) with no bootstrap value is scaled with \(\frac{1}{1-\rho^{\tilde T-t}}\). This is equivalent to bootstrap with \(\frac{r}{1-\rho}\) where \(r\) being the discounted mean from \(t+1\) to \(\tilde T\)

  • rtg_key

    the key for the new rtg0 column,

    • if None, the key is “cumulative “ + key of the rew1 column

  • mix_guardrail – rtg can’t be estimated yet with mix of terminated and truncated episodes. If True, assert episodes are not mixed which is the correct alternative. Default to False while we seek a solution.

Returns:

the original dataframe with an additional “rtg0” column

Return type:

DataFrame[RTGSchema]

Examples

Initialize OAR dataframe

>>> import pandas as pd
>>> df = pd.DataFrame({
...      ("act0", "choice"): [1, 2, 3],
...      ("rew1", "money"): [1., 2., 3.],
...      ("term1", "fundraising end"): [False, False, True]})
>>> df.columns.names = ("signal", "key")
>>> df = df.set_index(
...     pd.MultiIndex.from_product([[0], pd.date_range("2000-01-01", "2000-01-03")],
...     names=["episode", "date"]))

validate initial dataframe

>>> from sara.oar import OARSchema, RTGSchema, enrich_rtg
>>> OARSchema.validate(df)
signal               act0  rew1           term1
key                choice money fundraising end
episode date
0       2000-01-01      1   1.0           False
        2000-01-02      2   2.0           False
        2000-01-03      3   3.0            True

enrich dataframe with return-to-go, df then validates RTGSchema

>>> df = enrich_rtg(df, 0.9)
>>> RTGSchema.validate(df)
signal               act0  rew1           term1             rtg0
key                choice money fundraising end cumulative money
episode date
0       2000-01-01      1   1.0           False             5.23
        2000-01-02      2   2.0           False             4.70
        2000-01-03      3   3.0            True             3.00

check correct return-to-go

>>> import numpy as np
>>> rtg0 = df[("rtg0", "cumulative money")]
>>> rtg0_tgt = np.array([1. + .9*2. + .9**2*3., 2. + .9*3., 3.])
>>> all(rtg0.to_numpy() == rtg0_tgt)
True
sara.oar.filter_with_query(df: DataFrame[RTGSchema], query: str) DataFrame[RTGSchema][source]

Filter RTG dataframe with DataFrame.query.

use only keys instead of signal, key pairs in the query for conveniance.

Parameters:
  • df – the input dataframe is a rtg dataframe because filtering is applied after rtg calculations

  • query – the query cf pd.DataFrame.query()

Returns:

filtered dataframe

Return type:

DataFrame[RTGSchema]

sara.oar.get_keys_dataframe(schema: DataFrameSchema) DataFrame[source]

Transform columns of a schema into a nice dataframe ready for printing.

Parameters:

schema – an OARSchema

Returns:

a dataframe with key as index

and its signal, description as columns

Return type:

pd.DataFrame

Examples

>>> import numpy as np
>>> import pandas as pd
>>> import pandera.pandas as pa
>>> schema = pa.DataFrameSchema(columns={
...     ('act0', 'choice'): pa.Column(int, metadata={
...         'description': 'You take the blue pill the story ends, '
...                        'you wake up in your bed and believe whatever you '
...                        'want to believe. You take the red pill you stay '
...                        'in Wonderland, and I show you how deep the rabbit '
...                        'hole goes.'}),
...     ('rew1', 'reality'): pa.Column(float, metadata={
...         'description': 'What is real? How do you define real? If you re '
...                        'talking about what you can feel, what you can '
...                        'smell, what you can taste and see, then real is '
...                        'simply electrical signals interpreted by your '
...                        'brain.'})},
...     index={'date': pa.Index(np.datetime64)})
>>> from sara.oar import get_keys_dataframe
>>> with pd.option_context(
...     "display.max_colwidth", None):
...     print(get_keys_dataframe(schema))
       signal                                                      description
key
choice   act0  You take the blue pill the story ends, you wake up in your
               bed and believe whatever you want to believe. You take the red
               pill you stay in Wonderland, and I show you how deep the rabbit
               hole goes.