You are here: Home Linux transaction_test.py

transaction_test.py

by Harald Hoyer last modified Mar 25, 2008 08:16 AM
UnitTests for the Transaction Class

transaction_test.py — Python Source, 13 kB (13838 bytes)

File contents

# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 2.1 of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this program; if not, write to the
# Free Software Foundation, Inc., 59 Temple Place - Suite 330,
# Boston, MA 02111-1307, USA.

"""\
 UnitTests for the Transaction class

 See also: http://www.harald-hoyer.de/linux/pythontransactionclass
    
 Copyright (C) 2008 Harald Hoyer <harald@redhat.com>
 Copyright (C) 2008 Red Hat, Inc.
"""

import unittest
import sys
import copy

from transaction import Transaction


class TransactionOld1(object):
    """\
    Old Transaction class from  
    http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/284677
    """
    def __init__(self):
        self.log = []
    def commit(self, **kwargs):
        self.log.append(self.__dict__.copy())
    def rollback(self, **kwargs):
        try:
            self.__dict__.update(self.log.pop(-1))
        except IndexError:
            pass
    def __repr__(self):
        return "'self.__dict__ = %s'" % self.__dict__
    
class TransactionOld2(TransactionOld1):
    def commit(self, **kwargs):
        self.log.append(copy.deepcopy(self.__dict__))

class TransactionOld3(TransactionOld2):
    def rollback(self, **kwargs):
        try:
            state = self.log.pop(-1)
            self.__dict__.clear()
            self.__dict__.update(state)
        except IndexError:
            pass

class TransactionNew1(object):           
    def _docommit(self):
        if "log" not in self.__dict__:
            self.__dict__["log"] = list()
            
        self.__dict__["log"].append(copy.deepcopy(self.__dict__))

    def _dorollback(self):
        if "log" not in self.__dict__:
            return
        try:
            state = self.__dict__["log"].pop(-1)
            self.__dict__.clear()
            self.__dict__.update(state)
        except IndexError:
            pass
    
    def commit(self, **kwargs):
        # commit ourselves, then our childs
        self._docommit()
        if kwargs.get("deep", True):
            for child in self.__dict__.values():
                if isinstance(child, self.__class__):
                    child.commit()
                
    def rollback(self, **kwargs):
        # rollback our childs, then ourselves
        if kwargs.get("deep", True):
            for child in self.__dict__.values():
                if isinstance(child, self.__class__):
                    child.rollback()
        self._dorollback()

    def __repr__(self):
        return "'self.__dict__ = %s'" % self.__dict__

class TransactionNew2(TransactionNew1):
    def _docommit(self):
        if "log" in self.__dict__:
            oldstate = self.__dict__.pop("log")
        else:
            oldstate = None             
        state = copy.deepcopy(self.__dict__)
        if oldstate:
            state["log"] = oldstate
        self.__dict__["log"] = state

    def _dorollback(self):
        if "log" not in self.__dict__:
            return
        try:
            state = self.__dict__["log"]
            self.__dict__.clear()
            self.__dict__.update(state)
        except IndexError:
            pass
    def __repr__(self):
        return "'self.__dict__ = %s'" % self.__dict__

class TransactionNew3(TransactionNew2):
    def _checksetseen(self, seen):
        if id(self) in seen:
            sys.stderr.write("Recursion detected... ")
            return True
        seen.add(id(self))
        return False
    
    def commit(self, **kwargs): # pylint: disable-msg=W0613
        seen = kwargs.get("_commit_seen", set())
        if self._checksetseen(seen): 
            return
        # commit ourselves, then our childs
        self._docommit()
        if kwargs.get("deep", True):
            for child in self.__dict__.values():
                if isinstance(child, self.__class__):
                    child.commit(_commit_seen = seen)
                
    def rollback(self, **kwargs):
        seen = kwargs.get("_rollback_seen", set())
        if self._checksetseen(seen):
            return
        # rollback our childs, then ourselves
        if kwargs.get("deep", True):
            for child in self.__dict__.values():
                if isinstance(child, self.__class__):
                    child.rollback(_rollback_seen = seen)
        self._dorollback()

