"""Utilities."""
from typing import Literal
import numpy as np
import pandas as pd
import pandera.pandas as pa
from pandera.typing.pandas import DataFrame
from .schemas import OARSchema, RTGSchema
def discounted_cumsum(
rews: np.ndarray,
discount: float,
val: float,
) -> np.ndarray:
r"""Return-to-go estimated with Python for loop.
Return-to-go is defined by:
.. math::
R_t &= \sum_{i=t}^{T-1} \gamma^{i-t}r_{i+1} \\
&= r_{t+1} + \gamma R_{t+1}
where :math:`\gamma` is the discount factor and :math:`T` is the episode
termination. If the episode is instead truncated at :math:`\tilde T`,
Return-to-go is estimated with:
.. math::
\tilde R_t = \sum_{i=t}^{\tilde T-1} \gamma^{i-t}r_{i+1}
+ \gamma^{\tilde T-t}V_{\tilde T}
where :math:`V_{\tilde T}` is a bootstrap value.
Note:
- I tried to jit compile with numba but compilation overhead was bottleneck.
- can be batched with padding
Args:
rews (np.ndarray): :math:`r_1, r_2, ..., r_T` rewards along episode
discount (float): discount factor :math:`\gamma`
val (float): estimation of :math:`R_T` (bootstrapped value)
Returns:
np.ndarray: the estimated return-to-go :math:`R_0, R_1, ..., R_{T-1}`
"""
rtgs = np.empty_like(rews)
rtg = val
for t in reversed(range(len(rews))):
rtg = rews[t] + discount * rtg
rtgs[t] = rtg
return rtgs
def get_rtg_by_episode(
df: pd.DataFrame,
discount: float,
value_estimation_mode: Literal["zero", "mean", "scale"],
) -> pd.DataFrame:
"""Apply :func:`discounted_cumsum` to each episodes.
Args:
df: the episodic dataframe (after groupby over level not date)
discount: the discount
value_estimation_mode: mode for bootstrap value estimation
Returns:
pd.Series: the Series of return-to-go over the episode
(same index and dtype as rews)
"""
dates = df.index.get_level_values("date")
diff = dates.diff()[1:]
if len(diff) > 0 and not bool((diff == diff[0]).all()):
msg = "Don't know what to do yet when date diff is not uniform"
raise NotImplementedError(msg)
rew1_key = next(k for k in df.columns if k[0] == "rew1")
rews = df[rew1_key].to_numpy()
terminated = False
if "term1" in [k[0] for k in df.columns]:
term1_key = next(k for k in df.columns if k[0] == "term1")
terms = df[term1_key].to_numpy()
terminated = terms[-1]
val = 0.0
if (value_estimation_mode == "mean") and (not terminated):
val = float(np.mean(rews) / (1.0 - discount))
rtgs = discounted_cumsum(rews, discount, val)
if (value_estimation_mode == "scale") and (not terminated):
pows = np.pow(discount, len(rews) - np.arange(len(rews)))
rtgs = rtgs / (1.0 - pows)
# NOTE: Using DataFrame with column rtg0 instead of Series
# because Series lead to date as index with only one shop
# -> groupby and apply seems to not work well if inner_fn
# -> outputs series instead of dataframe
index, dtype = df.index, df[rew1_key].dtype
return pd.DataFrame({"rtg0": rtgs}, index=index, dtype=dtype)
[docs]
@pa.check_types
def enrich_rtg(
df: DataFrame[OARSchema],
discount: float,
value_estimation_mode: Literal["zero", "mean", "scale"] = "scale",
rtg_key: str | None = None,
mix_guardrail: bool = False,
) -> DataFrame[RTGSchema]:
r"""Enrich OAR dataframe with return-to-go.
Return-to-go from date :math:`t` is:
.. math:: R_t = r_{t+1} + \rho r_{t+2} + \cdots + \rho^{T-t-1}r_T
where :math:`\rho` is the discount, :math:`r_{t+i}` is the reward at date
:math:`t+i`, date :math:`T` is the last date of the terminated episode.
When the episode is truncated at date :math:`\tilde T`, the return-to-go is
estimated with some bootstrap value :math:`v` as:
.. math:: \tilde R_t = r_{t+1} + \rho r_{t+2} + \cdots +
\rho^{\tilde T-t-1}r_{\tilde T} + \rho^{\tilde T-t}v
Args:
df: valid OAR dataframe
all episodes should be truncated or all episodes should be terminated
- when episodes are truncated, episodes are considered infinite and
are bootstrapped.
- when episodes are terminated, bootstrap is ignored.
discount: discount factor
value_estimation_mode: mode for bootstrap value estimation for truncated
episodes,
- if "zero": the bootstrap value is zero
- if "mean": the bootstrap value is estimated with the reward mean
over episode :math:`r` as :math:`\frac{r}{1-\rho}`
- if "scale" (default): the return-to-go at :math:`t` with no
bootstrap value is scaled with :math:`\frac{1}{1-\rho^{\tilde
T-t}}`. This is equivalent to bootstrap with
:math:`\frac{r}{1-\rho}` where :math:`r` being the discounted mean
from :math:`t+1` to :math:`\tilde T`
rtg_key: the key for the new rtg0 column,
- if None, the key is "cumulative " + key of the rew1 column
mix_guardrail: rtg can't be estimated yet with mix of terminated and
truncated episodes. If True, assert episodes are not mixed which is
the correct alternative. Default to False while we seek a solution.
Returns:
DataFrame[RTGSchema]: the original dataframe with an additional "rtg0" column
Examples:
Initialize OAR dataframe
>>> import pandas as pd
>>> df = pd.DataFrame({
... ("act0", "choice"): [1, 2, 3],
... ("rew1", "money"): [1., 2., 3.],
... ("term1", "fundraising end"): [False, False, True]})
>>> df.columns.names = ("signal", "key")
>>> df = df.set_index(
... pd.MultiIndex.from_product([[0], pd.date_range("2000-01-01", "2000-01-03")],
... names=["episode", "date"]))
validate initial dataframe
>>> from sara.oar import OARSchema, RTGSchema, enrich_rtg
>>> OARSchema.validate(df) # doctest: +NORMALIZE_WHITESPACE
signal act0 rew1 term1
key choice money fundraising end
episode date
0 2000-01-01 1 1.0 False
2000-01-02 2 2.0 False
2000-01-03 3 3.0 True
enrich dataframe with return-to-go, df then validates RTGSchema
>>> df = enrich_rtg(df, 0.9)
>>> RTGSchema.validate(df) # doctest: +NORMALIZE_WHITESPACE
signal act0 rew1 term1 rtg0
key choice money fundraising end cumulative money
episode date
0 2000-01-01 1 1.0 False 5.23
2000-01-02 2 2.0 False 4.70
2000-01-03 3 3.0 True 3.00
check correct return-to-go
>>> import numpy as np
>>> rtg0 = df[("rtg0", "cumulative money")]
>>> rtg0_tgt = np.array([1. + .9*2. + .9**2*3., 2. + .9*3., 3.])
>>> all(rtg0.to_numpy() == rtg0_tgt)
True
"""
levels_not_date = [name for name in df.index.names if name != "date"]
# check there is no mix of truncated and terminated episodes
if mix_guardrail and ("term1" in [k[0] for k in df.columns]) and (len(levels_not_date) > 0):
term1_key = next(k for k in df.columns if k[0] == "term1")
grp = df.groupby(level=levels_not_date, group_keys=False)
terminateds = grp.apply(lambda sdf: sdf[term1_key].iloc[-1])
assert bool(terminateds.all()) or bool((~terminateds).all()), (
"return-to-go estimation is not available for mix of terminated "
"and truncated episodes: either all episodes are terminated "
"or all episodes are truncated."
)
# compute rtgs
rtgs = df.groupby(level=levels_not_date, group_keys=False).apply(
lambda subdf: get_rtg_by_episode(subdf, discount, value_estimation_mode),
include_groups=False,
)
# replace key associated with 'rtg0'
rew1_key = next(k for k in df.columns if k[0] == "rew1")
rtg_key = (
("rtg0", rtg_key)
if rtg_key is not None
else ("rtg0", "cumulative " + rew1_key[1])
)
# copy to avoid inplace modification
df_out = df.copy()
df_out[rtg_key] = rtgs["rtg0"]
return DataFrame[RTGSchema](df_out)