Как я могу повысить точность 1D CNN для оценки коэффициентов концентрации напряжений?

Я пытаюсь разработать нейронную сеть, которая сможет оценить коэффициент концентрации напряжений Kt образцов с V-образным надрезом на основе сканирований профиля надреза. Сканирования были интерполированы для создания областей с равноудаленными точками, поскольку я хочу использовать 1D CNN, которая может интерпретировать значения высоты профиля. У меня около 150 образцов. Среднее значение Kt образцов составляет 2,27 со стандартным значением 0,17.

Пример массива, который используется в качестве входных данных до нормализации:

[4.7605     4.60461111 4.44872222 4.29283333 4.13694444 3.98105556
 3.82516667 3.66927778 3.51338889 3.3575     3.3575     3.35472643
 3.35195286 3.34917929 3.34635815 3.34346399 3.34056983 3.33767567
 3.3351553  3.33265909 3.33016288 3.32763043 3.32502569 3.32242095
 3.31981621 3.31755051 3.31533165 3.31311279 3.31102569 3.3092892
 3.3075527  3.30581621 3.30391204 3.30197054 3.30002904 3.2985
 3.2985     3.2985     3.2985     3.29807997 3.29752525 3.29697054
 3.29641217 3.29583333 3.2952545  3.29467567 3.29353409 3.29214731
 3.29076052 3.28931555 3.28728964 3.28526372 3.28323781 3.28126557
 3.27932407 3.27738258 3.27544108 3.27349958 3.27155808 3.26961658
 3.26763922 3.26561331 3.2635874  3.26156148 3.26123106 3.2609537
 3.26067635 3.26081621 3.26168445 3.2625527  3.26342095 3.26551684
 3.26773569 3.26995455 3.27229051 3.27489526 3.2775     3.28010474
 3.28238215 3.28460101 3.28681987 3.28920268 3.29209684 3.294991
 3.29788516 3.30002904 3.30197054 3.30391204 3.30592161 3.30823693
 3.31055226 3.31286759 3.31531439 3.31781061 3.32030682 3.32281621
 3.32542095 3.32802569 3.33063043 3.33316288 3.33565909 3.3381553
 3.34067567 3.34356983 3.34646399 3.34935815 3.35201136 3.35450758
 3.35700379 3.3595     3.3595     3.51516667 3.67083333 3.8265
 3.98216667 4.13783333 4.2935     4.44916667 4.60483333 4.7605    ]

Я выполнил поиск по сетке, чтобы оптимизировать свою модель, но меня не устраивает точность. Добавление дополнительных сверточных слоев не сильно повлияло на точность, равно как и включение слоя LSTM. Я также пробовал разные скалеры. Модель всегда возвращает значение Kt 2,11XXXXX, при этом меняются только последние цифры. RMSE равно 0,17, а MAE равно 0,13.

def preprocessing(specimens_list, scaler=Normalizer):
    # fill data lists
    X = []
    y = []

    for specimen in specimens_list:
        scan = list(specimen.KeyenceScansValues)[0]
        x_data, y_data = scan.format_for_ML(NUM1, NUM2) # creates the equidistant points
        X.append(y_data) # only use height (y) values
        y.append(specimen.Kt)

    # Calculate the mean of y_train (ignoring NaN values)
    mean_y = np.nanmean(y)

    # Replace NaN values in y_train with the mean
    y = np.where(np.isnan(y), mean_y, y)

    # scaling
    scaler_x = scaler().fit(X)
    X_scaled = scaler_x.transform(X) # Normalizer
    
    return np.asarray(X_scaled), np.asarray(y)

def create_model(filter1:int=32, filter2:int=64, kernel_size1:int=5, kernel_size2:int=5, learning_rate:float=0.001):
    model = models.Sequential([
        layers.Reshape((NUM, 1), input_shape=(NUM,)),
        layers.Conv1D(filters=filter1, kernel_size=kernel_size1, activation='relu', padding='same'),
        layers.MaxPooling1D(pool_size=2, strides=2),
        layers.Conv1D(filter2, kernel_size=kernel_size2, activation='relu', padding='same'),
        layers.MaxPooling1D(pool_size=2, strides=2),
        #layers.LSTM(units=128),
        layers.Flatten(),
        layers.Dense(64, activation='relu'),
        layers.Dense(1)  # Output layer for regression
    ])

    optimizer = keras.optimizers.Adam(learning_rate = learning_rate)

    model.compile(optimizer=optimizer,
              loss = keras.losses.MeanSquaredError(),
              metrics = [keras.metrics.RootMeanSquaredError(), 
                         keras.metrics.MeanSquaredLogarithmicError(), 
                         keras.metrics.MeanAbsoluteError()])
    return model

def training(model, x_train, y_train, x_test, y_test, batch_size=5, epochs=10):
    
    ### Train model ###
    
    #model.save_weights(checkpoint_path.format(epoch=0))

    train_history = model.fit(x_train, y_train, 
                      epochs = epochs, 
                      batch_size = batch_size, 
                      #callbacks = [cp_callback],
                      validation_data = (x_test, y_test),
                      verbose = 0)
    
    return train_history

def evaluate(model, train_history, x_train, y_train, x_test, y_test):
    
    ### Track performance ###
    
    training_performance = model.evaluate(x_train, y_train, verbose = 0)
    validation_performance = model.evaluate(x_test, y_test, verbose = 0)

    #model.summary()

    print(f'Training performance: RMSE = {training_performance[1]:.2f}, MSLE = {training_performance[2]:.2f}, MAE = {training_performance[3]:.2f}')
    print(f'Validation performance: RMSE = {validation_performance[1]:.2f}, MSLE = {validation_performance[2]:.2f}, MAE = {validation_performance[3]:.2f}')

    return [training_performance[1], training_performance[2], training_performance[3], validation_performance[1], validation_performance[2], validation_performance[3]]

Некоторые исходные данные образца:

