Я пытаюсь разработать нейронную сеть, которая сможет оценить коэффициент концентрации напряжений Kt образцов с V-образным надрезом на основе сканирований профиля надреза. Сканирования были интерполированы для создания областей с равноудаленными точками, поскольку я хочу использовать 1D CNN, которая может интерпретировать значения высоты профиля. У меня около 150 образцов. Среднее значение Kt образцов составляет 2,27 со стандартным значением 0,17.
Пример массива, который используется в качестве входных данных до нормализации:
[4.7605 4.60461111 4.44872222 4.29283333 4.13694444 3.98105556
3.82516667 3.66927778 3.51338889 3.3575 3.3575 3.35472643
3.35195286 3.34917929 3.34635815 3.34346399 3.34056983 3.33767567
3.3351553 3.33265909 3.33016288 3.32763043 3.32502569 3.32242095
3.31981621 3.31755051 3.31533165 3.31311279 3.31102569 3.3092892
3.3075527 3.30581621 3.30391204 3.30197054 3.30002904 3.2985
3.2985 3.2985 3.2985 3.29807997 3.29752525 3.29697054
3.29641217 3.29583333 3.2952545 3.29467567 3.29353409 3.29214731
3.29076052 3.28931555 3.28728964 3.28526372 3.28323781 3.28126557
3.27932407 3.27738258 3.27544108 3.27349958 3.27155808 3.26961658
3.26763922 3.26561331 3.2635874 3.26156148 3.26123106 3.2609537
3.26067635 3.26081621 3.26168445 3.2625527 3.26342095 3.26551684
3.26773569 3.26995455 3.27229051 3.27489526 3.2775 3.28010474
3.28238215 3.28460101 3.28681987 3.28920268 3.29209684 3.294991
3.29788516 3.30002904 3.30197054 3.30391204 3.30592161 3.30823693
3.31055226 3.31286759 3.31531439 3.31781061 3.32030682 3.32281621
3.32542095 3.32802569 3.33063043 3.33316288 3.33565909 3.3381553
3.34067567 3.34356983 3.34646399 3.34935815 3.35201136 3.35450758
3.35700379 3.3595 3.3595 3.51516667 3.67083333 3.8265
3.98216667 4.13783333 4.2935 4.44916667 4.60483333 4.7605 ]
Я выполнил поиск по сетке, чтобы оптимизировать свою модель, но меня не устраивает точность. Добавление дополнительных сверточных слоев не сильно повлияло на точность, равно как и включение слоя LSTM. Я также пробовал разные скалеры. Модель всегда возвращает значение Kt 2,11XXXXX, при этом меняются только последние цифры. RMSE равно 0,17, а MAE равно 0,13.
def preprocessing(specimens_list, scaler=Normalizer):
# fill data lists
X = []
y = []
for specimen in specimens_list:
scan = list(specimen.KeyenceScansValues)[0]
x_data, y_data = scan.format_for_ML(NUM1, NUM2) # creates the equidistant points
X.append(y_data) # only use height (y) values
y.append(specimen.Kt)
# Calculate the mean of y_train (ignoring NaN values)
mean_y = np.nanmean(y)
# Replace NaN values in y_train with the mean
y = np.where(np.isnan(y), mean_y, y)
# scaling
scaler_x = scaler().fit(X)
X_scaled = scaler_x.transform(X) # Normalizer
return np.asarray(X_scaled), np.asarray(y)
def create_model(filter1:int=32, filter2:int=64, kernel_size1:int=5, kernel_size2:int=5, learning_rate:float=0.001):
model = models.Sequential([
layers.Reshape((NUM, 1), input_shape=(NUM,)),
layers.Conv1D(filters=filter1, kernel_size=kernel_size1, activation='relu', padding='same'),
layers.MaxPooling1D(pool_size=2, strides=2),
layers.Conv1D(filter2, kernel_size=kernel_size2, activation='relu', padding='same'),
layers.MaxPooling1D(pool_size=2, strides=2),
#layers.LSTM(units=128),
layers.Flatten(),
layers.Dense(64, activation='relu'),
layers.Dense(1) # Output layer for regression
])
optimizer = keras.optimizers.Adam(learning_rate = learning_rate)
model.compile(optimizer=optimizer,
loss = keras.losses.MeanSquaredError(),
metrics = [keras.metrics.RootMeanSquaredError(),
keras.metrics.MeanSquaredLogarithmicError(),
keras.metrics.MeanAbsoluteError()])
return model
def training(model, x_train, y_train, x_test, y_test, batch_size=5, epochs=10):
### Train model ###
#model.save_weights(checkpoint_path.format(epoch=0))
train_history = model.fit(x_train, y_train,
epochs = epochs,
batch_size = batch_size,
#callbacks = [cp_callback],
validation_data = (x_test, y_test),
verbose = 0)
return train_history
def evaluate(model, train_history, x_train, y_train, x_test, y_test):
### Track performance ###
training_performance = model.evaluate(x_train, y_train, verbose = 0)
validation_performance = model.evaluate(x_test, y_test, verbose = 0)
#model.summary()
print(f'Training performance: RMSE = {training_performance[1]:.2f}, MSLE = {training_performance[2]:.2f}, MAE = {training_performance[3]:.2f}')
print(f'Validation performance: RMSE = {validation_performance[1]:.2f}, MSLE = {validation_performance[2]:.2f}, MAE = {validation_performance[3]:.2f}')
return [training_performance[1], training_performance[2], training_performance[3], validation_performance[1], validation_performance[2], validation_performance[3]]
Некоторые исходные данные образца:
ID,x,y,Kt
1,[-3.6183010075758877, -0.30600000000000094, -0.28200000000000003, -0.25900000000000034, -0.2350000000000012, -0.21199999999999974, -0.1880000000000006, -0.16500000000000092, -0.14100000000000001, -0.11800000000000033, -0.0940000000000012, -0.07099999999999973, -0.0470000000000006, -0.02400000000000091, 0.0, 0.022999999999999687, 0.04699999999999882, 0.07099999999999973, 0.09399999999999942, 0.11800000000000033, 0.14100000000000001, 0.16499999999999915, 0.18799999999999883, 0.21199999999999974, 0.23499999999999943, 0.25900000000000034, 0.28200000000000003, 0.30599999999999916, 0.32899999999999885, 0.35299999999999976, 3.6653010075758865],[1.461, 0.089, 0.079, 0.069, 0.06, 0.052, 0.044, 0.035, 0.028, 0.02, 0.013, 0.007, 0.003, 0.001, 0.0, 0.0, 0.002, 0.005, 0.008, 0.014, 0.02, 0.026, 0.033, 0.041, 0.048, 0.056, 0.064, 0.073, 0.082, 0.091, 1.461],2.2766371542377914
2,[-3.7002299397640215, -0.40000000000000036, -0.37600000000000033, -0.35299999999999976, -0.32899999999999974, -0.30600000000000005, -0.28200000000000003, -0.25900000000000034, -0.23500000000000032, -0.2110000000000003, -0.18799999999999972, -0.1639999999999997, -0.14100000000000001, -0.11699999999999999, -0.0940000000000003, -0.07000000000000028, -0.04699999999999971, -0.022999999999999687, 0.0, 0.02400000000000002, 0.04699999999999971, 0.07099999999999973, 0.0940000000000003, 0.11800000000000033, 0.14100000000000001, 0.16500000000000004, 0.18799999999999972, 0.21199999999999974, 0.23500000000000032, 0.25900000000000034, 0.28200000000000003, 0.30600000000000005, 0.33000000000000007, 3.630229939764021],[1.466, 0.09899999999999998, 0.09199999999999997, 0.08399999999999996, 0.07699999999999996, 0.06999999999999995, 0.061999999999999944, 0.05399999999999994, 0.04499999999999993, 0.03700000000000003, 0.030000000000000027, 0.025000000000000022, 0.020000000000000018, 0.016000000000000014, 0.015000000000000013, 0.009000000000000008, 0.006000000000000005, 0.0010000000000000009, 0.0, 0.0030000000000000027, 0.010000000000000009, 0.015000000000000013, 0.025000000000000022, 0.030000000000000027, 0.03700000000000003, 0.04200000000000004, 0.04699999999999993, 0.051999999999999935, 0.061999999999999944, 0.07099999999999995, 0.07999999999999996, 0.08899999999999997, 0.09599999999999997, 1.466],2.616131437456064
3,[-3.621845163453171, -0.35299999999999976, -0.32899999999999885, -0.30599999999999916, -0.28200000000000003, -0.25899999999999856, -0.23499999999999943, -0.21199999999999974, -0.18799999999999883, -0.16499999999999915, -0.14100000000000001, -0.11799999999999855, -0.09399999999999942, -0.07099999999999973, -0.04699999999999882, -0.023999999999999133, 0.0, 0.023000000000001464, 0.0470000000000006, 0.07000000000000028, 0.0940000000000012, 0.11700000000000088, 0.14100000000000001, 0.16500000000000092, 0.1880000000000006, 0.21200000000000152, 0.2350000000000012, 0.25900000000000034, 0.28200000000000003, 0.30600000000000094, 0.3290000000000006, 0.35300000000000153, 0.3760000000000012, 0.40000000000000036, 3.6688451634531716],[1.453, 0.09899999999999998, 0.08899999999999997, 0.07899999999999996, 0.07000000000000006, 0.06000000000000005, 0.052000000000000046, 0.04500000000000004, 0.039000000000000035, 0.03300000000000003, 0.027000000000000024, 0.02200000000000002, 0.016000000000000014, 0.010000000000000009, 0.006000000000000005, 0.0020000000000000018, 0.0, 0.0, 0.0010000000000000009, 0.0050000000000000044, 0.007000000000000006, 0.007000000000000006, 0.01100000000000001, 0.014000000000000012, 0.018000000000000016, 0.02200000000000002, 0.028000000000000025, 0.03500000000000003, 0.04400000000000004, 0.05300000000000005, 0.062000000000000055, 0.07200000000000006, 0.08199999999999996, 0.09199999999999997, 1.453],2.142634834792794
7,[-3.643573085514529, -0.35300000000000065, -0.33000000000000007, -0.30600000000000005, -0.28300000000000036, -0.25900000000000034, -0.23500000000000032, -0.21200000000000063, -0.1880000000000006, -0.16500000000000004, -0.14100000000000001, -0.11800000000000033, -0.0940000000000003, -0.07100000000000062, -0.0470000000000006, -0.023000000000000576, 0.0, 0.02400000000000002, 0.04699999999999971, 0.07099999999999973, 0.09399999999999942, 0.11799999999999944, 0.14100000000000001, 0.16500000000000004, 0.18900000000000006, 0.21199999999999974, 0.23599999999999977, 0.25899999999999945, 0.2829999999999995, 0.30600000000000005, 0.33000000000000007, 0.35299999999999976, 3.6435730855145283],[1.462, 0.099, 0.09, 0.081, 0.072, 0.063, 0.053, 0.044, 0.035, 0.027, 0.02, 0.014, 0.009, 0.005, 0.002, 0.001, 0.0, 0.002, 0.006, 0.012, 0.017, 0.024, 0.03, 0.037, 0.043, 0.05, 0.057, 0.064, 0.072, 0.081, 0.089, 0.098, 1.462],2.3949189992310247
8,[-3.575987299076902, -0.28300000000000036, -0.25900000000000034, -0.23500000000000032, -0.21199999999999974, -0.18799999999999972, -0.16500000000000004, -0.14100000000000001, -0.11800000000000033, -0.0940000000000003, -0.07099999999999973, -0.04699999999999971, -0.022999999999999687, 0.0, 0.02400000000000002, 0.04699999999999971, 0.07099999999999973, 0.0940000000000003, 0.11800000000000033, 0.14200000000000035, 0.16500000000000004, 0.18900000000000006, 0.21199999999999974, 0.23599999999999977, 0.25900000000000034, 0.28300000000000036, 0.30600000000000005, 0.33000000000000007, 0.3540000000000001, 0.3769999999999998, 0.4009999999999998, 3.6939872990769014],[1.458, 0.094, 0.084, 0.074, 0.064, 0.055, 0.046, 0.038, 0.03, 0.023, 0.016, 0.009, 0.003, 0.0, 0.0, 0.002, 0.006, 0.011, 0.017, 0.023, 0.028, 0.032, 0.037, 0.041, 0.045, 0.05, 0.055, 0.061, 0.069, 0.079, 0.089, 1.458],2.502224358908798
9,[-3.682058366888767, -0.3769999999999998, -0.35299999999999976, -0.33000000000000007, -0.30600000000000005, -0.2829999999999995, -0.25899999999999945, -0.23599999999999977, -0.21199999999999974, -0.18799999999999972, -0.16500000000000004, -0.14100000000000001, -0.11799999999999944, -0.09399999999999942, -0.07099999999999973, -0.04699999999999971, -0.02400000000000002, 0.0, 0.02400000000000002, 0.0470000000000006, 0.07100000000000062, 0.0940000000000003, 0.11800000000000033, 0.14100000000000001, 0.16500000000000004, 0.1880000000000006, 0.21200000000000063, 0.23600000000000065, 0.25900000000000034, 0.28300000000000036, 0.30600000000000005, 3.6110583668887672],[1.461, 0.092, 0.083, 0.074, 0.065, 0.056, 0.048, 0.042, 0.036, 0.03, 0.025, 0.02, 0.015, 0.01, 0.006, 0.003, 0.001, 0.0, 0.002, 0.005, 0.01, 0.016, 0.024, 0.033, 0.04, 0.048, 0.056, 0.064, 0.073, 0.081, 0.091, 1.461],2.3797426230814387
10,[-3.6484015126392757, -0.35300000000000065, -0.33000000000000007, -0.30600000000000005, -0.28300000000000036, -0.25900000000000034, -0.23600000000000065, -0.21200000000000063, -0.18900000000000006, -0.16500000000000004, -0.14100000000000001, -0.11800000000000033, -0.0940000000000003, -0.07100000000000062, -0.0470000000000006, -0.02400000000000002, 0.0, 0.02400000000000002, 0.04699999999999971, 0.07099999999999973, 0.09399999999999942, 0.11799999999999944, 0.14100000000000001, 0.16500000000000004, 0.18900000000000006, 0.21199999999999974, 0.23599999999999977, 0.25899999999999945, 0.2829999999999995, 0.30600000000000005, 0.33000000000000007, 0.35299999999999976, 3.648401512639275],[1.459, 0.094, 0.084, 0.075, 0.066, 0.056, 0.047, 0.039, 0.031, 0.023, 0.017, 0.011, 0.008, 0.005, 0.002, 0.001, 0.0, 0.0, 0.002, 0.006, 0.01, 0.015, 0.021, 0.027, 0.033, 0.04, 0.048, 0.057, 0.066, 0.076, 0.086, 0.097, 1.459],2.1552667710791282
11,[-3.658058366888767, -0.35299999999999976, -0.33000000000000007, -0.30600000000000005, -0.2829999999999995, -0.25899999999999945, -0.23599999999999977, -0.21199999999999974, -0.18900000000000006, -0.16500000000000004, -0.14100000000000001, -0.11799999999999944, -0.09399999999999942, -0.07099999999999973, -0.04699999999999971, -0.02400000000000002, 0.0, 0.023000000000000576, 0.0470000000000006, 0.07000000000000028, 0.0940000000000003, 0.11800000000000033, 0.14100000000000001, 0.16500000000000004, 0.1880000000000006, 0.21200000000000063, 0.23500000000000032, 0.25900000000000034, 0.28200000000000003, 0.30600000000000005, 0.3290000000000006, 0.35300000000000065, 3.658058366888768],[1.462, 0.093, 0.084, 0.074, 0.064, 0.054, 0.045, 0.038, 0.032, 0.027, 0.023, 0.018, 0.012, 0.006, 0.003, 0.001, 0.0, 0.003, 0.007, 0.012, 0.017, 0.02, 0.024, 0.029, 0.035, 0.042, 0.051, 0.06, 0.069, 0.078, 0.088, 0.097, 1.462],2.490420148832229
12,[-3.646987299076901, -0.3539999999999992, -0.3299999999999992, -0.30599999999999916, -0.2829999999999995, -0.25899999999999945, -0.23599999999999977, -0.21199999999999974, -0.18899999999999917, -0.16499999999999915, -0.14199999999999946, -0.11799999999999944, -0.09499999999999975, -0.07099999999999973, -0.047999999999999154, -0.023999999999999133, 0.0, 0.023000000000000576, 0.0470000000000006, 0.07000000000000028, 0.0940000000000003, 0.11699999999999999, 0.14100000000000001, 0.1640000000000006, 0.1880000000000006, 0.2110000000000003, 0.23500000000000032, 0.258, 0.28200000000000003, 0.30600000000000005, 0.3290000000000006, 0.35300000000000065, 3.645987299076902],[1.46, 0.096, 0.087, 0.079, 0.071, 0.062, 0.053, 0.044, 0.037, 0.029, 0.022, 0.016, 0.011, 0.006, 0.003, 0.001, 0.0, 0.002, 0.005, 0.008, 0.013, 0.017, 0.022, 0.028, 0.034, 0.042, 0.05, 0.059, 0.068, 0.078, 0.088, 0.098, 1.46],2.334695259830575
14,[-3.6134725804511403, -0.30600000000000005, -0.28300000000000036, -0.25900000000000034, -0.23600000000000065, -0.21200000000000063, -0.1880000000000006, -0.16500000000000004, -0.14100000000000001, -0.11800000000000033, -0.0940000000000003, -0.07100000000000062, -0.0470000000000006, -0.02400000000000002, 0.0, 0.022999999999999687, 0.04699999999999971, 0.0699999999999994, 0.09399999999999942, 0.11699999999999999, 0.14100000000000001, 0.16499999999999915, 0.18799999999999972, 0.21199999999999974, 0.23499999999999943, 0.25899999999999945, 0.28199999999999914, 0.30599999999999916, 0.32899999999999974, 0.35299999999999976, 0.37599999999999945, 3.6834725804511397],[1.466, 0.096, 0.086, 0.075, 0.064, 0.055, 0.046, 0.038, 0.031, 0.024, 0.018, 0.012, 0.006, 0.002, 0.0, 0.0, 0.002, 0.006, 0.011, 0.016, 0.021, 0.026, 0.031, 0.037, 0.044, 0.052, 0.06, 0.069, 0.078, 0.088, 0.097, 1.466],2.306403976152674
15,[-3.5941588719521564, -0.3060000000000005, -0.2820000000000005, -0.25900000000000034, -0.23500000000000032, -0.21200000000000063, -0.1880000000000006, -0.16500000000000004, -0.14100000000000001, -0.11699999999999999, -0.0940000000000003, -0.07000000000000028, -0.0470000000000006, -0.023000000000000576, 0.0, 0.02400000000000002, 0.04800000000000004, 0.07099999999999973, 0.09499999999999975, 0.11799999999999944, 0.14199999999999946, 0.16599999999999948, 0.18900000000000006, 0.21300000000000008, 0.23599999999999977, 0.2599999999999998, 0.2829999999999995, 0.3069999999999995, 0.3309999999999995, 0.3540000000000001, 0.3780000000000001, 0.4009999999999998, 3.6891588719521557],[1.457, 0.095, 0.085, 0.075, 0.066, 0.057, 0.048, 0.041, 0.031, 0.023, 0.017, 0.01, 0.005, 0.002, 0.0, 0.001, 0.001, 0.001, 0.004, 0.007, 0.011, 0.016, 0.022, 0.029, 0.036, 0.043, 0.05, 0.059, 0.068, 0.077, 0.086, 0.096, 1.457],2.2766667935315708
16,[-3.6725436482630056, -0.35299999999999976, -0.33000000000000007, -0.30600000000000005, -0.28300000000000036, -0.25900000000000034, -0.23599999999999977, -0.21200000000000063, -0.18900000000000006, -0.16500000000000004, -0.14200000000000035, -0.11800000000000033, -0.0940000000000003, -0.07100000000000062, -0.0470000000000006, -0.02400000000000002, 0.0, 0.022999999999999687, 0.04699999999999971, 0.0699999999999994, 0.09399999999999942, 0.11699999999999999, 0.14100000000000001, 0.16500000000000004, 0.18799999999999972, 0.21199999999999974, 0.23499999999999943, 0.25899999999999945, 0.28200000000000003, 0.30600000000000005, 0.32899999999999974, 0.35299999999999976, 3.6725436482630056],[1.473, 0.098, 0.083, 0.07, 0.058, 0.048, 0.04, 0.035, 0.033, 0.031, 0.028, 0.023, 0.018, 0.012, 0.005, 0.001, 0.0, 0.002, 0.005, 0.008, 0.012, 0.016, 0.02, 0.026, 0.032, 0.039, 0.046, 0.054, 0.063, 0.072, 0.081, 0.09, 1.473],2.3957215190027896
17,[-3.598987299076901, -0.3059999999999996, -0.2829999999999995, -0.25899999999999945, -0.23499999999999943, -0.21199999999999974, -0.18799999999999972, -0.16500000000000004, -0.14100000000000001, -0.11799999999999944, -0.09399999999999942, -0.0699999999999994, -0.04699999999999971, -0.022999999999999687, 0.0, 0.02400000000000002, 0.0470000000000006, 0.07100000000000062, 0.0940000000000003, 0.11800000000000033, 0.14200000000000035, 0.16500000000000004, 0.18900000000000006, 0.21200000000000063, 0.23600000000000065, 0.25900000000000034, 0.28300000000000036, 0.3070000000000004, 0.33000000000000007, 0.3540000000000001, 0.37700000000000067, 3.6699872990769022],[1.456, 0.092, 0.082, 0.072, 0.063, 0.055, 0.046, 0.038, 0.03, 0.023, 0.016, 0.01, 0.005, 0.002, 0.0, 0.0, 0.002, 0.004, 0.009, 0.014, 0.019, 0.025, 0.031, 0.038, 0.045, 0.052, 0.061, 0.07, 0.079, 0.089, 0.099, 1.456],2.29778546510921
19,[-3.658058366888767, -0.35299999999999976, -0.32899999999999974, -0.30600000000000005, -0.28200000000000003, -0.25900000000000034, -0.23500000000000032, -0.21199999999999974, -0.18799999999999972, -0.16500000000000004, -0.14100000000000001, -0.11800000000000033, -0.0940000000000003, -0.07099999999999973, -0.04699999999999971, -0.02400000000000002, 0.0, 0.022999999999999687, 0.04699999999999971, 0.07000000000000028, 0.0940000000000003, 0.11699999999999999, 0.14100000000000001, 0.16500000000000004, 0.18799999999999972, 0.21199999999999974, 0.23500000000000032, 0.25900000000000034, 0.28200000000000003, 0.30600000000000005, 0.32899999999999974, 3.634058366888767],[1.462, 0.093, 0.083, 0.073, 0.062, 0.053, 0.044, 0.038, 0.031, 0.025, 0.02, 0.015, 0.011, 0.007, 0.003, 0.001, 0.0, 0.0, 0.003, 0.007, 0.012, 0.019, 0.026, 0.033, 0.041, 0.048, 0.056, 0.064, 0.073, 0.083, 0.093, 1.462],2.3384060717833646
20,[-3.6738157262016484, -0.37600000000000033, -0.35299999999999976, -0.32899999999999974, -0.3049999999999997, -0.28200000000000003, -0.258, -0.23500000000000032, -0.2110000000000003, -0.18799999999999972, -0.1639999999999997, -0.14100000000000001, -0.11699999999999999, -0.0940000000000003, -0.07000000000000028, -0.04699999999999971, -0.022999999999999687, 0.0, 0.02400000000000002, 0.04800000000000004, 0.07099999999999973, 0.09499999999999975, 0.11800000000000033, 0.14199999999999946, 0.16500000000000004, 0.18900000000000006, 0.21199999999999974, 0.23599999999999977, 0.25899999999999945, 0.2829999999999995, 0.30600000000000005, 0.33000000000000007, 3.627815726201648],[1.459, 0.093, 0.083, 0.074, 0.065, 0.058, 0.049, 0.042, 0.036, 0.031, 0.026, 0.022, 0.022, 0.019, 0.013, 0.007, 0.002, 0.0, 0.0, 0.001, 0.009, 0.017, 0.024, 0.03, 0.036, 0.041, 0.048, 0.056, 0.065, 0.074, 0.084, 0.094, 1.459],2.8303554018395176
Спасибо за ответ @grfreitas. Я пробовал использовать традиционные модели машинного обучения, но у меня возникли проблемы с реализацией профиля высоты, поскольку каждую точку нельзя рассматривать как отдельный объект в модели ML.
Я не думаю, что вам нужно что-то реализовывать для профиля высоты — вы можете попробовать передать данные изображения непосредственно в модель (например, RandomForestRegressor
). Например, если ваше изображение (каждый образец) имеет размер 28x28 пикселей в оттенках серого, вы должны сгладить его до 784 объектов (где каждый пиксель является элементом) и передать это в модель. Модель будет учитывать функции вместе во время обучения. На странице примеров scikit-learn.org
представлены различные демонстрации использования моделей с данными изображения. Я мог бы попробовать несколько моделей, если бы были доступны образцы данных.
Спасибо за ваш комментарий @MuhammedYunus! Вероятно, это неправильный способ, но я добавил несколько необработанных образцов в исходный пост. Можете ли вы дать мне ссылку на примеры, о которых вы говорите?
Без проблем. На этой странице перечислены все примеры, а этот представляет собой пример классификации изображений с использованием случайных лесов (эта модель также доступна в качестве регрессора). Ниже я опубликовал ответ, в котором показано, как можно предоставить профили в качестве входных данных для моделей регрессии.
Из 15 предоставленных вами образцов можно увидеть тенденцию к Kt:
Более высокие значения кажутся более узкими и нерегулярными по сравнению с более низкими кривыми Kt, которые более широкие и плавные.
Я нормализовал каждый профиль по среднему значению его краев, предположив, что они служат ориентирами. Эта нормализация позволила мне отбросить крайние точки, что привело к меньшему количеству функций и, следовательно, к меньшей вероятности переподбора модели для небольшого набора данных.
Область выреза, примерно, отбирается с одинаковыми интервалами, хотя не все интервалы доступны для каждого образца (длины примерно от 31 до 35):
Я использовал информацию на рисунке выше, чтобы определить средние местоположения, общие для всех образцов, а затем повторно выполнил выборку данных по этой новой оси. Причина такой осторожности при повторной выборке заключается в том, что я хочу минимизировать искажение немногих доступных семплов.
После небольшой предварительной обработки и настройки я получил средний балл валидации MAE CV около 0,005 с использованием различных линейных моделей. Это может быть очень оптимистично, поскольку я использую в основном синтетические данные. Код ниже, если вы хотите попробовать больше данных.
model mae
rank
0 pls_reg 0.004795
1 linear_reg 0.005180
2 linear_svr 0.014007
3 ridge 0.019600
4 knn 0.029414
5 gradboost 0.030936
6 randomforest 0.037000
Я создал синтетические образцы, чтобы иметь больше возможностей для экспериментов при подборе моделей. Однако вы также можете использовать его в качестве метода дополнения, чтобы уменьшить переобучение вашего набора данных.
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
id_list = [1, 2, 3, 7, 8, 9, 10, 11, 12, 14,15, 16, 17, 19, 20]
positions_list = [
[-3.6183010075758877, -0.30600000000000094, -0.28200000000000003, -0.25900000000000034, -0.2350000000000012, -0.21199999999999974, -0.1880000000000006, -0.16500000000000092, -0.14100000000000001, -0.11800000000000033, -0.0940000000000012, -0.07099999999999973, -0.0470000000000006, -0.02400000000000091, 0.0, 0.022999999999999687, 0.04699999999999882, 0.07099999999999973, 0.09399999999999942, 0.11800000000000033, 0.14100000000000001, 0.16499999999999915, 0.18799999999999883, 0.21199999999999974, 0.23499999999999943, 0.25900000000000034, 0.28200000000000003, 0.30599999999999916, 0.32899999999999885, 0.35299999999999976, 3.6653010075758865],
[-3.7002299397640215, -0.40000000000000036, -0.37600000000000033, -0.35299999999999976, -0.32899999999999974, -0.30600000000000005, -0.28200000000000003, -0.25900000000000034, -0.23500000000000032, -0.2110000000000003, -0.18799999999999972, -0.1639999999999997, -0.14100000000000001, -0.11699999999999999, -0.0940000000000003, -0.07000000000000028, -0.04699999999999971, -0.022999999999999687, 0.0, 0.02400000000000002, 0.04699999999999971, 0.07099999999999973, 0.0940000000000003, 0.11800000000000033, 0.14100000000000001, 0.16500000000000004, 0.18799999999999972, 0.21199999999999974, 0.23500000000000032, 0.25900000000000034, 0.28200000000000003, 0.30600000000000005, 0.33000000000000007, 3.630229939764021],
[-3.621845163453171, -0.35299999999999976, -0.32899999999999885, -0.30599999999999916, -0.28200000000000003, -0.25899999999999856, -0.23499999999999943, -0.21199999999999974, -0.18799999999999883, -0.16499999999999915, -0.14100000000000001, -0.11799999999999855, -0.09399999999999942, -0.07099999999999973, -0.04699999999999882, -0.023999999999999133, 0.0, 0.023000000000001464, 0.0470000000000006, 0.07000000000000028, 0.0940000000000012, 0.11700000000000088, 0.14100000000000001, 0.16500000000000092, 0.1880000000000006, 0.21200000000000152, 0.2350000000000012, 0.25900000000000034, 0.28200000000000003, 0.30600000000000094, 0.3290000000000006, 0.35300000000000153, 0.3760000000000012, 0.40000000000000036, 3.6688451634531716],
[-3.643573085514529, -0.35300000000000065, -0.33000000000000007, -0.30600000000000005, -0.28300000000000036, -0.25900000000000034, -0.23500000000000032, -0.21200000000000063, -0.1880000000000006, -0.16500000000000004, -0.14100000000000001, -0.11800000000000033, -0.0940000000000003, -0.07100000000000062, -0.0470000000000006, -0.023000000000000576, 0.0, 0.02400000000000002, 0.04699999999999971, 0.07099999999999973, 0.09399999999999942, 0.11799999999999944, 0.14100000000000001, 0.16500000000000004, 0.18900000000000006, 0.21199999999999974, 0.23599999999999977, 0.25899999999999945, 0.2829999999999995, 0.30600000000000005, 0.33000000000000007, 0.35299999999999976, 3.6435730855145283],
[-3.575987299076902, -0.28300000000000036, -0.25900000000000034, -0.23500000000000032, -0.21199999999999974, -0.18799999999999972, -0.16500000000000004, -0.14100000000000001, -0.11800000000000033, -0.0940000000000003, -0.07099999999999973, -0.04699999999999971, -0.022999999999999687, 0.0, 0.02400000000000002, 0.04699999999999971, 0.07099999999999973, 0.0940000000000003, 0.11800000000000033, 0.14200000000000035, 0.16500000000000004, 0.18900000000000006, 0.21199999999999974, 0.23599999999999977, 0.25900000000000034, 0.28300000000000036, 0.30600000000000005, 0.33000000000000007, 0.3540000000000001, 0.3769999999999998, 0.4009999999999998, 3.6939872990769014],
[-3.682058366888767, -0.3769999999999998, -0.35299999999999976, -0.33000000000000007, -0.30600000000000005, -0.2829999999999995, -0.25899999999999945, -0.23599999999999977, -0.21199999999999974, -0.18799999999999972, -0.16500000000000004, -0.14100000000000001, -0.11799999999999944, -0.09399999999999942, -0.07099999999999973, -0.04699999999999971, -0.02400000000000002, 0.0, 0.02400000000000002, 0.0470000000000006, 0.07100000000000062, 0.0940000000000003, 0.11800000000000033, 0.14100000000000001, 0.16500000000000004, 0.1880000000000006, 0.21200000000000063, 0.23600000000000065, 0.25900000000000034, 0.28300000000000036, 0.30600000000000005, 3.6110583668887672],
[-3.6484015126392757, -0.35300000000000065, -0.33000000000000007, -0.30600000000000005, -0.28300000000000036, -0.25900000000000034, -0.23600000000000065, -0.21200000000000063, -0.18900000000000006, -0.16500000000000004, -0.14100000000000001, -0.11800000000000033, -0.0940000000000003, -0.07100000000000062, -0.0470000000000006, -0.02400000000000002, 0.0, 0.02400000000000002, 0.04699999999999971, 0.07099999999999973, 0.09399999999999942, 0.11799999999999944, 0.14100000000000001, 0.16500000000000004, 0.18900000000000006, 0.21199999999999974, 0.23599999999999977, 0.25899999999999945, 0.2829999999999995, 0.30600000000000005, 0.33000000000000007, 0.35299999999999976, 3.648401512639275],
[-3.658058366888767, -0.35299999999999976, -0.33000000000000007, -0.30600000000000005, -0.2829999999999995, -0.25899999999999945, -0.23599999999999977, -0.21199999999999974, -0.18900000000000006, -0.16500000000000004, -0.14100000000000001, -0.11799999999999944, -0.09399999999999942, -0.07099999999999973, -0.04699999999999971, -0.02400000000000002, 0.0, 0.023000000000000576, 0.0470000000000006, 0.07000000000000028, 0.0940000000000003, 0.11800000000000033, 0.14100000000000001, 0.16500000000000004, 0.1880000000000006, 0.21200000000000063, 0.23500000000000032, 0.25900000000000034, 0.28200000000000003, 0.30600000000000005, 0.3290000000000006, 0.35300000000000065, 3.658058366888768],
[-3.646987299076901, -0.3539999999999992, -0.3299999999999992, -0.30599999999999916, -0.2829999999999995, -0.25899999999999945, -0.23599999999999977, -0.21199999999999974, -0.18899999999999917, -0.16499999999999915, -0.14199999999999946, -0.11799999999999944, -0.09499999999999975, -0.07099999999999973, -0.047999999999999154, -0.023999999999999133, 0.0, 0.023000000000000576, 0.0470000000000006, 0.07000000000000028, 0.0940000000000003, 0.11699999999999999, 0.14100000000000001, 0.1640000000000006, 0.1880000000000006, 0.2110000000000003, 0.23500000000000032, 0.258, 0.28200000000000003, 0.30600000000000005, 0.3290000000000006, 0.35300000000000065, 3.645987299076902],
[-3.6134725804511403, -0.30600000000000005, -0.28300000000000036, -0.25900000000000034, -0.23600000000000065, -0.21200000000000063, -0.1880000000000006, -0.16500000000000004, -0.14100000000000001, -0.11800000000000033, -0.0940000000000003, -0.07100000000000062, -0.0470000000000006, -0.02400000000000002, 0.0, 0.022999999999999687, 0.04699999999999971, 0.0699999999999994, 0.09399999999999942, 0.11699999999999999, 0.14100000000000001, 0.16499999999999915, 0.18799999999999972, 0.21199999999999974, 0.23499999999999943, 0.25899999999999945, 0.28199999999999914, 0.30599999999999916, 0.32899999999999974, 0.35299999999999976, 0.37599999999999945, 3.6834725804511397],
[-3.5941588719521564, -0.3060000000000005, -0.2820000000000005, -0.25900000000000034, -0.23500000000000032, -0.21200000000000063, -0.1880000000000006, -0.16500000000000004, -0.14100000000000001, -0.11699999999999999, -0.0940000000000003, -0.07000000000000028, -0.0470000000000006, -0.023000000000000576, 0.0, 0.02400000000000002, 0.04800000000000004, 0.07099999999999973, 0.09499999999999975, 0.11799999999999944, 0.14199999999999946, 0.16599999999999948, 0.18900000000000006, 0.21300000000000008, 0.23599999999999977, 0.2599999999999998, 0.2829999999999995, 0.3069999999999995, 0.3309999999999995, 0.3540000000000001, 0.3780000000000001, 0.4009999999999998, 3.6891588719521557],
[-3.6725436482630056, -0.35299999999999976, -0.33000000000000007, -0.30600000000000005, -0.28300000000000036, -0.25900000000000034, -0.23599999999999977, -0.21200000000000063, -0.18900000000000006, -0.16500000000000004, -0.14200000000000035, -0.11800000000000033, -0.0940000000000003, -0.07100000000000062, -0.0470000000000006, -0.02400000000000002, 0.0, 0.022999999999999687, 0.04699999999999971, 0.0699999999999994, 0.09399999999999942, 0.11699999999999999, 0.14100000000000001, 0.16500000000000004, 0.18799999999999972, 0.21199999999999974, 0.23499999999999943, 0.25899999999999945, 0.28200000000000003, 0.30600000000000005, 0.32899999999999974, 0.35299999999999976, 3.6725436482630056],
[-3.598987299076901, -0.3059999999999996, -0.2829999999999995, -0.25899999999999945, -0.23499999999999943, -0.21199999999999974, -0.18799999999999972, -0.16500000000000004, -0.14100000000000001, -0.11799999999999944, -0.09399999999999942, -0.0699999999999994, -0.04699999999999971, -0.022999999999999687, 0.0, 0.02400000000000002, 0.0470000000000006, 0.07100000000000062, 0.0940000000000003, 0.11800000000000033, 0.14200000000000035, 0.16500000000000004, 0.18900000000000006, 0.21200000000000063, 0.23600000000000065, 0.25900000000000034, 0.28300000000000036, 0.3070000000000004, 0.33000000000000007, 0.3540000000000001, 0.37700000000000067, 3.6699872990769022],
[-3.658058366888767, -0.35299999999999976, -0.32899999999999974, -0.30600000000000005, -0.28200000000000003, -0.25900000000000034, -0.23500000000000032, -0.21199999999999974, -0.18799999999999972, -0.16500000000000004, -0.14100000000000001, -0.11800000000000033, -0.0940000000000003, -0.07099999999999973, -0.04699999999999971, -0.02400000000000002, 0.0, 0.022999999999999687, 0.04699999999999971, 0.07000000000000028, 0.0940000000000003, 0.11699999999999999, 0.14100000000000001, 0.16500000000000004, 0.18799999999999972, 0.21199999999999974, 0.23500000000000032, 0.25900000000000034, 0.28200000000000003, 0.30600000000000005, 0.32899999999999974, 3.634058366888767],
[-3.6738157262016484, -0.37600000000000033, -0.35299999999999976, -0.32899999999999974, -0.3049999999999997, -0.28200000000000003, -0.258, -0.23500000000000032, -0.2110000000000003, -0.18799999999999972, -0.1639999999999997, -0.14100000000000001, -0.11699999999999999, -0.0940000000000003, -0.07000000000000028, -0.04699999999999971, -0.022999999999999687, 0.0, 0.02400000000000002, 0.04800000000000004, 0.07099999999999973, 0.09499999999999975, 0.11800000000000033, 0.14199999999999946, 0.16500000000000004, 0.18900000000000006, 0.21199999999999974, 0.23599999999999977, 0.25899999999999945, 0.2829999999999995, 0.30600000000000005, 0.33000000000000007, 3.627815726201648],
]
positions_list = [np.array(positions) for positions in positions_list]
heights_list = [
[1.461, 0.089, 0.079, 0.069, 0.06, 0.052, 0.044, 0.035, 0.028, 0.02, 0.013, 0.007, 0.003, 0.001, 0.0, 0.0, 0.002, 0.005, 0.008, 0.014, 0.02, 0.026, 0.033, 0.041, 0.048, 0.056, 0.064, 0.073, 0.082, 0.091, 1.461],
[1.466, 0.09899999999999998, 0.09199999999999997, 0.08399999999999996, 0.07699999999999996, 0.06999999999999995, 0.061999999999999944, 0.05399999999999994, 0.04499999999999993, 0.03700000000000003, 0.030000000000000027, 0.025000000000000022, 0.020000000000000018, 0.016000000000000014, 0.015000000000000013, 0.009000000000000008, 0.006000000000000005, 0.0010000000000000009, 0.0, 0.0030000000000000027, 0.010000000000000009, 0.015000000000000013, 0.025000000000000022, 0.030000000000000027, 0.03700000000000003, 0.04200000000000004, 0.04699999999999993, 0.051999999999999935, 0.061999999999999944, 0.07099999999999995, 0.07999999999999996, 0.08899999999999997, 0.09599999999999997, 1.466],
[1.453, 0.09899999999999998, 0.08899999999999997, 0.07899999999999996, 0.07000000000000006, 0.06000000000000005, 0.052000000000000046, 0.04500000000000004, 0.039000000000000035, 0.03300000000000003, 0.027000000000000024, 0.02200000000000002, 0.016000000000000014, 0.010000000000000009, 0.006000000000000005, 0.0020000000000000018, 0.0, 0.0, 0.0010000000000000009, 0.0050000000000000044, 0.007000000000000006, 0.007000000000000006, 0.01100000000000001, 0.014000000000000012, 0.018000000000000016, 0.02200000000000002, 0.028000000000000025, 0.03500000000000003, 0.04400000000000004, 0.05300000000000005, 0.062000000000000055, 0.07200000000000006, 0.08199999999999996, 0.09199999999999997, 1.453],
[1.462, 0.099, 0.09, 0.081, 0.072, 0.063, 0.053, 0.044, 0.035, 0.027, 0.02, 0.014, 0.009, 0.005, 0.002, 0.001, 0.0, 0.002, 0.006, 0.012, 0.017, 0.024, 0.03, 0.037, 0.043, 0.05, 0.057, 0.064, 0.072, 0.081, 0.089, 0.098, 1.462],
[1.458, 0.094, 0.084, 0.074, 0.064, 0.055, 0.046, 0.038, 0.03, 0.023, 0.016, 0.009, 0.003, 0.0, 0.0, 0.002, 0.006, 0.011, 0.017, 0.023, 0.028, 0.032, 0.037, 0.041, 0.045, 0.05, 0.055, 0.061, 0.069, 0.079, 0.089, 1.458],
[1.461, 0.092, 0.083, 0.074, 0.065, 0.056, 0.048, 0.042, 0.036, 0.03, 0.025, 0.02, 0.015, 0.01, 0.006, 0.003, 0.001, 0.0, 0.002, 0.005, 0.01, 0.016, 0.024, 0.033, 0.04, 0.048, 0.056, 0.064, 0.073, 0.081, 0.091, 1.461],
[1.459, 0.094, 0.084, 0.075, 0.066, 0.056, 0.047, 0.039, 0.031, 0.023, 0.017, 0.011, 0.008, 0.005, 0.002, 0.001, 0.0, 0.0, 0.002, 0.006, 0.01, 0.015, 0.021, 0.027, 0.033, 0.04, 0.048, 0.057, 0.066, 0.076, 0.086, 0.097, 1.459],
[1.462, 0.093, 0.084, 0.074, 0.064, 0.054, 0.045, 0.038, 0.032, 0.027, 0.023, 0.018, 0.012, 0.006, 0.003, 0.001, 0.0, 0.003, 0.007, 0.012, 0.017, 0.02, 0.024, 0.029, 0.035, 0.042, 0.051, 0.06, 0.069, 0.078, 0.088, 0.097, 1.462],
[1.46, 0.096, 0.087, 0.079, 0.071, 0.062, 0.053, 0.044, 0.037, 0.029, 0.022, 0.016, 0.011, 0.006, 0.003, 0.001, 0.0, 0.002, 0.005, 0.008, 0.013, 0.017, 0.022, 0.028, 0.034, 0.042, 0.05, 0.059, 0.068, 0.078, 0.088, 0.098, 1.46],
[1.466, 0.096, 0.086, 0.075, 0.064, 0.055, 0.046, 0.038, 0.031, 0.024, 0.018, 0.012, 0.006, 0.002, 0.0, 0.0, 0.002, 0.006, 0.011, 0.016, 0.021, 0.026, 0.031, 0.037, 0.044, 0.052, 0.06, 0.069, 0.078, 0.088, 0.097, 1.466],
[1.457, 0.095, 0.085, 0.075, 0.066, 0.057, 0.048, 0.041, 0.031, 0.023, 0.017, 0.01, 0.005, 0.002, 0.0, 0.001, 0.001, 0.001, 0.004, 0.007, 0.011, 0.016, 0.022, 0.029, 0.036, 0.043, 0.05, 0.059, 0.068, 0.077, 0.086, 0.096, 1.457],
[1.473, 0.098, 0.083, 0.07, 0.058, 0.048, 0.04, 0.035, 0.033, 0.031, 0.028, 0.023, 0.018, 0.012, 0.005, 0.001, 0.0, 0.002, 0.005, 0.008, 0.012, 0.016, 0.02, 0.026, 0.032, 0.039, 0.046, 0.054, 0.063, 0.072, 0.081, 0.09, 1.473],
[1.456, 0.092, 0.082, 0.072, 0.063, 0.055, 0.046, 0.038, 0.03, 0.023, 0.016, 0.01, 0.005, 0.002, 0.0, 0.0, 0.002, 0.004, 0.009, 0.014, 0.019, 0.025, 0.031, 0.038, 0.045, 0.052, 0.061, 0.07, 0.079, 0.089, 0.099, 1.456],
[1.462, 0.093, 0.083, 0.073, 0.062, 0.053, 0.044, 0.038, 0.031, 0.025, 0.02, 0.015, 0.011, 0.007, 0.003, 0.001, 0.0, 0.0, 0.003, 0.007, 0.012, 0.019, 0.026, 0.033, 0.041, 0.048, 0.056, 0.064, 0.073, 0.083, 0.093, 1.462],
[1.459, 0.093, 0.083, 0.074, 0.065, 0.058, 0.049, 0.042, 0.036, 0.031, 0.026, 0.022, 0.022, 0.019, 0.013, 0.007, 0.002, 0.0, 0.0, 0.001, 0.009, 0.017, 0.024, 0.03, 0.036, 0.041, 0.048, 0.056, 0.065, 0.074, 0.084, 0.094, 1.459],
]
heights_list = [np.array(heights) for heights in heights_list]
kt_list = [2.2766371542377914, 2.616131437456064, 2.142634834792794, 2.3949189992310247, 2.502224358908798, 2.3797426230814387, 2.1552667710791282, 2.490420148832229, 2.334695259830575, 2.306403976152674, 2.2766667935315708, 2.3957215190027896, 2.29778546510921, 2.3384060717833646, 2.8303554018395176,]
#View data
from matplotlib.colors import Normalize, CenteredNorm
from matplotlib.cm import ScalarMappable
from matplotlib.gridspec import GridSpec
n_samples = len(positions_list)
vmin = min(kt_list)
vmax = max(kt_list)
norm = Normalize(vmin, vmax) #maps Kt range to ~0-1
centred_norm = CenteredNorm(vcenter=np.median(kt_list), halfrange=0.5 * (vmax - vmin))
colour_kt = ScalarMappable(centred_norm, 'coolwarm')
f = plt.figure(figsize=(11, 3))
gs = GridSpec(1, 3, width_ratios=[0.5, 0.5, 3], height_ratios=[1])
ax_xleft = f.add_subplot(gs[0])
ax_xright = f.add_subplot(gs[1])
ax_xcentre = f.add_subplot(gs[2])
for ax in [ax_xleft, ax_xright, ax_xcentre]:
for positions, heights, kt in zip(positions_list, heights_list, kt_list):
ax.plot(positions, heights, marker='.', ms='10', linewidth=2, color=colour_kt.to_rgba(kt))
if ax is ax_xleft:
ax.set(ylabel='height')
if ax is ax_xright:
ax.tick_params(labelleft=False, left=False)
ax.spines.left.set_visible(False)
ax.spines[['top', 'right']].set_visible(False)
ax.set_title(
'left' if ax is ax_xleft else ('right' if ax is ax_xright else 'centre')
)
ax.set_xlabel('position')
if ax in [ax_xleft, ax_xright]:
x_lims = [-3.71, -3.55] if ax is ax_xleft else [3.55, 3.71]
y_lims = [1.45, 1.475]
else:
x_lims = [-0.41, 0.42]
y_lims = [-0.005, 0.105] #general notch area
x_lims = [-0.2, 0.2]
y_lims = [-0.005, 0.04] #notch peak
ax.set(xlim=x_lims, ylim=y_lims)
#colorbar on right
ax_pos = ax_xcentre.get_position()
cax = f.add_subplot([
ax_pos.x0 + ax_pos.width * 1.05, ax_pos.y0, ax_pos.width / 20, ax_pos.height
])
f.colorbar(cax=cax, mappable=colour_kt, label='Kt\n(white = median Kt)')
#Synthesise data for testing
def synthesise_samples(n_samples, positions_list, heights_list, kt_list):
newsamples_positions = []
newsamples_heights = []
newsamples_kt = []
sample_idxs = np.arange(len(positions_list))
for _ in range(n_samples):
#Randomly select two map
idx_i = np.random.choice(sample_idxs)
idx_j = np.random.choice(sample_idxs[sample_idxs != idx_i])
positions_i, positions_j = [positions_list[idx] for idx in [idx_i, idx_j]]
heights_i, heights_j = [heights_list[idx] for idx in [idx_i, idx_j]]
#Decide on a new length at random, and interpolate onto a common axis x_j
new_len = len(positions_j)
heights_interp = np.interp(positions_j, positions_i, heights_i)
#Randomly sample a linear interpolation between the two samples in feature space
alpha = np.random.uniform()
new_heights = alpha * heights_interp + (1 - alpha) * heights_j
#Repeat for the target, with some noise added
new_kt = alpha * kt_list[idx_i] + (1 - alpha) * kt_list[idx_j]
new_kt += np.random.randn() * np.std(kt_list, ddof=1) / 30
#Store the the new sample
newsamples_positions.append(positions_j)
newsamples_heights.append(new_heights)
newsamples_kt.append(new_kt)
return newsamples_positions, newsamples_heights, newsamples_kt
positions_synth, heights_synth, kt_synth = synthesise_samples(
150 - len(positions_list), positions_list, heights_list, kt_list
)
#Combine with the original data
positions_list.extend(positions_synth)
heights_list.extend(heights_synth)
kt_list.extend(kt_synth)
n_samples = len(positions_list)
#Split the data
from sklearn.model_selection import train_test_split
import pandas as pd
test_frac = 0.15
train_frac = 1 - test_frac
kt_binned = pd.cut(kt_list, bins=5)
train_ixs, test_ixs = train_test_split(
np.arange(n_samples), test_size=test_frac, stratify=kt_binned, random_state=0
)
print(
f'Train/test sizes: {train_ixs.size}/{test_ixs.size}',
f'| fractions: {train_frac:.2f}/{test_frac:.2f}'
)
positions_list_trn = [positions_list[i] for i in train_ixs]
heights_list_trn = [heights_list[i] for i in train_ixs]
kt_list_trn = [kt_list[i] for i in train_ixs]
#Assess sampling consistency and resample onto a common axis
#Normalise each sample's heights using the mean of its edge points
heights_list = [heights / np.mean(heights[[0, -1]]) for heights in heights_list]
#Visualise the distribution of x sampling
f, ax = plt.subplots(figsize=(11, 2))
ax.set_xlim(-0.45, 0.45) #look at notch
ax.set_xlabel('positions')
ax.tick_params(left=False, labelleft=False)
ax.spines[['top', 'right', 'left']].set_visible(False)
ax.spines.bottom.set_bounds(-0.4, 0.4)
ax.set_title('sampling consistency (black) and average locations (red)', fontsize=11)
for sample_idx, sample in enumerate(positions_list_trn[:15]):
ax.scatter(sample, sample_idx * np.ones_like(sample), marker='|', c='darkslategray')
#Find the average sampling positions, in order to resample onto them
avg_step_size = np.mean(
[delta for sample in positions_list_trn for delta in np.diff(sample[1:-1])]
)
#The start and end points of the axis to resample onto. Dropping the edges.
start_pos = min(pos for sample in positions_list_trn for pos in sample[1:-1])
end_pos = max(pos for sample in positions_list_trn for pos in sample[1:-1])
max_seq_len = max(map(len, positions_list_trn))
pos_fine = np.linspace(start_pos, end_pos, num=max_seq_len * 100)
positions_interp = np.empty([n_samples, pos_fine.size]) * np.nan
for sample_idx, positions in enumerate(positions_list):
for pos_idx, pos in enumerate(positions[1:-1]):
match_idx = np.argmin(np.abs(pos_fine - pos))
positions_interp[sample_idx, match_idx] = pos
avg_positions = np.nanmean(positions_interp, axis=0)
avg_positions = avg_positions[~np.isnan(avg_positions)]
avg_positions = avg_positions[
np.argwhere(np.abs(np.diff(avg_positions, prepend=1e3)) > avg_step_size / 4)
].ravel()
#Visualise results
[ax.axvline(pos, ymax=0.05, color='red', linewidth=3, alpha=0.4) for pos in avg_positions];
#Resample heights onto the average sampling positions
heights_resampled = np.zeros([n_samples, avg_positions.size])
for sample_idx, (positions, heights) in enumerate(zip(positions_list, heights_list)):
heights_resampled[sample_idx, :] = np.interp(avg_positions, positions, heights)
#Use CV to assess various models and view results
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LinearRegression, Ridge
from sklearn.svm import LinearSVR
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
from sklearn.cross_decomposition import PLSRegression
from sklearn.neighbors import KNeighborsRegressor
from sklearn.pipeline import make_pipeline
from sklearn.model_selection import cross_validate, LeaveOneOut
np.random.seed(0)
X_train, X_test = [heights_resampled[ixs] for ixs in [train_ixs, test_ixs]]
y_train, y_test = [np.array(kt_list)[ixs] for ixs in [train_ixs, test_ixs]]
models_dict = {
'linear_reg': LinearRegression(),
'ridge': Ridge(),
'linear_svr': LinearSVR(C=0.05, dual='auto', max_iter=2000),
'randomforest': RandomForestRegressor(min_samples_split=8),
'gradboost': GradientBoostingRegressor(),
'pls_reg': PLSRegression(n_components=15),
'knn': KNeighborsRegressor(weights='distance'),
}
results_dfs = []
for name, model in models_dict.items():
pipeline = make_pipeline(StandardScaler(), model)
print(name, '...')
results = cross_validate(
pipeline, X_train, y_train,
scoring='neg_mean_absolute_error',
cv=LeaveOneOut(),
n_jobs=-1,
)
results_df = pd.DataFrame(
{'model': [name],
'mae': [-results['test_score'].mean()]
},
)
results_dfs.append(results_df)
results_df = pd.concat(results_dfs, axis=0, ignore_index=True)
results_df = (
results_df
.sort_values(by='mae')
.reset_index(drop=True)
.rename_axis(index='rank')
)
display(
results_df
.style
.format(precision=4)
.background_gradient(subset=['mae'], cmap='plasma')
.set_caption('CV validation scores')
)
ax = results_df.plot(kind='bar', x='model', ylabel='mae', legend=False)
ax.figure.set_size_inches(4, 2)
Спасибо за такой развернутый ответ, это именно то, что мне нужно!
С удовольствием @SeppeVanheulenberghe. Любые вопросы просто задавайте.
Просто небольшая провокация: почему вы конкретно хотите использовать модель нейронной сети с таким небольшим количеством выборок? Переобучение произойдет очень легко, и точность, вероятно, не будет лучше, чем у более простой модели.