Skip to content

BaseModelInference

Bases: BaseConfig, ABC

Abstract base class for performing model inference and jackknife resampling.

This class defines methods for generating predictions, preparing data for inference, and implementing jackknife resampling with confidence intervals. It is designed to handle binary and multiclass classification tasks and allows encoding configurations for model compatibility.

Inherits
  • BaseConfig: Provides configuration settings for data processing.
  • ABC: Specifies abstract methods for subclasses to implement.

Parameters:

Name Type Description Default
classification str

The type of classification task, either 'binary' or 'multiclass', used to configure the inference process.

required
model Any

A trained model instance that implements a predict_proba method for generating class probabilities.

required
verbose bool

If True, enables detailed logging of inference steps.

required

Attributes:

Name Type Description
classification str

Stores the classification type ('binary' or 'multiclass') for model compatibility.

model

The trained model used to make predictions during inference.

verbose bool

Indicates if verbose logging is enabled during inference.

Methods:

Name Description
predict

Run predictions on a batch of input data, returning predicted classes and probabilities.

create_predict_data

Prepare and encode data for inference based on raw data and patient data, supporting one-hot or target encoding formats.

prepare_inference

Prepares data for inference, performing any necessary preprocessing and scaling.

patient_inference

Runs predictions on specific patient data, returning results with predicted classes and probabilities.

process_patient

Processes a patient’s data for jackknife resampling, retraining the model while excluding the patient from training.

Abstract Methods
  • jackknife_resampling: Performs jackknife resampling by retraining the model on various patient subsets.
  • jackknife_confidence_intervals: Computes confidence intervals based on jackknife resampling results.
  • plot_jackknife_intervals: Visualizes jackknife confidence intervals for predictions.
  • jackknife_inference: Executes full jackknife inference, including interval computation and optional plotting.
