Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ Make a pull request (PR) from your fork into the main branch of WorkRB, followin
- [ ] Code follows project style guidelines
- [ ] Documentation updated
- [ ] No new warnings introduced
- [ ] If the rankings artifact schema changed: bumped SCHEMA_VERSION in workrb/rankings.py and updated SUPPORTED_SCHEMA_VERSIONS
```

### 4. Review Process
Expand Down
43 changes: 43 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ results = workrb.evaluate( # Returns BenchmarkResults (Pydantic model)
model,
tasks,
output_folder="results/my_model",
save_rankings=False, # Optional: store full per-target score arrays for ranking tasks
)
print(results) # Benchmark/Per-task/Per-language metrics
```
Expand Down Expand Up @@ -210,6 +211,48 @@ results/my_model/
└── config.yaml # Final benchmark configuration dump
```

If you pass `save_rankings=True` to `evaluate`, WorkRB also writes per-task,
per-dataset ranking score artifacts under a model-scoped subdirectory:

```
results/my_model/
└── rankings/
└── <model_name>/
└── <task_name>__<dataset_id>.json
```

Each JSON file has two top-level keys:

- `header`: identifies the artifact and pins it to a specific dataset
shape. Includes `schema_version`, `workrb_version`, `model_name`,
`task_name`, `dataset_id`, `split`, `num_queries`, `num_targets`,
plus four canary strings (`first_query_text`, `last_query_text`,
`first_target_text`, `last_target_text`) used to detect silent dataset
drift on replay.
- `scores`: a `{query_index: {target_index: score}}` mapping, with
keys as positional indices into the live dataset's `query_texts` /
`target_space` (stringified, as JSON requires string keys). Every
`(query, target)` cell is stored. Non-finite values (`NaN`, `+inf`, `-inf`) are
rejected at write time.

Once written, you can recompute metrics without rerunning the model by
pointing `evaluate_rankings` at the model-scoped directory:

```python
results = workrb.evaluate_rankings(
rankings_dir="results/my_model/rankings/my_model",
tasks=tasks,
output_folder="results/my_model_replay",
)
```

This is the recommended way to re-score after a metric definition has
changed (e.g. a new ranking metric is added in a workrb release): replay
is cheap, the model never runs again, and `validate_header` rejects any
artifact whose dataset shape no longer matches the live one. A
`workrb_version` mismatch only logs a warning; an unknown
`schema_version` is a hard reject.

To load & parse results from a run:

```python
Expand Down
11 changes: 10 additions & 1 deletion src/workrb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,15 @@

from workrb import data, metrics, models, tasks
from workrb.logging import setup_logger
from workrb.rankings import RankingsArtifactInvalid, RankingsArtifactMissing
from workrb.registry import list_available_tasks
from workrb.results import load_results
from workrb.run import evaluate, evaluate_multiple_models, get_tasks_overview
from workrb.run import (
evaluate,
evaluate_multiple_models,
evaluate_rankings,
get_tasks_overview,
)
from workrb.types import ExecutionMode, LanguageAggregationMode

