|
|
|
|
| |
|
|
| import os |
| import gradio as gr |
| from huggingface_hub import snapshot_download |
| from prettytable import PrettyTable |
| import pandas as pd |
| import torch |
| import traceback |
|
|
| config = { |
| "model_type": "roberta", |
| "model_name_or_path": "roberta-large", |
| "logic_lambda": 0.5, |
| "prior": "random", |
| "mask_rate": 0.0, |
| "cand_k": 1, |
| "max_seq1_length": 256, |
| "max_seq2_length": 128, |
| "max_num_questions": 8, |
| "do_lower_case": False, |
| "seed": 42, |
| "n_gpu": torch.cuda.device_count(), |
| } |
|
|
| os.system('git clone https://github.com/kkpathak91/project_metch/') |
| os.system('rm -r project_metch/data/') |
| os.system('rm -r project_metch/results/') |
| os.system('rm -r project_metch/models/') |
| os.system('mv project_metch/* ./') |
|
|
| model_dir = snapshot_download('kkpathak91/FVM') |
| config['fc_dir'] = os.path.join(model_dir, 'fact_checking/roberta-large/') |
| config['mrc_dir'] = os.path.join(model_dir, 'mrc_seq2seq/bart-base/') |
| config['er_dir'] = os.path.join(model_dir, 'evidence_retrieval/') |
|
|
|
|
| from src.loren import Loren |
|
|
|
|
| loren = Loren(config, verbose=False) |
| try: |
| js = loren.check('Donald Trump won the 2020 U.S. presidential election.') |
| except Exception as e: |
| raise ValueError(e) |
|
|
|
|
| def highlight_phrase(text, phrase): |
| text = loren.fc_client.tokenizer.clean_up_tokenization(text) |
| return text.replace('<mask>', f'<i><b>{phrase}</b></i>') |
|
|
|
|
| def highlight_entity(text, entity): |
| return text.replace(entity, f'<i><b>{entity}</b></i>') |
|
|
|
|
| def gradio_formatter(js, output_type): |
| zebra_css = ''' |
| tr:nth-child(even) { |
| background: #f1f1f1; |
| } |
| thead{ |
| background: #f1f1f1; |
| }''' |
| if output_type == 'e': |
| data = {'Evidence': [highlight_entity(x, e) for x, e in zip(js['evidence'], js['entities'])]} |
| elif output_type == 'z': |
| p_sup, p_ref, p_nei = [], [], [] |
| for x in js['phrase_veracity']: |
| max_idx = torch.argmax(torch.tensor(x)).tolist() |
| x = ['%.4f' % xx for xx in x] |
| x[max_idx] = f'<i><b>{x[max_idx]}</b></i>' |
| p_sup.append(x[2]) |
| p_ref.append(x[0]) |
| p_nei.append(x[1]) |
|
|
| data = { |
| 'Claim Phrase': js['claim_phrases'], |
| 'Local Premise': [highlight_phrase(q, x[0]) for q, x in zip(js['cloze_qs'], js['evidential'])], |
| 'p_SUP': p_sup, |
| 'p_REF': p_ref, |
| 'p_NEI': p_nei, |
| } |
| else: |
| raise NotImplementedError |
| data = pd.DataFrame(data) |
| pt = PrettyTable(field_names=list(data.columns), |
| align='l', border=True, hrules=1, vrules=1) |
| for v in data.values: |
| pt.add_row(v) |
| html = pt.get_html_string(attributes={ |
| 'style': 'border-width: 2px; bordercolor: black' |
| }, format=True) |
| html = f'<head> <style type="text/css"> {zebra_css} </style> </head>\n' + html |
| html = html.replace('<', '<').replace('>', '>') |
| return html |
|
|
|
|
| def run(claim): |
| try: |
| js = loren.check(claim) |
| except Exception as error_msg: |
| exc = traceback.format_exc() |
| msg = f'[Error]: {error_msg}.\n[Traceback]: {exc}' |
| loren.logger.error(claim) |
| loren.logger.error(msg) |
| return 'Oops, something went wrong.', '', '' |
| label = js['claim_veracity'] |
| loren.logger.warning(label + str(js)) |
| ev_html = gradio_formatter(js, 'e') |
| z_html = gradio_formatter(js, 'z') |
| return label, z_html, ev_html |
|
|
|
|
| iface = gr.Interface( |
| fn=run, |
| inputs="text", |
| outputs=[ |
| 'text', |
| 'html', |
| 'html', |
| ], |
| examples=['Kanpur is a city in Nepal', |
| 'PV Sindhu is an Indian Badminton Player.'], |
| title="A Framework for Data-Driven Document Evaluation and Scoring", |
| layout='horizontal', |
| description="[Student Name: Karan Kumar Pathak] " " [Roll No.: 2020fc04334] ", |
| flagging_dir='results/flagged/', |
| allow_flagging=True, |
| flagging_options=['Interesting!', 'Error: Claim Phrase Parsing', 'Error: Local Premise', |
| 'Error: Require Commonsense', 'Error: Evidence Retrieval'], |
| enable_queue=True |
| ) |
| iface.launch() |