ID,x,y,Kt
1,[-3.6183010075758877, -0.30600000000000094, -0.28200000000000003, -0.25900000000000034, -0.2350000000000012, -0.21199999999999974, -0.1880000000000006, -0.16500000000000092, -0.14100000000000001, -0.11800000000000033, -0.0940000000000012, -0.07099999999999973, -0.0470000000000006, -0.02400000000000091, 0.0, 0.022999999999999687, 0.04699999999999882, 0.07099999999999973, 0.09399999999999942, 0.11800000000000033, 0.14100000000000001, 0.16499999999999915, 0.18799999999999883, 0.21199999999999974, 0.23499999999999943, 0.25900000000000034, 0.28200000000000003, 0.30599999999999916, 0.32899999999999885, 0.35299999999999976, 3.6653010075758865],[1.461, 0.089, 0.079, 0.069, 0.06, 0.052, 0.044, 0.035, 0.028, 0.02, 0.013, 0.007, 0.003, 0.001, 0.0, 0.0, 0.002, 0.005, 0.008, 0.014, 0.02, 0.026, 0.033, 0.041, 0.048, 0.056, 0.064, 0.073, 0.082, 0.091, 1.461],2.2766371542377914
2,[-3.7002299397640215, -0.40000000000000036, -0.37600000000000033, -0.35299999999999976, -0.32899999999999974, -0.30600000000000005, -0.28200000000000003, -0.25900000000000034, -0.23500000000000032, -0.2110000000000003, -0.18799999999999972, -0.1639999999999997, -0.14100000000000001, -0.11699999999999999, -0.0940000000000003, -0.07000000000000028, -0.04699999999999971, -0.022999999999999687, 0.0, 0.02400000000000002, 0.04699999999999971, 0.07099999999999973, 0.0940000000000003, 0.11800000000000033, 0.14100000000000001, 0.16500000000000004, 0.18799999999999972, 0.21199999999999974, 0.23500000000000032, 0.25900000000000034, 0.28200000000000003, 0.30600000000000005, 0.33000000000000007, 3.630229939764021],[1.466, 0.09899999999999998, 0.09199999999999997, 0.08399999999999996, 0.07699999999999996, 0.06999999999999995, 0.061999999999999944, 0.05399999999999994, 0.04499999999999993, 0.03700000000000003, 0.030000000000000027, 0.025000000000000022, 0.020000000000000018, 0.016000000000000014, 0.015000000000000013, 0.009000000000000008, 0.006000000000000005, 0.0010000000000000009, 0.0, 0.0030000000000000027, 0.010000000000000009, 0.015000000000000013, 0.025000000000000022, 0.030000000000000027, 0.03700000000000003, 0.04200000000000004, 0.04699999999999993, 0.051999999999999935, 0.061999999999999944, 0.07099999999999995, 0.07999999999999996, 0.08899999999999997, 0.09599999999999997, 1.466],2.616131437456064
3,[-3.621845163453171, -0.35299999999999976, -0.32899999999999885, -0.30599999999999916, -0.28200000000000003, -0.25899999999999856, -0.23499999999999943, -0.21199999999999974, -0.18799999999999883, -0.16499999999999915, -0.14100000000000001, -0.11799999999999855, -0.09399999999999942, -0.07099999999999973, -0.04699999999999882, -0.023999999999999133, 0.0, 0.023000000000001464, 0.0470000000000006, 0.07000000000000028, 0.0940000000000012, 0.11700000000000088, 0.14100000000000001, 0.16500000000000092, 0.1880000000000006, 0.21200000000000152, 0.2350000000000012, 0.25900000000000034, 0.28200000000000003, 0.30600000000000094, 0.3290000000000006, 0.35300000000000153, 0.3760000000000012, 0.40000000000000036, 3.6688451634531716],[1.453, 0.09899999999999998, 0.08899999999999997, 0.07899999999999996, 0.07000000000000006, 0.06000000000000005, 0.052000000000000046, 0.04500000000000004, 0.039000000000000035, 0.03300000000000003, 0.027000000000000024, 0.02200000000000002, 0.016000000000000014, 0.010000000000000009, 0.006000000000000005, 0.0020000000000000018, 0.0, 0.0, 0.0010000000000000009, 0.0050000000000000044, 0.007000000000000006, 0.007000000000000006, 0.01100000000000001, 0.014000000000000012, 0.018000000000000016, 0.02200000000000002, 0.028000000000000025, 0.03500000000000003, 0.04400000000000004, 0.05300000000000005, 0.062000000000000055, 0.07200000000000006, 0.08199999999999996, 0.09199999999999997, 1.453],2.142634834792794
7,[-3.643573085514529, -0.35300000000000065, -0.33000000000000007, -0.30600000000000005, -0.28300000000000036, -0.25900000000000034, -0.23500000000000032, -0.21200000000000063, -0.1880000000000006, -0.16500000000000004, -0.14100000000000001, -0.11800000000000033, -0.0940000000000003, -0.07100000000000062, -0.0470000000000006, -0.023000000000000576, 0.0, 0.02400000000000002, 0.04699999999999971, 0.07099999999999973, 0.09399999999999942, 0.11799999999999944, 0.14100000000000001, 0.16500000000000004, 0.18900000000000006, 0.21199999999999974, 0.23599999999999977, 0.25899999999999945, 0.2829999999999995, 0.30600000000000005, 0.33000000000000007, 0.35299999999999976, 3.6435730855145283],[1.462, 0.099, 0.09, 0.081, 0.072, 0.063, 0.053, 0.044, 0.035, 0.027, 0.02, 0.014, 0.009, 0.005, 0.002, 0.001, 0.0, 0.002, 0.006, 0.012, 0.017, 0.024, 0.03, 0.037, 0.043, 0.05, 0.057, 0.064, 0.072, 0.081, 0.089, 0.098, 1.462],2.3949189992310247
8,[-3.575987299076902, -0.28300000000000036, -0.25900000000000034, -0.23500000000000032, -0.21199999999999974, -0.18799999999999972, -0.16500000000000004, -0.14100000000000001, -0.11800000000000033, -0.0940000000000003, -0.07099999999999973, -0.04699999999999971, -0.022999999999999687, 0.0, 0.02400000000000002, 0.04699999999999971, 0.07099999999999973, 0.0940000000000003, 0.11800000000000033, 0.14200000000000035, 0.16500000000000004, 0.18900000000000006, 0.21199999999999974, 0.23599999999999977, 0.25900000000000034, 0.28300000000000036, 0.30600000000000005, 0.33000000000000007, 0.3540000000000001, 0.3769999999999998, 0.4009999999999998, 3.6939872990769014],[1.458, 0.094, 0.084, 0.074, 0.064, 0.055, 0.046, 0.038, 0.03, 0.023, 0.016, 0.009, 0.003, 0.0, 0.0, 0.002, 0.006, 0.011, 0.017, 0.023, 0.028, 0.032, 0.037, 0.041, 0.045, 0.05, 0.055, 0.061, 0.069, 0.079, 0.089, 1.458],2.502224358908798
9,[-3.682058366888767, -0.3769999999999998, -0.35299999999999976, -0.33000000000000007, -0.30600000000000005, -0.2829999999999995, -0.25899999999999945, -0.23599999999999977, -0.21199999999999974, -0.18799999999999972, -0.16500000000000004, -0.14100000000000001, -0.11799999999999944, -0.09399999999999942, -0.07099999999999973, -0.04699999999999971, -0.02400000000000002, 0.0, 0.02400000000000002, 0.0470000000000006, 0.07100000000000062, 0.0940000000000003, 0.11800000000000033, 0.14100000000000001, 0.16500000000000004, 0.1880000000000006, 0.21200000000000063, 0.23600000000000065, 0.25900000000000034, 0.28300000000000036, 0.30600000000000005, 3.6110583668887672],[1.461, 0.092, 0.083, 0.074, 0.065, 0.056, 0.048, 0.042, 0.036, 0.03, 0.025, 0.02, 0.015, 0.01, 0.006, 0.003, 0.001, 0.0, 0.002, 0.005, 0.01, 0.016, 0.024, 0.033, 0.04, 0.048, 0.056, 0.064, 0.073, 0.081, 0.091, 1.461],2.3797426230814387
10,[-3.6484015126392757, -0.35300000000000065, -0.33000000000000007, -0.30600000000000005, -0.28300000000000036, -0.25900000000000034, -0.23600000000000065, -0.21200000000000063, -0.18900000000000006, -0.16500000000000004, -0.14100000000000001, -0.11800000000000033, -0.0940000000000003, -0.07100000000000062, -0.0470000000000006, -0.02400000000000002, 0.0, 0.02400000000000002, 0.04699999999999971, 0.07099999999999973, 0.09399999999999942, 0.11799999999999944, 0.14100000000000001, 0.16500000000000004, 0.18900000000000006, 0.21199999999999974, 0.23599999999999977, 0.25899999999999945, 0.2829999999999995, 0.30600000000000005, 0.33000000000000007, 0.35299999999999976, 3.648401512639275],[1.459, 0.094, 0.084, 0.075, 0.066, 0.056, 0.047, 0.039, 0.031, 0.023, 0.017, 0.011, 0.008, 0.005, 0.002, 0.001, 0.0, 0.0, 0.002, 0.006, 0.01, 0.015, 0.021, 0.027, 0.033, 0.04, 0.048, 0.057, 0.066, 0.076, 0.086, 0.097, 1.459],2.1552667710791282
11,[-3.658058366888767, -0.35299999999999976, -0.33000000000000007, -0.30600000000000005, -0.2829999999999995, -0.25899999999999945, -0.23599999999999977, -0.21199999999999974, -0.18900000000000006, -0.16500000000000004, -0.14100000000000001, -0.11799999999999944, -0.09399999999999942, -0.07099999999999973, -0.04699999999999971, -0.02400000000000002, 0.0, 0.023000000000000576, 0.0470000000000006, 0.07000000000000028, 0.0940000000000003, 0.11800000000000033, 0.14100000000000001, 0.16500000000000004, 0.1880000000000006, 0.21200000000000063, 0.23500000000000032, 0.25900000000000034, 0.28200000000000003, 0.30600000000000005, 0.3290000000000006, 0.35300000000000065, 3.658058366888768],[1.462, 0.093, 0.084, 0.074, 0.064, 0.054, 0.045, 0.038, 0.032, 0.027, 0.023, 0.018, 0.012, 0.006, 0.003, 0.001, 0.0, 0.003, 0.007, 0.012, 0.017, 0.02, 0.024, 0.029, 0.035, 0.042, 0.051, 0.06, 0.069, 0.078, 0.088, 0.097, 1.462],2.490420148832229
12,[-3.646987299076901, -0.3539999999999992, -0.3299999999999992, -0.30599999999999916, -0.2829999999999995, -0.25899999999999945, -0.23599999999999977, -0.21199999999999974, -0.18899999999999917, -0.16499999999999915, -0.14199999999999946, -0.11799999999999944, -0.09499999999999975, -0.07099999999999973, -0.047999999999999154, -0.023999999999999133, 0.0, 0.023000000000000576, 0.0470000000000006, 0.07000000000000028, 0.0940000000000003, 0.11699999999999999, 0.14100000000000001, 0.1640000000000006, 0.1880000000000006, 0.2110000000000003, 0.23500000000000032, 0.258, 0.28200000000000003, 0.30600000000000005, 0.3290000000000006, 0.35300000000000065, 3.645987299076902],[1.46, 0.096, 0.087, 0.079, 0.071, 0.062, 0.053, 0.044, 0.037, 0.029, 0.022, 0.016, 0.011, 0.006, 0.003, 0.001, 0.0, 0.002, 0.005, 0.008, 0.013, 0.017, 0.022, 0.028, 0.034, 0.042, 0.05, 0.059, 0.068, 0.078, 0.088, 0.098, 1.46],2.334695259830575
14,[-3.6134725804511403, -0.30600000000000005, -0.28300000000000036, -0.25900000000000034, -0.23600000000000065, -0.21200000000000063, -0.1880000000000006, -0.16500000000000004, -0.14100000000000001, -0.11800000000000033, -0.0940000000000003, -0.07100000000000062, -0.0470000000000006, -0.02400000000000002, 0.0, 0.022999999999999687, 0.04699999999999971, 0.0699999999999994, 0.09399999999999942, 0.11699999999999999, 0.14100000000000001, 0.16499999999999915, 0.18799999999999972, 0.21199999999999974, 0.23499999999999943, 0.25899999999999945, 0.28199999999999914, 0.30599999999999916, 0.32899999999999974, 0.35299999999999976, 0.37599999999999945, 3.6834725804511397],[1.466, 0.096, 0.086, 0.075, 0.064, 0.055, 0.046, 0.038, 0.031, 0.024, 0.018, 0.012, 0.006, 0.002, 0.0, 0.0, 0.002, 0.006, 0.011, 0.016, 0.021, 0.026, 0.031, 0.037, 0.044, 0.052, 0.06, 0.069, 0.078, 0.088, 0.097, 1.466],2.306403976152674
15,[-3.5941588719521564, -0.3060000000000005, -0.2820000000000005, -0.25900000000000034, -0.23500000000000032, -0.21200000000000063, -0.1880000000000006, -0.16500000000000004, -0.14100000000000001, -0.11699999999999999, -0.0940000000000003, -0.07000000000000028, -0.0470000000000006, -0.023000000000000576, 0.0, 0.02400000000000002, 0.04800000000000004, 0.07099999999999973, 0.09499999999999975, 0.11799999999999944, 0.14199999999999946, 0.16599999999999948, 0.18900000000000006, 0.21300000000000008, 0.23599999999999977, 0.2599999999999998, 0.2829999999999995, 0.3069999999999995, 0.3309999999999995, 0.3540000000000001, 0.3780000000000001, 0.4009999999999998, 3.6891588719521557],[1.457, 0.095, 0.085, 0.075, 0.066, 0.057, 0.048, 0.041, 0.031, 0.023, 0.017, 0.01, 0.005, 0.002, 0.0, 0.001, 0.001, 0.001, 0.004, 0.007, 0.011, 0.016, 0.022, 0.029, 0.036, 0.043, 0.05, 0.059, 0.068, 0.077, 0.086, 0.096, 1.457],2.2766667935315708
16,[-3.6725436482630056, -0.35299999999999976, -0.33000000000000007, -0.30600000000000005, -0.28300000000000036, -0.25900000000000034, -0.23599999999999977, -0.21200000000000063, -0.18900000000000006, -0.16500000000000004, -0.14200000000000035, -0.11800000000000033, -0.0940000000000003, -0.07100000000000062, -0.0470000000000006, -0.02400000000000002, 0.0, 0.022999999999999687, 0.04699999999999971, 0.0699999999999994, 0.09399999999999942, 0.11699999999999999, 0.14100000000000001, 0.16500000000000004, 0.18799999999999972, 0.21199999999999974, 0.23499999999999943, 0.25899999999999945, 0.28200000000000003, 0.30600000000000005, 0.32899999999999974, 0.35299999999999976, 3.6725436482630056],[1.473, 0.098, 0.083, 0.07, 0.058, 0.048, 0.04, 0.035, 0.033, 0.031, 0.028, 0.023, 0.018, 0.012, 0.005, 0.001, 0.0, 0.002, 0.005, 0.008, 0.012, 0.016, 0.02, 0.026, 0.032, 0.039, 0.046, 0.054, 0.063, 0.072, 0.081, 0.09, 1.473],2.3957215190027896
17,[-3.598987299076901, -0.3059999999999996, -0.2829999999999995, -0.25899999999999945, -0.23499999999999943, -0.21199999999999974, -0.18799999999999972, -0.16500000000000004, -0.14100000000000001, -0.11799999999999944, -0.09399999999999942, -0.0699999999999994, -0.04699999999999971, -0.022999999999999687, 0.0, 0.02400000000000002, 0.0470000000000006, 0.07100000000000062, 0.0940000000000003, 0.11800000000000033, 0.14200000000000035, 0.16500000000000004, 0.18900000000000006, 0.21200000000000063, 0.23600000000000065, 0.25900000000000034, 0.28300000000000036, 0.3070000000000004, 0.33000000000000007, 0.3540000000000001, 0.37700000000000067, 3.6699872990769022],[1.456, 0.092, 0.082, 0.072, 0.063, 0.055, 0.046, 0.038, 0.03, 0.023, 0.016, 0.01, 0.005, 0.002, 0.0, 0.0, 0.002, 0.004, 0.009, 0.014, 0.019, 0.025, 0.031, 0.038, 0.045, 0.052, 0.061, 0.07, 0.079, 0.089, 0.099, 1.456],2.29778546510921
19,[-3.658058366888767, -0.35299999999999976, -0.32899999999999974, -0.30600000000000005, -0.28200000000000003, -0.25900000000000034, -0.23500000000000032, -0.21199999999999974, -0.18799999999999972, -0.16500000000000004, -0.14100000000000001, -0.11800000000000033, -0.0940000000000003, -0.07099999999999973, -0.04699999999999971, -0.02400000000000002, 0.0, 0.022999999999999687, 0.04699999999999971, 0.07000000000000028, 0.0940000000000003, 0.11699999999999999, 0.14100000000000001, 0.16500000000000004, 0.18799999999999972, 0.21199999999999974, 0.23500000000000032, 0.25900000000000034, 0.28200000000000003, 0.30600000000000005, 0.32899999999999974, 3.634058366888767],[1.462, 0.093, 0.083, 0.073, 0.062, 0.053, 0.044, 0.038, 0.031, 0.025, 0.02, 0.015, 0.011, 0.007, 0.003, 0.001, 0.0, 0.0, 0.003, 0.007, 0.012, 0.019, 0.026, 0.033, 0.041, 0.048, 0.056, 0.064, 0.073, 0.083, 0.093, 1.462],2.3384060717833646
20,[-3.6738157262016484, -0.37600000000000033, -0.35299999999999976, -0.32899999999999974, -0.3049999999999997, -0.28200000000000003, -0.258, -0.23500000000000032, -0.2110000000000003, -0.18799999999999972, -0.1639999999999997, -0.14100000000000001, -0.11699999999999999, -0.0940000000000003, -0.07000000000000028, -0.04699999999999971, -0.022999999999999687, 0.0, 0.02400000000000002, 0.04800000000000004, 0.07099999999999973, 0.09499999999999975, 0.11800000000000033, 0.14199999999999946, 0.16500000000000004, 0.18900000000000006, 0.21199999999999974, 0.23599999999999977, 0.25899999999999945, 0.2829999999999995, 0.30600000000000005, 0.33000000000000007, 3.627815726201648],[1.459, 0.093, 0.083, 0.074, 0.065, 0.058, 0.049, 0.042, 0.036, 0.031, 0.026, 0.022, 0.022, 0.019, 0.013, 0.007, 0.002, 0.0, 0.0, 0.001, 0.009, 0.017, 0.024, 0.03, 0.036, 0.041, 0.048, 0.056, 0.065, 0.074, 0.084, 0.094, 1.459],2.8303554018395176

