Model
Bases: BaseConfig
Configurable machine learning model class with HPO options.
This class provides an interface for initializing machine learning models based on the specified learner type (e.g., random forest, logistic regression) and classification type (binary or multiclass). It supports optional hyperparameter optimization (HPO) configurations.
Inherits
BaseConfig
: Provides base configuration settings.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
learner
|
str
|
The machine learning algorithm to use. Options include: 'rf' (random forest), 'mlp' (multi-layer perceptron), 'xgb' (XGBoost), or 'lr' (logistic regression). |
required |
classification
|
str
|
Specifies the classification type. Can be either 'binary' or 'multiclass'. |
required |
hpo
|
Optional[str]
|
The hyperparameter optimization (HPO) method to use. Options are 'hebo' or 'rs'. Defaults to None, which requires specifying HPO in relevant methods. |
None
|
Attributes:
Name | Type | Description |
---|---|---|
learner |
str
|
The specified machine learning algorithm for the model. Options include 'rf', 'mlp', 'xgb', and 'lr'. |
classification |
str
|
Defines the type of classification task. Options are 'binary' or 'multiclass'. |
hpo |
Optional[str]
|
Hyperparameter optimization method for tuning, if specified. Options are 'hebo' or 'rs'. |
Methods:
Name | Description |
---|---|
get |
Class method returning a model and hyperparameter search space or parameter grid. |
get_model |
Class method that returns only the instantiated model without HPO options. |
Example
model_instance = Model.get(learner="rf", classification="binary", hpo="hebo")
trained_model = Model.get_model(learner="mlp", classification="multiclass")
Source code in periomod/learner/_learners.py
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
|
__init__(learner, classification, hpo=None)
¶
Initializes the Model with the learner type and classification.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
learner
|
str
|
The machine learning algorithm to use (e.g., 'rf', 'mlp', 'xgb', 'lr'). |
required |
classification
|
str
|
The type of classification ('binary' or 'multiclass'). |
required |
hpo
|
str
|
The hyperparameter optimization method to use. Defaults to None. |
None
|
Source code in periomod/learner/_learners.py
67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
|
get(learner, classification, hpo=None)
classmethod
¶
Return the machine learning model and parameter grid or HEBO search space.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
learner
|
str
|
The machine learning algorithm to use. |
required |
classification
|
str
|
The type of classification ('binary' or 'multiclass'). |
required |
hpo
|
str
|
The hyperparameter optimization method ('hebo' or 'rs'). |
None
|
Returns:
Name | Type | Description |
---|---|---|
Union |
Union[Tuple, Tuple]
|
If hpo is 'rs', returns a tuple of (model, parameter grid). If hpo is 'hebo', returns a tuple of (model, HEBO search space, transformation function). |
Source code in periomod/learner/_learners.py
128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
|
get_model(learner, classification)
classmethod
¶
Return only the machine learning model based on learner and classification.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
learner
|
str
|
The machine learning algorithm to use. |
required |
classification
|
str
|
Type of classification ('binary' or 'multiclass'). |
required |
Returns:
Name | Type | Description |
---|---|---|
model |
Union[RandomForestClassifier, LogisticRegression, MLPClassifier, XGBClassifier]
|
model instance (Union[sklearn estiamtor]). |
Source code in periomod/learner/_learners.py
171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
|