"""This module defines specific functions for MariaDB dialect."""
from sqlalchemy.ext.compiler import compiles
from geoalchemy2 import functions
from geoalchemy2.admin.dialects.common import compile_bin_literal
from geoalchemy2.admin.dialects.mysql import after_create # noqa
from geoalchemy2.admin.dialects.mysql import after_drop # noqa
from geoalchemy2.admin.dialects.mysql import before_create # noqa
from geoalchemy2.admin.dialects.mysql import before_drop # noqa
from geoalchemy2.admin.dialects.mysql import reflect_geometry_column # noqa
from geoalchemy2.elements import WKBElement
from geoalchemy2.elements import WKTElement
def _cast(param):
if isinstance(param, memoryview):
param = param.tobytes()
if isinstance(param, bytes):
param = WKBElement(param)
if isinstance(param, WKBElement):
param = param.as_wkb().desc
return param
[docs]
def before_cursor_execute(conn, cursor, statement, parameters, context, executemany, convert=True): # noqa: D417
"""Event handler to cast the parameters properly.
Args:
convert (bool): Trigger the conversion.
"""
if convert:
if isinstance(parameters, tuple | list):
parameters = tuple(_cast(x) for x in parameters)
elif isinstance(parameters, dict):
for k in parameters:
parameters[k] = _cast(parameters[k])
return statement, parameters
_MARIADB_FUNCTIONS = {
"ST_AsEWKB": "ST_AsBinary",
}
def _compiles_mariadb(cls, fn):
def _compile_mariadb(element, compiler, **kw):
return f"{fn}({compiler.process(element.clauses, **kw)})"
compiles(getattr(functions, cls), "mariadb")(_compile_mariadb)
[docs]
def register_mariadb_mapping(mapping):
"""Register compilation mappings for the given functions.
Args:
mapping: Should have the following form::
{
"function_name_1": "mariadb_function_name_1",
"function_name_2": "mariadb_function_name_2",
...
}
"""
for cls, fn in mapping.items():
_compiles_mariadb(cls, fn)
register_mariadb_mapping(_MARIADB_FUNCTIONS)
def _compile_GeomFromText_MariaDB(element, compiler, **kw):
identifier = "ST_GeomFromText"
compiled = compiler.process(element.clauses, **kw)
try:
clauses = list(element.clauses)
data_element = WKTElement(clauses[0].value)
srid = data_element.srid
if srid <= 0:
srid = element.type.srid
except Exception:
srid = element.type.srid
res = f"{identifier}({compiled}, {srid})" if srid > 0 else f"{identifier}({compiled})"
return res
def _compile_GeomFromWKB_MariaDB(element, compiler, **kw):
identifier = "ST_GeomFromWKB"
# Store the SRID
clauses = list(element.clauses)
try:
srid = clauses[1].value
except (IndexError, TypeError, ValueError):
srid = element.type.srid
wkb_clause = compile_bin_literal(clauses[0]) if kw.get("literal_binds", False) else clauses[0]
prefix = "unhex("
suffix = ")"
compiled = compiler.process(wkb_clause, **kw)
if srid > 0:
return f"{identifier}({prefix}{compiled}{suffix}, {srid})"
else:
return f"{identifier}({prefix}{compiled}{suffix})"
@compiles(functions.ST_GeomFromText, "mariadb") # type: ignore
def _MariaDB_ST_GeomFromText(element, compiler, **kw):
return _compile_GeomFromText_MariaDB(element, compiler, **kw)
@compiles(functions.ST_GeomFromEWKT, "mariadb") # type: ignore
def _MariaDB_ST_GeomFromEWKT(element, compiler, **kw):
return _compile_GeomFromText_MariaDB(element, compiler, **kw)
@compiles(functions.ST_GeomFromWKB, "mariadb") # type: ignore
def _MariaDB_ST_GeomFromWKB(element, compiler, **kw):
return _compile_GeomFromWKB_MariaDB(element, compiler, **kw)
@compiles(functions.ST_GeomFromEWKB, "mariadb") # type: ignore
def _MariaDB_ST_GeomFromEWKB(element, compiler, **kw):
return _compile_GeomFromWKB_MariaDB(element, compiler, **kw)