Source code for sara.oar.rtg

"""Utilities."""

from typing import Literal

import numpy as np
import pandas as pd
import pandera.pandas as pa
from pandera.typing.pandas import DataFrame

from .schemas import OARSchema, RTGSchema


def discounted_cumsum(
    rews: np.ndarray,
    discount: float,
    val: float,
) -> np.ndarray:
    r"""Return-to-go estimated with Python for loop.

    Return-to-go is defined by:

    .. math::
        R_t &= \sum_{i=t}^{T-1} \gamma^{i-t}r_{i+1} \\
            &= r_{t+1} + \gamma R_{t+1}

    where :math:`\gamma` is the discount factor and :math:`T` is the episode
    termination. If the episode is instead truncated at :math:`\tilde T`,
    Return-to-go is estimated with:

    .. math::
        \tilde R_t = \sum_{i=t}^{\tilde T-1} \gamma^{i-t}r_{i+1}
        + \gamma^{\tilde T-t}V_{\tilde T}

    where :math:`V_{\tilde T}` is a bootstrap value.

    Note:
        - I tried to jit compile with numba but compilation overhead was bottleneck.
        - can be batched with padding

    Args:
        rews (np.ndarray): :math:`r_1, r_2, ..., r_T` rewards along episode
        discount (float): discount factor :math:`\gamma`
        val (float): estimation of :math:`R_T` (bootstrapped value)

    Returns:
        np.ndarray: the estimated return-to-go :math:`R_0, R_1, ..., R_{T-1}`

    """
    rtgs = np.empty_like(rews)
    rtg = val
    for t in reversed(range(len(rews))):
        rtg = rews[t] + discount * rtg
        rtgs[t] = rtg
    return rtgs


def get_rtg_by_episode(
    df: pd.DataFrame,
    discount: float,
    value_estimation_mode: Literal["zero", "mean", "scale"],
) -> pd.DataFrame:
    """Apply :func:`discounted_cumsum` to each episodes.

    Args:
        df: the episodic dataframe (after groupby over level not date)
        discount: the discount
        value_estimation_mode: mode for bootstrap value estimation
    Returns:
        pd.Series: the Series of return-to-go over the episode
            (same index and dtype as rews)

    """
    dates = df.index.get_level_values("date")
    diff = dates.diff()[1:]
    if len(diff) > 0 and not bool((diff == diff[0]).all()):
        msg = "Don't know what to do yet when date diff is not uniform"
        raise NotImplementedError(msg)

    rew1_key = next(k for k in df.columns if k[0] == "rew1")
    rews = df[rew1_key].to_numpy()
    terminated = False
    if "term1" in [k[0] for k in df.columns]:
        term1_key = next(k for k in df.columns if k[0] == "term1")
        terms = df[term1_key].to_numpy()
        terminated = terms[-1]

    val = 0.0
    if (value_estimation_mode == "mean") and (not terminated):
        val = float(np.mean(rews) / (1.0 - discount))
    rtgs = discounted_cumsum(rews, discount, val)
    if (value_estimation_mode == "scale") and (not terminated):
        pows = np.pow(discount, len(rews) - np.arange(len(rews)))
        rtgs = rtgs / (1.0 - pows)

    # NOTE: Using DataFrame with column rtg0 instead of Series
    # because Series lead to date as index with only one shop
    # -> groupby and apply seems to not work well if inner_fn
    # -> outputs series instead of dataframe
    index, dtype = df.index, df[rew1_key].dtype
    return pd.DataFrame({"rtg0": rtgs}, index=index, dtype=dtype)


