GridSearchCV работает без сбоев при скоринге='accuracy', но не при скоринге=accuracy_score

Когда я запускаю следующий фрагмент кода в блокноте Jupyter внутри кода Visual Studio, он работает без проблем.

from sklearn.datasets import load_iris
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
from sklearn.model_selection import GridSearchCV

X, y = load_iris(return_X_y=True, as_frame=True)

gs = GridSearchCV(estimator=KNeighborsClassifier(),
                  param_grid=[{'n_neighbors': [3]}],
                  scoring='accuracy')
#                  scoring=accuracy_score)

gs.fit(X, y)

Однако если я откомментирую закомментированную строку, закомментирую строку над ней и повторно запущу блокнот, я получу следующую ошибку. Почему?

c:\Users\isc\Documents\Python\MLClassification\.venv\Lib\site-packages\sklearn\model_selection\_validation.py:982: UserWarning: Scoring failed. The score on this train-test partition for these parameters will be set to nan. Details: 
Traceback (most recent call last):
  File "c:\Users\isc\Documents\Python\MLClassification\.venv\Lib\site-packages\sklearn\model_selection\_validation.py", line 971, in _score
    scores = scorer(estimator, X_test, y_test, **score_params)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\isc\Documents\Python\MLClassification\.venv\Lib\site-packages\sklearn\utils\_param_validation.py", line 191, in wrapper
    params = func_sig.bind(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\isc\AppData\Local\Programs\Python\Python312\Lib\inspect.py", line 3267, in bind
    return self._bind(args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\isc\AppData\Local\Programs\Python\Python312\Lib\inspect.py", line 3191, in _bind
    raise TypeError(
TypeError: too many positional arguments

  warnings.warn(
c:\Users\isc\Documents\Python\MLClassification\.venv\Lib\site-packages\sklearn\model_selection\_validation.py:982: UserWarning: Scoring failed. The score on this train-test partition for these parameters will be set to nan. Details: 
Traceback (most recent call last):
  File "c:\Users\isc\Documents\Python\MLClassification\.venv\Lib\site-packages\sklearn\model_selection\_validation.py", line 971, in _score
    scores = scorer(estimator, X_test, y_test, **score_params)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\isc\Documents\Python\MLClassification\.venv\Lib\site-packages\sklearn\utils\_param_validation.py", line 191, in wrapper
    params = func_sig.bind(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\isc\AppData\Local\Programs\Python\Python312\Lib\inspect.py", line 3267, in bind
    return self._bind(args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\isc\AppData\Local\Programs\Python\Python312\Lib\inspect.py", line 3191, in _bind
    raise TypeError(
TypeError: too many positional arguments

  warnings.warn(
c:\Users\isc\Documents\Python\MLClassification\.venv\Lib\site-packages\sklearn\model_selection\_validation.py:982: UserWarning: Scoring failed. The score on this train-test partition for these parameters will be set to nan. Details: 
Traceback (most recent call last):
  File "c:\Users\isc\Documents\Python\MLClassification\.venv\Lib\site-packages\sklearn\model_selection\_validation.py", line 971, in _score
    scores = scorer(estimator, X_test, y_test, **score_params)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\isc\Documents\Python\MLClassification\.venv\Lib\site-packages\sklearn\utils\_param_validation.py", line 191, in wrapper
    params = func_sig.bind(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\isc\AppData\Local\Programs\Python\Python312\Lib\inspect.py", line 3267, in bind
    return self._bind(args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\isc\AppData\Local\Programs\Python\Python312\Lib\inspect.py", line 3191, in _bind
    raise TypeError(
TypeError: too many positional arguments

  warnings.warn(
c:\Users\isc\Documents\Python\MLClassification\.venv\Lib\site-packages\sklearn\model_selection\_validation.py:982: UserWarning: Scoring failed. The score on this train-test partition for these parameters will be set to nan. Details: 
Traceback (most recent call last):
  File "c:\Users\isc\Documents\Python\MLClassification\.venv\Lib\site-packages\sklearn\model_selection\_validation.py", line 971, in _score
    scores = scorer(estimator, X_test, y_test, **score_params)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\isc\Documents\Python\MLClassification\.venv\Lib\site-packages\sklearn\utils\_param_validation.py", line 191, in wrapper
    params = func_sig.bind(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\isc\AppData\Local\Programs\Python\Python312\Lib\inspect.py", line 3267, in bind
    return self._bind(args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\isc\AppData\Local\Programs\Python\Python312\Lib\inspect.py", line 3191, in _bind
    raise TypeError(
TypeError: too many positional arguments

  warnings.warn(
c:\Users\isc\Documents\Python\MLClassification\.venv\Lib\site-packages\sklearn\model_selection\_validation.py:982: UserWarning: Scoring failed. The score on this train-test partition for these parameters will be set to nan. Details: 
Traceback (most recent call last):
  File "c:\Users\isc\Documents\Python\MLClassification\.venv\Lib\site-packages\sklearn\model_selection\_validation.py", line 971, in _score
    scores = scorer(estimator, X_test, y_test, **score_params)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\isc\Documents\Python\MLClassification\.venv\Lib\site-packages\sklearn\utils\_param_validation.py", line 191, in wrapper
    params = func_sig.bind(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\isc\AppData\Local\Programs\Python\Python312\Lib\inspect.py", line 3267, in bind
    return self._bind(args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\isc\AppData\Local\Programs\Python\Python312\Lib\inspect.py", line 3191, in _bind
    raise TypeError(
TypeError: too many positional arguments

  warnings.warn(
c:\Users\isc\Documents\Python\MLClassification\.venv\Lib\site-packages\sklearn\model_selection\_search.py:1052: UserWarning: One or more of the test scores are non-finite: [nan]
  warnings.warn(
Почему в Python есть оператор "pass"?
Почему в Python есть оператор "pass"?
Оператор pass в Python - это простая концепция, которую могут быстро освоить даже новички без опыта программирования.
Некоторые методы, о которых вы не знали, что они существуют в Python
Некоторые методы, о которых вы не знали, что они существуют в Python
Python - самый известный и самый простой в изучении язык в наши дни. Имея широкий спектр применения в области машинного обучения, Data Science,...
Основы Python Часть I
Основы Python Часть I
Вы когда-нибудь задумывались, почему в программах на Python вы видите приведенный ниже код?
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
Алиса и Боб имеют неориентированный граф из n узлов и трех типов ребер:
Оптимизация кода с помощью тернарного оператора Python
Оптимизация кода с помощью тернарного оператора Python
И последнее, что мы хотели бы показать вам, прежде чем двигаться дальше, это
Советы по эффективной веб-разработке с помощью Python
Советы по эффективной веб-разработке с помощью Python
Как веб-разработчик, Python может стать мощным инструментом для создания эффективных и масштабируемых веб-приложений.
1
0
57
1
Перейти к ответу Данный вопрос помечен как решенный

Ответы 1

Ответ принят как подходящий

Интересный вопрос, пришлось провести небольшое исследование, чтобы выяснить причину.

Согласно документу для GridSearchCV, параметр scoring действительно может принимать строку, представляющую функцию оценки, или непосредственно вызываемый объект. Таким образом, использование вызываемого объекта напрямую, как вы пытались, технически поддерживается. Однако существуют особые требования к структуре вызываемого объекта, которые могут быть не сразу понятны из документации.

Проблема в том, что вы передали функцию accuracy_score напрямую. Когда вы используете функцию оценки непосредственно в качестве вызываемого объекта, она должна соответствовать конкретным требованиям, ожидаемым GridSearchCV: вызываемый объект принимает только модель, функции тестовых данных X_test и истинные метки y_test в качестве аргументов (плюс, при необходимости, **kwargs для обработки дополнительных параметров). ).

Стандарт accuracy_score не соответствует этому шаблону напрямую, поскольку он явно ожидает два аргумента (y_true и y_pred). Эта разница в ожиданиях приводит к ошибке, которую вы получаете.

Решение: используйте make_scorer. Он принимает метрическую функцию и адаптирует ее для соответствия ожидаемой сигнатуре. make_scorer оборачивает accuracy_score (или любую другую метрическую функцию) таким образом, чтобы ее можно было использовать непосредственно GridSearchCV или аналогичными утилитами путем правильной внутренней обработки шага прогнозирования и последующей передачи y_true и y_pred в фактическую функцию оценки.

Используйте это следующим образом:

from sklearn.metrics import accuracy_score, make_scorer
accuracy_scorer = make_scorer(accuracy_score)
...
gs = GridSearchCV(estimator=KNeighborsClassifier(),
              param_grid=[{'n_neighbors': [3]}],
              scoring=accuracy_scorer )

Другие вопросы по теме