1. 定义任务
from openprompt.data_utils import InputExample
classes=[
'negative',
'positive'
]
dataset=[
InputExample(
guid = 0,
text_a = "Albert Einstein was one of the greatest intellects of his time.",
),
InputExample(
guid = 1,
text_a = "The film was badly made.",
),
]
2. 定义预训练语言模型
from openprompt.plms import load_plm
plm,tokenizer,model_config,WrapperClass=load_plm('bert',"bert-base-cased")
3. 定义prompt模板
from openprompt.prompts import ManualTemplate
promptTemplate=ManualTemplate(
text='{"placeholder":"text_a"} It was {"mask"}',
tokenizer=tokenizer,
)
4. 定义输出-label映射
from openprompt.prompts import ManualVerbalizer
promptVerbalizer=ManualVerbalizer(
classes=classes,
label_words={
'negative':['bad'],
'positive':['good','wonderful','great'],
},
tokenizer=tokenizer,
)
5. 组合构建为PromptModel类
from openprompt import PromptForClassification
promptModel=PromptForClassification(
template=promptTemplate,
plm=plm,
verbalizer=promptVerbalizer,
)
6. 定义dataloader
from openprompt import PromptDataLoader
data_loader=PromptDataLoader(
dataset=dataset,
tokenizer=tokenizer,
template=promptTemplate,
tokenizer_wrapper_class=WrapperClass,
)
7. 开始训练、测试
# making zero-shot inference using pretrained MLM with prompt
promptModel.eval()
with torch.no_grad():
for batch in data_loader:
logits=promptModel(batch)
preds=torch.argmax(logits,dim=-1)
print(classes[preds])
# predictions would be 1, 0 for classes 'positive', 'negative'
参考
知乎
版权声明:本文为qq_42801194原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。