class TransactionImproved(Transaction):
    def __repr__(self):
        return "'self.__dict__ = %s'" % self.__dict__

    
class TestTransaction(unittest.TestCase):
    doRecursion = False
    
    def test01(self):
        "simple rollback"
        a = TestClass()
        a.test = "correct"
        a.commit()
        a.test = "roll me back"
        a.rollback()
        self.assertEqual(a.test, "correct")

    def test02(self):
        """double rollback
        
        demonstrates how you can commit / rollback several times
        """
        a = TestClass()
        a.test = "correct"
        a.commit()
        a.test = "roll me back second"
        a.commit()
        a.test = "roll me back first"
        a.rollback()
        a.rollback()
        self.assertEqual(a.test, "correct")

    def test03(self):
        """\
        test list rollback
        
        showing the side effects of a non deep rollback
        """
        a = TestClass()
        a.ls = [ 0, 1, 2 ]
        a.commit(deep = False)
        a.ls.append(3)
        a.rollback(deep = False)
        self.assertEqual(a.ls, [0, 1, 2])

    def test031(self):
        """\
        non deep rollback
        
        showing the side effects of a non deep rollback
        """
        a = TestClass()
        a.ckls = TestClass()
        b = a.ckls
        a.ckls.newvar = "correct"
        a.commit(deep = False)
        a.ckls.newvar = "roll me back"
        a.rollback(deep = False)
        self.assertEqual(b.newvar, "roll me back")
        self.assertEqual(a.ckls.newvar, "roll me back")

    def test04(self):
        """commit not deep, rollback deep
        
        showing the side effects of a non deep commit
        """
        a = TestClass()
        a.ckls = TestClass()
        a.ckls.newvar = TestClass()
        a.ckls.newvar.text = "correct"
        b = a.ckls
        a.commit(deep = False)
        a.ckls.newvar.text = "roll me back"
        a.rollback(deep = True)
        self.assertEqual(b.newvar.text, "roll me back")
        self.assertEqual(a.ckls.newvar.text, "roll me back")
        
    def test041(self):
        """check for leftover attributes
        """
        a = TestClass()
        a.newvar = "correct"
        a.commit()
        a.shouldnotbethere = True
        a.rollback()
        self.failIf(hasattr(a, "shouldnotbethere"), a)

    def test05(self):
        """commit and rollback deep
        
        no more side effects
        """
        a = TestClass()
        a.ckls = TestClass()
        a.ckls.newvar = "correct"
        b = a.ckls
        a.commit()
        a.ckls.newvar = "roll me back"
        a.rollback(deep = True)
        self.assertEqual(a.ckls.newvar, "correct")
        self.assertEqual(b.newvar, "correct")
        self.assertEqual(id(a.ckls), id(b))
        
    def test06(self):
        """commit only a sub object
        
        though we committed only an attribute, the deep rollback
        will roll it back.
        """
        a = TestClass()
        a.ckls = TestClass()
        a.newvar = "correct"
        a.ckls.newvar = "correct"
        a.ckls.commit()
        a.newvar = "will not be rolled back"
        a.ckls.newvar = "roll me back"
        a.rollback()
        self.assertEqual(a.newvar, "will not be rolled back")
        self.assertEqual(a.ckls.newvar, "correct")

    def test07(self):
        """commit only a sub object, rollback with deep=false
        
        we committed only an attribute and the non deep rollback
        will not roll it back.
        """
        a = TestClass()
        a.ckls = TestClass()
        a.newvar = "correct"
        a.ckls.newvar = "correct"
        a.ckls.commit()
        a.newvar = "will not be rolled back"
        a.ckls.newvar = "will not be rolled back"
        a.rollback(deep = False)
        self.assertEqual(a.newvar, "will not be rolled back")
        self.assertEqual(a.ckls.newvar, "will not be rolled back")


    def test10(self):
        """check for the commit/rollback recursion
        """
        if not TestTransaction.doRecursion:
            sys.stderr.write("skipped .. ")
            return
        
        a = TestClass()
        b = a
        for i in xrange(10):
            b.newvar = TestClass()
            b.test = "test" + str(i)
            b = b.newvar
        b.newvar = a
        # would raise a recursion maximum exception
        a.commit(deep = True)
    
    
    def test11(self):
        """check for swapping Transaction objects
        """
        a = TestClass()
        a.t1 = TestClass()
        a.t2 = TestClass()
        a.t1.text = "test1"
        a.t2.text = "test2"
        a.commit()
        b = a.t1
        a.t1 = a.t2
        a.t2 = b
        a.rollback()
        self.assertEqual(a.t1.text, "test1")
        self.assertEqual(a.t2.text, "test2")
        
    def count_dict(self, seen, what):
        if id(what) in seen:
            return 1
        seen.add(id(what))
        if not hasattr(what, "__dict__"):
            return 1
        i = 1
        for val in what.__dict__.values():
            #print "Counting ", id(val), val
            i = i + self.count_dict(seen, val)
        return i

    def test91(self):
        """check for the stack length (deep=False)
        """
        a = TestClass()
        b = a
        for i in xrange(10):
            b.newvar = TestClass()
            b.test = "test"  + str(i)
            b = b.newvar
        a.commit(deep = False)
        a.commit(deep = False)
        a.commit(deep = False)

        seen = set()
        count = self.count_dict(seen, a)            
        sys.stderr.write("%d objects in %d places .. " % (len(seen), count))
                
    def test92(self):
        """check for the stack length (deep=True)
        """
        a = TestClass()
        b = a
        for i in xrange(10):
            b.newvar = TestClass()
            b.test = "test" + str(i)
            b.test2 = "test2" + str(i)
            b = b.newvar
            a.commit(deep = True)

        a.commit()
        a.commit()
        a.commit()
        
        seen = set()
        count = self.count_dict(seen, a)            
        sys.stderr.write("%d objects in %d places .. " % (len(seen), count))

        
