mirror of
https://github.com/transformerlab/transformerlab-api.git
synced 2025-04-19 19:36:18 +03:00
Fix db locks issue when executing remotely
This commit is contained in:
@@ -14,7 +14,7 @@ class TransformerLabClient:
|
||||
def __init__(self, server_url: str = "http://localhost:8338", sdk_version: str = "v1", log_file: str = None):
|
||||
"""Initialize the XML-RPC client"""
|
||||
server_url = server_url.rstrip("/") + f"/client/{sdk_version}/jobs"
|
||||
if not server_url.startswith("http") or not server_url.startswith("https"):
|
||||
if not server_url.startswith("http") and not server_url.startswith("https"):
|
||||
raise ValueError("Invalid server URL. Must start with http:// or https://")
|
||||
self.server = xmlrpc.client.ServerProxy(server_url)
|
||||
self.job_id = None
|
||||
@@ -23,7 +23,7 @@ class TransformerLabClient:
|
||||
self.report_interval = 1 # seconds
|
||||
self.log_file = log_file
|
||||
|
||||
def start_job(self, config):
|
||||
def start(self, config):
|
||||
"""Register job with TransformerLab and get a job ID"""
|
||||
result = self.server.start_training(json.dumps(config))
|
||||
if result["status"] == "started":
|
||||
@@ -66,7 +66,7 @@ class TransformerLabClient:
|
||||
# Still return True to continue training despite reporting error
|
||||
return True
|
||||
|
||||
def complete_job(self, message="Training completed successfully"):
|
||||
def complete(self, message="Training completed successfully"):
|
||||
"""Mark job as complete in TransformerLab"""
|
||||
if not self.job_id:
|
||||
return
|
||||
@@ -82,7 +82,7 @@ class TransformerLabClient:
|
||||
except Exception as e:
|
||||
self.log_error(f"Error completing job: {e}")
|
||||
|
||||
def stop_job(self, message="Training completed successfully"):
|
||||
def stop(self, message="Training completed successfully"):
|
||||
"""Mark job as complete in TransformerLab"""
|
||||
if not self.job_id:
|
||||
return
|
||||
|
||||
@@ -16,6 +16,7 @@ from transformerlab.shared import dirs
|
||||
from transformerlab.shared.models import models # noqa: F401
|
||||
|
||||
db = None
|
||||
db_sync = None
|
||||
DATABASE_FILE_NAME = f"{dirs.WORKSPACE_DIR}/llmlab.sqlite3"
|
||||
DATABASE_URL = f"sqlite+aiosqlite:///{DATABASE_FILE_NAME}"
|
||||
|
||||
@@ -34,6 +35,10 @@ async def init():
|
||||
global db
|
||||
os.makedirs(os.path.dirname(DATABASE_FILE_NAME), exist_ok=True)
|
||||
db = await aiosqlite.connect(DATABASE_FILE_NAME)
|
||||
global db_sync
|
||||
db_sync = sqlite3.connect(DATABASE_FILE_NAME, isolation_level=None)
|
||||
db_sync.execute("PRAGMA journal_mode=WAL")
|
||||
db_sync.execute("PRAGMA busy_timeout = 5000")
|
||||
|
||||
# Create the tables if they don't exist
|
||||
async with async_engine.begin() as conn:
|
||||
@@ -247,14 +252,14 @@ def job_create_sync(type, status, job_data="{}", experiment_id=""):
|
||||
"""
|
||||
Synchronous version of job_create function for use with XML-RPC.
|
||||
"""
|
||||
global DATABASE_FILE_NAME
|
||||
# global DATABASE_FILE_NAME
|
||||
# check if type is allowed
|
||||
if type not in ALLOWED_JOB_TYPES:
|
||||
raise ValueError(f"Job type {type} is not allowed")
|
||||
|
||||
# Use SQLite directly in synchronous mode
|
||||
conn = sqlite3.connect(DATABASE_FILE_NAME, isolation_level=None)
|
||||
cursor = conn.cursor()
|
||||
# conn = sqlite3.connect(DATABASE_FILE_NAME, isolation_level=None)
|
||||
cursor = db_sync.cursor()
|
||||
|
||||
# Execute insert
|
||||
cursor.execute(
|
||||
@@ -266,8 +271,8 @@ def job_create_sync(type, status, job_data="{}", experiment_id=""):
|
||||
row_id = cursor.lastrowid
|
||||
|
||||
# Commit and close
|
||||
conn.commit()
|
||||
conn.close()
|
||||
db_sync.commit()
|
||||
cursor.close()
|
||||
|
||||
return row_id
|
||||
|
||||
@@ -396,15 +401,24 @@ async def job_update_status(job_id, status, error_msg=None):
|
||||
|
||||
|
||||
def job_update_status_sync(job_id, status, error_msg=None):
|
||||
global DATABASE_FILE_NAME
|
||||
db_sync = sqlite3.connect(DATABASE_FILE_NAME, isolation_level=None)
|
||||
try:
|
||||
global DATABASE_FILE_NAME
|
||||
# db_sync = sqlite3.connect(DATABASE_FILE_NAME, isolation_level=None)
|
||||
|
||||
db_sync.execute("UPDATE job SET status = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ?", (status, job_id))
|
||||
db_sync.commit()
|
||||
db_sync.close()
|
||||
return
|
||||
cursor = db_sync.cursor()
|
||||
|
||||
|
||||
cursor.execute("UPDATE job SET status = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ?", (status, job_id))
|
||||
db_sync.commit()
|
||||
cursor.close()
|
||||
return
|
||||
except Exception as e:
|
||||
print("Error updating job status: " + str(e))
|
||||
return
|
||||
finally:
|
||||
if cursor:
|
||||
cursor.close()
|
||||
|
||||
async def job_update(job_id, type, status):
|
||||
await db.execute(
|
||||
"UPDATE job SET type = ?, status = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ?", (type, status, job_id)
|
||||
@@ -419,11 +433,12 @@ def job_update_sync(job_id, status):
|
||||
# which can only support sychronous functions
|
||||
# This is a hack to get around that limitation
|
||||
global DATABASE_FILE_NAME
|
||||
db_sync = sqlite3.connect(DATABASE_FILE_NAME, isolation_level=None)
|
||||
# db_sync = sqlite3.connect(DATABASE_FILE_NAME, isolation_level=None)
|
||||
cursor = db_sync.cursor()
|
||||
|
||||
db_sync.execute("UPDATE job SET status = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ?", (status, job_id))
|
||||
cursor.execute("UPDATE job SET status = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ?", (status, job_id))
|
||||
db_sync.commit()
|
||||
db_sync.close()
|
||||
cursor.close()
|
||||
return
|
||||
|
||||
|
||||
@@ -432,14 +447,14 @@ def job_mark_as_complete_if_running(job_id):
|
||||
# only marks a job as "COMPLETE" if it is currenty "RUNNING"
|
||||
# This avoids updating "stopped" jobs and marking them as complete
|
||||
global DATABASE_FILE_NAME
|
||||
db_sync = sqlite3.connect(DATABASE_FILE_NAME, isolation_level=None)
|
||||
|
||||
db_sync.execute(
|
||||
# db_sync = sqlite3.connect(DATABASE_FILE_NAME, isolation_level=None)
|
||||
cursor = db_sync.cursor()
|
||||
cursor.execute(
|
||||
"UPDATE job SET status = 'COMPLETE', updated_at = CURRENT_TIMESTAMP WHERE id = ? AND status = 'RUNNING'",
|
||||
(job_id,),
|
||||
)
|
||||
db_sync.commit()
|
||||
db_sync.close()
|
||||
cursor.close()
|
||||
return
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user