Source code in periomod/inference/_baseinference.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
class BaseModelInference(BaseConfig, ABC):
    """Abstract base class for performing model inference and jackknife resampling.

    This class defines methods for generating predictions, preparing data for
    inference, and implementing jackknife resampling with confidence intervals.
    It is designed to handle binary and multiclass classification tasks and
    allows encoding configurations for model compatibility.

    Inherits:
        - `BaseConfig`: Provides configuration settings for data processing.
        - `ABC`: Specifies abstract methods for subclasses to implement.

    Args:
        classification (str): The type of classification task, either 'binary'
            or 'multiclass', used to configure the inference process.
        model: A trained model instance that implements a `predict_proba` method
            for generating class probabilities.
        verbose (bool): If True, enables detailed logging of inference steps.

    Attributes:
        classification (str): Stores the classification type ('binary' or 'multiclass')
            for model compatibility.
        model: The trained model used to make predictions during inference.
        verbose (bool): Indicates if verbose logging is enabled during inference.

    Methods:
        predict: Run predictions on a batch of input data, returning
            predicted classes and probabilities.
        create_predict_data: Prepare and encode data for inference based on raw data
            and patient data, supporting one-hot or target encoding formats.
        prepare_inference: Prepares data for inference, performing any
            necessary preprocessing and scaling.
        patient_inference: Runs predictions on specific patient data,
            returning results with predicted classes and probabilities.
        process_patient: Processes a patient’s data for jackknife resampling,
            retraining the model while excluding the patient from training.

    Abstract Methods:
        - `jackknife_resampling`: Performs jackknife resampling by retraining
          the model on various patient subsets.
        - `jackknife_confidence_intervals`: Computes confidence intervals
          based on jackknife resampling results.
        - `plot_jackknife_intervals`: Visualizes jackknife confidence intervals
          for predictions.
        - `jackknife_inference`: Executes full jackknife inference, including
          interval computation and optional plotting.
    """

    def __init__(self, classification: str, model: Any, verbose: bool):
        """Initialize the ModelInference class with a trained model."""
        super().__init__()
        self.classification = classification
        self.model = model
        self.verbose = verbose

    def predict(self, input_data: pd.DataFrame) -> pd.DataFrame:
        """Run prediction on a batch of input data.

        Args:
            input_data (pd.DataFrame): DataFrame containing feature values.

        Returns:
            probs_df: DataFrame with predictions and probabilities for each class.
        """
        probs = self.model.predict_proba(input_data)

        if self.classification == "binary":
            if (
                hasattr(self.model, "best_threshold")
                and self.model.best_threshold is not None
            ):
                preds = (probs[:, 1] >= self.model.best_threshold).astype(int)
        preds = self.model.predict(input_data)
        classes = [str(cls) for cls in self.model.classes_]
        probs_df = pd.DataFrame(probs, columns=classes, index=input_data.index)
        probs_df["prediction"] = preds
        return probs_df

    def create_predict_data(
        self,
        raw_data: pd.DataFrame,
        patient_data: pd.DataFrame,
        encoding: str,
    ) -> pd.DataFrame:
        """Creates prediction data for model inference.

        Args:
            raw_data (pd.DataFrame): The raw, preprocessed data.
            patient_data (pd.DataFrame): Original patient data before preprocessing.
            encoding (str): Type of encoding used ('one_hot' or 'target').

        Returns:
            predict_data: A DataFrame containing the prepared data for model prediction.
        """
        base_data = raw_data.copy()

        if encoding == "one_hot":
            drop_columns = self.cat_vars + self.infect_vars
            base_data = base_data.drop(columns=drop_columns, errors="ignore")
            encoded_data = pd.DataFrame(index=base_data.index)

            for tooth_num in range(11, 49):
                if tooth_num % 10 == 0 or tooth_num % 10 == 9:
                    continue
                encoded_data[f"tooth_{tooth_num}"] = 0

            for feature, max_val in self.cat_map.items():
                for i in range(0, max_val + 1):
                    encoded_data[f"{feature}_{i}"] = 0

            for idx, row in patient_data.iterrows():
                encoded_data.at[idx, f"tooth_{row['tooth']}"] = 1
                for feature in self.cat_map:
                    encoded_data.at[idx, f"{feature}_{row[feature]}"] = 1

                complete_data = pd.concat(
                    [
                        base_data.reset_index(drop=True),
                        encoded_data.reset_index(drop=True),
                    ],
                    axis=1,
                )

            complete_data = complete_data.loc[:, ~complete_data.columns.duplicated()]
            duplicates = complete_data.columns[
                complete_data.columns.duplicated()
            ].unique()
            if len(duplicates) > 0:
                print("Duplicate columns found:", duplicates)

        elif encoding == "target":
            complete_data = base_data.copy()
            for column in self.target_cols:
                if column in patient_data.columns:
                    complete_data[column] = patient_data[column].values
        else:
            raise ValueError(f"Unsupported encoding type: {encoding}")

        if hasattr(self.model, "get_booster"):
            model_features = self.model.get_booster().feature_names
        elif hasattr(self.model, "feature_names_in_"):
            model_features = self.model.feature_names_in_
        else:
            raise ValueError("Model type not supported for feature extraction")

        for feature in model_features:
            if feature not in complete_data.columns:
                complete_data[feature] = 0

        predict_data = complete_data[model_features]

        return predict_data

    def prepare_inference(
        self,
        task: str,
        patient_data: pd.DataFrame,
        encoding: str,
        X_train: pd.DataFrame,
        y_train: pd.Series,
    ) -> Tuple[pd.DataFrame, pd.DataFrame]:
        """Prepares the data for inference.

        Args:
            task (str): The task name for which the model was trained.
            patient_data (pd.DataFrame): The patient's data as a DataFrame.
            encoding (str): Encoding type ("one_hot" or "target").
            X_train (pd.DataFrame): Training features for target encoding.
            y_train (pd.Series): Training target for target encoding.

        Returns:
            Tuple: Transformed patient data for prediction and patient data.
        """
        if patient_data.empty:
            raise ValueError(
                "Patient data empty. Please submit data before running inference."
            )
        if self.verbose:
            print("Patient Data Received for Inference:\n", patient_data)

        engine = StaticProcessEngine()
        dataloader = ProcessedDataLoader(task, encoding)
        patient_data[self.group_col] = "inference_patient"
        raw_data = engine.create_tooth_features(
            df=patient_data, neighbors=True, patient_id=False
        )

        if encoding == "target":
            raw_data = dataloader.encode_categorical_columns(df=raw_data)
            resampler = Resampler(self.classification, encoding)
            _, raw_data = resampler.apply_target_encoding(
                X=X_train, X_val=raw_data, y=y_train
            )

            for key in raw_data.columns:
                if key not in self.cat_vars and key in patient_data.columns:
                    raw_data[key] = patient_data[key].values
        else:
            raw_data = self.create_predict_data(
                raw_data=raw_data, patient_data=patient_data, encoding=encoding
            )

        predict_data = self.create_predict_data(
            raw_data=raw_data, patient_data=patient_data, encoding=encoding
        )
        predict_data = dataloader.scale_numeric_columns(df=predict_data)

        return predict_data, patient_data

    def patient_inference(
        self, predict_data: pd.DataFrame, patient_data: pd.DataFrame
    ) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
        """Run inference on the patient's data.

        Args:
            predict_data (pd.DataFrame): Transformed patient data for prediction.
            patient_data (pd.DataFrame): The patient's data as a DataFrame.

        Returns:
            Tuple:
                - predict_data: Transformed patient data for prediction.
                - output_data: DataFrame with columns "tooth", "side",
                transformed "prediction", and "probability".
                - results: Original results from the model inference.
        """
        results = self.predict(predict_data)
        output_data = patient_data[["tooth", "side"]].copy()
        output_data["prediction"] = results["prediction"]
        output_data["probability"] = results.drop(columns=["prediction"]).max(axis=1)
        return predict_data, output_data, results

    def process_patient(
        self,
        patient_id: int,
        train_df: pd.DataFrame,
        patient_data: pd.DataFrame,
        encoding: str,
        model_params: dict,
        resampler: Resampler,
    ) -> pd.DataFrame:
        """Processes a single patient's data for jackknife resampling.

        Args:
            patient_id (int): ID of the patient to exclude from training.
            train_df (pd.DataFrame): Full training dataset.
            patient_data (pd.DataFrame): The data for the patient(s) to predict on.
            encoding (str): Encoding type used ('one_hot' or 'target').
            model_params (dict): Parameters for the model initialization.
            resampler (Resampler): Instance of the Resampler class for encoding.

        Returns:
            predictions_df: DataFrame containing patient predictions and probabilities.
        """
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore", category=UserWarning)
            warnings.filterwarnings("ignore", category=ConvergenceWarning)

        train_data = train_df[train_df[self.group_col] != patient_id]
        X_train = train_data.drop(columns=[self.y])
        y_train = train_data[self.y]

        if encoding == "target":
            X_train = X_train.drop(columns=[self.group_col], errors="ignore")
            X_train_enc, _ = resampler.apply_target_encoding(
                X=X_train, X_val=None, y=y_train, jackknife=True
            )
        else:
            X_train_enc = X_train.drop(columns=[self.group_col], errors="ignore")

        predictor = clone(self.model)
        predictor.set_params(**model_params)
        predictor.fit(X_train_enc, y_train)

        if self.classification == "binary" and hasattr(predictor, "best_threshold"):
            probs = get_probs(
                model=predictor, classification=self.classification, X=patient_data
            )
            if probs is not None:
                val_pred_classes = (probs >= predictor.best_threshold).astype(int)
            else:
                val_pred_classes = predictor.predict(patient_data)
        else:
            val_pred_classes = predictor.predict(patient_data)
            probs = predictor.predict_proba(patient_data)

        predictions_df = pd.DataFrame(
            probs,
            columns=[str(cls) for cls in predictor.classes_],
            index=patient_data.index,
        )
        return predictions_df.assign(
            prediction=val_pred_classes,
            iteration=patient_id,
            data_index=patient_data.index,
        )

    @abstractmethod
    def jackknife_resampling(
        self,
        train_df: pd.DataFrame,
        patient_data: pd.DataFrame,
        encoding: str,
        model_params: dict,
        sample_fraction: float,
        n_jobs: int,
    ):
        """Perform jackknife resampling with retraining for each patient.

        Args:
            train_df (pd.DataFrame): Full training dataset.
            patient_data (pd.DataFrame): The data for the patient(s) to predict on.
            encoding (str): Encoding type used ('one_hot' or 'target').
            model_params (dict): Parameters for the model initialization.
            sample_fraction (float, optional): Proportion of patient IDs to use for
                jackknife resampling.
            n_jobs (int, optional): Number of jobs to run in parallel.
        """

    @abstractmethod
    def jackknife_confidence_intervals(
        self, jackknife_results: pd.DataFrame, alpha: float
    ):
        """Compute confidence intervals from jackknife results.

        Args:
            jackknife_results (pd.DataFrame): DataFrame with jackknife predictions.
            alpha (float, optional): Significance level for confidence intervals.
        """

    @abstractmethod
    def plot_jackknife_intervals(
        self,
        ci_dict: Dict[int, Dict[str, Dict[str, float]]],
        data_indices: List[int],
        original_preds: pd.DataFrame,
    ):
        """Plot Jackknife confidence intervals.

        Args:
            ci_dict (Dict[int, Dict[str, Dict[str, float]]]): Confidence intervals for
                each data index and class.
            data_indices (List[int]): List of data indices to plot.
            original_preds (pd.DataFrame): DataFrame containing original predictions and
                probabilities for each data point.
        """

    @abstractmethod
    def jackknife_inference(
        self,
        model: Any,
        train_df: pd.DataFrame,
        patient_data: pd.DataFrame,
        encoding: str,
        inference_results: pd.DataFrame,
        alpha: float,
        sample_fraction: float,
        n_jobs: int,
        max_plots: int,
    ):
        """Run jackknife inference and generate confidence intervals and plots.

        Args:
            model (Any): Trained model instance.
            train_df (pd.DataFrame): Training DataFrame.
            patient_data (pd.DataFrame): Patient data to predict on.
            encoding (str): Encoding type.
            inference_results (pd.DataFrame): Original inference results.
            alpha (float, optional): Significance level for confidence intervals.
            sample_fraction (float, optional): Fraction of patient IDs for jackknife.
            n_jobs (int, optional): Number of parallel jobs.
            max_plots (int): Maximum number of plots for jackknife intervals.
        """