# Configure 'workrb' logger to INFO level by default, by usage of package
Expand All @@ -15,9 +21,12 @@
__all__ = [
"ExecutionMode",
"LanguageAggregationMode",
"RankingsArtifactInvalid",
"RankingsArtifactMissing",
"data",
"evaluate",
"evaluate_multiple_models",
"evaluate_rankings",
"get_tasks_overview",
"list_available_tasks",
"load_results",
Expand Down
104 changes: 103 additions & 1 deletion src/workrb/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,36 @@

import json
import logging
import re
import time
from collections.abc import Sequence
from dataclasses import asdict, dataclass, field
from datetime import datetime, timezone
from importlib.metadata import PackageNotFoundError
from importlib.metadata import version as _pkg_version
from pathlib import Path
from typing import Any
from typing import TYPE_CHECKING, Any

import numpy as np
import yaml

from workrb.rankings import SCHEMA_VERSION, rankings_filename
from workrb.results import BenchmarkResults
from workrb.tasks.abstract import Task

if TYPE_CHECKING:
from workrb.tasks.abstract.ranking_base import RankingDataset

logger = logging.getLogger(__name__)


def _get_workrb_version() -> str:
try:
return _pkg_version("workrb")
except PackageNotFoundError:
return "unknown"


@dataclass
class BenchmarkConfig:
"""
Expand Down Expand Up @@ -134,6 +149,93 @@ def get_results_path(self) -> Path:
"""Get the path where final results should be saved."""
return self.get_output_path() / "results.json"

def get_rankings_dir(self) -> Path:
"""Get the directory where per-dataset ranking artifacts are saved.

Rankings are nested under a sanitized model-name directory so that
running multiple models into the same ``output_folder`` cannot clobber
each other's ranking files.
"""
safe_model_name = re.sub(r"[^A-Za-z0-9_.-]+", "_", self.model_name).strip("_")
return self.get_output_path() / "rankings" / safe_model_name

def get_task_rankings_path(self, task_name: str, dataset_id: str) -> Path:
"""Get the output path for one task/dataset ranking artifact."""
return self.get_rankings_dir() / rankings_filename(task_name, dataset_id)

def save_rankings_artifact(
self,
task_name: str,
dataset_id: str,
split: str,
dataset: "RankingDataset",
prediction_matrix: np.ndarray,
) -> Path:
"""Save the prediction matrix for one ``(task, dataset_id)`` as a JSON artifact.

Schema:
``{"header": {...metadata...}, "scores": {q_idx: {t_idx: score}}}``
Query and target keys are positional indices (stringified, since JSON
object keys must be strings); the dataset's row order at the pinned
workrb version is the implicit ID source. Every ``(q, t)`` cell is
stored.

Non-finite scores (``NaN``, ``+inf``, ``-inf``) are rejected: standard
JSON cannot represent them, and silently coercing would corrupt the
artifact for downstream readers.
"""
workrb_version = _get_workrb_version()

num_queries, num_targets = prediction_matrix.shape
if num_queries != len(dataset.query_texts):
raise ValueError(
f"prediction_matrix has {num_queries} rows but dataset has "
f"{len(dataset.query_texts)} queries"
)
if num_targets != len(dataset.target_space):
raise ValueError(
f"prediction_matrix has {num_targets} cols but dataset has "
f"{len(dataset.target_space)} targets"
)
if not np.all(np.isfinite(prediction_matrix)):
bad = np.argwhere(~np.isfinite(prediction_matrix))
sample = bad[0]
raise ValueError(
f"prediction_matrix contains non-finite values (e.g. at "
f"query_index={int(sample[0])}, target_index={int(sample[1])}); "
"JSON cannot represent NaN/inf, refusing to write a corrupt artifact"
)

rankings_path = self.get_task_rankings_path(task_name=task_name, dataset_id=dataset_id)
rankings_path.parent.mkdir(parents=True, exist_ok=True)

scores: dict[str, dict[str, float]] = {}
matrix = prediction_matrix.tolist()
for q_idx, row in enumerate(matrix):
scores[str(q_idx)] = {str(t_idx): float(score) for t_idx, score in enumerate(row)}

payload = {
"header": {
"schema_version": SCHEMA_VERSION,
"workrb_version": workrb_version,
"model_name": self.model_name,
"task_name": task_name,
"dataset_id": dataset_id,
"split": split,
"num_queries": int(num_queries),
"num_targets": int(num_targets),
"first_query_text": dataset.query_texts[0],
"last_query_text": dataset.query_texts[-1],
"first_target_text": dataset.target_space[0],
"last_target_text": dataset.target_space[-1],
},
"scores": scores,
}
with open(rankings_path, "w") as f:
json.dump(payload, f, indent=2, allow_nan=False)
logger.debug(f"Ranking artifact saved to {rankings_path}")
return rankings_path

def has_checkpoint(self) -> bool:
"""Check if a checkpoint exists."""
return self.get_checkpoint_path().exists()
Expand Down
Loading
Loading