Source code for rdflib.plugins.sparql.aggregates

from rdflib import Literal, XSD
from rdflib.plugins.sparql.evalutils import _eval, NotBoundError, _val
from rdflib.plugins.sparql.operators import numeric
from rdflib.plugins.sparql.datatypes import type_promotion

from rdflib.plugins.sparql.sparql import SPARQLTypeError

from decimal import Decimal

"""
Aggregation functions
"""


[docs]class Accumulator(object): """abstract base class for different aggregation functions"""
[docs] def __init__(self, aggregation): self.var = aggregation.res self.expr = aggregation.vars if not aggregation.distinct: self.use_row = self.dont_care self.distinct = False else: self.distinct = aggregation.distinct self.seen = set()
[docs] def dont_care(self, row): """skips distinct test""" return True
[docs] def use_row(self, row): """tests distinct with set""" return _eval(self.expr, row) not in self.seen
[docs] def set_value(self, bindings): """sets final value in bindings""" bindings[self.var] = self.get_value()
[docs]class Counter(Accumulator):
[docs] def __init__(self, aggregation): super(Counter, self).__init__(aggregation) self.value = 0 if self.expr == "*": # cannot eval "*" => always use the full row self.eval_row = self.eval_full_row
[docs] def update(self, row, aggregator): try: val = self.eval_row(row) except NotBoundError: # skip UNDEF return self.value += 1 if self.distinct: self.seen.add(val)
[docs] def get_value(self): return Literal(self.value)
[docs] def eval_row(self, row): return _eval(self.expr, row)
[docs] def eval_full_row(self, row): return row
[docs] def use_row(self, row): return self.eval_row(row) not in self.seen
[docs]def type_safe_numbers(*args): if any(isinstance(arg, float) for arg in args) and any( isinstance(arg, Decimal) for arg in args ): return map(float, args) return args
[docs]class Sum(Accumulator):
[docs] def __init__(self, aggregation): super(Sum, self).__init__(aggregation) self.value = 0 self.datatype = None
[docs] def update(self, row, aggregator): try: value = _eval(self.expr, row) dt = self.datatype if dt is None: dt = value.datatype else: dt = type_promotion(dt, value.datatype) self.datatype = dt self.value = sum(type_safe_numbers(self.value, numeric(value))) if self.distinct: self.seen.add(value) except NotBoundError: # skip UNDEF pass
[docs] def get_value(self): return Literal(self.value, datatype=self.datatype)
[docs]class Average(Accumulator):
[docs] def __init__(self, aggregation): super(Average, self).__init__(aggregation) self.counter = 0 self.sum = 0 self.datatype = None
[docs] def update(self, row, aggregator): try: value = _eval(self.expr, row) dt = self.datatype self.sum = sum(type_safe_numbers(self.sum, numeric(value))) if dt is None: dt = value.datatype else: dt = type_promotion(dt, value.datatype) self.datatype = dt if self.distinct: self.seen.add(value) self.counter += 1 # skip UNDEF or BNode => SPARQLTypeError except NotBoundError: pass except SPARQLTypeError: pass
[docs] def get_value(self): if self.counter == 0: return Literal(0) if self.datatype in (XSD.float, XSD.double): return Literal(self.sum / self.counter) else: return Literal(Decimal(self.sum) / Decimal(self.counter))
[docs]class Extremum(Accumulator): """abstract base class for Minimum and Maximum"""
[docs] def __init__(self, aggregation): super(Extremum, self).__init__(aggregation) self.value = None # DISTINCT would not change the value for MIN or MAX self.use_row = self.dont_care
[docs] def set_value(self, bindings): if self.value is not None: # simply do not set if self.value is still None bindings[self.var] = Literal(self.value)
[docs] def update(self, row, aggregator): try: if self.value is None: self.value = _eval(self.expr, row) else: # self.compare is implemented by Minimum/Maximum self.value = self.compare(self.value, _eval(self.expr, row)) # skip UNDEF or BNode => SPARQLTypeError except NotBoundError: pass except SPARQLTypeError: pass
[docs]class Minimum(Extremum):
[docs] def compare(self, val1, val2): return min(val1, val2, key=_val)
[docs]class Maximum(Extremum):
[docs] def compare(self, val1, val2): return max(val1, val2, key=_val)
[docs]class Sample(Accumulator): """takes the first eligible value"""
[docs] def __init__(self, aggregation): super(Sample, self).__init__(aggregation) # DISTINCT would not change the value self.use_row = self.dont_care
[docs] def update(self, row, aggregator): try: # set the value now aggregator.bindings[self.var] = _eval(self.expr, row) # and skip this accumulator for future rows del aggregator.accumulators[self.var] except NotBoundError: pass
[docs] def get_value(self): # set None if no value was set return None
[docs]class GroupConcat(Accumulator):
[docs] def __init__(self, aggregation): super(GroupConcat, self).__init__(aggregation) # only GROUPCONCAT needs to have a list as accumulator self.value = [] self.separator = aggregation.separator or " "
[docs] def update(self, row, aggregator): try: value = _eval(self.expr, row) self.value.append(value) if self.distinct: self.seen.add(value) # skip UNDEF except NotBoundError: pass
[docs] def get_value(self): return Literal(self.separator.join(str(v) for v in self.value))
[docs]class Aggregator(object): """combines different Accumulator objects""" accumulator_classes = { "Aggregate_Count": Counter, "Aggregate_Sample": Sample, "Aggregate_Sum": Sum, "Aggregate_Avg": Average, "Aggregate_Min": Minimum, "Aggregate_Max": Maximum, "Aggregate_GroupConcat": GroupConcat, }
[docs] def __init__(self, aggregations): self.bindings = {} self.accumulators = {} for a in aggregations: accumulator_class = self.accumulator_classes.get(a.name) if accumulator_class is None: raise Exception("Unknown aggregate function " + a.name) self.accumulators[a.res] = accumulator_class(a)
[docs] def update(self, row): """update all own accumulators""" # SAMPLE accumulators may delete themselves # => iterate over list not generator for acc in list(self.accumulators.values()): if acc.use_row(row): acc.update(row, self)
[docs] def get_bindings(self): """calculate and set last values""" for acc in self.accumulators.values(): acc.set_value(self.bindings) return self.bindings