Aggregates all v0.2.0 sprint work (GARAA-30 through GARAA-40) and fixes 2 integration tests that broke when the codebase went async (DSPyLLMAdapter and full pipeline tests now properly await coroutines). 277 tests pass (260 unit + 17 integration). Co-Authored-By: Paperclip <noreply@paperclip.ing>
87 lines
3.0 KiB
Python
87 lines
3.0 KiB
Python
"""Tests for the ground-truth dataset loader."""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import os
|
|
import tempfile
|
|
|
|
import pytest
|
|
|
|
from prometheus.domain.entities import GroundTruthExample
|
|
from prometheus.infrastructure.dataset_loader import FileDatasetLoader
|
|
|
|
|
|
@pytest.fixture
|
|
def loader():
|
|
return FileDatasetLoader()
|
|
|
|
|
|
class TestCsvLoader:
|
|
def test_load_csv(self, loader, tmp_path):
|
|
csv_file = tmp_path / "test.csv"
|
|
csv_file.write_text("input,expected_output\nhello,world\nfoo,bar\n")
|
|
result = loader.load(str(csv_file))
|
|
assert len(result) == 2
|
|
assert result[0].input_text == "hello"
|
|
assert result[0].expected_output == "world"
|
|
assert result[1].input_text == "foo"
|
|
assert result[1].expected_output == "bar"
|
|
|
|
def test_load_csv_skips_empty_input(self, loader, tmp_path):
|
|
csv_file = tmp_path / "test.csv"
|
|
csv_file.write_text("input,expected_output\n,bar\nhello,world\n")
|
|
result = loader.load(str(csv_file))
|
|
assert len(result) == 1
|
|
assert result[0].input_text == "hello"
|
|
|
|
def test_load_csv_with_whitespace(self, loader, tmp_path):
|
|
csv_file = tmp_path / "test.csv"
|
|
csv_file.write_text("input,expected_output\n hello , world \n")
|
|
result = loader.load(str(csv_file))
|
|
assert result[0].input_text == "hello"
|
|
assert result[0].expected_output == "world"
|
|
|
|
def test_load_csv_empty_file(self, loader, tmp_path):
|
|
csv_file = tmp_path / "test.csv"
|
|
csv_file.write_text("input,expected_output\n")
|
|
result = loader.load(str(csv_file))
|
|
assert len(result) == 0
|
|
|
|
|
|
class TestJsonLoader:
|
|
def test_load_json(self, loader, tmp_path):
|
|
json_file = tmp_path / "test.json"
|
|
data = [
|
|
{"input": "hello", "expected_output": "world"},
|
|
{"input": "foo", "expected_output": "bar"},
|
|
]
|
|
json_file.write_text(json.dumps(data))
|
|
result = loader.load(str(json_file))
|
|
assert len(result) == 2
|
|
assert result[0].input_text == "hello"
|
|
assert result[0].expected_output == "world"
|
|
|
|
def test_load_json_skips_empty_input(self, loader, tmp_path):
|
|
json_file = tmp_path / "test.json"
|
|
data = [
|
|
{"input": "", "expected_output": "bar"},
|
|
{"input": "hello", "expected_output": "world"},
|
|
]
|
|
json_file.write_text(json.dumps(data))
|
|
result = loader.load(str(json_file))
|
|
assert len(result) == 1
|
|
|
|
def test_load_json_not_array_raises(self, loader, tmp_path):
|
|
json_file = tmp_path / "test.json"
|
|
json_file.write_text(json.dumps({"not": "an array"}))
|
|
with pytest.raises(ValueError, match="must be an array"):
|
|
loader.load(str(json_file))
|
|
|
|
|
|
class TestUnsupportedFormat:
|
|
def test_unsupported_extension_raises(self, loader, tmp_path):
|
|
txt_file = tmp_path / "test.txt"
|
|
txt_file.write_text("hello")
|
|
with pytest.raises(ValueError, match="Unsupported dataset format"):
|
|
loader.load(str(txt_file))
|