qxmt.models.hyperparameter_search.search module

qxmt.models.hyperparameter_search.search module#

class qxmt.models.hyperparameter_search.search.HyperParameterSearch(X, y, model, search_type, search_space, search_args, objective=None, logger=<Logger qxmt.models.hyperparameter_search.search (INFO)>)

Bases: object

Hyperparameter search class for machine models using optuna. This class provides grid search , random search, and TPE search for hyperparameter optimization. The search space is defined by the user, and the search arguments can be customized. Reference: https://optuna.readthedocs.io/en/stable/reference/index.html

Example

>>> from sklearn.svm import SVC
>>> from sklearn.datasets import load_iris
>>> from sklearn.model_selection import train_test_split
>>> from qxmt.models.hyperparameter_search.search import HyperParameterSearch
>>> X, y = load_iris(return_X_y=True)
>>> X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
>>> model = SVC()
>>> search_type = "tpe"
>>> search_space = {
...     "C": [0.1, 1.0],
...     "gamma": [0.01, 0.1]
... }
>>> search_args = {
...     "cv": 5,
...     "direction": "maximize",
...     "n_jobs": -1,
... }
>>> searcher = HyperParameterSearch(X_train, y_train, model, search_type, search_space, search_args)
>>> best_params = searcher.search()
>>> best_params
{'C': 0.8526745595533768, 'gamma': 0.01217052619278743}
Parameters:
  • X (ndarray)

  • y (ndarray)

  • model (Any)

  • search_type (str)

  • search_space (dict[str, list[Any]])

  • search_args (dict[str, Any] | None)

  • objective (Callable | None)

  • logger (Logger)

__init__(X, y, model, search_type, search_space, search_args, objective=None, logger=<Logger qxmt.models.hyperparameter_search.search (INFO)>)

Initialize the hyperparameter search class.

Parameters:
  • X (np.ndarray) – dataset for search

  • y (np.ndarray) – target values for search

  • model (Any) – model instance for hyperparameter search

  • search_type (str) – search type for hyperparameter search (random, grid, tpe)

  • search_space (dict[str, list[Any]]) – search space for hyperparameter search

  • search_args (Optional[dict[str, Any]]) – search arguments for hyperparameter search

  • objective (Optional[Callable], optional) – objective function for hyperparameter search. Defaults to None.

  • logger (Logger, optional) – logger instance. Defaults to LOGGER.

Raises:

ValueError – Sampler of search_args not matching search type

Return type:

None

default_objective(trial)

Default objective function for hyperparameter search.

Parameters:

trial (Trial) – optuna trial instance

Returns:

score of the model

Return type:

float

search()

Search the best hyperparameters for the model.

Returns:

best hyperparameters found by search

Return type:

dict[str, Any]