mirror of
https://github.com/transformerlab/transformerlab-api.git
synced 2025-04-19 19:36:18 +03:00
write first test
This commit is contained in:
13
.vscode/settings.json
vendored
13
.vscode/settings.json
vendored
@@ -1,6 +1,9 @@
|
||||
{
|
||||
"[python]": {
|
||||
"editor.defaultFormatter": "ms-python.autopep8",
|
||||
"editor.formatOnSave": true
|
||||
}
|
||||
}
|
||||
"[python]": {
|
||||
"editor.defaultFormatter": "ms-python.autopep8",
|
||||
"editor.formatOnSave": true
|
||||
},
|
||||
"python.testing.pytestArgs": ["test"],
|
||||
"python.testing.unittestEnabled": false,
|
||||
"python.testing.pytestEnabled": true
|
||||
}
|
||||
|
||||
6
pytest.ini
Normal file
6
pytest.ini
Normal file
@@ -0,0 +1,6 @@
|
||||
# pytest.ini
|
||||
[pytest]
|
||||
asyncio_mode = auto
|
||||
asyncio_default_fixture_loop_scope = module
|
||||
testpaths =
|
||||
test
|
||||
0
test/__init__.py
Normal file
0
test/__init__.py
Normal file
164
test/test_db.py
Normal file
164
test/test_db.py
Normal file
@@ -0,0 +1,164 @@
|
||||
from transformerlab.db import (
|
||||
init, close, get_dataset, get_datasets, create_huggingface_dataset,
|
||||
create_local_dataset, delete_dataset, model_local_list, model_local_count,
|
||||
model_local_create, model_local_get, model_local_delete, job_create,
|
||||
jobs_get_all, jobs_get_all_by_experiment_and_type, job_get_status,
|
||||
job_get_error_msg, job_get, job_count_running, jobs_get_next_queued_job,
|
||||
job_update_status, job_update, job_update_sync, job_mark_as_complete_if_running,
|
||||
job_delete_all, job_delete, job_cancel_in_progress_jobs, job_update_job_data_insert_key_value,
|
||||
job_stop, get_training_template, get_training_template_by_name, get_training_templates,
|
||||
create_training_template, update_training_template, delete_training_template,
|
||||
training_jobs_get_all, job_get_for_template_id, export_job_create, experiment_get_all,
|
||||
experiment_create, experiment_get, experiment_get_by_name, experiment_delete,
|
||||
experiment_update, experiment_update_config, experiment_save_prompt_template,
|
||||
get_plugins, get_plugins_of_type, get_plugin, save_plugin, config_get, config_set
|
||||
)
|
||||
import pytest
|
||||
import asyncio
|
||||
import transformerlab.db as db
|
||||
|
||||
pytest_plugins = ('pytest_asyncio',)
|
||||
|
||||
# FILE: transformerlab/test_db.py
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
async def setup_db():
|
||||
await db.init()
|
||||
yield
|
||||
await db.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def test_dataset():
|
||||
# Setup code to create test_dataset
|
||||
dataset = await db.create_local_dataset("test_dataset")
|
||||
yield dataset
|
||||
# Teardown code to delete test_dataset
|
||||
await db.delete_dataset("test_dataset")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def test_experiment():
|
||||
# Setup code to create test_experiment
|
||||
experiment_id = await db.experiment_create("test_experiment", "{}")
|
||||
yield experiment_id
|
||||
# Teardown code to delete test_experiment
|
||||
await db.experiment_delete(experiment_id)
|
||||
|
||||
# content of test_sample.py
|
||||
|
||||
|
||||
def test_db_exists():
|
||||
global db
|
||||
assert db is not None
|
||||
|
||||
|
||||
class TestDatasets:
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_and_get_dataset(self, test_dataset):
|
||||
dataset = await get_dataset("test_dataset")
|
||||
assert dataset is not None
|
||||
assert dataset["dataset_id"] == "test_dataset"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_datasets(self):
|
||||
datasets = await get_datasets()
|
||||
assert isinstance(datasets, list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_dataset(self):
|
||||
await create_local_dataset("test_dataset_delete")
|
||||
await delete_dataset("test_dataset_delete")
|
||||
dataset = await get_dataset("test_dataset_delete")
|
||||
assert dataset is None
|
||||
|
||||
|
||||
class TestModels:
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_and_get_model(self):
|
||||
await model_local_create("test_model", "Test Model", {})
|
||||
model = await model_local_get("test_model")
|
||||
assert model is not None
|
||||
assert model["model_id"] == "test_model"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_local_list(self):
|
||||
models = await model_local_list()
|
||||
assert isinstance(models, list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_local_count(self):
|
||||
count = await model_local_count()
|
||||
assert isinstance(count, int)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_local_delete(self):
|
||||
await model_local_create("test_model_delete", "Test Model Delete", {})
|
||||
await model_local_delete("test_model_delete")
|
||||
model = await model_local_get("test_model_delete")
|
||||
assert model is None
|
||||
|
||||
|
||||
class TestJobs:
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_and_get_job(self):
|
||||
job_id = await job_create("TRAIN", "QUEUED")
|
||||
job = await job_get(job_id)
|
||||
assert job is not None
|
||||
assert job["id"] == job_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_jobs_get_all(self):
|
||||
jobs = await jobs_get_all()
|
||||
assert isinstance(jobs, list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_job_update_status(self):
|
||||
job_id = await job_create("TRAIN", "QUEUED")
|
||||
print(job_id)
|
||||
await job_update_status(job_id, "RUNNING")
|
||||
status = await job_get_status(job_id)
|
||||
assert status == "RUNNING"
|
||||
|
||||
|
||||
class TestExperiments:
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_and_get_experiment(self, test_experiment):
|
||||
experiment = await experiment_get(test_experiment)
|
||||
assert experiment is not None
|
||||
assert experiment["name"] == "test_experiment"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_experiment_get_all(self):
|
||||
experiments = await experiment_get_all()
|
||||
assert isinstance(experiments, list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_experiment_delete(self):
|
||||
experiment_id = await experiment_create("test_experiment_delete", "{}")
|
||||
await experiment_delete(experiment_id)
|
||||
experiment = await experiment_get(experiment_id)
|
||||
assert experiment is None
|
||||
|
||||
|
||||
class TestPlugins:
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_and_get_plugin(self):
|
||||
await save_plugin("test_plugin", "test_type")
|
||||
plugin = await get_plugin("test_plugin")
|
||||
assert plugin is not None
|
||||
assert plugin["name"] == "test_plugin"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_plugins(self):
|
||||
plugins = await get_plugins()
|
||||
assert isinstance(plugins, list)
|
||||
|
||||
|
||||
class TestConfig:
|
||||
@pytest.mark.asyncio
|
||||
async def test_config_set_and_get(self):
|
||||
await config_set("test_key", "test_value")
|
||||
value = await config_get("test_key")
|
||||
assert value == "test_value"
|
||||
@@ -338,10 +338,10 @@ async def jobs_get_all_by_experiment_and_type(experiment_id, job_type):
|
||||
|
||||
async def job_get_status(job_id):
|
||||
|
||||
cursor = await db.execute("SELECT status FROM job WHERE job_id = ?", (job_id,))
|
||||
cursor = await db.execute("SELECT status FROM job WHERE id = ?", (job_id,))
|
||||
row = await cursor.fetchone()
|
||||
await cursor.close()
|
||||
return row
|
||||
return row[0]
|
||||
|
||||
|
||||
async def job_get_error_msg(job_id):
|
||||
|
||||
@@ -30,6 +30,11 @@ async def job_create(type: str = 'UNDEFINED', status: str = 'CREATED', data: str
|
||||
return jobid
|
||||
|
||||
|
||||
async def job_create_task(script: str, job_data: str = '{}', experiment_id: str = '-1'):
|
||||
jobid = await db.job_create(type='UNDEFINED', status='CREATED', job_data=job_data, experiment_id=experiment_id)
|
||||
return jobid
|
||||
|
||||
|
||||
@router.get("/update/{job_id}")
|
||||
async def job_update(job_id: str, status: str):
|
||||
await db.job_update_status(job_id, status)
|
||||
|
||||
Reference in New Issue
Block a user