"""Tests for the ground-truth dataset loader.""" from __future__ import annotations import json import os import tempfile import pytest from prometheus.domain.entities import GroundTruthExample from prometheus.infrastructure.dataset_loader import FileDatasetLoader @pytest.fixture def loader(): return FileDatasetLoader() class TestCsvLoader: def test_load_csv(self, loader, tmp_path): csv_file = tmp_path / "test.csv" csv_file.write_text("input,expected_output\nhello,world\nfoo,bar\n") result = loader.load(str(csv_file)) assert len(result) == 2 assert result[0].input_text == "hello" assert result[0].expected_output == "world" assert result[1].input_text == "foo" assert result[1].expected_output == "bar" def test_load_csv_skips_empty_input(self, loader, tmp_path): csv_file = tmp_path / "test.csv" csv_file.write_text("input,expected_output\n,bar\nhello,world\n") result = loader.load(str(csv_file)) assert len(result) == 1 assert result[0].input_text == "hello" def test_load_csv_with_whitespace(self, loader, tmp_path): csv_file = tmp_path / "test.csv" csv_file.write_text("input,expected_output\n hello , world \n") result = loader.load(str(csv_file)) assert result[0].input_text == "hello" assert result[0].expected_output == "world" def test_load_csv_empty_file(self, loader, tmp_path): csv_file = tmp_path / "test.csv" csv_file.write_text("input,expected_output\n") result = loader.load(str(csv_file)) assert len(result) == 0 class TestJsonLoader: def test_load_json(self, loader, tmp_path): json_file = tmp_path / "test.json" data = [ {"input": "hello", "expected_output": "world"}, {"input": "foo", "expected_output": "bar"}, ] json_file.write_text(json.dumps(data)) result = loader.load(str(json_file)) assert len(result) == 2 assert result[0].input_text == "hello" assert result[0].expected_output == "world" def test_load_json_skips_empty_input(self, loader, tmp_path): json_file = tmp_path / "test.json" data = [ {"input": "", "expected_output": "bar"}, {"input": "hello", "expected_output": "world"}, ] json_file.write_text(json.dumps(data)) result = loader.load(str(json_file)) assert len(result) == 1 def test_load_json_not_array_raises(self, loader, tmp_path): json_file = tmp_path / "test.json" json_file.write_text(json.dumps({"not": "an array"})) with pytest.raises(ValueError, match="must be an array"): loader.load(str(json_file)) class TestUnsupportedFormat: def test_unsupported_extension_raises(self, loader, tmp_path): txt_file = tmp_path / "test.txt" txt_file.write_text("hello") with pytest.raises(ValueError, match="Unsupported dataset format"): loader.load(str(txt_file))