__init__(classification, model, verbose)

Initialize the ModelInference class with a trained model.

Source code in periomod/inference/_baseinference.py
63
64
65
66
67
68
def __init__(self, classification: str, model: Any, verbose: bool):
    """Initialize the ModelInference class with a trained model."""
    super().__init__()
    self.classification = classification
    self.model = model
    self.verbose = verbose

create_predict_data(raw_data, patient_data, encoding)

Creates prediction data for model inference.

Parameters:

Name Type Description Default
raw_data DataFrame

The raw, preprocessed data.

required
patient_data DataFrame

Original patient data before preprocessing.

required
encoding str

Type of encoding used ('one_hot' or 'target').

required

Returns:

Name Type Description
predict_data DataFrame

A DataFrame containing the prepared data for model prediction.

Source code in periomod/inference/_baseinference.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
def create_predict_data(
    self,
    raw_data: pd.DataFrame,
    patient_data: pd.DataFrame,
    encoding: str,
) -> pd.DataFrame:
    """Creates prediction data for model inference.

    Args:
        raw_data (pd.DataFrame): The raw, preprocessed data.
        patient_data (pd.DataFrame): Original patient data before preprocessing.
        encoding (str): Type of encoding used ('one_hot' or 'target').

    Returns:
        predict_data: A DataFrame containing the prepared data for model prediction.
    """
    base_data = raw_data.copy()

    if encoding == "one_hot":
        drop_columns = self.cat_vars + self.infect_vars
        base_data = base_data.drop(columns=drop_columns, errors="ignore")
        encoded_data = pd.DataFrame(index=base_data.index)

        for tooth_num in range(11, 49):
            if tooth_num % 10 == 0 or tooth_num % 10 == 9:
                continue
            encoded_data[f"tooth_{tooth_num}"] = 0

        for feature, max_val in self.cat_map.items():
            for i in range(0, max_val + 1):
                encoded_data[f"{feature}_{i}"] = 0

        for idx, row in patient_data.iterrows():
            encoded_data.at[idx, f"tooth_{row['tooth']}"] = 1
            for feature in self.cat_map:
                encoded_data.at[idx, f"{feature}_{row[feature]}"] = 1

            complete_data = pd.concat(
                [
                    base_data.reset_index(drop=True),
                    encoded_data.reset_index(drop=True),
                ],
                axis=1,
            )

        complete_data = complete_data.loc[:, ~complete_data.columns.duplicated()]
        duplicates = complete_data.columns[
            complete_data.columns.duplicated()
        ].unique()
        if len(duplicates) > 0:
            print("Duplicate columns found:", duplicates)

    elif encoding == "target":
        complete_data = base_data.copy()
        for column in self.target_cols:
            if column in patient_data.columns:
                complete_data[column] = patient_data[column].values
    else:
        raise ValueError(f"Unsupported encoding type: {encoding}")

    if hasattr(self.model, "get_booster"):
        model_features = self.model.get_booster().feature_names
    elif hasattr(self.model, "feature_names_in_"):
        model_features = self.model.feature_names_in_
    else:
        raise ValueError("Model type not supported for feature extraction")

    for feature in model_features:
        if feature not in complete_data.columns:
            complete_data[feature] = 0

    predict_data = complete_data[model_features]

    return predict_data

