Точная настройка BERT с детерминированной маскировкой вместо случайной маскировки

Я хочу точно настроить BERT для конкретного набора данных. Моя проблема в том, что я не хочу случайным образом маскировать некоторые токены моего набора обучающих данных, но я уже выбрал, какие токены хочу замаскировать (по определенным причинам).

Для этого я создал набор данных с двумя столбцами: text, в котором некоторые токены были заменены на [MASK] (мне известно, что некоторые слова можно маркировать более чем одним токеном, и я позаботился об этом) и label где у меня есть весь текст.

Теперь я хочу точно настроить модель BERT (скажем, bert-base-uncased), используя библиотеку transformers Hugging Face, но я не хочу использовать DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.2), где маскирование выполняется случайным образом, и я могу контролировать только вероятность. Что я могу сделать?

Тонкая настройка GPT-3 с помощью Anaconda
Тонкая настройка GPT-3 с помощью Anaconda
Зарегистрируйте аккаунт Open ai, а затем получите ключ API ниже.
0
0
124
2
Перейти к ответу Данный вопрос помечен как решенный

Ответы 2

возможно, вам нужна специфичная для домена адаптация BERT. до сих пор я также не смог найти настроенную маскировку. но я нашел эту статью полезной PERL: адаптация предметной области на основе Pivot для предварительно обученных глубоких Контекстуализированные модели внедрения если у кого-нибудь есть способ настроить маскировку для BertForMaskedLM пожалуйста, помогите

Ответ принят как подходящий

Это то, что я сделал, чтобы решить свою проблему. Я создал собственный класс и изменил токенизацию так, как мне было нужно (замаскировать один из числовых диапазонов во входных данных).

class CustomDataCollator(DataCollatorForLanguageModeling):

    mlm: bool = True
    return_tensors: str = "pt"

    def __post_init__(self):
        if self.mlm and self.tokenizer.mask_token is None:
            raise ValueError(
                "This tokenizer does not have a mask token which is necessary "
                "for masked language modeling. You should pass `mlm=False` to "
                "train on causal language modeling instead."
            )

    def torch_mask_tokens(self, inputs, special_tokens_mask):
        """
        Prepare masked tokens inputs/labels for masked language modeling.
        NOTE: keep `special_tokens_mask` as an argument for avoiding error
        """

        # labels is batch_size x length of the sequence tensor
        # with the original token id
        # the length of the sequence includes the special tokens (2)
        labels = inputs.clone()

        batch_size = inputs.size(0)
        # seq_len = inputs.size(1)
        # in each seq, find the indices of the tokens that represent digits
        dig_ids = [1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022, 1023]
        dig_idx = torch.zeros_like(labels)
        for dig_id in dig_ids:
            dig_idx += (labels == dig_id)
        dig_idx = dig_idx.bool()
        # in each seq, find the spans of Trues using `find_spans` function
        spans = []
        for i in range(batch_size):
            spans.append(find_spans(dig_idx[i].tolist()))
        masked_indices = torch.zeros_like(labels)
        # spans is a list of lists of tuples
        # in each tuple, the first element is the start index
        # and the second element is the length
        # in each child list, choose a random tuple
        for i in range(batch_size):
            if len(spans[i]) > 0:
                idx = torch.randint(0, len(spans[i]), (1,))
                start, length = spans[i][idx[0]]
                masked_indices[i, start:start + length] = 1
            else:
                print("No digit found in the sequence!")
        masked_indices = masked_indices.bool()

        # We only compute loss on masked tokens
        labels[~masked_indices] = -100

        # change the input's masked_indices to self.tokenizer.mask_token
        inputs[masked_indices] = self.tokenizer.mask_token_id

        return inputs, labels

def find_spans(lst):
    spans = []
    for k, g in groupby(enumerate(lst), key=itemgetter(1)):
        if k:
            glist = list(g)
            spans.append((glist[0][0], len(glist)))

    return spans

Другие вопросы по теме