Почему я несу огромные потери в этой модели преобразования цвета, которую я создаю?

У меня есть набор данных из более чем 200 тысяч цветовых пятен, снятых на двух разных носителях, для которых я создаю преобразование цвета. Изначально я сделал в нейронной сети прямой ввод-вывод RGB-в-RGB. Это работает достаточно хорошо, но я хотел использовать пространство яркости-цветности для выполнения сопоставления, чтобы потенциально лучше преобразовать отношения яркости и цветового контраста. Хотя изначально я делал это в CIELAB и YCbCr, преобразование набора данных в любое пространство в конечном итоге является неточным, поскольку данные представляют данные сцены HDR в логарифмическом контейнере, и ни одно из пространств не создано для представления сцены HDR. Поэтому я пытаюсь использовать пространство ICtCP Dolby, которое построено на основе неограниченной линейной информации сцены. Я выполнил преобразование в пространство и подтвердил правильность вывода и структуры массива. Однако после подачи переменных в сеть она сразу же начнет давать астрономические потери, прежде чем переключиться на потери inf или nan. Я не могу понять, в чем проблема.

Я использую библиотеку цветоведения для внутренних преобразований цвета и протестировал как пользовательскую потерю специально для пространства ICtCP, так и mse, встроенную в TF (чтобы убедиться, что это не проблема с форматированием). Оба принесли мне огромную потерю. Я также распечатал значения RGB и ICtCP в текстовые файлы, чтобы убедиться в отсутствии значений, выходящих за пределы диапазона, но проблема была не в этом. Значения RGB находятся в диапазоне 0–1, а значения ICtCp — в диапазоне I (0:1), Ct (-1:1), Cp (-1:1).

Мои функции преобразования цвета внутри и вне ICtCP

#Davinci Wide Gamut Intermediate to Dolby ICtCP HDR opponent space
def DWG_TO_ITP(rgb_values):
    cs = colour.models.RGB_COLOURSPACE_DAVINCI_WIDE_GAMUT
    
    #DWG DI to XYZ Linear
    xyzLin = colour.RGB_to_XYZ(rgb_values, cs.whitepoint, cs.whitepoint, cs.matrix_RGB_to_XYZ, cctf_decoding=cs.cctf_decoding)
    
    #XYZ to ICtCp
    ictcp = colour.XYZ_to_ICtCp(xyzLin)
    
    return ictcp

# Dolby ICtCp HDR opponent space to Davinci Wide Gamut Intermediate
def ITP_TO_DWG(itp_values):

    cs = colour.models.RGB_COLOURSPACE_DAVINCI_WIDE_GAMUT
    
    #ICtCp to XYZ
    xyzLin = colour.ICtCp_to_XYZ(itp_values)
    
    #XYZ Linear to DWG DI
    dwg = colour.XYZ_to_RGB(xyzLin, cs.whitepoint, cs.whitepoint, cs.matrix_XYZ_to_RGB, cctf_encoding=cs.cctf_encoding)
    
    return dwg

Пользовательский убыток (в настоящее время не активен)

def ITP_loss(y_true, y_pred):
    
    # Split the ICtCp values into I, T, and P components
    I_1, T_1, P_1 = tf.split(y_true, 3, axis=-1)
    I_2, T_2, P_2 = tf.split(y_pred, 3, axis=-1)

    
    # Adjust the T components as in the original delta_E_ITP function
    T_1 = T_1 * 0.5
    T_2 = T_2 * 0.5

    # Compute the squared differences
    d_E_ITP = 720 * tf.sqrt(
        tf.square(I_2 - I_1) +
        tf.square(T_2 - T_1) +
        tf.square(P_2 - P_1)
    )
    
    # Return the mean error as the loss
    return tf.reduce_mean(d_E_ITP)

Моя нейронная сеть

def transform_nn(combined_rgb_values, output_callback, epochs=10000, batch_size=32):
    source_rgb = np.vstack([rgb_pair[0] for rgb_pair in combined_rgb_values])
    target_rgb = np.vstack([rgb_pair[1] for rgb_pair in combined_rgb_values])

    source_itp = DWG_TO_ITP(source_rgb)
    target_itp = DWG_TO_ITP(target_rgb)
    
    # Neural network base model with L2 regularization
    alpha = 0  # no penalty for now
    model = keras.Sequential([
        keras.layers.Input(shape=(3,)),
        keras.layers.Dense(128, activation = 'gelu', kernel_regularizer = keras.regularizers.L2(alpha)),
        keras.layers.Dense(64, activation = 'gelu', kernel_regularizer = keras.regularizers.L2(alpha)),
        keras.layers.Dense(32, activation = 'gelu', kernel_regularizer = keras.regularizers.L2(alpha)),
        keras.layers.Dense(3,)
    ])

    # Model optimization with Adam
    adam_optimizer = keras.optimizers.Adam(learning_rate=0.001)
    model.compile(
        optimizer= adam_optimizer,
        loss= "mean_squared_error",
        metrics=['mean_squared_error'])
    
    #normal
    early_stopping_norm = EarlyStopping(
        monitor = 'val_loss',
        patience = 30,
        verbose=1,
        restore_best_weights=True
    )
    
    # Train without early stopping
    history = model.fit(x=source_itp, y=target_itp,
                        epochs=epochs, batch_size=batch_size, 
                        verbose = "auto", validation_split=0.3, 
                        callbacks=[early_stopping_norm])
    
    def interpolator(input_rgb):
        input_itp = DWG_TO_ITP(input_rgb)
        output_itp = model.predict(input_itp)
        output_rgb = ITP_TO_DWG(output_itp)
        return output_rgb
    
    return interpolator