jackknife_confidence_intervals(jackknife_results, alpha) abstractmethod

Compute confidence intervals from jackknife results.

Parameters:

Name Type Description Default
jackknife_results DataFrame

DataFrame with jackknife predictions.

required
alpha float

Significance level for confidence intervals.

required
Source code in periomod/inference/_baseinference.py
333
334
335
336
337
338
339
340
341
342
@abstractmethod
def jackknife_confidence_intervals(
    self, jackknife_results: pd.DataFrame, alpha: float
):
    """Compute confidence intervals from jackknife results.

    Args:
        jackknife_results (pd.DataFrame): DataFrame with jackknife predictions.
        alpha (float, optional): Significance level for confidence intervals.
    """

jackknife_inference(model, train_df, patient_data, encoding, inference_results, alpha, sample_fraction, n_jobs, max_plots) abstractmethod

Run jackknife inference and generate confidence intervals and plots.

Parameters:

Name Type Description Default
model Any

Trained model instance.

required
train_df DataFrame

Training DataFrame.

required
patient_data DataFrame

Patient data to predict on.

required
encoding str

Encoding type.

required
inference_results DataFrame

Original inference results.

required
alpha float

Significance level for confidence intervals.

required
sample_fraction float

Fraction of patient IDs for jackknife.

required
n_jobs int