Просто небольшая провокация: почему вы конкретно хотите использовать модель нейронной сети с таким небольшим количеством выборок? Переобучение произойдет очень легко, и точность, вероятно, не будет лучше, чем у более простой модели.

grfreitas 30.04.2024 10:14

Спасибо за ответ @grfreitas. Я пробовал использовать традиционные модели машинного обучения, но у меня возникли проблемы с реализацией профиля высоты, поскольку каждую точку нельзя рассматривать как отдельный объект в модели ML.

Seppe Vanheulenberghe 30.04.2024 10:53

Я не думаю, что вам нужно что-то реализовывать для профиля высоты — вы можете попробовать передать данные изображения непосредственно в модель (например, RandomForestRegressor). Например, если ваше изображение (каждый образец) имеет размер 28x28 пикселей в оттенках серого, вы должны сгладить его до 784 объектов (где каждый пиксель является элементом) и передать это в модель. Модель будет учитывать функции вместе во время обучения. На странице примеров scikit-learn.org представлены различные демонстрации использования моделей с данными изображения. Я мог бы попробовать несколько моделей, если бы были доступны образцы данных.

MuhammedYunus 01.05.2024 00:02

Спасибо за ваш комментарий @MuhammedYunus! Вероятно, это неправильный способ, но я добавил несколько необработанных образцов в исходный пост. Можете ли вы дать мне ссылку на примеры, о которых вы говорите?