И, наконец, тип потерь, которые я получаю. Примечание: это потери «mean_squared_error», встроенные в компилятор, но аналогичные экстремальные значения входят в состав пользовательских потерь. Я никогда не сталкивался с этой проблемой ни при реализации CIELAB, ни при реализации YCbCr.

Epoch 1/10000
  70/9078 [..............................] - ETA: 6s - loss: 19151210161612029119172287351962936121302040109299793920.0000 - mean_squared_error: 1915121016161202911917228735196293612130 150/9078 [..............................] - ETA: 6s - loss: 8937231408752302941104862160146934414914780835554000896.0000 - mean_squared_error: 89372314087523029411048621601469344149147 236/9078 [..............................] - ETA: 5s - loss: 8411239422438024050387858001836140461389620391983448064.0000 - mean_squared_error: 84112394224380240503878580018361404613896 322/9078 [>.............................] - ETA: 5s - loss: 55694874365583834449267799576553768559551931724848365789378071082067252634355826658906428848197067342214382161787617280.0000 407/9078 [>.............................] - ETA: 5s - loss: 9272320170949610945087897503565859983725183487173275717008470165482614622395441710684957926712521227412477496744314184602899 494/9078 [>.............................] - ETA: 5s - loss: inf - mean_squared_error: inf                                                                                               9078/9078 [==============================] - 7s 686us/step - loss: nan - mean_squared_error: nan - val_loss: nan - val_mean_squared_error: nan                                         
Epoch 2/10000
9078/9078 [==============================] - 6s 682us/step - loss: nan - mean_squared_error: nan - val_loss: nan - val_mean_squared_error: nan
Epoch 3/10000
8987/9078 [============================>.] - ETA: 0s - loss: nan - mean_squared_error: nan%
Почему в Python есть оператор "pass"?
Почему в Python есть оператор "pass"?
Оператор pass в Python - это простая концепция, которую могут быстро освоить даже новички без опыта программирования.
Некоторые методы, о которых вы не знали, что они существуют в Python
Некоторые методы, о которых вы не знали, что они существуют в Python
Python - самый известный и самый простой в изучении язык в наши дни. Имея широкий спектр применения в области машинного обучения, Data Science,...
Основы Python Часть I
Основы Python Часть I
Вы когда-нибудь задумывались, почему в программах на Python вы видите приведенный ниже код?
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
Алиса и Боб имеют неориентированный граф из n узлов и трех типов ребер:
Оптимизация кода с помощью тернарного оператора Python
Оптимизация кода с помощью тернарного оператора Python
И последнее, что мы хотели бы показать вам, прежде чем двигаться дальше, это
Советы по эффективной веб-разработке с помощью Python
Советы по эффективной веб-разработке с помощью Python
Как веб-разработчик, Python может стать мощным инструментом для создания эффективных и масштабируемых веб-приложений.
1
0
54
1
Перейти к ответу Данный вопрос помечен как решенный

Ответы 1

Ответ принят как подходящий

Я подозреваю, что это потому, что преобразование ICtCp в XYZ может привести, например, к NaN(0.24, -0.42, 0.48). В этом случае S' ~= -0.15 и с помощью передаточной функции PQ (EOTF) мы пытаемся возвести это в степень (1 / 78.84375), что неясно, как с этим справиться.

Ого, окей. Даже не рассматривал такую ​​возможность. Хотя я немного озадачен тем, как была разработана эта спецификация, если такой случай мог произойти. Если только этого не произойдет с 10/12-битными целочисленными сигналами, которые фактически будут использовать его, в отличие от обработки с плавающей запятой для операций сопоставления. Тем не менее, я переключил метод в функции на использование передаточной функции HLG, и, похоже, он работает должным образом, по крайней мере, в отчетах об ошибках и ожидаемых значениях. Спасибо!

Errick Jackson 13.08.2024 20:07

Это сбило меня с толку, когда я пытался протестировать свой код с помощью двусторонних преобразований. Существует удивительное количество цветовых пространств, которые могут привести к NaN, например ICtCp, JzAzBz, CAM02, CAM16, и это не очень хорошо!

Wacton 14.08.2024 10:34

Другие вопросы по теме