Source code for examples.provencia.env

"""Environment Formatting for Provencia.

Applies environment definition to DataFrame from ETL
"""

import logging

import pandas as pd
import pandera.pandas as pa
from pandera.typing.pandas import DataFrame

from sara.oar import OARSchema

try:
    from .schemas import ProvenciaEnvSchema
except ImportError:
    from schemas import ProvenciaEnvSchema


logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class EnvError(Exception):
    """Error for env formating.

    This allows distinction between controlled error and
    uncontrolled error
    """

    def __init__(self, message: str) -> None:
        """Initialize the error with a message."""
        super().__init__(message)

    @staticmethod
    def check_df(
        message: str,
        df: pd.DataFrame,
        wrong_sr: pd.Series,
    ) -> None:
        """Check if any row is wrong and raise error.

        Error is `message` as first line and problematic rows of df in 2nd line

        Args:
            message: first line of the error message
            df: the DataFrame
            wrong_sr: boolean serie True where df is problematic
        """
        if wrong_sr.to_numpy().any():
            msg = message
            msg = f"\n{df[wrong_sr]}"
            raise EnvError(msg)


def enrich_profit(df: pd.DataFrame, lag: int) -> pd.DataFrame:
    """Add profit column to DataFrame.

    Basically, profit is gross - total_paid. But since the flow is:
        pay order1 -> sell product0 -> receive order1 -> pay order2 -> sell product1
    there is a lag of one day between ordering and actually selling the products
    in that order. This lag can make credit assignement harder so a `lag` parameter
    is introduced

    Args:
       df: dataframe
       lag: shift the gross data `lag` day before, therefore `lag=1` compares
           the order with the first gross that contains this order.

    Returns:
       pd.DataFrame: the input dataframe with new `profit` column
    """
    n = len(df)
    if n < lag:
        msg = f"not enough rows for rew computation: {n}"
        raise EnvError(msg)
    profit = (
        -df["total_paid"].iloc[: n - lag] + df["gross"].shift(-lag, freq="D").iloc[lag:]
    )
    df = df.iloc[: n - lag]
    df["profit"] = profit
    return df


def align(df: pd.DataFrame) -> pd.DataFrame:
    """Time alignement of data."""
    length_df, min_length_df = len(df), 2
    if length_df < min_length_df:
        msg = f"not enough rows for time alignement: {length_df}"
        # BUG experimenting with no error
        logger.info(msg)
        logger.info("removing the subdf instead of throwing error")
        # raise EnvError(msg)
    # shifting some data to yesterday so that they can be set before action
    df_yd = pd.DataFrame(
        df[["sold", "gross", "stock_evening"]].shift(1, freq="D"),
    ).rename(
        columns={
            "sold": "sold_yd",
            "gross": "gross_yd",
            "stock_evening": "stock_morning",
        }
    )
    # removing now useless columns
    df = df.drop(columns=["total_paid", "sold", "gross", "stock_evening"])
    # merging shifted data and base data
    df = df.merge(df_yd, left_index=True, right_index=True, how="inner")
    # include delivery in stock_morning to make stock after delivery
    # s.t. stock 0 means nothing to sell
    df[("stock_after_delivery")] = df[("stock_morning")] + df[("delivered")]
    return df.drop(columns="stock_morning")


def grp_pipeline(
    df: pd.DataFrame,
    lag_profit: int,
) -> pd.DataFrame:
    """Part of the env pipeline that act on episodic sub dataframe.

    Args:
        df: the etl dataframe grouped by levels that are not `date`
        lag_profit: lag to use in :func:`enrich_rew`
    """
    df = df.droplevel(["store_id", "product_code"])
    df = enrich_profit(df, lag_profit)
    return align(df)


[docs] @pa.check_types def env_pipeline( df_etl: pd.DataFrame, lag_profit: int = 0, ) -> DataFrame[OARSchema]: """Transform Provencia ETL dataframe into Env dataframe with environment definition. Args: df_etl: dataframe from :func:`etl_pipeline` lag_profit: number of days between 'total_paid' (at 'date') and 'gross' (at 'date'+lag_profit) in profit calculcation for the reward. Returns: DataFrame[OARSchema]: OAR dataframe formatted with environment Exemples: >>> from examples.provencia import etl_pipeline >>> df = etl_pipeline( # doctest: +NORMALIZE_WHITESPACE ... ['JE', 'KV', 'EV'], ... [2870622000000, 2870557000000, 2870549000000]) >>> env_pipeline(df) # doctest: +NORMALIZE_WHITESPACE signal obs0 act0 rew1 key sold_yd gross_yd stock_after_delivery delivered purchase_price ordered profit store_id product_code date EV 2870549000000 2023-12-06 0 0 5 0 647 0 0.0 2023-12-07 0 0 5 0 647 0 0.0 2023-12-08 0 0 5 0 647 0 1834.0 2023-12-09 4 1834 1 0 647 0 0.0 2023-12-10 0 0 1 0 647 0 0.0 ... ... ... ... ... ... ... ... 2870622000000 2024-07-28 0 0 0 0 560 0 0.0 2024-07-29 0 0 0 0 560 0 0.0 2024-07-30 0 0 0 0 560 0 0.0 2024-07-31 0 0 0 0 560 0 0.0 2024-08-01 0 0 0 0 560 1 -1680.0 <BLANKLINE> [515 rows x 7 columns] """ df = df_etl.groupby( level=[lvl for lvl in df_etl.index.names if lvl != "date"], ).apply(lambda df: grp_pipeline(df, lag_profit)) # reorder columns in causal order columns = ProvenciaEnvSchema.to_schema().columns df = df.reindex(columns=[c[1] for c in columns]) # multiindex columns with signal df.columns = pd.MultiIndex.from_tuples(columns, names=["signal", "key"]) EnvError.check_df( "Nan introduced in env pipeline", df, pd.Series(df.isna().any(axis=1)) ) return OARSchema.validate(df)