Number of parallel jobs.

required
max_plots int

Maximum number of plots for jackknife intervals.

required
Source code in periomod/inference/_baseinference.py
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
@abstractmethod
def jackknife_inference(
    self,
    model: Any,
    train_df: pd.DataFrame,
    patient_data: pd.DataFrame,
    encoding: str,
    inference_results: pd.DataFrame,
    alpha: float,
    sample_fraction: float,
    n_jobs: int,
    max_plots: int,
):
    """Run jackknife inference and generate confidence intervals and plots.

    Args:
        model (Any): Trained model instance.
        train_df (pd.DataFrame): Training DataFrame.
        patient_data (pd.DataFrame): Patient data to predict on.
        encoding (str): Encoding type.
        inference_results (pd.DataFrame): Original inference results.
        alpha (float, optional): Significance level for confidence intervals.
        sample_fraction (float, optional): Fraction of patient IDs for jackknife.
        n_jobs (int, optional): Number of parallel jobs.
        max_plots (int): Maximum number of plots for jackknife intervals.
    """

jackknife_resampling(train_df, patient_data, encoding, model_params, sample_fraction, n_jobs) abstractmethod

Perform jackknife resampling with retraining for each patient.

Parameters:

Name Type Description Default
train_df DataFrame

Full training dataset.

required
patient_data DataFrame

The data for the patient(s) to predict on.

required
encoding str

Encoding type used ('one_hot' or 'target').

required
model_params dict

Parameters for the model initialization.

required
sample_fraction float

Proportion of patient IDs to use for jackknife resampling.

required
n_jobs int

Number of jobs to run in parallel.

required
Source code in periomod/inference/_baseinference.py
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
@abstractmethod
def jackknife_resampling(
    self,
    train_df: pd.DataFrame,
    patient_data: pd.DataFrame,
    encoding: str,
    model_params: dict,
    sample_fraction: float,
    n_jobs: int,
):
    """Perform jackknife resampling with retraining for each patient.

    Args:
        train_df (pd.DataFrame): Full training dataset.
        patient_data (pd.DataFrame): The data for the patient(s) to predict on.
        encoding (str): Encoding type used ('one_hot' or 'target').
        model_params (dict): Parameters for the model initialization.
        sample_fraction (float, optional): Proportion of patient IDs to use for
            jackknife resampling.
        n_jobs (int, optional): Number of jobs to run in parallel.
    """

patient_inference(predict_data, patient_data)

Run inference on the patient's data.

Parameters:

Name Type Description Default
predict_data DataFrame

Transformed patient data for prediction.

required
patient_data DataFrame

The patient's data as a DataFrame.

required

Returns:

Name Type Description
Tuple Tuple[DataFrame, DataFrame, DataFrame]
  • predict_data: Transformed patient data for prediction.
  • output_data: DataFrame with columns "tooth", "side", transformed "prediction", and "probability".
  • results: Original results from the model inference.
Source code in periomod/inference/_baseinference.py
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
def patient_inference(
    self, predict_data: pd.DataFrame, patient_data: pd.DataFrame
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    """Run inference on the patient's data.

    Args:
        predict_data (pd.DataFrame): Transformed patient data for prediction.
        patient_data (pd.DataFrame): The patient's data as a DataFrame.

    Returns:
        Tuple:
            - predict_data: Transformed patient data for prediction.
            - output_data: DataFrame with columns "tooth", "side",
            transformed "prediction", and "probability".
            - results: Original results from the model inference.
    """
    results = self.predict(predict_data)
    output_data = patient_data[["tooth", "side"]].copy()
    output_data["prediction"] = results["prediction"]
    output_data["probability"] = results.drop(columns=["prediction"]).max(axis=1)
    return predict_data, output_data, results

plot_jackknife_intervals(ci_dict, data_indices, original_preds) abstractmethod

Plot Jackknife confidence intervals.

Parameters:

Name Type Description Default
ci_dict Dict[int, Dict[str, Dict[str, float]]]

Confidence intervals for each data index and class.

required
data_indices List[int]

List of data indices to plot.

required
original_preds DataFrame

DataFrame containing original predictions and probabilities for each data point.

required
Source code in periomod/inference/_baseinference.py
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
@abstractmethod
def plot_jackknife_intervals(
    self,
    ci_dict: Dict[int, Dict[str, Dict[str, float]]],
    data_indices: List[int],
    original_preds: pd.DataFrame,
):
    """Plot Jackknife confidence intervals.

    Args:
        ci_dict (Dict[int, Dict[str, Dict[str, float]]]): Confidence intervals for
            each data index and class.
        data_indices (List[int]): List of data indices to plot.
        original_preds (pd.DataFrame): DataFrame containing original predictions and
            probabilities for each data point.
    """

predict(input_data)

Run prediction on a batch of input data.

Parameters:

Name Type Description Default
input_data DataFrame

DataFrame containing feature values.

required

Returns:

Name Type Description
probs_df DataFrame

DataFrame with predictions and probabilities for each class.

Source code in periomod/inference/_baseinference.py
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
def predict(self, input_data: pd.DataFrame) -> pd.DataFrame:
    """Run prediction on a batch of input data.

    Args:
        input_data (pd.DataFrame): DataFrame containing feature values.

    Returns:
        probs_df: DataFrame with predictions and probabilities for each class.
    """
    probs = self.model.predict_proba(input_data)

    if self.classification == "binary":
        if (
            hasattr(self.model, "best_threshold")
            and self.model.best_threshold is not None
        ):
            preds = (probs[:, 1] >= self.model.best_threshold).astype(int)
    preds = self.model.predict(input_data)
    classes = [str(cls) for cls in self.model.classes_]
    probs_df = pd.DataFrame(probs, columns=classes, index=input_data.index)
    probs_df["prediction"] = preds
    return probs_df

prepare_inference(task, patient_data, encoding, X_train, y_train)

Prepares the data for inference.

Parameters:

Name Type Description Default
task str

The task name for which the model was trained.

required
patient_data DataFrame

The patient's data as a DataFrame.

required
encoding str

Encoding type ("one_hot" or "target").

required
X_train DataFrame

Training features for target encoding.

required
y_train Series

Training target for target encoding.

required

Returns:

Name Type Description
Tuple Tuple[DataFrame, DataFrame]

Transformed patient data for prediction and patient data.

Source code in periomod/inference/_baseinference.py
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
def prepare_inference(
    self,
    task: str,
    patient_data: pd.DataFrame,
    encoding: str,
    X_train: pd.DataFrame,
    y_train: pd.Series,
) -> Tuple[pd.DataFrame, pd.DataFrame]:
    """Prepares the data for inference.

    Args:
        task (str): The task name for which the model was trained.
        patient_data (pd.DataFrame): The patient's data as a DataFrame.
        encoding (str): Encoding type ("one_hot" or "target").
        X_train (pd.DataFrame): Training features for target encoding.
        y_train (pd.Series): Training target for target encoding.

    Returns:
        Tuple: Transformed patient data for prediction and patient data.
    """
    if patient_data.empty:
        raise ValueError(
            "Patient data empty. Please submit data before running inference."
        )
    if self.verbose:
        print("Patient Data Received for Inference:\n", patient_data)

    engine = StaticProcessEngine()
    dataloader = ProcessedDataLoader(task, encoding)
    patient_data[self.group_col] = "inference_patient"
    raw_data = engine.create_tooth_features(
        df=patient_data, neighbors=True, patient_id=False
    )

    if encoding == "target":
        raw_data = dataloader.encode_categorical_columns(df=raw_data)
        resampler = Resampler(self.classification, encoding)
        _, raw_data = resampler.apply_target_encoding(
            X=X_train, X_val=raw_data, y=y_train
        )

        for key in raw_data.columns:
            if key not in self.cat_vars and key in patient_data.columns:
                raw_data[key] = patient_data[key].values
    else:
        raw_data = self.create_predict_data(
            raw_data=raw_data, patient_data=patient_data, encoding=encoding
        )

    predict_data = self.create_predict_data(
        raw_data=raw_data, patient_data=patient_data, encoding=encoding
    )
    predict_data = dataloader.scale_numeric_columns(df=predict_data)

    return predict_data, patient_data

process_patient(patient_id, train_df, patient_data, encoding, model_params, resampler)

Processes a single patient's data for jackknife resampling.

Parameters:

Name Type Description Default
patient_id int

ID of the patient to exclude from training.

required
train_df DataFrame

Full training dataset.

required
patient_data DataFrame

The data for the patient(s) to predict on.

required
encoding str

Encoding type used ('one_hot' or 'target').

required
model_params dict

Parameters for the model initialization.

required
resampler Resampler

Instance of the Resampler class for encoding.

required

Returns:

Name Type Description
predictions_df DataFrame

DataFrame containing patient predictions and probabilities.

Source code in periomod/inference/_baseinference.py
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
def process_patient(
    self,
    patient_id: int,
    train_df: pd.DataFrame,
    patient_data: pd.DataFrame,
    encoding: str,
    model_params: dict,
    resampler: Resampler,
) -> pd.DataFrame:
    """Processes a single patient's data for jackknife resampling.

    Args:
        patient_id (int): ID of the patient to exclude from training.
        train_df (pd.DataFrame): Full training dataset.
        patient_data (pd.DataFrame): The data for the patient(s) to predict on.
        encoding (str): Encoding type used ('one_hot' or 'target').
        model_params (dict): Parameters for the model initialization.
        resampler (Resampler): Instance of the Resampler class for encoding.

    Returns:
        predictions_df: DataFrame containing patient predictions and probabilities.
    """
    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", category=UserWarning)
        warnings.filterwarnings("ignore", category=ConvergenceWarning)

    train_data = train_df[train_df[self.group_col] != patient_id]
    X_train = train_data.drop(columns=[self.y])
    y_train = train_data[self.y]

    if encoding == "target":
        X_train = X_train.drop(columns=[self.group_col], errors="ignore")
        X_train_enc, _ = resampler.apply_target_encoding(
            X=X_train, X_val=None, y=y_train, jackknife=True
        )
    else:
        X_train_enc = X_train.drop(columns=[self.group_col], errors="ignore")

    predictor = clone(self.model)
    predictor.set_params(**model_params)
    predictor.fit(X_train_enc, y_train)

    if self.classification == "binary" and hasattr(predictor, "best_threshold"):
        probs = get_probs(
            model=predictor, classification=self.classification, X=patient_data
        )
        if probs is not None:
            val_pred_classes = (probs >= predictor.best_threshold).astype(int)
        else:
            val_pred_classes = predictor.predict(patient_data)
    else:
        val_pred_classes = predictor.predict(patient_data)
        probs = predictor.predict_proba(patient_data)

    predictions_df = pd.DataFrame(
        probs,
        columns=[str(cls) for cls in predictor.classes_],
        index=patient_data.index,
    )
    return predictions_df.assign(
        prediction=val_pred_classes,
        iteration=patient_id,
        data_index=patient_data.index,
    )