From 67a19a83c5cfcd1efc28dff987fdb4c9e0023b1d Mon Sep 17 00:00:00 2001 From: deep1401 Date: Fri, 11 Apr 2025 15:21:48 -0400 Subject: [PATCH] Fix db locks issue when executing remotely --- .../tlab_sdk_client/client.py | 8 +-- transformerlab/db.py | 51 ++++++++++++------- 2 files changed, 37 insertions(+), 22 deletions(-) diff --git a/scripts/xml-rpc-client-example/tlab_sdk_client/client.py b/scripts/xml-rpc-client-example/tlab_sdk_client/client.py index 0f5ad32..40d41ba 100644 --- a/scripts/xml-rpc-client-example/tlab_sdk_client/client.py +++ b/scripts/xml-rpc-client-example/tlab_sdk_client/client.py @@ -14,7 +14,7 @@ class TransformerLabClient: def __init__(self, server_url: str = "http://localhost:8338", sdk_version: str = "v1", log_file: str = None): """Initialize the XML-RPC client""" server_url = server_url.rstrip("/") + f"/client/{sdk_version}/jobs" - if not server_url.startswith("http") or not server_url.startswith("https"): + if not server_url.startswith("http") and not server_url.startswith("https"): raise ValueError("Invalid server URL. Must start with http:// or https://") self.server = xmlrpc.client.ServerProxy(server_url) self.job_id = None @@ -23,7 +23,7 @@ class TransformerLabClient: self.report_interval = 1 # seconds self.log_file = log_file - def start_job(self, config): + def start(self, config): """Register job with TransformerLab and get a job ID""" result = self.server.start_training(json.dumps(config)) if result["status"] == "started": @@ -66,7 +66,7 @@ class TransformerLabClient: # Still return True to continue training despite reporting error return True - def complete_job(self, message="Training completed successfully"): + def complete(self, message="Training completed successfully"): """Mark job as complete in TransformerLab""" if not self.job_id: return @@ -82,7 +82,7 @@ class TransformerLabClient: except Exception as e: self.log_error(f"Error completing job: {e}") - def stop_job(self, message="Training completed successfully"): + def stop(self, message="Training completed successfully"): """Mark job as complete in TransformerLab""" if not self.job_id: return diff --git a/transformerlab/db.py b/transformerlab/db.py index 124eea5..992059e 100644 --- a/transformerlab/db.py +++ b/transformerlab/db.py @@ -16,6 +16,7 @@ from transformerlab.shared import dirs from transformerlab.shared.models import models # noqa: F401 db = None +db_sync = None DATABASE_FILE_NAME = f"{dirs.WORKSPACE_DIR}/llmlab.sqlite3" DATABASE_URL = f"sqlite+aiosqlite:///{DATABASE_FILE_NAME}" @@ -34,6 +35,10 @@ async def init(): global db os.makedirs(os.path.dirname(DATABASE_FILE_NAME), exist_ok=True) db = await aiosqlite.connect(DATABASE_FILE_NAME) + global db_sync + db_sync = sqlite3.connect(DATABASE_FILE_NAME, isolation_level=None) + db_sync.execute("PRAGMA journal_mode=WAL") + db_sync.execute("PRAGMA busy_timeout = 5000") # Create the tables if they don't exist async with async_engine.begin() as conn: @@ -247,14 +252,14 @@ def job_create_sync(type, status, job_data="{}", experiment_id=""): """ Synchronous version of job_create function for use with XML-RPC. """ - global DATABASE_FILE_NAME + # global DATABASE_FILE_NAME # check if type is allowed if type not in ALLOWED_JOB_TYPES: raise ValueError(f"Job type {type} is not allowed") # Use SQLite directly in synchronous mode - conn = sqlite3.connect(DATABASE_FILE_NAME, isolation_level=None) - cursor = conn.cursor() + # conn = sqlite3.connect(DATABASE_FILE_NAME, isolation_level=None) + cursor = db_sync.cursor() # Execute insert cursor.execute( @@ -266,8 +271,8 @@ def job_create_sync(type, status, job_data="{}", experiment_id=""): row_id = cursor.lastrowid # Commit and close - conn.commit() - conn.close() + db_sync.commit() + cursor.close() return row_id @@ -396,15 +401,24 @@ async def job_update_status(job_id, status, error_msg=None): def job_update_status_sync(job_id, status, error_msg=None): - global DATABASE_FILE_NAME - db_sync = sqlite3.connect(DATABASE_FILE_NAME, isolation_level=None) + try: + global DATABASE_FILE_NAME + # db_sync = sqlite3.connect(DATABASE_FILE_NAME, isolation_level=None) - db_sync.execute("UPDATE job SET status = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ?", (status, job_id)) - db_sync.commit() - db_sync.close() - return + cursor = db_sync.cursor() + cursor.execute("UPDATE job SET status = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ?", (status, job_id)) + db_sync.commit() + cursor.close() + return + except Exception as e: + print("Error updating job status: " + str(e)) + return + finally: + if cursor: + cursor.close() + async def job_update(job_id, type, status): await db.execute( "UPDATE job SET type = ?, status = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ?", (type, status, job_id) @@ -419,11 +433,12 @@ def job_update_sync(job_id, status): # which can only support sychronous functions # This is a hack to get around that limitation global DATABASE_FILE_NAME - db_sync = sqlite3.connect(DATABASE_FILE_NAME, isolation_level=None) + # db_sync = sqlite3.connect(DATABASE_FILE_NAME, isolation_level=None) + cursor = db_sync.cursor() - db_sync.execute("UPDATE job SET status = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ?", (status, job_id)) + cursor.execute("UPDATE job SET status = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ?", (status, job_id)) db_sync.commit() - db_sync.close() + cursor.close() return @@ -432,14 +447,14 @@ def job_mark_as_complete_if_running(job_id): # only marks a job as "COMPLETE" if it is currenty "RUNNING" # This avoids updating "stopped" jobs and marking them as complete global DATABASE_FILE_NAME - db_sync = sqlite3.connect(DATABASE_FILE_NAME, isolation_level=None) - - db_sync.execute( + # db_sync = sqlite3.connect(DATABASE_FILE_NAME, isolation_level=None) + cursor = db_sync.cursor() + cursor.execute( "UPDATE job SET status = 'COMPLETE', updated_at = CURRENT_TIMESTAMP WHERE id = ? AND status = 'RUNNING'", (job_id,), ) db_sync.commit() - db_sync.close() + cursor.close() return