from __future__ import annotations
from decimal import Decimal
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Mapping,
MutableMapping,
Optional,
Set,
Tuple,
TypeVar,
Union,
overload,
)
from rdflib.namespace import XSD
from rdflib.plugins.sparql.datatypes import type_promotion
from rdflib.plugins.sparql.evalutils import _eval, _val
from rdflib.plugins.sparql.operators import numeric
from rdflib.plugins.sparql.parserutils import CompValue
from rdflib.plugins.sparql.sparql import FrozenBindings, NotBoundError, SPARQLTypeError
from rdflib.term import BNode, Identifier, Literal, URIRef, Variable
"""
Aggregation functions
"""
[docs]class Accumulator(object):
"""abstract base class for different aggregation functions"""
[docs] def __init__(self, aggregation: CompValue):
self.get_value: Callable[[], Optional[Literal]]
self.update: Callable[[FrozenBindings, "Aggregator"], None]
self.var = aggregation.res
self.expr = aggregation.vars
if not aggregation.distinct:
# type error: Cannot assign to a method
self.use_row = self.dont_care # type: ignore[method-assign]
self.distinct = False
else:
self.distinct = aggregation.distinct
self.seen: Set[Any] = set()
[docs] def dont_care(self, row: FrozenBindings) -> bool:
"""skips distinct test"""
return True
[docs] def use_row(self, row: FrozenBindings) -> bool:
"""tests distinct with set"""
return _eval(self.expr, row) not in self.seen
[docs] def set_value(self, bindings: MutableMapping[Variable, Identifier]) -> None:
"""sets final value in bindings"""
# type error: Incompatible types in assignment (expression has type "Optional[Literal]", target has type "Identifier")
bindings[self.var] = self.get_value() # type: ignore[assignment]
[docs]class Counter(Accumulator):
[docs] def __init__(self, aggregation: CompValue):
super(Counter, self).__init__(aggregation)
self.value = 0
if self.expr == "*":
# cannot eval "*" => always use the full row
# type error: Cannot assign to a method
self.eval_row = self.eval_full_row # type: ignore[assignment]
[docs] def update(self, row: FrozenBindings, aggregator: "Aggregator") -> None:
try:
val = self.eval_row(row)
except NotBoundError:
# skip UNDEF
return
self.value += 1
if self.distinct:
self.seen.add(val)
[docs] def get_value(self) -> Literal:
return Literal(self.value)
[docs] def eval_row(self, row: FrozenBindings) -> Identifier:
return _eval(self.expr, row)
[docs] def eval_full_row(self, row: FrozenBindings) -> FrozenBindings:
return row
[docs] def use_row(self, row: FrozenBindings) -> bool:
return self.eval_row(row) not in self.seen
@overload
def type_safe_numbers(*args: int) -> Tuple[int]:
...
@overload
def type_safe_numbers(*args: Union[Decimal, float, int]) -> Tuple[Union[float, int]]:
...
[docs]def type_safe_numbers(*args: Union[Decimal, float, int]) -> Iterable[Union[float, int]]:
if any(isinstance(arg, float) for arg in args) and any(
isinstance(arg, Decimal) for arg in args
):
return map(float, args)
# type error: Incompatible return value type (got "Tuple[Union[Decimal, float, int], ...]", expected "Iterable[Union[float, int]]")
# NOTE on type error: if args contains a Decimal it will nopt get here.
return args # type: ignore[return-value]
[docs]class Sum(Accumulator):
[docs] def __init__(self, aggregation: CompValue):
super(Sum, self).__init__(aggregation)
self.value = 0
self.datatype: Optional[str] = None
[docs] def update(self, row: FrozenBindings, aggregator: "Aggregator") -> None:
try:
value = _eval(self.expr, row)
dt = self.datatype
if dt is None:
dt = value.datatype
else:
# type error: Argument 1 to "type_promotion" has incompatible type "str"; expected "URIRef"
dt = type_promotion(dt, value.datatype) # type: ignore[arg-type]
self.datatype = dt
self.value = sum(type_safe_numbers(self.value, numeric(value)))
if self.distinct:
self.seen.add(value)
except NotBoundError:
# skip UNDEF
pass
[docs] def get_value(self) -> Literal:
return Literal(self.value, datatype=self.datatype)
[docs]class Average(Accumulator):
[docs] def __init__(self, aggregation: CompValue):
super(Average, self).__init__(aggregation)
self.counter = 0
self.sum = 0
self.datatype: Optional[str] = None
[docs] def update(self, row: FrozenBindings, aggregator: "Aggregator") -> None:
try:
value = _eval(self.expr, row)
dt = self.datatype
self.sum = sum(type_safe_numbers(self.sum, numeric(value)))
if dt is None:
dt = value.datatype
else:
# type error: Argument 1 to "type_promotion" has incompatible type "str"; expected "URIRef"
dt = type_promotion(dt, value.datatype) # type: ignore[arg-type]
self.datatype = dt
if self.distinct:
self.seen.add(value)
self.counter += 1
# skip UNDEF or BNode => SPARQLTypeError
except NotBoundError:
pass
except SPARQLTypeError:
pass
[docs] def get_value(self) -> Literal:
if self.counter == 0:
return Literal(0)
if self.datatype in (XSD.float, XSD.double):
return Literal(self.sum / self.counter)
else:
return Literal(Decimal(self.sum) / Decimal(self.counter))
[docs]class Extremum(Accumulator):
"""abstract base class for Minimum and Maximum"""
[docs] def __init__(self, aggregation: CompValue):
self.compare: Callable[[Any, Any], Any]
super(Extremum, self).__init__(aggregation)
self.value: Any = None
# DISTINCT would not change the value for MIN or MAX
# type error: Cannot assign to a method
self.use_row = self.dont_care # type: ignore[method-assign]
[docs] def set_value(self, bindings: MutableMapping[Variable, Identifier]) -> None:
if self.value is not None:
# simply do not set if self.value is still None
bindings[self.var] = Literal(self.value)
[docs] def update(self, row: FrozenBindings, aggregator: "Aggregator") -> None:
try:
if self.value is None:
self.value = _eval(self.expr, row)
else:
# self.compare is implemented by Minimum/Maximum
self.value = self.compare(self.value, _eval(self.expr, row))
# skip UNDEF or BNode => SPARQLTypeError
except NotBoundError:
pass
except SPARQLTypeError:
pass
_ValueT = TypeVar("_ValueT", Variable, BNode, URIRef, Literal)
[docs]class Minimum(Extremum):
[docs] def compare(self, val1: _ValueT, val2: _ValueT) -> _ValueT:
return min(val1, val2, key=_val)
[docs]class Maximum(Extremum):
[docs] def compare(self, val1: _ValueT, val2: _ValueT) -> _ValueT:
return max(val1, val2, key=_val)
[docs]class Sample(Accumulator):
"""takes the first eligible value"""
[docs] def __init__(self, aggregation):
super(Sample, self).__init__(aggregation)
# DISTINCT would not change the value
self.use_row = self.dont_care
[docs] def update(self, row: FrozenBindings, aggregator: "Aggregator") -> None:
try:
# set the value now
aggregator.bindings[self.var] = _eval(self.expr, row)
# and skip this accumulator for future rows
del aggregator.accumulators[self.var]
except NotBoundError:
pass
[docs] def get_value(self) -> None:
# set None if no value was set
return None
[docs]class GroupConcat(Accumulator):
[docs] def __init__(self, aggregation):
super(GroupConcat, self).__init__(aggregation)
# only GROUPCONCAT needs to have a list as accumulator
self.value = []
self.separator = aggregation.separator or " "
[docs] def update(self, row: FrozenBindings, aggregator: "Aggregator") -> None:
try:
value = _eval(self.expr, row)
# skip UNDEF
if isinstance(value, NotBoundError):
return
self.value.append(value)
if self.distinct:
self.seen.add(value)
# skip UNDEF
# NOTE: It seems like this is not the way undefined values occur, they
# come through not as exceptions but as values. This is left here
# however as it may occur in some cases.
# TODO: Consider removing this.
except NotBoundError:
pass
[docs] def get_value(self) -> Literal:
return Literal(self.separator.join(str(v) for v in self.value))
[docs]class Aggregator(object):
"""combines different Accumulator objects"""
accumulator_classes = {
"Aggregate_Count": Counter,
"Aggregate_Sample": Sample,
"Aggregate_Sum": Sum,
"Aggregate_Avg": Average,
"Aggregate_Min": Minimum,
"Aggregate_Max": Maximum,
"Aggregate_GroupConcat": GroupConcat,
}
[docs] def __init__(self, aggregations: List[CompValue]):
self.bindings: Dict[Variable, Identifier] = {}
self.accumulators: Dict[str, Accumulator] = {}
for a in aggregations:
accumulator_class = self.accumulator_classes.get(a.name)
if accumulator_class is None:
raise Exception("Unknown aggregate function " + a.name)
self.accumulators[a.res] = accumulator_class(a)
[docs] def update(self, row: FrozenBindings) -> None:
"""update all own accumulators"""
# SAMPLE accumulators may delete themselves
# => iterate over list not generator
for acc in list(self.accumulators.values()):
if acc.use_row(row):
acc.update(row, self)
[docs] def get_bindings(self) -> Mapping[Variable, Identifier]:
"""calculate and set last values"""
for acc in self.accumulators.values():
acc.set_value(self.bindings)
return self.bindings