write first test

This commit is contained in:
ali asaria
2025-01-31 10:14:12 -05:00
committed by sanjaycal
parent fc2b235ebf
commit e1e3af1352
6 changed files with 185 additions and 7 deletions

13
.vscode/settings.json vendored
View File

@@ -1,6 +1,9 @@
{
"[python]": {
"editor.defaultFormatter": "ms-python.autopep8",
"editor.formatOnSave": true
}
}
"[python]": {
"editor.defaultFormatter": "ms-python.autopep8",
"editor.formatOnSave": true
},
"python.testing.pytestArgs": ["test"],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
}

6
pytest.ini Normal file
View File

@@ -0,0 +1,6 @@
# pytest.ini
[pytest]
asyncio_mode = auto
asyncio_default_fixture_loop_scope = module
testpaths =
test

0
test/__init__.py Normal file
View File

164
test/test_db.py Normal file
View File

@@ -0,0 +1,164 @@
from transformerlab.db import (
init, close, get_dataset, get_datasets, create_huggingface_dataset,
create_local_dataset, delete_dataset, model_local_list, model_local_count,
model_local_create, model_local_get, model_local_delete, job_create,
jobs_get_all, jobs_get_all_by_experiment_and_type, job_get_status,
job_get_error_msg, job_get, job_count_running, jobs_get_next_queued_job,
job_update_status, job_update, job_update_sync, job_mark_as_complete_if_running,
job_delete_all, job_delete, job_cancel_in_progress_jobs, job_update_job_data_insert_key_value,
job_stop, get_training_template, get_training_template_by_name, get_training_templates,
create_training_template, update_training_template, delete_training_template,
training_jobs_get_all, job_get_for_template_id, export_job_create, experiment_get_all,
experiment_create, experiment_get, experiment_get_by_name, experiment_delete,
experiment_update, experiment_update_config, experiment_save_prompt_template,
get_plugins, get_plugins_of_type, get_plugin, save_plugin, config_get, config_set
)
import pytest
import asyncio
import transformerlab.db as db
pytest_plugins = ('pytest_asyncio',)
# FILE: transformerlab/test_db.py
@pytest.fixture(scope="module", autouse=True)
async def setup_db():
await db.init()
yield
await db.close()
@pytest.fixture
async def test_dataset():
# Setup code to create test_dataset
dataset = await db.create_local_dataset("test_dataset")
yield dataset
# Teardown code to delete test_dataset
await db.delete_dataset("test_dataset")
@pytest.fixture
async def test_experiment():
# Setup code to create test_experiment
experiment_id = await db.experiment_create("test_experiment", "{}")
yield experiment_id
# Teardown code to delete test_experiment
await db.experiment_delete(experiment_id)
# content of test_sample.py
def test_db_exists():
global db
assert db is not None
class TestDatasets:
@pytest.mark.asyncio
async def test_create_and_get_dataset(self, test_dataset):
dataset = await get_dataset("test_dataset")
assert dataset is not None
assert dataset["dataset_id"] == "test_dataset"
@pytest.mark.asyncio
async def test_get_datasets(self):
datasets = await get_datasets()
assert isinstance(datasets, list)
@pytest.mark.asyncio
async def test_delete_dataset(self):
await create_local_dataset("test_dataset_delete")
await delete_dataset("test_dataset_delete")
dataset = await get_dataset("test_dataset_delete")
assert dataset is None
class TestModels:
@pytest.mark.asyncio
async def test_create_and_get_model(self):
await model_local_create("test_model", "Test Model", {})
model = await model_local_get("test_model")
assert model is not None
assert model["model_id"] == "test_model"
@pytest.mark.asyncio
async def test_model_local_list(self):
models = await model_local_list()
assert isinstance(models, list)
@pytest.mark.asyncio
async def test_model_local_count(self):
count = await model_local_count()
assert isinstance(count, int)
@pytest.mark.asyncio
async def test_model_local_delete(self):
await model_local_create("test_model_delete", "Test Model Delete", {})
await model_local_delete("test_model_delete")
model = await model_local_get("test_model_delete")
assert model is None
class TestJobs:
@pytest.mark.asyncio
async def test_create_and_get_job(self):
job_id = await job_create("TRAIN", "QUEUED")
job = await job_get(job_id)
assert job is not None
assert job["id"] == job_id
@pytest.mark.asyncio
async def test_jobs_get_all(self):
jobs = await jobs_get_all()
assert isinstance(jobs, list)
@pytest.mark.asyncio
async def test_job_update_status(self):
job_id = await job_create("TRAIN", "QUEUED")
print(job_id)
await job_update_status(job_id, "RUNNING")
status = await job_get_status(job_id)
assert status == "RUNNING"
class TestExperiments:
@pytest.mark.asyncio
async def test_create_and_get_experiment(self, test_experiment):
experiment = await experiment_get(test_experiment)
assert experiment is not None
assert experiment["name"] == "test_experiment"
@pytest.mark.asyncio
async def test_experiment_get_all(self):
experiments = await experiment_get_all()
assert isinstance(experiments, list)
@pytest.mark.asyncio
async def test_experiment_delete(self):
experiment_id = await experiment_create("test_experiment_delete", "{}")
await experiment_delete(experiment_id)
experiment = await experiment_get(experiment_id)
assert experiment is None
class TestPlugins:
@pytest.mark.asyncio
async def test_save_and_get_plugin(self):
await save_plugin("test_plugin", "test_type")
plugin = await get_plugin("test_plugin")
assert plugin is not None
assert plugin["name"] == "test_plugin"
@pytest.mark.asyncio
async def test_get_plugins(self):
plugins = await get_plugins()
assert isinstance(plugins, list)
class TestConfig:
@pytest.mark.asyncio
async def test_config_set_and_get(self):
await config_set("test_key", "test_value")
value = await config_get("test_key")
assert value == "test_value"

View File

@@ -338,10 +338,10 @@ async def jobs_get_all_by_experiment_and_type(experiment_id, job_type):
async def job_get_status(job_id):
cursor = await db.execute("SELECT status FROM job WHERE job_id = ?", (job_id,))
cursor = await db.execute("SELECT status FROM job WHERE id = ?", (job_id,))
row = await cursor.fetchone()
await cursor.close()
return row
return row[0]
async def job_get_error_msg(job_id):

View File

@@ -30,6 +30,11 @@ async def job_create(type: str = 'UNDEFINED', status: str = 'CREATED', data: str
return jobid
async def job_create_task(script: str, job_data: str = '{}', experiment_id: str = '-1'):
jobid = await db.job_create(type='UNDEFINED', status='CREATED', job_data=job_data, experiment_id=experiment_id)
return jobid
@router.get("/update/{job_id}")
async def job_update(job_id: str, status: str):
await db.job_update_status(job_id, status)