Я хочу точно настроить BERT для конкретного набора данных. Моя проблема в том, что я не хочу случайным образом маскировать некоторые токены моего набора обучающих данных, но я уже выбрал, какие токены хочу замаскировать (по определенным причинам).
Для этого я создал набор данных с двумя столбцами: text
, в котором некоторые токены были заменены на [MASK]
(мне известно, что некоторые слова можно маркировать более чем одним токеном, и я позаботился об этом) и label
где у меня есть весь текст.
Теперь я хочу точно настроить модель BERT (скажем, bert-base-uncased), используя библиотеку transformers
Hugging Face, но я не хочу использовать DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.2)
, где маскирование выполняется случайным образом, и я могу контролировать только вероятность. Что я могу сделать?
возможно, вам нужна специфичная для домена адаптация BERT. до сих пор я также не смог найти настроенную маскировку. но я нашел эту статью полезной PERL: адаптация предметной области на основе Pivot для предварительно обученных глубоких Контекстуализированные модели внедрения если у кого-нибудь есть способ настроить маскировку для BertForMaskedLM пожалуйста, помогите
Это то, что я сделал, чтобы решить свою проблему. Я создал собственный класс и изменил токенизацию так, как мне было нужно (замаскировать один из числовых диапазонов во входных данных).
class CustomDataCollator(DataCollatorForLanguageModeling):
mlm: bool = True
return_tensors: str = "pt"
def __post_init__(self):
if self.mlm and self.tokenizer.mask_token is None:
raise ValueError(
"This tokenizer does not have a mask token which is necessary "
"for masked language modeling. You should pass `mlm=False` to "
"train on causal language modeling instead."
)
def torch_mask_tokens(self, inputs, special_tokens_mask):
"""
Prepare masked tokens inputs/labels for masked language modeling.
NOTE: keep `special_tokens_mask` as an argument for avoiding error
"""
# labels is batch_size x length of the sequence tensor
# with the original token id
# the length of the sequence includes the special tokens (2)
labels = inputs.clone()
batch_size = inputs.size(0)
# seq_len = inputs.size(1)
# in each seq, find the indices of the tokens that represent digits
dig_ids = [1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022, 1023]
dig_idx = torch.zeros_like(labels)
for dig_id in dig_ids:
dig_idx += (labels == dig_id)
dig_idx = dig_idx.bool()
# in each seq, find the spans of Trues using `find_spans` function
spans = []
for i in range(batch_size):
spans.append(find_spans(dig_idx[i].tolist()))
masked_indices = torch.zeros_like(labels)
# spans is a list of lists of tuples
# in each tuple, the first element is the start index
# and the second element is the length
# in each child list, choose a random tuple
for i in range(batch_size):
if len(spans[i]) > 0:
idx = torch.randint(0, len(spans[i]), (1,))
start, length = spans[i][idx[0]]
masked_indices[i, start:start + length] = 1
else:
print("No digit found in the sequence!")
masked_indices = masked_indices.bool()
# We only compute loss on masked tokens
labels[~masked_indices] = -100
# change the input's masked_indices to self.tokenizer.mask_token
inputs[masked_indices] = self.tokenizer.mask_token_id
return inputs, labels
def find_spans(lst):
spans = []
for k, g in groupby(enumerate(lst), key=itemgetter(1)):
if k:
glist = list(g)
spans.append((glist[0][0], len(glist)))
return spans