qxmt.models.hyperparameter_search.search module#
- class qxmt.models.hyperparameter_search.search.HyperParameterSearch(X, y, model, sampler_type, search_space, search_args, objective=None, logger=<Logger qxmt.models.hyperparameter_search.search (INFO)>)
Bases:
object
Hyperparameter search class for machine models using optuna. This class provides grid search , random search, and TPE search for hyperparameter optimization. The search space is defined by the user, and the search arguments can be customized. Reference: https://optuna.readthedocs.io/en/stable/reference/index.html
Example
>>> from sklearn.svm import SVC >>> from sklearn.datasets import load_iris >>> from sklearn.model_selection import train_test_split >>> from qxmt.models.hyperparameter_search.search import HyperParameterSearch >>> X, y = load_iris(return_X_y=True) >>> X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2) >>> model = SVC() >>> sampler_type = "tpe" >>> search_space = { ... "C": [0.1, 1.0], ... "gamma": [0.01, 0.1] ... } >>> search_args = { ... "cv": 5, ... "direction": "maximize", ... "n_jobs": -1, ... } >>> searcher = HyperParameterSearch(X_train, y_train, model, sampler_type, search_space, search_args) >>> best_params = searcher.search() >>> best_params {'C': 0.8526745595533768, 'gamma': 0.01217052619278743}
- Parameters:
X (ndarray)
y (ndarray)
model (Any)
sampler_type (str)
search_space (dict[str, list[Any]])
search_args (dict[str, Any] | None)
objective (Callable | None)
logger (Logger)
- __init__(X, y, model, sampler_type, search_space, search_args, objective=None, logger=<Logger qxmt.models.hyperparameter_search.search (INFO)>)
Initialize the hyperparameter search class.
- Parameters:
X (np.ndarray) – dataset for search
y (np.ndarray) – target values for search
model (Any) – model instance for hyperparameter search
sampler_type (str) – sampler type for hyperparameter search (random, grid, tpe)
search_space (dict[str, list[Any]]) – search space for hyperparameter search
search_args (Optional[dict[str, Any]]) – search arguments for hyperparameter search
objective (Optional[Callable], optional) – objective function for hyperparameter search. Defaults to None.
logger (Logger, optional) – logger instance. Defaults to LOGGER.
- Raises:
ValueError – Sampler of search_args not matching search type
- Return type:
None
- default_objective(trial)
Default objective function for hyperparameter search.
- Parameters:
trial (Trial) – optuna trial instance
- Returns:
score of the model
- Return type:
float
- search()
Search the best hyperparameters for the model.
- Returns:
best hyperparameters found by search
- Return type:
dict[str, Any]