Michael Rabinovich Cursor commited on
Commit
1fd03de
·
1 Parent(s): b430299

Retry straggler shards on poll-deadline instead of failing the submission

Browse files

When the Space-side shard poll window elapses with shards still
non-terminal (typically the tail shard stuck QUEUED waiting for
a10g-large GPU capacity, not a compute failure), cancel and re-dispatch
those stragglers and reset the window, up to SHARD_DEADLINE_RETRY_ROUNDS
(2) rounds, before failing. All-or-nothing is preserved; shard uploads
are idempotent so a re-dispatch is safe. Fixes large (full-81) runs that
intermittently failed with "Space-side poll deadline exceeded (2700s)".

Co-authored-by: Cursor <cursoragent@cursor.com>

Files changed (1) hide show
  1. submit.py +74 -1
submit.py CHANGED
@@ -242,6 +242,14 @@ SHARD_MAX_RETRIES = 1
242
  # vs. the per-shard ceiling because queued shards (past the ~8
243
  # concurrent slots) wait their turn before their own timeout starts.
244
  SHARD_POLL_DEADLINE_SECONDS = 45 * 60
 
 
 
 
 
 
 
 
245
 
246
  # One HfApi client per process. HF_TOKEN is picked up from the env at
247
  # construction time and reused for every call.
@@ -1312,6 +1320,31 @@ def _dispatch_shard(
1312
  )
1313
 
