So here is FetchAccumulator. It sits at the boundry just above the database calls and returns tidy, regular data to its callers.
class FetchAccumulator(object):
def __init__(self, sql, args=None, fetch_per=-1, limit=-1):
self.results = []
self.sql = sql
self.args = args
self.fetch_per = fetch_per
self.limit = limit
return
def fetch(self, cursor):
cursor.execute(self.sql, self.args)
if self.fetch_per == 1:
results = cursor.fetchone()
assert len(results) <= 1, results
elif self.limit > 0:
results = cursor.fetchmany(self.limit)
assert len(results) <= self.limit, (len(results), self.limit)
else:
results = cursor.fetchall()
if not results or not filter(None, results): # code smell
return
self.results.extend(results)
self.limit -= len(results)
if not self.limit: # we fetched our limit
raise DoneApply()
return
def __iter__(self):
return iter(self.results)
This makes the other functions much, much simpler. Here are four database query functions that use FetchAccumulator. Seventy lines are now twenty.
class ShardCursor(cursor.BaseCursor):
def selectOne(self, sql, args=None):
accum = FetchAccumulator(sql, args, fetch_per=1, limit=1)
apply_all(valid_shards(self._shard), accum.fetch)
return accum
def selectMany(self, sql, args=None, size=-1):
accum = FetchAccumulator(sql, args, limit=size)
apply_all(valid_shards(self._shard), accum.fetch)
return accum
def selectAll(self, sql, args=None):
accum = FetchAccumulator(sql, args)
apply_all(valid_shards(self._shard), accum.fetch)
return accum
def countOne(self, sql, args=None):
accum = FetchAccumulator(sql, args, fetch_per=1)
apply_all(valid_shards(self._shard), accum.fetch)
return accum
Of course these functions now have their own code smell -- they only vary in their accumulator so they could be collapsed into a single function. That would require refactoring all the calling code which is a bigger project than I wanted to take on.
The apply_all function grew a proper exception to allow callers to bail out of the loop early.
class DoneApply(Exception): pass
def apply_all(shards, func):
for shard in shards:
db = shard.establishConnection()
try:
cursor = db.cursor()
func(cursor)
except DoneApply:
break
finally:
db.close()
I'll omit the unit tests. The original project had no unit tests for this code so I had to write some to make sure my refactoring wasn't breaking anything.
No comments:
Post a Comment