openenv-code-debugger / inference.py
arnavzz
refactor: simplify and fix efficiency issues
3faaaa0
"""
Inference Script — Code Debug OpenEnv
======================================
MANDATORY environment variables:
API_BASE_URL The API endpoint for the LLM.
MODEL_NAME The model identifier to use for inference.
HF_TOKEN Your Hugging Face / API key.
ENV_URL URL of the running OpenEnv server (default: http://localhost:7860)
STDOUT FORMAT:
[START] task=<task_id> env=<env_url> model=<model_name>
[STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null>
[END] success=<true|false> steps=<n> score=<0.000> rewards=<r1,r2,...>
"""
import os
import sys
import textwrap
from typing import Optional
import httpx
from openai import OpenAI
# ------------------------------------------------------------------
# Configuration
# ------------------------------------------------------------------
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
HF_TOKEN = os.getenv("HF_TOKEN", "")
ENV_URL = os.getenv("ENV_URL", "http://localhost:7860")
TEMPERATURE = 0.2
MAX_TOKENS = 2048
SUCCESS_THRESHOLD = 1.0 # require all tests passing
SYSTEM_PROMPT = textwrap.dedent("""
You are an expert Python debugger.
You will receive:
1. A description of the task
2. The buggy Python code
3. Descriptions of what each test checks
4. (From step 2 onwards) Test results showing which tests passed or failed,
with actual vs expected values and any error messages
Your job: return ONLY the corrected Python code with all bugs fixed.
Rules:
- Output raw Python code only — no markdown fences, no explanations
- Include the complete function definition(s), not just the changed lines
- Make sure all tests pass
""").strip()
# ------------------------------------------------------------------
# Logging helpers (strict format per submission spec)
# ------------------------------------------------------------------
def log_start(task: str, env: str, model: str) -> None:
print(f"[START] task={task} env={env} model={model}", flush=True)
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
error_val = error if error else "null"
done_val = str(done).lower()
safe_action = action.replace("\n", "\\n")[:120]
print(
f"[STEP] step={step} action={safe_action!r} reward={reward:.2f} done={done_val} error={error_val}",
flush=True,
)
def log_end(success: bool, steps: int, score: float, rewards: list[float]) -> None:
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
print(
f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}",
flush=True,
)
# ------------------------------------------------------------------
# Prompt builders
# ------------------------------------------------------------------
def build_initial_prompt(obs: dict) -> str:
lines = [
f"## Task\n{obs['description']}",
f"\n## Buggy Code\n```python\n{obs['buggy_code']}\n```",
"\n## Tests that must pass",
]
for desc in obs.get("test_descriptions", []):
lines.append(f"- {desc}")
return "\n".join(lines)
def build_feedback_prompt(obs: dict) -> str:
lines = ["## Test Results from your last submission\n"]
for tr in obs.get("test_results", []):
status = "PASS" if tr["passed"] else "FAIL"
lines.append(f"[{status}] {tr['test_name']}")
if not tr["passed"]:
lines.append(f" expected : {tr['expected']}")
lines.append(f" actual : {tr['actual']}")
if tr.get("error"):
lines.append(f" error : {tr['error']}")
if obs.get("stderr"):
lines.append(f"\n## Stderr\n{obs['stderr'][:500]}")
lines.append("\nFix the remaining failures and return the complete corrected code.")
return "\n".join(lines)
def strip_fences(text: str) -> str:
text = text.strip()
text = text.removeprefix("```python").removeprefix("```").strip()
text = text.removesuffix("```").strip()
return text
# ------------------------------------------------------------------
# Single episode runner
# ------------------------------------------------------------------
def run_episode(http: httpx.Client, client: OpenAI, task_id: str) -> dict:
# Reset
reset_resp = http.post("/reset", json={"task_id": task_id})
reset_resp.raise_for_status()
reset_data = reset_resp.json()
episode_id = reset_data["episode_id"]
obs = reset_data["observation"]
max_steps = obs.get("max_steps", 5)
log_start(task=task_id, env=ENV_URL, model=MODEL_NAME)
messages: list[dict] = [{"role": "system", "content": SYSTEM_PROMPT}]
rewards: list[float] = []
steps_taken = 0
error_msg: Optional[str] = None
try:
for step in range(1, max_steps + 1):
# Build user message
if step == 1:
user_content = build_initial_prompt(obs)
else:
user_content = build_feedback_prompt(obs)
messages.append({"role": "user", "content": user_content})
# LLM call
try:
completion = client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
temperature=TEMPERATURE,
max_tokens=MAX_TOKENS,
)
fixed_code = strip_fences(completion.choices[0].message.content or "")
messages.append({"role": "assistant", "content": fixed_code})
except Exception as exc:
error_msg = str(exc)
log_step(step=step, action="llm_error", reward=0.0, done=False, error=error_msg)
break
# Step environment
step_resp = http.post(
f"/step/{episode_id}",
json={"action": {"code": fixed_code}},
)
step_resp.raise_for_status()
step_data = step_resp.json()
obs = step_data["observation"]
reward = step_data["reward"]
done = step_data["done"]
rewards.append(reward)
steps_taken = step
log_step(step=step, action=fixed_code, reward=reward, done=done, error=None)
if done:
break
except Exception as exc:
error_msg = str(exc)
score = rewards[-1] if rewards else 0.0
success = score >= SUCCESS_THRESHOLD
log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
return {
"task_id": task_id,
"episode_id": episode_id,
"success": success,
"score": score,
"steps": steps_taken,
}
# ------------------------------------------------------------------
# Main: run all tasks
# ------------------------------------------------------------------
def main():
client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN or "EMPTY")
http = httpx.Client(base_url=ENV_URL, timeout=60.0)
# Discover available tasks from the server
try:
tasks_resp = http.get("/tasks")
tasks_resp.raise_for_status()
all_tasks = [t["task_id"] for t in tasks_resp.json()]
except Exception as exc:
print(f"[ERROR] Could not fetch task list: {exc}", file=sys.stderr, flush=True)
sys.exit(1)
results = []
for task_id in all_tasks:
result = run_episode(http, client, task_id)
results.append(result)
# Summary
total = len(results)
solved = sum(1 for r in results if r["success"])
avg = sum(r["score"] for r in results) / total if total else 0.0
print(f"\n=== SUMMARY: solved={solved}/{total} avg_score={avg:.3f} ===", flush=True)
if __name__ == "__main__":
main()