[docs] @pa.check_types def enrich_rtg( df: DataFrame[OARSchema], discount: float, value_estimation_mode: Literal["zero", "mean", "scale"] = "scale", rtg_key: str | None = None, mix_guardrail: bool = False, ) -> DataFrame[RTGSchema]: r"""Enrich OAR dataframe with return-to-go. Return-to-go from date :math:`t` is: .. math:: R_t = r_{t+1} + \rho r_{t+2} + \cdots + \rho^{T-t-1}r_T where :math:`\rho` is the discount, :math:`r_{t+i}` is the reward at date :math:`t+i`, date :math:`T` is the last date of the terminated episode. When the episode is truncated at date :math:`\tilde T`, the return-to-go is estimated with some bootstrap value :math:`v` as: .. math:: \tilde R_t = r_{t+1} + \rho r_{t+2} + \cdots + \rho^{\tilde T-t-1}r_{\tilde T} + \rho^{\tilde T-t}v Args: df: valid OAR dataframe all episodes should be truncated or all episodes should be terminated - when episodes are truncated, episodes are considered infinite and are bootstrapped. - when episodes are terminated, bootstrap is ignored. discount: discount factor value_estimation_mode: mode for bootstrap value estimation for truncated episodes, - if "zero": the bootstrap value is zero - if "mean": the bootstrap value is estimated with the reward mean over episode :math:`r` as :math:`\frac{r}{1-\rho}` - if "scale" (default): the return-to-go at :math:`t` with no bootstrap value is scaled with :math:`\frac{1}{1-\rho^{\tilde T-t}}`. This is equivalent to bootstrap with :math:`\frac{r}{1-\rho}` where :math:`r` being the discounted mean from :math:`t+1` to :math:`\tilde T` rtg_key: the key for the new rtg0 column, - if None, the key is "cumulative " + key of the rew1 column mix_guardrail: rtg can't be estimated yet with mix of terminated and truncated episodes. If True, assert episodes are not mixed which is the correct alternative. Default to False while we seek a solution. Returns: DataFrame[RTGSchema]: the original dataframe with an additional "rtg0" column Examples: Initialize OAR dataframe >>> import pandas as pd >>> df = pd.DataFrame({ ... ("act0", "choice"): [1, 2, 3], ... ("rew1", "money"): [1., 2., 3.], ... ("term1", "fundraising end"): [False, False, True]}) >>> df.columns.names = ("signal", "key") >>> df = df.set_index( ... pd.MultiIndex.from_product([[0], pd.date_range("2000-01-01", "2000-01-03")], ... names=["episode", "date"])) validate initial dataframe >>> from sara.oar import OARSchema, RTGSchema, enrich_rtg >>> OARSchema.validate(df) # doctest: +NORMALIZE_WHITESPACE signal act0 rew1 term1 key choice money fundraising end episode date 0 2000-01-01 1 1.0 False 2000-01-02 2 2.0 False 2000-01-03 3 3.0 True enrich dataframe with return-to-go, df then validates RTGSchema >>> df = enrich_rtg(df, 0.9) >>> RTGSchema.validate(df) # doctest: +NORMALIZE_WHITESPACE signal act0 rew1 term1 rtg0 key choice money fundraising end cumulative money episode date 0 2000-01-01 1 1.0 False 5.23 2000-01-02 2 2.0 False 4.70 2000-01-03 3 3.0 True 3.00 check correct return-to-go >>> import numpy as np >>> rtg0 = df[("rtg0", "cumulative money")] >>> rtg0_tgt = np.array([1. + .9*2. + .9**2*3., 2. + .9*3., 3.]) >>> all(rtg0.to_numpy() == rtg0_tgt) True """ levels_not_date = [name for name in df.index.names if name != "date"] # check there is no mix of truncated and terminated episodes if mix_guardrail and ("term1" in [k[0] for k in df.columns]) and (len(levels_not_date) > 0): term1_key = next(k for k in df.columns if k[0] == "term1") grp = df.groupby(level=levels_not_date, group_keys=False) terminateds = grp.apply(lambda sdf: sdf[term1_key].iloc[-1]) assert bool(terminateds.all()) or bool((~terminateds).all()), ( "return-to-go estimation is not available for mix of terminated " "and truncated episodes: either all episodes are terminated " "or all episodes are truncated." ) # compute rtgs rtgs = df.groupby(level=levels_not_date, group_keys=False).apply( lambda subdf: get_rtg_by_episode(subdf, discount, value_estimation_mode), include_groups=False, ) # replace key associated with 'rtg0' rew1_key = next(k for k in df.columns if k[0] == "rew1") rtg_key = ( ("rtg0", rtg_key) if rtg_key is not None else ("rtg0", "cumulative " + rew1_key[1]) ) # copy to avoid inplace modification df_out = df.copy() df_out[rtg_key] = rtgs["rtg0"] return DataFrame[RTGSchema](df_out)