def suite():
    _suite = unittest.TestSuite()
    _suite = unittest.makeSuite(TestTransaction,'test')
    return _suite

if __name__ == "__main__":
    global TestClass
#    print >> sys.stderr, """\
#**********************************************************************
#Old Transaction Class (original)
#**********************************************************************
#"""
#    TestClass = TransactionOld1
#    testrunner = unittest.TextTestRunner(verbosity=2)
#    result = testrunner.run(suite())


#    print >> sys.stderr, """\
#**********************************************************************
#Old Transaction Class (copy.deepcopy)
#**********************************************************************
#"""
#    TestClass = TransactionOld2
#    testrunner = unittest.TextTestRunner(verbosity=2)
#    result = testrunner.run(suite())
#


#    print >> sys.stderr, """\
#**********************************************************************
#Old Transaction Class (copy.deepcopy + __dict__.clear)
#**********************************************************************
#"""
#    TestClass = TransactionOld3
#    testrunner = unittest.TextTestRunner(verbosity=2)
#    result = testrunner.run(suite())
#


#    print >> sys.stderr, """\
#**********************************************************************
#New Transaction Class (deep commit)
#**********************************************************************
#"""
#    TestClass = TransactionNew1
#    testrunner = unittest.TextTestRunner(verbosity=2)
#    result = testrunner.run(suite())
#        


#    print >> sys.stderr, """\
#**********************************************************************
#New Transaction Class (deep commit + stack improvement)
#**********************************************************************
#"""
#    TestClass = TransactionNew2
#    testrunner = unittest.TextTestRunner(verbosity=2)
#    result = testrunner.run(suite())
#    


#    print >> sys.stderr, """\
#**********************************************************************
#New Transaction Class (deep commit + stack improvement + recursion check)
#**********************************************************************
#"""
#    TestClass = TransactionNew3
#    TestTransaction.doRecursion = True
#    testrunner = unittest.TextTestRunner(verbosity=2)
#    result = testrunner.run(suite())
    
    print >> sys.stderr, """\
**********************************************************************
New Transaction Class (final)
**********************************************************************
"""
    TestClass = TransactionImproved
    TestTransaction.doRecursion = True
    testrunner = unittest.TextTestRunner(verbosity=2)
    result = testrunner.run(suite())

    sys.exit(not result.wasSuccessful())