Seppe Vanheulenberghe 01.05.2024 08:04

Без проблем. На этой странице перечислены все примеры, а этот представляет собой пример классификации изображений с использованием случайных лесов (эта модель также доступна в качестве регрессора). Ниже я опубликовал ответ, в котором показано, как можно предоставить профили в качестве входных данных для моделей регрессии.

MuhammedYunus 02.05.2024 01:32
Стоит ли изучать PHP в 2023-2024 годах?
Стоит ли изучать PHP в 2023-2024 годах?
Привет всем, сегодня я хочу высказать свои соображения по поводу вопроса, который я уже много раз получал в своем сообществе: "Стоит ли изучать PHP в...
Поведение ключевого слова "this" в стрелочной функции в сравнении с нормальной функцией
Поведение ключевого слова "this" в стрелочной функции в сравнении с нормальной функцией
В JavaScript одним из самых запутанных понятий является поведение ключевого слова "this" в стрелочной и обычной функциях.
Приемы CSS-макетирования - floats и Flexbox
Приемы CSS-макетирования - floats и Flexbox
Здравствуйте, друзья-студенты! Готовы совершенствовать свои навыки веб-дизайна? Сегодня в нашем путешествии мы рассмотрим приемы CSS-верстки - в...
Тестирование функциональных ngrx-эффектов в Angular 16 с помощью Jest
В системе управления состояниями ngrx, совместимой с Angular 16, появились функциональные эффекты. Это здорово и делает код определенно легче для...
Концепция локализации и ее применение в приложениях React ⚡️
Концепция локализации и ее применение в приложениях React ⚡️
Локализация - это процесс адаптации приложения к различным языкам и культурным требованиям. Это позволяет пользователям получить опыт, соответствующий...
Пользовательский скаляр GraphQL
Пользовательский скаляр GraphQL
Листовые узлы системы типов GraphQL называются скалярами. Достигнув скалярного типа, невозможно спуститься дальше по иерархии типов. Скалярный тип...
0
5
111
1
Перейти к ответу Данный вопрос помечен как решенный

Ответы 1

Ответ принят как подходящий

Из 15 предоставленных вами образцов можно увидеть тенденцию к Kt:

Более высокие значения кажутся более узкими и нерегулярными по сравнению с более низкими кривыми Kt, которые более широкие и плавные.

Я нормализовал каждый профиль по среднему значению его краев, предположив, что они служат ориентирами. Эта нормализация позволила мне отбросить крайние точки, что привело к меньшему количеству функций и, следовательно, к меньшей вероятности переподбора модели для небольшого набора данных.

Область выреза, примерно, отбирается с одинаковыми интервалами, хотя не все интервалы доступны для каждого образца (длины примерно от 31 до 35):

Я использовал информацию на рисунке выше, чтобы определить средние местоположения, общие для всех образцов, а затем повторно выполнил выборку данных по этой новой оси. Причина такой осторожности при повторной выборке заключается в том, что я хочу минимизировать искажение немногих доступных семплов.

После небольшой предварительной обработки и настройки я получил средний балл валидации MAE CV около 0,005 с использованием различных линейных моделей. Это может быть очень оптимистично, поскольку я использую в основном синтетические данные. Код ниже, если вы хотите попробовать больше данных.

             model       mae
