Source code for sara.oar.schemas

"""Observation-Action-Reward (OAR) format definition.

The OAR format organizes a dataset when it comes (or interpreted as coming) from
a reinforcement learning environment.

This module contains `pandera <https://pandera.readthedocs.io/en/stable>`__
schemas to check and validate the structure of :class:`pd.DataFrame` in use in
:mod:`sara`.

Todo:
    - impose in the OAR that data have to be upsampled with uniform frequency
      (is it valid with filtering ?)
    - ensure metadata has description ?

"""

from typing import Final

import pandas as pd
from pandas.api.types import is_timedelta64_ns_dtype
import pandera.pandas as pa

COL_NAMES: Final[list[str]] = ["signal", "key"]
"""A list containing the names of the levels of the OAR multi-indexed columns.

- signal: the reinforcement learning signal e.g. "act"
- key: the action in the dataset e.g. "order"
"""

SIGNALS: Final[set[str]] = {"obs0", "act0", "rew1", "obs1", "term1", "rtg0"}
r"""The set of possible signals.

- obs0: the observation at date `d`
- act0: the action taken at date `d` given obs0
- rew1: the reward received after act0 have been submitted to environment
- obs1: the next observation after act0 have been submitted to environment
- term1: the termination signal after act0 have been submitted to environment
- rtg0: the return to go from date d
"""


