Source code for boolean_parser.mixins.sqla

# !/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Filename: sqla.py
# Project: mixins
# Author: Brian Cherinka
# Created: Wednesday, 13th February 2019 3:49:07 pm
# License: BSD 3-clause "New" or "Revised" License
# Copyright (c) 2019 Brian Cherinka
# Last Modified: Sunday, 17th February 2019 3:45:21 pm
# Modified By: Brian Cherinka


from __future__ import absolute_import, division, print_function

import decimal
import inspect
from datetime import date, datetime
from operator import eq, ge, gt, le, lt, ne

from sqlalchemy import bindparam, func
from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import DeclarativeMeta
from sqlalchemy.orm.util import AliasedClass
from sqlalchemy.sql import between, sqltypes

from boolean_parser.parsers.base import BooleanParserException


opdict = {'<=': le, '>=': ge, '>': gt, '<': lt, '!=': ne, '==': eq, '=': eq}

[docs]class SQLAMixin(object): ''' A Mixin class to apply SQLAlchemy filter parsing This mixin adds a ``filter`` method to the parsed result which converts the parsed string object into an appropriate SQLAlchemy filter condition to be used in SQLAlchemy queries. ''' def _check_models(self, classes): ''' Check the input modelclass format Checks if input classes is a module of modelclasses, a list of modelclasses or a single ModelClass and returns a list of ModelClass objects. Parameters: classes (object): A ModelClass module, list of models, or single ModelClass Returns: A list of ModelClasses ''' # an entire module of classes if inspect.ismodule(classes): # an entire module of classes models = [i[1] for i in inspect.getmembers( classes, inspect.isclass) if hasattr(i[1], '__tablename__')] elif isinstance(classes, (list, tuple)): # a list of ModelClasses models = classes else: # assume a single ModelClass models = [classes] # check for proper modelclasses allmeta = all([isinstance(m, DeclarativeMeta) or isinstance(m, AliasedClass) for m in models]) assert allmeta is True, 'All input classes must be of type SQLAlchemy ModelClasses' return models def _get_field(self, modelclass, field_name, base_name=None): ''' Return a SQLAlchemy attribute from a field name. Checks that a given model contains the named field. Parameters: modelclass (ModelClass): A SQLAlchemy ModelClass field_name (str): The database field name base_name (str): The database table name Returns: An SQLA instrumented attribute object ''' field = None # Handle hierarchical field names such as 'parent.name' if base_name: # Match alias name if isinstance(modelclass, AliasedClass) and \ base_name == modelclass._aliased_insp.name.lower(): field = getattr(modelclass, field_name, None) # Match table name if not field and base_name in modelclass.__tablename__: field = getattr(modelclass, field_name, None) else: # Handle flat field names such as 'name' field = getattr(modelclass, field_name, None) return field
[docs] def filter(self, modelclass): ''' Return the condition as an SQLalchemy query filter condition Loops over all models and creates a filter condition for that model given the input filter parameters. Parameters: modelclass (objects): A set of ModelClasses to use in the filter condition Returns: A SQL query filter condition ''' assert modelclass is not None, 'No input found' condition = None models = self._check_models(modelclass) for model in models: # get the SQLA instrumented attribute field = self._get_field(model, self.name, base_name=self.base) # if there is an attribute then break and use that model if field and hasattr(field, 'type') and hasattr(field, 'ilike'): break # raise if no attribute found if not field: raise BooleanParserException(f'Table {model.__tablename__} does not have field {self.name}') # produce the SQLA filter condition condition = self._filter_one(model, field=field, condition=condition) return condition
def _filter_one(self, model, field=None, condition=None): ''' Create a single SQLAlchemy filter condition ''' # if no field present return the original condition if not field: return condition # Prepare field and value lower_field, lower_value, lower_value_2 = self._bind_and_lower_value(field) # Handle postgresql arrays if any if isinstance(field.type, postgresql.ARRAY): condition = field.any(self.value, operator=opdict[self.operator]) return condition # Handle scalar values # Return SQLAlchemy condition based on operator value # self.name is parameter name, lower_field is Table.parameterName elif self.operator == '<': condition = lower_field.__lt__(lower_value) elif self.operator == '<=': condition = lower_field.__le__(lower_value) elif self.operator == '>': condition = lower_field.__gt__(lower_value) elif self.operator == '>=': condition = lower_field.__ge__(lower_value) elif self.operator == '!=': field = getattr(model, self.name) value = self.value # Handle NULL values if value.lower() == 'null': condition = field.isnot(None) else: condition = lower_field.__ne__(lower_value) elif self.operator == '=' or self.operator == '==': # Handle string type if isinstance(field.type, sqltypes.TEXT) or \ isinstance(field.type, sqltypes.VARCHAR) or \ isinstance(field.type, sqltypes.String): field = getattr(model, self.name) value = self.value # Handle NULL values if value.lower() == 'null': condition = field.is_(None) # if operator is straing equals, check accordingly elif self.operator == '==': condition = lower_field.__eq__(lower_value) # otherwise, this operator maps to LIKE # x=5 -> x LIKE '%5%' (x contains 5) # x=5* -> x LIKE '5%' (x starts with 5) # x=*5 -> x LIKE '%5' (x ends with 5) elif value.find('*') >= 0: value = value.replace('*', '%') condition = lower_field.ilike(bindparam(self.fullname, value)) else: condition = lower_field.ilike('%' + bindparam(self.fullname, value) + '%') # For all other types, assume straight equality else: field = getattr(model, self.name) value = self.value # Handle NULL values if value.lower() == 'null': condition = field.is_(None) else: condition = lower_field.__eq__(lower_value) elif self.operator == 'between': # between condition condition = between(lower_field, lower_value, lower_value_2) elif self.operator in ['&', '|']: # bitwise operations condition = lower_field.op(self.operator)(lower_value) > 0 return condition def _bind_and_lower_value(self, field): ''' Bind and lower the value based on field type''' lower_value_2 = None # get python field type ftypes = [float, bool, int, decimal.Decimal, date, datetime] fieldtype = field.type.python_type # format the values value, lower_field = self._format_value(self.value, fieldtype, field) if hasattr(self, 'value2'): value2, lower_field = self._format_value(self.value2, fieldtype, field) # bind the parameter value to the parameter name boundvalue = bindparam(self.fullname, value, unique=True) lower_value = func.lower(boundvalue) if fieldtype not in ftypes else boundvalue if hasattr(self, 'value2'): boundvalue2 = bindparam(self.fullname, value2, unique=True) lower_value_2 = func.lower(boundvalue2) if fieldtype not in ftypes else boundvalue2 return lower_field, lower_value, lower_value_2 def _format_value(self, value, fieldtype, field): ''' Formats the value based on the fieldtype Formats the value to proper numerical type and lowercases the field for string fields. Parameters: value (str): The conditional value to format fieldtype (object): The python field type field (SQLA attribute): SQLA instrumented attribute Returns: The formatted value and lowercase field ''' lower_field = field if fieldtype == float or fieldtype == decimal.Decimal: out_value = self._cast_value(value, datatype=float) elif fieldtype == int: out_value = self._cast_value(value, datatype=int) elif fieldtype == bool: out_value = self._cast_value(value, datatype=bool) elif fieldtype == date: out_value = self._cast_value(value, datatype=date) elif fieldtype == datetime: out_value = self._cast_value(value, datatype=datetime) else: lower_field = func.lower(field) out_value = value return out_value, lower_field def _to_bool(self, value): """ Cast value to Boolean. Parameters: value (str): The value to format Returns: True from inputs: - "true" - "t" - "1" - "yes" False from inputs: - "false" - "f" - "0" - "no" """ valid = { "true": True, "t": True, "1": True, "yes": True, "false": False, "f": False, "0": False, "no": False, } if isinstance(value, bool): return value if not isinstance(value, str): raise ValueError("Invalid literal for boolean. Not a string or boolean.") lower_value = value.lower() if lower_value in valid: return valid[lower_value] else: raise ValueError('Invalid literal for boolean: "%s"' % value) def _to_date(self, value): """ Cast value to Datetime. Parameters: value (str): The value to format. Should be an ISO 8601 date string such as '2011-11-04' or '2011-11-04T00:05:23' Returns: The value as an date object """ if isinstance(value, date): return value if not isinstance(value, str): raise ValueError("Invalid literal for date. Not a string or date.") try: # When casting to date, we don't care about time, so only take # the first 10 characters of the string dt = date.fromisoformat(value[:10]) return dt except ValueError: raise ValueError('Could not parse date from string: "%s"' % value) def _to_datetime(self, value): """ Cast value to Datetime. Parameters: value (str): The value to format. Should be an ISO 8601 date string such as '2011-11-04' or '2011-11-04T00:05:23' Returns: The value as a datetime object """ if isinstance(value, datetime): return value if not isinstance(value, str): raise ValueError("Invalid literal for datetime. Not a string or datetime.") try: dt = datetime.fromisoformat(value) return dt except ValueError: raise ValueError('Could not parse datetime from string: "%s"' % value) def _cast_value(self, value, datatype=float): ''' Cast a value to a specific Python type Parameters: value (int|float|bool|date|datetime): A numeric value to cast to a float or integer datatype (object): The cast function. Can be either float, int or bool Returns: The value explicitly cast to an integer, float, boolean or datetime ''' assert datatype in [float, int, bool, date, datetime], 'datatype must be either float, int or bool' try: if value.lower() == 'null': out = 'null' elif datatype == bool: out = self._to_bool(value) elif datatype == date: out = self._to_date(value) elif datatype == datetime: out = self._to_datetime(value) else: out = datatype(value) except (ValueError, SyntaxError): raise BooleanParserException(f'Field {self.name} expects a {datatype.__name__} value. Received {value} instead.') else: return out