rank                        
0          pls_reg  0.004795
1       linear_reg  0.005180
2       linear_svr  0.014007
3            ridge  0.019600
4              knn  0.029414
5        gradboost  0.030936
6     randomforest  0.037000

Я создал синтетические образцы, чтобы иметь больше возможностей для экспериментов при подборе моделей. Однако вы также можете использовать его в качестве метода дополнения, чтобы уменьшить переобучение вашего набора данных.


import numpy as np
import pandas as pd
from matplotlib import pyplot as plt

id_list = [1, 2, 3, 7, 8, 9, 10, 11, 12, 14,15, 16, 17, 19, 20]

positions_list = [
    [-3.6183010075758877, -0.30600000000000094, -0.28200000000000003, -0.25900000000000034, -0.2350000000000012, -0.21199999999999974, -0.1880000000000006, -0.16500000000000092, -0.14100000000000001, -0.11800000000000033, -0.0940000000000012, -0.07099999999999973, -0.0470000000000006, -0.02400000000000091, 0.0, 0.022999999999999687, 0.04699999999999882, 0.07099999999999973, 0.09399999999999942, 0.11800000000000033, 0.14100000000000001, 0.16499999999999915, 0.18799999999999883, 0.21199999999999974, 0.23499999999999943, 0.25900000000000034, 0.28200000000000003, 0.30599999999999916, 0.32899999999999885, 0.35299999999999976, 3.6653010075758865],
    [-3.7002299397640215, -0.40000000000000036, -0.37600000000000033, -0.35299999999999976, -0.32899999999999974, -0.30600000000000005, -0.28200000000000003, -0.25900000000000034, -0.23500000000000032, -0.2110000000000003, -0.18799999999999972, -0.1639999999999997, -0.14100000000000001, -0.11699999999999999, -0.0940000000000003, -0.07000000000000028, -0.04699999999999971, -0.022999999999999687, 0.0, 0.02400000000000002, 0.04699999999999971, 0.07099999999999973, 0.0940000000000003, 0.11800000000000033, 0.14100000000000001, 0.16500000000000004, 0.18799999999999972, 0.21199999999999974, 0.23500000000000032, 0.25900000000000034, 0.28200000000000003, 0.30600000000000005, 0.33000000000000007, 3.630229939764021],
    [-3.621845163453171, -0.35299999999999976, -0.32899999999999885, -0.30599999999999916, -0.28200000000000003, -0.25899999999999856, -0.23499999999999943, -0.21199999999999974, -0.18799999999999883, -0.16499999999999915, -0.14100000000000001, -0.11799999999999855, -0.09399999999999942, -0.07099999999999973, -0.04699999999999882, -0.023999999999999133, 0.0, 0.023000000000001464, 0.0470000000000006, 0.07000000000000028, 0.0940000000000012, 0.11700000000000088, 0.14100000000000001, 0.16500000000000092, 0.1880000000000006, 0.21200000000000152, 0.2350000000000012, 0.25900000000000034, 0.28200000000000003, 0.30600000000000094, 0.3290000000000006, 0.35300000000000153, 0.3760000000000012, 0.40000000000000036, 3.6688451634531716],
    [-3.643573085514529, -0.35300000000000065, -0.33000000000000007, -0.30600000000000005, -0.28300000000000036, -0.25900000000000034, -0.23500000000000032, -0.21200000000000063, -0.1880000000000006, -0.16500000000000004, -0.14100000000000001, -0.11800000000000033, -0.0940000000000003, -0.07100000000000062, -0.0470000000000006, -0.023000000000000576, 0.0, 0.02400000000000002, 0.04699999999999971, 0.07099999999999973, 0.09399999999999942, 0.11799999999999944, 0.14100000000000001, 0.16500000000000004, 0.18900000000000006, 0.21199999999999974, 0.23599999999999977, 0.25899999999999945, 0.2829999999999995, 0.30600000000000005, 0.33000000000000007, 0.35299999999999976, 3.6435730855145283],
    [-3.575987299076902, -0.28300000000000036, -0.25900000000000034, -0.23500000000000032, -0.21199999999999974, -0.18799999999999972, -0.16500000000000004, -0.14100000000000001, -0.11800000000000033, -0.0940000000000003, -0.07099999999999973, -0.04699999999999971, -0.022999999999999687, 0.0, 0.02400000000000002, 0.04699999999999971, 0.07099999999999973, 0.0940000000000003, 0.11800000000000033, 0.14200000000000035, 0.16500000000000004, 0.18900000000000006, 0.21199999999999974, 0.23599999999999977, 0.25900000000000034, 0.28300000000000036, 0.30600000000000005, 0.33000000000000007, 0.3540000000000001, 0.3769999999999998, 0.4009999999999998, 3.6939872990769014],
    [-3.682058366888767, -0.3769999999999998, -0.35299999999999976, -0.33000000000000007, -0.30600000000000005, -0.2829999999999995, -0.25899999999999945, -0.23599999999999977, -0.21199999999999974, -0.18799999999999972, -0.16500000000000004, -0.14100000000000001, -0.11799999999999944, -0.09399999999999942, -0.07099999999999973, -0.04699999999999971, -0.02400000000000002, 0.0, 0.02400000000000002, 0.0470000000000006, 0.07100000000000062, 0.0940000000000003, 0.11800000000000033, 0.14100000000000001, 0.16500000000000004, 0.1880000000000006, 0.21200000000000063, 0.23600000000000065, 0.25900000000000034, 0.28300000000000036, 0.30600000000000005, 3.6110583668887672],
    [-3.6484015126392757, -0.35300000000000065, -0.33000000000000007, -0.30600000000000005, -0.28300000000000036, -0.25900000000000034, -0.23600000000000065, -0.21200000000000063, -0.18900000000000006, -0.16500000000000004, -0.14100000000000001, -0.11800000000000033, -0.0940000000000003, -0.07100000000000062, -0.0470000000000006, -0.02400000000000002, 0.0, 0.02400000000000002, 0.04699999999999971, 0.07099999999999973, 0.09399999999999942, 0.11799999999999944, 0.14100000000000001, 0.16500000000000004, 0.18900000000000006, 0.21199999999999974, 0.23599999999999977, 0.25899999999999945, 0.2829999999999995, 0.30600000000000005, 0.33000000000000007, 0.35299999999999976, 3.648401512639275],
    [-3.658058366888767, -0.35299999999999976, -0.33000000000000007, -0.30600000000000005, -0.2829999999999995, -0.25899999999999945, -0.23599999999999977, -0.21199999999999974, -0.18900000000000006, -0.16500000000000004, -0.14100000000000001, -0.11799999999999944, -0.09399999999999942, -0.07099999999999973, -0.04699999999999971, -0.02400000000000002, 0.0, 0.023000000000000576, 0.0470000000000006, 0.07000000000000028, 0.0940000000000003, 0.11800000000000033, 0.14100000000000001, 0.16500000000000004, 0.1880000000000006, 0.21200000000000063, 0.23500000000000032, 0.25900000000000034, 0.28200000000000003, 0.30600000000000005, 0.3290000000000006, 0.35300000000000065, 3.658058366888768],
    [-3.646987299076901, -0.3539999999999992, -0.3299999999999992, -0.30599999999999916, -0.2829999999999995, -0.25899999999999945, -0.23599999999999977, -0.21199999999999974, -0.18899999999999917, -0.16499999999999915, -0.14199999999999946, -0.11799999999999944, -0.09499999999999975, -0.07099999999999973, -0.047999999999999154, -0.023999999999999133, 0.0, 0.023000000000000576, 0.0470000000000006, 0.07000000000000028, 0.0940000000000003, 0.11699999999999999, 0.14100000000000001, 0.1640000000000006, 0.1880000000000006, 0.2110000000000003, 0.23500000000000032, 0.258, 0.28200000000000003, 0.30600000000000005, 0.3290000000000006, 0.35300000000000065, 3.645987299076902],
    [-3.6134725804511403, -0.30600000000000005, -0.28300000000000036, -0.25900000000000034, -0.23600000000000065, -0.21200000000000063, -0.1880000000000006, -0.16500000000000004, -0.14100000000000001, -0.11800000000000033, -0.0940000000000003, -0.07100000000000062, -0.0470000000000006, -0.02400000000000002, 0.0, 0.022999999999999687, 0.04699999999999971, 0.0699999999999994, 0.09399999999999942, 0.11699999999999999, 0.14100000000000001, 0.16499999999999915, 0.18799999999999972, 0.21199999999999974, 0.23499999999999943, 0.25899999999999945, 0.28199999999999914, 0.30599999999999916, 0.32899999999999974, 0.35299999999999976, 0.37599999999999945, 3.6834725804511397],
    [-3.5941588719521564, -0.3060000000000005, -0.2820000000000005, -0.25900000000000034, -0.23500000000000032, -0.21200000000000063, -0.1880000000000006, -0.16500000000000004, -0.14100000000000001, -0.11699999999999999, -0.0940000000000003, -0.07000000000000028, -0.0470000000000006, -0.023000000000000576, 0.0, 0.02400000000000002, 0.04800000000000004, 0.07099999999999973, 0.09499999999999975, 0.11799999999999944, 0.14199999999999946, 0.16599999999999948, 0.18900000000000006, 0.21300000000000008, 0.23599999999999977, 0.2599999999999998, 0.2829999999999995, 0.3069999999999995, 0.3309999999999995, 0.3540000000000001, 0.3780000000000001, 0.4009999999999998, 3.6891588719521557],
    [-3.6725436482630056, -0.35299999999999976, -0.33000000000000007, -0.30600000000000005, -0.28300000000000036, -0.25900000000000034, -0.23599999999999977, -0.21200000000000063, -0.18900000000000006, -0.16500000000000004, -0.14200000000000035, -0.11800000000000033, -0.0940000000000003, -0.07100000000000062, -0.0470000000000006, -0.02400000000000002, 0.0, 0.022999999999999687, 0.04699999999999971, 0.0699999999999994, 0.09399999999999942, 0.11699999999999999, 0.14100000000000001, 0.16500000000000004, 0.18799999999999972, 0.21199999999999974, 0.23499999999999943, 0.25899999999999945, 0.28200000000000003, 0.30600000000000005, 0.32899999999999974, 0.35299999999999976, 3.6725436482630056],
    [-3.598987299076901, -0.3059999999999996, -0.2829999999999995, -0.25899999999999945, -0.23499999999999943, -0.21199999999999974, -0.18799999999999972, -0.16500000000000004, -0.14100000000000001, -0.11799999999999944, -0.09399999999999942, -0.0699999999999994, -0.04699999999999971, -0.022999999999999687, 0.0, 0.02400000000000002, 0.0470000000000006, 0.07100000000000062, 0.0940000000000003, 0.11800000000000033, 0.14200000000000035, 0.16500000000000004, 0.18900000000000006, 0.21200000000000063, 0.23600000000000065, 0.25900000000000034, 0.28300000000000036, 0.3070000000000004, 0.33000000000000007, 0.3540000000000001, 0.37700000000000067, 3.6699872990769022],
    [-3.658058366888767, -0.35299999999999976, -0.32899999999999974, -0.30600000000000005, -0.28200000000000003, -0.25900000000000034, -0.23500000000000032, -0.21199999999999974, -0.18799999999999972, -0.16500000000000004, -0.14100000000000001, -0.11800000000000033, -0.0940000000000003, -0.07099999999999973, -0.04699999999999971, -0.02400000000000002, 0.0, 0.022999999999999687, 0.04699999999999971, 0.07000000000000028, 0.0940000000000003, 0.11699999999999999, 0.14100000000000001, 0.16500000000000004, 0.18799999999999972, 0.21199999999999974, 0.23500000000000032, 0.25900000000000034, 0.28200000000000003, 0.30600000000000005, 0.32899999999999974, 3.634058366888767],
    [-3.6738157262016484, -0.37600000000000033, -0.35299999999999976, -0.32899999999999974, -0.3049999999999997, -0.28200000000000003, -0.258, -0.23500000000000032, -0.2110000000000003, -0.18799999999999972, -0.1639999999999997, -0.14100000000000001, -0.11699999999999999, -0.0940000000000003, -0.07000000000000028, -0.04699999999999971, -0.022999999999999687, 0.0, 0.02400000000000002, 0.04800000000000004, 0.07099999999999973, 0.09499999999999975, 0.11800000000000033, 0.14199999999999946, 0.16500000000000004, 0.18900000000000006, 0.21199999999999974, 0.23599999999999977, 0.25899999999999945, 0.2829999999999995, 0.30600000000000005, 0.33000000000000007, 3.627815726201648],
]
positions_list = [np.array(positions) for positions in positions_list]

