Code
import pandas as pd
February 21, 2025
Pandera schema’s are a useful tool to make sure input data is as expected. If you’ve ever used dbt before, theyre just like schema tests or great-expectations.
Data validation is extremely important for just about everyone in the data science space. For instance, your model may not be expecting null values, or maybe it needs data to be iid. This is where Pandera can come in and excel, especially when you don’t have control over the input data.
I recently had a project where I wasn’t responsible for and had no control over the input data. For a particularly model, many different datasets would be sent to different models for inference, and I couldn’t always rely on the data being properly structured.
I was also working with a range of different models including time-varying survival and time series models. I wanted to be able to reuse as much code across models, which is where Pandera comes in.
Lets start with looking at some panel data:
panel = pd.DataFrame({
"entity_id":[1,1,2,],
"date":['2023-01-01','2023-02-01','2023-01-01'],
})
panel.assign(metric=[0,1,1])
entity_id | date | metric | |
---|---|---|---|
0 | 1 | 2023-01-01 | 0 |
1 | 1 | 2023-02-01 | 1 |
2 | 2 | 2023-01-01 | 1 |
It turns out Time Series data is typically going to be panel data (although not always true). Time-varying survival data is also a special case of panel data, where each entity (for instance a customer’s subscription id) starts at some point at there are consistently spaced observations.
Class inheritance is a perfect pattern to reflect this. We can do something like the following to reuse schema checks and save a lot of lines of code.
class PanelSchema:
...
class TimeSeriesSchema(PanelSchema):
...
class TimevaryingSurvivalSchema(PanelSchema):
...
For example, we can make a panel schema with some obvious checks:
class BasePanelSchema(pa.DataFrameModel):
@classmethod
def infer_frequency(cls, df: pd.DataFrame) -> str:
"""Identifies the frequency of dates in the panel dataset"""
...
@classmethod
def validate(cls, df: pd.DataFrame, *args, **kwargs):
# Run and save dataset frequency before other validations
cls.Config.metadata["freq"] = cls.infer_frequency(df)
return super().validate(df, *args, **kwargs)
class PanelSchema(BasePanelSchema):
"""Panel Data schema. This assumes that Panel Data has 1 entry per entity per date, and that all
dates between the min and the max date for an entity should exist as records, evenly spaced.
"""
entity_id: Series[str] = pa.Field(coerce=True, nullable=False)
date: Series[DateTime] = pa.Field(coerce=True, nullable=False)
class Config:
unique = ["entity_id", "date"]
strict = False
metadata: dict = {}
@pa.dataframe_check
def validate_frequency(cls, df: pd.DataFrame) -> bool:
"""Ensures that every entity has the same frequency between dates of consecutive records"""
# Use the inferred frequency for some validation
...
@pa.dataframe_check
def check_complete_date_index_per_entity(cls, df: DataFrame) -> Series[bool]:
"""Ensures that all dates (at the applicable frequency) between the min and the max date for
an entity should exist as records.
"""
...
As a starting point, this makes it really easy to extend to time series models.
entity_id | date | y | |
---|---|---|---|
0 | 1 | 2023-01-01 | 100 |
1 | 1 | 2023-02-01 | 105 |
I can take my PanelSchema and inherit it for different data models. A simple example is lets say I need a non-null outcome:
class UnivariateTimeSeriesSchema(PanelSchema):
y: Series[float] = pa.Field(coerce=True, nullable=False)
and just like that, I have all of those previous data checks extended to my time series model.
There are some downsides though. First, you cant parameterize field names, for instance your outcome needs to be named y
. I’d also need an entity_id
column for this to work because the PanelSchema requires it - but that’s simple enough, you could just add say a product_id field with the product name you’re forecasting.
For time-varying survival schema, they can be defined as
class TimeVaryingSurvivalSchema(PanelSchema):
"""A schema for Time Varying Survival Analysis, should be a unique and complete panel dataset
on the `entity_id` and `time` columns.
"""
tenure: Series[int] = pa.Field(ge=0, coerce=True, nullable=False)
event: Series[int] = pa.Field(isin=[0, 1], coerce=True, nullable=False)
exposure: Series[float] = pa.Field(le=1, ge=0, coerce=True, nullable=False)
@pa.dataframe_check
def check_tenure_is_consecutive(cls, df: DataFrame) -> Series[bool]:
...
@pa.dataframe_check
def check_max_one_event_per_entity(cls, df: DataFrame) -> Series[bool]:
...
@pa.dataframe_check
def check_event_is_last_obs_per_entity(cls, df: DataFrame) -> Seris[bool]:
...
In this case, it inherits all of the previous schema checks from PanelSchema. We extend it further to add some survival analysis schema specific checks - for instance, for each entity_id, we expect there to be a tenure column indicating each consecutive time period. A common example for tenure is the months since a customer started a subscription.
entity_id | date | tenure | exposure | event | |
---|---|---|---|---|---|
0 | 1 | 2023-01-01 | 0 | 1.00 | 0 |
1 | 1 | 2023-02-01 | 1 | 0.50 | 1 |
2 | 2 | 2023-01-01 | 0 | 0.25 | 1 |
Whats great about this is that its all so easy to read and mimics how we think about it conceptually. Time Series and Time-varying survival are just sub-categories of panel data. And our code mimics that perfectly.
Pandera has some nice use cases, particularly when you’re not sure what data you’re going to get thrown at you. The read-ability is fantastic, and it can save a ton of code.
That said there were some caveats I’ve run into (some of which I mentioned above):
churn
instead of event
.Overall, its definitely a useful tool to keep in your back pocket, but I have two pieces of advice:
Note: I may update this later to have actual working code to demo, but for now it’s pseudocode.