prompt 代码示例

  • Post author:
  • Post category:其他




1. 定义任务

from openprompt.data_utils import InputExample
classes=[
    'negative',
    'positive'
]

dataset=[
    InputExample(
        guid = 0,
        text_a = "Albert Einstein was one of the greatest intellects of his time.",
    ),
    InputExample(
        guid = 1,
        text_a = "The film was badly made.",
    ),
]



2. 定义预训练语言模型

from openprompt.plms import load_plm
plm,tokenizer,model_config,WrapperClass=load_plm('bert',"bert-base-cased")



3. 定义prompt模板

from openprompt.prompts import ManualTemplate
promptTemplate=ManualTemplate(
    text='{"placeholder":"text_a"} It was {"mask"}',
    tokenizer=tokenizer,
)



4. 定义输出-label映射

from openprompt.prompts import ManualVerbalizer
promptVerbalizer=ManualVerbalizer(
    classes=classes,
    label_words={
        'negative':['bad'],
        'positive':['good','wonderful','great'],
    },
    tokenizer=tokenizer,
)



5. 组合构建为PromptModel类

from openprompt import PromptForClassification
promptModel=PromptForClassification(
    template=promptTemplate,
    plm=plm,
    verbalizer=promptVerbalizer,
)



6. 定义dataloader

from openprompt import PromptDataLoader
data_loader=PromptDataLoader(
    dataset=dataset,
    tokenizer=tokenizer,
    template=promptTemplate,
    tokenizer_wrapper_class=WrapperClass,
)



7. 开始训练、测试

# making zero-shot inference using pretrained MLM with prompt
promptModel.eval()
with torch.no_grad():
    for batch in data_loader:
        logits=promptModel(batch)
        preds=torch.argmax(logits,dim=-1)
        print(classes[preds])
        # predictions would be 1, 0 for classes 'positive', 'negative'

参考

知乎



版权声明:本文为qq_42801194原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。