heights_list = [
    [1.461, 0.089, 0.079, 0.069, 0.06, 0.052, 0.044, 0.035, 0.028, 0.02, 0.013, 0.007, 0.003, 0.001, 0.0, 0.0, 0.002, 0.005, 0.008, 0.014, 0.02, 0.026, 0.033, 0.041, 0.048, 0.056, 0.064, 0.073, 0.082, 0.091, 1.461],
    [1.466, 0.09899999999999998, 0.09199999999999997, 0.08399999999999996, 0.07699999999999996, 0.06999999999999995, 0.061999999999999944, 0.05399999999999994, 0.04499999999999993, 0.03700000000000003, 0.030000000000000027, 0.025000000000000022, 0.020000000000000018, 0.016000000000000014, 0.015000000000000013, 0.009000000000000008, 0.006000000000000005, 0.0010000000000000009, 0.0, 0.0030000000000000027, 0.010000000000000009, 0.015000000000000013, 0.025000000000000022, 0.030000000000000027, 0.03700000000000003, 0.04200000000000004, 0.04699999999999993, 0.051999999999999935, 0.061999999999999944, 0.07099999999999995, 0.07999999999999996, 0.08899999999999997, 0.09599999999999997, 1.466],
    [1.453, 0.09899999999999998, 0.08899999999999997, 0.07899999999999996, 0.07000000000000006, 0.06000000000000005, 0.052000000000000046, 0.04500000000000004, 0.039000000000000035, 0.03300000000000003, 0.027000000000000024, 0.02200000000000002, 0.016000000000000014, 0.010000000000000009, 0.006000000000000005, 0.0020000000000000018, 0.0, 0.0, 0.0010000000000000009, 0.0050000000000000044, 0.007000000000000006, 0.007000000000000006, 0.01100000000000001, 0.014000000000000012, 0.018000000000000016, 0.02200000000000002, 0.028000000000000025, 0.03500000000000003, 0.04400000000000004, 0.05300000000000005, 0.062000000000000055, 0.07200000000000006, 0.08199999999999996, 0.09199999999999997, 1.453],
    [1.462, 0.099, 0.09, 0.081, 0.072, 0.063, 0.053, 0.044, 0.035, 0.027, 0.02, 0.014, 0.009, 0.005, 0.002, 0.001, 0.0, 0.002, 0.006, 0.012, 0.017, 0.024, 0.03, 0.037, 0.043, 0.05, 0.057, 0.064, 0.072, 0.081, 0.089, 0.098, 1.462],
    [1.458, 0.094, 0.084, 0.074, 0.064, 0.055, 0.046, 0.038, 0.03, 0.023, 0.016, 0.009, 0.003, 0.0, 0.0, 0.002, 0.006, 0.011, 0.017, 0.023, 0.028, 0.032, 0.037, 0.041, 0.045, 0.05, 0.055, 0.061, 0.069, 0.079, 0.089, 1.458],
    [1.461, 0.092, 0.083, 0.074, 0.065, 0.056, 0.048, 0.042, 0.036, 0.03, 0.025, 0.02, 0.015, 0.01, 0.006, 0.003, 0.001, 0.0, 0.002, 0.005, 0.01, 0.016, 0.024, 0.033, 0.04, 0.048, 0.056, 0.064, 0.073, 0.081, 0.091, 1.461],
    [1.459, 0.094, 0.084, 0.075, 0.066, 0.056, 0.047, 0.039, 0.031, 0.023, 0.017, 0.011, 0.008, 0.005, 0.002, 0.001, 0.0, 0.0, 0.002, 0.006, 0.01, 0.015, 0.021, 0.027, 0.033, 0.04, 0.048, 0.057, 0.066, 0.076, 0.086, 0.097, 1.459],
    [1.462, 0.093, 0.084, 0.074, 0.064, 0.054, 0.045, 0.038, 0.032, 0.027, 0.023, 0.018, 0.012, 0.006, 0.003, 0.001, 0.0, 0.003, 0.007, 0.012, 0.017, 0.02, 0.024, 0.029, 0.035, 0.042, 0.051, 0.06, 0.069, 0.078, 0.088, 0.097, 1.462],
    [1.46, 0.096, 0.087, 0.079, 0.071, 0.062, 0.053, 0.044, 0.037, 0.029, 0.022, 0.016, 0.011, 0.006, 0.003, 0.001, 0.0, 0.002, 0.005, 0.008, 0.013, 0.017, 0.022, 0.028, 0.034, 0.042, 0.05, 0.059, 0.068, 0.078, 0.088, 0.098, 1.46],
    [1.466, 0.096, 0.086, 0.075, 0.064, 0.055, 0.046, 0.038, 0.031, 0.024, 0.018, 0.012, 0.006, 0.002, 0.0, 0.0, 0.002, 0.006, 0.011, 0.016, 0.021, 0.026, 0.031, 0.037, 0.044, 0.052, 0.06, 0.069, 0.078, 0.088, 0.097, 1.466],
    [1.457, 0.095, 0.085, 0.075, 0.066, 0.057, 0.048, 0.041, 0.031, 0.023, 0.017, 0.01, 0.005, 0.002, 0.0, 0.001, 0.001, 0.001, 0.004, 0.007, 0.011, 0.016, 0.022, 0.029, 0.036, 0.043, 0.05, 0.059, 0.068, 0.077, 0.086, 0.096, 1.457],
    [1.473, 0.098, 0.083, 0.07, 0.058, 0.048, 0.04, 0.035, 0.033, 0.031, 0.028, 0.023, 0.018, 0.012, 0.005, 0.001, 0.0, 0.002, 0.005, 0.008, 0.012, 0.016, 0.02, 0.026, 0.032, 0.039, 0.046, 0.054, 0.063, 0.072, 0.081, 0.09, 1.473],
    [1.456, 0.092, 0.082, 0.072, 0.063, 0.055, 0.046, 0.038, 0.03, 0.023, 0.016, 0.01, 0.005, 0.002, 0.0, 0.0, 0.002, 0.004, 0.009, 0.014, 0.019, 0.025, 0.031, 0.038, 0.045, 0.052, 0.061, 0.07, 0.079, 0.089, 0.099, 1.456],
    [1.462, 0.093, 0.083, 0.073, 0.062, 0.053, 0.044, 0.038, 0.031, 0.025, 0.02, 0.015, 0.011, 0.007, 0.003, 0.001, 0.0, 0.0, 0.003, 0.007, 0.012, 0.019, 0.026, 0.033, 0.041, 0.048, 0.056, 0.064, 0.073, 0.083, 0.093, 1.462],
    [1.459, 0.093, 0.083, 0.074, 0.065, 0.058, 0.049, 0.042, 0.036, 0.031, 0.026, 0.022, 0.022, 0.019, 0.013, 0.007, 0.002, 0.0, 0.0, 0.001, 0.009, 0.017, 0.024, 0.03, 0.036, 0.041, 0.048, 0.056, 0.065, 0.074, 0.084, 0.094, 1.459],
]
heights_list = [np.array(heights) for heights in heights_list]

