| """Model manager for stance detection model""" |
|
|
| import os |
| import torch |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification |
| import logging |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class StanceModelManager: |
| """Manages stance detection model loading and predictions""" |
| |
| def __init__(self): |
| self.model = None |
| self.tokenizer = None |
| self.device = None |
| self.model_loaded = False |
| |
| def load_model(self, model_id: str, api_key: str = None): |
| """Load model and tokenizer from Hugging Face""" |
| if self.model_loaded: |
| logger.info("Stance model already loaded") |
| return |
| |
| try: |
| logger.info(f"Loading stance model from Hugging Face: {model_id}") |
| |
| |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| logger.info(f"Using device: {self.device}") |
| |
| |
| token = api_key if api_key else None |
| |
| |
| logger.info("Loading tokenizer...") |
| self.tokenizer = AutoTokenizer.from_pretrained( |
| model_id, |
| token=token, |
| trust_remote_code=True |
| ) |
| |
| logger.info("Loading model...") |
| self.model = AutoModelForSequenceClassification.from_pretrained( |
| model_id, |
| token=token, |
| trust_remote_code=True |
| ) |
| self.model.to(self.device) |
| self.model.eval() |
| |
| self.model_loaded = True |
| logger.info("✓ Stance model loaded successfully from Hugging Face!") |
| |
| except Exception as e: |
| logger.error(f"Error loading stance model: {str(e)}") |
| raise RuntimeError(f"Failed to load stance model: {str(e)}") |
| |
| def predict(self, topic: str, argument: str) -> dict: |
| """Make a single stance prediction""" |
| if not self.model_loaded: |
| raise RuntimeError("Stance model not loaded") |
| |
| |
| text = f"Topic: {topic} [SEP] Argument: {argument}" |
| |
| |
| inputs = self.tokenizer( |
| text, |
| return_tensors="pt", |
| truncation=True, |
| max_length=512, |
| padding=True |
| ).to(self.device) |
| |
| |
| with torch.no_grad(): |
| outputs = self.model(**inputs) |
| probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) |
| predicted_class = torch.argmax(probabilities, dim=-1).item() |
| |
| |
| prob_con = probabilities[0][0].item() |
| prob_pro = probabilities[0][1].item() |
| |
| |
| stance = "PRO" if predicted_class == 1 else "CON" |
| confidence = probabilities[0][predicted_class].item() |
| |
| return { |
| "predicted_stance": stance, |
| "confidence": confidence, |
| "probability_con": prob_con, |
| "probability_pro": prob_pro |
| } |
|
|
|
|
| |
| stance_model_manager = StanceModelManager() |
|
|
|
|