fix: align mem0 pgvector config with validation
This commit is contained in:
parent
dc2a212f35
commit
1dd45107af
@ -7,10 +7,12 @@ import logging
|
|||||||
import asyncio
|
import asyncio
|
||||||
from typing import Any, Dict, List, Optional, Literal
|
from typing import Any, Dict, List, Optional, Literal
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
from urllib.parse import unquote, urlparse
|
||||||
from embedding.manager import get_model_manager
|
from embedding.manager import get_model_manager
|
||||||
import json_repair
|
import json_repair
|
||||||
from psycopg2 import pool
|
from psycopg2 import pool
|
||||||
from utils.settings import (
|
from utils.settings import (
|
||||||
|
CHECKPOINT_DB_URL,
|
||||||
EMBEDDING_API_KEY,
|
EMBEDDING_API_KEY,
|
||||||
EMBEDDING_BASE_URL,
|
EMBEDDING_BASE_URL,
|
||||||
EMBEDDING_DIMENSIONS,
|
EMBEDDING_DIMENSIONS,
|
||||||
@ -190,27 +192,68 @@ class Mem0Manager:
|
|||||||
mem0_instance: Mem0 Memory instance
|
mem0_instance: Mem0 Memory instance
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Mem0 Memory instances have a vector_store attribute of type PGVector
|
vector_store = getattr(mem0_instance, 'vector_store', None)
|
||||||
if hasattr(mem0_instance, 'vector_store'):
|
if vector_store is not None and getattr(vector_store, 'conn', None) is not None:
|
||||||
vector_store = mem0_instance.vector_store
|
try:
|
||||||
# PGVector has conn and connection_pool attributes
|
if getattr(vector_store, 'cur', None):
|
||||||
if hasattr(vector_store, 'conn') and hasattr(vector_store, 'connection_pool'):
|
vector_store.cur.close()
|
||||||
if vector_store.conn is not None and vector_store.connection_pool is not None:
|
vector_store.cur = None
|
||||||
try:
|
connection_pool = getattr(vector_store, 'connection_pool', None)
|
||||||
# Close the cursor first
|
if connection_pool is not None:
|
||||||
if hasattr(vector_store, 'cur') and vector_store.cur:
|
connection_pool.putconn(vector_store.conn)
|
||||||
vector_store.cur.close()
|
logger.debug("Successfully released Mem0 database connection back to pool")
|
||||||
vector_store.cur = None
|
else:
|
||||||
# Return the connection to the pool
|
vector_store.conn.close()
|
||||||
vector_store.connection_pool.putconn(vector_store.conn)
|
logger.debug("Successfully closed Mem0 database connection")
|
||||||
# Mark as cleaned up to prevent __del__ from releasing it again
|
vector_store.conn = None
|
||||||
vector_store.conn = None
|
except Exception as e:
|
||||||
logger.debug("Successfully released Mem0 database connection back to pool")
|
logger.warning(f"Error releasing Mem0 connection: {e}")
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Error releasing Mem0 connection: {e}")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Error cleaning up Mem0 instance: {e}")
|
logger.warning(f"Error cleaning up Mem0 instance: {e}")
|
||||||
|
|
||||||
|
def _build_pgvector_config(self, agent_id: str) -> Dict[str, Any]:
|
||||||
|
"""Build Mem0 PGVector config using only fields accepted by mem0."""
|
||||||
|
parsed_url = urlparse(CHECKPOINT_DB_URL)
|
||||||
|
if parsed_url.scheme not in ("postgresql", "postgres"):
|
||||||
|
raise ValueError(f"Unsupported CHECKPOINT_DB_URL scheme: {parsed_url.scheme}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"dbname": unquote(parsed_url.path.lstrip("/") or "postgres"),
|
||||||
|
"user": unquote(parsed_url.username or ""),
|
||||||
|
"password": unquote(parsed_url.password or ""),
|
||||||
|
"host": parsed_url.hostname or "localhost",
|
||||||
|
"port": parsed_url.port or 5432,
|
||||||
|
"collection_name": f"mem0_{agent_id}".replace("-", "_")[:50],
|
||||||
|
"embedding_model_dims": EMBEDDING_DIMENSIONS,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _attach_pool_to_vector_store(self, mem0_instance: Any) -> None:
|
||||||
|
"""Move Mem0's runtime vector store onto the shared psycopg2 pool."""
|
||||||
|
vector_store = getattr(mem0_instance, 'vector_store', None)
|
||||||
|
if vector_store is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if getattr(vector_store, 'cur', None):
|
||||||
|
vector_store.cur.close()
|
||||||
|
vector_store.cur = None
|
||||||
|
if getattr(vector_store, 'conn', None) is not None:
|
||||||
|
vector_store.conn.close()
|
||||||
|
vector_store.conn = None
|
||||||
|
vector_store.connection_pool = self._sync_pool
|
||||||
|
|
||||||
|
def _close_telemetry_vector_store(self, mem0_instance: Any) -> None:
|
||||||
|
"""Close Mem0's migration telemetry vector-store connection after init."""
|
||||||
|
vector_store = getattr(mem0_instance, '_telemetry_vector_store', None)
|
||||||
|
if vector_store is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if getattr(vector_store, 'cur', None):
|
||||||
|
vector_store.cur.close()
|
||||||
|
vector_store.cur = None
|
||||||
|
if getattr(vector_store, 'conn', None) is not None:
|
||||||
|
vector_store.conn.close()
|
||||||
|
vector_store.conn = None
|
||||||
|
|
||||||
def _ensure_connection(self, mem0_instance: Any) -> None:
|
def _ensure_connection(self, mem0_instance: Any) -> None:
|
||||||
"""Ensure a Mem0 instance has a database connection before use.
|
"""Ensure a Mem0 instance has a database connection before use.
|
||||||
|
|
||||||
@ -225,8 +268,7 @@ class Mem0Manager:
|
|||||||
if hasattr(vs, 'conn') and vs.conn is None and self._sync_pool:
|
if hasattr(vs, 'conn') and vs.conn is None and self._sync_pool:
|
||||||
vs.conn = self._sync_pool.getconn()
|
vs.conn = self._sync_pool.getconn()
|
||||||
vs.cur = vs.conn.cursor()
|
vs.cur = vs.conn.cursor()
|
||||||
# Ensure the connection_pool reference exists for later return
|
if not hasattr(vs, 'connection_pool') or vs.connection_pool is None:
|
||||||
if hasattr(vs, 'connection_pool') and vs.connection_pool is None:
|
|
||||||
vs.connection_pool = self._sync_pool
|
vs.connection_pool = self._sync_pool
|
||||||
logger.debug("Re-acquired Mem0 database connection from pool")
|
logger.debug("Re-acquired Mem0 database connection from pool")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -249,8 +291,11 @@ class Mem0Manager:
|
|||||||
if hasattr(vs, 'cur') and vs.cur:
|
if hasattr(vs, 'cur') and vs.cur:
|
||||||
vs.cur.close()
|
vs.cur.close()
|
||||||
vs.cur = None
|
vs.cur = None
|
||||||
if hasattr(vs, 'connection_pool') and vs.connection_pool is not None:
|
connection_pool = getattr(vs, 'connection_pool', None)
|
||||||
vs.connection_pool.putconn(vs.conn)
|
if connection_pool is not None:
|
||||||
|
connection_pool.putconn(vs.conn)
|
||||||
|
else:
|
||||||
|
vs.conn.close()
|
||||||
vs.conn = None
|
vs.conn = None
|
||||||
logger.debug("Released Mem0 database connection back to pool")
|
logger.debug("Released Mem0 database connection back to pool")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -336,15 +381,13 @@ class Mem0Manager:
|
|||||||
# Create a custom embedder backed by the external embedding API.
|
# Create a custom embedder backed by the external embedding API.
|
||||||
custom_embedder = CustomMem0Embedding()
|
custom_embedder = CustomMem0Embedding()
|
||||||
|
|
||||||
# Configure Mem0 to use Pgvector
|
# Configure Mem0 to use Pgvector.
|
||||||
|
# Mem0 validates this config strictly, so connection_pool is attached after creation.
|
||||||
|
pgvector_config = self._build_pgvector_config(agent_id)
|
||||||
config_dict = {
|
config_dict = {
|
||||||
"vector_store": {
|
"vector_store": {
|
||||||
"provider": "pgvector",
|
"provider": "pgvector",
|
||||||
"config": {
|
"config": pgvector_config,
|
||||||
"connection_pool": connection_pool,
|
|
||||||
"collection_name": f"mem0_{agent_id}".replace("-", "_")[:50], # Isolate by agent_id
|
|
||||||
"embedding_model_dims": EMBEDDING_DIMENSIONS,
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
# The embedder is replaced immediately after Memory is created.
|
# The embedder is replaced immediately after Memory is created.
|
||||||
"embedder": {
|
"embedder": {
|
||||||
@ -388,6 +431,8 @@ class Mem0Manager:
|
|||||||
|
|
||||||
# Create the Mem0 instance
|
# Create the Mem0 instance
|
||||||
mem = Memory.from_config(config_dict)
|
mem = Memory.from_config(config_dict)
|
||||||
|
self._attach_pool_to_vector_store(mem)
|
||||||
|
self._close_telemetry_vector_store(mem)
|
||||||
logger.debug(f"Original embedder type: {type(mem.embedding_model).__name__}")
|
logger.debug(f"Original embedder type: {type(mem.embedding_model).__name__}")
|
||||||
logger.debug(f"Original embedder.embedding_dims: {getattr(mem.embedding_model, 'embedding_dims', 'N/A')}")
|
logger.debug(f"Original embedder.embedding_dims: {getattr(mem.embedding_model, 'embedding_dims', 'N/A')}")
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user