kt_list = [2.2766371542377914, 2.616131437456064, 2.142634834792794, 2.3949189992310247, 2.502224358908798, 2.3797426230814387, 2.1552667710791282, 2.490420148832229, 2.334695259830575, 2.306403976152674, 2.2766667935315708, 2.3957215190027896, 2.29778546510921, 2.3384060717833646, 2.8303554018395176,]

#View data
from matplotlib.colors import Normalize, CenteredNorm
from matplotlib.cm import ScalarMappable
from matplotlib.gridspec import GridSpec

n_samples = len(positions_list)
vmin = min(kt_list)
vmax = max(kt_list)
norm  = Normalize(vmin, vmax) #maps Kt range to ~0-1
centred_norm = CenteredNorm(vcenter=np.median(kt_list), halfrange=0.5 * (vmax - vmin))

colour_kt = ScalarMappable(centred_norm, 'coolwarm')

f = plt.figure(figsize=(11, 3))
gs = GridSpec(1, 3, width_ratios=[0.5, 0.5, 3], height_ratios=[1])
ax_xleft = f.add_subplot(gs[0])
ax_xright = f.add_subplot(gs[1])
ax_xcentre = f.add_subplot(gs[2])

for ax in [ax_xleft, ax_xright, ax_xcentre]:
    for positions, heights, kt in zip(positions_list, heights_list, kt_list):
        ax.plot(positions, heights, marker='.', ms='10', linewidth=2, color=colour_kt.to_rgba(kt))
    
    if ax is ax_xleft:
        ax.set(ylabel='height')
    if ax is ax_xright:
        ax.tick_params(labelleft=False, left=False)
        ax.spines.left.set_visible(False)
        
    ax.spines[['top', 'right']].set_visible(False)
    ax.set_title(
        'left' if ax is ax_xleft else ('right' if ax is ax_xright else 'centre')
    )
    ax.set_xlabel('position')
    
    if ax in [ax_xleft, ax_xright]:
        x_lims = [-3.71, -3.55] if ax is ax_xleft else [3.55, 3.71]
        y_lims = [1.45, 1.475]
    else:
        x_lims = [-0.41, 0.42]
        y_lims = [-0.005, 0.105] #general notch area
        
        x_lims = [-0.2, 0.2]
        y_lims = [-0.005, 0.04] #notch peak
    ax.set(xlim=x_lims, ylim=y_lims)
  