[docs] class OARSchema(pa.DataFrameModel): """The base OAR format for pandas DataFrame. Raises: SchemaError: if the input dataframe do not validate the schema Examples: Initialize minimal OAR dataframe: >>> import pandas as pd >>> from sara.oar import OARSchema >>> df = pd.DataFrame({ ... ("act0", "choice"): [1, 2, 3], ... ("rew1", "money"): [1., 2., 3.]}) >>> df.columns.names = ("signal", "key") >>> df = df.set_index( ... pd.MultiIndex.from_product( ... [[0], pd.date_range("2000-01-01", "2000-01-03")], ... names=["episode", "date"])) validate it against OARSchema: >>> OARSchema.validate(df) # doctest: +NORMALIZE_WHITESPACE signal act0 rew1 key choice money episode date 0 2000-01-01 1 1.0 2000-01-02 2 2.0 2000-01-03 3 3.0 Note: - episodes ending with :code:`term1 == False` are considered truncated - if no :code:`term1` column is present, episodes are considered truncated - date should be an integer or a datetime Todo: - maybe it is more general to include truncation signal, it allows to slice episode how we want and eventually batch them """ # ----- EPISODE ----- @pa.dataframe_check( name="more_than_two_indexes", error="num indexes > 1", ) @classmethod def _more_than_two_indexes(cls, df: pd.DataFrame) -> bool: return isinstance(df.index, pd.MultiIndex) and len(df.index.names) > 1 @pa.dataframe_check( name="all_index_levels_named", error="All index levels must have a non-empty name", ) @classmethod def _all_index_levels_named(cls, df: pd.DataFrame) -> bool: # Validates that every level in the MultiIndex has a name that isn't None or empty return all(name is not None and str(name).strip() != "" for name in df.index.names) # ----- DATE ----- @pa.dataframe_check( name="has_date", error="Indexes should contain exactly one 'date'", ) @classmethod def _has_date(cls, df: pd.DataFrame) -> bool: indexes = list(df.index.names) return indexes.count("date") == 1 @pa.dataframe_check( name="check_date_increasing", error="Date should increase", ) @classmethod def _check_date_increasing(cls, df: pd.DataFrame) -> bool: idx_names_not_date = [idx for idx in df.index.names if idx != "date"] grp = df.groupby(level=idx_names_not_date) flag = grp.apply( lambda sdf: sdf.index.get_level_values("date").is_monotonic_increasing, ) return bool(flag.all()) @pa.dataframe_check( name="check_date_dtype", error="Date should have datetime64[ns] or integer dtype", ) @classmethod def _check_date_dtype(cls, df: pd.DataFrame) -> bool: dtype = df.index.get_level_values("date").dtype return ( pd.api.types.is_datetime64_dtype( dtype, ) or pd.api.types.is_datetime64_ns_dtype( dtype, ) or pd.api.types.is_timedelta64_dtype( dtype, ) or pd.api.types.is_timedelta64_ns_dtype( dtype, ) or pd.api.types.is_integer_dtype( dtype, ) ) # ----- SIGNAL and NAMES ----- @pa.dataframe_check( name="check_column_names", error=f"Columns should be a MultiIndex with names {COL_NAMES}.", ) @classmethod def _check_column_names(cls, df: pd.DataFrame) -> bool: return df.columns.names == COL_NAMES @pa.dataframe_check(name="check_signal", error=f"Columns must be in {SIGNALS}") @classmethod def _check_signal_columns(cls, df: pd.DataFrame) -> bool: signals = list(df.columns.get_level_values("signal")) return all((s in SIGNALS) for s in signals) @pa.dataframe_check( name="check_keys_non_empty", error="All columns must have a non-empty string 'key'", ) @classmethod def _check_keys_non_empty(cls, df: pd.DataFrame) -> bool: keys = df.columns.get_level_values("key") # Checks if keys are null or just empty/whitespace strings return all(isinstance(k, str) and len(k.strip()) > 0 for k in keys if k is not None) # ----- OBS0 ----- # nothing # ----- ACT0 ----- @pa.dataframe_check( name="has_at_least_one_act", error="Columns should contain at least one 'act0'", ) @classmethod def _has_at_least_one_act(cls, df: pd.DataFrame) -> bool: signals = list(df.columns.get_level_values("signal")) return signals.count("act0") > 0 # ----- REW1 ----- @pa.dataframe_check( name="has_one_rew", error="Columns should contain exactly one 'rew1'", ) @classmethod def _has_one_rew(cls, df: pd.DataFrame) -> bool: signals = list(df.columns.get_level_values("signal")) return signals.count("rew1") == 1 @pa.dataframe_check( name="check_rew_dtype", error="Rew1 should be of a floating dtype", ) @classmethod def _check_rew_dtype(cls, df: pd.DataFrame) -> bool: rew_sth = next(k for k in df.columns if k[0] == "rew1") return pd.api.types.is_float_dtype(df[rew_sth].dtype) # ----- OBS1 ----- @pa.dataframe_check( name="check_obs1_same_name", error="Obs0 and obs1 columns should have same name", ) @classmethod def _check_obs1_same_name(cls, df: pd.DataFrame) -> bool: # NOTE: same value is not checked because # it assumes same timedelta between rows obs0_cols = [col[1] for col in df.columns if col[0] == "obs0"] obs1_cols = [col[1] for col in df.columns if col[0] == "obs1"] if len(obs1_cols) == 0: # there is no obs1 which is possible return True if len(obs1_cols) != len(obs0_cols): return False return all(e0 == e1 for e0, e1 in zip(obs0_cols, obs1_cols, strict=True)) @pa.dataframe_check( name="check_obs1_same_dtype", error="Obs0 and obs1 columns should have same dtype", ) @classmethod def _check_obs1_same_dtype(cls, df: pd.DataFrame) -> bool: # NOTE: same value is not checked because # it assumes same timedelta between rows obs0_cols = [df[col].dtype for col in df.columns if col[0] == "obs0"] obs1_cols = [df[col].dtype for col in df.columns if col[0] == "obs1"] if len(obs1_cols) == 0: # can ommit obs1 return True if len(obs1_cols) != len(obs0_cols): return False return all(e0 == e1 for e0, e1 in zip(obs0_cols, obs1_cols, strict=True)) @staticmethod def _check_obs1_same_value_in_episode(df: pd.DataFrame) -> bool: obs0_cols = [df[col].to_numpy()[1:] for col in df.columns if col[0] == "obs0"] obs1_cols = [df[col].to_numpy()[:-1] for col in df.columns if col[0] == "obs1"] if len(obs1_cols) == 0: # can ommit obs1 return True if len(obs1_cols) != len(obs0_cols): return False return all( (e0 == e1).all() for e0, e1 in zip(obs0_cols, obs1_cols, strict=True) ) @pa.dataframe_check( name="check_obs1_same_value", error="Obs0 and obs1 columns should have same value", ) @classmethod def _check_obs1_same_value(cls, df: pd.DataFrame) -> bool: idx_names_not_date = [idx for idx in df.index.names if idx != "date"] grp = df.groupby(level=idx_names_not_date) flag = grp.apply(lambda sdf: cls._check_obs1_same_value_in_episode(sdf)) return bool(flag.all()) # ----- TERM1 ----- @pa.dataframe_check( name="has_at_most_term1", error="Columns should contain at most one 'term1'", ) @classmethod def _has_at_most_terms(cls, df: pd.DataFrame) -> bool: signals = list(df.columns.get_level_values("signal")) return signals.count("term1") <= 1 @pa.dataframe_check( name="check_term1_dtype", error="Term1 should be of boolean dtype", ) @classmethod def _check_term1_dtype(cls, df: pd.DataFrame) -> bool: term1_sths = [k for k in df.columns if k[0] == "term1"] if len(term1_sths) == 0: return True if len(term1_sths) == 1: return pd.api.types.is_bool_dtype(df[term1_sths[0]].dtype) return False @pa.dataframe_check( name="check_term1_last", error="only last term1 can be True", ) @classmethod def _check_term1_last(cls, df: pd.DataFrame) -> bool: term1_sths = [k for k in df.columns if k[0] == "term1"] if len(term1_sths) == 0: return True idx_names_not_date = [idx for idx in df.index.names if idx != "date"] grp = df.groupby(level=idx_names_not_date) flag = grp.apply(lambda sdf: (~sdf[term1_sths[0]][:-1]).all()) return bool(flag.all()) # ----- RTG0 ----- @pa.dataframe_check( name="check_at_most_one_rtg", error="Rew1 should be of a floating dtype", ) @classmethod def _check_at_most_one_rtg(cls, df: pd.DataFrame) -> bool: signals = list(df.columns.get_level_values("signal")) return signals.count("rtg0") <= 1 @pa.dataframe_check( name="check_rtg_dtype", error="Rtg0 should be of a floating dtype", ) @classmethod def _check_rtg_dtype(cls, df: pd.DataFrame) -> bool: rtg0_sths = [k for k in df.columns if k[0] == "rtg0"] if len(rtg0_sths) == 0: return True if len(rtg0_sths) == 1: return pd.api.types.is_float_dtype(df[rtg0_sths[0]].dtype) return False
[docs] class RTGSchema(OARSchema): """An :class:`OARSchema` with a non-optional "rtg0" signal.""" @pa.dataframe_check( name="has_one_rtg", error="Columns should contain exactly one 'rtg0'", ) @classmethod def _has_one_rtg(cls, df: pd.DataFrame) -> bool: signals = list(df.columns.get_level_values("signal")) return signals.count("rtg0") == 1
class NeoVisionSchema(OARSchema): """An :class:`OARSchema` with additional conditions required for Neovision.""" @pa.dataframe_check( name="has_one_obs0", error="Columns should contain exactly one 'obs0'", ) @classmethod def _has_one_obs0(cls, df: pd.DataFrame) -> bool: signals = list(df.columns.get_level_values("signal")) return signals.count("obs0") == 1 @pa.dataframe_check( name="check_metadata_description", error="All columns must have a metadata['description'] attribute", ) @classmethod def _check_metadata_description(cls, df: pd.DataFrame) -> bool: fields = getattr(cls, "__fields__", {}) for (_, field_info) in fields.values(): metadata = field_info.metadata or {} if not metadata.get("description"): return False return True @pa.dataframe_check(name="check_business_metadata") @classmethod def _check_business_metadata(cls, _: pd.DataFrame) -> bool: # Define the required keys REQUIRED_METADATA_KEYS = { "pain_points", "business_kpis", "current_baseline", "hard_constraints", "soft_constraints", "success_criteria", "business_objective", "process_to_optimize", "baseline_for_comparison", } # Access the metadata from the Config class metadata = getattr(cls.Config, "metadata", {}) metadata_keys = set(metadata.keys()) # check at least one and no other keys than REQUIRED return (len(metadata_keys) > 0) and metadata_keys.issubset(REQUIRED_METADATA_KEYS) @pa.dataframe_check( name="check_schema_description", error="The schema Config must have a metadata['description'] attribute", ) @classmethod def _check_schema_description(cls, _df: pd.DataFrame) -> bool: metadata = getattr(cls.Config, "metadata", {}) return bool(metadata.get("description"))