from rdflib import Literal, XSD
from rdflib.plugins.sparql.evalutils import _eval, NotBoundError, _val
from rdflib.plugins.sparql.operators import numeric
from rdflib.plugins.sparql.datatypes import type_promotion
from rdflib.plugins.sparql.sparql import SPARQLTypeError
from decimal import Decimal
"""
Aggregation functions
"""
[docs]class Accumulator(object):
"""abstract base class for different aggregation functions """
[docs] def __init__(self, aggregation):
self.var = aggregation.res
self.expr = aggregation.vars
if not aggregation.distinct:
self.use_row = self.dont_care
self.distinct = False
else:
self.distinct = aggregation.distinct
self.seen = set()
[docs] def dont_care(self, row):
"""skips distinct test """
return True
[docs] def use_row(self, row):
"""tests distinct with set """
return _eval(self.expr, row) not in self.seen
[docs] def set_value(self, bindings):
"""sets final value in bindings"""
bindings[self.var] = self.get_value()
[docs]class Counter(Accumulator):
[docs] def __init__(self, aggregation):
super(Counter, self).__init__(aggregation)
self.value = 0
if self.expr == "*":
# cannot eval "*" => always use the full row
self.eval_row = self.eval_full_row
[docs] def update(self, row, aggregator):
try:
val = self.eval_row(row)
except NotBoundError:
# skip UNDEF
return
self.value += 1
if self.distinct:
self.seen.add(val)
[docs] def get_value(self):
return Literal(self.value)
[docs] def eval_row(self, row):
return _eval(self.expr, row)
[docs] def eval_full_row(self, row):
return row
[docs] def use_row(self, row):
return self.eval_row(row) not in self.seen
[docs]def type_safe_numbers(*args):
types = map(type, args)
if float in types and Decimal in types:
return map(float, args)
return args
[docs]class Sum(Accumulator):
[docs] def __init__(self, aggregation):
super(Sum, self).__init__(aggregation)
self.value = 0
self.datatype = None
[docs] def update(self, row, aggregator):
try:
value = _eval(self.expr, row)
dt = self.datatype
if dt is None:
dt = value.datatype
else:
dt = type_promotion(dt, value.datatype)
self.datatype = dt
self.value = sum(type_safe_numbers(self.value, numeric(value)))
if self.distinct:
self.seen.add(value)
except NotBoundError:
# skip UNDEF
pass
[docs] def get_value(self):
return Literal(self.value, datatype=self.datatype)
[docs]class Average(Accumulator):
[docs] def __init__(self, aggregation):
super(Average, self).__init__(aggregation)
self.counter = 0
self.sum = 0
self.datatype = None
[docs] def update(self, row, aggregator):
try:
value = _eval(self.expr, row)
dt = self.datatype
self.sum = sum(type_safe_numbers(self.sum, numeric(value)))
if dt is None:
dt = value.datatype
else:
dt = type_promotion(dt, value.datatype)
self.datatype = dt
if self.distinct:
self.seen.add(value)
self.counter += 1
# skip UNDEF or BNode => SPARQLTypeError
except NotBoundError:
pass
except SPARQLTypeError:
pass
[docs] def get_value(self):
if self.counter == 0:
return Literal(0)
if self.datatype in (XSD.float, XSD.double):
return Literal(self.sum / self.counter)
else:
return Literal(Decimal(self.sum) / Decimal(self.counter))
[docs]class Extremum(Accumulator):
"""abstract base class for Minimum and Maximum"""
[docs] def __init__(self, aggregation):
super(Extremum, self).__init__(aggregation)
self.value = None
# DISTINCT would not change the value for MIN or MAX
self.use_row = self.dont_care
[docs] def set_value(self, bindings):
if self.value is not None:
# simply do not set if self.value is still None
bindings[self.var] = Literal(self.value)
[docs] def update(self, row, aggregator):
try:
if self.value is None:
self.value = _eval(self.expr, row)
else:
# self.compare is implemented by Minimum/Maximum
self.value = self.compare(self.value, _eval(self.expr, row))
# skip UNDEF or BNode => SPARQLTypeError
except NotBoundError:
pass
except SPARQLTypeError:
pass
[docs]class Minimum(Extremum):
[docs] def compare(self, val1, val2):
return min(val1, val2, key=_val)
[docs]class Maximum(Extremum):
[docs] def compare(self, val1, val2):
return max(val1, val2, key=_val)
[docs]class Sample(Accumulator):
"""takes the first eligable value"""
[docs] def __init__(self, aggregation):
super(Sample, self).__init__(aggregation)
# DISTINCT would not change the value
self.use_row = self.dont_care
[docs] def update(self, row, aggregator):
try:
# set the value now
aggregator.bindings[self.var] = _eval(self.expr, row)
# and skip this accumulator for future rows
del aggregator.accumulators[self.var]
except NotBoundError:
pass
[docs] def get_value(self):
# set None if no value was set
return None
[docs]class GroupConcat(Accumulator):
[docs] def __init__(self, aggregation):
super(GroupConcat, self).__init__(aggregation)
# only GROUPCONCAT needs to have a list as accumlator
self.value = []
self.separator = aggregation.separator or " "
[docs] def update(self, row, aggregator):
try:
value = _eval(self.expr, row)
self.value.append(value)
if self.distinct:
self.seen.add(value)
# skip UNDEF
except NotBoundError:
pass
[docs] def get_value(self):
return Literal(self.separator.join(unicode(v) for v in self.value))
[docs]class Aggregator(object):
"""combines different Accumulator objects"""
accumulator_classes = {
"Aggregate_Count": Counter,
"Aggregate_Sample": Sample,
"Aggregate_Sum": Sum,
"Aggregate_Avg": Average,
"Aggregate_Min": Minimum,
"Aggregate_Max": Maximum,
"Aggregate_GroupConcat": GroupConcat,
}
[docs] def __init__(self, aggregations):
self.bindings = {}
self.accumulators = {}
for a in aggregations:
accumulator_class = self.accumulator_classes.get(a.name)
if accumulator_class is None:
raise Exception("Unknown aggregate function " + a.name)
self.accumulators[a.res] = accumulator_class(a)
[docs] def update(self, row):
"""update all own accumulators"""
# SAMPLE accumulators may delete themselves
# => iterate over list not generator
for acc in self.accumulators.values():
if acc.use_row(row):
acc.update(row, self)
[docs] def get_bindings(self):
"""calculate and set last values"""
for acc in self.accumulators.itervalues():
acc.set_value(self.bindings)
return self.bindings