Я пытаюсь воспроизвести это решение Python pandas: как запустить множественную одномерную регрессию по группе, но используя линейную регрессию sklearn вместо статистических моделей.
import pandas as pd
import numpy as np
from sklearn.linear_model import LinearRegression
df = pd.DataFrame({
'y': np.random.randn(20),
'x1': np.random.randn(20),
'x2': np.random.randn(20),
'grp': ['a', 'b'] * 10})
def ols_res(x, y):
return pd.Series(LinearRegression.fit(x,y).predict(x))
results = df.groupby('grp').apply(lambda x : x[['x1', 'x2']].apply(ols_res, y=x['y']))
print(results)
Я получил:
TypeError: ("fit() missing 1 required positional argument: 'y'", 'occurred at index x1')
Результаты должны быть такими же, как в статье, на которую я ссылаюсь, а именно:
x1 x2
grp
a 0 -0.102766 -0.205196
1 -0.073282 -0.102290
2 0.023832 0.033228
3 0.059369 -0.017519
4 0.003281 -0.077150
... ...
b 5 0.072874 -0.002919
6 0.180362 0.000502
7 0.005274 0.050313
8 -0.065506 -0.005163
9 0.003419 -0.013829






В вашем коде есть две незначительные проблемы:
Вы не создаете экземпляр LinearRegressionобъект, поэтому ваш код фактически пытается вызвать несвязанный fit метод LinearRegressionсорт.
Даже если вы это исправите, экземпляр LinearRegression не сможет выполнять fit и transform, потому что он ожидает 2D-массив, а получает 1D-массив. Соответственно, вам также нужно изменить форму массива, содержащегося в каждом Series.
import pandas as pd
import numpy as np
from sklearn.linear_model import LinearRegression
df = pd.DataFrame({
'y': np.random.randn(20),
'x1': np.random.randn(20),
'x2': np.random.randn(20),
'grp': ['a', 'b'] * 10})
def ols_res(x, y):
x_2d = x.values.reshape(len(x), -1)
return pd.Series(LinearRegression().fit(x_2d, y).predict(x_2d))
results = df.groupby('grp').apply(lambda df: df[['x1', 'x2']].apply(ols_res, y=df['y']))
print(results)
Выход:
x1 x2
grp
a 0 -0.126680 0.137907
1 -0.441300 -0.595972
2 -0.285903 -0.385033
3 -0.252434 0.560938
4 -0.046632 -0.718514
5 -0.267396 -0.693155
6 -0.364425 -0.476643
7 -0.221493 -0.779082
8 -0.203781 0.722860
9 -0.106912 -0.090262
b 0 -0.015384 0.092137
1 0.478447 0.032881
2 0.366102 0.059832
3 -0.055907 0.055388
4 -0.221876 0.013941
5 -0.054299 0.048263
6 0.043979 0.024594
7 -0.307831 0.059972
8 -0.226570 -0.024809
9 0.394460 0.038921