1314
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1315
  def _poll_shards_until_done(
1316
  submission_id: str,
1317
  submission_blob_url: str,
@@ -1327,8 +1360,16 @@ def _poll_shards_until_done(
1327
  after their retries (empty list means every shard COMPLETED).
1328
  Transient ``inspect_job`` failures retry up to
1329
  :data:`JOB_POLL_MAX_CONSECUTIVE_ERRORS` before raising.
 
 
 
 
 
 
 
1330
  """
1331
  deadline = time.monotonic() + SHARD_POLL_DEADLINE_SECONDS
 
1332
  consecutive_errors = 0
1333
  last_done = -1
1334
  total = len(shards)
@@ -1393,12 +1434,44 @@ def _poll_shards_until_done(
1393
  )
1394
 
1395
  if time.monotonic() >= deadline:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1396
  for shard_id, st in shards.items():
1397
  if st["stage"] not in ("COMPLETED", "FAILED"):
1398
  st["stage"] = "FAILED"
1399
  st["message"] = (
1400
  f"Space-side poll deadline exceeded "
1401
- f"({SHARD_POLL_DEADLINE_SECONDS}s)"
 
1402
  )
1403
  break
1404
  time.sleep(JOB_POLL_INTERVAL_SECONDS)
 
242
  # vs. the per-shard ceiling because queued shards (past the ~8
243
  # concurrent slots) wait their turn before their own timeout starts.
244
  SHARD_POLL_DEADLINE_SECONDS = 45 * 60
245
+ # When the poll window elapses with shards still non-terminal — typically the
246
+ # tail shard stuck QUEUED waiting for GPU (a10g-large) capacity rather than a
247
+ # compute failure — re-dispatch those stragglers and reset the window, up to
248
+ # this many rounds, before giving up. A fresh dispatch after the window can land
249
+ # a freed slot; shard uploads are idempotent (each rewrites its own staging
250
+ # prefix), so a re-dispatch is safe. Worst-case total wait is roughly
251
+ # SHARD_POLL_DEADLINE_SECONDS * (1 + SHARD_DEADLINE_RETRY_ROUNDS).
252
+ SHARD_DEADLINE_RETRY_ROUNDS = 2
253
 
254
  # One HfApi client per process. HF_TOKEN is picked up from the env at
255
  # construction time and reused for every call.
 
1320
  )
1321
 
1322
 
1323
+ def _cancel_shard_job(state: dict[str, Any]) -> None:
1324
+ """Best-effort cancel of a shard's in-flight job before re-dispatch.
1325
+
1326
+ Used on the poll-deadline retry path so a straggler that is still
1327
+ QUEUED/RUNNING releases its slot and does not keep writing its staging
1328
+ prefix once a replacement is dispatched. Best-effort: a failure is
1329
+ logged and ignored, since shard uploads are idempotent (a stale job
1330
+ only ever overwrites its own prefix with an equivalent result).
1331
+ """
1332
+ job_id = state.get("job_id")
1333
+ if not job_id:
1334
+ return
1335
+ try:
1336
+ from huggingface_hub import cancel_job
1337
+
1338
+ cancel_job(
1339
+ job_id=job_id,
1340
+ namespace=EVAL_JOB_NAMESPACE,
1341
+ token=_jobs_token(),
1342
+ )
1343
+ logger.info("Cancelled straggler shard job %s before retry", job_id)
1344
+ except Exception as e: # noqa: BLE001 - cancel is best-effort
1345
+ logger.warning("Could not cancel shard job %s: %s", job_id, e)
1346
+
1347
+
1348
  def _poll_shards_until_done(
1349
  submission_id: str,
1350
  submission_blob_url: str,
 
1360
  after their retries (empty list means every shard COMPLETED).
1361
  Transient ``inspect_job`` failures retry up to
1362
  :data:`JOB_POLL_MAX_CONSECUTIVE_ERRORS` before raising.
1363
+
1364
+ If the :data:`SHARD_POLL_DEADLINE_SECONDS` window elapses with shards
1365
+ still non-terminal (the GPU-capacity-starvation case, where a tail
1366
+ shard sits QUEUED), those stragglers are cancelled and re-dispatched
1367
+ and the window resets, up to :data:`SHARD_DEADLINE_RETRY_ROUNDS`
1368
+ rounds, before the submission is finally failed. All-or-nothing is
1369
+ preserved: the list is non-empty unless every shard COMPLETED.
1370
  """
1371
  deadline = time.monotonic() + SHARD_POLL_DEADLINE_SECONDS
1372
+ deadline_rounds_left = SHARD_DEADLINE_RETRY_ROUNDS
1373
  consecutive_errors = 0
1374
  last_done = -1
1375
  total = len(shards)
 
1434
  )
1435
 
1436
  if time.monotonic() >= deadline:
1437
+ stragglers = [
1438
+ sid for sid, st in shards.items()
1439
+ if st["stage"] not in ("COMPLETED", "FAILED")
1440
+ ]
1441
+ if stragglers and deadline_rounds_left > 0:
1442
+ deadline_rounds_left -= 1
1443
+ logger.warning(
1444
+ "Poll deadline (%ds) hit for %s with %d straggler shard(s) "
1445
+ "%s; re-dispatching (%d round(s) left).",
1446
+ SHARD_POLL_DEADLINE_SECONDS, submission_id,
1447
+ len(stragglers), stragglers, deadline_rounds_left,
1448
+ )
1449
+ for sid in stragglers:
1450
+ st = shards[sid]
1451
+ _cancel_shard_job(st)
1452
+ # Give the replacement a fresh ERROR-retry budget too.
1453
+ st["attempts"] = 0
1454
+ _dispatch_shard(
1455
+ submission_id, submission_blob_url, sid, st,
1456
+ )
1457
+ progress.publish(
1458
+ submission_id,
1459
+ progress.RUNNING,
1460
+ f"GPU capacity was tight — retrying {len(stragglers)} "
1461
+ f"straggler chunk(s) (round "
1462
+ f"{SHARD_DEADLINE_RETRY_ROUNDS - deadline_rounds_left} of "
1463
+ f"{SHARD_DEADLINE_RETRY_ROUNDS})…",
1464
+ )
1465
+ deadline = time.monotonic() + SHARD_POLL_DEADLINE_SECONDS
1466
+ last_done = -1 # force a progress republish on the next sweep
1467
+ continue
1468
  for shard_id, st in shards.items():
1469
  if st["stage"] not in ("COMPLETED", "FAILED"):
1470
  st["stage"] = "FAILED"
1471
  st["message"] = (
1472
  f"Space-side poll deadline exceeded "
1473
+ f"({SHARD_POLL_DEADLINE_SECONDS}s) after "
1474
+ f"{SHARD_DEADLINE_RETRY_ROUNDS} retry round(s)"
1475
  )
1476
  break
1477
  time.sleep(JOB_POLL_INTERVAL_SECONDS)