diff --git a/datajoint/fetch.py b/datajoint/fetch.py index 961764cd8..56c88467d 100644 --- a/datajoint/fetch.py +++ b/datajoint/fetch.py @@ -8,6 +8,7 @@ from datajoint import DataJointError from . import key as PRIMARY_KEY from collections import abc +from . import config def prepare_attributes(relation, item): @@ -35,7 +36,9 @@ def ret(*args, **kwargs): return ret + class Fetch: + def __init__(self, relation): if isinstance(relation, Fetch): # copy constructor self.behavior = dict(relation.behavior) @@ -46,7 +49,6 @@ def __init__(self, relation): ) self._relation = relation - @copy_first def order_by(self, *args): if len(args) > 0: @@ -162,6 +164,23 @@ def __getitem__(self, item): ] return return_values[0] if single_output else return_values + def __repr__(self): + limit = config['display.limit'] + width = config['display.width'] + rel = self._relation.project(*self._relation.heading.non_blobs) # project out blobs + template = '%%-%d.%ds' % (width, width) + columns = rel.heading.names + repr_string = ' '.join([template % column for column in columns]) + '\n' + repr_string += ' '.join(['+' + '-' * (width - 2) + '+' for _ in columns]) + '\n' + for tup in rel.fetch(limit=limit): + repr_string += ' '.join([template % column for column in tup]) + '\n' + if len(rel) > limit: + repr_string += '...\n' + repr_string += ' (%d tuples)\n' % len(rel) + return repr_string + + def __len__(self): + return len(self._relation) class Fetch1: def __init__(self, relation): diff --git a/datajoint/relation.py b/datajoint/relation.py index 1dd98f0bd..f69c86cc2 100644 --- a/datajoint/relation.py +++ b/datajoint/relation.py @@ -106,6 +106,9 @@ def descendants(self): for table in self.connection.erm.get_descendants(self.full_table_name)) return [relation for relation in relations if relation.is_declared] + def _repr_helper(self): + return "%s.%s()" % (self.__module__, self.__class__.__name__) + # --------- SQL functionality --------- # @property def is_declared(self): @@ -282,6 +285,9 @@ def __init__(self, connection, full_table_name, definition=None, context=None): self._definition = definition self._context = context + def __repr__(self): + return "FreeRelation(`%s`.`%s`)" % (self.database, self._table_name) + @property def definition(self): return self._definition diff --git a/datajoint/relational_operand.py b/datajoint/relational_operand.py index 769dd7986..581e7365c 100644 --- a/datajoint/relational_operand.py +++ b/datajoint/relational_operand.py @@ -143,7 +143,17 @@ def __sub__(self, restriction): """ inverted restriction aka antijoin """ - return self & Not(restriction) + return \ + self & Not(restriction) + + def _repr_helper(self): + return "None" + + def __repr__(self): + ret = self._repr_helper() + if self._restrictions: + ret += ' & %r' % self._restrictions + return ret # ------ data retrieval methods ----------- @@ -190,21 +200,6 @@ def cursor(self, offset=0, limit=None, order_by=None, as_dict=False): logger.debug(sql) return self.connection.query(sql, as_dict=as_dict) - def __repr__(self): - limit = config['display.limit'] - width = config['display.width'] - rel = self.project(*self.heading.non_blobs) # project out blobs - template = '%%-%d.%ds' % (width, width) - columns = rel.heading.names - repr_string = ' '.join([template % column for column in columns]) + '\n' - repr_string += ' '.join(['+' + '-' * (width - 2) + '+' for _ in columns]) + '\n' - for tup in rel.fetch(limit=limit): - repr_string += ' '.join([template % column for column in tup]) + '\n' - if len(self) > limit: - repr_string += '...\n' - repr_string += ' (%d tuples)\n' % len(self) - return repr_string - @property def fetch1(self): return Fetch1(self) @@ -274,6 +269,9 @@ def __init__(self, arg1, arg2): self._arg2 = Subquery(arg1) if arg2.heading.computed else arg2 self._restrictions = self._arg1.restrictions + self._arg2.restrictions + def _repr_helper(self): + return "(%r) * (%r)" % (self._arg1, self._arg2) + @property def connection(self): return self._arg1.connection @@ -352,6 +350,10 @@ def _restrict(self, restriction): else: return super()._restrict(restriction) + def _repr_helper(self): + # TODO: create better repr + return "project(%r, %r)" % (self._arg, self._attributes) + class Subquery(RelationalOperand): """ @@ -379,3 +381,6 @@ def from_clause(self): @property def heading(self): return self._arg.heading.resolve() + + def _repr_helper(self): + return "%r" % self._arg diff --git a/tests/test_fetch.py b/tests/test_fetch.py index fc50becff..8706c47a0 100644 --- a/tests/test_fetch.py +++ b/tests/test_fetch.py @@ -125,3 +125,11 @@ def test_copy(self): f2 = f.order_by('name') assert_true(f.behavior['order_by'] is None and len(f2.behavior['order_by']) == 1, 'Object was not copied') + def test_repr(self): + """Test string representation of fetch, returning table preview""" + repr = self.subject.fetch.__repr__() + n = len(repr.strip().split('\n')) + limit = dj.config['display.limit'] + # 3 lines are used for headers (2) and summary statement (1) + assert_true(n - 3 <= limit) +