add support for multiple instances running simultaneously
This commit is contained in:
@@ -142,6 +142,7 @@ def update_config(json_request: ConfigRequestJson, authorization: Optional[str]
|
||||
@fastapi_app.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket, token: str = Query(...)):
|
||||
from jesse.services.multiprocessing import process_manager
|
||||
from jesse.services.env import ENV_VALUES
|
||||
|
||||
if not authenticator.is_valid_token(token):
|
||||
return
|
||||
@@ -149,7 +150,7 @@ async def websocket_endpoint(websocket: WebSocket, token: str = Query(...)):
|
||||
await websocket.accept()
|
||||
|
||||
queue = Queue()
|
||||
ch, = await async_redis.psubscribe('channel:*')
|
||||
ch, = await async_redis.psubscribe(f"{ENV_VALUES['APP_PORT']}:channel:*")
|
||||
|
||||
async def echo(q):
|
||||
while True:
|
||||
@@ -173,7 +174,7 @@ async def websocket_endpoint(websocket: WebSocket, token: str = Query(...)):
|
||||
# just so WebSocketDisconnect would be raised on connection close
|
||||
await websocket.receive_text()
|
||||
except WebSocketDisconnect:
|
||||
await async_redis.punsubscribe('channel:*')
|
||||
await async_redis.punsubscribe(f"{ENV_VALUES['APP_PORT']}:channel:*")
|
||||
print('Websocket disconnected')
|
||||
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from jesse.services.redis import sync_publish, sync_redis
|
||||
from jesse.services.failure import terminate_session
|
||||
import jesse.helpers as jh
|
||||
from datetime import timedelta
|
||||
from jesse.services.env import ENV_VALUES
|
||||
|
||||
# set multiprocessing process type to spawn
|
||||
mp.set_start_method('spawn', force=True)
|
||||
@@ -49,9 +50,16 @@ class ProcessManager:
|
||||
self.client_id_to_pid_to_map = {}
|
||||
|
||||
@staticmethod
|
||||
def _set_process_status(pid, status):
|
||||
def _prefixed_pid(pid):
|
||||
return f"{ENV_VALUES['APP_PORT']}|{pid}"
|
||||
|
||||
@staticmethod
|
||||
def _prefixed_client_id(client_id):
|
||||
return f"{ENV_VALUES['APP_PORT']}|{client_id}"
|
||||
|
||||
def _set_process_status(self, pid, status):
|
||||
seconds = 3600 * 24 * 365 # one year
|
||||
key = f'process-status:{pid}'
|
||||
key = f"{ENV_VALUES['APP_PORT']}|process-status:{pid}"
|
||||
value = f'status:{status}'
|
||||
sync_redis.setex(key, timedelta(seconds=seconds), value)
|
||||
|
||||
@@ -59,33 +67,23 @@ class ProcessManager:
|
||||
w = Process(target=function, args=args)
|
||||
self._workers.append(w)
|
||||
w.start()
|
||||
self._pid_to_client_id_map[w.pid] = client_id
|
||||
self.client_id_to_pid_to_map[client_id] = w.pid
|
||||
|
||||
self._pid_to_client_id_map[self._prefixed_pid(w.pid)] = self._prefixed_client_id(client_id)
|
||||
self.client_id_to_pid_to_map[self._prefixed_client_id(client_id)] = self._prefixed_pid(w.pid)
|
||||
self._set_process_status(w.pid, 'started')
|
||||
|
||||
def get_client_id(self, pid):
|
||||
client_id: str = self._pid_to_client_id_map[pid]
|
||||
client_id: str = self._pid_to_client_id_map[self._prefixed_pid(pid)]
|
||||
# return after "-" because we add them before sending it to multiprocessing
|
||||
return client_id[client_id.index('-') + len('-'):]
|
||||
|
||||
def get_pid(self, client_id):
|
||||
return self.client_id_to_pid_to_map[client_id]
|
||||
return self.client_id_to_pid_to_map[self._prefixed_client_id(client_id)]
|
||||
|
||||
def cancel_process(self, client_id):
|
||||
pid = self.get_pid(client_id)
|
||||
pid = jh.string_after_character(pid, '|')
|
||||
self._set_process_status(pid, 'stopping')
|
||||
# TODO: after some time, set it to stopped?
|
||||
|
||||
# pid = self.get_pid(client_id)
|
||||
# for i, w in enumerate(self._workers):
|
||||
# if w.is_alive() and w.pid == pid:
|
||||
# del self.client_id_to_pid_to_map[client_id]
|
||||
# del self._pid_to_client_id_map[w.pid]
|
||||
# w.terminate()
|
||||
# w.join()
|
||||
# w.close()
|
||||
# del self._workers[i]
|
||||
# return
|
||||
|
||||
def flush(self):
|
||||
for w in self._workers:
|
||||
|
||||
@@ -31,7 +31,7 @@ def sync_publish(event: str, msg):
|
||||
raise EnvironmentError('sync_publish() should be NOT called during testing. There must be something wrong')
|
||||
|
||||
sync_redis.publish(
|
||||
'channel:1', json.dumps({
|
||||
f"{ENV_VALUES['APP_PORT']}:channel:1", json.dumps({
|
||||
'id': os.getpid(),
|
||||
'event': f'{jh.app_mode()}.{event}',
|
||||
'data': msg
|
||||
@@ -41,7 +41,7 @@ def sync_publish(event: str, msg):
|
||||
|
||||
async def async_publish(event: str, msg):
|
||||
await async_redis.publish(
|
||||
'channel:1', json.dumps({
|
||||
f"{ENV_VALUES['APP_PORT']}:channel:1", json.dumps({
|
||||
'id': os.getpid(),
|
||||
'event': f'{jh.app_mode()}.{event}',
|
||||
'data': msg
|
||||
@@ -56,7 +56,7 @@ def process_status(pid=None) -> str:
|
||||
if pid is None:
|
||||
pid = jh.get_pid()
|
||||
|
||||
key = f'process-status:{pid}'
|
||||
key = f"{ENV_VALUES['APP_PORT']}|process-status:{pid}"
|
||||
|
||||
res: str = jh.str_or_none(sync_redis.get(key))
|
||||
if res is None:
|
||||
|
||||
Reference in New Issue
Block a user