"""Observation-Action-Reward (OAR) format definition.
The OAR format organizes a dataset when it comes (or interpreted as coming) from
a reinforcement learning environment.
This module contains `pandera <https://pandera.readthedocs.io/en/stable>`__
schemas to check and validate the structure of :class:`pd.DataFrame` in use in
:mod:`sara`.
Todo:
- impose in the OAR that data have to be upsampled with uniform frequency
(is it valid with filtering ?)
- ensure metadata has description ?
"""
from typing import Final
import pandas as pd
from pandas.api.types import is_timedelta64_ns_dtype
import pandera.pandas as pa
COL_NAMES: Final[list[str]] = ["signal", "key"]
"""A list containing the names of the levels of the OAR multi-indexed columns.
- signal: the reinforcement learning signal e.g. "act"
- key: the action in the dataset e.g. "order"
"""
SIGNALS: Final[set[str]] = {"obs0", "act0", "rew1", "obs1", "term1", "rtg0"}
r"""The set of possible signals.
- obs0: the observation at date `d`
- act0: the action taken at date `d` given obs0
- rew1: the reward received after act0 have been submitted to environment
- obs1: the next observation after act0 have been submitted to environment
- term1: the termination signal after act0 have been submitted to environment
- rtg0: the return to go from date d
"""
[docs]
class OARSchema(pa.DataFrameModel):
"""The base OAR format for pandas DataFrame.
Raises:
SchemaError: if the input dataframe do not validate the schema
Examples:
Initialize minimal OAR dataframe:
>>> import pandas as pd
>>> from sara.oar import OARSchema
>>> df = pd.DataFrame({
... ("act0", "choice"): [1, 2, 3],
... ("rew1", "money"): [1., 2., 3.]})
>>> df.columns.names = ("signal", "key")
>>> df = df.set_index(
... pd.MultiIndex.from_product(
... [[0], pd.date_range("2000-01-01", "2000-01-03")],
... names=["episode", "date"]))
validate it against OARSchema:
>>> OARSchema.validate(df) # doctest: +NORMALIZE_WHITESPACE
signal act0 rew1
key choice money
episode date
0 2000-01-01 1 1.0
2000-01-02 2 2.0
2000-01-03 3 3.0
Note:
- episodes ending with :code:`term1 == False` are considered truncated
- if no :code:`term1` column is present, episodes are considered truncated
- date should be an integer or a datetime
Todo:
- maybe it is more general to include truncation signal, it allows to
slice episode how we want and eventually batch them
"""
# ----- EPISODE -----
@pa.dataframe_check(
name="more_than_two_indexes",
error="num indexes > 1",
)
@classmethod
def _more_than_two_indexes(cls, df: pd.DataFrame) -> bool:
return isinstance(df.index, pd.MultiIndex) and len(df.index.names) > 1
@pa.dataframe_check(
name="all_index_levels_named",
error="All index levels must have a non-empty name",
)
@classmethod
def _all_index_levels_named(cls, df: pd.DataFrame) -> bool:
# Validates that every level in the MultiIndex has a name that isn't None or empty
return all(name is not None and str(name).strip() != "" for name in df.index.names)
# ----- DATE -----
@pa.dataframe_check(
name="has_date",
error="Indexes should contain exactly one 'date'",
)
@classmethod
def _has_date(cls, df: pd.DataFrame) -> bool:
indexes = list(df.index.names)
return indexes.count("date") == 1
@pa.dataframe_check(
name="check_date_increasing",
error="Date should increase",
)
@classmethod
def _check_date_increasing(cls, df: pd.DataFrame) -> bool:
idx_names_not_date = [idx for idx in df.index.names if idx != "date"]
grp = df.groupby(level=idx_names_not_date)
flag = grp.apply(
lambda sdf: sdf.index.get_level_values("date").is_monotonic_increasing,
)
return bool(flag.all())
@pa.dataframe_check(
name="check_date_dtype",
error="Date should have datetime64[ns] or integer dtype",
)
@classmethod
def _check_date_dtype(cls, df: pd.DataFrame) -> bool:
dtype = df.index.get_level_values("date").dtype
return (
pd.api.types.is_datetime64_dtype(
dtype,
)
or pd.api.types.is_datetime64_ns_dtype(
dtype,
)
or pd.api.types.is_timedelta64_dtype(
dtype,
)
or pd.api.types.is_timedelta64_ns_dtype(
dtype,
)
or pd.api.types.is_integer_dtype(
dtype,
)
)
# ----- SIGNAL and NAMES -----
@pa.dataframe_check(
name="check_column_names",
error=f"Columns should be a MultiIndex with names {COL_NAMES}.",
)
@classmethod
def _check_column_names(cls, df: pd.DataFrame) -> bool:
return df.columns.names == COL_NAMES
@pa.dataframe_check(name="check_signal", error=f"Columns must be in {SIGNALS}")
@classmethod
def _check_signal_columns(cls, df: pd.DataFrame) -> bool:
signals = list(df.columns.get_level_values("signal"))
return all((s in SIGNALS) for s in signals)
@pa.dataframe_check(
name="check_keys_non_empty",
error="All columns must have a non-empty string 'key'",
)
@classmethod
def _check_keys_non_empty(cls, df: pd.DataFrame) -> bool:
keys = df.columns.get_level_values("key")
# Checks if keys are null or just empty/whitespace strings
return all(isinstance(k, str) and len(k.strip()) > 0 for k in keys if k is not None)
# ----- OBS0 -----
# nothing
# ----- ACT0 -----
@pa.dataframe_check(
name="has_at_least_one_act",
error="Columns should contain at least one 'act0'",
)
@classmethod
def _has_at_least_one_act(cls, df: pd.DataFrame) -> bool:
signals = list(df.columns.get_level_values("signal"))
return signals.count("act0") > 0
# ----- REW1 -----
@pa.dataframe_check(
name="has_one_rew",
error="Columns should contain exactly one 'rew1'",
)
@classmethod
def _has_one_rew(cls, df: pd.DataFrame) -> bool:
signals = list(df.columns.get_level_values("signal"))
return signals.count("rew1") == 1
@pa.dataframe_check(
name="check_rew_dtype",
error="Rew1 should be of a floating dtype",
)
@classmethod
def _check_rew_dtype(cls, df: pd.DataFrame) -> bool:
rew_sth = next(k for k in df.columns if k[0] == "rew1")
return pd.api.types.is_float_dtype(df[rew_sth].dtype)
# ----- OBS1 -----
@pa.dataframe_check(
name="check_obs1_same_name",
error="Obs0 and obs1 columns should have same name",
)
@classmethod
def _check_obs1_same_name(cls, df: pd.DataFrame) -> bool:
# NOTE: same value is not checked because
# it assumes same timedelta between rows
obs0_cols = [col[1] for col in df.columns if col[0] == "obs0"]
obs1_cols = [col[1] for col in df.columns if col[0] == "obs1"]
if len(obs1_cols) == 0:
# there is no obs1 which is possible
return True
if len(obs1_cols) != len(obs0_cols):
return False
return all(e0 == e1 for e0, e1 in zip(obs0_cols, obs1_cols, strict=True))
@pa.dataframe_check(
name="check_obs1_same_dtype",
error="Obs0 and obs1 columns should have same dtype",
)
@classmethod
def _check_obs1_same_dtype(cls, df: pd.DataFrame) -> bool:
# NOTE: same value is not checked because
# it assumes same timedelta between rows
obs0_cols = [df[col].dtype for col in df.columns if col[0] == "obs0"]
obs1_cols = [df[col].dtype for col in df.columns if col[0] == "obs1"]
if len(obs1_cols) == 0:
# can ommit obs1
return True
if len(obs1_cols) != len(obs0_cols):
return False
return all(e0 == e1 for e0, e1 in zip(obs0_cols, obs1_cols, strict=True))
@staticmethod
def _check_obs1_same_value_in_episode(df: pd.DataFrame) -> bool:
obs0_cols = [df[col].to_numpy()[1:] for col in df.columns if col[0] == "obs0"]
obs1_cols = [df[col].to_numpy()[:-1] for col in df.columns if col[0] == "obs1"]
if len(obs1_cols) == 0:
# can ommit obs1
return True
if len(obs1_cols) != len(obs0_cols):
return False
return all(
(e0 == e1).all() for e0, e1 in zip(obs0_cols, obs1_cols, strict=True)
)
@pa.dataframe_check(
name="check_obs1_same_value",
error="Obs0 and obs1 columns should have same value",
)
@classmethod
def _check_obs1_same_value(cls, df: pd.DataFrame) -> bool:
idx_names_not_date = [idx for idx in df.index.names if idx != "date"]
grp = df.groupby(level=idx_names_not_date)
flag = grp.apply(lambda sdf: cls._check_obs1_same_value_in_episode(sdf))
return bool(flag.all())
# ----- TERM1 -----
@pa.dataframe_check(
name="has_at_most_term1",
error="Columns should contain at most one 'term1'",
)
@classmethod
def _has_at_most_terms(cls, df: pd.DataFrame) -> bool:
signals = list(df.columns.get_level_values("signal"))
return signals.count("term1") <= 1
@pa.dataframe_check(
name="check_term1_dtype",
error="Term1 should be of boolean dtype",
)
@classmethod
def _check_term1_dtype(cls, df: pd.DataFrame) -> bool:
term1_sths = [k for k in df.columns if k[0] == "term1"]
if len(term1_sths) == 0:
return True
if len(term1_sths) == 1:
return pd.api.types.is_bool_dtype(df[term1_sths[0]].dtype)
return False
@pa.dataframe_check(
name="check_term1_last",
error="only last term1 can be True",
)
@classmethod
def _check_term1_last(cls, df: pd.DataFrame) -> bool:
term1_sths = [k for k in df.columns if k[0] == "term1"]
if len(term1_sths) == 0:
return True
idx_names_not_date = [idx for idx in df.index.names if idx != "date"]
grp = df.groupby(level=idx_names_not_date)
flag = grp.apply(lambda sdf: (~sdf[term1_sths[0]][:-1]).all())
return bool(flag.all())
# ----- RTG0 -----
@pa.dataframe_check(
name="check_at_most_one_rtg",
error="Rew1 should be of a floating dtype",
)
@classmethod
def _check_at_most_one_rtg(cls, df: pd.DataFrame) -> bool:
signals = list(df.columns.get_level_values("signal"))
return signals.count("rtg0") <= 1
@pa.dataframe_check(
name="check_rtg_dtype",
error="Rtg0 should be of a floating dtype",
)
@classmethod
def _check_rtg_dtype(cls, df: pd.DataFrame) -> bool:
rtg0_sths = [k for k in df.columns if k[0] == "rtg0"]
if len(rtg0_sths) == 0:
return True
if len(rtg0_sths) == 1:
return pd.api.types.is_float_dtype(df[rtg0_sths[0]].dtype)
return False
[docs]
class RTGSchema(OARSchema):
"""An :class:`OARSchema` with a non-optional "rtg0" signal."""
@pa.dataframe_check(
name="has_one_rtg",
error="Columns should contain exactly one 'rtg0'",
)
@classmethod
def _has_one_rtg(cls, df: pd.DataFrame) -> bool:
signals = list(df.columns.get_level_values("signal"))
return signals.count("rtg0") == 1
class NeoVisionSchema(OARSchema):
"""An :class:`OARSchema` with additional conditions required for Neovision."""
@pa.dataframe_check(
name="has_one_obs0",
error="Columns should contain exactly one 'obs0'",
)
@classmethod
def _has_one_obs0(cls, df: pd.DataFrame) -> bool:
signals = list(df.columns.get_level_values("signal"))
return signals.count("obs0") == 1
@pa.dataframe_check(
name="check_metadata_description",
error="All columns must have a metadata['description'] attribute",
)
@classmethod
def _check_metadata_description(cls, df: pd.DataFrame) -> bool:
fields = getattr(cls, "__fields__", {})
for (_, field_info) in fields.values():
metadata = field_info.metadata or {}
if not metadata.get("description"):
return False
return True
@pa.dataframe_check(name="check_business_metadata")
@classmethod
def _check_business_metadata(cls, _: pd.DataFrame) -> bool:
# Define the required keys
REQUIRED_METADATA_KEYS = {
"pain_points", "business_kpis", "current_baseline",
"hard_constraints", "soft_constraints", "success_criteria",
"business_objective", "process_to_optimize", "baseline_for_comparison",
}
# Access the metadata from the Config class
metadata = getattr(cls.Config, "metadata", {})
metadata_keys = set(metadata.keys())
# check at least one and no other keys than REQUIRED
return (len(metadata_keys) > 0) and metadata_keys.issubset(REQUIRED_METADATA_KEYS)
@pa.dataframe_check(
name="check_schema_description",
error="The schema Config must have a metadata['description'] attribute",
)
@classmethod
def _check_schema_description(cls, _df: pd.DataFrame) -> bool:
metadata = getattr(cls.Config, "metadata", {})
return bool(metadata.get("description"))