#colorbar on right  
ax_pos = ax_xcentre.get_position()
cax = f.add_subplot([
    ax_pos.x0 + ax_pos.width * 1.05, ax_pos.y0, ax_pos.width / 20, ax_pos.height
])
f.colorbar(cax=cax, mappable=colour_kt, label='Kt\n(white = median Kt)')

#Synthesise data for testing
def synthesise_samples(n_samples, positions_list, heights_list, kt_list):
    newsamples_positions = []
    newsamples_heights = []
    newsamples_kt = []
    
    sample_idxs = np.arange(len(positions_list))
    
    for _ in range(n_samples):
        #Randomly select two map
        idx_i = np.random.choice(sample_idxs)
        idx_j = np.random.choice(sample_idxs[sample_idxs != idx_i])

        positions_i, positions_j = [positions_list[idx] for idx in [idx_i, idx_j]]
        heights_i, heights_j = [heights_list[idx] for idx in [idx_i, idx_j]]
        
        #Decide on a new length at random, and interpolate onto a common axis x_j
        new_len = len(positions_j)
        heights_interp = np.interp(positions_j, positions_i, heights_i)

        #Randomly sample a linear interpolation between the two samples in feature space
        alpha = np.random.uniform()
        new_heights = alpha * heights_interp + (1 - alpha) * heights_j
        
        #Repeat for the target, with some noise added
        new_kt = alpha * kt_list[idx_i] + (1 - alpha) * kt_list[idx_j]
        new_kt += np.random.randn() * np.std(kt_list, ddof=1) / 30
        
        #Store the the new sample
        newsamples_positions.append(positions_j)
        newsamples_heights.append(new_heights)
        newsamples_kt.append(new_kt)
    
    return newsamples_positions, newsamples_heights, newsamples_kt

positions_synth, heights_synth, kt_synth = synthesise_samples(
    150 - len(positions_list), positions_list, heights_list, kt_list
)

#Combine with the original data
positions_list.extend(positions_synth)
heights_list.extend(heights_synth)
kt_list.extend(kt_synth)

n_samples = len(positions_list)

#Split the data
from sklearn.model_selection import train_test_split
import pandas as pd

test_frac = 0.15
train_frac = 1 - test_frac

kt_binned = pd.cut(kt_list, bins=5)
train_ixs, test_ixs = train_test_split(
    np.arange(n_samples), test_size=test_frac, stratify=kt_binned, random_state=0
)

print(
    f'Train/test sizes: {train_ixs.size}/{test_ixs.size}',
    f'| fractions: {train_frac:.2f}/{test_frac:.2f}'
)

positions_list_trn = [positions_list[i] for i in train_ixs]
heights_list_trn = [heights_list[i] for i in train_ixs]
kt_list_trn = [kt_list[i] for i in train_ixs]

#Assess sampling consistency and resample onto a common axis
#Normalise each sample's heights using the mean of its edge points
heights_list = [heights / np.mean(heights[[0, -1]]) for heights in heights_list]

#Visualise the distribution of x sampling
f, ax = plt.subplots(figsize=(11, 2))
ax.set_xlim(-0.45, 0.45) #look at notch
ax.set_xlabel('positions')
ax.tick_params(left=False, labelleft=False)
ax.spines[['top', 'right', 'left']].set_visible(False)
ax.spines.bottom.set_bounds(-0.4, 0.4)
ax.set_title('sampling consistency (black) and average locations (red)', fontsize=11)

for sample_idx, sample in enumerate(positions_list_trn[:15]):
    ax.scatter(sample, sample_idx * np.ones_like(sample), marker='|', c='darkslategray')

#Find the average sampling positions, in order to resample onto them
avg_step_size = np.mean(
    [delta for sample in positions_list_trn for delta in np.diff(sample[1:-1])]
)

#The start and end points of the axis to resample onto. Dropping the edges.
start_pos = min(pos for sample in positions_list_trn for pos in sample[1:-1])
end_pos = max(pos for sample in positions_list_trn for pos in sample[1:-1])

max_seq_len = max(map(len, positions_list_trn))
pos_fine = np.linspace(start_pos, end_pos, num=max_seq_len * 100)

positions_interp = np.empty([n_samples, pos_fine.size]) * np.nan

for sample_idx, positions in enumerate(positions_list):
    for pos_idx, pos in enumerate(positions[1:-1]):
        match_idx = np.argmin(np.abs(pos_fine - pos))
        positions_interp[sample_idx, match_idx] = pos

avg_positions = np.nanmean(positions_interp, axis=0)
avg_positions = avg_positions[~np.isnan(avg_positions)]
avg_positions = avg_positions[
    np.argwhere(np.abs(np.diff(avg_positions, prepend=1e3)) > avg_step_size / 4)
].ravel()

#Visualise results
[ax.axvline(pos, ymax=0.05, color='red', linewidth=3, alpha=0.4) for pos in avg_positions];

#Resample heights onto the average sampling positions
heights_resampled = np.zeros([n_samples, avg_positions.size])
for sample_idx, (positions, heights) in enumerate(zip(positions_list, heights_list)):
    heights_resampled[sample_idx, :] = np.interp(avg_positions, positions, heights)

#Use CV to assess various models and view results
from sklearn.preprocessing import StandardScaler

from sklearn.linear_model import LinearRegression, Ridge
from sklearn.svm import LinearSVR
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
from sklearn.cross_decomposition import PLSRegression
from sklearn.neighbors import KNeighborsRegressor

from sklearn.pipeline import make_pipeline
from sklearn.model_selection import cross_validate, LeaveOneOut

np.random.seed(0)

X_train, X_test = [heights_resampled[ixs] for ixs in [train_ixs, test_ixs]]
y_train, y_test = [np.array(kt_list)[ixs] for ixs in [train_ixs, test_ixs]]

models_dict = {
    'linear_reg': LinearRegression(),
    'ridge': Ridge(),
    'linear_svr': LinearSVR(C=0.05, dual='auto', max_iter=2000),
    'randomforest': RandomForestRegressor(min_samples_split=8),
    'gradboost': GradientBoostingRegressor(),
    'pls_reg': PLSRegression(n_components=15),
    'knn': KNeighborsRegressor(weights='distance'),
}

results_dfs = []
for name, model in models_dict.items():
    pipeline = make_pipeline(StandardScaler(), model)
    
    print(name, '...')
    results = cross_validate(
        pipeline, X_train, y_train,
        scoring='neg_mean_absolute_error',
        cv=LeaveOneOut(),
        n_jobs=-1,
    )
    
    results_df = pd.DataFrame(
        {'model': [name],
         'mae': [-results['test_score'].mean()]
         },
    )
    
    results_dfs.append(results_df)

results_df = pd.concat(results_dfs, axis=0, ignore_index=True)
results_df = (
    results_df
    .sort_values(by='mae')
    .reset_index(drop=True)
    .rename_axis(index='rank')
)

display(
    results_df
    .style
    .format(precision=4)
    .background_gradient(subset=['mae'], cmap='plasma')
    .set_caption('CV validation scores')
)

ax = results_df.plot(kind='bar', x='model', ylabel='mae', legend=False)
ax.figure.set_size_inches(4, 2)

Спасибо за такой развернутый ответ, это именно то, что мне нужно!

Seppe Vanheulenberghe 04.05.2024 07:18

С удовольствием @SeppeVanheulenberghe. Любые вопросы просто задавайте.

MuhammedYunus 04.05.2024 11:20

Другие вопросы по теме