From 6113828efebfe8d0c0779e583133bf8b153d3dbf Mon Sep 17 00:00:00 2001 From: ram Date: Sun, 10 Jul 2022 21:13:50 -0400 Subject: [PATCH 1/3] made changes to make it compatible with python 3 --- .vscode/settings.json | 3 + build.py | 51 +- build.py.bak | 400 ++++ docs/sphinx/source/conf.py | 3 +- docs/sphinx/source/conf.py.bak | 231 ++ instrunctions.txt | 10 + main.py | 32 + main.py.bak | 1 + requirements.txt | 21 + samples/petclinic/cherrypy/controller.py | 2 +- samples/petclinic/cherrypy/controller.py.bak | 314 +++ samples/petclinic/cherrypy/view.py | 2 +- samples/petclinic/cherrypy/view.py.bak | 476 ++++ samples/petclinic/configure.py | 15 +- samples/petclinic/configure.py.bak | 91 + samples/springwiki/model.py | 3 +- samples/springwiki/model.py.bak | 641 ++++++ samples/springwiki/view.py | 2 +- samples/springwiki/view.py.bak | 143 ++ src/plugins/gen-cherrypy-app/__init__.py | 5 +- src/plugins/gen-cherrypy-app/__init__.py.bak | 48 + src/setup-template.py | 3 +- src/setup-template.py.bak | 65 + src/springpython/aop/__init__.py | 4 +- src/springpython/aop/__init__.py.bak | 250 ++ src/springpython/config/__init__.py | 9 +- src/springpython/config/__init__.py.bak | 19 + src/springpython/config/_config_base.py | 3 +- src/springpython/config/_config_base.py.bak | 291 +++ src/springpython/config/_python_config.py | 25 +- src/springpython/config/_python_config.py.bak | 148 ++ src/springpython/config/_xml_config.py | 3 +- src/springpython/config/_xml_config.py.bak | 607 +++++ src/springpython/config/_yaml_config.py | 5 +- src/springpython/config/_yaml_config.py.bak | 418 ++++ src/springpython/config/decorator.py | 193 +- src/springpython/config/decorator.py.bak | 248 ++ src/springpython/container/__init__.py | 4 +- src/springpython/container/__init__.py.bak | 143 ++ src/springpython/context/__init__.py | 117 +- src/springpython/context/__init__.py.bak | 151 ++ src/springpython/database/core.py | 14 +- src/springpython/database/core.py.bak | 233 ++ src/springpython/database/factory.py | 18 +- src/springpython/database/factory.py.bak | 164 ++ src/springpython/database/transaction.py | 12 +- src/springpython/database/transaction.py.bak | 310 +++ src/springpython/factory/__init__.py | 4 +- src/springpython/factory/__init__.py.bak | 65 + src/springpython/jms/factory.py | 300 ++- src/springpython/jms/factory.py.bak | 817 +++++++ src/springpython/jms/listener.py | 6 +- src/springpython/jms/listener.py.bak | 117 + src/springpython/remoting/hessian/__init__.py | 3 +- .../remoting/hessian/__init__.py.bak | 42 + .../remoting/hessian/hessianlib.py | 14 +- .../remoting/hessian/hessianlib.py.bak | 481 ++++ .../remoting/pyro/Pyro4DaemonHolder.py | 5 +- .../remoting/pyro/Pyro4DaemonHolder.py.bak | 132 ++ .../remoting/pyro/PyroDaemonHolder.py | 3 +- .../remoting/pyro/PyroDaemonHolder.py.bak | 121 + src/springpython/remoting/pyro/__init__.py | 5 +- .../remoting/pyro/__init__.py.bak | 173 ++ src/springpython/remoting/xmlrpc.py | 2 +- src/springpython/remoting/xmlrpc.py.bak | 215 ++ src/springpython/security/providers/Ldap.py | 5 +- .../security/providers/Ldap.py.bak | 22 + .../security/providers/_Ldap_cpython.py | 2 +- .../security/providers/_Ldap_cpython.py.bak | 208 ++ .../security/providers/_Ldap_jython.py | 11 +- .../security/providers/_Ldap_jython.py.bak | 126 ++ .../security/providers/__init__.py | 4 +- .../security/providers/__init__.py.bak | 127 ++ src/springpython/security/providers/dao.py | 8 +- .../security/providers/dao.py.bak | 199 ++ src/springpython/security/web.py | 12 +- src/springpython/security/web.py.bak | 419 ++++ src/springpython/util.py | 4 +- src/springpython/util.py.bak | 51 + test/springpythontest/allTests.py | 5 +- test/springpythontest/allTests.py.bak | 39 + test/springpythontest/contextTestCases.py | 8 +- test/springpythontest/contextTestCases.py.bak | 2004 +++++++++++++++++ .../springpythontest/databaseCoreTestCases.py | 67 +- .../databaseCoreTestCases.py.bak | 920 ++++++++ .../databaseTransactionTestCases.py | 25 +- .../databaseTransactionTestCases.py.bak | 604 +++++ .../jms_websphere_mq_test_cases.py | 22 +- .../jms_websphere_mq_test_cases.py.bak | 1696 ++++++++++++++ test/springpythontest/remoting_xmlrpc.py | 2 +- test/springpythontest/remoting_xmlrpc.py.bak | 291 +++ test/springpythontest/securityWebTestCases.py | 8 +- .../securityWebTestCases.py.bak | 233 ++ .../support/testSupportClasses.py | 10 +- .../support/testSupportClasses.py.bak | 537 +++++ test/standalone/pyro_thread_test.py | 5 +- test/standalone/pyro_thread_test.py.bak | 59 + test/standalone/test.py | 9 +- test/standalone/test.py.bak | 35 + test/standalone/xsd_test_cases.py | 2 +- test/standalone/xsd_test_cases.py.bak | 85 + 101 files changed, 15911 insertions(+), 409 deletions(-) create mode 100644 .vscode/settings.json create mode 100755 build.py.bak create mode 100644 docs/sphinx/source/conf.py.bak create mode 100644 instrunctions.txt create mode 100644 main.py create mode 100644 main.py.bak create mode 100644 requirements.txt create mode 100644 samples/petclinic/cherrypy/controller.py.bak create mode 100644 samples/petclinic/cherrypy/view.py.bak create mode 100644 samples/petclinic/configure.py.bak create mode 100644 samples/springwiki/model.py.bak create mode 100644 samples/springwiki/view.py.bak create mode 100644 src/plugins/gen-cherrypy-app/__init__.py.bak create mode 100644 src/setup-template.py.bak create mode 100644 src/springpython/aop/__init__.py.bak create mode 100644 src/springpython/config/__init__.py.bak create mode 100644 src/springpython/config/_config_base.py.bak create mode 100644 src/springpython/config/_python_config.py.bak create mode 100644 src/springpython/config/_xml_config.py.bak create mode 100644 src/springpython/config/_yaml_config.py.bak create mode 100644 src/springpython/config/decorator.py.bak create mode 100644 src/springpython/container/__init__.py.bak create mode 100644 src/springpython/context/__init__.py.bak create mode 100644 src/springpython/database/core.py.bak create mode 100644 src/springpython/database/factory.py.bak create mode 100644 src/springpython/database/transaction.py.bak create mode 100644 src/springpython/factory/__init__.py.bak create mode 100644 src/springpython/jms/factory.py.bak create mode 100644 src/springpython/jms/listener.py.bak create mode 100644 src/springpython/remoting/hessian/__init__.py.bak create mode 100644 src/springpython/remoting/hessian/hessianlib.py.bak create mode 100644 src/springpython/remoting/pyro/Pyro4DaemonHolder.py.bak create mode 100644 src/springpython/remoting/pyro/PyroDaemonHolder.py.bak create mode 100644 src/springpython/remoting/pyro/__init__.py.bak create mode 100644 src/springpython/remoting/xmlrpc.py.bak create mode 100644 src/springpython/security/providers/Ldap.py.bak create mode 100644 src/springpython/security/providers/_Ldap_cpython.py.bak create mode 100644 src/springpython/security/providers/_Ldap_jython.py.bak create mode 100644 src/springpython/security/providers/__init__.py.bak create mode 100644 src/springpython/security/providers/dao.py.bak create mode 100644 src/springpython/security/web.py.bak create mode 100644 src/springpython/util.py.bak create mode 100644 test/springpythontest/allTests.py.bak create mode 100644 test/springpythontest/contextTestCases.py.bak create mode 100644 test/springpythontest/databaseCoreTestCases.py.bak create mode 100644 test/springpythontest/databaseTransactionTestCases.py.bak create mode 100644 test/springpythontest/jms_websphere_mq_test_cases.py.bak create mode 100644 test/springpythontest/remoting_xmlrpc.py.bak create mode 100644 test/springpythontest/securityWebTestCases.py.bak create mode 100644 test/springpythontest/support/testSupportClasses.py.bak create mode 100644 test/standalone/pyro_thread_test.py.bak create mode 100644 test/standalone/test.py.bak create mode 100644 test/standalone/xsd_test_cases.py.bak diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..b7368ca --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "python.formatting.provider": "black" +} diff --git a/build.py b/build.py index 5a80c7b..bbc1184 100755 --- a/build.py +++ b/build.py @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. """ +from __future__ import print_function from datetime import datetime from glob import glob import logging @@ -49,11 +50,11 @@ def load_properties(prop_dict, prop_file): "This function loads standard, java-style properties files into a dictionary." if os.path.exists(prop_file): - print "Reading property file " + prop_file + print("Reading property file " + prop_file) [prop_dict.update({prop.split("=")[0].strip(): prop.split("=")[1].strip()}) for prop in open(prop_file).readlines() if not (prop.startswith("#") or prop.strip() == "")] else: - print "Unable to read property file " + prop_file + print("Unable to read property file " + prop_file) # Override defaults with a properties file load_properties(p, "springpython.properties") @@ -65,22 +66,22 @@ def load_properties(prop_dict, prop_file): def usage(): """This function is used to print out help either by request, or if an invalid option is used.""" - print - print "Usage: python build.py [command]" - print - print "\t--help\t\t\tprint this help message" - print "\t--clean\t\t\tclean out this build by deleting the %s directory" % p["targetDir"] - print "\t--test\t\t\trun the test suite, leaving all artifacts in %s" % p["testDir"] - print "\t--suite [suite]\t\trun a specific test suite, leaving all artifacts in %s" % p["testDir"] - print "\t--coverage\t\trun the test suite with coverage analysis, leaving all artifacts in %s" % p["testDir"] - print "\t--debug-level [info|debug]\n\t\t\t\tthreshold of logging message when running tests or coverage analysis" - print "\t--package\t\tpackage everything up into a tarball for release to sourceforge in %s" % p["packageDir"] - print "\t--build-stamp [tag]\tfor --package, this specifies a special tag, generating version tag '%s.. springpython.properties can override with build.stamp'" % p["version"] - print "\t\t\t\tIf this option isn't used, default will be tag will be '%s.'" % p["version"] - print "\t--register\t\tregister this release with http://pypi.python.org/pypi" - print "\t--docs-sphinx\t\tgenerate Sphinx documentation" - print "\t--pydoc\t\t\tgenerate pydoc information" - print + print() + print("Usage: python build.py [command]") + print() + print("\t--help\t\t\tprint this help message") + print("\t--clean\t\t\tclean out this build by deleting the %s directory" % p["targetDir"]) + print("\t--test\t\t\trun the test suite, leaving all artifacts in %s" % p["testDir"]) + print("\t--suite [suite]\t\trun a specific test suite, leaving all artifacts in %s" % p["testDir"]) + print("\t--coverage\t\trun the test suite with coverage analysis, leaving all artifacts in %s" % p["testDir"]) + print("\t--debug-level [info|debug]\n\t\t\t\tthreshold of logging message when running tests or coverage analysis") + print("\t--package\t\tpackage everything up into a tarball for release to sourceforge in %s" % p["packageDir"]) + print("\t--build-stamp [tag]\tfor --package, this specifies a special tag, generating version tag '%s.. springpython.properties can override with build.stamp'" % p["version"]) + print("\t\t\t\tIf this option isn't used, default will be tag will be '%s.'" % p["version"]) + print("\t--register\t\tregister this release with http://pypi.python.org/pypi") + print("\t--docs-sphinx\t\tgenerate Sphinx documentation") + print("\t--pydoc\t\t\tgenerate pydoc information") + print() try: optlist, args = getopt.getopt(sys.argv[1:], @@ -89,7 +90,7 @@ def usage(): "register", "docs-sphinx", "pydoc"]) except getopt.GetoptError: # print help information and exit: - print "Invalid command found in %s" % sys.argv + print("Invalid command found in %s" % sys.argv) usage() sys.exit(2) @@ -105,7 +106,7 @@ def usage(): ############################################################################ def clean(dir): - print "Removing '%s' directory" % dir + print("Removing '%s' directory" % dir) if os.path.exists(".coverage"): os.remove(".coverage") if os.path.exists(dir): @@ -205,7 +206,7 @@ def package(dir, version, s3bucket, src_filename, sample_filename): os.makedirs(dir) _substitute("src/plugins/coily-template", "src/plugins/coily", [("version", version)]) - os.chmod("src/plugins/coily", 0755) + os.chmod("src/plugins/coily", 0o755) build("src", version, s3bucket, src_filename) build("samples", version, s3bucket, sample_filename) #os.remove("src/plugins/coily") @@ -232,7 +233,7 @@ def register(): def copy(src, dest, patterns): if not os.path.exists(dest): - print "+++ Creating " + dest + print("+++ Creating " + dest) os.makedirs(dest) [shutil.copy(file, dest) for pattern in patterns for file in glob(src + pattern)] @@ -317,7 +318,7 @@ def create_pydocs(): for file in os.listdir("."): if "springpython" not in file: continue - print "Altering appearance of %s" % file + print("Altering appearance of %s" % file) file_input = open(file).read() file_input = re.compile(top_color).sub("GREEN", file_input) file_input = re.compile(pkg_color).sub("GREEN", file_input) @@ -375,11 +376,11 @@ def create_pydocs(): clean(p["targetDir"]) if option[0] in ("--test"): - print "Running checkin tests..." + print("Running checkin tests...") test(p["testDir"], "checkin", debug_level) if option[0] in ("--suite"): - print "Running test suite %s..." % option[1] + print("Running test suite %s..." % option[1]) test(p["testDir"], option[1], debug_level) if option[0] in ("--coverage"): diff --git a/build.py.bak b/build.py.bak new file mode 100755 index 0000000..5a80c7b --- /dev/null +++ b/build.py.bak @@ -0,0 +1,400 @@ +#!/usr/bin/python +""" + Copyright 2006-2011 SpringSource (http://springsource.com), All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +from datetime import datetime +from glob import glob +import logging +import mimetypes +import os +import pydoc +import re +import sys +import tarfile +import getopt +import shutil + +try: + import hashlib + _sha = hashlib.sha1 +except ImportError: + import sha + _sha = sha.new + +############################################################################ +# Get external properties and load into a dictionary. NOTE: These properties +# files mimic Java props files. +############################################################################ + +p = {} + +# Default settings, before reading the properties file +p["targetDir"] = "target" +p["testDir"] = "%s/test-results/xml" % p["targetDir"] +p["packageDir"] = "%s/artifacts" % p["targetDir"] + + +def load_properties(prop_dict, prop_file): + "This function loads standard, java-style properties files into a dictionary." + if os.path.exists(prop_file): + print "Reading property file " + prop_file + [prop_dict.update({prop.split("=")[0].strip(): prop.split("=")[1].strip()}) + for prop in open(prop_file).readlines() if not (prop.startswith("#") or prop.strip() == "")] + else: + print "Unable to read property file " + prop_file + +# Override defaults with a properties file +load_properties(p, "springpython.properties") + +############################################################################ +# Read the command-line, and assemble commands. Any invalid command, print +# usage info, and EXIT. +############################################################################ + +def usage(): + """This function is used to print out help either by request, or if an invalid option is used.""" + print + print "Usage: python build.py [command]" + print + print "\t--help\t\t\tprint this help message" + print "\t--clean\t\t\tclean out this build by deleting the %s directory" % p["targetDir"] + print "\t--test\t\t\trun the test suite, leaving all artifacts in %s" % p["testDir"] + print "\t--suite [suite]\t\trun a specific test suite, leaving all artifacts in %s" % p["testDir"] + print "\t--coverage\t\trun the test suite with coverage analysis, leaving all artifacts in %s" % p["testDir"] + print "\t--debug-level [info|debug]\n\t\t\t\tthreshold of logging message when running tests or coverage analysis" + print "\t--package\t\tpackage everything up into a tarball for release to sourceforge in %s" % p["packageDir"] + print "\t--build-stamp [tag]\tfor --package, this specifies a special tag, generating version tag '%s.. springpython.properties can override with build.stamp'" % p["version"] + print "\t\t\t\tIf this option isn't used, default will be tag will be '%s.'" % p["version"] + print "\t--register\t\tregister this release with http://pypi.python.org/pypi" + print "\t--docs-sphinx\t\tgenerate Sphinx documentation" + print "\t--pydoc\t\t\tgenerate pydoc information" + print + +try: + optlist, args = getopt.getopt(sys.argv[1:], + "hct", + ["help", "clean", "test", "suite=", "debug-level=", "coverage", "package", "build-stamp=", \ + "register", "docs-sphinx", "pydoc"]) +except getopt.GetoptError: + # print help information and exit: + print "Invalid command found in %s" % sys.argv + usage() + sys.exit(2) + +############################################################################ +# Pre-generate needed values +############################################################################ + +# Default build stamp value +build_stamp = "BUILD-%s" % datetime.now().strftime("%Y%m%d%H%M%S") + +############################################################################ +# Definition of operations this script can do. +############################################################################ + +def clean(dir): + print "Removing '%s' directory" % dir + if os.path.exists(".coverage"): + os.remove(".coverage") + if os.path.exists(dir): + shutil.rmtree(dir, True) + for root, dirs, files in os.walk(".", topdown=False): + for name in files: + if name.endswith(".pyc") or name.endswith(".class"): + os.remove(os.path.join(root, name)) + +def test(dir, test_suite, debug_level): + """ + Run nose programmatically, so that it uses the same python version as this script uses + + Nose expects to receive a sys.argv, of which the first arg is the script path (usually nosetests). Since this isn't + being run that way, a filler entry was created to satisfy the library's needs. + """ + if not os.path.exists(dir): + os.makedirs(dir) + + try: + import java + if test_suite == "checkin": test_suite = "jython" + _run_nose(argv=["", "--where=test/springpythontest", test_suite], debug_level=debug_level) + except ImportError: + _run_nose(argv=["", "--with-nosexunit", "--source-folder=src", "--where=test/springpythontest", "--xml-report-folder=%s" % dir, test_suite], debug_level=debug_level) + +def test_coverage(dir, test_suite, debug_level): + """ + Run nose programmatically, so that it uses the same python version as this script uses + + Nose expects to receive a sys.argv, of which the first arg is the script path (usually nosetests). Since this isn't + being run that way, a filler entry was created to satisfy the library's needs. + """ + + if not os.path.exists(dir): + os.makedirs(dir) + + _run_nose(argv=["", "--with-nosexunit", "--source-folder=src", "--where=test/springpythontest", "--xml-report-folder=%s" % dir, "--with-coverage", "--cover-package=springpython", test_suite], debug_level=debug_level) + +def _run_nose(argv, debug_level): + logger = logging.getLogger("springpython") + loggingLevel = debug_level + logger.setLevel(loggingLevel) + ch = logging.StreamHandler() + ch.setLevel(loggingLevel) + formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") + ch.setFormatter(formatter) + logger.addHandler(ch) + + + # Running nose causes the stdout/stderr to get changed, and also it changes directories as well. + _stdout, _stderr, _curdir = sys.stdout, sys.stderr, os.getcwd() + + import nose + nose.run(argv=argv) + + # Restored these streams + sys.stdout, sys.stderr = _stdout, _stderr + os.chdir(_curdir) + +def _substitute(input_file, output_file, patterns_to_replace): + """Scan the input file, and do a pattern substitution, writing all results to output file.""" + input = open(input_file).read() + output = open(output_file, "w") + for pattern, replacement in patterns_to_replace: + input = re.compile(r"\$\{%s}" % pattern).sub(replacement, input) + output.write(input) + output.close() + +def build(dir, version, s3bucket, filepath): + filename = filepath.split("/")[-1] + s3key = "/".join([ p['release.type'], + p['project.key'], + filename ]) + + patterns_to_replace = [("version", version)] + + _substitute(dir + "/setup-template.py", dir + "/setup.py", patterns_to_replace) + + os.chdir(dir) + os.system("%s %s sdist" % (sys.executable, os.path.join(".", "setup.py"))) + os.chdir("..") + + dist_dir = os.path.join(os.getcwd(), dir, "dist") + + for name in os.listdir(dist_dir): + old_location = os.path.join(dist_dir,name) + new_location = "." + shutil.move(old_location, new_location) + + os.rmdir(dist_dir) + if os.path.exists(os.path.join(dir, "MANIFEST")): + os.remove(os.path.join(dir, "MANIFEST")) + +def package(dir, version, s3bucket, src_filename, sample_filename): + if not os.path.exists(dir): + os.makedirs(dir) + + _substitute("src/plugins/coily-template", "src/plugins/coily", [("version", version)]) + os.chmod("src/plugins/coily", 0755) + build("src", version, s3bucket, src_filename) + build("samples", version, s3bucket, sample_filename) + #os.remove("src/plugins/coily") + + for name in glob("*.tar.gz"): + old_location = os.path.join(".", name) + shutil.move(old_location, dir) + + curdir = os.getcwd() + os.chdir("src/plugins") + for item in os.listdir("."): + if item in ["coily-template", ".svn"]: continue + t = tarfile.open("../../%s/springpython-plugin-%s.%s.tar.gz" % (dir, item, version), "w:gz") + for path, dirs, files in os.walk(item): + if ".svn" not in path: # Don't want to include version information + t.add(path, recursive=False) + [t.add(path + "/" + file, recursive=False) for file in files] + t.close() + os.chdir(curdir) + +def register(): + os.system("cd src ; %s setup.py register sdist upload" % sys.executable) + os.system("cd samples ; %s setup.py register sdist upload" % sys.executable) + +def copy(src, dest, patterns): + if not os.path.exists(dest): + print "+++ Creating " + dest + os.makedirs(dest) + + [shutil.copy(file, dest) for pattern in patterns for file in glob(src + pattern)] + +def setup(root, stylesheets=True): + copy( + p["doc.ref.dir"]+"/src/images/", + root + "/images/", + ["*.gif", "*.svg", "*.jpg", "*.png"]) + + docbook_images_dir = p["targetDir"] + "/" + p["dist.ref.dir"] + "/images" + if not os.path.exists(docbook_images_dir): + shutil.copytree(p["doc.ref.dir"]+"/images/", docbook_images_dir) + + if stylesheets: + copy( + p["doc.ref.dir"]+"/styles/", + root, + ["*.css", "*.js"]) + +def sub_version(cur, version): + _substitute(cur + "/" + p["doc.ref.dir"] + "/src/index.xml", cur + "/" + p["doc.ref.dir"] + "/src/mangled.xml", [("version", version)]) + +def docs_sphinx(): + cur = os.getcwd() + os.chdir("docs/sphinx") + os.system("make clean html epub man") + os.chdir(cur) + shutil.copytree("docs/sphinx/build/html", "target/docs/sphinx/html") + shutil.copytree("docs/sphinx/build/man", "target/docs/sphinx/man") + shutil.copy("docs/sphinx/build/epub/SpringPython.epub", "target/docs/") + +def create_pydocs(): + sys.path.append(os.getcwd() + "/src") + import springpython + + if not os.path.exists("target/docs/pydoc"): + os.makedirs("target/docs/pydoc") + + cur = os.getcwd() + os.chdir("target/docs/pydoc") + + pydoc.writedoc("springpython") + pydoc.writedoc("springpython.aop") + pydoc.writedoc("springpython.aop.utils") + pydoc.writedoc("springpython.config") + pydoc.writedoc("springpython.config.decorator") + pydoc.writedoc("springpython.container") + pydoc.writedoc("springpython.context") + pydoc.writedoc("springpython.context.scope") + pydoc.writedoc("springpython.database") + pydoc.writedoc("springpython.database.core") + pydoc.writedoc("springpython.database.factory") + pydoc.writedoc("springpython.database.transaction") + pydoc.writedoc("springpython.factory") + pydoc.writedoc("springpython.remoting") + pydoc.writedoc("springpython.remoting.hessian") + pydoc.writedoc("springpython.remoting.hessian.hessianlib") + pydoc.writedoc("springpython.remoting.pyro") + pydoc.writedoc("springpython.remoting.pyro.PyroDaemonHolder") + pydoc.writedoc("springpython.security") + pydoc.writedoc("springpython.security.cherrypy3") + pydoc.writedoc("springpython.security.intercept") + pydoc.writedoc("springpython.security.context") + pydoc.writedoc("springpython.security.context.SecurityContextHolder") + pydoc.writedoc("springpython.security.providers") + pydoc.writedoc("springpython.security.providers.dao") + pydoc.writedoc("springpython.security.providers.encoding") + pydoc.writedoc("springpython.security.providers.Ldap") + pydoc.writedoc("springpython.security.providers._Ldap_cpython") + pydoc.writedoc("springpython.security.providers._Ldap_jython") + pydoc.writedoc("springpython.security.userdetails") + pydoc.writedoc("springpython.security.userdetails.dao") + pydoc.writedoc("springpython.security.web") + + top_color = "#7799ee" + pkg_color = "#aa55cc" + class_color = "#ee77aa" + class_highlight = "#ffc8d8" + function_color = "#eeaa77" + data_color = "#55aa55" + + for file in os.listdir("."): + if "springpython" not in file: continue + print "Altering appearance of %s" % file + file_input = open(file).read() + file_input = re.compile(top_color).sub("GREEN", file_input) + file_input = re.compile(pkg_color).sub("GREEN", file_input) + file_input = re.compile(class_color).sub("GREEN", file_input) + file_input = re.compile(class_highlight).sub("LIGHTGREEN", file_input) + file_input = re.compile(function_color).sub("LIGHTGREEN", file_input) + file_input = re.compile(data_color).sub("LIGHTGREEN", file_input) + file_output = open(file, "w") + file_output.write(file_input) + file_output.close() + + os.chdir(cur) + + +############################################################################ +# Pre-commands. Skim the options, and pick out commands the MUST be +# run before others. +############################################################################ + +debug_levels = {"info":logging.INFO, "debug":logging.DEBUG} +debug_level = debug_levels["info"] # Default debug level is INFO + +# No matter what order the command are specified in, the build-stamp must be extracted first. +for option in optlist: + if option[0] == "--build-stamp": + build_stamp = option[1] # Override build stamp with user-supplied version + + if option[0] in ("--debug-level"): + debug_level = debug_levels[option[1]] # Override with a user-supplied debug level + +# However, a springpython.properties entry can override the command-line +if "build.stamp" in p: + build_stamp = p["build.stamp"] +complete_version = p["version"] + "." + build_stamp + +# However, a springpython.properties entry can override the command-line +if "debug.level" in p: + debug_level = debug_levels[p["debug.level"]] + +# Check for help requests, which cause all other options to be ignored. Help can offer version info, which is +# why it comes as the second check +for option in optlist: + if option[0] in ("--help", "-h"): + usage() + sys.exit(1) + +############################################################################ +# Main commands. Skim the options, and run each command as its found. +# Commands are run in the order found ON THE COMMAND LINE. +############################################################################ + +# Parse the arguments, in order +for option in optlist: + if option[0] in ("--clean", "-c"): + clean(p["targetDir"]) + + if option[0] in ("--test"): + print "Running checkin tests..." + test(p["testDir"], "checkin", debug_level) + + if option[0] in ("--suite"): + print "Running test suite %s..." % option[1] + test(p["testDir"], option[1], debug_level) + + if option[0] in ("--coverage"): + test_coverage(p["testDir"], "checkin", debug_level) + + if option[0] in ("--package"): + package(p["packageDir"], complete_version, p['s3.bucket'], "springpython", "springpython-samples") + + if option[0] in ("--register"): + register() + + if option[0] in ("--docs-sphinx"): + docs_sphinx() + + if option[0] in ("--pydoc"): + create_pydocs() + + diff --git a/docs/sphinx/source/conf.py b/docs/sphinx/source/conf.py index 348b188..caeccce 100644 --- a/docs/sphinx/source/conf.py +++ b/docs/sphinx/source/conf.py @@ -11,6 +11,7 @@ # All configuration values have a default; values that are commented out # serve to show the default. +from __future__ import print_function import sys, os # Read the properties file to harvest version info. @@ -22,7 +23,7 @@ (key, value) = line.split("=") p[key] = value[:-1] -print p +print(p) # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the diff --git a/docs/sphinx/source/conf.py.bak b/docs/sphinx/source/conf.py.bak new file mode 100644 index 0000000..348b188 --- /dev/null +++ b/docs/sphinx/source/conf.py.bak @@ -0,0 +1,231 @@ +# -*- coding: utf-8 -*- +# +# Spring Python documentation build configuration file, created by +# sphinx-quickstart on Fri Jun 4 12:32:42 2010. +# +# This file is execfile()d with the current directory set to its containing dir. +# +# Note that not all possible configuration values are present in this +# autogenerated file. +# +# All configuration values have a default; values that are commented out +# serve to show the default. + +import sys, os + +# Read the properties file to harvest version info. +p = {} +with open("../../../springpython.properties") as f: + lines = f.readlines() + for line in lines: + if "=" in line and not line.startswith("#"): + (key, value) = line.split("=") + p[key] = value[:-1] + +print p + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +#sys.path.append(os.path.abspath('.')) + +# -- General configuration ----------------------------------------------------- + +# If your documentation needs a minimal Sphinx version, state it here. +#needs_sphinx = '1.0' + +# Add any Sphinx extension module names here, as strings. They can be extensions +# coming with Sphinx (named 'sphinx.ext.*') or your custom ones. +extensions = ['sphinx.ext.autodoc', 'sphinx.ext.ifconfig', 'sphinx.ext.viewcode'] + +# Add any paths that contain templates here, relative to this directory. +templates_path = ['_templates'] + +# The suffix of source filenames. +source_suffix = '.rst' + +# The encoding of source files. +#source_encoding = 'utf-8-sig' + +# The master toctree document. +master_doc = 'index' + +# General information about the project. +project = u'Spring Python' +copyright = u'2006-2011, Greg Turnquist, Dariusz Suchojad' + +# The version info for the project you're documenting, acts as replacement for +# |version| and |release|, also used in various other places throughout the +# built documents. +# +# The short X.Y version. +version = ".".join(p["version"].split(".")[0:2]) + +# The full version, including alpha/beta/rc tags. +if "build.stamp" in p: + release = p["version"] + "." + p["build.stamp"] +else: + release = p["version"] + "." + p["release.type"].upper() + +# The language for content autogenerated by Sphinx. Refer to documentation +# for a list of supported languages. +#language = None + +# There are two options for replacing |today|: either, you set today to some +# non-false value, then it is used: +#today = '' +# Else, today_fmt is used as the format for a strftime call. +#today_fmt = '%B %d, %Y' + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +exclude_patterns = [] + +# The reST default role (used for this markup: `text`) to use for all documents. +#default_role = None + +# If true, '()' will be appended to :func: etc. cross-reference text. +#add_function_parentheses = True + +# If true, the current module name will be prepended to all description +# unit titles (such as .. function::). +#add_module_names = True + +# If true, sectionauthor and moduleauthor directives will be shown in the +# output. They are ignored by default. +#show_authors = False + +# The name of the Pygments (syntax highlighting) style to use. +pygments_style = 'sphinx' + +# A list of ignored prefixes for module index sorting. +#modindex_common_prefix = [] + + +# -- Options for HTML output --------------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +html_theme = 'nature' + +# Theme options are theme-specific and customize the look and feel of a theme +# further. For a list of options available for each theme, see the +# documentation. +#html_theme_options = {} + +# Add any paths that contain custom themes here, relative to this directory. +#html_theme_path = [] + +# The name for this set of Sphinx documents. If None, it defaults to +# " v documentation". +#html_title = None + +# A shorter title for the navigation bar. Default is the same as html_title. +#html_short_title = None + +# The name of an image file (relative to this directory) to place at the top +# of the sidebar. +#html_logo = None + +# The name of an image file (within the static path) to use as favicon of the +# docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 +# pixels large. +#html_favicon = None + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ['_static'] + +# If not '', a 'Last updated on:' timestamp is inserted at every page bottom, +# using the given strftime format. +#html_last_updated_fmt = '%b %d, %Y' + +# If true, SmartyPants will be used to convert quotes and dashes to +# typographically correct entities. +#html_use_smartypants = True + +# Custom sidebar templates, maps document names to template names. +#html_sidebars = {} + +# Additional templates that should be rendered to pages, maps page names to +# template names. +#html_additional_pages = {} + +# If false, no module index is generated. +#html_domain_indices = True + +# If false, no index is generated. +#html_use_index = True + +# If true, the index is split into individual pages for each letter. +#html_split_index = False + +# If true, links to the reST sources are added to the pages. +#html_show_sourcelink = True + +# If true, "Created using Sphinx" is shown in the HTML footer. Default is True. +#html_show_sphinx = True + +# If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. +#html_show_copyright = True + +# If true, an OpenSearch description file will be output, and all pages will +# contain a tag referring to it. The value of this option must be the +# base URL from which the finished HTML is served. +#html_use_opensearch = '' + +# If nonempty, this is the file name suffix for HTML files (e.g. ".xhtml"). +#html_file_suffix = '' + +# Output file base name for HTML help builder. +htmlhelp_basename = 'SpringPythondoc' + + +# -- Options for LaTeX output -------------------------------------------------- + +# The paper size ('letter' or 'a4'). +#latex_paper_size = 'letter' + +# The font size ('10pt', '11pt' or '12pt'). +#latex_font_size = '10pt' + +# Grouping the document tree into LaTeX files. List of tuples +# (source start file, target name, title, author, documentclass [howto/manual]). +latex_documents = [ + ('index', 'SpringPython.tex', u'Spring Python Documentation', + u'Greg Turnquist, Dariusz Suchojad', 'manual'), +] + +# The name of an image file (relative to this directory) to place at the top of +# the title page. +#latex_logo = None + +# For "manual" documents, if this is true, then toplevel headings are parts, +# not chapters. +#latex_use_parts = False + +# If true, show page references after internal links. +#latex_show_pagerefs = False + +# If true, show URL addresses after external links. +#latex_show_urls = False + +# Additional stuff for the LaTeX preamble. +#latex_preamble = '' + +# Documents to append as an appendix to all manuals. +#latex_appendices = [] + +# If false, no module index is generated. +#latex_domain_indices = True + + +# -- Options for manual page output -------------------------------------------- + +# One entry per manual page. List of tuples +# (source start file, name, description, authors, manual section). +man_pages = [ + ('index', 'springpython', u'Spring Python Documentation', + [u'Greg Turnquist, Dariusz Suchojad'], 1) +] diff --git a/instrunctions.txt b/instrunctions.txt new file mode 100644 index 0000000..74830db --- /dev/null +++ b/instrunctions.txt @@ -0,0 +1,10 @@ +build package: +python build.py --package + +#setup file exists in src folder + +install module: +pip install src/ + +after making changes to the source, rebuild and pip install src +in the working python env. diff --git a/main.py b/main.py new file mode 100644 index 0000000..abd9316 --- /dev/null +++ b/main.py @@ -0,0 +1,32 @@ +import pdb +from src.springpython.config import PythonConfig, Object +from src.springpython.context import ApplicationContext + + +class WikiService(object): + """ + serves wiki + """ + + def __init__(self): + self.data = "wiki service" + + def get_name(self): + return "test service" + + +class WikiProductAppConfig(PythonConfig): + def __init__(self): + super(WikiProductAppConfig, self).__init__() + + @Object + def wiki_service(self): + return WikiService() + + +if __name__ == "__main__": + + ctx = ApplicationContext(WikiProductAppConfig()) + service = ctx.get_object("wiki_service") + + assert service.get_name() == "test service" diff --git a/main.py.bak b/main.py.bak new file mode 100644 index 0000000..fc3eccd --- /dev/null +++ b/main.py.bak @@ -0,0 +1 @@ +print "python 2 works" \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..9a997ee --- /dev/null +++ b/requirements.txt @@ -0,0 +1,21 @@ +attrs==18.2.0 +Automat==0.7.0 +backports.functools-lru-cache==1.6.4 +caniusepython3==7.3.0 +certifi==2021.10.8 +chardet==4.0.0 +constantly==15.1.0 +distlib==0.3.4 +futures==3.3.0 +hyperlink==18.0.0 +idna==2.8 +incremental==17.5.0 +nose==1.3.7 +packaging==20.9 +PyHamcrest==1.9.0 +pyparsing==2.4.7 +requests==2.27.1 +six==1.12.0 +Twisted==18.9.0 +urllib3==1.26.10 +zope.interface==4.6.0 diff --git a/samples/petclinic/cherrypy/controller.py b/samples/petclinic/cherrypy/controller.py index 455a1a2..b70bc43 100644 --- a/samples/petclinic/cherrypy/controller.py +++ b/samples/petclinic/cherrypy/controller.py @@ -181,7 +181,7 @@ def getUsername(self, id): SELECT username FROM owners WHERE id = ? - """, (id,), types.StringType) + """, (id,), bytes) def getUsers(self): """ diff --git a/samples/petclinic/cherrypy/controller.py.bak b/samples/petclinic/cherrypy/controller.py.bak new file mode 100644 index 0000000..455a1a2 --- /dev/null +++ b/samples/petclinic/cherrypy/controller.py.bak @@ -0,0 +1,314 @@ +""" + Copyright 2006-2008 SpringSource (http://springsource.com), All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +import cgi +import logging +import types +from datetime import date +from springpython.database.core import DaoSupport +from springpython.database.core import DatabaseTemplate +from springpython.database.core import RowMapper +from springpython.security import UsernameNotFoundException +from springpython.security.userdetails import UserDetailsService +from model import Owner +from model import Pet +from model import PetType +from model import Specialty +from model import Vet +from model import Visit +from springpython.security.vote import AccessDecisionVoter + +class PetClinicController(DaoSupport): + """This is a database-orienetd controller. Controllers are often responsible for providing data to populate views + and also interface with other subsystems. For example, another version of the PetClinic controller could interface with a + Directory Server to fetch data like telephone numbers and email addresses. + + All of the SQL operations use binding variables ("?") to prevent SQL injection attacks. This is a highly recommended + feature in order to avoid security holes in your application.""" + + def getVets(self): + """Return a list of vets from the database.""" + return self.database_template.query(""" + SELECT + id, + first_name, + last_name + FROM vets + """, rowhandler=VetRowMapper()) + + def getOwners(self, lastName = ""): + """Return a list of owners, filtered by partial lastname.""" + return self.database_template.query(""" + SELECT + id, + first_name, + last_name, + address, + city, + telephone + FROM owners + WHERE upper(last_name) like ? + """, ("%"+lastName.upper()+"%",), OwnerRowMapper()) + + def getOwner(self, id): + """Return one owner.""" + return self.database_template.query(""" + SELECT + id, + first_name, + last_name, + address, + city, + telephone + FROM owners + WHERE id = ? + """, (id,), OwnerRowMapper())[0] + + def addOwner(self, **kwargs): + """Add an owner to the database.""" + rowsAffected = self.database_template.execute(""" + INSERT INTO owners + (first_name, last_name, address, city, telephone) + VALUES + (?, ?, ?, ?, ?) + """, (kwargs["firstName"], kwargs["lastName"], kwargs["address"], kwargs["city"], kwargs["telephone"])) + return rowsAffected + + def updateOwner(self, id, address = "", city = "", telephone = ""): + """Add an owner to the database.""" + rowsAffected = self.database_template.update(""" + UPDATE owners + SET + address = ?, + city = ?, + telephone = ? + WHERE id = ? + """, (address, city, telephone, id)) + return rowsAffected + + def getPets(self, owner): + """Return pets belonging to a particular owner.""" + return self.database_template.query(""" + SELECT + pets.id, + pets.name, + pets.birth_date, + types.name + FROM pets, owners, types + WHERE owners.id = ? + AND owners.id = pets.owner_id + AND types.id = pets.type_id + """, (owner.id,), PetRowMapper()) + + def getPet(self, id): + """Return pets belonging to a particular owner.""" + return self.database_template.query(""" + SELECT + pets.id, + pets.name, + pets.birth_date, + types.name + FROM pets, types + WHERE pets.id = ? + AND types.id = pets.type_id + """, (id,), PetRowMapper())[0] + + def getVisits(self, pet): + """Return visits associated with a particular pet.""" + return self.database_template.query(""" + SELECT + visits.visit_date, + visits.description + FROM pets, visits + WHERE pets.id = ? + AND pets.id = visits.pet_id + """, (pet.id,), VisitRowMapper()) + + def addPet(self, id, name, birthDate, type): + """Store a new pet in the database.""" + rowsAffected = self.database_template.execute(""" + INSERT INTO pets + (name, birth_date, type_id, owner_id) + values + (?, ?, ?, ?) + """, (name, birthDate, type, id)) + return rowsAffected + + def getPetTypes(self): + """Return visits associated with a particular pet.""" + return self.database_template.query(""" + SELECT types.id, types.name + FROM types + """, rowhandler=PetTypeRowMapper()) + + def visitClinic(self, petId, description): + """Record a visit to the clinic.""" + rowsAffected = self.database_template.execute(""" + INSERT INTO visits + (pet_id, description, visit_date) + values + (?, ?, ?) + """, (petId, description, date.today())) + return rowsAffected + + def getVetSpecialties(self, vet): + """Look up specialties associated with a particular veterinarian.""" + return self.database_template.query(""" + SELECT + specialties.id, + specialties.name + FROM vets, vet_specialties, specialties + WHERE vets.id = vet_specialties.vet_id + AND vet_specialties.specialty_id = specialties.id + AND vets.id = ? + """, (vet.id,), SpecialtyRowMapper()) + + def getUsername(self, id): + """Look up the username associated with a user id""" + return self.database_template.query_for_object(""" + SELECT username + FROM owners + WHERE id = ? + """, (id,), types.StringType) + + def getUsers(self): + """ + This function fetches the users out of the database, so someone trying out PetClinic + can get the passwords to log in. + """ + users = self.database_template.query_for_list("select username, password, ' ', enabled from users") + for i in range(len(users)): + authorities = [row for (row,) in self.database_template.query_for_list("select authority from authorities where username = ?", (users[i][0],))] + users[i] = (users[i][0], users[i][1], authorities, users[i][3]) + return users + +class VetRowMapper(RowMapper): + """This is a row callback handler used in a database template call. It is used to process + one row of data from a Vet-oriented query by mapping a Vet-record.""" + def map_row(self, row, metadata=None): + vet = Vet() + vet.id = row[0] + vet.firstName = row[1] + vet.lastName = row[2] + return vet + +class OwnerRowMapper(RowMapper): + """This is a row callback handler used in a database template call. It is used to process + one row of data from an owner-oriented query by mapping an Owner-record.""" + def map_row(self, row, metadata=None): + owner = Owner() + owner.id = row[0] + owner.firstName = row[1] + owner.lastName = row[2] + owner.address = row[3] + owner.city = row[4] + owner.telephone = row[5] + return owner + +class PetRowMapper(RowMapper): + """This is a row callback handler used in a database template call. It is used to process + one row of data from a pet-oriented query by mapping an Pet-record.""" + def map_row(self, row, metadata=None): + pet = Pet() + pet.id = row[0] + pet.name = row[1] + pet.birthDate = row[2] + pet.type = row[3] + return pet + +class PetTypeRowMapper(RowMapper): + """This is a row callback handler used in a database template call. It is used to process + one row of data from a visit-oriented query by mapping an Visit-record.""" + def map_row(self, row, metadata=None): + petType = PetType() + petType.id = row[0] + petType.name = row[1] + return petType + +class SpecialtyRowMapper(RowMapper): + """This is a row callback handler used in a database template call. It is used to process + one row of data from a visit-oriented query by mapping an Visit-record.""" + def map_row(self, row, metadata=None): + specialty = Specialty() + specialty.id = row[0] + specialty.name = row[1] + return specialty + +class VisitRowMapper(RowMapper): + """This is a row callback handler used in a database template call. It is used to process + one row of data from a visit-oriented query by mapping an Visit-record.""" + def map_row(self, row, metadata=None): + visit = Visit() + visit.date = row[0] + visit.description = row[1] + return visit + +class OwnerVoter(AccessDecisionVoter): + def __init__(self, controller=None): + self.controller = controller + self.logger = logging.getLogger("springpython.petclinic.controller") + + def supports(self, attr): + """This voter will support a list. + """ + if isinstance(attr, list) or (attr is not None and attr == "OWNER"): + return True + else: + return False + + def vote(self, authentication, invocation, config): + """Grant access if any of the granted authorities matches any of the required + roles. + """ + results = self.ACCESS_ABSTAIN + for attribute in config: + if self.supports(attribute): + self.logger.debug("This OWNER voter will vote whether user owns this record.") + results = self.ACCESS_DENIED + id = cgi.parse_qs(invocation.environ["QUERY_STRING"])["id"][0] + if self.controller.getUsername(id) == authentication.username: + self.logger.debug("User %s owns this record. Access GRANTED!" % authentication.username) + return self.ACCESS_GRANTED + + if results == self.ACCESS_ABSTAIN: + self.logger.debug("This OWNER voter is abstaining from voting") + elif results == self.ACCESS_DENIED: + self.logger.debug("This OWNER voter did NOT own this record.") + + return results + + def __str__(self): + return "" + +class PreencodingUserDetailsService(UserDetailsService): + """ + This user details service allows passwords to be created that are un-encoded, but + will be encoded before the authentication step occurs. This is for demonstration + purposes only, specifically to show the password encoders being plugged in. + """ + def __init__(self, wrappedUserDetailsService = None, encoder = None): + UserDetailsService.__init__(self) + self.wrappedUserDetailsService = wrappedUserDetailsService + self.encoder = encoder + self.logger = logging.getLogger("springpython.petclinic.controller.PreencodingUserDetailsService") + + def load_user(self, username): + user = self.wrappedUserDetailsService.load_user(username) + user.password = self.encoder.encodePassword(user.password, None) + self.logger.debug("Pre-converting %s's password to hashed format of %s, before authentication happens." % (username, user.password)) + return user + + def __str__(self): + return "%s %s" % (self.encoder, self.wrappedUserDetailsService) diff --git a/samples/petclinic/cherrypy/view.py b/samples/petclinic/cherrypy/view.py index 63e2a48..bdf6e76 100644 --- a/samples/petclinic/cherrypy/view.py +++ b/samples/petclinic/cherrypy/view.py @@ -383,7 +383,7 @@ def login(self, fromPage="/", login="", password="", errorMsg=""): try: self.attemptAuthentication(login, password) return [self.redirectStrategy.redirect(fromPage)] - except AuthenticationException, e: + except AuthenticationException as e: return [self.redirectStrategy.redirect("?login=%s&errorMsg=Username/password failure" % login)] results = header() + """ diff --git a/samples/petclinic/cherrypy/view.py.bak b/samples/petclinic/cherrypy/view.py.bak new file mode 100644 index 0000000..63e2a48 --- /dev/null +++ b/samples/petclinic/cherrypy/view.py.bak @@ -0,0 +1,476 @@ +""" + Copyright 2006-2008 SpringSource (http://springsource.com), All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +import cherrypy +import logging +import re +import types +from springpython.database.core import DatabaseTemplate +from springpython.security import AuthenticationException +from springpython.security.context import SecurityContext +from springpython.security.context import SecurityContextHolder +from springpython.security.providers import UsernamePasswordAuthenticationToken + +def header(): + """Standard header used for all pages""" + return """ + + + + + PetClinic :: a Spring Python demonstration + + + + +
 
+
 
+
+

 

+
+
+
+
+ """ + +def footer(): + """Standard footer used for all pages.""" + return """ +
+ + + +
HomePetClinic :: a Spring Python demonstration (powered by CherryPy)
+ + """ + +class PetClinicView(object): + """Presentation layer of the web application.""" + + def __init__(self, filter=None, controller = None, hashedUserDetailsServiceList = None, authenticationManager=None, redirectStrategy=None): + self.filter = filter + self.controller = controller + self.hashedUserDetailsServiceList = hashedUserDetailsServiceList + self.authenticationManager = authenticationManager + self.redirectStrategy = redirectStrategy + self.httpContextFilter = None + self.logger = logging.getLogger("springpython.petclinic.view.PetClinicView") + + @cherrypy.expose + def accessDenied(self): + return header() + """ +

Access Denied

+

+ You have attempted to access a page which you are unauthorized to view. + """ + footer() + + @cherrypy.expose + def index(self): + """CherryPy will call this method for the root URI ("/") and send + its return value to the client.""" + + return header() + """ +

Welcome

+

+ Find owner +

+ Display all veterinarians +

+ Detailed description of this demo +

+ Logout + """ + footer() + + @cherrypy.expose + def findOwners(self, lastName = ""): + """Fetch owners by a partially matching against last name.""" + + results = header() + """ +

+

Find Owners:

+

+

+ Last Name: +
+

+ +

+

+
+ """ % lastName + if lastName != "": + results += """ +

Owners:

+ + + + + + """ + for owner in self.controller.getOwners(lastName): + results += """ + + + + + + + """ % (owner.id, owner.firstName, owner.lastName, owner.address, owner.city, owner.telephone) + results += """ +
NameAddressCityTelephone
%s %s%s%s%s
+

+
+ """ + results += """ + Add Owner +

+
+ Logout + """ + footer() + return results + + @cherrypy.expose + def addOwner(self, **kwargs): + """Insert a new owner into the database.""" + + results = header() + + if len(kwargs) > 0: + rowsAffected = self.controller.addOwner(**kwargs) + else: + rowsAffected = 0 + kwargs = { "firstName": "", "lastName": "", "address": "", "city": "", "telephone":"" } + + if rowsAffected > 0: + results += "

%s %s was successfully added.

" % (kwargs["firstName"], kwargs["lastName"]) + + results += """ +

+

New Owner:

+

+

+ First Name: +
+

+ + Last Name: +
+

+ + Address: +
+

+ + City: +
+

+ + Telephone: +
+

+ +

+

+
+ Logout + """ % (kwargs["firstName"], kwargs["lastName"], kwargs["address"], kwargs["city"], kwargs["telephone"]) + footer() + return results + + @cherrypy.expose + def editOwner(self, id): + """Update an existing owner""" + owner = self.controller.getOwner(id) + results = header() + """ +

+

Owner: %s %s

+

+

+ + Address: +
+

+ + City: +
+

+ + Telephone: +
+

+ +

+
+

Pets:

+

+ + + + + """ % (owner.firstName, owner.lastName, owner.id, owner.address, owner.city, owner.telephone) + for pet in self.controller.getPets(owner): + results += """ + + + + + + """ % (owner.id, pet.id, pet.name, pet.birthDate, pet.type) + results += """ +
NameBirth dateType
%s%s%s
+

+ Add Pet +

+
+ Logout + """ % id + footer() + return results + + @cherrypy.expose + def doUpdateOwner(self, id, address = "", city = "", telephone = ""): + owner = self.controller.updateOwner(id, address, city, telephone) + return self.editOwner(id) + + @cherrypy.expose + def addPet(self, id): + types = self.controller.getPetTypes() + results = header() + """ +

+

New Pet:

+

+

+ + Name: +
+

+ + Birth Date: +
+

+ + Type: +
+

+ +

+ Back to Owner +

+
+ Logout + """ % id + footer() + return results + + @cherrypy.expose + def doAddPet(self, id, name = "", birthDate = "", type = ""): + owner = self.controller.addPet(id, name, birthDate, type) + return self.addPet(id) + + @cherrypy.expose + def vetHistory(self, ownerId, petId): + """Look up history of visits for a pet.""" + pet = self.controller.getPet(petId) + + results = header() + """ +

+

History of visits for %s:

+ + + + + """ % pet.name + for visit in self.controller.getVisits(pet): + results += """ + + + + + """ % (visit.date, visit.description) + results += """ +
DateDescription
%s%s
+

+ Visit the Clinic +

+ Back to owner +
+ Logout + """ % (ownerId, pet.id, pet.name, ownerId) + footer() + return results + + @cherrypy.expose + def visitClinic(self, ownerId, petId, name): + """Look up history of visits for a pet.""" + results = header() + """ +

New visit for %s:

+
+ + + + Name: +
+

+ +

+

+ Back to history of visits +

+ Logout + """ % (name, name, ownerId, petId, ownerId, petId) + footer() + return results + + @cherrypy.expose + def doVisitClinic(self, ownerId, petId, name, description = ""): + owner = self.controller.visitClinic(petId, description) + return self.visitClinic(ownerId, petId, name) + + @cherrypy.expose + def vets(self): + """Look up all the veterinarians.""" + + results = header() + """ +

+

Veterinarians:

+ + + + + """ + for vet in self.controller.getVets(): + specialties = ",".join([specialty.name for specialty in self.controller.getVetSpecialties(vet)]) + results += """ + + + + + """ % (vet.firstName, vet.lastName, specialties) + results += """ +
NameSpecialties
%s %s%s
+

+
+ Logout + """ + footer() + return results + + @cherrypy.expose + def login(self, fromPage="/", login="", password="", errorMsg=""): + if login != "" and password != "": + try: + self.attemptAuthentication(login, password) + return [self.redirectStrategy.redirect(fromPage)] + except AuthenticationException, e: + return [self.redirectStrategy.redirect("?login=%s&errorMsg=Username/password failure" % login)] + + results = header() + """ + + %s +

+

Unhashed passwords - The following table contains a set of accounts that are stored in the clear.

+

+ + + + + + + + """ % errorMsg + + for (username, password, authorities, enabled) in self.controller.getUsers(): + results += """ + + + + + + + """ % (username, password, authorities, enabled) + + results += """ +
UsernamePasswordGranted authoritiesEnabled?
%s %s %s %s
+ """ + + # Display hard-coded, unhashed passwords. NOTE: These cannot be retrieved from + # the application context, because they are one way hashes. This must be kept + # in sync with the application context. + results += """ +

Hashed passwords - The following tables contain accounts that are stored with one-way hashes.

+

+ """ + for hashedUserDetailsService in self.hashedUserDetailsServiceList: + results += """ + %s + + + + + + + + """ % re.compile("<").sub("<", str(hashedUserDetailsService)) + for key, value in hashedUserDetailsService.wrappedUserDetailsService.user_dict.items(): + results += """ + + + + + + + """ % (key, value[0], value[1], value[2]) + results += """ +
UsernamePasswordGranted authoritiesEnabled?
%s %s %s %s
+

+ """ + + + results += """ +

+ Login:
+ Password:
+
+ +
+ + """ % (login, fromPage) + results += footer() + return [results] + + @cherrypy.expose + def logout(self): + """Replaces current authentication token, with an empty, non-authenticated one.""" + self.filter.logout() + self.httpContextFilter.saveContext() + raise cherrypy.HTTPRedirect("/") + + def attemptAuthentication(self, username, password): + """Authenticate a new username/password pair using the authentication manager.""" + self.logger.debug("Trying to authenticate %s using the authentication manager" % username) + token = UsernamePasswordAuthenticationToken(username, password) + SecurityContextHolder.getContext().authentication = self.authenticationManager.authenticate(token) + self.httpContextFilter.saveContext() + self.logger.debug(SecurityContextHolder.getContext()) diff --git a/samples/petclinic/configure.py b/samples/petclinic/configure.py index cc30c8c..a9efa03 100644 --- a/samples/petclinic/configure.py +++ b/samples/petclinic/configure.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ +from __future__ import print_function import os import re import termios @@ -27,7 +28,7 @@ def tryMySQL(): try: import MySQLdb except: - print "You don't appear to have MySQLdb module." + print("You don't appear to have MySQLdb module.") raise NotImplementedError("Can't setup the database") useGetPass = True @@ -35,8 +36,8 @@ def tryMySQL(): if useGetPass: try: password = getpass("Mysql 'root' password: ") - except termios.error, e: - print "Okay, we can't use that mechanism to ask for your password." + except termios.error as e: + print("Okay, we can't use that mechanism to ask for your password.") useGetPass = False password = raw_input("Mysql 'root' password: ") else: @@ -48,9 +49,9 @@ def tryMySQL(): del(connection) break except: - print "!!! Bad password!" + print("!!! Bad password!") if i >= 3: - print "!!! Failed all attempts to connection to the database>" + print("!!! Failed all attempts to connection to the database>") return None subprocess.Popen([r"mysql","-uroot", "-p%s" % password], @@ -80,10 +81,10 @@ def setupDatabase(): if line.strip() != ""]: databaseTemplate.execute(sqlStatement) - print "+++ Database is setup." + print("+++ Database is setup.") def main(): - print "+++ Setting up the Spring Python demo application 'petclinic'" + print("+++ Setting up the Spring Python demo application 'petclinic'") setupDatabase() diff --git a/samples/petclinic/configure.py.bak b/samples/petclinic/configure.py.bak new file mode 100644 index 0000000..cc30c8c --- /dev/null +++ b/samples/petclinic/configure.py.bak @@ -0,0 +1,91 @@ +""" + Copyright 2006-2008 SpringSource (http://springsource.com), All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +import os +import re +import termios +import subprocess +from getpass import getpass +from springpython.database.core import DatabaseTemplate +from springpython.database.factory import MySQLConnectionFactory + +def tryMySQL(): + """Try to setup the database through MySQL. If it fails, return None. Otherwise, return + a handle on the database. Later on, other databases may be supported.""" + try: + import MySQLdb + except: + print "You don't appear to have MySQLdb module." + raise NotImplementedError("Can't setup the database") + + useGetPass = True + for i in [1, 2, 3]: + if useGetPass: + try: + password = getpass("Mysql 'root' password: ") + except termios.error, e: + print "Okay, we can't use that mechanism to ask for your password." + useGetPass = False + password = raw_input("Mysql 'root' password: ") + else: + password = raw_input("Mysql 'root' password: ") + + try: + connection = MySQLdb.connect(host="localhost", user="root", passwd=password, db="") + connection.close() + del(connection) + break + except: + print "!!! Bad password!" + if i >= 3: + print "!!! Failed all attempts to connection to the database>" + return None + + subprocess.Popen([r"mysql","-uroot", "-p%s" % password], + stdout=subprocess.PIPE, + stdin=file("db/mysql/dropDB.txt")).communicate()[0] + + subprocess.Popen([r"mysql","-uroot", "-p%s" % password], + stdout=subprocess.PIPE, + stdin=file("db/mysql/initDB.txt")).communicate()[0] + connectionFactory = MySQLConnectionFactory() + connectionFactory.username = "springpython" + connectionFactory.password = "springpython" + connectionFactory.hostname = "localhost" + connectionFactory.db = "petclinic" + return connectionFactory + +def setupDatabase(): + """Figure out what type of database exists, and then set it up.""" + connectionFactory = tryMySQL() + + if connectionFactory is None: + raise Exception("+++ Could not setup MySQL. We don't support any others yet.") + + databaseTemplate = DatabaseTemplate(connectionFactory) + + for sqlStatement in [line.strip() for line in open("db/populateDB.txt").readlines() + if line.strip() != ""]: + databaseTemplate.execute(sqlStatement) + + print "+++ Database is setup." + +def main(): + print "+++ Setting up the Spring Python demo application 'petclinic'" + + setupDatabase() + +if __name__ == "__main__": + main() diff --git a/samples/springwiki/model.py b/samples/springwiki/model.py index ff3e65b..804aff2 100644 --- a/samples/springwiki/model.py +++ b/samples/springwiki/model.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ +from __future__ import print_function import re @@ -484,7 +485,7 @@ def header(self): def generateHistory(self): - print self.history + print(self.history) results = """

Diff selection: mark the radio boxes of the versions to compare and hit enter or the button at the bottom.
diff --git a/samples/springwiki/model.py.bak b/samples/springwiki/model.py.bak new file mode 100644 index 0000000..ff3e65b --- /dev/null +++ b/samples/springwiki/model.py.bak @@ -0,0 +1,641 @@ +""" + Copyright 2006-2008 SpringSource (http://springsource.com), All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +import re + +intrawikiR = re.compile("\[\[(?P.*?)(\|(?P.*?))?\]\](?P[a-zA-Z0-9]*)") +externalLinkR = re.compile("\[(?P.*?)\s(?P.*)\]") +header1R = re.compile("=(?P

.*)=") +header2R = re.compile("==(?P
.*)==") +header3R = re.compile("===(?P
.*)===") +header4R = re.compile("====(?P
.*)====") +header5R = re.compile("=====(?P
.*)=====") + +class Page(object): + def __init__(self, article, wikitext, controller): + self.article = article + self.wikitext = wikitext + self.controller = controller + + def makeAList(self, wikitext): + keyCharR = re.compile("(([*#]+)(.*))") + tokens = wikitext.split("\n") + bulletCount = 0 + currentLevel = 0 + + tagStack = [] + for i in range(1,len(tokens)): + if tokens[i]: + match = keyCharR.match(tokens[i]) + if match: + token = match.groups()[1] + text = match.groups()[2] + if len(token) > currentLevel: + tokens[i] = "" + for j in range(currentLevel, len(token)): + if token[j] == "*": + tagStack.append("") + tokens[i] += "
    " + elif token[j] == "#": + tagStack.append("") + tokens[i] += "
      " + tokens[i] += "
    1. " + text + "
    2. " + currentLevel = len(token) + elif len(token) < currentLevel: + tokens[i] = "" + for j in range(len(token), currentLevel): + tokens[i] += tagStack.pop() + tokens[i] += "
    3. " + text + "
    4. " + currentLevel = len(token) + else: + if token[-1] == "*": + tokens[i] = "
    5. " + text + "
    6. " + elif token[-1] == "#": + tokens[i] = "
    7. " + text + "
    8. " + else: + if currentLevel > 0: + tokens[i] = "" + for j in range(0, currentLevel): + tokens[i] += tagStack.pop() + currentLevel = 0 + return "\n".join(tokens) + + def intrawikiSubstitution(self, match): + g = match.groupdict() + + if self.controller.exists(g["link"]): + str = ' + + + + SpringWiki :: a Spring Python demonstration + + + + + +
      +
      +
      + +

      """ + self.article + """

      +
      +

      From Springwiki

      +
      + """ + + def footer(self, selected = "article"): + """Standard footer used for all pages.""" + footer = """ +
      +
      +
      +
      +
      +
      +
        + """ + + if selected == "article" and self.article.split(":")[0] == "Talk": + selected = "discussion" + + if selected == "article": + if self.controller.exists(self.article.split(":")[-1]): + footer += """
      • article
      • \n""" + else: + footer += """
      • article
      • \n""" + + if self.controller.exists("Talk:" + self.article.split(":")[-1]): + footer += """
      • discussion
      • \n""" + else: + footer += """
      • discussion
      • \n""" + + + footer += """
      • edit
      • \n""" + if self.controller.exists(self.article.split(":")[-1]): + footer += """ +
      • history
      • +
      • delete
      • + """ + elif selected == "discussion": + footer += """ +
      • article
      • +
      • discussion
      • +
      • edit
      • + """ + if self.controller.exists("Talk:" + self.article.split(":")[-1]): + footer += """ +
      • history
      • +
      • delete
      • + """ + elif selected == "edit": + if self.controller.exists(self.article.split(":")[-1]): + footer += """
      • article
      • \n""" + else: + footer += """
      • article
      • \n""" + + if self.controller.exists("Talk:" + self.article.split(":")[-1]): + footer += """
      • discussion
      • \n""" + else: + footer += """
      • discussion
      • \n""" + + footer += """ +
      • edit
      • + """ + if self.controller.exists(self.article): + footer += """ +
      • history
      • +
      • delete
      • + """ + elif selected == "history": + footer += """
      • article
      • + """ + + if self.controller.exists("Talk:" + self.article.split(":")[-1]): + footer += """
      • discussion
      • \n""" + else: + footer += """
      • discussion
      • \n""" + + footer += """ +
      • edit
      • +
      • history
      • +
      • delete
      • + """ + elif selected == "delete": + footer += """
      • article
      • + """ + if self.controller.exists("Talk:" + self.article.split(":")[-1]): + footer += """
      • discussion
      • \n""" + else: + footer += """
      • discussion
      • \n""" + footer += """ +
      • edit
      • +
      • history
      • +
      • delete
      • + """ + + footer += """ +
      +
      + """ + self.icon() + """ + """ + self.navigationHeader() + """ +
      + +
      + + +
      + + """ + return footer + + def icon(self): + return """ + + """ + + def navigationHeader(self): + """Left hand HTML""" + sidebar = self.controller.getPage("Springwiki Sidebar") + results = """ + + """ + + listStack = [] + for eachLine in sidebar.wikitext.split("\n"): + if len(eachLine) > 1 and eachLine[0:2] == "**": + temp = re.compile("\[.*\]").findall(eachLine[2:])[0] + wikiLink = re.compile("[\[]+(.*?)[\]]+").findall(temp)[0] + if temp[0:2] == "[[": + pipe = wikiLink.split("|")[-1] + link = wikiLink.split("|")[0] + results += """ +
    9. %s
    10. """ % (link, link, pipe) + else: + pipe = " ".join(wikiLink.split(" ")[1:]) + link = wikiLink.split(" ")[0] + results += """ +
    11. %s
    12. """ % (link, link, pipe) + + elif len(eachLine) > 0 and eachLine[0] == "*": + target = eachLine[1:] + if len(listStack) > 0: + results += """ +
+ + """ % listStack[-1] + listStack.pop() + results += """ +
+
%s
+ +
+
    """ % (target, target) + listStack.append(target) + + if len(listStack) > 0: + results += """ +
+
+
""" % listStack[-1] + listStack.pop() + + return results + + def wikiToHtml(self): + htmlText = self.wikitext + htmlText = header5R.sub('
\g
', htmlText) + htmlText = header4R.sub('

\g

', htmlText) + htmlText = header3R.sub('

\g

', htmlText) + htmlText = header2R.sub('

\g

', htmlText) + htmlText = header1R.sub('

\g

', htmlText) + htmlText = intrawikiR.sub(self.intrawikiSubstitution, htmlText) + htmlText = externalLinkR.sub('\g', htmlText) + htmlText = self.makeAList(htmlText) + return htmlText + + def html(self): + results = self.header() + results += """ + + """ + results += self.wikiToHtml() + results += """ + + """ + results += self.footer(selected="article") + return results + +class EditPage(Page): + def __init__(self, article, wikitext, controller): + Page.__init__(self, article, wikitext, controller) + + def header(self): + """Standard header used for all pages""" + return """ + + + + + SpringWiki :: a Spring Python demonstration + + + + + + +
+
+
+ +

Editing """ + self.article + """

+
+

From Springwiki

+
+ """ + def wikiToHtml(self): + htmlText = """ + + """ + htmlText += """ +
+ +
+ Summary: +
+
+ + + + Cancel +
+ """ + return htmlText + + def html(self): + results = self.header() + results += """ + + """ + results += self.wikiToHtml() + results += """ + + """ + results += self.footer(selected="edit") + return results + + +class PreviewEditPage(Page): + def __init__(self, article, wikitext, controller): + Page.__init__(self, article, wikitext, controller) + + def header(self): + """Standard header used for all pages""" + return """ + + + + + SpringWiki :: a Spring Python demonstration + + + + + +
+
+
+ +

Editing """ + self.article + """

+
+

From Springwiki

+
+ """ + def html(self): + results = self.header() + results += '

Preview

' + results += Page(self.article, self.wikitext, self.controller).wikiToHtml() + results += '
' + results += EditPage(self.article, self.wikitext, self.controller).wikiToHtml() + results += self.footer() + return results + +class Version(object): + def __init__(self, article, versionNumber, wikitext, summary, date, editor): + self.article = article + self.versionNumber = versionNumber + self.wikitext = wikitext + self.summary = summary + self.date = date + self.editor = editor + +class HistoryPage(Page): + def __init__(self, article, controller, history): + Page.__init__(self, article, 'History', controller) + self.history = [] + for rev in history: + self.history.append(Version(article=article, versionNumber=len(self.history), wikitext=rev[0], summary=rev[1], date=rev[2], editor=rev[3])) + + def header(self): + """Standard header used for all pages""" + return """ + + + + + SpringWiki :: a Spring Python demonstration + + + + + +
+
+
+ +

""" + self.article + """

+
+

From Springwiki

+
Revision history
+ """ + + + def generateHistory(self): + print self.history + results = """ +

+ Diff selection: mark the radio boxes of the versions to compare and hit enter or the button at the bottom.
+ Legend: (cur) = difference with current version, + (last) = difference with preceding version, M = minor edit. +

+
+ +
    + """ + + if len(self.history) > 0: + for rev in reversed(self.history): + hiddenStyle = "" + if rev == self.history[-1]: + hiddenStyle = 'style="visibility:hidden"' + + userStyle = "" + if self.controller.exists(rev.editor): + userStyle = 'class="new"' + + try: + editor = rev.editor.split(":")[1] + except: + editor = rev.editor + + results += """ +
  • + + + %s + %s + (->%s) +
  • + """ % (rev.versionNumber, hiddenStyle, rev.versionNumber, rev.article, rev.versionNumber, rev.date, + rev.editor, editor, userStyle, rev.summary) + + results += """ +
+ +
+ """ + return results + + def html(self): + results = self.header() + results += """ + + """ + results += self.generateHistory() + results += """ + + """ + results += self.footer(selected="history") + return results + +class OldPage(Page): + def __init__(self, article, wikitext, controller): + Page.__init__(self, article, wikitext, controller) + +class NoPage(Page): + def __init__(self, article, controller): + Page.__init__(self, article, "This page does not yet exist.", controller) + +class DeletePage(Page): + def __init__(self, article, controller): + Page.__init__(self, article, "This page does not yet exist.", controller) + + def header(self): + """Standard header used for all pages""" + return """ + + + + + SpringWiki :: a Spring Python demonstration + + + + + +
+
+
+ +

""" + self.article + """

+
+

From Springwiki

+
(Deleting "%s")
+ """ % self.article + + def html(self): + results = self.header() + results += """ + + """ + results += """ + Warning: The page you are about to delete has a history +

+ You are about to permanently delete a page or image along with all of its history from the database. + Please confirm that you intend to do this, that you understand the consequences.

+

+ + + + + + + + + + +
+ + + +
  + +
+
+

+

Return to %s.

+ """ % (self.article, self.article, self.article, self.article) + results += """ + + """ + results += self.footer(selected="delete") + return results + +class ActionCompletedPage(Page): + def __init__(self, article, controller): + Page.__init__(self, article, "This page does not yet exist.", controller) + + def html(self): + results = self.header() + results += """ + + """ + results += """ +

"%s" has been deleted.

+

Return to Main Page.

+ """ % self.article + + results += """ + + """ + results += self.footer(selected="article") + return results + + diff --git a/samples/springwiki/view.py b/samples/springwiki/view.py index 2c69602..34eb73b 100644 --- a/samples/springwiki/view.py +++ b/samples/springwiki/view.py @@ -113,7 +113,7 @@ def index(self, fromPage="/", login="", password="", errorMsg=""): try: self.attemptAuthentication(login, password) return [self.redirectStrategy.redirect(fromPage)] - except AuthenticationException, e: + except AuthenticationException as e: return [self.redirectStrategy.redirect("?login=%s&errorMsg=Username/password failure" % login)] results = """ diff --git a/samples/springwiki/view.py.bak b/samples/springwiki/view.py.bak new file mode 100644 index 0000000..2c69602 --- /dev/null +++ b/samples/springwiki/view.py.bak @@ -0,0 +1,143 @@ +""" + Copyright 2006-2008 SpringSource (http://springsource.com), All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +import logging +import cherrypy +import re +from model import ActionCompletedPage +from model import Page +from model import PreviewEditPage + +from springpython.security import AuthenticationException +from springpython.security.context import SecurityContext, SecurityContextHolder +from springpython.security.providers import UsernamePasswordAuthenticationToken + +def forward(url): + return '' + +class Springwiki(object): + """Render a RESTful wiki article for display.""" + + def __init__(self, controller = None): + """Inject a controller object in order to fetch live data.""" + self.controller = controller + + @cherrypy.expose + def index(self, article = "Main Page"): + if article: + return self.default(article) + + def addRawWikitext(self, page): + return "

Original Wikitext

" + re.sub("\n", "
", page.wikitext) + + @cherrypy.expose + def default(self, article, oldid=None): + page = self.controller.getPage(article, oldid) + return page.html() + + def addRawWikitext(self, page): + return "

Original Wikitext

" + re.sub("\n", "
", page.wikitext) + + @cherrypy.expose + def submit(self, article, wpTextbox1 = None, wpSummary = None, wpSave = None, wpMinorEdit = False, + wpPreview = None, wpDiff = None, wpMinoredit = False): + + if wpSave: + self.controller.updatePage(article, wpTextbox1, wpSummary, wpMinorEdit) + + if wpPreview: + return PreviewEditPage(article, wpTextbox1, self.controller).html() + + return forward("/" + article) + + @cherrypy.expose + def edit(self, article): + page = self.controller.getEditPage(article) + return page.html() + + @cherrypy.expose + def history(self, article): + page = self.controller.getHistory(article) + return page.html() + + @cherrypy.expose + def delete(self, article): + page = self.controller.getDeletePage(article) + return page.html() + + @cherrypy.expose + def doDelete(self, article, wpReason = None, wpConfirmB = None): + if wpConfirmB == "Confirm": + self.controller.deletePage(article, wpReason) + return forward("/actionCompleted?article=%s" % article) + + @cherrypy.expose + def actionCompleted(self, article): + page = self.controller.getActionCompletedPage(article) + return page.html() + + @cherrypy.expose + def special(self, specialPage, article = None): + return "Special page " + specialPage + " doesn't exist yet." + +class CherryPyAuthenticationForm: + """ + This is simple authentication page used to test the PetClinic app. For production use, you + should use something like SSL to keep people from eavesdropping on passwords being passed + to the application server. + """ + def __init__(self, filter=None, controller = None, authenticationManager=None, redirectStrategy=None, + httpContextFilter=None): + self.filter = filter + self.controller = controller + self.authenticationManager = authenticationManager + self.redirectStrategy = redirectStrategy + self.httpContextFilter = httpContextFilter + self.logger = logging.getLogger("springpython.springwiki.view.CherryPyAuthenticationForm") + + @cherrypy.expose + def index(self, fromPage="/", login="", password="", errorMsg=""): + if login != "" and password != "": + try: + self.attemptAuthentication(login, password) + return [self.redirectStrategy.redirect(fromPage)] + except AuthenticationException, e: + return [self.redirectStrategy.redirect("?login=%s&errorMsg=Username/password failure" % login)] + + results = """ +
+ Login:
+ Password:
+
+ +
+ + """ % (login, fromPage) + return [results] + + @cherrypy.expose + def logout(self): + """Replaces current authentication token, with an empty, non-authenticated one.""" + self.filter.logout() + self.httpContextFilter.saveContext() + raise cherrypy.HTTPRedirect("/") + + def attemptAuthentication(self, username, password): + """Authenticate a new username/password pair using the authentication manager.""" + token = UsernamePasswordAuthenticationToken(username, password) + self.logger.debug("Trying to authenticate '%s' using the authentication manager: %s" % (username, token)) + SecurityContextHolder.getContext().authentication = self.authenticationManager.authenticate(token) + self.logger.debug(SecurityContextHolder.getContext()) + self.httpContextFilter.saveContext() diff --git a/src/plugins/gen-cherrypy-app/__init__.py b/src/plugins/gen-cherrypy-app/__init__.py index 3cc4f7e..f74323e 100644 --- a/src/plugins/gen-cherrypy-app/__init__.py +++ b/src/plugins/gen-cherrypy-app/__init__.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ +from __future__ import print_function import re import os import shutil @@ -21,7 +22,7 @@ def create(plugin_path, name): if not os.path.exists(name): - print "Creating CherryPy skeleton app %s" % name + print("Creating CherryPy skeleton app %s" % name) os.makedirs(name) # Copy/transform the template files @@ -44,5 +45,5 @@ def create(plugin_path, name): # Recursively copy other parts shutil.copytree(plugin_path + "/images", name + "/" + "images") else: - print "There is already something called %s. ABORT!" % name + print("There is already something called %s. ABORT!" % name) diff --git a/src/plugins/gen-cherrypy-app/__init__.py.bak b/src/plugins/gen-cherrypy-app/__init__.py.bak new file mode 100644 index 0000000..3cc4f7e --- /dev/null +++ b/src/plugins/gen-cherrypy-app/__init__.py.bak @@ -0,0 +1,48 @@ +""" + Copyright 2006-2008 SpringSource (http://springsource.com), All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +import re +import os +import shutil + +__description__ = "plugin to create skeleton CherryPy applications" + +def create(plugin_path, name): + if not os.path.exists(name): + print "Creating CherryPy skeleton app %s" % name + os.makedirs(name) + + # Copy/transform the template files + for file_name in ["cherrypy-app.py", "controller.py", "view.py", "app_context.py"]: + input_file = open(plugin_path + "/" + file_name).read() + + # Iterate over a list of patterns, performing string substitution on the input file + patterns_to_replace = [("name", name), ("properName", name[0].upper() + name[1:])] + for pattern, replacement in patterns_to_replace: + input_file = re.compile(r"\$\{%s}" % pattern).sub(replacement, input_file) + + output_filename = name + "/" + file_name + if file_name == "cherrypy-app.py": + output_filename = name + "/" + name + ".py" + + app = open(output_filename, "w") + app.write(input_file) + app.close() + + # Recursively copy other parts + shutil.copytree(plugin_path + "/images", name + "/" + "images") + else: + print "There is already something called %s. ABORT!" % name + diff --git a/src/setup-template.py b/src/setup-template.py index 0daa6a3..6f38dee 100644 --- a/src/setup-template.py +++ b/src/setup-template.py @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. """ +from __future__ import print_function import re import sys @@ -23,7 +24,7 @@ from setuptools import setup if sys.version_info < (2, 6): - print "Spring Python only supports Python 2.6 and higher" + print("Spring Python only supports Python 2.6 and higher") sys.exit(1) setup(name='springpython', diff --git a/src/setup-template.py.bak b/src/setup-template.py.bak new file mode 100644 index 0000000..0daa6a3 --- /dev/null +++ b/src/setup-template.py.bak @@ -0,0 +1,65 @@ +#!/usr/bin/env python +""" + Copyright 2006-2008 SpringSource (http://springsource.com), All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +import re +import sys + +try: + from distribute.core import setup +except ImportError: + from setuptools import setup + +if sys.version_info < (2, 6): + print "Spring Python only supports Python 2.6 and higher" + sys.exit(1) + +setup(name='springpython', + version='${version}', + description='Spring Python', + long_description='Spring Python is an offshoot of the Java-based SpringFramework, targeted for Python. Spring provides many useful features, and I wanted those same features available when working with Python.', + author='Greg L. Turnquist', + author_email='greg.turnquist at springsource dot com', + url='http://springpython.webfactional.com', + platforms = ["Python >= 2.6"], + license='Apache Software License (http://www.apache.org/licenses/LICENSE-2.0)', + scripts=['plugins/coily'], + packages=['springpython', + 'springpython.aop', + 'springpython.jms', + 'springpython.config', + 'springpython.container', + 'springpython.context', + 'springpython.database', + 'springpython.factory', + 'springpython.remoting', + 'springpython.remoting.hessian', + 'springpython.remoting.pyro', + 'springpython.security', + 'springpython.security.context', + 'springpython.security.providers', + 'springpython.security.userdetails' + ], + package_data={'springpython': ["README", "COPYRIGHT", "LICENSE.txt"]}, + classifiers=["License :: OSI Approved :: Apache Software License", + "Intended Audience :: Developers", + "Development Status :: 5 - Production/Stable", + "Topic :: Software Development :: Libraries :: Python Modules", + "Programming Language :: Python", + "Operating System :: OS Independent" + ] + + ) + diff --git a/src/springpython/aop/__init__.py b/src/springpython/aop/__init__.py index faef576..57e44ac 100644 --- a/src/springpython/aop/__init__.py +++ b/src/springpython/aop/__init__.py @@ -56,7 +56,7 @@ def getInterceptor(self): def proceed(self): """This is the method every interceptor should call in order to continue down the chain of interceptors.""" - interceptor = self.iterator.next() + interceptor = next(self.iterator) self.logger.debug("Calling %s.%s(%s, %s)" % (interceptor.__class__.__name__, self.method_name, self.args, self.kwargs)) return interceptor.invoke(self) @@ -218,7 +218,7 @@ def getProxy(self): return AopProxy(self.target, self.interceptors) def __setattr__(self, name, value): - if name == "target" and type(value) == types.StringType: + if name == "target" and type(value) == bytes: value = utils.getClass(value)() elif name == "interceptors" and not isinstance(value, list): value = [value] diff --git a/src/springpython/aop/__init__.py.bak b/src/springpython/aop/__init__.py.bak new file mode 100644 index 0000000..faef576 --- /dev/null +++ b/src/springpython/aop/__init__.py.bak @@ -0,0 +1,250 @@ +""" + Copyright 2006-2008 SpringSource (http://springsource.com), All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +import copy +import logging +import re +import types +from springpython.aop import utils + +class Pointcut(object): + """Interface defining where to apply an aspect.""" + def class_filter(self): + raise NotImplementedError() + def method_matcher(self): + raise NotImplementedError() + +class MethodMatcher(object): + """Interface defining how to apply aspects based on methods.""" + def matches_method_and_target(self, method, targetClass, args): + raise NotImplementedError() + +class MethodInterceptor(object): + """Interface defining "around" advice.""" + def invoke(self, invocation): + raise NotImplementedError() + +class MethodInvocation(object): + """Encapsulation of invoking a method on a proxied service. It iterates throgh the list of interceptors by using + a generator.""" + def __init__(self, instance, method_name, args, kwargs, interceptors): + self.instance = instance + self.method_name = method_name + self.args = args + self.kwargs = kwargs + self.intercept_stack = copy.copy(interceptors) + self.intercept_stack.append(FinalInterceptor()) + self.logger = logging.getLogger("springpython.aop.MethodInvocation") + + def getInterceptor(self): + """This is a generator to proceed through the stack of interceptors. By using generator convention, code may + proceed in a nested fashion, versus a for-loop which would act in a chained fashion.""" + for interceptor in self.intercept_stack: + yield interceptor + + def proceed(self): + """This is the method every interceptor should call in order to continue down the chain of interceptors.""" + interceptor = self.iterator.next() + self.logger.debug("Calling %s.%s(%s, %s)" % (interceptor.__class__.__name__, self.method_name, self.args, self.kwargs)) + return interceptor.invoke(self) + + def __getattr__(self, name): + """This only deals with method invocations. Attributes are dealt with by the AopProxy, and don't every reach this + block of code.""" + self.iterator = self.getInterceptor() + self.method_name = name + return self + + def __call__ (self, *args, **kwargs): + """This method converts this from being a stored object into a callable class. This is effectively like a metaclass + that dispatches calls to proceed through a stack of interceptors.""" + self.args = args + self.kwargs = kwargs + return self.proceed() + + def dump_interceptors(self, level = logging.INFO): + """DEBUG: Method used to dump the stack of interceptors in order of execution.""" + for interceptor in self.intercept_stack: + self.logger.log(level, "Interceptor stack: %s" % interceptor.__class__.__name__) + +class RegexpMethodPointcutAdvisor(Pointcut, MethodMatcher, MethodInterceptor): + """ + This is a combination PointCut/MethodMatcher/MethodInterceptor. It allows associating one or more + defined advices with a set of regular expression patterns. + """ + def __init__(self, advice = None, patterns = None): + Pointcut.__init__(self) + MethodMatcher.__init__(self) + self.advice = advice + if not patterns: + self.patterns = [] + else: + self.patterns = patterns + self.logger = logging.getLogger("springpython.aop.RegexpMethodPointcut") + + def init_patterns(self): + """Precompile the regular expression pattern matcher list.""" + self.compiled_patterns = {} + for pattern in self.patterns: + self.compiled_patterns[pattern] = re.compile(pattern) + + def matches_method_and_target(self, method, target_class, args): + """Iterate through all patterns, checking for a match. Calls the pattern matcher against "class.method_name".""" + for pointcut_pattern in self.patterns: + if (self.matches_pattern(target_class + "." + method, pointcut_pattern)): + return True + return False + + def matches_pattern(self, method_name, pointcut_pattern): + """Uses a pre-built dictionary of regular expression patterns to check for a matcch.""" + if self.compiled_patterns[pointcut_pattern].match(method_name): + matched = True + else: + matched = False + self.logger.debug("Candidate is [%s]; pattern is [%s]; matched=%s" % (method_name, pointcut_pattern, matched)) + return matched + + def invoke(self, invocation): + """Compares "class.method" against regular expression pattern and if it passes, it will + pass through to the chain of interceptors. Otherwise, bypass interceptors and invoke class + method directly.""" + + className = invocation.instance.__class__.__name__ + + if self.matches_method_and_target(invocation.method_name, className, invocation.args): + # This constant is not class level, because it is a hack solution, and should only be used + # used here, and not known outside the scope of this block of code. --G.Turnquist (9/22/2008) + ASSUME_THIS_ADVISOR_WAS_FIRST = 1 + invocation.intercept_stack[ASSUME_THIS_ADVISOR_WAS_FIRST:ASSUME_THIS_ADVISOR_WAS_FIRST] = self.advice + + self.logger.debug("We have a match, passing through to the advice.") + invocation.dump_interceptors(logging.DEBUG) + + return invocation.proceed() + else: + self.logger.debug("No match, bypassing advice, going straight to targetClass.") + return getattr(invocation.instance, invocation.method_name)(*invocation.args, **invocation.kwargs) + + def __setattr__(self, name, value): + """If "advice", make sure it is a list. Anything else, pass through to simple assignment. + Also, if "patterns", initialize the regular expression parsers. + """ + if name == "advice" and type(value) != list: + self.__dict__[name] = [value] + else: + self.__dict__[name] = value + + if name == "patterns": + self.init_patterns() + +class FinalInterceptor(MethodInterceptor): + """ + Final interceptor is always at the bottom of interceptor stack. + It executes the actual target method on the instance. + """ + def __init__(self): + MethodInterceptor.__init__(self) + self.logger = logging.getLogger("springpython.aop.FinalInterceptor") + + def invoke(self, invocation): + return getattr(invocation.instance, invocation.method_name)(*invocation.args, **invocation.kwargs) + +class AopProxy(object): + """AopProxy acts like the target object by dispatching all method calls to the target through a MethodInvocation. + The MethodInvocation object actually deals with potential "around" advice, referred to as interceptors. Attribute + lookups are not intercepted, but instead fetched from the actual target object.""" + + def __init__(self, target, interceptors): + self.target = target + if type(interceptors) == list: + self.interceptors = interceptors + else: + self.interceptors = [interceptors] + self.logger = logging.getLogger("springpython.aop.AopProxy") + + def __getattr__(self, name): + """If any of the parameters are local objects, they are immediately retrieved. Callables cause the dispatch method + to be return, which forwards callables through the interceptor stack. Target attributes are retrieved directly from + the target object.""" + if name in ["target", "interceptors", "method_name"]: + return self.__dict__[name] + else: + attr = getattr(self.target, name) + if not callable(attr): + return attr + + def dispatch(*args, **kwargs): + """This method is returned to the caller emulating the function call being sent to the + target object. This services as a proxying agent for the target object.""" + invocation = MethodInvocation(self.target, name, args, kwargs, self.interceptors) + ############################################################################## + # NOTE: + # getattr(invocation, name) doesn't work here, because __str__ will print + # the MethodInvocation's string, instead of trigger the interceptor stack. + ############################################################################## + return invocation.__getattr__(name)(*args, **kwargs) + + return dispatch + +class ProxyFactory(object): + """This object helps to build AopProxy objects programmatically. It allows configuring advice and target objects. + Then it will produce an AopProxy when needed. To use similar behavior in an IoC environment, see ProxyFactoryObject.""" + + def __init__(self, target = None, interceptors = None): + self.logger = logging.getLogger("springpython.aop.ProxyFactory") + self.target = target + if not interceptors: + self.interceptors = [] + elif type(interceptors) == list: + self.interceptors = interceptors + else: + self.interceptors = [interceptors] + + def getProxy(self): + """Generate an AopProxy given the current target and list of interceptors. Any changes to the factory after + proxy creation do NOT propagate to the proxies.""" + return AopProxy(self.target, self.interceptors) + + def __setattr__(self, name, value): + if name == "target" and type(value) == types.StringType: + value = utils.getClass(value)() + elif name == "interceptors" and not isinstance(value, list): + value = [value] + + self.__dict__[name] = value + +class ProxyFactoryObject(ProxyFactory, AopProxy): + """This class acts as both a ProxyFactory to build and an AopProxy. It makes itself look like the target object. + Any changes to the target and list of interceptors is immediately seen when using this as a proxy.""" + def __init__(self, target = None, interceptors = None): + ProxyFactory.__init__(self, target, interceptors) + self.logger = logging.getLogger("springpython.aop.ProxyFactoryObject") + + def __str__(self): + return self.__getattr__("__str__")() + +class PerformanceMonitorInterceptor(MethodInterceptor): + def __init__(self, prefix = None, level = logging.DEBUG): + self.prefix = prefix + self.level = level + self.logger = logging.getLogger("springpython.aop") + + def invoke(self, invocation): + self.logger.log(self.level, "%s BEGIN" % (self.prefix)) + timing.start() + results = invocation.proceed() + timing.finish() + self.logger.log(self.level, "%s END => %s" % (self.prefix, timing.milli()/1000.0)) + return results diff --git a/src/springpython/config/__init__.py b/src/springpython/config/__init__.py index f7b9398..7ae384f 100644 --- a/src/springpython/config/__init__.py +++ b/src/springpython/config/__init__.py @@ -13,7 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. """ -from _config_base import * -from _xml_config import * -from _yaml_config import * -from _python_config import * +from __future__ import absolute_import +from ._config_base import * +from ._xml_config import * +from ._yaml_config import * +from ._python_config import * diff --git a/src/springpython/config/__init__.py.bak b/src/springpython/config/__init__.py.bak new file mode 100644 index 0000000..f7b9398 --- /dev/null +++ b/src/springpython/config/__init__.py.bak @@ -0,0 +1,19 @@ +""" + Copyright 2006-2008 SpringSource (http://springsource.com), All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +from _config_base import * +from _xml_config import * +from _yaml_config import * +from _python_config import * diff --git a/src/springpython/config/_config_base.py b/src/springpython/config/_config_base.py index 8126a82..f97572e 100644 --- a/src/springpython/config/_config_base.py +++ b/src/springpython/config/_config_base.py @@ -13,13 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. """ +from __future__ import absolute_import import re import types import inspect import logging from springpython.context import scope -from decorator import decorator, partial +from .decorator import decorator, partial from springpython.context import ApplicationContextAware from springpython.factory import PythonObjectFactory from springpython.factory import ReflectiveObjectFactory diff --git a/src/springpython/config/_config_base.py.bak b/src/springpython/config/_config_base.py.bak new file mode 100644 index 0000000..8126a82 --- /dev/null +++ b/src/springpython/config/_config_base.py.bak @@ -0,0 +1,291 @@ +""" + Copyright 2006-2008 SpringSource (http://springsource.com), All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +import re +import types +import inspect +import logging + +from springpython.context import scope +from decorator import decorator, partial +from springpython.context import ApplicationContextAware +from springpython.factory import PythonObjectFactory +from springpython.factory import ReflectiveObjectFactory +from springpython.container import InvalidObjectScope + +def get_string(value): + """This function is used to parse text that could either be ASCII or unicode.""" + try: + return str(value) + except UnicodeEncodeError: + return unicode(value) + +class ObjectDef(object): + """ + ObjectDef is a format-neutral way of storing object definition information. It includes + a handle for the actual ObjectFactory that should be used to utilize this information when + creating an instance of a object. + """ + def __init__(self, id, props=None, factory=None, scope=scope.SINGLETON, + lazy_init=False, abstract=False, parent=None): + super(ObjectDef, self).__init__() + self.id = id + self.factory = factory + if props is None: + self.props = [] + else: + self.props = props + self.scope = scope + self.lazy_init = lazy_init + self.abstract = abstract + self.parent = parent + self.pos_constr = [] + self.named_constr = {} + + def __str__(self): + return "id=%s props=%s scope=%s factory=%s" % (self.id, self.props, self.scope, self.factory) + +class ReferenceDef(object): + """ + This class represents a definition that is referencing another object. + """ + def __init__(self, name, ref): + self.name = name + self.ref = ref + + def prefetch(self, container): + self.get_value(container) + + def get_value(self, container): + return container.get_object(self.ref) + + def set_value(self, obj, container): + setattr(obj, self.name, container.objects[self.ref]) + + def __str__(self): + return "name=%s ref=%s" % (self.name, self.ref) + +class InnerObjectDef(object): + """ + This class represents an inner object. It is optional whether or not the object + has its own name. + """ + def __init__(self, name, inner_comp): + self.name = name + self.inner_comp = inner_comp + + def prefetch(self, container): + self.get_value(container) + + def get_value(self, container): + return container.get_object(self.inner_comp.id) + + def set_value(self, obj, container): + setattr(obj, self.name, self.get_value(container)) + + def __str__(self): + return "name=%s inner_comp=%s" % (self.name, self.inner_comp) + +class ValueDef(object): + """ + This class represents a property that holds a value. The value can be simple value, or + it can be a complex container which internally holds references, inner objects, or + any other type. + """ + def __init__(self, name, value): + self.name = name + if value == "True": + self.value = True + elif value == "False": + self.value= False + else: + self.value = value + self.logger = logging.getLogger("springpython.config.ValueDef") + + def scan_value(self, container, value): + if hasattr(value, "get_value"): + return value.get_value(container) + elif isinstance(value, tuple): + new_list = [self.scan_value(container, item) for item in value] + results = tuple(new_list) + return results + elif isinstance(value, list): + new_list = [self.scan_value(container, item) for item in value] + return new_list + elif isinstance(value, set): + results = set([self.scan_value(container, item) for item in value]) + return results + elif isinstance(value, frozenset): + results = frozenset([self.scan_value(container, item) for item in value]) + return results + else: + if value == "True": + return True + elif value == "False": + return False + else: + return value + + def get_value(self, container): + val = self._replace_refs_with_actuals(self.value, container) + if val is None: + return self.value + else: + return val + + def set_value(self, obj, container): + setattr(obj, self.name, self.value) + val = self._replace_refs_with_actuals(obj, container) + + def _replace_refs_with_actuals(self, obj, container): + """Normal values do nothing for this step. However, sub-classes are defined for + the various containers, like lists, set, dictionaries, etc., to handle iterating + through and pre-fetching items.""" + pass + + def __str__(self): + return "name=%s value=%s" % (self.name, self.value) + +class DictDef(ValueDef): + """Handles behavior for a dictionary-based value.""" + def __init__(self, name, value): + super(DictDef, self).__init__(name, value) + + def _replace_refs_with_actuals(self, obj, container): + for key in self.value.keys(): + if hasattr(self.value[key], "ref"): + self.value[key] = container.get_object(self.value[key].ref) + else: + self.value[key] = self.scan_value(container, self.value[key]) + +class ListDef(ValueDef): + """Handles behavior for a list-based value.""" + def __init__(self, name, value): + super(ListDef, self).__init__(name, value) + self.logger = logging.getLogger("springpython.config.ListDef") + + def _replace_refs_with_actuals(self, obj, container): + for i in range(0, len(self.value)): + self.logger.debug("Checking out %s, wondering if I need to do any replacement..." % get_string(self.value[i])) + if hasattr(self.value[i], "ref"): + self.value[i] = container.get_object(self.value[i].ref) + else: + self.value[i] = self.scan_value(container, self.value[i]) + +class TupleDef(ValueDef): + """Handles behavior for a tuple-based value.""" + + def __init__(self, name, value): + super(TupleDef, self).__init__(name, value) + + def _replace_refs_with_actuals(self, obj, container): + new_value = list(self.value) + for i in range(0, len(new_value)): + if hasattr(new_value[i], "ref"): + new_value[i] = container.get_object(new_value[i].ref) + else: + new_value[i] = self.scan_value(container, new_value[i]) + try: + setattr(obj, self.name, tuple(new_value)) + except AttributeError: + pass + return tuple(new_value) + +class SetDef(ValueDef): + """Handles behavior for a set-based value.""" + def __init__(self, name, value): + super(SetDef, self).__init__(name, value) + self.logger = logging.getLogger("springpython.config.SetDef") + + def _replace_refs_with_actuals(self, obj, container): + self.logger.debug("Replacing refs with actuals...") + self.logger.debug("set before changes = %s" % self.value) + new_set = set() + for item in self.value: + if hasattr(item, "ref"): + self.logger.debug("Item !!!%s!!! is a ref, trying to replace with actual object !!!%s!!!" % (item, item.ref)) + #self.value.remove(item) + #self.value.add(container.get_object(item.ref)) + newly_fetched_value = container.get_object(item.ref) + new_set.add(newly_fetched_value) + self.logger.debug("Item !!!%s!!! was removed, and newly fetched value !!!%s!!! was added." % (item, newly_fetched_value)) + #new_set.add(container.get_object(item.ref)) + else: + self.logger.debug("Item !!!%s!!! is NOT a ref, trying to replace with scanned value" % get_string(item)) + #self.value.remove(item) + #self.value.add(self.scan_value(container, item)) + newly_scanned_value = self.scan_value(container, item) + new_set.add(newly_scanned_value) + self.logger.debug("Item !!!%s!!! was removed, and newly scanned value !!!%s!!! was added." % (item, newly_scanned_value)) + #new_set.add(self.scan_value(container, item)) + #self.value = new_set + self.logger.debug("set after changes = %s" % new_set) + #return self.value + try: + setattr(obj, self.name, new_set) + except AttributeError: + pass + return new_set + +class FrozenSetDef(ValueDef): + """Handles behavior for a frozen-set-based value.""" + def __init__(self, name, value): + super(FrozenSetDef, self).__init__(name, value) + self.logger = logging.getLogger("springpython.config.FrozenSetDef") + + def _replace_refs_with_actuals(self, obj, container): + self.logger.debug("Replacing refs with actuals...") + self.logger.debug("set before changes = %s" % self.value) + new_set = set() + for item in self.value: + if hasattr(item, "ref"): + self.logger.debug("Item <<<%s>>> is a ref, trying to replace with actual object <<<%s>>>" % (item, item.ref)) + #new_set.remove(item) + #debug begin + newly_fetched_value = container.get_object(item.ref) + new_set.add(newly_fetched_value) + self.logger.debug("Item <<<%s>>> was removed, and newly fetched value <<<%s>>> was added." % (item, newly_fetched_value)) + #debug end + #new_set.add(container.get_object(item.ref)) + else: + self.logger.debug("Item <<<%s>>> is NOT a ref, trying to replace with scanned value" % get_string(item)) + #new_set.remove(item) + #debug begin + newly_scanned_value = self.scan_value(container, item) + new_set.add(newly_scanned_value) + self.logger.debug("Item <<<%s>>> was removed, and newly scanned value <<<%s>>> was added." % (item, newly_scanned_value)) + #debug end + #new_set.add(self.scan_value(container, item)) + #self.logger.debug("Newly built set = %s" % new_set) + #self.value = frozenset(new_set) + new_frozen_set = frozenset(new_set) + self.logger.debug("set after changes = %s" % new_frozen_set) + #return self.value + try: + setattr(obj, self.name, new_frozen_set) + except AttributeError: + pass + except TypeError: + pass + return new_frozen_set + +class Config(object): + """ + Config is an interface that defines how to read object definitions from an input source. + """ + def read_object_defs(self): + """Abstract method definition - should return an array of Object objects""" + raise NotImplementedError() + diff --git a/src/springpython/config/_python_config.py b/src/springpython/config/_python_config.py index 2e8fb57..d00526a 100644 --- a/src/springpython/config/_python_config.py +++ b/src/springpython/config/_python_config.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ +from __future__ import absolute_import try: import cElementTree as etree except ImportError: @@ -26,9 +27,9 @@ import inspect import logging -from _config_base import * +from ._config_base import * from springpython.context import scope -from decorator import decorator, partial +from .decorator import decorator, partial from springpython.context import ApplicationContextAware from springpython.factory import PythonObjectFactory from springpython.factory import ReflectiveObjectFactory @@ -50,14 +51,14 @@ def read_object_defs(self): for name, method in inspect.getmembers(self, inspect.ismethod): if name not in _pythonConfigMethods: try: - wrapper = method.im_func.func_globals["_call_"] + wrapper = method.__func__.__globals__["_call_"] - if wrapper.func_name.startswith("object"): + if wrapper.__name__.startswith("object"): c = ObjectDef(id=name, factory=PythonObjectFactory(method, wrapper), scope=wrapper.scope, lazy_init=wrapper.lazy_init, abstract=wrapper.abstract, parent=wrapper.parent) objects.append(c) - except KeyError, e: + except KeyError as e: pass self.logger.debug("==============================================================") return objects @@ -66,7 +67,7 @@ def set_app_context(self, app_context): super(PythonConfig, self).set_app_context(app_context) try: _object_context[(self,)]["container"] = app_context - except KeyError, e: + except KeyError as e: _object_context[(self,)] = {"container": app_context} @@ -89,7 +90,7 @@ def _object_wrapper(f, scope, parent, log_func_name, *args, **kwargs): def _deco(f, scope, parent, log_func_name, *args, **kwargs): log = logging.getLogger("springpython.config.%s%s - %s%s" % (log_func_name, f, str(args), scope)) - if f.func_name != top_func: + if f.__name__ != top_func: log.debug("This is NOT the top-level object %s, deferring to container." % top_func) container = _object_context[args]["container"] log.debug("Container = %s" % container) @@ -97,10 +98,10 @@ def _deco(f, scope, parent, log_func_name, *args, **kwargs): if parent: parent_result = container.get_object(parent, ignore_abstract=True) log.debug("This IS the top-level object, calling %s(%s)" \ - % (f.func_name, parent_result)) - results = container.get_object(f.func_name)(parent_result) + % (f.__name__, parent_result)) + results = container.get_object(f.__name__)(parent_result) else: - results = container.get_object(f.func_name) + results = container.get_object(f.__name__) log.debug("Found %s inside the container" % results) return results @@ -109,10 +110,10 @@ def _deco(f, scope, parent, log_func_name, *args, **kwargs): container = _object_context[(args[0],)]["container"] parent_result = container.get_object(parent, ignore_abstract=True) log.debug("This IS the top-level object, calling %s(%s)" \ - % (f.func_name, parent_result)) + % (f.__name__, parent_result)) results = f(container, parent_result) else: - log.debug("This IS the top-level object, calling %s()." % f.func_name) + log.debug("This IS the top-level object, calling %s()." % f.__name__) results = f(*args, **kwargs) log.debug("Found %s" % results) diff --git a/src/springpython/config/_python_config.py.bak b/src/springpython/config/_python_config.py.bak new file mode 100644 index 0000000..2e8fb57 --- /dev/null +++ b/src/springpython/config/_python_config.py.bak @@ -0,0 +1,148 @@ +""" + Copyright 2006-2008 SpringSource (http://springsource.com), All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +try: + import cElementTree as etree +except ImportError: + try: + import xml.etree.ElementTree as etree + except ImportError: + from elementtree import ElementTree as etree + +import re +import types +import inspect +import logging + +from _config_base import * +from springpython.context import scope +from decorator import decorator, partial +from springpython.context import ApplicationContextAware +from springpython.factory import PythonObjectFactory +from springpython.factory import ReflectiveObjectFactory +from springpython.container import InvalidObjectScope + +class PythonConfig(Config, ApplicationContextAware): + """ + PythonConfig supports using pure python code to define objects. + """ + + def __init__(self): + self.logger = logging.getLogger("springpython.config.PythonConfig") + super(PythonConfig, self).__init__() + + def read_object_defs(self): + self.logger.debug("==============================================================") + objects = [] + self.logger.debug("Parsing %s" % self) + for name, method in inspect.getmembers(self, inspect.ismethod): + if name not in _pythonConfigMethods: + try: + wrapper = method.im_func.func_globals["_call_"] + + if wrapper.func_name.startswith("object"): + c = ObjectDef(id=name, factory=PythonObjectFactory(method, wrapper), + scope=wrapper.scope, lazy_init=wrapper.lazy_init, + abstract=wrapper.abstract, parent=wrapper.parent) + objects.append(c) + except KeyError, e: + pass + self.logger.debug("==============================================================") + return objects + + def set_app_context(self, app_context): + super(PythonConfig, self).set_app_context(app_context) + try: + _object_context[(self,)]["container"] = app_context + except KeyError, e: + _object_context[(self,)] = {"container": app_context} + + +_pythonConfigMethods = [name for (name, method) in inspect.getmembers(PythonConfig, inspect.ismethod)] + +_object_context = {} + +def _object_wrapper(f, scope, parent, log_func_name, *args, **kwargs): + """ + This function checks if the object already exists in the container. If so, + it will retrieve its results. Otherwise, it calls the function. + + For prototype objects, the function is basically a pass through, + because everytime a prototype function is called, there should be no + caching of results. + + Using the @decorator library greatly simplifies the implementation of this. + """ + + def _deco(f, scope, parent, log_func_name, *args, **kwargs): + log = logging.getLogger("springpython.config.%s%s - %s%s" % (log_func_name, + f, str(args), scope)) + if f.func_name != top_func: + log.debug("This is NOT the top-level object %s, deferring to container." % top_func) + container = _object_context[args]["container"] + log.debug("Container = %s" % container) + + if parent: + parent_result = container.get_object(parent, ignore_abstract=True) + log.debug("This IS the top-level object, calling %s(%s)" \ + % (f.func_name, parent_result)) + results = container.get_object(f.func_name)(parent_result) + else: + results = container.get_object(f.func_name) + + log.debug("Found %s inside the container" % results) + return results + else: + if parent: + container = _object_context[(args[0],)]["container"] + parent_result = container.get_object(parent, ignore_abstract=True) + log.debug("This IS the top-level object, calling %s(%s)" \ + % (f.func_name, parent_result)) + results = f(container, parent_result) + else: + log.debug("This IS the top-level object, calling %s()." % f.func_name) + results = f(*args, **kwargs) + + log.debug("Found %s" % results) + return results + + return _deco(f, scope, parent, log_func_name, *args, **kwargs) + +def Object(theScope=scope.SINGLETON, lazy_init=False, abstract=False, parent=None): + """ + This function is a wrapper around the function which returns the real decorator. + It decides, based on scope and lazy-init, which decorator to return. + """ + if type(theScope) == types.FunctionType: + return Object()(theScope) + + elif theScope == scope.SINGLETON: + log_func_name = "objectSingleton" + + elif theScope == scope.PROTOTYPE: + log_func_name = "objectPrototype" + + else: + raise InvalidObjectScope("Don't know how to handle scope %s" % theScope) + + def object_wrapper(f, *args, **kwargs): + return _object_wrapper(f, theScope, parent, log_func_name, *args, **kwargs) + + object_wrapper.lazy_init = lazy_init + object_wrapper.abstract = abstract + object_wrapper.parent = parent + object_wrapper.scope = theScope + + return decorator(object_wrapper) diff --git a/src/springpython/config/_xml_config.py b/src/springpython/config/_xml_config.py index 0e28f72..d5eb2bc 100644 --- a/src/springpython/config/_xml_config.py +++ b/src/springpython/config/_xml_config.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ +from __future__ import absolute_import try: import cElementTree as etree @@ -27,7 +28,7 @@ import inspect import logging -from _config_base import * +from ._config_base import * from springpython.context import scope from springpython.context import ApplicationContextAware from springpython.factory import PythonObjectFactory diff --git a/src/springpython/config/_xml_config.py.bak b/src/springpython/config/_xml_config.py.bak new file mode 100644 index 0000000..0e28f72 --- /dev/null +++ b/src/springpython/config/_xml_config.py.bak @@ -0,0 +1,607 @@ +""" + Copyright 2006-2008 SpringSource (http://springsource.com), All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +try: + import cElementTree as etree +except ImportError: + try: + import xml.etree.ElementTree as etree + except ImportError: + from elementtree import ElementTree as etree + +import re +import types +import inspect +import logging + +from _config_base import * +from springpython.context import scope +from springpython.context import ApplicationContextAware +from springpython.factory import PythonObjectFactory +from springpython.factory import ReflectiveObjectFactory +from springpython.container import InvalidObjectScope + + +xml_mappings = { + "str":"types.StringType", "unicode":"types.UnicodeType", + "int":"types.IntType", "long":"types.LongType", + "float":"types.FloatType", "decimal":"decimal.Decimal", + "bool":"types.BooleanType", "complex":"types.ComplexType", +} + +class PyContainerConfig(Config): + """ + PyContainerConfig supports the legacy XML dialect (PyContainer) of reading object definitions. + """ + + NS = "{http://www.springframework.org/springpython/schema/pycontainer-components}" + + def __init__(self, config_location): + if isinstance(config_location, list): + self.config_location = config_location + else: + self.config_location = [config_location] + self.logger = logging.getLogger("springpython.config.PyContainerConfig") + + def read_object_defs(self): + self.logger.debug("==============================================================") + objects = [] + for config in self.config_location: + self.logger.debug("* Parsing %s" % config) + components = etree.parse(config).getroot() + objects.extend([self._convert_component(component) for component in components]) + self.logger.debug("==============================================================") + return objects + + + def _convert_component(self, component): + "This function generates a object definition, then converts scope and property elements." + self.logger.debug("component: Processing %s" % component) + c = ObjectDef(component.get("id"), factory=ReflectiveObjectFactory(component.get("class"))) + if "scope" in component.attrib: + c.scope = scope.convert(component.get("scope")) + c.props = [self._convert_prop_def(p) for p in component.findall(self.NS+"property")] + return c + + def _convert_prop_def(self, p): + "This function translates object properties into useful dictionaries of information for the container." + if "local" in p.attrib or p.find(self.NS+"local") is not None: + if "local" in p.attrib: + return ReferenceDef(p.get("name"), p.get("local")) + else: + return ReferenceDef(p.get("name"), p.find(self.NS+"local")) + elif "list" in p.attrib or p.find(self.NS+"list") is not None: + if "list" in p.attrib: + return ListDef(p.name, [ReferenceDef(p.name + ".list", prop_list.local) for prop_list in p.list]) + else: + return ListDef(p.name, [ReferenceDef(p.name + ".list", prop_list.local) for prop_list in p.list]) + else: + self.logger.debug("py: name = %s code = %s" % (p.get("name"), p.text)) + thing = eval(str(p.text).strip()) + self.logger.debug("py: You have parsed %s" % thing) + return ValueDef(p.get("name"), eval(str(p.text).strip())) + +class SpringJavaConfig(Config): + """ + SpringJavaConfig supports current Spring Java format of XML bean definitions. + """ + NS = "{http://www.springframework.org/schema/beans}" + + def __init__(self, config_location): + if isinstance(config_location, list): + self.config_location = config_location + else: + self.config_location = [config_location] + self.logger = logging.getLogger("springpython.config.SpringJavaConfig") + + # By making this an instance-based property (instead of function local), inner object + # definitions can add themselves to the list in the midst of parsing an input. + self.objects = [] + + def read_object_defs(self): + self.logger.debug("==============================================================") + # Reset, in case the file is re-read + self.objects = [] + for config in self.config_location: + self.logger.debug("* Parsing %s" % config) + beans = etree.parse(config).getroot() + self.objects.extend([self._convert_bean(bean) for bean in beans]) + self.logger.debug("==============================================================") + return self.objects + + def _convert_bean(self, bean, prefix=""): + "This function generates a object definition, then converts scope and property elements." + if prefix != "": + if "id" in bean.attrib: + bean.set("id", prefix + bean.get("id")) + else: + bean.set("id", prefix + "") + + c = ObjectDef(bean.get("id"), factory=ReflectiveObjectFactory(bean.get("class"))) + + if "scope" in bean.attrib: + c.scope = scope.convert(bean.get("scope")) + self.logger.debug("bean: %s" % bean) + c.pos_constr = [self._convert_prop_def(bean, constr, bean.get("id") + ".constr") for constr in bean.findall(self.NS+"constructor-arg")] + self.logger.debug("Constructors = %s" % c.pos_constr) + c.props = [self._convert_prop_def(bean, p, p.get("name")) for p in bean.findall(self.NS+"property")] + + return c + + def _convert_prop_def(self, bean, p, name): + "This function translates object constructors/properties into useful collections of information for the container." + + if "ref" in p.keys() or p.find(self.NS+"ref") is not None: + if "ref" in p.keys(): + return ReferenceDef(name, p.get("ref")) + else: + return ReferenceDef(name, p.find(self.NS+"ref").get("bean")) + elif "value" in p.keys() or p.find(self.NS+"value") is not None: + if "value" in p.keys(): + return ValueDef(name, p.get("value")) + else: + return ValueDef(name, p.find(self.NS+"value").text) + elif p.find(self.NS+"map") is not None: + dict = {} + for entry in p.find(self.NS+"map"): + key = entry.find(self.NS+"key").find(self.NS+"value").text + if entry.find(self.NS+"value") is not None: + dict[key] = str(entry.find(self.NS+"value").text) + elif entry.find(self.NS+"ref") is not None: + dict[key] = ReferenceDef(key, entry.find(self.NS+"ref").get("bean")) + else: + self.logger.debug("Don't know how to handle %s" % entry) + return DictDef(name, dict) + elif p.find(self.NS+"props") is not None: + dict = {} + for prop in p.find(self.NS+"props"): + dict[prop.get("key")] = str(prop.text) + return DictDef(name, dict) + elif p.find(self.NS+"list") is not None: + list = [] + for element in p.find(self.NS+"list"): + if element.tag == self.NS+"value": + list.append(element.text) + elif element.tag == self.NS+"ref": + list.append(ReferenceDef(name + ".list", element.get("bean"))) + else: + self.logger.debug("Don't know how to handle %s" % element.tag) + return ListDef(name, list) + elif p.find(self.NS+"set") is not None: + s = set() + for element in p.find(self.NS+"set"): + if element.tag == self.NS+"value": + s.add(element.text) + elif element.tag == self.NS+"ref": + s.add(ReferenceDef(name + ".set", element.get("bean"))) + else: + self.logger.debug("Don't know how to handle %s" % element.tag) + return SetDef(name, s) + elif p.find(self.NS+"bean"): + inner_object_def = self._convert_bean(p.find(self.NS+"bean"), prefix=bean.get("id") + "." + name + ".") + self.objects.append(inner_object_def) + return InnerObjectDef(name, inner_object_def) + +class XMLConfig(Config): + """ + XMLConfig supports current Spring Python format of XML object definitions. + """ + + NS = "{http://www.springframework.org/springpython/schema/objects}" + NS_11 = "{http://www.springframework.org/springpython/schema/objects/1.1}" + + def __init__(self, config_location): + if isinstance(config_location, list): + self.config_location = config_location + else: + self.config_location = [config_location] + self.logger = logging.getLogger("springpython.config.XMLConfig") + + # By making this an instance-based property (instead of function local), inner object + # definitions can add themselves to the list in the midst of parsing an input. + self.objects = [] + + def read_object_defs(self): + self.logger.debug("==============================================================") + # Reset, in case the file is re-read + self.objects = [] + for config in self.config_location: + self.logger.debug("* Parsing %s" % config) + + # A flat list of objects, as found in the XML document. + objects = etree.parse(config).getroot() + + # We need to handle both 1.0 and 1.1 XSD schemata *and* we may be + # passed a list of config locations of different XSD versions so we + # must find out here which one is used in the current config file + # and pass the correct namespace down to other parts of XMLConfig. + ns = objects.tag[:objects.tag.find("}") + 1] + + # A dictionary of abstract objects, keyed by their IDs, used in + # traversing the hierarchies of parents; built upfront here for + # convenience. + abstract_objects = {} + for obj in objects: + if obj.get("abstract"): + abstract_objects[obj.get("id")] = obj + + for obj in objects: + if obj.get("class") is None and not obj.get("parent"): + self._map_custom_class(obj, xml_mappings, ns) + + elif obj.get("parent"): + # Children are added to self.objects during the children->abstract parents traversal. + pos_constr = self._get_pos_constr(obj, ns) + named_constr = self._get_named_constr(obj, ns) + props = self._get_props(obj, ns) + self._traverse_parents(obj, obj, ns, pos_constr, named_constr, props, abstract_objects) + continue + + self.objects.append(self._convert_object(obj, ns=ns)) + + self.logger.debug("==============================================================") + for object in self.objects: + self.logger.debug("Parsed %s" % object) + return self.objects + + def _map_custom_class(self, obj, mappings, ns): + """ Fill in the missing attributes of Python objects and make it look + to the rest of XMLConfig as if they already were in the XML config file. + """ + for class_name in mappings: + tag_no_ns = obj.tag.replace(ns, "") + if class_name == tag_no_ns: + + obj.set("class", mappings[class_name]) + constructor_arg = etree.Element("%s%s" % (ns, "constructor-arg")) + value = etree.Element("%s%s" % (ns, "value")) + value.text = obj.text + obj.append(constructor_arg) + constructor_arg.append(value) + obj.text = "" + + break + + else: + self.logger.warning("No matching type found for object %s" % obj) + + def _traverse_parents(self, leaf, child, ns, pos_constr, + named_constr, props, abstract_objects): + + parent = abstract_objects[child.get("parent")] + + # At this point we only build up the lists of parameters but we don't create + # the object yet because the current parent object may still have its + # own parent. + + # Positional constructors + + parent_pos_constrs = self._get_pos_constr(parent, ns) + + # Make sure there are as many child positional parameters as there + # are in the parent's list. + + len_pos_constr = len(pos_constr) + len_parent_pos_constrs = len(parent_pos_constrs) + + if len_pos_constr < len_parent_pos_constrs: + pos_constr.extend([None] * (len_parent_pos_constrs - len_pos_constr)) + + for idx, parent_pos_constr in enumerate(parent_pos_constrs): + if not pos_constr[idx]: + pos_constr[idx] = parent_pos_constr + + # Named constructors + child_named_constrs = named_constr + parent_named_constrs = self._get_named_constr(parent, ns) + + for parent_named_constr in parent_named_constrs: + if parent_named_constr not in child_named_constrs: + named_constr[parent_named_constr] = parent_named_constrs[parent_named_constr] + + # Properties + child_props = [prop.name for prop in props] + parent_props = self._get_props(parent, ns) + + for parent_prop in parent_props: + if parent_prop.name not in child_props: + props.append(parent_prop) + + if parent.get("parent"): + self._traverse_parents(leaf, parent, ns, pos_constr, named_constr, props, abstract_objects) + else: + # Now we know we can create an object out of all the accumulated values. + + # The object's class is its topmost parent's class. + class_ = parent.get("class") + id, factory, lazy_init, abstract, parent, scope_ = self._get_basic_object_data(leaf, class_) + + c = self._create_object(id, factory, lazy_init, abstract, parent, + scope_, pos_constr, named_constr, props) + + self.objects.append(c) + + return parent + + def _get_pos_constr(self, object, ns): + """ Returns a list of all positional constructor arguments of an object. + """ + return [self._convert_prop_def(object, constr, object.get("id") + ".constr", ns) for constr in object.findall(ns+"constructor-arg") + if not "name" in constr.attrib] + + def _get_named_constr(self, object, ns): + """ Returns a dictionary of all named constructor arguments of an object. + """ + return dict([(str(constr.get("name")), self._convert_prop_def(object, constr, object.get("id") + ".constr", ns)) + for constr in object.findall(ns+"constructor-arg") if "name" in constr.attrib]) + + def _get_props(self, object, ns): + """ Returns a list of all properties defined by an object. + """ + return [self._convert_prop_def(object, p, p.get("name"), ns) for p in object.findall(ns+"property")] + + def _create_object(self, id, factory, lazy_init, abstract, parent, + scope, pos_constr, named_constr, props): + """ A helper function which creates an object out of the supplied + arguments. + """ + + c = ObjectDef(id=id, factory=factory, lazy_init=lazy_init, + abstract=abstract, parent=parent) + + c.scope = scope + c.pos_constr = pos_constr + c.named_constr = named_constr + c.props = props + + self.logger.debug("object: props = %s" % c.props) + self.logger.debug("object: There are %s props" % len(c.props)) + + return c + + def _get_basic_object_data(self, object, class_): + """ A convenience method which creates basic object's data so that + the code is not repeated. + """ + + if "scope" in object.attrib: + scope_ = scope.convert(object.get("scope")) + else: + scope_ = scope.SINGLETON + + return(object.get("id"), ReflectiveObjectFactory(class_), + object.get("lazy-init", False), object.get("abstract", False), + object.get("parent"), scope_) + + def _convert_object(self, object, prefix="", ns=None): + """ This function collects all parameters required for an object creation + and then calls a helper function which creates it. + """ + if prefix != "": + if "id" in object.attrib: + object.set("id", prefix + "." + object.get("id")) + else: + object.set("id", prefix + ".") + + id, factory, lazy_init, abstract, parent, scope_ = self._get_basic_object_data(object, object.get("class")) + + pos_constr = self._get_pos_constr(object, ns) + named_constr = self._get_named_constr(object, ns) + props = self._get_props(object, ns) + + return self._create_object(id, factory, lazy_init, abstract, parent, + scope_, pos_constr, named_constr, props) + + def _convert_ref(self, ref_node, name): + if hasattr(ref_node, "attrib"): + results = ReferenceDef(name, ref_node.get("object")) + self.logger.debug("ref: Returning %s" % results) + return results + else: + results = ReferenceDef(name, ref_node) + self.logger.debug("ref: Returning %s" % results) + return results + + def _convert_value(self, value, id, name, ns): + if value.text is not None and value.text.strip() != "": + self.logger.debug("value: Converting a direct value <%s>" % value.text) + return value.text + else: + if value.tag == ns+"value": + self.logger.debug("value: Converting a value's children %s" % value.getchildren()[0]) + results = self._convert_value(value.getchildren()[0], id, name, ns) + self.logger.debug("value: results = %s" % str(results)) + return results + elif value.tag == ns+"tuple": + self.logger.debug("value: Converting a tuple") + return self._convert_tuple(value, id, name, ns).value + elif value.tag == ns+"list": + self.logger.debug("value: Converting a list") + return self._convert_list(value, id, name, ns).value + elif value.tag == ns+"dict": + self.logger.debug("value: Converting a dict") + return self._convert_dict(value, id, name, ns).value + elif value.tag == ns+"set": + self.logger.debug("value: Converting a set") + return self._convert_set(value, id, name, ns).value + elif value.tag == ns+"frozenset": + self.logger.debug("value: Converting a frozenset") + return self._convert_frozen_set(value, id, name, ns).value + else: + self.logger.debug("value: %s.%s Don't know how to handle %s" % (id, name, value.tag)) + + def _convert_dict(self, dict_node, id, name, ns): + dict = {} + for entry in dict_node.findall(ns+"entry"): + self.logger.debug("dict: entry = %s" % entry) + key = entry.find(ns+"key").find(ns+"value").text + self.logger.debug("dict: key = %s" % key) + if entry.find(ns+"value") is not None: + dict[key] = self._convert_value(entry.find(ns+"value"), id, "%s.dict['%s']" % (name, key), ns) + elif entry.find(ns+"ref") is not None: + dict[key] = self._convert_ref(entry.find(ns+"ref"), "%s.dict['%s']" % (name, key)) + elif entry.find(ns+"object") is not None: + self.logger.debug("dict: Parsing an inner object definition...") + dict[key] = self._convert_inner_object(entry.find(ns+"object"), id, "%s.dict['%s']" % (name, key), ns) + else: + for token in ["dict", "tuple", "set", "frozenset", "list"]: + if entry.find(ns+token) is not None: + dict[key] = self._convert_value(entry.find(ns+token), id, "%s.dict['%s']" % (name, key), ns) + break + if key not in dict: + self.logger.debug("dict: Don't know how to handle %s" % entry.tag) + + self.logger.debug("Dictionary is now %s" % dict) + return DictDef(name, dict) + + def _convert_props(self, props_node, name, ns): + dict = {} + self.logger.debug("props: Looking at %s" % props_node) + for prop in props_node: + dict[prop.get("key")] = str(prop.text) + self.logger.debug("props: Dictionary is now %s" % dict) + return DictDef(name, dict) + + def _convert_list(self, list_node, id, name, ns): + list = [] + self.logger.debug("list: Parsing %s" % list_node) + for element in list_node: + if element.tag == ns+"value": + list.append(get_string(element.text)) + elif element.tag == ns+"ref": + list.append(self._convert_ref(element, "%s.list[%s]" % (name, len(list)))) + elif element.tag == ns+"object": + self.logger.debug("list: Parsing an inner object definition...") + list.append(self._convert_inner_object(element, id, "%s.list[%s]" % (name, len(list)), ns)) + elif element.tag in [ns+token for token in ["dict", "tuple", "set", "frozenset", "list"]]: + self.logger.debug("This list has child elements of type %s." % element.tag) + list.append(self._convert_value(element, id, "%s.list[%s]" % (name, len(list)), ns)) + self.logger.debug("List is now %s" % list) + else: + self.logger.debug("list: Don't know how to handle %s" % element.tag) + self.logger.debug("List is now %s" % list) + return ListDef(name, list) + + def _convert_tuple(self, tuple_node, id, name, ns): + list = [] + self.logger.debug("tuple: Parsing %s" % tuple_node) + for element in tuple_node: + self.logger.debug("tuple: Looking at %s" % element) + if element.tag == ns+"value": + self.logger.debug("tuple: Appending %s" % element.text) + list.append(get_string(element.text)) + elif element.tag == ns+"ref": + list.append(self._convert_ref(element, "%s.tuple(%s}" % (name, len(list)))) + elif element.tag == ns+"object": + self.logger.debug("tuple: Parsing an inner object definition...") + list.append(self._convert_inner_object(element, id, "%s.tuple(%s)" % (name, len(list)), ns)) + elif element.tag in [ns+token for token in ["dict", "tuple", "set", "frozenset", "list"]]: + self.logger.debug("tuple: This tuple has child elements of type %s." % element.tag) + list.append(self._convert_value(element, id, "%s.tuple(%s)" % (name, len(list)), ns)) + self.logger.debug("tuple: List is now %s" % list) + else: + self.logger.debug("tuple: Don't know how to handle %s" % element.tag) + self.logger.debug("Tuple is now %s" % str(tuple(list))) + return TupleDef(name, tuple(list)) + + def _convert_set(self, set_node, id, name, ns): + s = set() + self.logger.debug("set: Parsing %s" % set_node) + for element in set_node: + self.logger.debug("Looking at element %s" % element) + if element.tag == ns+"value": + s.add(get_string(element.text)) + elif element.tag == ns+"ref": + s.add(self._convert_ref(element, name + ".set")) + elif element.tag == ns+"object": + self.logger.debug("set: Parsing an inner object definition...") + s.add(self._convert_inner_object(element, id, "%s.set(%s)" % (name, len(s)), ns)) + elif element.tag in [ns+token for token in ["dict", "tuple", "set", "frozenset", "list"]]: + self.logger.debug("set: This set has child elements of type %s." % element.tag) + s.add(self._convert_value(element, id, "%s.set(%s)" % (name,len(s)), ns)) + else: + self.logger.debug("set: Don't know how to handle %s" % element.tag) + self.logger.debug("Set is now %s" % s) + return SetDef(name, s) + + def _convert_frozen_set(self, frozen_set_node, id, name, ns): + item = self._convert_set(frozen_set_node, id, name, ns) + self.logger.debug("frozenset: Frozen set is now %s" % frozenset(item.value)) + return FrozenSetDef(name, frozenset(item.value)) + + def _convert_inner_object(self, object_node, id, name, ns): + inner_object_def = self._convert_object(object_node, prefix="%s.%s" % (id, name), ns=ns) + self.logger.debug("innerobj: Innerobject is now %s" % inner_object_def) + self.objects.append(inner_object_def) + return InnerObjectDef(name, inner_object_def) + + def _convert_prop_def(self, comp, p, name, ns): + "This function translates object properties into useful collections of information for the container." + #self.logger.debug("Is %s.%s a ref? %s" % (comp.get("id"), p.get("name"), p.find(ns+"ref") is not None or "ref" in p.attrib)) + #self.logger.debug("Is %s.%s a value? %s" % (comp.get("id"), p.get("name"), p.find(ns+"value") is not None or "value" in p.attrib)) + #self.logger.debug("Is %s.%s an inner object? %s" % (comp.get("id"), p.get("name"), p.find(ns+"object") is not None or "object" in p.attrib)) + #self.logger.debug("Is %s.%s a dict? %s" % (comp.get("id"), p.get("name"), p.find(ns+"dict") is not None or "dict" in p.attrib)) + #self.logger.debug("Is %s.%s a list? %s" % (comp.get("id"), p.get("name"), p.find(ns+"list") is not None or "list" in p.attrib)) + #self.logger.debug("Is %s.%s a tuple? %s" % (comp.get("id"), p.get("name"), p.find(ns+"tuple") is not None or "tuple" in p.attrib)) + #self.logger.debug("Is %s.%s a set? %s" % (comp.get("id"), p.get("name"), p.find(ns+"set") is not None or "set" in p.attrib)) + #self.logger.debug("Is %s.%s a frozenset? %s" % (comp.get("id"), p.get("name"), p.find(ns+"frozenset") is not None or "frozenset" in p.attrib)) + #self.logger.debug("") + if "ref" in p.attrib or p.find(ns+"ref") is not None: + if "ref" in p.attrib: + return self._convert_ref(p.get("ref"), name) + else: + return self._convert_ref(p.find(ns+"ref"), name) + elif "value" in p.attrib or p.find(ns+"value") is not None: + if "value" in p.attrib: + return ValueDef(name, get_string(p.get("value"))) + else: + return ValueDef(name, get_string(p.find(ns+"value").text)) + elif "dict" in p.attrib or p.find(ns+"dict") is not None: + if "dict" in p.attrib: + return self._convert_dict(p.get("dict"), comp.get("id"), name, ns) + else: + return self._convert_dict(p.find(ns+"dict"), comp.get("id"), name, ns) + elif "props" in p.attrib or p.find(ns+"props") is not None: + if "props" in p.attrib: + return self._convert_props(p.get("props"), name, ns) + else: + return self._convert_props(p.find(ns+"props"), name, ns) + elif "list" in p.attrib or p.find(ns+"list") is not None: + if "list" in p.attrib: + return self._convert_list(p.get("list"), comp.get("id"), name, ns) + else: + return self._convert_list(p.find(ns+"list"), comp.get("id"), name, ns) + elif "tuple" in p.attrib or p.find(ns+"tuple") is not None: + if "tuple" in p.attrib: + return self._convert_tuple(p.get("tuple"), comp.get("id"), name, ns) + else: + return self._convert_tuple(p.find(ns+"tuple"), comp.get("id"), name, ns) + elif "set" in p.attrib or p.find(ns+"set") is not None: + if "set" in p.attrib: + return self._convert_set(p.get("set"), comp.get("id"), name, ns) + else: + return self._convert_set(p.find(ns+"set"), comp.get("id"), name, ns) + elif "frozenset" in p.attrib or p.find(ns+"frozenset") is not None: + if "frozenset" in p.attrib: + return self._convert_frozen_set(p.get("frozenset"), comp.get("id"), name, ns) + else: + return self._convert_frozen_set(p.find(ns+"frozenset"), comp.get("id"), name, ns) + elif "object" in p.attrib or p.find(ns+"object") is not None: + if "object" in p.attrib: + return self._convert_inner_object(p.get("object"), comp.get("id"), name, ns) + else: + return self._convert_inner_object(p.find(ns+"object"), comp.get("id"), name, ns) + diff --git a/src/springpython/config/_yaml_config.py b/src/springpython/config/_yaml_config.py index 8f02902..b27f75b 100644 --- a/src/springpython/config/_yaml_config.py +++ b/src/springpython/config/_yaml_config.py @@ -13,15 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. """ +from __future__ import absolute_import import re import types import inspect import logging import collections -from _config_base import * +from ._config_base import * from springpython.context import scope -from decorator import decorator, partial +from .decorator import decorator, partial from springpython.context import ApplicationContextAware from springpython.factory import PythonObjectFactory from springpython.factory import ReflectiveObjectFactory diff --git a/src/springpython/config/_yaml_config.py.bak b/src/springpython/config/_yaml_config.py.bak new file mode 100644 index 0000000..8f02902 --- /dev/null +++ b/src/springpython/config/_yaml_config.py.bak @@ -0,0 +1,418 @@ +""" + Copyright 2006-2008 SpringSource (http://springsource.com), All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +import re +import types +import inspect +import logging +import collections + +from _config_base import * +from springpython.context import scope +from decorator import decorator, partial +from springpython.context import ApplicationContextAware +from springpython.factory import PythonObjectFactory +from springpython.factory import ReflectiveObjectFactory +from springpython.container import InvalidObjectScope + +yaml_mappings = { + "str":"types.StringType", "unicode":"types.UnicodeType", + "int":"types.IntType", "long":"types.LongType", + "float":"types.FloatType", "decimal":"decimal.Decimal", + "bool":"types.BooleanType", "complex":"types.ComplexType", + "list":"types.ListType", "tuple":"types.TupleType", + "dict":"types.DictType", +} + +class YamlConfig(Config): + """ + YamlConfig provides an alternative YAML-based version of objects. + """ + def __init__(self, config_location): + if isinstance(config_location, list): + self.config_location = config_location + else: + self.config_location = [config_location] + self.logger = logging.getLogger("springpython.config.YamlConfig") + + # By making this an instance-based property (instead of function local), inner object + # definitions can add themselves to the list in the midst of parsing an input. + self.objects = [] + + def read_object_defs(self): + import yaml + + self.logger.debug("==============================================================") + # Reset, in case the file is re-read + self.objects = [] + for config in self.config_location: + self.logger.debug("* Parsing %s" % config) + stream = file(config) + doc = yaml.load(stream) + self.logger.debug(doc) + + # A dictionary of abstract objects, keyed by their IDs, used in + # traversing the hierarchies of parents; built upfront here for + # convenience. + self.abstract_objects = {} + for object in doc["objects"]: + if "abstract" in object: + self.abstract_objects[object["object"]] = object + + for object in doc["objects"]: + self._print_obj(object) + self.objects.append(self._convert_object(object)) + + self.logger.debug("==============================================================") + self.logger.debug("objects = %s" % self.objects) + return self.objects + + def _map_custom_class(self, obj, mappings): + """ Enrich the object's attributes and make it look to the rest of + YamlConfig as if the object had all of them right in the definition. + """ + for class_name in mappings: + if class_name in obj: + self.logger.debug("Found a matching type: %s -> %s" % (obj["object"], + class_name)) + + obj["class"] = mappings[class_name] + obj["constructor-args"] = [obj[class_name]] + break + else: + self.logger.warning("No matching type found for object %s" % obj) + + def _convert_child_object(self, leaf, child, pos_constr, + named_constr, props): + + parent = self.abstract_objects[child["parent"]] + + # At this point we only build up the lists of parameters but we don't create + # the object yet because the current parent object may still have its + # own parent. + + # Positional constructors + + parent_pos_constrs = self._get_pos_constr(parent) + + # Make sure there are as many child positional parameters as there + # are in the parent's list. + + len_pos_constr = len(pos_constr) + len_parent_pos_constrs = len(parent_pos_constrs) + + if len_pos_constr < len_parent_pos_constrs: + pos_constr.extend([None] * (len_parent_pos_constrs - len_pos_constr)) + + for idx, parent_pos_constr in enumerate(parent_pos_constrs): + if not pos_constr[idx]: + pos_constr[idx] = parent_pos_constr + + # Named constructors + child_named_constrs = named_constr + parent_named_constrs = self._get_named_constr(parent) + + for parent_named_constr in parent_named_constrs: + if parent_named_constr not in child_named_constrs: + named_constr[parent_named_constr] = parent_named_constrs[parent_named_constr] + + # Properties + child_props = [prop.name for prop in props] + parent_props = self._get_props(parent) + + for parent_prop in parent_props: + if parent_prop.name not in child_props: + props.append(parent_prop) + + if "parent" in parent: + # Continue traversing up the parent objects + return self._convert_child_object(leaf, parent, pos_constr, named_constr, props) + else: + # Now we know we can create an object out of all the accumulated values. + + # The object's class is its topmost parent's class. + class_ = parent["class"] + id, factory, lazy_init, abstract, parent, scope_ = self._get_basic_object_data(leaf, class_) + + c = self._create_object(id, factory, lazy_init, abstract, parent, + scope_, pos_constr, named_constr, props) + + return c + + def _get_pos_constr(self, object): + """ Returns a list of all positional constructor arguments of an object. + """ + if "constructor-args" in object and isinstance(object["constructor-args"], list): + return [self._convert_prop_def(object, constr, object["object"]) for constr in object["constructor-args"]] + return [] + + def _get_named_constr(self, object): + """ Returns a dictionary of all named constructor arguments of an object. + """ + if "constructor-args" in object and isinstance(object["constructor-args"], dict): + return dict([(name, self._convert_prop_def(object, constr, object["object"])) + for (name, constr) in object["constructor-args"].items()]) + return {} + + def _get_props(self, object): + """ Returns a list of all properties defined by an object. + """ + if "properties" in object: + return [self._convert_prop_def(object, p, name) for (name, p) in object["properties"].items()] + return [] + + def _create_object(self, id, factory, lazy_init, abstract, parent, + scope, pos_constr, named_constr, props): + """ A helper function which creates an object out of the supplied + arguments. + """ + + c = ObjectDef(id=id, factory=factory, lazy_init=lazy_init, + abstract=abstract, parent=parent) + + c.scope = scope + c.pos_constr = pos_constr + c.named_constr = named_constr + c.props = props + + self.logger.debug("object: props = %s" % c.props) + self.logger.debug("object: There are %s props" % len(c.props)) + + return c + + def _get_basic_object_data(self, object, class_): + """ A convenience method which creates basic object's data so that + the code is not repeated. + """ + + if "scope" in object: + scope_ = scope.convert(object["scope"]) + else: + scope_ = scope.SINGLETON + + return(object["object"], ReflectiveObjectFactory(class_), + object.get("lazy-init", False), object.get("abstract", False), + object.get("parent"), scope_) + + def _convert_object(self, object, prefix=""): + "This function generates a object definition, then converts scope and property elements." + if prefix != "": + if "object" in object and object["object"] is not None: + object["object"] = prefix + "." + object["object"] + else: + object["object"] = prefix + "." + + if not "class" in object and "parent" not in object: + self._map_custom_class(object, yaml_mappings) + + pos_constr = self._get_pos_constr(object) + named_constr = self._get_named_constr(object) + props = self._get_props(object) + + if "parent" in object: + return self._convert_child_object(object, object, pos_constr, named_constr, props) + else: + id, factory, lazy_init, abstract, parent, scope_ = self._get_basic_object_data(object, object.get("class")) + + return self._create_object(id, factory, lazy_init, abstract, parent, + scope_, pos_constr, named_constr, props) + + def _print_obj(self, obj, level=0): + self.logger.debug("%sobject = %s" % ("\t"*level, obj["object"])) + self.logger.debug("%sclass = %s" % ("\t"*(level+1), obj.get("class"))) + + if "scope" in obj: + self.logger.debug("%sscope = %s" % ("\t"*(level+1), obj["scope"])) + else: + self.logger.debug("%sscope = singleton (default)" % ("\t"*(level+1))) + + if "properties" in obj: + self.logger.debug("%sproperties:" % ("\t"*(level+1))) + for prop in obj["properties"].keys(): + if isinstance(obj["properties"][prop], collections.Iterable) and "object" in obj["properties"][prop]: + self.logger.debug("%s%s = ..." % ("\t"*(level+2), prop)) + self._print_obj(obj["properties"][prop], level+3) + else: + self.logger.debug("%s%s = %s" % ("\t"*(level+2), prop, obj["properties"][prop])) + self.logger.debug("") + + def _convert_ref(self, ref_node, name): + self.logger.debug("ref: Parsing %s, %s" % (ref_node, name)) + if "object" in ref_node: + return ReferenceDef(name, ref_node["object"]) + else: + return ReferenceDef(name, ref_node) + + def _convert_value(self, value, id, name): + results = [] + + if isinstance(value, dict): + if "tuple" in value: + self.logger.debug("value: Converting tuple") + return self._convert_tuple(value["tuple"], id, name) + elif "list" in value: + self.logger.debug("value: Converting list") + return self._convert_list(value["list"], id, name) + elif "dict" in value: + self.logger.debug("value: Converting dict") + return self._convert_dict(value["dict"], id, name) + elif "set" in value: + self.logger.debug("value: Converting set") + return self._convert_set(value["set"], id, name) + elif "frozenset" in value: + self.logger.debug("value: Converting frozenset") + return self._convert_frozen_set(value["frozenset"], id, name) + else: + self.logger.debug("value: Plain ole value = %s" % value) + return value + + return results + + def _convert_dict(self, dict_node, id, name): + d = {} + for (k, v) in dict_node.items(): + if isinstance(v, dict): + self.logger.debug("dict: You have a special type stored at %s" % k) + if "ref" in v: + self.logger.debug("dict/ref: k,v = %s,%s" % (k, v)) + d[k] = self._convert_ref(v["ref"], "%s.dict['%s']" % (name, k)) + self.logger.debug("dict: Stored %s => %s" % (k, d[k])) + elif "tuple" in v: + self.logger.debug("dict: Converting a tuple...") + d[k] = self._convert_tuple(v["tuple"], id, "%s.dict['%s']" % (name, k)) + else: + self.logger.debug("dict: Don't know how to handle type %s" % v) + else: + self.logger.debug("dict: %s is NOT a dict, so going to convert as a value." % v) + d[k] = self._convert_value(v, id, "%s.dict['%s']" % (name, k)) + return DictDef(name, d) + + def _convert_props(self, props_node, name): + dict = {} + for prop in props_node.prop: + dict[prop.key] = str(prop) + return DictDef(name, dict) + + def _convert_list(self, list_node, id, name): + list = [] + for item in list_node: + self.logger.debug("list: Adding %s to list..." % item) + if isinstance(item, dict): + if "ref" in item: + list.append(self._convert_ref(item["ref"], "%s.list[%s]" % (name, len(list)))) + elif "object" in item: + list.append(self._convert_inner_object(item, id, "%s.list[%s]" % (name, len(list)))) + elif len(set(["dict", "tuple", "set", "frozenset", "list"]) & set(item)) > 0: + list.append(self._convert_value(item, id, "%s.list[%s]" % (name, len(list)))) + else: + self.logger.debug("list: Don't know how to handle %s" % item.keys()) + else: + list.append(item) + return ListDef(name, list) + + def _convert_tuple(self, tuple_node, id, name): + list = [] + self.logger.debug("tuple: tuple_node = %s, id = %s, name = %s" % (tuple_node, id, name)) + for item in tuple_node: + if isinstance(item, dict): + if "ref" in item: + list.append(self._convert_ref(item["ref"], name + ".tuple")) + elif "object" in item: + list.append(self._convert_inner_object(item, id, "%s.tuple[%s]" % (name, len(list)))) + elif len(set(["dict", "tuple", "set", "frozenset", "list"]) & set(item)) > 0: + list.append(self._convert_value(item, id, "%s.tuple[%s]" % (name, len(list)))) + else: + self.logger.debug("tuple: Don't know how to handle %s" % item) + else: + list.append(item) + return TupleDef(name, tuple(list)) + + def _convert_set(self, set_node, id, name): + s = set() + self.logger.debug("set: set_node = %s, id = %s, name = %s" % (set_node, id, name)) + for item in set_node: + if isinstance(item, dict): + if "ref" in item: + s.add(self._convert_ref(item["ref"], name + ".set")) + elif "object" in item: + s.add(self._convert_inner_object(item, id, "%s.set[%s]" % (name, len(s)))) + elif len(set(["dict", "tuple", "set", "frozenset", "list"]) & set(item)) > 0: + s.add(self._convert_value(item, id, "%s.set[%s]" % (name, len(s)))) + else: + self.logger.debug("set: Don't know how to handle %s" % item) + else: + s.add(item) + return SetDef(name, s) + + def _convert_frozen_set(self, frozen_set_node, id, name): + item = self._convert_set(frozen_set_node, id, name) + self.logger.debug("frozenset: Just got back converted set %s" % item) + self.logger.debug("frozenset: value is %s, which will be turned into %s" % (item.value, frozenset(item.value))) + return FrozenSetDef(name, frozenset(item.value)) + + def _convert_inner_object(self, object_node, id, name): + self.logger.debug("inner object: Converting %s" % object_node) + inner_object_def = self._convert_object(object_node, prefix="%s.%s" % (id, name)) + self.objects.append(inner_object_def) + return InnerObjectDef(name, inner_object_def) + + def _convert_prop_def(self, comp, p, name): + "This function translates object properties into useful collections of information for the container." + self.logger.debug("prop_def: Trying to read property %s -> %s" % (name, p)) + if isinstance(p, dict): + if "ref" in p: + self.logger.debug("prop_def: >>>>>>>>>>>>Call _convert_ref(%s, %s)" % (p["ref"], name)) + return self._convert_ref(p["ref"], name) + elif "tuple" in p: + self.logger.debug("prop_def: Call _convert_tuple(%s,%s,%s)" % (p["tuple"], comp["object"], name)) + return self._convert_tuple(p["tuple"], comp["object"], name) + elif "set" in p: + self.logger.debug("prop_def: Call _convert_set(%s,%s,%s)" % (p["set"], comp["object"], name)) + return self._convert_set(p["set"], comp["object"], name) + elif "frozenset" in p: + self.logger.debug("prop_def: Call _convert_frozen_set(%s,%s,%s)" % (p["frozenset"], comp["object"], name)) + return self._convert_frozen_set(p["frozenset"], comp["object"], name) + elif "object" in p: + self.logger.debug("prop_def: Call _convert_inner_object(%s,%s,%s)" % (p, comp["object"], name)) + return self._convert_inner_object(p, comp["object"], name) + else: + #self.logger.debug("prop_def: Don't know how to handle %s" % p) + return self._convert_dict(p, comp["object"], name) + elif isinstance(p, list): + return self._convert_list(p, comp["object"], name) + else: + return ValueDef(name, p) + return None + + if hasattr(p, "ref"): + return self._convert_ref(p.ref, name) + elif hasattr(p, "value"): + return ValueDef(name, str(p.value)) + elif hasattr(p, "dict"): + return self._convert_dict(p.dict, comp.id, name) + elif hasattr(p, "props"): + return self._convert_props(p.props, name) + elif hasattr(p, "list"): + return self._convert_list(p.list, comp.id, name) + elif hasattr(p, "tuple"): + return self._convert_tuple(p.tuple, comp.id, name) + elif hasattr(p, "set"): + return self._convert_set(p.set, comp.id, name) + elif hasattr(p, "frozenset"): + self.logger.debug("Converting frozenset") + return self._convert_frozen_set(p.frozenset, comp.id, name) + elif hasattr(p, "object"): + return self._convert_inner_object(p.object, comp.id, name) + diff --git a/src/springpython/config/decorator.py b/src/springpython/config/decorator.py index 5c81350..fd01aa1 100644 --- a/src/springpython/config/decorator.py +++ b/src/springpython/config/decorator.py @@ -3,12 +3,12 @@ ## Copyright (c) 2005, Michele Simionato ## All rights reserved. ## -## Redistributions of source code must retain the above copyright +## Redistributions of source code must retain the above copyright ## notice, this list of conditions and the following disclaimer. ## Redistributions in bytecode form must reproduce the above copyright ## notice, this list of conditions and the following disclaimer in ## the documentation and/or other materials provided with the -## distribution. +## distribution. ## THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS ## "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -27,26 +27,38 @@ Decorator module, see http://pypi.python.org/pypi/decorator for the documentation. """ +from __future__ import print_function -__all__ = ["decorator", "FunctionMaker", "partial", - "deprecated", "getinfo", "new_wrapper"] +__all__ = [ + "decorator", + "FunctionMaker", + "partial", + "deprecated", + "getinfo", + "new_wrapper", +] import os, sys, re, inspect, warnings + try: from functools import partial -except ImportError: # for Python version < 2.5 +except ImportError: # for Python version < 2.5 + class partial(object): "A simple replacement of functools.partial" + def __init__(self, func, *args, **kw): self.func = func - self.args = args + self.args = args self.keywords = kw + def __call__(self, *otherargs, **otherkw): kw = self.keywords.copy() kw.update(otherkw) return self.func(*(self.args + otherargs), **kw) -DEF = re.compile('\s*def\s*([_\w][_\w\d]*)\s*\(') + +DEF = re.compile("\s*def\s*([_\w][_\w\d]*)\s*\(") # basic functionality class FunctionMaker(object): @@ -55,19 +67,29 @@ class FunctionMaker(object): It has attributes name, doc, module, signature, defaults, dict and methods update and make. """ - def __init__(self, func=None, name=None, signature=None, - defaults=None, doc=None, module=None, funcdict=None): + + def __init__( + self, + func=None, + name=None, + signature=None, + defaults=None, + doc=None, + module=None, + funcdict=None, + ): if func: # func can also be a class or a callable, but not an instance method self.name = func.__name__ - if self.name == '': # small hack for lambda functions - self.name = '_lambda_' + if self.name == "": # small hack for lambda functions + self.name = "_lambda_" self.doc = func.__doc__ self.module = func.__module__ if inspect.isfunction(func): self.signature = inspect.formatargspec( - formatvalue=lambda val: "", *inspect.getargspec(func))[1:-1] - self.defaults = func.func_defaults + formatvalue=lambda val: "", *inspect.getargspec(func) + )[1:-1] + self.defaults = func.__defaults__ self.dict = func.__dict__.copy() if name: self.name = name @@ -82,99 +104,115 @@ def __init__(self, func=None, name=None, signature=None, if funcdict: self.dict = funcdict # check existence required attributes - assert hasattr(self, 'name') - if not hasattr(self, 'signature'): - raise TypeError('You are decorating a non function: %s' % func) + assert hasattr(self, "name") + if not hasattr(self, "signature"): + raise TypeError("You are decorating a non function: %s" % func) def update(self, func, **kw): "Update the signature of func with the data in self" func.__name__ = self.name - func.__doc__ = getattr(self, 'doc', None) - func.__dict__ = getattr(self, 'dict', {}) - func.func_defaults = getattr(self, 'defaults', ()) - callermodule = sys._getframe(3).f_globals.get('__name__', '?') - func.__module__ = getattr(self, 'module', callermodule) + func.__doc__ = getattr(self, "doc", None) + func.__dict__ = getattr(self, "dict", {}) + func.__defaults__ = getattr(self, "defaults", ()) + callermodule = sys._getframe(3).f_globals.get("__name__", "?") + func.__module__ = getattr(self, "module", callermodule) func.__dict__.update(kw) - + def make(self, src_templ, evaldict=None, addsource=False, **attrs): "Make a new function from a given template and update the signature" - src = src_templ % vars(self) # expand name and signature + src = src_templ % vars(self) # expand name and signature evaldict = evaldict or {} mo = DEF.match(src) if mo is None: - raise SyntaxError('not a valid function template\n%s' % src) - name = mo.group(1) # extract the function name - reserved_names = set([name] + [ - arg.strip(' *') for arg in self.signature.split(',')]) - for n, v in evaldict.iteritems(): + raise SyntaxError("not a valid function template\n%s" % src) + name = mo.group(1) # extract the function name + reserved_names = set( + [name] + [arg.strip(" *") for arg in self.signature.split(",")] + ) + for n, v in evaldict.items(): if n in reserved_names: - raise NameError('%s is overridden in\n%s' % (n, src)) - if not src.endswith('\n'): # add a newline just for safety - src += '\n' + raise NameError("%s is overridden in\n%s" % (n, src)) + if not src.endswith("\n"): # add a newline just for safety + src += "\n" try: - code = compile(src, '', 'single') - exec code in evaldict + code = compile(src, "", "single") + exec(code, evaldict) except: - print >> sys.stderr, 'Error in generated code:' - print >> sys.stderr, src + print("Error in generated code:", file=sys.stderr) + print(src, file=sys.stderr) raise func = evaldict[name] if addsource: - attrs['__source__'] = src + attrs["__source__"] = src self.update(func, **attrs) return func @classmethod - def create(cls, obj, body, evaldict, defaults=None, addsource=True,**attrs): + def create(cls, obj, body, evaldict, defaults=None, addsource=True, **attrs): """ Create a function from the strings name, signature and body. evaldict is the evaluation dictionary. If addsource is true an attribute __source__ is added to the result. The attributes attrs are added, if any. """ - if isinstance(obj, str): # "name(signature)" - name, rest = obj.strip().split('(', 1) + if isinstance(obj, str): # "name(signature)" + name, rest = obj.strip().split("(", 1) signature = rest[:-1] func = None - else: # a function + else: # a function name = None signature = None func = obj fun = cls(func, name, signature, defaults) - ibody = '\n'.join(' ' + line for line in body.splitlines()) - return fun.make('def %(name)s(%(signature)s):\n' + ibody, - evaldict, addsource, **attrs) - + ibody = "\n".join(" " + line for line in body.splitlines()) + return fun.make( + "def %(name)s(%(signature)s):\n" + ibody, evaldict, addsource, **attrs + ) + + def decorator(caller, func=None): """ decorator(caller) converts a caller function into a decorator; decorator(caller, func) decorates a function using a caller. """ - if func is not None: # returns a decorated function + if func is not None: # returns a decorated function return FunctionMaker.create( - func, "return _call_(_func_, %(signature)s)", - dict(_call_=caller, _func_=func), undecorated=func) - else: # returns a decorator + func, + "return _call_(_func_, %(signature)s)", + dict(_call_=caller, _func_=func), + undecorated=func, + ) + else: # returns a decorator if isinstance(caller, partial): return partial(decorator, caller) # otherwise assume caller is a function - f = inspect.getargspec(caller)[0][0] # first arg + f = inspect.getargspec(caller)[0][0] # first arg return FunctionMaker.create( - '%s(%s)' % (caller.__name__, f), - 'return decorator(_call_, %s)' % f, - dict(_call_=caller, decorator=decorator), undecorated=caller) + "%s(%s)" % (caller.__name__, f), + "return decorator(_call_, %s)" % f, + dict(_call_=caller, decorator=decorator), + undecorated=caller, + ) + ###################### deprecated functionality ######################### + @decorator def deprecated(func, *args, **kw): "A decorator for deprecated functions" warnings.warn( - ('Calling the deprecated function %r\n' - 'Downgrade to decorator 2.3 if you want to use this functionality') - % func.__name__, DeprecationWarning, stacklevel=3) + ( + "Calling the deprecated function %r\n" + "Downgrade to decorator 2.3 if you want to use this functionality" + ) + % func.__name__, + DeprecationWarning, + stacklevel=3, + ) return func(*args, **kw) + @deprecated def getinfo(func): """ @@ -186,7 +224,7 @@ def getinfo(func): - doc (the docstring : str) - module (the module name : str) - dict (the function __dict__ : str) - + >>> def f(self, x=1, y=2, *args, **kw): pass >>> info = getinfo(f) @@ -195,7 +233,7 @@ def getinfo(func): 'f' >>> info["argnames"] ['self', 'x', 'y', 'args', 'kw'] - + >>> info["defaults"] (1, 2) @@ -209,40 +247,51 @@ def getinfo(func): argnames.append(varargs) if varkwargs: argnames.append(varkwargs) - signature = inspect.formatargspec(regargs, varargs, varkwargs, defaults, - formatvalue=lambda value: "")[1:-1] - return dict(name=func.__name__, argnames=argnames, signature=signature, - defaults = func.func_defaults, doc=func.__doc__, - module=func.__module__, dict=func.__dict__, - globals=func.func_globals, closure=func.func_closure) + signature = inspect.formatargspec( + regargs, varargs, varkwargs, defaults, formatvalue=lambda value: "" + )[1:-1] + return dict( + name=func.__name__, + argnames=argnames, + signature=signature, + defaults=func.__defaults__, + doc=func.__doc__, + module=func.__module__, + dict=func.__dict__, + globals=func.__globals__, + closure=func.__closure__, + ) + @deprecated def update_wrapper(wrapper, model, infodict=None): "A replacement for functools.update_wrapper" infodict = infodict or getinfo(model) - wrapper.__name__ = infodict['name'] - wrapper.__doc__ = infodict['doc'] - wrapper.__module__ = infodict['module'] - wrapper.__dict__.update(infodict['dict']) - wrapper.func_defaults = infodict['defaults'] + wrapper.__name__ = infodict["name"] + wrapper.__doc__ = infodict["doc"] + wrapper.__module__ = infodict["module"] + wrapper.__dict__.update(infodict["dict"]) + wrapper.__defaults__ = infodict["defaults"] wrapper.undecorated = model return wrapper + @deprecated def new_wrapper(wrapper, model): """ An improvement over functools.update_wrapper. The wrapper is a generic - callable object. It works by generating a copy of the wrapper with the + callable object. It works by generating a copy of the wrapper with the right signature and by updating the copy, not the original. Moreovoer, 'model' can be a dictionary with keys 'name', 'doc', 'module', 'dict', 'defaults'. """ if isinstance(model, dict): infodict = model - else: # assume model is a function + else: # assume model is a function infodict = getinfo(model) - assert not '_wrapper_' in infodict["argnames"], ( - '"_wrapper_" is a reserved argument name!') + assert ( + not "_wrapper_" in infodict["argnames"] + ), '"_wrapper_" is a reserved argument name!' src = "lambda %(signature)s: _wrapper_(%(signature)s)" % infodict funcopy = eval(src, dict(_wrapper_=wrapper)) return update_wrapper(funcopy, model, infodict) diff --git a/src/springpython/config/decorator.py.bak b/src/springpython/config/decorator.py.bak new file mode 100644 index 0000000..5c81350 --- /dev/null +++ b/src/springpython/config/decorator.py.bak @@ -0,0 +1,248 @@ +########################## LICENCE ############################### +## +## Copyright (c) 2005, Michele Simionato +## All rights reserved. +## +## Redistributions of source code must retain the above copyright +## notice, this list of conditions and the following disclaimer. +## Redistributions in bytecode form must reproduce the above copyright +## notice, this list of conditions and the following disclaimer in +## the documentation and/or other materials provided with the +## distribution. + +## THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +## "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +## LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +## A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +## HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +## INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +## BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS +## OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +## ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR +## TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +## USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH +## DAMAGE. + +""" +Decorator module, see http://pypi.python.org/pypi/decorator +for the documentation. +""" + +__all__ = ["decorator", "FunctionMaker", "partial", + "deprecated", "getinfo", "new_wrapper"] + +import os, sys, re, inspect, warnings +try: + from functools import partial +except ImportError: # for Python version < 2.5 + class partial(object): + "A simple replacement of functools.partial" + def __init__(self, func, *args, **kw): + self.func = func + self.args = args + self.keywords = kw + def __call__(self, *otherargs, **otherkw): + kw = self.keywords.copy() + kw.update(otherkw) + return self.func(*(self.args + otherargs), **kw) + +DEF = re.compile('\s*def\s*([_\w][_\w\d]*)\s*\(') + +# basic functionality +class FunctionMaker(object): + """ + An object with the ability to create functions with a given signature. + It has attributes name, doc, module, signature, defaults, dict and + methods update and make. + """ + def __init__(self, func=None, name=None, signature=None, + defaults=None, doc=None, module=None, funcdict=None): + if func: + # func can also be a class or a callable, but not an instance method + self.name = func.__name__ + if self.name == '': # small hack for lambda functions + self.name = '_lambda_' + self.doc = func.__doc__ + self.module = func.__module__ + if inspect.isfunction(func): + self.signature = inspect.formatargspec( + formatvalue=lambda val: "", *inspect.getargspec(func))[1:-1] + self.defaults = func.func_defaults + self.dict = func.__dict__.copy() + if name: + self.name = name + if signature is not None: + self.signature = signature + if defaults: + self.defaults = defaults + if doc: + self.doc = doc + if module: + self.module = module + if funcdict: + self.dict = funcdict + # check existence required attributes + assert hasattr(self, 'name') + if not hasattr(self, 'signature'): + raise TypeError('You are decorating a non function: %s' % func) + + def update(self, func, **kw): + "Update the signature of func with the data in self" + func.__name__ = self.name + func.__doc__ = getattr(self, 'doc', None) + func.__dict__ = getattr(self, 'dict', {}) + func.func_defaults = getattr(self, 'defaults', ()) + callermodule = sys._getframe(3).f_globals.get('__name__', '?') + func.__module__ = getattr(self, 'module', callermodule) + func.__dict__.update(kw) + + def make(self, src_templ, evaldict=None, addsource=False, **attrs): + "Make a new function from a given template and update the signature" + src = src_templ % vars(self) # expand name and signature + evaldict = evaldict or {} + mo = DEF.match(src) + if mo is None: + raise SyntaxError('not a valid function template\n%s' % src) + name = mo.group(1) # extract the function name + reserved_names = set([name] + [ + arg.strip(' *') for arg in self.signature.split(',')]) + for n, v in evaldict.iteritems(): + if n in reserved_names: + raise NameError('%s is overridden in\n%s' % (n, src)) + if not src.endswith('\n'): # add a newline just for safety + src += '\n' + try: + code = compile(src, '', 'single') + exec code in evaldict + except: + print >> sys.stderr, 'Error in generated code:' + print >> sys.stderr, src + raise + func = evaldict[name] + if addsource: + attrs['__source__'] = src + self.update(func, **attrs) + return func + + @classmethod + def create(cls, obj, body, evaldict, defaults=None, addsource=True,**attrs): + """ + Create a function from the strings name, signature and body. + evaldict is the evaluation dictionary. If addsource is true an attribute + __source__ is added to the result. The attributes attrs are added, + if any. + """ + if isinstance(obj, str): # "name(signature)" + name, rest = obj.strip().split('(', 1) + signature = rest[:-1] + func = None + else: # a function + name = None + signature = None + func = obj + fun = cls(func, name, signature, defaults) + ibody = '\n'.join(' ' + line for line in body.splitlines()) + return fun.make('def %(name)s(%(signature)s):\n' + ibody, + evaldict, addsource, **attrs) + +def decorator(caller, func=None): + """ + decorator(caller) converts a caller function into a decorator; + decorator(caller, func) decorates a function using a caller. + """ + if func is not None: # returns a decorated function + return FunctionMaker.create( + func, "return _call_(_func_, %(signature)s)", + dict(_call_=caller, _func_=func), undecorated=func) + else: # returns a decorator + if isinstance(caller, partial): + return partial(decorator, caller) + # otherwise assume caller is a function + f = inspect.getargspec(caller)[0][0] # first arg + return FunctionMaker.create( + '%s(%s)' % (caller.__name__, f), + 'return decorator(_call_, %s)' % f, + dict(_call_=caller, decorator=decorator), undecorated=caller) + +###################### deprecated functionality ######################### + +@decorator +def deprecated(func, *args, **kw): + "A decorator for deprecated functions" + warnings.warn( + ('Calling the deprecated function %r\n' + 'Downgrade to decorator 2.3 if you want to use this functionality') + % func.__name__, DeprecationWarning, stacklevel=3) + return func(*args, **kw) + +@deprecated +def getinfo(func): + """ + Returns an info dictionary containing: + - name (the name of the function : str) + - argnames (the names of the arguments : list) + - defaults (the values of the default arguments : tuple) + - signature (the signature : str) + - doc (the docstring : str) + - module (the module name : str) + - dict (the function __dict__ : str) + + >>> def f(self, x=1, y=2, *args, **kw): pass + + >>> info = getinfo(f) + + >>> info["name"] + 'f' + >>> info["argnames"] + ['self', 'x', 'y', 'args', 'kw'] + + >>> info["defaults"] + (1, 2) + + >>> info["signature"] + 'self, x, y, *args, **kw' + """ + assert inspect.ismethod(func) or inspect.isfunction(func) + regargs, varargs, varkwargs, defaults = inspect.getargspec(func) + argnames = list(regargs) + if varargs: + argnames.append(varargs) + if varkwargs: + argnames.append(varkwargs) + signature = inspect.formatargspec(regargs, varargs, varkwargs, defaults, + formatvalue=lambda value: "")[1:-1] + return dict(name=func.__name__, argnames=argnames, signature=signature, + defaults = func.func_defaults, doc=func.__doc__, + module=func.__module__, dict=func.__dict__, + globals=func.func_globals, closure=func.func_closure) + +@deprecated +def update_wrapper(wrapper, model, infodict=None): + "A replacement for functools.update_wrapper" + infodict = infodict or getinfo(model) + wrapper.__name__ = infodict['name'] + wrapper.__doc__ = infodict['doc'] + wrapper.__module__ = infodict['module'] + wrapper.__dict__.update(infodict['dict']) + wrapper.func_defaults = infodict['defaults'] + wrapper.undecorated = model + return wrapper + +@deprecated +def new_wrapper(wrapper, model): + """ + An improvement over functools.update_wrapper. The wrapper is a generic + callable object. It works by generating a copy of the wrapper with the + right signature and by updating the copy, not the original. + Moreovoer, 'model' can be a dictionary with keys 'name', 'doc', 'module', + 'dict', 'defaults'. + """ + if isinstance(model, dict): + infodict = model + else: # assume model is a function + infodict = getinfo(model) + assert not '_wrapper_' in infodict["argnames"], ( + '"_wrapper_" is a reserved argument name!') + src = "lambda %(signature)s: _wrapper_(%(signature)s)" % infodict + funcopy = eval(src, dict(_wrapper_=wrapper)) + return update_wrapper(funcopy, model, infodict) diff --git a/src/springpython/container/__init__.py b/src/springpython/container/__init__.py index 2ddfad8..df4a1d7 100644 --- a/src/springpython/container/__init__.py +++ b/src/springpython/container/__init__.py @@ -70,7 +70,7 @@ def get_object(self, name, ignore_abstract=False): return self.objects[name] - except KeyError, e: + except KeyError as e: self.logger.debug("Did NOT find object '%s' in the singleton storage." % name) try: object_def = self.object_defs[name] @@ -89,7 +89,7 @@ def get_object(self, name, ignore_abstract=False): raise InvalidObjectScope("Don't know how to handle scope %s" % self.object_defs[name].scope) return comp - except KeyError, e: + except KeyError as e: self.logger.error("Object '%s' has no definition!" % name) raise e diff --git a/src/springpython/container/__init__.py.bak b/src/springpython/container/__init__.py.bak new file mode 100644 index 0000000..2ddfad8 --- /dev/null +++ b/src/springpython/container/__init__.py.bak @@ -0,0 +1,143 @@ +""" + Copyright 2006-2008 SpringSource (http://springsource.com), All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +import logging +from springpython.context import scope + +class ObjectContainer(object): + """ + ObjectContainer is a container which uses multiple Config objects to read sources of + object definitions. When a object is requested from this container, it may optionally + pull the object from a scoped cache. If there is no stored copy of the object, it + uses the scanned definition and its associated ObjectFactory to create an instance. It + can then optionally store it in a scoped cache for future usage (e.g. singleton). + + Object definitions are stored in the container in a neutral format, decoupling the + container entirely from the original source location. This means that XML, python code, + and other formats may all contain definitions. By the time they + reach this container, it doesn't matter what their original format was when a object + instance is needed. NOTE: This explicitly means that one object in one source + can refer to another object in another source OF ANY FORMAT as a property. + """ + def __init__(self, config = None): + self.logger = logging.getLogger("springpython.container.ObjectContainer") + + if config is None: + self.configs = [] + elif isinstance(config, list): + self.configs = config + else: + self.configs = [config] + + self.object_defs = {} + + for configuration in self.configs: + self.logger.debug("=== Scanning configuration %s for object definitions ===" % configuration) + for object_def in configuration.read_object_defs(): + if object_def.id not in self.object_defs: + self.logger.debug("%s object definition does not exist. Adding to list of definitions." % object_def.id) + else: + self.logger.debug("Overriding previous definition of %s" % object_def.id) + self.object_defs[object_def.id] = object_def + + self.logger.debug("=== Done reading object definitions. ===") + + self.objects = {} + + def get_object(self, name, ignore_abstract=False): + """ + This function attempts to find the object in the singleton cache. If not found, + delegates to _create_object in order to hunt for the definition, and request a + object factory to generate one. + """ + try: + object_def = self.object_defs[name] + if object_def.abstract and not ignore_abstract: + raise AbstractObjectException("Object [%s] is an abstract one." % name) + + return self.objects[name] + + except KeyError, e: + self.logger.debug("Did NOT find object '%s' in the singleton storage." % name) + try: + object_def = self.object_defs[name] + if object_def.abstract and not ignore_abstract: + raise AbstractObjectException("Object [%s] is an abstract one." % name) + + comp = self._create_object(object_def) + + # Evaluate any scopes, and store appropriately. + if self.object_defs[name].scope == scope.SINGLETON: + self.objects[name] = comp + self.logger.debug("Stored object '%s' in container's singleton storage" % name) + elif self.object_defs[name].scope == scope.PROTOTYPE: + pass + else: + raise InvalidObjectScope("Don't know how to handle scope %s" % self.object_defs[name].scope) + + return comp + except KeyError, e: + self.logger.error("Object '%s' has no definition!" % name) + raise e + + def _get_constructors_pos(self, object_def): + """ + This function iterates over the positional constructors, and assembles their values into a list. + In this situation, the order as read from the XML should be the order expected by the class + definition. + """ + return tuple([constr.get_value(self) for constr in object_def.pos_constr + if hasattr(constr, "get_value")]) + + def _get_constructors_kw(self, kwargs): + """ + This function iterates over the named constructors, and assembles their values into a list. + In this situation, each argument is associated with a name, and due to unicode format provided + by the XML parser, requires conversion into a new dictionary. + """ + return dict([(key, kwargs[key].get_value(self)) for key in kwargs + if hasattr(kwargs[key], "get_value")]) + + + def _create_object(self, object_def): + """ + If the object isn't stored in any scoped cache, and must instead be created, this method + takes all the steps to read the object's definition, res it up, and store it in the appropriate + scoped cache. + """ + self.logger.debug("Creating an instance of %s" % object_def) + + [constr.prefetch(self) for constr in object_def.pos_constr if hasattr(constr, "prefetch")] + [constr.prefetch(self) for constr in object_def.named_constr.values() if hasattr(constr, "prefetch")] + [prop.prefetch(self) for prop in object_def.props if hasattr(prop, "prefetch")] + + # Res up an instance of the object, with ONLY constructor-based properties set. + obj = object_def.factory.create_object(self._get_constructors_pos(object_def), + self._get_constructors_kw(object_def.named_constr)) + + # Fill in the other property values. + [prop.set_value(obj, self) for prop in object_def.props if hasattr(prop, "set_value")] + + return obj + + +class AbstractObjectException(Exception): + """ Raised when the user's code tries to get an abstract object from + the container. + """ + +class InvalidObjectScope(Exception): + pass diff --git a/src/springpython/context/__init__.py b/src/springpython/context/__init__.py index 790e7ee..ea59498 100644 --- a/src/springpython/context/__init__.py +++ b/src/springpython/context/__init__.py @@ -16,26 +16,36 @@ import atexit import logging +import pdb from traceback import format_exc from springpython.container import ObjectContainer + class ApplicationContext(ObjectContainer): """ ApplicationContext IS a ObjectContainer. It also has the ability to define the lifecycle of objects. """ - def __init__(self, config = None): + + def __init__(self, config=None): super(ApplicationContext, self).__init__(config) - + atexit.register(self.shutdown_hook) - + self.logger = logging.getLogger("springpython.context.ApplicationContext") - self.classnames_to_avoid = set(["PyroProxyFactory", "ProxyFactoryObject", "Pyro4ProxyFactory", "Pyro4FactoryObject"]) - + self.classnames_to_avoid = set( + [ + "PyroProxyFactory", + "ProxyFactoryObject", + "Pyro4ProxyFactory", + "Pyro4FactoryObject", + ] + ) + for object_def in self.object_defs.values(): self._apply(object_def) - + for configuration in self.configs: self._apply(configuration) @@ -44,108 +54,137 @@ def __init__(self, config = None): self.logger.debug("Eagerly fetching %s" % object_def.id) self.get_object(object_def.id, ignore_abstract=True) - post_processors = [object for object in self.objects.values() if isinstance(object, ObjectPostProcessor)] + post_processors = [ + object + for object in self.objects.values() + if isinstance(object, ObjectPostProcessor) + ] - for obj_name, obj in self.objects.iteritems(): + for obj_name, obj in self.objects.items(): if not isinstance(obj, ObjectPostProcessor): for post_processor in post_processors: - self.objects[obj_name] = post_processor.post_process_before_initialization(obj, obj_name) - + self.objects[ + obj_name + ] = post_processor.post_process_before_initialization(obj, obj_name) for object in self.objects.values(): self._apply(object) - for obj_name, obj in self.objects.iteritems(): + for obj_name, obj in self.objects.items(): if not isinstance(obj, ObjectPostProcessor): for post_processor in post_processors: - self.objects[obj_name] = post_processor.post_process_after_initialization(obj, obj_name) - + self.objects[ + obj_name + ] = post_processor.post_process_after_initialization(obj, obj_name) + def _apply(self, obj): - if not (obj.__class__.__name__ in self.classnames_to_avoid): + if not (obj.__class__.__name__ in self.classnames_to_avoid): if hasattr(obj, "after_properties_set"): obj.after_properties_set() - #if hasattr(obj, "post_process_after_initialization"): + # if hasattr(obj, "post_process_after_initialization"): # obj.post_process_after_initialization(self) if hasattr(obj, "set_app_context"): obj.set_app_context(self) - + def get_objects_by_type(self, type_, include_type=True): - """ Returns all objects which are instances of a given type. + """Returns all objects which are instances of a given type. If include_type is False then only instances of the type's subclasses will be returned. """ result = {} - for obj_name, obj in self.objects.iteritems(): + for obj_name, obj in self.objects.items(): if isinstance(obj, type_): if include_type == False and type(obj) is type_: continue result[obj_name] = obj - + return result - + def shutdown_hook(self): self.logger.debug("Invoking the destroy_method on registered objects") - - for obj_name, obj in self.objects.iteritems(): + + for obj_name, obj in self.objects.items(): if isinstance(obj, DisposableObject): try: if hasattr(obj, "destroy_method"): destroy_method_name = getattr(obj, "destroy_method") else: destroy_method_name = "destroy" - + destroy_method = getattr(obj, destroy_method_name) - - except Exception, e: - self.logger.error("Could not destroy object '%s', exception '%s'" % (obj_name, format_exc())) - + + except Exception as e: + self.logger.error( + "Could not destroy object '%s', exception '%s'" + % (obj_name, format_exc()) + ) + else: if callable(destroy_method): try: self.logger.debug("About to destroy object '%s'" % obj_name) destroy_method() - self.logger.debug("Successfully destroyed object '%s'" % obj_name) - except Exception, e: - self.logger.error("Could not destroy object '%s', exception '%s'" % (obj_name, format_exc())) + self.logger.debug( + "Successfully destroyed object '%s'" % obj_name + ) + except Exception as e: + self.logger.error( + "Could not destroy object '%s', exception '%s'" + % (obj_name, format_exc()) + ) else: - self.logger.error("Could not destroy object '%s', " \ - "the 'destroy_method' attribute it defines is not callable, " \ - "its type is '%r', value is '%r'" % (obj_name, type(destroy_method), destroy_method)) - - self.logger.debug("Successfully invoked the destroy_method on registered objects") - + self.logger.error( + "Could not destroy object '%s', " + "the 'destroy_method' attribute it defines is not callable, " + "its type is '%r', value is '%r'" + % (obj_name, type(destroy_method), destroy_method) + ) + + self.logger.debug( + "Successfully invoked the destroy_method on registered objects" + ) + + class InitializingObject(object): """This allows definition of a method which is invoked by the container after an object has had all properties set.""" + def after_properties_set(self): pass + class ObjectPostProcessor(object): def post_process_before_initialization(self, obj, obj_name): return obj + def post_process_after_initialization(self, obj, obj_name): return obj + class ApplicationContextAware(object): def __init__(self): self.app_context = None - + def set_app_context(self, app_context): self.app_context = app_context + class ObjectNameAutoProxyCreator(ApplicationContextAware, ObjectPostProcessor): """ This object will iterate over a list of objects, and automatically apply a list of advisors to every callable method. This is useful when default advice needs to be applied widely with minimal configuration. """ - def __init__(self, objectNames = [], interceptorNames = []): + + def __init__(self, objectNames=[], interceptorNames=[]): super(ObjectNameAutoProxyCreator, self).__init__() self.objectNames = objectNames self.interceptorNames = interceptorNames + class DisposableObject(object): - """ This allows definition of a method which is invoked when the + """This allows definition of a method which is invoked when the container's shutting down to release the resources held by an object. """ + def destroy(self): raise NotImplementedError("Should be overridden by subclasses") diff --git a/src/springpython/context/__init__.py.bak b/src/springpython/context/__init__.py.bak new file mode 100644 index 0000000..790e7ee --- /dev/null +++ b/src/springpython/context/__init__.py.bak @@ -0,0 +1,151 @@ +""" + Copyright 2006-2008 SpringSource (http://springsource.com), All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +import atexit +import logging +from traceback import format_exc + +from springpython.container import ObjectContainer + +class ApplicationContext(ObjectContainer): + """ + ApplicationContext IS a ObjectContainer. It also has the ability to define the lifecycle of + objects. + """ + def __init__(self, config = None): + super(ApplicationContext, self).__init__(config) + + atexit.register(self.shutdown_hook) + + self.logger = logging.getLogger("springpython.context.ApplicationContext") + self.classnames_to_avoid = set(["PyroProxyFactory", "ProxyFactoryObject", "Pyro4ProxyFactory", "Pyro4FactoryObject"]) + + for object_def in self.object_defs.values(): + self._apply(object_def) + + for configuration in self.configs: + self._apply(configuration) + + for object_def in self.object_defs.values(): + if not object_def.lazy_init and object_def.id not in self.objects: + self.logger.debug("Eagerly fetching %s" % object_def.id) + self.get_object(object_def.id, ignore_abstract=True) + + post_processors = [object for object in self.objects.values() if isinstance(object, ObjectPostProcessor)] + + for obj_name, obj in self.objects.iteritems(): + if not isinstance(obj, ObjectPostProcessor): + for post_processor in post_processors: + self.objects[obj_name] = post_processor.post_process_before_initialization(obj, obj_name) + + + for object in self.objects.values(): + self._apply(object) + + for obj_name, obj in self.objects.iteritems(): + if not isinstance(obj, ObjectPostProcessor): + for post_processor in post_processors: + self.objects[obj_name] = post_processor.post_process_after_initialization(obj, obj_name) + + def _apply(self, obj): + if not (obj.__class__.__name__ in self.classnames_to_avoid): + if hasattr(obj, "after_properties_set"): + obj.after_properties_set() + #if hasattr(obj, "post_process_after_initialization"): + # obj.post_process_after_initialization(self) + if hasattr(obj, "set_app_context"): + obj.set_app_context(self) + + def get_objects_by_type(self, type_, include_type=True): + """ Returns all objects which are instances of a given type. + If include_type is False then only instances of the type's subclasses + will be returned. + """ + result = {} + for obj_name, obj in self.objects.iteritems(): + if isinstance(obj, type_): + if include_type == False and type(obj) is type_: + continue + result[obj_name] = obj + + return result + + def shutdown_hook(self): + self.logger.debug("Invoking the destroy_method on registered objects") + + for obj_name, obj in self.objects.iteritems(): + if isinstance(obj, DisposableObject): + try: + if hasattr(obj, "destroy_method"): + destroy_method_name = getattr(obj, "destroy_method") + else: + destroy_method_name = "destroy" + + destroy_method = getattr(obj, destroy_method_name) + + except Exception, e: + self.logger.error("Could not destroy object '%s', exception '%s'" % (obj_name, format_exc())) + + else: + if callable(destroy_method): + try: + self.logger.debug("About to destroy object '%s'" % obj_name) + destroy_method() + self.logger.debug("Successfully destroyed object '%s'" % obj_name) + except Exception, e: + self.logger.error("Could not destroy object '%s', exception '%s'" % (obj_name, format_exc())) + else: + self.logger.error("Could not destroy object '%s', " \ + "the 'destroy_method' attribute it defines is not callable, " \ + "its type is '%r', value is '%r'" % (obj_name, type(destroy_method), destroy_method)) + + self.logger.debug("Successfully invoked the destroy_method on registered objects") + +class InitializingObject(object): + """This allows definition of a method which is invoked by the container after an object has had all properties set.""" + def after_properties_set(self): + pass + +class ObjectPostProcessor(object): + def post_process_before_initialization(self, obj, obj_name): + return obj + def post_process_after_initialization(self, obj, obj_name): + return obj + +class ApplicationContextAware(object): + def __init__(self): + self.app_context = None + + def set_app_context(self, app_context): + self.app_context = app_context + +class ObjectNameAutoProxyCreator(ApplicationContextAware, ObjectPostProcessor): + """ + This object will iterate over a list of objects, and automatically apply + a list of advisors to every callable method. This is useful when default advice + needs to be applied widely with minimal configuration. + """ + def __init__(self, objectNames = [], interceptorNames = []): + super(ObjectNameAutoProxyCreator, self).__init__() + self.objectNames = objectNames + self.interceptorNames = interceptorNames + +class DisposableObject(object): + """ This allows definition of a method which is invoked when the + container's shutting down to release the resources held by an object. + """ + def destroy(self): + raise NotImplementedError("Should be overridden by subclasses") diff --git a/src/springpython/database/core.py b/src/springpython/database/core.py index 034744e..d1f890f 100644 --- a/src/springpython/database/core.py +++ b/src/springpython/database/core.py @@ -71,13 +71,13 @@ def _execute(self, sql_statement, args = None): cursor.execute(sql_statement) rows_affected = cursor.rowcount lastrowid = cursor.lastrowid - except Exception, e: + except Exception as e: self.logger.debug("execute.execute: Trapped %s while trying to execute '%s'" % (e, sql_statement)) error = e finally: try: cursor.close() - except Exception, e: + except Exception as e: self.logger.debug("execute.close: Trapped %s, and throwing away." % e) if error: @@ -131,13 +131,13 @@ def __query_for_list(self, sql_query, args = None): cursor.execute(sql_query) results = cursor.fetchall() metadata = [{"name":row[0], "type_code":row[1], "display_size":row[2], "internal_size":row[3], "precision":row[4], "scale":row[5], "null_ok":row[6]} for row in cursor.description] - except Exception, e: + except Exception as e: self.logger.debug("query_for_list.execute: Trapped %s while trying to execute '%s'" % (e, sql_query)) error = e finally: try: cursor.close() - except Exception, e: + except Exception as e: self.logger.debug("query_for_list.close: Trapped %s, and throwing away." % e) if error: @@ -149,12 +149,12 @@ def __query_for_list(self, sql_query, args = None): def query_for_int(self, sql_query, args = None): """Execute a query that results in an int value, given static SQL. If args is provided, bind the arguments (to avoid SQL injection attacks).""" - return self.query_for_object(sql_query, args, types.IntType) + return self.query_for_object(sql_query, args, int) def query_for_long(self, sql_query, args = None): """Execute a query that results in an int value, given static SQL. If args is provided, bind the arguments (to avoid SQL injection attacks).""" - return self.query_for_object(sql_query, args, types.LongType) + return self.query_for_object(sql_query, args, int) def query_for_object(self, sql_query, args = None, required_type = None): """Execute a query that results in an int value, given static SQL. If args is provided, bind the arguments @@ -175,7 +175,7 @@ def query_for_object(self, sql_query, args = None, required_type = None): raise IncorrectResultSizeDataAccessException("Instead of getting one column, this query returned %s" % len(results[0])) equivalentTypes = [ - [types.UnicodeType, types.StringType] + [str, bytes] ] if type(results[0][0]) != required_type: foundEquivType = False diff --git a/src/springpython/database/core.py.bak b/src/springpython/database/core.py.bak new file mode 100644 index 0000000..034744e --- /dev/null +++ b/src/springpython/database/core.py.bak @@ -0,0 +1,233 @@ +""" + Copyright 2006-2008 SpringSource (http://springsource.com), All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +import logging +import types +from springpython.database import ArgumentMustBeNamed +from springpython.database import DataAccessException +from springpython.database import IncorrectResultSizeDataAccessException +from springpython.database import InvalidArgumentType +from springpython.database import factory + +class DaoSupport(object): + """ + Any class that extends this one will be provided with a DatabaseTemplate class + to help carry out database operations. It requires that a connection object be + provided during instantion. + """ + def __init__(self, connection_factory = None): + self.database_template = DatabaseTemplate() + self.connection_factory = connection_factory + + def __setattr__(self, name, value): + """When the connection factory is set, pass it on through to the database template.""" + self.__dict__[name] = value + if name == "connection_factory" and value: + self.__dict__["database_template"].connection_factory = value + +class DatabaseTemplate(object): + """ + This class is meant to mimic the Spring framework's JdbcTemplate class. + Since Python doesn't use JDBC, the name is generalized to "Database" + """ + def __init__(self, connection_factory = None): + self.connection_factory = connection_factory + self.logger = logging.getLogger("springpython.database.core.DatabaseTemplate") + + def __del__(self): + "When this template goes out of scope, need to close the connection it formed." + if self.connection_factory is not None: self.connection_factory.close() + + def _execute(self, sql_statement, args = None): + """Issue a single SQL execute, typically a DDL statement.""" + + if args and type(args) not in self.connection_factory.acceptable_types: + raise InvalidArgumentType(type(args), self.connection_factory.acceptable_types) + + sql_statement = self.connection_factory.convert_sql_binding(sql_statement) + + cursor = self.connection_factory.getConnection().cursor() + error = None + rows_affected = 0 + try: + try: + if args: + cursor.execute(sql_statement, args) + rows_affected = cursor.rowcount + lastrowid = cursor.lastrowid + else: + cursor.execute(sql_statement) + rows_affected = cursor.rowcount + lastrowid = cursor.lastrowid + except Exception, e: + self.logger.debug("execute.execute: Trapped %s while trying to execute '%s'" % (e, sql_statement)) + error = e + finally: + try: + cursor.close() + except Exception, e: + self.logger.debug("execute.close: Trapped %s, and throwing away." % e) + + if error: + raise DataAccessException(error) + + return {"rows_affected":rows_affected, "lastrowid":lastrowid} + + def execute(self, sql_statement, args = None): + """Execute a single SQL statement, and return the number of rows affected.""" + return self._execute(sql_statement, args)["rows_affected"] + + def insert_and_return_id(self, sql_statement, args = None): + """Execute a single INSERT statement, and return the PK of the new row.""" + return self._execute(sql_statement, args)["lastrowid"] + + def query(self, sql_query, args = None, rowhandler = None): + """Execute a query given static SQL, reading the ResultSet on a per-row basis with a RowMapper. + If args is provided, bind the arguments (to avoid SQL injection attacks).""" + + # This is the case where only two, non-named arguments were provided, the sql_query and one other. + # If the second argument was 'args', it is invalid since 'rowhandler' is required. + # It is was 'rowhandler', it shifted into 'args' position, and requires naming. + if args and not rowhandler: + raise ArgumentMustBeNamed(arg_name="rowhandler") + + results, metadata = self.__query_for_list(sql_query, args) + return [rowhandler.map_row(row, metadata) for row in results] + + def query_for_list(self, sql_query, args = None): + results, metadata = self.__query_for_list(sql_query, args) + return results + + def __query_for_list(self, sql_query, args = None): + """Execute a query for a result list, given static SQL. If args is provided, bind the arguments + (to avoid SQL injection attacks).""" + + if args and type(args) not in self.connection_factory.acceptable_types: + raise InvalidArgumentType(type(args), self.connection_factory.acceptable_types) + + sql_query = self.connection_factory.convert_sql_binding(sql_query) + + cursor = self.connection_factory.getConnection().cursor() + error = None + results = None + metadata = None + try: + try: + if args: + cursor.execute(sql_query, args) + else: + cursor.execute(sql_query) + results = cursor.fetchall() + metadata = [{"name":row[0], "type_code":row[1], "display_size":row[2], "internal_size":row[3], "precision":row[4], "scale":row[5], "null_ok":row[6]} for row in cursor.description] + except Exception, e: + self.logger.debug("query_for_list.execute: Trapped %s while trying to execute '%s'" % (e, sql_query)) + error = e + finally: + try: + cursor.close() + except Exception, e: + self.logger.debug("query_for_list.close: Trapped %s, and throwing away." % e) + + if error: + self.logger.debug("query_for_list: I thought about kicking this up the chain => %s" % error) + + # Convert multi-item tuple into list + return [result for result in results or []], metadata + + def query_for_int(self, sql_query, args = None): + """Execute a query that results in an int value, given static SQL. If args is provided, bind the arguments + (to avoid SQL injection attacks).""" + return self.query_for_object(sql_query, args, types.IntType) + + def query_for_long(self, sql_query, args = None): + """Execute a query that results in an int value, given static SQL. If args is provided, bind the arguments + (to avoid SQL injection attacks).""" + return self.query_for_object(sql_query, args, types.LongType) + + def query_for_object(self, sql_query, args = None, required_type = None): + """Execute a query that results in an int value, given static SQL. If args is provided, bind the arguments + (to avoid SQL injection attacks).""" + + # This is the case where only two, non-named arguments were provided, the sql_query and one other. + # If the second argument was 'args', it is invalid since 'required_type' is required. + # It is was 'required_type', it shifted into 'args' position, and requires naming. + if args and not required_type: + raise ArgumentMustBeNamed(arg_name="required_type") + + results = self.query_for_list(sql_query, args) + + if len(results) != 1: + raise IncorrectResultSizeDataAccessException("Instead of getting one row, this query returned %s" % len(results)) + + if len(results[0]) != 1: + raise IncorrectResultSizeDataAccessException("Instead of getting one column, this query returned %s" % len(results[0])) + + equivalentTypes = [ + [types.UnicodeType, types.StringType] + ] + if type(results[0][0]) != required_type: + foundEquivType = False + for equivType in equivalentTypes: + if type(results[0][0]) in equivType and required_type in equivType: + foundEquivType = True + break + if not foundEquivType: + raise DataAccessException("Expected %s, but instead got %s"% (required_type, type(results[0][0]))) + + return results[0][0] + + def update(self, sql_statement, args = None): + """Issue a single SQL update. If args is provided, bind the arguments + (to avoid SQL injection attacks).""" + return self.execute(sql_statement, args) + + +class RowMapper(object): + """ + This is an interface to handle one row of data. + """ + def map_row(self, row, metadata=None): + raise NotImplementedError() + +class DictionaryRowMapper(RowMapper): + """ + This row mapper converts the tuple into a dictionary using the column names as the keys. + """ + def map_row(self, row, metadata=None): + if metadata is not None: + obj = {} + for i, column in enumerate(metadata): + obj[column["name"]] = row[i] + return obj + else: + raise DataAccessException("metadata is None, unable to convert result set into a dictionary") + +class SimpleRowMapper(RowMapper): + """ + This row mapper uses convention over configuration to create and populate attributes + of an object. + """ + def __init__(self, clazz): + self.clazz = clazz + + def map_row(self, row, metadata=None): + if metadata is not None: + obj = self.clazz() + for i, column in enumerate(metadata): + setattr(obj, column["name"], row[i]) + return obj + else: + raise DataAccessException("metadata is None, unable to map result set into %s instance" % self.clazz) + diff --git a/src/springpython/database/factory.py b/src/springpython/database/factory.py index 57e280d..72107d2 100644 --- a/src/springpython/database/factory.py +++ b/src/springpython/database/factory.py @@ -62,7 +62,7 @@ def convert_sql_binding(self, sql_query): class MySQLConnectionFactory(ConnectionFactory): def __init__(self, username = None, password = None, hostname = None, db = None): - ConnectionFactory.__init__(self, [types.TupleType]) + ConnectionFactory.__init__(self, [tuple]) self.username = username self.password = password self.hostname = hostname @@ -77,11 +77,11 @@ def in_transaction(self): return True def count_type(self): - return types.LongType + return int class PgdbConnectionFactory(ConnectionFactory): def __init__(self, user = None, password = None, host = None, database = None): - ConnectionFactory.__init__(self, [types.TupleType]) + ConnectionFactory.__init__(self, [tuple]) self.user = user self.password = password self.host = host @@ -96,11 +96,11 @@ def in_transaction(self): return True def count_type(self): - return types.LongType + return int class Sqlite3ConnectionFactory(ConnectionFactory): def __init__(self, db = None, check_same_thread=True): - ConnectionFactory.__init__(self, [types.TupleType]) + ConnectionFactory.__init__(self, [tuple]) self.db = db self.check_same_thread = check_same_thread self.using_sqlite3 = True @@ -119,7 +119,7 @@ def in_transaction(self): return True def count_type(self): - return types.IntType + return int def convert_sql_binding(self, sql_query): if self.using_sqlite3: @@ -131,7 +131,7 @@ def convert_sql_binding(self, sql_query): class cxoraConnectionFactory(ConnectionFactory): def __init__(self, username = None, password = None, hostname = None, db = None): - ConnectionFactory.__init__(self, [types.DictType]) + ConnectionFactory.__init__(self, [dict]) self.username = username self.password = password self.hostname = hostname @@ -144,7 +144,7 @@ def connect(self): class SQLServerConnectionFactory(ConnectionFactory): def __init__(self, **odbc_info): - ConnectionFactory.__init__(self, [types.TupleType]) + ConnectionFactory.__init__(self, [tuple]) self.odbc_info = odbc_info def connect(self): @@ -157,7 +157,7 @@ def in_transaction(self): return True def count_type(self): - return types.IntType + return int def convert_sql_binding(self, sql_query): """SQL Server expects parameters to be passed as question marks.""" diff --git a/src/springpython/database/factory.py.bak b/src/springpython/database/factory.py.bak new file mode 100644 index 0000000..57e280d --- /dev/null +++ b/src/springpython/database/factory.py.bak @@ -0,0 +1,164 @@ +""" + Copyright 2006-2008 SpringSource (http://springsource.com), All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +import logging +import re +import sys +import types + +class ConnectionFactory(object): + def __init__(self, acceptable_types): + self.__db = None + self.acceptable_types = acceptable_types + + """This interface defines an object that is able to make database connections. + This allows database connections to be defined inside application contexts, and + fed to DAO and DatabaseTemplates.""" + def connect(self): + raise NotImplementedError() + + def getConnection(self): + if self.__db is None: + self.__db = self.connect() + return self.__db + + def close(self): + "Need to offer API call to close the connection to the database." + if self.__db is not None: + self.__db.close() + self.__db = None + + def commit(self): + if self.in_transaction(): + self.getConnection().commit() + + def rollback(self): + if self.in_transaction(): + self.getConnection().rollback() + + def in_transaction(self): + raise NotImplementedError() + + def count_type(self): + raise NotImplementedError() + + def convert_sql_binding(self, sql_query): + """This is to help Java users migrate to Python. Java notation defines binding variables + points with '?', while Python uses '%s', and this method will convert from one format + to the other.""" + return re.sub(pattern="\?", repl="%s", string=sql_query) + +class MySQLConnectionFactory(ConnectionFactory): + def __init__(self, username = None, password = None, hostname = None, db = None): + ConnectionFactory.__init__(self, [types.TupleType]) + self.username = username + self.password = password + self.hostname = hostname + self.db = db + + def connect(self): + """The import statement is delayed so the library is loaded ONLY if this factory is really used.""" + import MySQLdb + return MySQLdb.connect(self.hostname, self.username, self.password, self.db) + + def in_transaction(self): + return True + + def count_type(self): + return types.LongType + +class PgdbConnectionFactory(ConnectionFactory): + def __init__(self, user = None, password = None, host = None, database = None): + ConnectionFactory.__init__(self, [types.TupleType]) + self.user = user + self.password = password + self.host = host + self.database = database + + def connect(self): + """The import statement is delayed so the library is loaded ONLY if this factory is really used.""" + import pgdb + return pgdb.connect(user=self.user, password=self.password, database=self.database, host=self.host) + + def in_transaction(self): + return True + + def count_type(self): + return types.LongType + +class Sqlite3ConnectionFactory(ConnectionFactory): + def __init__(self, db = None, check_same_thread=True): + ConnectionFactory.__init__(self, [types.TupleType]) + self.db = db + self.check_same_thread = check_same_thread + self.using_sqlite3 = True + + def connect(self): + """The import statement is delayed so the library is loaded ONLY if this factory is really used.""" + try: + import sqlite3 + return sqlite3.connect(self.db, check_same_thread=self.check_same_thread) + except: + import sqlite + self.using_sqlite3 = False + return sqlite.connect(self.db, check_same_thread=self.check_same_thread) + + def in_transaction(self): + return True + + def count_type(self): + return types.IntType + + def convert_sql_binding(self, sql_query): + if self.using_sqlite3: + """sqlite3 uses the ? notation, like Java's JDBC.""" + return re.sub(pattern="%s", repl="?", string=sql_query) + else: + """Older versions of sqlite use the %s notation""" + return re.sub(pattern="\?", repl="%s", string=sql_query) + +class cxoraConnectionFactory(ConnectionFactory): + def __init__(self, username = None, password = None, hostname = None, db = None): + ConnectionFactory.__init__(self, [types.DictType]) + self.username = username + self.password = password + self.hostname = hostname + self.db = db + + def connect(self): + """The import statement is delayed so the library is loaded ONLY if this factory is really used.""" + import cx_Oracle + return cx_Oracle.connect(self.username, self.password, self.db) + +class SQLServerConnectionFactory(ConnectionFactory): + def __init__(self, **odbc_info): + ConnectionFactory.__init__(self, [types.TupleType]) + self.odbc_info = odbc_info + + def connect(self): + """The import statement is delayed so the library is loaded ONLY if this factory is really used.""" + import pyodbc + odbc_info = ";".join(["%s=%s" % (key, value) for key, value in self.odbc_info.items()]) + return pyodbc.connect(odbc_info) + + def in_transaction(self): + return True + + def count_type(self): + return types.IntType + + def convert_sql_binding(self, sql_query): + """SQL Server expects parameters to be passed as question marks.""" + return re.sub(pattern="%s", repl="?", string=sql_query) diff --git a/src/springpython/database/transaction.py b/src/springpython/database/transaction.py index 5e31401..c4bc6e9 100644 --- a/src/springpython/database/transaction.py +++ b/src/springpython/database/transaction.py @@ -147,7 +147,7 @@ def execute(self, transactionCallback): self.logger.debug("Execute the steps inside the transaction") result = transactionCallback.do_in_transaction(status) self.tx_manager.commit(status) - except Exception, e: + except Exception as e: self.logger.debug("Exception: (%s)" % e) self.tx_manager.rollback(status) raise e @@ -211,7 +211,7 @@ def do_in_transaction(s, status): self.logger.debug("Call TransactionTemplate") try: results = tx_template.execute(tx_def()) - except Exception, e: + except Exception as e: self.logger.debug("Exception => %s" % e) raise e self.logger.debug("Return from TransactionTemplate") @@ -299,11 +299,11 @@ def post_process_after_initialization(self, obj, obj_name): for name, method in inspect.getmembers(obj, inspect.ismethod): try: # If the method contains _call_, then you are looking at a wrapper... - wrapper = method.im_func.func_globals["_call_"] - if wrapper.func_name == "transactional_wrapper": # name of @transactional's wrapper method + wrapper = method.__func__.__globals__["_call_"] + if wrapper.__name__ == "transactional_wrapper": # name of @transactional's wrapper method self.logger.debug("Linking tx_manager with %s" % name) - wrapper.func_globals["tx_manager"] = self.tx_manager - except KeyError, e: # If the method is NOT wrapped, there will be no _call_ attribute + wrapper.__globals__["tx_manager"] = self.tx_manager + except KeyError as e: # If the method is NOT wrapped, there will be no _call_ attribute pass return obj diff --git a/src/springpython/database/transaction.py.bak b/src/springpython/database/transaction.py.bak new file mode 100644 index 0000000..5e31401 --- /dev/null +++ b/src/springpython/database/transaction.py.bak @@ -0,0 +1,310 @@ +""" + Copyright 2006-2008 SpringSource (http://springsource.com), All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +import inspect +import logging +import re +import types +from springpython.aop import MethodInterceptor +from springpython.aop import ProxyFactoryObject +from springpython.context import ObjectPostProcessor +from springpython.config.decorator import decorator + +logger = logging.getLogger("springpython.database.transaction") + +class TransactionException(Exception): + pass + +class TransactionPropagationException(TransactionException): + pass + +class TransactionStatus(object): + pass + +class DefaultTransactionStatus(TransactionStatus): + pass + +class PlatformTransactionManager(object): + """This interface is used to define the operations necessary in handling transactions.""" + def commit(self, status): + raise NotImplementedError() + + def getTransaction(self, definition): + raise NotImplementedError() + + def rollback(self, status): + raise NotImplementedError() + +class ConnectionFactoryTransactionManager(PlatformTransactionManager): + """ + This transaction manager is based upon using a connection factory to control transactions. Since + connection factories are tied to vendor-specific databases, this allows delegation of various + transactional functions on a per-vendor basis. + """ + + def __init__(self, connection_factory): + self.connection_factory = connection_factory + self.logger = logging.getLogger("springpython.database.transaction.ConnectionFactoryTransactionManager") + self.status = [] + + def getTransaction(self, definition): + """According to PEP 249, commits and rollbacks silently start new transactions. Until a more + robust transaction manager is implemented to handle save points and so forth, this must suffice.""" + + self.logger.debug("Analyzing %s" % definition.propagation) + + start_tx = False + + if definition.propagation == "PROPAGATION_REQUIRED": + if len(self.status) == 0: + self.logger.debug("There is no current transaction, and one is required, so starting one.") + start_tx = True + self.status.append(DefaultTransactionStatus()) + + elif definition.propagation == "PROPAGATION_SUPPORTS": + self.logger.debug("This code can execute inside or outside a transaction.") + + elif definition.propagation == "PROPAGATION_MANDATORY": + if len(self.status) == 0: + raise TransactionPropagationException("Trying to execute PROPAGATION_MANDATORY operation while outside TX") + self.status.append(DefaultTransactionStatus()) + + elif definition.propagation == "PROPAGATION_NEVER": + if len(self.status) != 0: + raise TransactionPropagationException("Trying to execute PROPAGATION_NEVER operation while inside TX") + + else: + raise TransactionPropagationException("Transaction propagation level %s is not supported!" % definition.start_tx) + + if start_tx: + self.logger.debug("START TRANSACTION") + self.logger.debug("Creating a transaction, propagation = %s, isolation = %s, timeout = %s, read_only = %s" % (definition.propagation, definition.isolation, definition.timeout, definition.read_only)) + self.connection_factory.commit() + + return self.status + + def commit(self, status): + self.status = status + try: + self.status.pop() + if len(self.status) == 0: + self.logger.debug("Commit the changes") + self.connection_factory.commit() + self.logger.debug("END TRANSACTION") + except IndexError: + pass + + def rollback(self, status): + self.status = status + try: + self.status.pop() + if len(self.status) == 0: + self.logger.debug("Rolling back the transaction.") + self.connection_factory.rollback() + self.logger.debug("END TRANSACTION") + except IndexError: + pass + +class TransactionDefinition(object): + def __init__(self, isolation = None, name = None, propagation = None, timeout = None, read_only = None): + self.isolation = isolation + self.name = name + self.propagation = propagation + self.timeout = timeout + self.read_only = read_only + +class DefaultTransactionDefinition(TransactionDefinition): + def __init__(self, isolation = "ISOLATION_DEFAULT", name = "", propagation = "PROPAGATION_REQUIRED", timeout = "TIMEOUT_DEFAULT", read_only = False): + TransactionDefinition.__init__(self, isolation, name, propagation, timeout, read_only) + +class TransactionTemplate(DefaultTransactionDefinition): + """This utility class is used to simplify defining transactional blocks. Any exceptions thrown inside the + transaction block will be propagated to whom ever is calling the template execute method.""" + + def __init__(self, tx_manager): + DefaultTransactionDefinition.__init__(self) + self.tx_manager = tx_manager + self.logger = logging.getLogger("springpython.database.transaction.TransactionTemplate") + + def execute(self, transactionCallback): + """Execute the action specified by the given callback object within a transaction.""" + + status = self.tx_manager.getTransaction(self) + result = None + try: + self.logger.debug("Execute the steps inside the transaction") + result = transactionCallback.do_in_transaction(status) + self.tx_manager.commit(status) + except Exception, e: + self.logger.debug("Exception: (%s)" % e) + self.tx_manager.rollback(status) + raise e + return result + + def setTxAttributes(self, tx_attributes): + for tx_def_prop in tx_attributes: + if tx_def_prop.startswith("ISOLATION"): + if tx_def_prop != self.isolation: self.isolation = tx_def_prop + elif tx_def_prop.startswith("PROPAGATION"): + if tx_def_prop != self.propagation: self.propagation = tx_def_prop + elif tx_def_prop.startswith("TIMEOUT"): + if tx_def_prop != self.timeout: self.timeout = tx_def_prop + elif tx_def_prop == "read_only": + if not self.read_only: self.read_only = True + else: + self.logger.debug("Don't know how to handle %s" % tx_def_prop) + + +class TransactionCallback(object): + """This interface defines the basic action needed to plug into the TransactionTemplate""" + def do_in_transaction(self, status): + raise NotImplementedError() + +class TransactionCallbackWithoutResult(TransactionCallback): + """This abstract class implements the TransactionCallback, but assumes no value is being returned.""" + def __init__(self): + self.logger = logging.getLogger("springpython.database.transaction.TransactionCallbackWithoutResult") + + def do_in_transaction(self, status): + self.logger.debug("Starting a transaction without result") + self.do_in_tx_without_result(status) + self.logger.debug("Completing a transaction without result") + return None + + def do_in_tx_without_result(self, status): + pass + +class TransactionalInterceptor(MethodInterceptor): + """This interceptor is used by the TransactionProxyFactoryObject in order to wrap + method calls with transactions.""" + def __init__(self, tx_manager, tx_attributes): + self.logger = logging.getLogger("springpython.database.transaction.TransactionalInterceptor") + self.tx_attributes = tx_attributes + self.tx_manager = tx_manager + + def invoke(self, invocation): + class tx_def(TransactionCallback): + def do_in_transaction(s, status): + return invocation.proceed() + + tx_template = TransactionTemplate(self.tx_manager) + + # Iterate over the tx_attributes, and when a method match is found, apply the properties + for pattern, tx_def_props in self.tx_attributes: + if re.compile(pattern).match(invocation.method_name): + self.logger.debug("%s matches pattern %s, tx attributes = %s" % (invocation.method_name, pattern, tx_def_props)) + tx_template.setTxAttributes(tx_def_props) + break + + self.logger.debug("Call TransactionTemplate") + try: + results = tx_template.execute(tx_def()) + except Exception, e: + self.logger.debug("Exception => %s" % e) + raise e + self.logger.debug("Return from TransactionTemplate") + return results + +class TransactionProxyFactoryObject(ProxyFactoryObject): + """This class acts like the target object, and routes function calls through a + transactional interceptor.""" + def __init__(self, tx_manager, target, tx_attributes): + self.logger = logging.getLogger("springpython.database.transaction.TransactionProxyFactoryObject") + ProxyFactoryObject.__init__(self, target, TransactionalInterceptor(tx_manager, tx_attributes)) + +def transactional(tx_attributes = None): + """ + This decorator is actually a utility function that returns an embedded decorator, in order + to handle whether it was called in any of the following ways: + + @transactional() + def foo(): + pass + + @transactional + def foo(): + pass + + The first two ways get parsed by Python as: + + foo = transactional("some contextual string")(foo) # first way + foo = transactional()(foo) # second way + + Since this is expected, they are granted direct access to the embedded transactional_wrapper. + + However, the third way ends up getting parsed by Python as: + + foo = Transactional(foo) + + This causes context to improperly get populated with a function instead of a string. This + requires recalling this utility like: + + return Transactional()(context) + """ + + @decorator + def transactional_wrapper(f, *args, **kwargs): + """ + transactional_wrapper is used to wrap the decorated function in a TransactionTemplate callback, + and then return the results. + """ + class tx_def(TransactionCallback): + """TransactionTemplate requires a callback defined this way.""" + def do_in_transaction(s, status): + return f(*args, **kwargs) + + try: + # Assumes tx_manager is supplied by AutoTransactionalObject + tx_template = TransactionTemplate(tx_manager) + if tx_attributes is not None: + tx_template.setTxAttributes(tx_attributes) + else: + logger.debug("There are NO tx_attributes! %s" % tx_attributes) + return tx_template.execute(tx_def()) + except NameError: + # If no AutoTransactionalObject found in IoC container, then pass straight through. + return tx_def().do_in_transaction(None) + + if type(tx_attributes) == types.FunctionType: + return transactional()(tx_attributes) + else: + return transactional_wrapper + + +class AutoTransactionalObject(ObjectPostProcessor): + """ + This object is used to automatically scan objects in an IoC container, and if @Transaction + is found applied to any of the object's methods, link it with a TransactionManager. + """ + + def __init__(self, tx_manager): + self.tx_manager = tx_manager + self.logger = logging.getLogger("springpython.database.transaction.AutoTransactionalObject") + + def post_process_after_initialization(self, obj, obj_name): + """This setup is run after all objects in the container have been created.""" + # Check every method in the object... + for name, method in inspect.getmembers(obj, inspect.ismethod): + try: + # If the method contains _call_, then you are looking at a wrapper... + wrapper = method.im_func.func_globals["_call_"] + if wrapper.func_name == "transactional_wrapper": # name of @transactional's wrapper method + self.logger.debug("Linking tx_manager with %s" % name) + wrapper.func_globals["tx_manager"] = self.tx_manager + except KeyError, e: # If the method is NOT wrapped, there will be no _call_ attribute + pass + return obj + + diff --git a/src/springpython/factory/__init__.py b/src/springpython/factory/__init__.py index b532f77..91ec3f0 100644 --- a/src/springpython/factory/__init__.py +++ b/src/springpython/factory/__init__.py @@ -49,12 +49,12 @@ def __init__(self, method, wrapper): self.wrapper = wrapper def create_object(self, constr, named_constr): - self.logger.debug("Creating an instance of %s" % self.method.func_name) + self.logger.debug("Creating an instance of %s" % self.method.__name__) # Setting wrapper's top_func can NOT be done earlier than this method call, # because it is tied to a wrapper decorator, which may not have yet been # generated. - self.wrapper.func_globals["top_func"] = self.method.func_name + self.wrapper.__globals__["top_func"] = self.method.__name__ # Because @object-based objects use direct code to specify arguments, and NOT # external configuration data, this factory doesn't care about the incoming arguments. diff --git a/src/springpython/factory/__init__.py.bak b/src/springpython/factory/__init__.py.bak new file mode 100644 index 0000000..b532f77 --- /dev/null +++ b/src/springpython/factory/__init__.py.bak @@ -0,0 +1,65 @@ +""" + Copyright 2006-2008 SpringSource (http://springsource.com), All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +import logging +import sys + +class ObjectFactory(object): + def create_object(self, constr, named_constr): + raise NotImplementedError() + +class ReflectiveObjectFactory(ObjectFactory): + def __init__(self, module_and_class): + self.logger = logging.getLogger("springpython.factory.ReflectiveObjectFactory") + self.module_and_class = module_and_class + + def create_object(self, constr, named_constr): + self.logger.debug("Creating an instance of %s" % self.module_and_class) + parts = self.module_and_class.split(".") + module_name = ".".join(parts[:-1]) + class_name = parts[-1] + if module_name == "": + return __import__(class_name)(*constr, **named_constr) + else: + __import__(module_name) + cls = getattr(sys.modules[module_name], class_name) + return cls(*constr, **named_constr) + + + def __str__(self): + return "ReflectiveObjectFactory(%s)" % self.module_and_class + +class PythonObjectFactory(ObjectFactory): + def __init__(self, method, wrapper): + self.logger = logging.getLogger("springpython.factory.PythonObjectFactory") + self.method = method + self.wrapper = wrapper + + def create_object(self, constr, named_constr): + self.logger.debug("Creating an instance of %s" % self.method.func_name) + + # Setting wrapper's top_func can NOT be done earlier than this method call, + # because it is tied to a wrapper decorator, which may not have yet been + # generated. + self.wrapper.func_globals["top_func"] = self.method.func_name + + # Because @object-based objects use direct code to specify arguments, and NOT + # external configuration data, this factory doesn't care about the incoming arguments. + + return self.method() + + def __str__(self): + return "PythonObjectFactory(%s)" % self.method diff --git a/src/springpython/jms/factory.py b/src/springpython/jms/factory.py index cd93a04..75ac6cd 100644 --- a/src/springpython/jms/factory.py +++ b/src/springpython/jms/factory.py @@ -37,9 +37,13 @@ from springpython.context import DisposableObject from springpython.jms.core import reserved_attributes, TextMessage from springpython.util import TRACE1, synchronized -from springpython.jms import JMSException, WebSphereMQJMSException, \ - NoMessageAvailableException, DELIVERY_MODE_NON_PERSISTENT, \ - DELIVERY_MODE_PERSISTENT +from springpython.jms import ( + JMSException, + WebSphereMQJMSException, + NoMessageAvailableException, + DELIVERY_MODE_NON_PERSISTENT, + DELIVERY_MODE_PERSISTENT, +) # Don't pollute the caller's namespace @@ -79,27 +83,37 @@ _msd.text = "jms_text" _msgbody = etree.Element("msgbody") -_msgbody.set("xmlns:xsi", "dummy") # We're using a dummy namespace +_msgbody.set("xmlns:xsi", "dummy") # We're using a dummy namespace _msgbody.set("xsi:nil", "true") _mcd.append(_msgbody) # Clean up namespace. -del(_msd, _msgbody) +del (_msd, _msgbody) def unhexlify_wmq_id(wmq_id): - """ Converts the WebSphere MQ generated identifier back to bytes, + """Converts the WebSphere MQ generated identifier back to bytes, i.e. "ID:414d5120535052494e47505954484f4ecc90674a041f0020" -> "AMQ SPRINGPYTHON\xcc\x90gJ\x04\x1f\x00 ". """ return unhexlify(wmq_id.replace(_WMQ_ID_PREFIX, "", 1)) class WebSphereMQConnectionFactory(DisposableObject): - - def __init__(self, queue_manager=None, channel=None, host=None, listener_port=None, - cache_open_send_queues=True, cache_open_receive_queues=True, - use_shared_connections=True, dynamic_queue_template="SYSTEM.DEFAULT.MODEL.QUEUE", - ssl=False, ssl_cipher_spec=None, ssl_key_repository=None, needs_mcd=True): + def __init__( + self, + queue_manager=None, + channel=None, + host=None, + listener_port=None, + cache_open_send_queues=True, + cache_open_receive_queues=True, + use_shared_connections=True, + dynamic_queue_template="SYSTEM.DEFAULT.MODEL.QUEUE", + ssl=False, + ssl_cipher_spec=None, + ssl_key_repository=None, + needs_mcd=True, + ): self.queue_manager = queue_manager self.channel = channel self.host = host @@ -112,11 +126,13 @@ def __init__(self, queue_manager=None, channel=None, host=None, listener_port=No self.ssl = ssl self.ssl_cipher_spec = ssl_cipher_spec self.ssl_key_repository = ssl_key_repository - + # WMQ >= 7.0 must not use the mcd folder self.needs_mcd = needs_mcd - self.logger = logging.getLogger("springpython.jms.factory.WebSphereMQConnectionFactory") + self.logger = logging.getLogger( + "springpython.jms.factory.WebSphereMQConnectionFactory" + ) import CMQC import pymqi @@ -146,19 +162,27 @@ def destroy(self): self._open_receive_queues_cache.clear() self._open_dynamic_queues_cache.clear() self.logger.info("Caches cleared") - except Exception, e: + except Exception as e: try: - self.logger.error("Could not clear the caches. Exception [%s]" % format_exc()) + self.logger.error( + "Could not clear the caches. Exception [%s]" % format_exc() + ) except: pass try: - self.logger.info("Disconnecting from queue manager [%s]" % self.queue_manager) + self.logger.info( + "Disconnecting from queue manager [%s]" % self.queue_manager + ) self.mgr.disconnect() - self.logger.info("Disconnected from queue manager [%s]" % self.queue_manager) - except Exception, e: + self.logger.info( + "Disconnected from queue manager [%s]" % self.queue_manager + ) + except Exception as e: try: - self.logger.error("Could not disconnect from queue manager [%s], exception [%s] " % (self.queue_manager, - format_exc())) + self.logger.error( + "Could not disconnect from queue manager [%s], exception [%s] " + % (self.queue_manager, format_exc()) + ) except Exception: pass @@ -169,7 +193,11 @@ def destroy(self): def get_connection_info(self): return "queue manager=[%s], channel=[%s], conn_name=[%s(%s)]" % ( - self.queue_manager, self.channel, self.host, self.listener_port) + self.queue_manager, + self.channel, + self.host, + self.listener_port, + ) @synchronized() def _connect(self): @@ -178,8 +206,10 @@ def _connect(self): conn_name = "%s(%s)" % (self.host, self.listener_port) - self.logger.info("Connecting to queue manager [%s], channel [%s]" \ - ", connection info [%s]" % (self.queue_manager, self.channel, conn_name)) + self.logger.info( + "Connecting to queue manager [%s], channel [%s]" + ", connection info [%s]" % (self.queue_manager, self.channel, conn_name) + ) self.mgr = self.mq.QueueManager(None) sco = self.mq.sco() @@ -190,7 +220,7 @@ def _connect(self): cd.TransportType = self.CMQC.MQXPT_TCP if self.ssl: - if not(self.ssl_cipher_spec and self.ssl_key_repository): + if not (self.ssl_cipher_spec and self.ssl_key_repository): msg = "SSL support requires setting both ssl_cipher_spec and ssl_key_repository" self.logger.error(msg) raise JMSException(msg) @@ -204,18 +234,23 @@ def _connect(self): connect_options = self.CMQC.MQCNO_HANDLE_SHARE_NONE try: - self.mgr.connectWithOptions(self.queue_manager, cd=cd, opts=connect_options, sco=sco) - except self.mq.MQMIError, e: + self.mgr.connectWithOptions( + self.queue_manager, cd=cd, opts=connect_options, sco=sco + ) + except self.mq.MQMIError as e: exc = WebSphereMQJMSException(e, e.comp, e.reason) raise exc - except Exception, e: + except Exception as e: self.logger.error("Could not connect to queue manager, e=[%s]" % e) exc = WebSphereMQJMSException(e, None, None) raise exc else: self._is_connected = True - self.logger.info("Successfully connected to queue manager [%s]" \ - ", channel [%s], connection info [%s]" % (self.queue_manager, self.channel, conn_name)) + self.logger.info( + "Successfully connected to queue manager [%s]" + ", channel [%s], connection info [%s]" + % (self.queue_manager, self.channel, conn_name) + ) def _get_queue_from_cache(self, destination, cache): lock = RLock() @@ -226,7 +261,11 @@ def _get_queue_from_cache(self, destination, cache): return cache[destination] else: self.logger.debug("Adding queue [%s] to the cache" % destination) - cache[destination] = self.mq.Queue(self.mgr, destination, self.CMQC.MQOO_INPUT_SHARED | self.CMQC.MQOO_OUTPUT) + cache[destination] = self.mq.Queue( + self.mgr, + destination, + self.CMQC.MQOO_INPUT_SHARED | self.CMQC.MQOO_OUTPUT, + ) self.logger.debug("Queue [%s] added to the cache" % destination) self.logger.log(TRACE1, "Cache contents [%s]" % cache) return cache[destination] @@ -235,7 +274,9 @@ def _get_queue_from_cache(self, destination, cache): def get_queue_for_sending(self, destination): if self.cache_open_send_queues: - queue = self._get_queue_from_cache(destination, self._open_send_queues_cache) + queue = self._get_queue_from_cache( + destination, self._open_send_queues_cache + ) else: queue = self.mq.Queue(self.mgr, destination) @@ -243,13 +284,14 @@ def get_queue_for_sending(self, destination): def get_queue_for_receiving(self, destination): if self.cache_open_receive_queues: - queue = self._get_queue_from_cache(destination, self._open_receive_queues_cache) + queue = self._get_queue_from_cache( + destination, self._open_receive_queues_cache + ) else: queue = self.mq.Queue(self.mgr, destination) return queue - def send(self, message, destination): if self._disconnecting: self.logger.info("Connection factory disconnecting, aborting receive") @@ -272,7 +314,9 @@ def send(self, message, destination): # Create MQRFH2 header now = long(time() * 1000) - mqrfh2jms = MQRFH2JMS(self.needs_mcd).build_header(message, destination, self.CMQC, now) + mqrfh2jms = MQRFH2JMS(self.needs_mcd).build_header( + message, destination, self.CMQC, now + ) buff.write(mqrfh2jms) if message.text != None: @@ -285,9 +329,11 @@ def send(self, message, destination): try: queue.put(body, md) - except self.mq.MQMIError, e: - self.logger.error("MQMIError in queue.put, e.comp [%s], e.reason [%s] " % ( - e.comp, e.reason)) + except self.mq.MQMIError as e: + self.logger.error( + "MQMIError in queue.put, e.comp [%s], e.reason [%s] " + % (e.comp, e.reason) + ) exc = WebSphereMQJMSException(e, e.comp, e.reason) raise exc @@ -302,20 +348,28 @@ def send(self, message, destination): message.JMSXAppID = md.PutApplName if md.PutDate and md.PutTime: - message.jms_timestamp = self._get_jms_timestamp_from_md(md.PutDate.strip(), md.PutTime.strip()) + message.jms_timestamp = self._get_jms_timestamp_from_md( + md.PutDate.strip(), md.PutTime.strip() + ) message.JMS_IBM_PutDate = md.PutDate.strip() message.JMS_IBM_PutTime = md.PutTime.strip() else: - self.logger.warning("No md.PutDate and md.PutTime found, md [%r]" % repr(md)) + self.logger.warning( + "No md.PutDate and md.PutTime found, md [%r]" % repr(md) + ) # queue.put has succeeded, so overwrite expiration time as well if message.jms_expiration != None: message.jms_expiration += now - self.logger.debug("Successfully sent a message [%s], connection info [%s]" % ( - message, self.get_connection_info())) + self.logger.debug( + "Successfully sent a message [%s], connection info [%s]" + % (message, self.get_connection_info()) + ) - self.logger.log(TRACE1, "message [%s], body [%r], md [%r]" % (message, body, repr(md))) + self.logger.log( + TRACE1, "message [%s], body [%r], md [%r]" % (message, body, repr(md)) + ) def receive(self, destination, wait_interval): if self._disconnecting: @@ -324,7 +378,6 @@ def receive(self, destination, wait_interval): else: self.logger.log(TRACE1, "receive -> not disconnecting") - if not self._is_connected: self.logger.log(TRACE1, "receive -> _is_connected1 %s" % self._is_connected) self._connect() @@ -345,31 +398,43 @@ def receive(self, destination, wait_interval): return self._build_text_message(md, message) - except self.mq.MQMIError, e: + except self.mq.MQMIError as e: if e.reason == self.CMQC.MQRC_NO_MSG_AVAILABLE: - text = "No message available for destination [%s], " \ + text = ( + "No message available for destination [%s], " "wait_interval [%s] ms" % (destination, wait_interval) + ) raise NoMessageAvailableException(text) else: - self.logger.log(TRACE1, "Exception caught in get, e.comp=[%s], e.reason=[%s]" % (e.comp, e.reason)) + self.logger.log( + TRACE1, + "Exception caught in get, e.comp=[%s], e.reason=[%s]" + % (e.comp, e.reason), + ) exc = WebSphereMQJMSException(e, e.comp, e.reason) raise exc - def open_dynamic_queue(self): if self._disconnecting: - self.logger.info("Connection factory disconnecting, aborting open_dynamic_queue") + self.logger.info( + "Connection factory disconnecting, aborting open_dynamic_queue" + ) return else: self.logger.log(TRACE1, "open_dynamic_queue -> not disconnecting") if not self._is_connected: - self.logger.log(TRACE1, "open_dynamic_queue -> _is_connected1 %s" % self._is_connected) + self.logger.log( + TRACE1, "open_dynamic_queue -> _is_connected1 %s" % self._is_connected + ) self._connect() - self.logger.log(TRACE1, "open_dynamic_queue -> _is_connected2 %s" % self._is_connected) + self.logger.log( + TRACE1, "open_dynamic_queue -> _is_connected2 %s" % self._is_connected + ) - dynamic_queue = self.mq.Queue(self.mgr, self.dynamic_queue_template, - self.CMQC.MQOO_INPUT_SHARED) + dynamic_queue = self.mq.Queue( + self.mgr, self.dynamic_queue_template, self.CMQC.MQOO_INPUT_SHARED + ) # A bit hackish, but there's no other way to get its name. dynamic_queue_name = dynamic_queue._Queue__qDesc.ObjectName.strip() @@ -381,24 +446,33 @@ def open_dynamic_queue(self): finally: lock.release() - self.logger.log(TRACE1, "Successfully created a dynamic queue, descriptor [%s]" % ( - dynamic_queue._Queue__qDesc)) + self.logger.log( + TRACE1, + "Successfully created a dynamic queue, descriptor [%s]" + % (dynamic_queue._Queue__qDesc), + ) return dynamic_queue_name def close_dynamic_queue(self, dynamic_queue_name): if self._disconnecting: - self.logger.info("Connection factory disconnecting, aborting close_dynamic_queue") + self.logger.info( + "Connection factory disconnecting, aborting close_dynamic_queue" + ) return else: self.logger.log(TRACE1, "close_dynamic_queue -> not disconnecting") if not self._is_connected: # If we're not connected then all dynamic queues had been already closed. - self.logger.log(TRACE1, "close_dynamic_queue -> _is_connected1 %s" % self._is_connected) + self.logger.log( + TRACE1, "close_dynamic_queue -> _is_connected1 %s" % self._is_connected + ) return else: - self.logger.log(TRACE1, "close_dynamic_queue -> _is_connected2 %s" % self._is_connected) + self.logger.log( + TRACE1, "close_dynamic_queue -> _is_connected2 %s" % self._is_connected + ) lock = RLock() lock.acquire() try: @@ -409,8 +483,10 @@ def close_dynamic_queue(self, dynamic_queue_name): self._open_send_queues_cache.pop(dynamic_queue_name, None) self._open_receive_queues_cache.pop(dynamic_queue_name, None) - self.logger.log(TRACE1, "Successfully closed a dynamic queue [%s]" % ( - dynamic_queue_name)) + self.logger.log( + TRACE1, + "Successfully closed a dynamic queue [%s]" % (dynamic_queue_name), + ) finally: lock.release() @@ -424,9 +500,10 @@ def _get_jms_timestamp_from_md(self, put_date, put_time): return long((mk - altzone + centi) * 1000.0) - def _build_text_message(self, md, message): - self.logger.log(TRACE1, "Building a text message [%r], md [%r]" % (repr(message), repr(md))) + self.logger.log( + TRACE1, "Building a text message [%r], md [%r]" % (repr(message), repr(md)) + ) mqrfh2 = MQRFH2JMS(self.needs_mcd) mqrfh2.build_folders_and_payload_from_message(message) @@ -451,28 +528,37 @@ def _build_text_message(self, md, message): if jms_folder.find("Exp") is not None: text_message.jms_expiration = long(jms_folder.find("Exp").text) else: - text_message.jms_expiration = 0 # Same as in Java + text_message.jms_expiration = 0 # Same as in Java if jms_folder.find("Cid") is not None: text_message.jms_correlation_id = jms_folder.find("Cid").text if md.Persistence == self.CMQC.MQPER_NOT_PERSISTENT: text_message.jms_delivery_mode = DELIVERY_MODE_NON_PERSISTENT - elif md.Persistence in(self.CMQC.MQPER_PERSISTENT, self.CMQC.MQPER_PERSISTENCE_AS_Q_DEF): + elif md.Persistence in ( + self.CMQC.MQPER_PERSISTENT, + self.CMQC.MQPER_PERSISTENCE_AS_Q_DEF, + ): text_message.jms_delivery_mode = DELIVERY_MODE_PERSISTENT else: - text = "Don't know how to handle md.Persistence mode [%s]" % (md.Persistence) + text = "Don't know how to handle md.Persistence mode [%s]" % ( + md.Persistence + ) self.logger.error(text) exc = WebSphereMQJMSException(text) raise exc if md.ReplyToQ.strip(): self.logger.log(TRACE1, "Found md.ReplyToQ=[%r]" % md.ReplyToQ) - text_message.jms_reply_to = "queue://" + md.ReplyToQMgr.strip() + "/" + md.ReplyToQ.strip() + text_message.jms_reply_to = ( + "queue://" + md.ReplyToQMgr.strip() + "/" + md.ReplyToQ.strip() + ) text_message.jms_priority = md.Priority text_message.jms_message_id = _WMQ_ID_PREFIX + hexlify(md.MsgId) - text_message.jms_timestamp = self._get_jms_timestamp_from_md(md.PutDate.strip(), md.PutTime.strip()) + text_message.jms_timestamp = self._get_jms_timestamp_from_md( + md.PutDate.strip(), md.PutTime.strip() + ) text_message.jms_redelivered = bool(int(md.BackoutCount)) text_message.JMSXUserID = md.UserIdentifier.strip() @@ -493,7 +579,7 @@ def _build_text_message(self, md, message): self.CMQC.MQRO_DISCARD_MSG: "Discard_Msg", } - for report_name, jms_header_name in md_report_to_jms.iteritems(): + for report_name, jms_header_name in md_report_to_jms.items(): report_value = md.Report & report_name if report_value: header_value = report_value @@ -552,7 +638,10 @@ def _build_md(self, message): elif message.jms_delivery_mode == DELIVERY_MODE_PERSISTENT: persistence = self.CMQC.MQPER_PERSISTENT else: - info = "jms_delivery_mode should be equal to DELIVERY_MODE_NON_PERSISTENT or DELIVERY_MODE_PERSISTENT, not [%s]" % message.jms_delivery_mode + info = ( + "jms_delivery_mode should be equal to DELIVERY_MODE_NON_PERSISTENT or DELIVERY_MODE_PERSISTENT, not [%s]" + % message.jms_delivery_mode + ) self.logger.error(info) exc = JMSException(info) raise exc @@ -565,8 +654,13 @@ def _build_md(self, message): if message.jms_reply_to: md.ReplyToQ = message.jms_reply_to - self.logger.log(TRACE1, ("Set jms_reply_to. md.ReplyToQ=[%r]," - " message.jms_reply_to=[%r]" % (md.ReplyToQ, message.jms_reply_to))) + self.logger.log( + TRACE1, + ( + "Set jms_reply_to. md.ReplyToQ=[%r]," + " message.jms_reply_to=[%r]" % (md.ReplyToQ, message.jms_reply_to) + ), + ) # jms_expiration is in milliseconds, md.Expiry is in centiseconds. if message.jms_expiration: @@ -590,8 +684,17 @@ def _build_md(self, message): md.GroupId = jmsxgroupid.ljust(24)[:24] md.MsgFlags |= self.CMQC.MQMF_MSG_IN_GROUP - for report_name in("Exception", "Expiration", "COA", "COD", "PAN", - "NAN", "Pass_Msg_ID", "Pass_Correl_ID", "Discard_Msg"): + for report_name in ( + "Exception", + "Expiration", + "COA", + "COD", + "PAN", + "NAN", + "Pass_Msg_ID", + "Pass_Correl_ID", + "Discard_Msg", + ): report = getattr(message, "JMS_IBM_Report_" + report_name, None) if report != None: @@ -612,8 +715,9 @@ def _build_md(self, message): return md + class MQRFH2JMS(object): - """ A class for representing a subset of MQRFH2, suitable for passing + """A class for representing a subset of MQRFH2, suitable for passing WebSphere MQ JMS headers around. """ @@ -634,32 +738,34 @@ class MQRFH2JMS(object): FOLDER_SIZE_HEADER_LENGTH = 4 def __init__(self, needs_mcd=True): - + # Whether to add the mcd folder. Needs to be False for everything to # work properly with WMQ >= 7.0 self.needs_mcd = needs_mcd - + self.folders = {} self.payload = None self.logger = logging.getLogger("springpython.jms.factory.MQRFH2JMS") def _pad_folder(self, folder): - """ Pads the folder to a multiple of 4, as required by WebSphere MQ. - """ + """Pads the folder to a multiple of 4, as required by WebSphere MQ.""" folder_len = len(folder) if folder_len % MQRFH2JMS.FOLDER_LENGTH_MULTIPLE == 0: return folder else: - padding = MQRFH2JMS.FOLDER_LENGTH_MULTIPLE - folder_len % MQRFH2JMS.FOLDER_LENGTH_MULTIPLE + padding = ( + MQRFH2JMS.FOLDER_LENGTH_MULTIPLE + - folder_len % MQRFH2JMS.FOLDER_LENGTH_MULTIPLE + ) return folder.ljust(folder_len + padding) def build_folders_and_payload_from_message(self, message): total_mqrfh2_length = unpack("!l", message[8:12])[0] - mqrfh2 = message[MQRFH2JMS.FIXED_PART_LENGTH:total_mqrfh2_length] - self.payload = message[MQRFH2JMS.FIXED_PART_LENGTH + len(mqrfh2):] + mqrfh2 = message[MQRFH2JMS.FIXED_PART_LENGTH : total_mqrfh2_length] + self.payload = message[MQRFH2JMS.FIXED_PART_LENGTH + len(mqrfh2) :] self.logger.log(TRACE1, "message [%r]" % message) self.logger.log(TRACE1, "mqrfh2 [%r]" % mqrfh2) @@ -668,12 +774,15 @@ def build_folders_and_payload_from_message(self, message): left = mqrfh2 while left: current_folder_length = unpack("!l", left[:4])[0] - raw_folder = left[MQRFH2JMS.FOLDER_SIZE_HEADER_LENGTH:MQRFH2JMS.FOLDER_SIZE_HEADER_LENGTH + current_folder_length] + raw_folder = left[ + MQRFH2JMS.FOLDER_SIZE_HEADER_LENGTH : MQRFH2JMS.FOLDER_SIZE_HEADER_LENGTH + + current_folder_length + ] self.logger.log(TRACE1, "raw_folder [%r]" % raw_folder) self.build_folder(raw_folder) - left = left[MQRFH2JMS.FOLDER_SIZE_HEADER_LENGTH + current_folder_length:] + left = left[MQRFH2JMS.FOLDER_SIZE_HEADER_LENGTH + current_folder_length :] def build_folder(self, raw_folder): @@ -683,14 +792,18 @@ def build_folder(self, raw_folder): # of any other way to work around it if we'd like to treat folders as # XML(-like) structures. - if 'xsi:nil="true"' in raw_folder and not 'xmlns' in raw_folder: - self.logger.log(TRACE1, "Binding xsi:nil to a dummy namespace [%s]" % raw_folder) - raw_folder = raw_folder.replace('xsi:nil="true"', 'xmlns:xsi="dummy" xsi:nil="true"') + if 'xsi:nil="true"' in raw_folder and not "xmlns" in raw_folder: + self.logger.log( + TRACE1, "Binding xsi:nil to a dummy namespace [%s]" % raw_folder + ) + raw_folder = raw_folder.replace( + 'xsi:nil="true"', 'xmlns:xsi="dummy" xsi:nil="true"' + ) self.logger.log(TRACE1, "raw_folder after binding [%s]" % raw_folder) folder = etree.fromstring(raw_folder) root_name = folder.tag - + root_names = ["jms", "usr"] if self.needs_mcd: root_names.append("mcd") @@ -698,18 +811,19 @@ def build_folder(self, raw_folder): if root_name in root_names: self.folders[root_name] = folder else: - self.logger.warn("Ignoring unrecognized JMS folder [%s]=[%s]" % (root_name, raw_folder)) - + self.logger.warn( + "Ignoring unrecognized JMS folder [%s]=[%s]" % (root_name, raw_folder) + ) def build_header(self, message, queue_name, CMQC, now): - + if self.needs_mcd: self.folders["mcd"] = _mcd mcd = self._pad_folder(etree.tostring(self.folders["mcd"])) mcd_len = len(mcd) else: mcd_len = 0 - + self.add_jms(message, queue_name, now) self.add_usr(message) @@ -742,11 +856,11 @@ def build_header(self, message, queue_name, CMQC, now): buff.write(CMQC.MQFMT_STRING) buff.write(_WMQ_MQRFH_NO_FLAGS_WIRE_FORMAT) buff.write(_WMQ_DEFAULT_CCSID_WIRE_FORMAT) - + if self.needs_mcd: buff.write(pack("!l", mcd_len)) buff.write(mcd) - + buff.write(pack("!l", jms_len)) buff.write(jms) @@ -771,7 +885,7 @@ def add_jms(self, message, queue_name, now): jms.append(dlv) tms.text = unicode(now) - dst.text = u"queue:///" + queue_name + dst.text = "queue:///" + queue_name dlv.text = unicode(message.jms_delivery_mode) if message.jms_expiration: diff --git a/src/springpython/jms/factory.py.bak b/src/springpython/jms/factory.py.bak new file mode 100644 index 0000000..cd93a04 --- /dev/null +++ b/src/springpython/jms/factory.py.bak @@ -0,0 +1,817 @@ +""" + Copyright 2006-2008 SpringSource (http://springsource.com), All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +# stdlib +import sys +import logging +from threading import RLock +from cStringIO import StringIO +from struct import pack, unpack +from xml.sax.saxutils import escape +from binascii import hexlify, unhexlify +from time import time, mktime, strptime, altzone +from traceback import format_exc + +try: + import cElementTree as etree +except ImportError: + try: + import xml.etree.ElementTree as etree + except ImportError: + from elementtree import ElementTree as etree + +# Spring Python +from springpython.context import DisposableObject +from springpython.jms.core import reserved_attributes, TextMessage +from springpython.util import TRACE1, synchronized +from springpython.jms import JMSException, WebSphereMQJMSException, \ + NoMessageAvailableException, DELIVERY_MODE_NON_PERSISTENT, \ + DELIVERY_MODE_PERSISTENT + + +# Don't pollute the caller's namespace +__all__ = ["WebSphereMQConnectionFactory"] + + +# Internal constants, don't touch. + +# Some WMQ constants are not exposed by pymqi. +_WMQ_MQRFH_VERSION_2 = "\x00\x00\x00\x02" +_WMQ_DEFAULT_ENCODING = 273 +_WMQ_DEFAULT_ENCODING_WIRE_FORMAT = pack("!l", _WMQ_DEFAULT_ENCODING) + +# 1208 = UTF-8 +_WMQ_DEFAULT_CCSID = 1208 +_WMQ_DEFAULT_CCSID_WIRE_FORMAT = pack("!l", _WMQ_DEFAULT_CCSID) + +# From cmqc.h +_WMQ_MQFMT_RF_HEADER_2 = "MQHRF2 " + +# MQRFH_NO_FLAGS_WIRE is in cmqc.h +_WMQ_MQRFH_NO_FLAGS_WIRE_FORMAT = "\x00\x00\x00\x00" + +# Java documentation says "214748364.7 seconds". +_WMQ_MAX_EXPIRY_TIME = 214748364.7 + +_WMQ_ID_PREFIX = "ID:" + +# In current implementation, an mcd JMS folder is constant for every message +# sent, so let's build it here. + +_mcd = etree.Element("mcd") +_msd = etree.Element("Msd") +_mcd.append(_msd) + +# For now, it's always a TextMessage +_msd.text = "jms_text" + +_msgbody = etree.Element("msgbody") +_msgbody.set("xmlns:xsi", "dummy") # We're using a dummy namespace +_msgbody.set("xsi:nil", "true") +_mcd.append(_msgbody) + +# Clean up namespace. +del(_msd, _msgbody) + + +def unhexlify_wmq_id(wmq_id): + """ Converts the WebSphere MQ generated identifier back to bytes, + i.e. "ID:414d5120535052494e47505954484f4ecc90674a041f0020" -> "AMQ SPRINGPYTHON\xcc\x90gJ\x04\x1f\x00 ". + """ + return unhexlify(wmq_id.replace(_WMQ_ID_PREFIX, "", 1)) + + +class WebSphereMQConnectionFactory(DisposableObject): + + def __init__(self, queue_manager=None, channel=None, host=None, listener_port=None, + cache_open_send_queues=True, cache_open_receive_queues=True, + use_shared_connections=True, dynamic_queue_template="SYSTEM.DEFAULT.MODEL.QUEUE", + ssl=False, ssl_cipher_spec=None, ssl_key_repository=None, needs_mcd=True): + self.queue_manager = queue_manager + self.channel = channel + self.host = host + self.listener_port = listener_port + + self.use_shared_connections = use_shared_connections + self.dynamic_queue_template = dynamic_queue_template + + # SSL support + self.ssl = ssl + self.ssl_cipher_spec = ssl_cipher_spec + self.ssl_key_repository = ssl_key_repository + + # WMQ >= 7.0 must not use the mcd folder + self.needs_mcd = needs_mcd + + self.logger = logging.getLogger("springpython.jms.factory.WebSphereMQConnectionFactory") + + import CMQC + import pymqi + + self.CMQC = CMQC + self.mq = pymqi + + self._open_send_queues_cache = {} + self._open_receive_queues_cache = {} + self._open_dynamic_queues_cache = {} + + self.cache_open_send_queues = cache_open_send_queues + self.cache_open_receive_queues = cache_open_receive_queues + + self._is_connected = False + self._disconnecting = False + + self.logger.log(TRACE1, "Finished __init__") + + @synchronized() + def destroy(self): + if self._is_connected: + self._disconnecting = True + try: + self.logger.info("Deleting queues from caches") + self._open_send_queues_cache.clear() + self._open_receive_queues_cache.clear() + self._open_dynamic_queues_cache.clear() + self.logger.info("Caches cleared") + except Exception, e: + try: + self.logger.error("Could not clear the caches. Exception [%s]" % format_exc()) + except: + pass + try: + self.logger.info("Disconnecting from queue manager [%s]" % self.queue_manager) + self.mgr.disconnect() + self.logger.info("Disconnected from queue manager [%s]" % self.queue_manager) + except Exception, e: + try: + self.logger.error("Could not disconnect from queue manager [%s], exception [%s] " % (self.queue_manager, + format_exc())) + except Exception: + pass + + self._is_connected = False + + else: + self.logger.debug("Not connected, skipping cleaning up the resources") + + def get_connection_info(self): + return "queue manager=[%s], channel=[%s], conn_name=[%s(%s)]" % ( + self.queue_manager, self.channel, self.host, self.listener_port) + + @synchronized() + def _connect(self): + if self._is_connected: + return + + conn_name = "%s(%s)" % (self.host, self.listener_port) + + self.logger.info("Connecting to queue manager [%s], channel [%s]" \ + ", connection info [%s]" % (self.queue_manager, self.channel, conn_name)) + self.mgr = self.mq.QueueManager(None) + + sco = self.mq.sco() + cd = self.mq.cd() + cd.ChannelName = self.channel + cd.ConnectionName = conn_name + cd.ChannelType = self.CMQC.MQCHT_CLNTCONN + cd.TransportType = self.CMQC.MQXPT_TCP + + if self.ssl: + if not(self.ssl_cipher_spec and self.ssl_key_repository): + msg = "SSL support requires setting both ssl_cipher_spec and ssl_key_repository" + self.logger.error(msg) + raise JMSException(msg) + + sco.KeyRepository = self.ssl_key_repository + cd.SSLCipherSpec = self.ssl_cipher_spec + + if self.use_shared_connections: + connect_options = self.CMQC.MQCNO_HANDLE_SHARE_BLOCK + else: + connect_options = self.CMQC.MQCNO_HANDLE_SHARE_NONE + + try: + self.mgr.connectWithOptions(self.queue_manager, cd=cd, opts=connect_options, sco=sco) + except self.mq.MQMIError, e: + exc = WebSphereMQJMSException(e, e.comp, e.reason) + raise exc + except Exception, e: + self.logger.error("Could not connect to queue manager, e=[%s]" % e) + exc = WebSphereMQJMSException(e, None, None) + raise exc + else: + self._is_connected = True + self.logger.info("Successfully connected to queue manager [%s]" \ + ", channel [%s], connection info [%s]" % (self.queue_manager, self.channel, conn_name)) + + def _get_queue_from_cache(self, destination, cache): + lock = RLock() + lock.acquire() + try: + # Will usually choose this path and find the queue here. + if destination in cache: + return cache[destination] + else: + self.logger.debug("Adding queue [%s] to the cache" % destination) + cache[destination] = self.mq.Queue(self.mgr, destination, self.CMQC.MQOO_INPUT_SHARED | self.CMQC.MQOO_OUTPUT) + self.logger.debug("Queue [%s] added to the cache" % destination) + self.logger.log(TRACE1, "Cache contents [%s]" % cache) + return cache[destination] + finally: + lock.release() + + def get_queue_for_sending(self, destination): + if self.cache_open_send_queues: + queue = self._get_queue_from_cache(destination, self._open_send_queues_cache) + else: + queue = self.mq.Queue(self.mgr, destination) + + return queue + + def get_queue_for_receiving(self, destination): + if self.cache_open_receive_queues: + queue = self._get_queue_from_cache(destination, self._open_receive_queues_cache) + else: + queue = self.mq.Queue(self.mgr, destination) + + return queue + + + def send(self, message, destination): + if self._disconnecting: + self.logger.info("Connection factory disconnecting, aborting receive") + return + else: + self.logger.log(TRACE1, "send -> not disconnecting") + + if not self._is_connected: + self.logger.log(TRACE1, "send -> _is_connected1 %s" % self._is_connected) + self._connect() + self.logger.log(TRACE1, "send -> _is_connected2 %s" % self._is_connected) + + destination = self._strip_prefixes_from_destination(destination) + + # Will consist of an MQRFH2 header and the actual business payload. + buff = StringIO() + + # Build the message descriptor (MQMD) + md = self._build_md(message) + + # Create MQRFH2 header + now = long(time() * 1000) + mqrfh2jms = MQRFH2JMS(self.needs_mcd).build_header(message, destination, self.CMQC, now) + + buff.write(mqrfh2jms) + if message.text != None: + buff.write(message.text.encode("utf-8")) + + body = buff.getvalue() + buff.close() + + queue = self.get_queue_for_sending(destination) + + try: + queue.put(body, md) + except self.mq.MQMIError, e: + self.logger.error("MQMIError in queue.put, e.comp [%s], e.reason [%s] " % ( + e.comp, e.reason)) + exc = WebSphereMQJMSException(e, e.comp, e.reason) + raise exc + + if not self.cache_open_send_queues: + queue.close() + + # Map the JMS headers overwritten by calling queue.put + message.jms_message_id = _WMQ_ID_PREFIX + hexlify(md.MsgId) + message.jms_priority = md.Priority + message.jms_correlation_id = _WMQ_ID_PREFIX + hexlify(md.CorrelId) + message.JMSXUserID = md.UserIdentifier + message.JMSXAppID = md.PutApplName + + if md.PutDate and md.PutTime: + message.jms_timestamp = self._get_jms_timestamp_from_md(md.PutDate.strip(), md.PutTime.strip()) + message.JMS_IBM_PutDate = md.PutDate.strip() + message.JMS_IBM_PutTime = md.PutTime.strip() + else: + self.logger.warning("No md.PutDate and md.PutTime found, md [%r]" % repr(md)) + + # queue.put has succeeded, so overwrite expiration time as well + if message.jms_expiration != None: + message.jms_expiration += now + + self.logger.debug("Successfully sent a message [%s], connection info [%s]" % ( + message, self.get_connection_info())) + + self.logger.log(TRACE1, "message [%s], body [%r], md [%r]" % (message, body, repr(md))) + + def receive(self, destination, wait_interval): + if self._disconnecting: + self.logger.info("Connection factory disconnecting, aborting receive") + return + else: + self.logger.log(TRACE1, "receive -> not disconnecting") + + + if not self._is_connected: + self.logger.log(TRACE1, "receive -> _is_connected1 %s" % self._is_connected) + self._connect() + self.logger.log(TRACE1, "receive -> _is_connected2 %s" % self._is_connected) + + queue = self.get_queue_for_receiving(destination) + + try: + # Default message descriptor .. + md = self.mq.md() + + # .. and custom get message options + gmo = self.mq.gmo() + gmo.Options = self.CMQC.MQGMO_WAIT | self.CMQC.MQGMO_FAIL_IF_QUIESCING + gmo.WaitInterval = wait_interval + + message = queue.get(None, md, gmo) + + return self._build_text_message(md, message) + + except self.mq.MQMIError, e: + if e.reason == self.CMQC.MQRC_NO_MSG_AVAILABLE: + text = "No message available for destination [%s], " \ + "wait_interval [%s] ms" % (destination, wait_interval) + raise NoMessageAvailableException(text) + else: + self.logger.log(TRACE1, "Exception caught in get, e.comp=[%s], e.reason=[%s]" % (e.comp, e.reason)) + exc = WebSphereMQJMSException(e, e.comp, e.reason) + raise exc + + + def open_dynamic_queue(self): + if self._disconnecting: + self.logger.info("Connection factory disconnecting, aborting open_dynamic_queue") + return + else: + self.logger.log(TRACE1, "open_dynamic_queue -> not disconnecting") + + if not self._is_connected: + self.logger.log(TRACE1, "open_dynamic_queue -> _is_connected1 %s" % self._is_connected) + self._connect() + self.logger.log(TRACE1, "open_dynamic_queue -> _is_connected2 %s" % self._is_connected) + + dynamic_queue = self.mq.Queue(self.mgr, self.dynamic_queue_template, + self.CMQC.MQOO_INPUT_SHARED) + + # A bit hackish, but there's no other way to get its name. + dynamic_queue_name = dynamic_queue._Queue__qDesc.ObjectName.strip() + + lock = RLock() + lock.acquire() + try: + self._open_dynamic_queues_cache[dynamic_queue_name] = dynamic_queue + finally: + lock.release() + + self.logger.log(TRACE1, "Successfully created a dynamic queue, descriptor [%s]" % ( + dynamic_queue._Queue__qDesc)) + + return dynamic_queue_name + + def close_dynamic_queue(self, dynamic_queue_name): + if self._disconnecting: + self.logger.info("Connection factory disconnecting, aborting close_dynamic_queue") + return + else: + self.logger.log(TRACE1, "close_dynamic_queue -> not disconnecting") + + if not self._is_connected: + # If we're not connected then all dynamic queues had been already closed. + self.logger.log(TRACE1, "close_dynamic_queue -> _is_connected1 %s" % self._is_connected) + return + else: + self.logger.log(TRACE1, "close_dynamic_queue -> _is_connected2 %s" % self._is_connected) + lock = RLock() + lock.acquire() + try: + dynamic_queue = self._open_dynamic_queues_cache[dynamic_queue_name] + dynamic_queue.close() + + self._open_dynamic_queues_cache.pop(dynamic_queue_name, None) + self._open_send_queues_cache.pop(dynamic_queue_name, None) + self._open_receive_queues_cache.pop(dynamic_queue_name, None) + + self.logger.log(TRACE1, "Successfully closed a dynamic queue [%s]" % ( + dynamic_queue_name)) + + finally: + lock.release() + + def _get_jms_timestamp_from_md(self, put_date, put_time): + pattern = "%Y%m%d%H%M%S" + centi = int(put_time[6:]) / 100.0 + + strp = strptime(put_date + put_time[:6], pattern) + mk = mktime(strp) + + return long((mk - altzone + centi) * 1000.0) + + + def _build_text_message(self, md, message): + self.logger.log(TRACE1, "Building a text message [%r], md [%r]" % (repr(message), repr(md))) + + mqrfh2 = MQRFH2JMS(self.needs_mcd) + mqrfh2.build_folders_and_payload_from_message(message) + + jms_folder = mqrfh2.folders.get("jms", None) + mcd_folder = mqrfh2.folders.get("mcd", None) + usr_folder = mqrfh2.folders.get("usr", None) + + # Create a message instance .. + text_message = TextMessage() + + if usr_folder: + for attr_name, attr_value in usr_folder.items(): + setattr(text_message, attr_name, str(attr_value)) + + # .. set its JMS properties .. + + if jms_folder: + if jms_folder.find("Dst") is not None: + text_message.jms_destination = jms_folder.find("Dst").text.strip() + + if jms_folder.find("Exp") is not None: + text_message.jms_expiration = long(jms_folder.find("Exp").text) + else: + text_message.jms_expiration = 0 # Same as in Java + + if jms_folder.find("Cid") is not None: + text_message.jms_correlation_id = jms_folder.find("Cid").text + + if md.Persistence == self.CMQC.MQPER_NOT_PERSISTENT: + text_message.jms_delivery_mode = DELIVERY_MODE_NON_PERSISTENT + elif md.Persistence in(self.CMQC.MQPER_PERSISTENT, self.CMQC.MQPER_PERSISTENCE_AS_Q_DEF): + text_message.jms_delivery_mode = DELIVERY_MODE_PERSISTENT + else: + text = "Don't know how to handle md.Persistence mode [%s]" % (md.Persistence) + self.logger.error(text) + exc = WebSphereMQJMSException(text) + raise exc + + if md.ReplyToQ.strip(): + self.logger.log(TRACE1, "Found md.ReplyToQ=[%r]" % md.ReplyToQ) + text_message.jms_reply_to = "queue://" + md.ReplyToQMgr.strip() + "/" + md.ReplyToQ.strip() + + text_message.jms_priority = md.Priority + text_message.jms_message_id = _WMQ_ID_PREFIX + hexlify(md.MsgId) + text_message.jms_timestamp = self._get_jms_timestamp_from_md(md.PutDate.strip(), md.PutTime.strip()) + text_message.jms_redelivered = bool(int(md.BackoutCount)) + + text_message.JMSXUserID = md.UserIdentifier.strip() + text_message.JMSXAppID = md.PutApplName.strip() + text_message.JMSXDeliveryCount = md.BackoutCount + text_message.JMSXGroupID = md.GroupId.strip() + text_message.JMSXGroupSeq = md.MsgSeqNumber + + md_report_to_jms = { + self.CMQC.MQRO_EXCEPTION: "Exception", + self.CMQC.MQRO_EXPIRATION: "Expiration", + self.CMQC.MQRO_COA: "COA", + self.CMQC.MQRO_COD: "COD", + self.CMQC.MQRO_PAN: "PAN", + self.CMQC.MQRO_NAN: "NAN", + self.CMQC.MQRO_PASS_MSG_ID: "Pass_Msg_ID", + self.CMQC.MQRO_PASS_CORREL_ID: "Pass_Correl_ID", + self.CMQC.MQRO_DISCARD_MSG: "Discard_Msg", + } + + for report_name, jms_header_name in md_report_to_jms.iteritems(): + report_value = md.Report & report_name + if report_value: + header_value = report_value + else: + header_value = None + + setattr(text_message, "JMS_IBM_Report_" + jms_header_name, header_value) + + text_message.JMS_IBM_MsgType = md.MsgType + text_message.JMS_IBM_Feedback = md.Feedback + text_message.JMS_IBM_Format = md.Format.strip() + text_message.JMS_IBM_PutApplType = md.PutApplType + text_message.JMS_IBM_PutDate = md.PutDate.strip() + text_message.JMS_IBM_PutTime = md.PutTime.strip() + + if md.MsgFlags & self.CMQC.MQMF_LAST_MSG_IN_GROUP: + text_message.JMS_IBM_Last_Msg_In_Group = self.CMQC.MQMF_LAST_MSG_IN_GROUP + else: + text_message.JMS_IBM_Last_Msg_In_Group = None + + # .. and its payload too. + if mqrfh2.payload: + text_message.text = mqrfh2.payload + + return text_message + + def _strip_prefixes_from_destination(self, destination): + if destination.startswith("queue:///"): + return destination.replace("queue:///", "", 1) + elif destination.startswith("queue://"): + no_qm_dest = destination.replace("queue://", "", 1) + no_qm_dest = no_qm_dest.split("/")[1:] + return "/".join(no_qm_dest) + else: + return destination + + def _build_md(self, message): + md = self.mq.md() + + md.Format = _WMQ_MQFMT_RF_HEADER_2 + md.CodedCharSetId = _WMQ_DEFAULT_CCSID + md.Encoding = _WMQ_DEFAULT_ENCODING + + # Map JMS headers to MQMD + + if message.jms_correlation_id: + if message.jms_correlation_id.startswith(_WMQ_ID_PREFIX): + md.CorrelId = unhexlify_wmq_id(message.jms_correlation_id) + else: + md.CorrelId = message.jms_correlation_id.ljust(24)[:24] + + if message.jms_delivery_mode: + + if message.jms_delivery_mode == DELIVERY_MODE_NON_PERSISTENT: + persistence = self.CMQC.MQPER_NOT_PERSISTENT + elif message.jms_delivery_mode == DELIVERY_MODE_PERSISTENT: + persistence = self.CMQC.MQPER_PERSISTENT + else: + info = "jms_delivery_mode should be equal to DELIVERY_MODE_NON_PERSISTENT or DELIVERY_MODE_PERSISTENT, not [%s]" % message.jms_delivery_mode + self.logger.error(info) + exc = JMSException(info) + raise exc + + md.Persistence = persistence + + if message.jms_priority: + md.Priority = message.jms_priority + + if message.jms_reply_to: + md.ReplyToQ = message.jms_reply_to + + self.logger.log(TRACE1, ("Set jms_reply_to. md.ReplyToQ=[%r]," + " message.jms_reply_to=[%r]" % (md.ReplyToQ, message.jms_reply_to))) + + # jms_expiration is in milliseconds, md.Expiry is in centiseconds. + if message.jms_expiration: + if message.jms_expiration / 1000 > _WMQ_MAX_EXPIRY_TIME: + md.Expiry = self.CMQC.MQEI_UNLIMITED + else: + md.Expiry = message.jms_expiration / 10 + + # WebSphere MQ provider-specific JMS headers + + jmsxgroupseq = getattr(message, "JMSXGroupSeq", None) + if jmsxgroupseq != None: + md.MsgSeqNumber = jmsxgroupseq + md.MsgFlags |= self.CMQC.MQMF_MSG_IN_GROUP + + jmsxgroupid = getattr(message, "JMSXGroupID", None) + if jmsxgroupid != None: + if jmsxgroupid.startswith(_WMQ_ID_PREFIX): + md.GroupId = unhexlify_wmq_id(jmsxgroupid) + else: + md.GroupId = jmsxgroupid.ljust(24)[:24] + md.MsgFlags |= self.CMQC.MQMF_MSG_IN_GROUP + + for report_name in("Exception", "Expiration", "COA", "COD", "PAN", + "NAN", "Pass_Msg_ID", "Pass_Correl_ID", "Discard_Msg"): + + report = getattr(message, "JMS_IBM_Report_" + report_name, None) + if report != None: + md.Report |= report + + # Doesn't make much sense to map feedback options as we're stuffed into + # request messages (MQMT_REQUEST) not report messages (MQMT_REPORT) + # but different types of messages are still possible to implement in + # the future so let's leave it. + + jms_ibm_feedback = getattr(message, "JMS_IBM_Feedback", None) + if jms_ibm_feedback != None: + md.Feedback = jms_ibm_feedback + + jms_ibm_last_msg_in_group = getattr(message, "JMS_IBM_Last_Msg_In_Group", None) + if jms_ibm_last_msg_in_group != None: + md.MsgFlags |= self.CMQC.MQMF_LAST_MSG_IN_GROUP + + return md + +class MQRFH2JMS(object): + """ A class for representing a subset of MQRFH2, suitable for passing + WebSphere MQ JMS headers around. + """ + + # 4 bytes - MQRFH_STRUC_ID + # 4 bytes - _WMQ_MQRFH_VERSION_2 + # 4 bytes - the whole MQRFH2 header length + # 4 bytes - Encoding + # 4 bytes - CodedCharacterSetId + # 8 bytes - MQFMT_STRING + # 4 bytes - MQRFH_NO_FLAGS + # 4 bytes - NameValueCCSID + FIXED_PART_LENGTH = 36 + + # MQRFH2 folder length must be a multiple of 4. + FOLDER_LENGTH_MULTIPLE = 4 + + # Size of a folder header is always 4 bytes. + FOLDER_SIZE_HEADER_LENGTH = 4 + + def __init__(self, needs_mcd=True): + + # Whether to add the mcd folder. Needs to be False for everything to + # work properly with WMQ >= 7.0 + self.needs_mcd = needs_mcd + + self.folders = {} + self.payload = None + + self.logger = logging.getLogger("springpython.jms.factory.MQRFH2JMS") + + def _pad_folder(self, folder): + """ Pads the folder to a multiple of 4, as required by WebSphere MQ. + """ + folder_len = len(folder) + + if folder_len % MQRFH2JMS.FOLDER_LENGTH_MULTIPLE == 0: + return folder + else: + padding = MQRFH2JMS.FOLDER_LENGTH_MULTIPLE - folder_len % MQRFH2JMS.FOLDER_LENGTH_MULTIPLE + return folder.ljust(folder_len + padding) + + def build_folders_and_payload_from_message(self, message): + total_mqrfh2_length = unpack("!l", message[8:12])[0] + + mqrfh2 = message[MQRFH2JMS.FIXED_PART_LENGTH:total_mqrfh2_length] + self.payload = message[MQRFH2JMS.FIXED_PART_LENGTH + len(mqrfh2):] + + self.logger.log(TRACE1, "message [%r]" % message) + self.logger.log(TRACE1, "mqrfh2 [%r]" % mqrfh2) + self.logger.log(TRACE1, "self.payload [%r]" % self.payload) + + left = mqrfh2 + while left: + current_folder_length = unpack("!l", left[:4])[0] + raw_folder = left[MQRFH2JMS.FOLDER_SIZE_HEADER_LENGTH:MQRFH2JMS.FOLDER_SIZE_HEADER_LENGTH + current_folder_length] + + self.logger.log(TRACE1, "raw_folder [%r]" % raw_folder) + self.build_folder(raw_folder) + + left = left[MQRFH2JMS.FOLDER_SIZE_HEADER_LENGTH + current_folder_length:] + + def build_folder(self, raw_folder): + + # Java JMS sends folders with unbound prefixes, i.e. + # which is in no way a valid XML so we have to insert the prefix ourselves + # in order to avoid parser bailing out with an ExpatError. I can't think + # of any other way to work around it if we'd like to treat folders as + # XML(-like) structures. + + if 'xsi:nil="true"' in raw_folder and not 'xmlns' in raw_folder: + self.logger.log(TRACE1, "Binding xsi:nil to a dummy namespace [%s]" % raw_folder) + raw_folder = raw_folder.replace('xsi:nil="true"', 'xmlns:xsi="dummy" xsi:nil="true"') + self.logger.log(TRACE1, "raw_folder after binding [%s]" % raw_folder) + + folder = etree.fromstring(raw_folder) + root_name = folder.tag + + root_names = ["jms", "usr"] + if self.needs_mcd: + root_names.append("mcd") + + if root_name in root_names: + self.folders[root_name] = folder + else: + self.logger.warn("Ignoring unrecognized JMS folder [%s]=[%s]" % (root_name, raw_folder)) + + + def build_header(self, message, queue_name, CMQC, now): + + if self.needs_mcd: + self.folders["mcd"] = _mcd + mcd = self._pad_folder(etree.tostring(self.folders["mcd"])) + mcd_len = len(mcd) + else: + mcd_len = 0 + + self.add_jms(message, queue_name, now) + self.add_usr(message) + + jms = self._pad_folder(etree.tostring(self.folders["jms"])) + + if "usr" in self.folders: + usr = self._pad_folder(etree.tostring(self.folders["usr"])) + usr_len = len(usr) + else: + usr_len = 0 + + jms_len = len(jms) + + total_header_length = 0 + total_header_length += MQRFH2JMS.FIXED_PART_LENGTH + + # Each folder has a 4-byte header describing its length, + # hence the "len(self.folders) * 4" below. + variable_part_length = len(self.folders) * 4 + mcd_len + jms_len + usr_len + + total_header_length += variable_part_length + + buff = StringIO() + + buff.write(CMQC.MQRFH_STRUC_ID) + buff.write(_WMQ_MQRFH_VERSION_2) + buff.write(pack("!l", total_header_length)) + buff.write(_WMQ_DEFAULT_ENCODING_WIRE_FORMAT) + buff.write(_WMQ_DEFAULT_CCSID_WIRE_FORMAT) + buff.write(CMQC.MQFMT_STRING) + buff.write(_WMQ_MQRFH_NO_FLAGS_WIRE_FORMAT) + buff.write(_WMQ_DEFAULT_CCSID_WIRE_FORMAT) + + if self.needs_mcd: + buff.write(pack("!l", mcd_len)) + buff.write(mcd) + + buff.write(pack("!l", jms_len)) + buff.write(jms) + + if "usr" in self.folders: + buff.write(pack("!l", usr_len)) + buff.write(usr) + + value = buff.getvalue() + buff.close() + + return value + + def add_jms(self, message, queue_name, now): + + jms = etree.Element("jms") + dst = etree.Element("Dst") + tms = etree.Element("Tms") + dlv = etree.Element("Dlv") + + jms.append(dst) + jms.append(tms) + jms.append(dlv) + + tms.text = unicode(now) + dst.text = u"queue:///" + queue_name + dlv.text = unicode(message.jms_delivery_mode) + + if message.jms_expiration: + exp = etree.Element("Exp") + exp.text = unicode(now + message.jms_expiration) + self.logger.log(TRACE1, "jms.Exp [%r]" % exp.text) + jms.append(exp) + + if message.jms_priority: + pri = etree.Element("Pri") + pri.text = unicode(message.jms_priority) + self.logger.log(TRACE1, "jms.Pri [%r]" % pri.text) + jms.append(pri) + + if message.jms_correlation_id: + cid = etree.Element("Cid") + cid.text = unicode(message.jms_correlation_id) + self.logger.log(TRACE1, "jms.Cid [%r]" % cid.text) + jms.append(cid) + + self.folders["jms"] = jms + + def add_usr(self, message): + user_attrs = set(dir(message)) - reserved_attributes + self.logger.log(TRACE1, "user_attrs [%s]" % user_attrs) + + if user_attrs: + usr = etree.Element("usr") + + for user_attr in user_attrs: + + user_attr_value = getattr(message, user_attr) + + # Some values are integers, e.g. delivery_mode + if isinstance(user_attr_value, basestring): + user_attr_value = escape(user_attr_value) + + # Create a JMS attribute and set its value. + user_attr = etree.Element(unicode(user_attr)) + user_attr.text = unicode(user_attr_value) + usr.append(user_attr) + + self.folders["usr"] = usr diff --git a/src/springpython/jms/listener.py b/src/springpython/jms/listener.py index 2469cf5..34d5c81 100644 --- a/src/springpython/jms/listener.py +++ b/src/springpython/jms/listener.py @@ -53,13 +53,13 @@ def run(self, *ignored): try: self.handlers_pool.poll() - except NoResultsPending, e: + except NoResultsPending as e: pass - except NoMessageAvailableException, e: + except NoMessageAvailableException as e: self.logger.log(TRACE1, "Consumer did not receive a message. %s" % self._get_destination_info()) - except WebSphereMQJMSException, e: + except WebSphereMQJMSException as e: self.logger.error("%s in run, e.completion_code=[%s], " "e.reason_code=[%s]" % (e.__class__.__name__, e.completion_code, e.reason_code)) raise diff --git a/src/springpython/jms/listener.py.bak b/src/springpython/jms/listener.py.bak new file mode 100644 index 0000000..2469cf5 --- /dev/null +++ b/src/springpython/jms/listener.py.bak @@ -0,0 +1,117 @@ +""" + Copyright 2006-2008 SpringSource (http://springsource.com), All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +# stdlib +import logging + +# Circuits +from circuits import Component, Manager, Debugger + +# ThreadPool +from threadpool import ThreadPool, WorkRequest, NoResultsPending + +# Spring Python +from springpython.util import TRACE1 +from springpython.context import InitializingObject +from springpython.jms import WebSphereMQJMSException, NoMessageAvailableException + +class MessageHandler(object): + def handle(self, message): + raise NotImplementedError("Should be overridden by subclasses.") + +class WebSphereMQListener(Component): + """ A JMS listener for receiving the messages off WebSphere MQ queues. + """ + def __init__(self): + super(Component, self).__init__() + self.logger = logging.getLogger("springpython.jms.listener.WebSphereMQListener(%s)" % (hex(id(self)))) + + def _get_destination_info(self): + return "destination=[%s], %s" % (self.destination, self.factory.get_connection_info()) + + def run(self, *ignored): + while True: + try: + message = self.factory.receive(self.destination, self.wait_interval) + self.logger.log(TRACE1, "Message received [%s]" % str(message).decode("utf-8")) + + work_request = WorkRequest(self.handler.handle, [message]) + self.handlers_pool.putRequest(work_request) + + try: + self.handlers_pool.poll() + except NoResultsPending, e: + pass + + except NoMessageAvailableException, e: + self.logger.log(TRACE1, "Consumer did not receive a message. %s" % self._get_destination_info()) + + except WebSphereMQJMSException, e: + self.logger.error("%s in run, e.completion_code=[%s], " + "e.reason_code=[%s]" % (e.__class__.__name__, e.completion_code, e.reason_code)) + raise + +class SimpleMessageListenerContainer(InitializingObject): + """ A container for individual JMS listeners. + """ + + def __init__(self, factory=None, destination=None, handler=None, + concurrent_listeners=1, handlers_per_listener=2, wait_interval=1000): + """ factory - reference a to JMS connection factory + destination - name of a queue to get the messages off + handler - reference to an object which will be passed the incoming messages + concurrent_listeners - how many concurrent JMS listeners the container + will manage + handlers_per_listener - how many handler threads each listener will receive + wait_interval - time, in milliseconds, indicating how often each JMS + listener will check for new messages + """ + + self.factory = factory + self.destination = destination + self.handler = handler + self.concurrent_listeners = concurrent_listeners + self.handlers_per_listener = handlers_per_listener + self.wait_interval = wait_interval + + self.logger = logging.getLogger("springpython.jms.listener.SimpleMessageListenerContainer") + + def after_properties_set(self): + """ Run by Spring Python after all the JMS container's properties have + been set. + """ + + for idx in range(self.concurrent_listeners): + # Create as many Circuits managers as there are JMS listeners. + manager = Manager() + manager.start() + + # A pool of handler threads for each listener. + handlers_pool = ThreadPool(self.handlers_per_listener) + + # Each manager gets assigned its own listener. + listener = WebSphereMQListener() + + # Assign the listener and a debugger component to the manager. + manager += listener + manager += Debugger(logger=self.logger) + + listener.factory = self.factory + listener.destination = self.destination + listener.handler = self.handler + listener.handlers_pool = handlers_pool + listener.wait_interval = self.wait_interval + listener.start() diff --git a/src/springpython/remoting/hessian/__init__.py b/src/springpython/remoting/hessian/__init__.py index 2f14fc2..90dd8d0 100644 --- a/src/springpython/remoting/hessian/__init__.py +++ b/src/springpython/remoting/hessian/__init__.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ +from future.utils import raise_ from springpython.remoting.hessian.hessianlib import Hessian class HessianProxyFactory(object): @@ -34,7 +35,7 @@ def __getattr__(self, name): if name == "service_url": return self.service_url elif name in ["post_process_before_initialization", "post_process_after_initialization"]: - raise AttributeError, name + raise_(AttributeError, name) else: if self.client_proxy is None: self.__dict__["client_proxy"] = Hessian(self.service_url) diff --git a/src/springpython/remoting/hessian/__init__.py.bak b/src/springpython/remoting/hessian/__init__.py.bak new file mode 100644 index 0000000..2f14fc2 --- /dev/null +++ b/src/springpython/remoting/hessian/__init__.py.bak @@ -0,0 +1,42 @@ +""" + Copyright 2006-2008 SpringSource (http://springsource.com), All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +from springpython.remoting.hessian.hessianlib import Hessian + +class HessianProxyFactory(object): + """ + This is wrapper around a Hessian client proxy. The idea is to inject this object with a + Hessian service_url, which in turn generates a Hessian client proxy. After that, any + method calls or attribute accessses will be forwarded to the Hessian client proxy. + """ + def __init__(self): + self.__dict__["client_proxy"] = None + + def __setattr__(self, name, value): + if name == "service_url": + self.__dict__["service_url"] = value + else: + setattr(self.client_proxy, name, value) + + def __getattr__(self, name): + if name == "service_url": + return self.service_url + elif name in ["post_process_before_initialization", "post_process_after_initialization"]: + raise AttributeError, name + else: + if self.client_proxy is None: + self.__dict__["client_proxy"] = Hessian(self.service_url) + return getattr(self.client_proxy, name) + diff --git a/src/springpython/remoting/hessian/hessianlib.py b/src/springpython/remoting/hessian/hessianlib.py index 9efbc44..20fb3e6 100644 --- a/src/springpython/remoting/hessian/hessianlib.py +++ b/src/springpython/remoting/hessian/hessianlib.py @@ -1,3 +1,4 @@ +from __future__ import print_function # # A Hessian client interface for Python. The date and long types require # Python 2.2 or later. @@ -58,6 +59,7 @@ # Credits: hessianlib.py was inspired and partially based on # xmlrpclib.py created by Fredrik Lundh at www.pythonware.org # +from future.utils import raise_ import string, time import urllib from types import * @@ -184,7 +186,7 @@ def write_object(self, value): try: f = self.dispatch[type(value)] except KeyError: - raise TypeError, "cannot write %s objects" % type(value) + raise_(TypeError, "cannot write %s objects" % type(value)) else: f(self, value) @@ -213,7 +215,7 @@ def write_reference(self, value): # check for and write circular references # returns 1 if the object should be written, i.e. not a reference i = id(value) - if self.refs.has_key(i): + if i in self.refs: self.write('R') self.write(pack(">L", self.refs[i])) return 0 @@ -413,7 +415,7 @@ def __init__(self, url): # get the uri type, uri = urllib.splittype(url) if type != "http": - raise IOError, "unsupported Hessian protocol" + raise IOError("unsupported Hessian protocol") self._host, self._uri = urllib.splithost(uri) @@ -476,6 +478,6 @@ def __getattr__(self, name): proxy = Hessian("http://hessian.caucho.com/test/test") try: - print proxy.hello() - except Error, v: - print "ERROR", v + print(proxy.hello()) + except Error as v: + print("ERROR", v) diff --git a/src/springpython/remoting/hessian/hessianlib.py.bak b/src/springpython/remoting/hessian/hessianlib.py.bak new file mode 100644 index 0000000..9efbc44 --- /dev/null +++ b/src/springpython/remoting/hessian/hessianlib.py.bak @@ -0,0 +1,481 @@ +# +# A Hessian client interface for Python. The date and long types require +# Python 2.2 or later. +# +# The Hessian proxy is used as follows: +# +# proxy = Hessian("http://hessian.caucho.com/test/basic") +# +# print proxy.hello() +# +# -------------------------------------------------------------------- +# +# The Apache Software License, Version 1.1 +# +# Copyright (c) 2001-2002 Caucho Technology, Inc. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in +# the documentation and/or other materials provided with the +# distribution. +# +# 3. The end-user documentation included with the redistribution, if +# any, must include the following acknowlegement: +# "This product includes software developed by the +# Caucho Technology (http://www.caucho.com/)." +# Alternately, this acknowlegement may appear in the software itself, +# if and wherever such third-party acknowlegements normally appear. +# +# 4. The names "Hessian", "Resin", and "Caucho" must not be used to +# endorse or promote products derived from this software without prior +# written permission. For written permission, please contact +# info@caucho.com. +# +# 5. Products derived from this software may not be called "Resin" +# nor may "Resin" appear in their names without prior written +# permission of Caucho Technology. +# +# THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESSED OR IMPLIED +# WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES +# OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL CAUCHO TECHNOLOGY OR ITS CONTRIBUTORS +# BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, +# OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT +# OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR +# BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE +# OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN +# IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# -------------------------------------------------------------------- +# +# Credits: hessianlib.py was inspired and partially based on +# xmlrpclib.py created by Fredrik Lundh at www.pythonware.org +# +import string, time +import urllib +from types import * +from struct import unpack +from struct import pack + +__version__ = "0.1" + + +# -------------------------------------------------------------------- +# Exceptions + +class Error: + # base class for client errors + pass + +class ProtocolError(Error): + # Represents an HTTP protocol error + def __init__(self, url, code, message, headers): + self.url = url + self.code = code + self.message = message + self.headers = headers + + def __repr__(self): + return ( + "" % + (self.url, self.code, self.message) + ) + +class Fault(Error): + # Represents a fault from Hessian + def __init__(self, code, message, **detail): + self.code = code + self.message = message + + def __repr__(self): + return "" % (self.code, self.message) + +# -------------------------------------------------------------------- +# Wrappers for Hessian data types non-standard in Python +# + +# +# Boolean -- use the True or False constants +# +class Boolean: + def __init__(self, value = 0): + self.value = (value != 0) + + def _hessian_write(self, out): + if self.value: + out.write('T') + else: + out.write('F') + + def __repr__(self): + if self.value: + return "" % id(self) + else: + return "" % id(self) + + def __int__(self): + return self.value + + def __nonzero__(self): + return self.value + +True, False = Boolean(1), Boolean(0) + +# +# Date - wraps a time value in seconds +# +class Date: + def __init__(self, value = 0): + self.value = value + + def __repr__(self): + return ("" % + (time.asctime(time.localtime(self.value)), id(self))) + + def _hessian_write(self, out): + out.write("d") + out.write(pack(">q", self.value * 1000.0)) +# +# Binary - binary data +# + +class Binary: + def __init__(self, data=None): + self.data = data + + def _hessian_write(self, out): + out.write('B') + out.write(pack('>H', len(self.data))) + out.write(self.data) + +# -------------------------------------------------------------------- +# Marshalling and unmarshalling code + +# +# HessianWriter - writes Hessian data from Python objects +# +class HessianWriter: + dispatch = {} + + def write_call(self, method, params): + self.refs = {} + self.ref = 0 + self.__out = [] + self.write = write = self.__out.append + + write("c\x01\x00m"); + write(pack(">H", len(method))); + write(method); + for v in params: + self.write_object(v) + write("z"); + result = string.join(self.__out, "") + del self.__out, self.write, self.refs + return result + + def write_object(self, value): + try: + f = self.dispatch[type(value)] + except KeyError: + raise TypeError, "cannot write %s objects" % type(value) + else: + f(self, value) + + def write_int(self, value): + self.write('I') + self.write(pack(">l", value)) + dispatch[IntType] = write_int + + def write_long(self, value): + self.write('L') + self.write(pack(">q", value)) + dispatch[LongType] = write_long + + def write_double(self, value): + self.write('D') + self.write(pack(">d", value)) + dispatch[FloatType] = write_double + + def write_string(self, value): + self.write('S') + self.write(pack('>H', len(value))) + self.write(value) + dispatch[StringType] = write_string + + def write_reference(self, value): + # check for and write circular references + # returns 1 if the object should be written, i.e. not a reference + i = id(value) + if self.refs.has_key(i): + self.write('R') + self.write(pack(">L", self.refs[i])) + return 0 + else: + self.refs[i] = self.ref + self.ref = self.ref + 1 + return 1 + + def write_list(self, value): + if self.write_reference(value): + self.write("Vt\x00\x00I"); + self.write(pack('>l', len(value))) + for v in value: + self.__write(v) + self.write('z') + dispatch[TupleType] = write_list + dispatch[ListType] = write_list + + def write_map(self, value): + if self.write_reference(value): + self.write("Mt\x00\x00") + for k, v in value.items(): + self.__write(k) + self.__write(v) + self.write("z") + dispatch[DictType] = write_map + + def write_instance(self, value): + # check for special wrappers + if hasattr(value, "_hessian_write"): + value._hessian_write(self) + else: + fields = value.__dict__ + if self.write_reference(fields): + self.write("Mt\x00\x00") + for k, v in fields.items(): + self.__write(k) + self.__write(v) + self.write("z") + dispatch[InstanceType] = write_instance + +# +# Parses the results from the server +# +class HessianParser: + def __init__(self, f): + self._f = f + self._peek = -1 + # self.read = f.read + self._refs = [] + + def read(self, len): + if self._peek >= 0: + value = self._peek + self._peek = -1 + return value + else: + return self._f.read(len) + + def parse_reply(self): + # parse header 'c' x01 x00 'v' ... 'z' + read = self.read + if read(1) != 'r': + self.error() + major = read(1) + minor = read(1) + + value = self.parse_object() + + if read(1) == 'z': + return value + self.error() # actually a fault + + def parse_object(self): + # parse an arbitrary object based on the type in the data + return self.parse_object_code(self.read(1)) + + def parse_object_code(self, code): + # parse an object when the code is known + read = self.read + + if code == 'N': + return None + + elif code == 'T': + return True + + elif code == 'F': + return False + + elif code == 'I': + return unpack('>l', read(4))[0] + + elif code == 'L': + return unpack('>q', read(8))[0] + + elif code == 'D': + return unpack('>d', read(8))[0] + + elif code == 'd': + ms = unpack('>q', read(8))[0] + + return Date(int(ms / 1000.0)) + + elif code == 'S' or code == 'X': + return self.parse_string() + + elif code == 'B': + return Binary(self.parse_string()) + + elif code == 'V': + self.parse_type() # skip type + self.parse_length() # skip length + list = [] + self._refs.append(list) + ch = read(1) + while ch != 'z': + list.append(self.parse_object_code(ch)) + ch = read(1) + return list + + elif code == 'M': + self.parse_type() # skip type + map = {} + self._refs.append(map) + ch = read(1) + while ch != 'z': + key = self.parse_object_code(ch) + value = self.parse_object() + map[key] = value + ch = read(1) + return map + + elif code == 'R': + return self._refs[unpack('>l', read(4))[0]] + + elif code == 'r': + self.parse_type() # skip type + url = self.parse_type() # reads the url + return Hessian(url) + + else: + raise "UnknownObjectCode %d" % code + + def parse_string(self): + f = self._f + len = unpack('>H', f.read(2))[0] + return f.read(len) + + def parse_type(self): + f = self._f + code = self.read(1) + if code != 't': + self._peek = code + return "" + len = unpack('>H', f.read(2))[0] + return f.read(len) + + def parse_length(self): + f = self._f + code = self.read(1); + if code != 'l': + self._peek = code + return -1; + len = unpack('>l', f.read(4)) + return len + + def error(self): + raise "FOO" + +# +# Encapsulates the method to be called +# +class _Method: + def __init__(self, invoker, method): + self._invoker = invoker + self._method = method + + def __call__(self, *args): + return self._invoker(self._method, args) + +# -------------------------------------------------------------------- +# Hessian is the main class. A Hessian proxy is created with the URL +# and then called just as for a local method +# +# proxy = Hessian("http://www.caucho.com/hessian/test/basic") +# print proxy.hello() +# +class Hessian: + """Represents a remote object reachable by Hessian""" + + def __init__(self, url): + # Creates a Hessian proxy object + + self._url = url + + # get the uri + type, uri = urllib.splittype(url) + if type != "http": + raise IOError, "unsupported Hessian protocol" + + self._host, self._uri = urllib.splithost(uri) + + def __invoke(self, method, params): + # call a method on the remote server + + request = HessianWriter().write_call(method, params) + + import httplib + + h = httplib.HTTP(self._host) + h.putrequest("POST", self._uri) + + # required by HTTP/1.1 + h.putheader("Host", self._host) + + h.putheader("User-Agent", "hessianlib.py/%s" % __version__) + h.putheader("Content-Length", str(len(request))) + + h.endheaders() + + h.send(request) + + errcode, errmsg, headers = h.getreply() + + if errcode != 200: + raise ProtocolError(self._url, errcode, errmsg, headers) + + return self.parse_response(h.getfile()) + + def parse_response(self, f): + # read response from input file, and parse it + + parser = HessianParser(f) + value = parser.parse_reply() + f.close() + + return value + + def _hessian_write(self, out): + # marshals the proxy itself + out.write("rt\x00\x00S") + out.write(pack(">H", len(self._url))) + out.write(self._url) + + def __repr__(self): + return "" % self._url + + __str__ = __repr__ + + def __getattr__(self, name): + # encapsulate the method call + return _Method(self.__invoke, name) + +# +# Testing code. +# +if __name__ == "__main__": + + proxy = Hessian("http://hessian.caucho.com/test/test") + + try: + print proxy.hello() + except Error, v: + print "ERROR", v diff --git a/src/springpython/remoting/pyro/Pyro4DaemonHolder.py b/src/springpython/remoting/pyro/Pyro4DaemonHolder.py index d0c2b59..20a813c 100644 --- a/src/springpython/remoting/pyro/Pyro4DaemonHolder.py +++ b/src/springpython/remoting/pyro/Pyro4DaemonHolder.py @@ -66,7 +66,8 @@ def deregister(service_name, host, port): pyro_threads[(host, port)].pyro_daemon.unregister(serviceList[(service_name, host, port)]) del(serviceList[(service_name, host, port)]) - def get_address((service_name, host, port)): + def get_address(xxx_todo_changeme): + (service_name, host, port) = xxx_todo_changeme return (host, port) if len([True for x in serviceList.keys() if get_address(x) == (host, port)]) == 0: @@ -83,7 +84,7 @@ def shutdown(daemon_host, daemon_port): pyro_threads[(daemon_host, daemon_port)].shutdown() time.sleep(1.0) del(pyro_threads[(daemon_host, daemon_port)]) - except Exception, e: + except Exception as e: logger.debug("Failed to shutdown %s:%s => %s" % (daemon_host, daemon_port, e)) class _Pyro4Thread(threading.Thread): diff --git a/src/springpython/remoting/pyro/Pyro4DaemonHolder.py.bak b/src/springpython/remoting/pyro/Pyro4DaemonHolder.py.bak new file mode 100644 index 0000000..d0c2b59 --- /dev/null +++ b/src/springpython/remoting/pyro/Pyro4DaemonHolder.py.bak @@ -0,0 +1,132 @@ +""" + Copyright 2006-2008 SpringSource (http://springsource.com), All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +import logging +import threading +import time +import Pyro4 + +from socket import getaddrinfo, gethostbyname + +pyro_threads = {} +serviceList = {} +logger = logging.getLogger("springpython.remoting.pyro.Pyro4DaemonHolder") + +def resolve(host, port): + canonhost = gethostbyname(host) + canonport = getaddrinfo(host, port)[0][4][1] + + return canonhost, canonport + +def register(pyro_obj, service_name, host, port): + """ + Register the Pyro4 object and its service name with the daemon. + Also add the service to a dictionary of objects. This allows the + PyroDaemonHolder to intelligently know when to start and stop the + daemon thread. + """ + logger.debug("Registering %s at %s:%s with the Pyro4 server" % (service_name, host, port)) + + host, port = resolve(host, port) + + serviceList[(service_name, host, port)] = pyro_obj + + if (host, port) not in pyro_threads: + + logger.debug("Pyro4 thread needs to be started at %s:%s" % (host, port)) + + pyro_threads[(host, port)] = _Pyro4Thread(host, port) + pyro_threads[(host, port)].start() + + if not hasattr(pyro_obj, "_pyroId"): + uri = pyro_threads[(host, port)].pyro_daemon.register(pyro_obj, service_name) + +def deregister(service_name, host, port): + """ + Deregister the named service by removing it from the list of + managed services and also disconnect from the daemon. + """ + logger.debug("Deregistering %s at %s:%s with the Pyro4 server" % (service_name, host, port)) + + host, port = resolve(host, port) + + if (host, port) in pyro_threads: + pyro_threads[(host, port)].pyro_daemon.unregister(serviceList[(service_name, host, port)]) + del(serviceList[(service_name, host, port)]) + + def get_address((service_name, host, port)): + return (host, port) + + if len([True for x in serviceList.keys() if get_address(x) == (host, port)]) == 0: + shutdown(host, port) + +def shutdown(daemon_host, daemon_port): + """This provides a hook so an application can deliberately shutdown a + daemon thread.""" + logger.debug("Shutting down Pyro4 daemon at %s:%s" % (daemon_host, daemon_port)) + + daemon_host, daemon_port = resolve(daemon_host, daemon_port) + + try: + pyro_threads[(daemon_host, daemon_port)].shutdown() + time.sleep(1.0) + del(pyro_threads[(daemon_host, daemon_port)]) + except Exception, e: + logger.debug("Failed to shutdown %s:%s => %s" % (daemon_host, daemon_port, e)) + +class _Pyro4Thread(threading.Thread): + """ + This is a thread that runs the Pyro4 daemon. It is instantiated automatically + from within Pyro4ServiceExporter. + """ + + def __init__(self, host, port): + """ + When this class is created, it also created a Pyro4 core daemon to manage. + """ + threading.Thread.__init__(self) + self.host = host + self.port = port + self.logger = logging.getLogger("springpython.remoting.pyro.Pyro4DaemonHolder._Pyro4Thread") + + self.logger.debug("Creating Pyro4 daemon") + self.pyro_daemon = Pyro4.Daemon(host=host, port=port) + + def run(self): + """ + When this thread starts up, it initializes the Pyro4 server and then puts the + daemon into listen mode so it can process remote requests. + """ + self.logger.debug("Starting up Pyro4 server thread for %s:%s" % (self.host, self.port)) + self.pyro_daemon.requestLoop() + + def shutdown(self): + """ + This is a hook in order to signal the thread that its time to shutdown + the Pyro4 daemon. + """ + self.logger.debug("Signaling shutdown of Pyro4 server thread for %s:%s" % (self.host, self.port)) + class ShutdownThread(threading.Thread): + def __init__(self, pyro_daemon): + threading.Thread.__init__(self) + self.pyro_daemon = pyro_daemon + self.logger = logging.getLogger("springpython.remoting.pyro.Pyro4DaemonHolder.ShutdownThread") + def run(self): + self.logger.debug("Sending shutdown signal...") + self.pyro_daemon.shutdown() + + ShutdownThread(self.pyro_daemon).start() + + diff --git a/src/springpython/remoting/pyro/PyroDaemonHolder.py b/src/springpython/remoting/pyro/PyroDaemonHolder.py index 220bc2a..0afd1af 100644 --- a/src/springpython/remoting/pyro/PyroDaemonHolder.py +++ b/src/springpython/remoting/pyro/PyroDaemonHolder.py @@ -63,7 +63,8 @@ def deregister(service_name, host, port): pyro_threads[(host, port)].pyro_daemon.disconnect(serviceList[(service_name, host, port)]) del(serviceList[(service_name, host, port)]) - def get_address((service_name, host, port)): + def get_address(xxx_todo_changeme): + (service_name, host, port) = xxx_todo_changeme return (host, port) if len([True for x in serviceList.keys() if get_address(x) == (host, port)]) == 0: diff --git a/src/springpython/remoting/pyro/PyroDaemonHolder.py.bak b/src/springpython/remoting/pyro/PyroDaemonHolder.py.bak new file mode 100644 index 0000000..220bc2a --- /dev/null +++ b/src/springpython/remoting/pyro/PyroDaemonHolder.py.bak @@ -0,0 +1,121 @@ +""" + Copyright 2006-2008 SpringSource (http://springsource.com), All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +import logging +import threading +import Pyro.core, Pyro.naming + +from socket import getaddrinfo, gethostbyname + +pyro_threads = {} +serviceList = {} +logger = logging.getLogger("springpython.remoting.pyro.PyroDaemonHolder") + +def resolve(host, port): + canonhost = gethostbyname(host) + canonport = getaddrinfo(host, port)[0][4][1] + + return canonhost, canonport + +def register(pyro_obj, service_name, host, port): + """ + Register the pyro object and its service name with the daemon. + Also add the service to a dictionary of objects. This allows the + PyroDaemonHolder to intelligently know when to start and stop the + daemon thread. + """ + logger.debug("Registering %s at %s:%s with the Pyro server" % (service_name, host, port)) + + host, port = resolve(host, port) + + serviceList[(service_name, host, port)] = pyro_obj + + if (host, port) not in pyro_threads: + + logger.debug("Pyro thread needs to be started at %s:%s" % (host, port)) + + pyro_threads[(host, port)] = _PyroThread(host, port) + pyro_threads[(host, port)].start() + + pyro_threads[(host, port)].pyro_daemon.connect(pyro_obj, service_name) + +def deregister(service_name, host, port): + """ + Deregister the named service by removing it from the list of + managed services and also disconnect from the daemon. + """ + logger.debug("Deregistering %s at %s:%s with the Pyro server" % (service_name, host, port)) + + host, port = resolve(host, port) + + pyro_threads[(host, port)].pyro_daemon.disconnect(serviceList[(service_name, host, port)]) + del(serviceList[(service_name, host, port)]) + + def get_address((service_name, host, port)): + return (host, port) + + if len([True for x in serviceList.keys() if get_address(x) == (host, port)]) == 0: + logger.debug("Shutting down thread on %s:%s" % (host, port)) + shutdown(host, port) + +def shutdown(daemon_host, daemon_port): + """This provides a hook so an application can deliberately shutdown a + daemon thread.""" + logger.debug("Shutting down pyro daemon at %s:%s" % (daemon_host, daemon_port)) + + daemon_host, daemon_port = resolve(daemon_host, daemon_port) + + try: + pyro_threads[(daemon_host, daemon_port)].shutdown() + del(pyro_threads[(daemon_host, daemon_port)]) + except: + logger.debug("Failed to shutdown %s:%s" % (daemon_host, daemon_port)) + +class _PyroThread(threading.Thread): + """ + This is a thread that runs the Pyro daemon. It is instantiated automatically + from within PyroServiceExporter. + """ + + def __init__(self, host, port): + """ + When this class is created, it also created a Pyro core daemon to manage. + """ + threading.Thread.__init__(self) + self.host = host + self.port = port + self.logger = logging.getLogger("springpython.remoting.pyro.PyroDaemonHolder._PyroThread") + + self.pyro_daemon = Pyro.core.Daemon(host=host, port=port) + + def run(self): + """ + When this thread starts up, it initializes the Pyro server and then puts the + daemon into listen mode so it can process remote requests. + """ + self._running = True + self.logger.debug("Starting up Pyro server thread for %s:%s" % (self.host, self.port)) + Pyro.core.initServer() + self.pyro_daemon.requestLoop(condition = lambda:self._running) + + def shutdown(self): + """ + This is a hook in order to signal the thread that its time to shutdown + the Pyro daemon. + """ + self._running = False + self.logger.debug("Signaling shutdown of Pyro server thread for %s:%s" % (self.host, self.port)) + + diff --git a/src/springpython/remoting/pyro/__init__.py b/src/springpython/remoting/pyro/__init__.py index cfad165..22fc0b5 100644 --- a/src/springpython/remoting/pyro/__init__.py +++ b/src/springpython/remoting/pyro/__init__.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ +from future.utils import raise_ import logging import Pyro.core from springpython.context import InitializingObject @@ -79,7 +80,7 @@ def __getattr__(self, name): if name in ["service_url"]: return self.__dict__[name] elif name in ["post_process_before_initialization", "post_process_after_initialization"]: - raise AttributeError, name + raise_(AttributeError, name) else: if self.client_proxy is None: self.__dict__["client_proxy"] = Pyro.core.getProxyForURI(self.service_url) @@ -165,7 +166,7 @@ def __getattr__(self, name): if name in ["service_url"]: return self.__dict__[name] elif name in ["post_process_before_initialization", "post_process_after_initialization"]: - raise AttributeError, name + raise_(AttributeError, name) else: if self.client_proxy is None: self.__dict__["client_proxy"] = Pyro4.Proxy(self.service_url) diff --git a/src/springpython/remoting/pyro/__init__.py.bak b/src/springpython/remoting/pyro/__init__.py.bak new file mode 100644 index 0000000..cfad165 --- /dev/null +++ b/src/springpython/remoting/pyro/__init__.py.bak @@ -0,0 +1,173 @@ +""" + Copyright 2006-2008 SpringSource (http://springsource.com), All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +import logging +import Pyro.core +from springpython.context import InitializingObject +from springpython.remoting.pyro import PyroDaemonHolder + +class PyroServiceExporter(InitializingObject): + """ + This class will expose an object using Pyro. It requires that a daemon thread + be up and running in order to receive requests and allow dispatching to the exposed + object. + """ + def __init__(self, service = None, service_name = None, service_host = "localhost", service_port = 7766): + self.logger = logging.getLogger("springpython.remoting.pyro.PyroServiceExporter") + self.service = service + self.service_name = service_name + self.service_host = service_host + self.service_port = service_port + self._pyro_thread = None + + def __del__(self): + """ + When the service exporter goes out of scope and is garbage collected, the + service must be deregistered. + """ + PyroDaemonHolder.deregister(self.service_name, self.service_host, self.service_port) + + def __setattr__(self, name, value): + """ + Only the explicitly listed attributes can be assigned values. Everything else is passed through to + the actual service. + """ + if name in ["logger", "service", "service_name", "service_host", "service_port", "_pyro_thread"]: + self.__dict__[name] = value + else: + object.__setattr__(self, name, value) + + def after_properties_set(self): + if self.service is None: raise Exception("service must NOT be None") + if self.service_name is None: raise Exception("service_name must NOT be None") + if self.service_host is None: raise Exception("service_host must NOT be None") + if self.service_port is None: raise Exception("service_port must NOT be None") + self.logger.debug("Exporting %s as a Pyro service at %s:%s" % (self.service_name, self.service_host, self.service_port)) + pyro_obj = Pyro.core.ObjBase() + pyro_obj.delegateTo(self.service) + PyroDaemonHolder.register(pyro_obj, self.service_name, self.service_host, self.service_port) + +class PyroProxyFactory(object): + """ + This is wrapper around a Pyro client proxy. The idea is to inject this object with a + Pyro service_url, which in turn generates a Pyro client proxy. After that, any + method calls or attribute accessses will be forwarded to the Pyro client proxy. + """ + def __init__(self): + self.__dict__["client_proxy"] = None + self.__dict__["service_url"] = None + + def __setattr__(self, name, value): + if name == "service_url": + self.__dict__["service_url"] = value + else: + setattr(self.client_proxy, name, value) + + def __getattr__(self, name): + if name in ["service_url"]: + return self.__dict__[name] + elif name in ["post_process_before_initialization", "post_process_after_initialization"]: + raise AttributeError, name + else: + if self.client_proxy is None: + self.__dict__["client_proxy"] = Pyro.core.getProxyForURI(self.service_url) + return getattr(self.client_proxy, name) + +class Pyro4ServiceExporter(InitializingObject): + """ + This class will expose an object using Pyro. It requires that a daemon thread + be up and running in order to receive requests and allow dispatching to the exposed + object. + """ + def __init__(self, service = None, service_name = None, service_host = "localhost", service_port = 7766): + self.logger = logging.getLogger("springpython.remoting.pyro.Pyro4ServiceExporter") + self.service = service + self.service_name = service_name + self.service_host = service_host + self.service_port = service_port + self._pyro_thread = None + + def __del__(self): + """ + When the service exporter goes out of scope and is garbage collected, the + service must be deregistered. + """ + from springpython.remoting.pyro import Pyro4DaemonHolder + Pyro4DaemonHolder.deregister(self.service_name, self.service_host, self.service_port) + + def __setattr__(self, name, value): + """ + Only the explicitly listed attributes can be assigned values. Everything else is passed through to + the actual service. + """ + if name in ["logger", "service", "service_name", "service_host", "service_port", "_pyro_thread"]: + self.__dict__[name] = value + else: + object.__setattr__(self, name, value) + + def after_properties_set(self): + import Pyro4 + from springpython.remoting.pyro import Pyro4DaemonHolder + if self.service is None: raise Exception("service must NOT be None") + if self.service_name is None: raise Exception("service_name must NOT be None") + if self.service_host is None: raise Exception("service_host must NOT be None") + if self.service_port is None: raise Exception("service_port must NOT be None") + self.logger.debug("Exporting %s as a Pyro service at %s:%s" % (self.service_name, self.service_host, self.service_port)) + wrapping_obj = PyroWrapperObj(self.service) + Pyro4DaemonHolder.register(wrapping_obj, self.service_name, self.service_host, self.service_port) + +class PyroWrapperObj(object): + def __init__(self, delegate): + self.delegate = delegate + + def __getattr__(self, name): + if name in ["__pyroInvoke", "__call__", "_pyroId", "_pyroDaemon", "delegate"]: + return self.__dict__[name] + else: + return getattr(self.delegate, name) + + def __setattr__(self, name, value): + if name in ["__pyroInvoke", "__call__", "_pyroId", "_pyroDaemon", "delegate"]: + self.__dict__[name] = value + else: + setattr(self.delegate, name, value) + +class Pyro4ProxyFactory(object): + """ + This is wrapper around a Pyro client proxy. The idea is to inject this object with a + Pyro service_url, which in turn generates a Pyro client proxy. After that, any + method calls or attribute accessses will be forwarded to the Pyro client proxy. + """ + def __init__(self): + self.__dict__["client_proxy"] = None + self.__dict__["service_url"] = None + + def __setattr__(self, name, value): + if name == "service_url": + self.__dict__["service_url"] = value + else: + setattr(self.client_proxy, name, value) + + def __getattr__(self, name): + import Pyro4 + if name in ["service_url"]: + return self.__dict__[name] + elif name in ["post_process_before_initialization", "post_process_after_initialization"]: + raise AttributeError, name + else: + if self.client_proxy is None: + self.__dict__["client_proxy"] = Pyro4.Proxy(self.service_url) + return getattr(self.client_proxy, name) + diff --git a/src/springpython/remoting/xmlrpc.py b/src/springpython/remoting/xmlrpc.py index 010e4ce..0d16d6b 100644 --- a/src/springpython/remoting/xmlrpc.py +++ b/src/springpython/remoting/xmlrpc.py @@ -118,7 +118,7 @@ def verify_request(self, sock, from_addr): sock.close() return False - except Exception, e: + except Exception as e: # It was either an error on our side or the client didn't send the # certificate even though self.cert_reqs was CERT_OPTIONAL (it couldn't diff --git a/src/springpython/remoting/xmlrpc.py.bak b/src/springpython/remoting/xmlrpc.py.bak new file mode 100644 index 0000000..010e4ce --- /dev/null +++ b/src/springpython/remoting/xmlrpc.py.bak @@ -0,0 +1,215 @@ +# -*- coding: utf-8 -*- +""" + Copyright 2006-2011 SpringSource (http://springsource.com), All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + + +# stdlib +import httplib +import logging +import socket +import ssl +import sys +import traceback + +from SimpleXMLRPCServer import SimpleXMLRPCServer, SimpleXMLRPCRequestHandler +from xmlrpclib import ServerProxy, Error, Transport + +# Spring Python +from springpython.remoting.http import CAValidatingHTTPS + +__all__ = ["VerificationException", "SSLServer", "SSLClient"] + +class VerificationException(Exception): + """ Raised when the verification of a certificate's fields fails. + """ + +# ############################################################################## +# Server +# ############################################################################## + +class RequestHandler(SimpleXMLRPCRequestHandler): + rpc_paths = ("/", "/RPC2",) + + def setup(self): + self.connection = self.request # for doPOST + self.rfile = socket._fileobject(self.request, "rb", self.rbufsize) + self.wfile = socket._fileobject(self.request, "wb", self.wbufsize) + +class SSLServer(object, SimpleXMLRPCServer): + def __init__(self, host=None, port=None, keyfile=None, certfile=None, + ca_certs=None, cert_reqs=ssl.CERT_NONE, ssl_version=ssl.PROTOCOL_TLSv1, + do_handshake_on_connect=True, suppress_ragged_eofs=True, ciphers=None, + log_requests=True, **kwargs): + + SimpleXMLRPCServer.__init__(self, (host, port), requestHandler=RequestHandler) + self.logger = logging.getLogger(self.__class__.__name__) + + self.keyfile = keyfile + self.certfile = certfile + self.ca_certs = ca_certs + self.cert_reqs = cert_reqs + self.ssl_version = ssl_version + self.do_handshake_on_connect = do_handshake_on_connect + self.suppress_ragged_eofs = suppress_ragged_eofs + self.ciphers = ciphers + + # Looks awkward to use camelCase here but that's what SimpleXMLRPCRequestHandler + # expects. + self.logRequests = log_requests + + # 'verify_fields' is taken from kwargs to allow for adding more keywords + # in future versions. + self.verify_fields = kwargs.get("verify_fields") + + self.register_functions() + + def get_request(self): + """ Overridden from SocketServer.TCPServer.get_request, wraps the socket in + an SSL context. + """ + sock, from_addr = self.socket.accept() + + # 'ciphers' argument is new in 2.7 and we must support 2.6 so add it + # to kwargs conditionally, depending on the Python version. + + kwargs = {"keyfile":self.keyfile, "certfile":self.certfile, + "server_side":True, "cert_reqs":self.cert_reqs, "ssl_version":self.ssl_version, + "ca_certs":self.ca_certs, "do_handshake_on_connect":self.do_handshake_on_connect, + "suppress_ragged_eofs":self.suppress_ragged_eofs} + + if sys.version_info >= (2, 7): + kwargs["ciphers"] = self.ciphers + + sock = ssl.wrap_socket(sock, **kwargs) + + if self.logger.isEnabledFor(logging.DEBUG): + self.logger.debug("get_request cert='%s', from_addr='%s'" % (sock.getpeercert(), from_addr)) + + return sock, from_addr + + def verify_request(self, sock, from_addr): + """ Overridden from SocketServer.TCPServer.verify_request, adds validation of the + other side's certificate fields. + """ + try: + if self.verify_fields: + + cert = sock.getpeercert() + if not cert: + msg = "Couldn't verify fields, peer didn't send the certificate, from_addr='%s'" % (from_addr,) + raise VerificationException(msg) + + allow_peer, reason = self.verify_peer(cert, from_addr) + if not allow_peer: + self.logger.error(reason) + sock.close() + return False + + except Exception, e: + + # It was either an error on our side or the client didn't send the + # certificate even though self.cert_reqs was CERT_OPTIONAL (it couldn't + # have been CERT_REQUIRED because we wouldn't have got so far, the + # session would've been terminated much earlier in ssl.wrap_socket call). + # Regardless of the reason we cannot accept the client in that case. + + msg = "Verification error='%s', cert='%s', from_addr='%s'" % ( + traceback.format_exc(e), sock.getpeercert(), from_addr) + self.logger.error(msg) + + sock.close() + return False + + return True + + def verify_peer(self, cert, from_addr): + """ Verifies the other side's certificate. May be overridden in subclasses + if the verification process needs to be customized. + """ + + subject = cert.get("subject") + if not subject: + msg = "Peer certificate doesn't have the 'subject' field, cert='%s'" % cert + raise VerificationException(msg) + + subject = dict(elem[0] for elem in subject) + + for verify_field in self.verify_fields: + + expected_value = self.verify_fields[verify_field] + cert_value = subject.get(verify_field, None) + + if not cert_value: + reason = "Peer didn't send the '%s' field, subject fields received '%s'" % ( + verify_field, subject) + return False, reason + + if expected_value != cert_value: + reason = "Expected the subject field '%s' to have value '%s' instead of '%s', subject='%s'" % ( + verify_field, expected_value, cert_value, subject) + return False, reason + + return True, None + + def register_functions(self): + raise NotImplementedError("Must be overridden by subclasses") + +# ############################################################################## +# Client +# ############################################################################## + +class SSLClientTransport(Transport): + """ Handles an HTTPS transaction to an XML-RPC server. + """ + + user_agent = "SSL XML-RPC Client (by http://springpython.webfactional.com)" + + def __init__(self, ca_certs=None, keyfile=None, certfile=None, cert_reqs=None, + ssl_version=None, timeout=None, strict=None): + + self.ca_certs = ca_certs + self.keyfile = keyfile + self.certfile = certfile + self.cert_reqs = cert_reqs + self.ssl_version = ssl_version + self.timeout = timeout + self.strict = strict + + Transport.__init__(self) + + def make_connection(self, host): + return CAValidatingHTTPS(host, strict=self.strict, ca_certs=self.ca_certs, + keyfile=self.keyfile, certfile=self.certfile, cert_reqs=self.cert_reqs, + ssl_version=self.ssl_version, timeout=self.timeout) + +class SSLClient(ServerProxy): + def __init__(self, uri=None, ca_certs=None, keyfile=None, certfile=None, + cert_reqs=ssl.CERT_REQUIRED, ssl_version=ssl.PROTOCOL_TLSv1, + transport=None, encoding=None, verbose=0, allow_none=0, use_datetime=0, + timeout=socket._GLOBAL_DEFAULT_TIMEOUT, strict=None): + + if not transport: + _transport=SSLClientTransport(ca_certs, keyfile, certfile, cert_reqs, + ssl_version, timeout, strict) + else: + _transport=transport(ca_certs, keyfile, certfile, cert_reqs, + ssl_version, timeout, strict) + + + ServerProxy.__init__(self, uri, _transport, encoding, verbose, + allow_none, use_datetime) + + self.logger = logging.getLogger(self.__class__.__name__) diff --git a/src/springpython/security/providers/Ldap.py b/src/springpython/security/providers/Ldap.py index 18f5219..cb8893c 100644 --- a/src/springpython/security/providers/Ldap.py +++ b/src/springpython/security/providers/Ldap.py @@ -13,10 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. """ +from __future__ import absolute_import import sys if "java" in sys.platform.lower(): - from _Ldap_jython import * + from ._Ldap_jython import * else: - from _Ldap_cpython import * + from ._Ldap_cpython import * diff --git a/src/springpython/security/providers/Ldap.py.bak b/src/springpython/security/providers/Ldap.py.bak new file mode 100644 index 0000000..18f5219 --- /dev/null +++ b/src/springpython/security/providers/Ldap.py.bak @@ -0,0 +1,22 @@ +""" + Copyright 2006-2008 SpringSource (http://springsource.com), All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +import sys + +if "java" in sys.platform.lower(): + from _Ldap_jython import * +else: + from _Ldap_cpython import * + diff --git a/src/springpython/security/providers/_Ldap_cpython.py b/src/springpython/security/providers/_Ldap_cpython.py index 66331be..2fe6f23 100644 --- a/src/springpython/security/providers/_Ldap_cpython.py +++ b/src/springpython/security/providers/_Ldap_cpython.py @@ -90,7 +90,7 @@ def authenticate(self, authentication): l.simple_bind_s(dn, authentication.password) self.logger.debug("Successfully bound to server!") return (result_set[0],l) - except Exception, e: + except Exception as e: self.logger.debug("Error %s" % e) raise BadCredentialsException("Invalid password") diff --git a/src/springpython/security/providers/_Ldap_cpython.py.bak b/src/springpython/security/providers/_Ldap_cpython.py.bak new file mode 100644 index 0000000..66331be --- /dev/null +++ b/src/springpython/security/providers/_Ldap_cpython.py.bak @@ -0,0 +1,208 @@ +""" + Copyright 2006-2009 SpringSource (http://springsource.com), All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +import logging +import re +import sys +from springpython.security import AuthenticationException +from springpython.security import AuthenticationServiceException +from springpython.security import BadCredentialsException +from springpython.security import DisabledException +from springpython.security import UsernameNotFoundException +from springpython.security.providers import AuthenticationProvider +from springpython.security.providers import UsernamePasswordAuthenticationToken +from springpython.security.providers.dao import AbstractUserDetailsAuthenticationProvider +from springpython.security.providers.encoding import LdapShaPasswordEncoder + +""" +The ldap library only works with CPython. You should NOT import this library directly. +""" +import ldap + + +class DefaultSpringSecurityContextSource(object): + """ + This class is used to define the url of the ldap server. It expects a string like ldap://:/ + It provides functions to retrieve the parts + """ + + def __init__(self, url=None): + self.url = url + + def server(self): + """Extract the server's hostname/port from the url.""" + return self.url.split("ldap://")[1].split("/")[0].split(":") + + def base(self): + """Extract the baseDN from the url.""" + return self.url.split("ldap://")[1].split("/")[1] + +class BindAuthenticator(object): + """ + This ldap authenticator uses binding to confirm the user's password. This means the password encoding + depends on the ldap library's API as well as the directory server; NOT Spring Python's password + hashing algorithms. + """ + + def __init__(self, context_source=None, user_dn_patterns="uid={0},ou=people"): + self.context_source = context_source + self.user_dn_patterns = user_dn_patterns + self.logger = logging.getLogger("springpython.security.providers.Ldap.BindAuthenticator") + + def authenticate(self, authentication): + """Using the user_dn_patterns, find the user's entry, and then bind to the entry using supplied credentials.""" + + username = self.user_dn_patterns.replace("{0}", authentication.username) + baseDn = self.context_source.base() + + parts = username.split(",") + + if len(parts) > 1: + username = parts[0] + baseDn = ",".join(parts[1:]) + "," + baseDn + + (host, port) = self.context_source.server() + self.logger.debug("Opening connection to server %s/%s" % (host, int(port))) + l = ldap.open(host, int(port)) + + self.logger.debug("Searching for %s in %s" % (username, baseDn)) + result_set = l.search_s(baseDn, ldap.SCOPE_SUBTREE, username, None) + + if len(result_set) != 1: + raise BadCredentialsException("Found %s entries at %s/%s. Should only be 1." % (len(result_set), baseDn, username)) + + dn = result_set[0][0] + self.logger.debug("Attempting to bind %s" % dn) + try: + l.simple_bind_s(dn, authentication.password) + self.logger.debug("Successfully bound to server!") + return (result_set[0],l) + except Exception, e: + self.logger.debug("Error %s" % e) + raise BadCredentialsException("Invalid password") + +class PasswordComparisonAuthenticator(object): + """ + This ldap authenticator uses string comparison to confirm the user's password. This means a password encoder must + be provided, or the default LdapShaPasswordEncoder will be used. It searched for the user's entry, fetches the + password, and then does a string comparison to confirm the password. + """ + + def __init__(self, context_source=None, user_dn_patterns="uid={0},ou=people", password_attr_name="userPassword"): + self.context_source = context_source + self.user_dn_patterns = user_dn_patterns + self.password_attr_name = password_attr_name + self.encoder = LdapShaPasswordEncoder() + self.logger = logging.getLogger("springpython.security.providers.Ldap.PasswordComparisonAuthenticator") + + def authenticate(self, authentication): + """ + Using the user_dn_patterns, find the user's entry, and then retrieve the password field. Encode the supplied + password with the necessary hasher, and compare to the entry. + """ + + username = self.user_dn_patterns.replace("{0}", authentication.username) + baseDn = self.context_source.base() + + parts = username.split(",") + + if len(parts) > 1: + username = parts[0] + baseDn = ",".join(parts[1:]) + "," + baseDn + + (host, port) = self.context_source.server() + self.logger.debug("Opening connection to server %s/%s" % (host, int(port))) + l = ldap.open(host, int(port)) + + self.logger.debug("Searching for %s in %s" % (username, baseDn)) + result_set = l.search_s(baseDn, ldap.SCOPE_SUBTREE, username, None) + + if len(result_set) != 1: + raise BadCredentialsException("Found %s entries at %s/%s. Should only be 1." % (len(result_set), baseDn, username)) + + self.logger.debug("Looking for attributes...%s" % result_set[0][1]) + stored_password = result_set[0][1][self.password_attr_name.lower()][0] + self.logger.debug("Comparing passwords...") + + if self.encoder.isPasswordValid(stored_password, authentication.password, None): + self.logger.debug("Successfully matched passwords!") + return (result_set[0],l) + else: + raise BadCredentialsException("Invalid password") + +class DefaultLdapAuthoritiesPopulator(object): + """ + This ldap authorities populator follows a standard convention, where groups are created, with a member attribute, pointing + at user entries in another part of the directory structure. It then combines ROLE_ with the name of the group, and names + that as a granted role. + """ + + def __init__(self, context_source=None, group_search_base="ou=groups", group_search_filter="member={0}", group_role_attr="cn", role_prefix="ROLE_", convert_to_upper=True): + self.logger = logging.getLogger("springpython.security.providers.Ldap.DefaultLdapAuthoritiesPopulator") + self.context_source = context_source + self.group_search_base = group_search_base + self.group_search_filter = group_search_filter + self.group_role_attr = group_role_attr + self.role_prefix = role_prefix + self.convert_to_upper = convert_to_upper + + def get_granted_auths(self, user_details, l): + group_filter = self.group_search_filter.replace("{0}", user_details[0]) + baseDn = self.group_search_base + "," + self.context_source.base() + + self.logger.debug("Searching for groups for %s" % str(user_details[0])) + result_set = l.search_s(baseDn, ldap.SCOPE_SUBTREE, group_filter, None) + + auths = [] + for row in result_set: + role = self.role_prefix + row[1][self.group_role_attr][0] + if self.convert_to_upper: + auths.append(role.upper()) + else: + auths.append(role) + self.logger.debug("Authorities = %s" % auths) + return auths + +class LdapAuthenticationProvider(AuthenticationProvider): + """ + This authenticator performs two steps: + 1) Authenticate the user to confirm their credentials. + 2) Lookup roles the user has stored in the directory server. + + It is possible to inject any type of authenticator as well as roles populator. + + Spring Python includes two authenticators that perform standard binding or password comparisons. + You are able to code your own and use it instead, especially if you are using a non-conventional mechanism. + + Spring Python includes one role populator, based on the standard convention of defining groups elsewhere in + the directory server's hierarchy. However, you can inject your own if you have a non-convential structure, + such as storing the roles directly in the user's directory entry. + """ + + def __init__(self, ldap_authenticator=None, ldap_authorities_populator=None): + AuthenticationProvider.__init__(self) + self.ldap_authenticator = ldap_authenticator + self.ldap_authorities_populator = ldap_authorities_populator + self.logger = logging.getLogger("springpython.security.providers.Ldap.LdapAuthenticationProvider") + + def authenticate(self, authentication): + user_details, l = self.ldap_authenticator.authenticate(authentication) + from copy import deepcopy + results = deepcopy(authentication) + results.granted_auths = self.ldap_authorities_populator.get_granted_auths(user_details, l) + l.unbind() + return results + diff --git a/src/springpython/security/providers/_Ldap_jython.py b/src/springpython/security/providers/_Ldap_jython.py index 27fd74c..f6a6771 100644 --- a/src/springpython/security/providers/_Ldap_jython.py +++ b/src/springpython/security/providers/_Ldap_jython.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ +from __future__ import print_function import logging import re import sys @@ -41,14 +42,14 @@ import org.springframework.security.providers.UsernamePasswordAuthenticationToken from jarray import array -print """ +print(""" WARNING WARNING WARNING WARNING =============================== This doesn't yet work. There is some issue with Jython. See http://bugs.jython.org/issue1489 and http://jira.springframework.org/browse/SESPRINGPYTHONPY-121 for more details. =============================== WARNING WARNING WARNING WARNING -""" +""") class DefaultSpringSecurityContextSource(object): def __init__(self, url): @@ -100,11 +101,11 @@ def __init__(self, context_source=None, group_search_base="ou=groups", group_sea self._populator.setGroupRoleAttribute(self.group_role_attr) self._populator.setRolePrefix(self.role_prefix) self._populator.setConvertToUpperCase(self.convert_to_upper) - print "LdapAuthoritiesPopulator class loader %s" % self._populator.getClass().getClassLoader() + print("LdapAuthoritiesPopulator class loader %s" % self._populator.getClass().getClassLoader()) def get_granted_auths(self, user_details, username): results = self._populator.getGrantedAuthorities(user_details, username) - print results + print(results) return results class LdapAuthenticationProvider(AuthenticationProvider): @@ -116,7 +117,7 @@ def __init__(self, ldap_authenticator=None, ldap_authorities_populator=None): def authenticate(self, authentication): user_details = self.ldap_authenticator.authenticate(authentication) - print "Context class loader %s" % user_details.getClass().getClassLoader() + print("Context class loader %s" % user_details.getClass().getClassLoader()) from copy import deepcopy results = deepcopy(authentication) results.granted_auths = self.ldap_authorities_populator.get_granted_auths(user_details, authentication.username) diff --git a/src/springpython/security/providers/_Ldap_jython.py.bak b/src/springpython/security/providers/_Ldap_jython.py.bak new file mode 100644 index 0000000..27fd74c --- /dev/null +++ b/src/springpython/security/providers/_Ldap_jython.py.bak @@ -0,0 +1,126 @@ +""" + Copyright 2006-2009 SpringSource (http://springsource.com), All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +import logging +import re +import sys +from springpython.security import AuthenticationException +from springpython.security import AuthenticationServiceException +from springpython.security import BadCredentialsException +from springpython.security import DisabledException +from springpython.security import UsernameNotFoundException +from springpython.security.providers import AuthenticationProvider +from springpython.security.providers import UsernamePasswordAuthenticationToken +from springpython.security.providers.dao import AbstractUserDetailsAuthenticationProvider +from springpython.security.providers.encoding import LdapShaPasswordEncoder + + +""" +The ldap library only works with Jython. You should NOT import this library directly. + +Due to the lack of a pure Python library, this version uses Spring Security/Spring LDAP jar files to perform +authentication and LDAP lookups. +""" +import java +import org.springframework.security.ldap.DefaultSpringSecurityContextSource +import org.springframework.security.ldap.populator.DefaultLdapAuthoritiesPopulator +import org.springframework.security.providers.ldap.authenticator.BindAuthenticator +import org.springframework.security.providers.ldap.authenticator.PasswordComparisonAuthenticator +import org.springframework.security.providers.UsernamePasswordAuthenticationToken +from jarray import array + +print """ +WARNING WARNING WARNING WARNING +=============================== +This doesn't yet work. There is some issue with Jython. +See http://bugs.jython.org/issue1489 and http://jira.springframework.org/browse/SESPRINGPYTHONPY-121 for more details. +=============================== +WARNING WARNING WARNING WARNING +""" + +class DefaultSpringSecurityContextSource(object): + def __init__(self, url): + self._context = org.springframework.security.ldap.DefaultSpringSecurityContextSource(url) + java.lang.Thread.currentThread().setContextClassLoader(self._context.getClass().getClassLoader()) + self._context.afterPropertiesSet() + +class BindAuthenticator(object): + def __init__(self, context_source=None, user_dn_patterns="uid={0},ou=people"): + self.context_source = context_source + self.user_dn_patterns = user_dn_patterns + self.logger = logging.getLogger("springpython.security.providers.Ldap.BindAuthenticator") + self._authenticator = None + + def authenticate(self, authentication): + if self._authenticator is None: + self._authenticator = org.springframework.security.providers.ldap.authenticator.BindAuthenticator(self.context_source._context) + self._authenticator.setUserDnPatterns(array([self.user_dn_patterns], java.lang.String)) + self._authenticator.afterPropertiesSet() + #java.lang.Thread.currentThread().setContextClassLoader(self._authenticator.getClass().getClassLoader()) + #print "BindAuthenticator class loader %s" % self._authenticator.getClass().getClassLoader() + token = org.springframework.security.providers.UsernamePasswordAuthenticationToken(authentication.username, authentication.password) + return self._authenticator.authenticate(token) + +class PasswordComparisonAuthenticator(object): + def __init__(self, context_source=None, user_dn_patterns="uid={0},ou=people", password_attr_name="userPassword"): + self.context_source = context_source + self.user_dn_patterns = user_dn_patterns + self.password_attr_name = password_attr_name + self.encoder = LdapShaPasswordEncoder() + self.logger = logging.getLogger("springpython.security.providers.Ldap.PasswordComparisonAuthenticator") + + def authenticate(self, authentication): + if jython: + raise Exception("This code doesn't work inside Jython.") + +class DefaultLdapAuthoritiesPopulator(object): + def __init__(self, context_source=None, group_search_base="ou=groups", group_search_filter="(member={0})", group_role_attr="cn", role_prefix="ROLE_", convert_to_upper=True): + self.logger = logging.getLogger("springpython.security.providers.Ldap.DefaultLdapAuthoritiesPopulator") + self.context_source = context_source + self.group_search_base = group_search_base + self.group_search_filter = group_search_filter + self.group_role_attr = group_role_attr + self.role_prefix = role_prefix + self.convert_to_upper = convert_to_upper + self._populator = org.springframework.security.ldap.populator.DefaultLdapAuthoritiesPopulator(self.context_source._context, self.group_search_base) + #java.lang.Thread.currentThread().setContextClassLoader(self._populator.getClass().getClassLoader()) + self._populator.setGroupSearchFilter(self.group_search_filter) + self._populator.setGroupRoleAttribute(self.group_role_attr) + self._populator.setRolePrefix(self.role_prefix) + self._populator.setConvertToUpperCase(self.convert_to_upper) + print "LdapAuthoritiesPopulator class loader %s" % self._populator.getClass().getClassLoader() + + def get_granted_auths(self, user_details, username): + results = self._populator.getGrantedAuthorities(user_details, username) + print results + return results + +class LdapAuthenticationProvider(AuthenticationProvider): + def __init__(self, ldap_authenticator=None, ldap_authorities_populator=None): + AuthenticationProvider.__init__(self) + self.ldap_authenticator = ldap_authenticator + self.ldap_authorities_populator = ldap_authorities_populator + self.logger = logging.getLogger("springpython.security.providers.Ldap.LdapAuthenticationProvider") + + def authenticate(self, authentication): + user_details = self.ldap_authenticator.authenticate(authentication) + print "Context class loader %s" % user_details.getClass().getClassLoader() + from copy import deepcopy + results = deepcopy(authentication) + results.granted_auths = self.ldap_authorities_populator.get_granted_auths(user_details, authentication.username) + results.setAuthenticated(True) + l.unbind() + return results + diff --git a/src/springpython/security/providers/__init__.py b/src/springpython/security/providers/__init__.py index 83ea345..ee090cb 100644 --- a/src/springpython/security/providers/__init__.py +++ b/src/springpython/security/providers/__init__.py @@ -102,9 +102,9 @@ def authenticate(self, authentication): if results: results.setAuthenticated(True) return results - except DisabledException, e: # Disabled means account found, but invalid + except DisabledException as e: # Disabled means account found, but invalid raise e - except AuthenticationException, e: + except AuthenticationException as e: authenticationException = e raise authenticationException diff --git a/src/springpython/security/providers/__init__.py.bak b/src/springpython/security/providers/__init__.py.bak new file mode 100644 index 0000000..83ea345 --- /dev/null +++ b/src/springpython/security/providers/__init__.py.bak @@ -0,0 +1,127 @@ +""" + Copyright 2006-2008 SpringSource (http://springsource.com), All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +import logging +from springpython.security import AuthenticationException +from springpython.security import BadCredentialsException +from springpython.security import DisabledException + +class ProviderNotFoundException(AuthenticationException): + """ + An exception thrown when a list of providers are polled for a security decision, + and none of them supports the request. + """ + pass + +class Authentication: + """ + Abstract representation of credential data. The premise is that username and password + are populated, and after authentication this record is returned with the third attribute, + granted authorities, populated. + """ + + def __init__(self): + self.__authenticated = False + + def isAuthenticated(self): + return self.__authenticated + + def setAuthenticated(self, authenticated): + self.__authenticated = authenticated + + def getCredentials(self): + raise NotImplementedError() + + def __str__(self): + raise AuthenticationException("You should be using a concrete authentication object") + +class UsernamePasswordAuthenticationToken(Authentication): + """ + A basic concrete version of authentication. Works for most scenarios. + """ + + def __init__(self, username = None, password = None, granted_auths = None): + Authentication.__init__(self) + self.username = username + self.password = password + if granted_auths is None: + self.granted_auths = [] + else: + self.granted_auths = granted_auths + + def getCredentials(self): + return self.password + + def __str__(self): + return "[UsernamePasswordAuthenticationToken] User: [%s] Password: [PROTECTED] GrantedAuthorities: %s Authenticated: %s" % \ + (self.username, self.granted_auths, self.isAuthenticated()) + +class AuthenticationManager: + """ + Iterates an Authentication request through a list of AuthenticationProviders. + + AuthenticationProviders are tried in order until one provides a non-null response. + A non-null response indicates the provider had authority to decide on the authentication + request and no further providers are tried. If an AuthenticationException is thrown by + a provider, it is retained until subsequent providers are tried. If a subsequent provider + successfully authenticates the request, the earlier authentication exception is disregarded + and the successful authentication will be used. If no subsequent provider provides a + non-null response, or a new AuthenticationException, the last AuthenticationException + received will be used. If no provider returns a non-null response, or indicates it can + even process an Authentication, the AuthenticationManager will throw a ProviderNotFoundException. + """ + + def __init__(self, auth_providers = None): + if auth_providers is None: + self.auth_providers = [] + else: + self.auth_providers = auth_providers + self.logger = logging.getLogger("springpython.security.providers.AuthenticationManager") + + def authenticate(self, authentication): + """ + Attempts to authenticate the passed Authentication object, returning a fully + populated Authentication object (including granted authorities) if successful. + """ + authenticationException = ProviderNotFoundException() + for auth_provider in self.auth_providers: + try: + results = auth_provider.authenticate(authentication) + if results: + results.setAuthenticated(True) + return results + except DisabledException, e: # Disabled means account found, but invalid + raise e + except AuthenticationException, e: + authenticationException = e + raise authenticationException + +class AuthenticationProvider(object): + """ + Indicates a class can process a specific Authentication implementation. + """ + + def authenticate(self, authentication): + """ + Performs authentication with the same contract as AuthenticationManager.authenticate(Authentication). + """ + raise NotImplementedError() + + def supports(self, authentication): + """ + Returns true if this AuthenticationProvider supports the indicated Authentication object. + """ + raise NotImplementedError() + diff --git a/src/springpython/security/providers/dao.py b/src/springpython/security/providers/dao.py index f9374d2..53256be 100644 --- a/src/springpython/security/providers/dao.py +++ b/src/springpython/security/providers/dao.py @@ -64,7 +64,7 @@ def authenticate(self, authentication): try: user = self.retrieve_user(username, authentication) - except UsernameNotFoundException, notFound: + except UsernameNotFoundException as notFound: if self.hide_user_not_found_exceptions: raise BadCredentialsException("UsernameNotFound: Bad credentials") else: @@ -86,7 +86,7 @@ def authenticate(self, authentication): # about account status unless they presented the correct credentials try: self.additional_auth_checks(user, authentication) - except AuthenticationException, exception: + except AuthenticationException as exception: if cache_was_used: # There was a problem, so try again after checking we're using latest data (ie not from the cache) cache_was_used = False @@ -137,7 +137,7 @@ def retrieve_user(self, username, authentication): try: loaded_user = self.user_details_service.load_user(username) - except DataAccessException, repositoryProblem: + except DataAccessException as repositoryProblem: raise AuthenticationServiceException(repositoryProblem) if loaded_user is None: @@ -193,7 +193,7 @@ def get_salt(self, user): try: reflectionMethod = getattr(user, self.user_prop_to_use) return reflectionMethod() - except Exception, e: + except Exception as e: raise AuthenticationServiceException(e); diff --git a/src/springpython/security/providers/dao.py.bak b/src/springpython/security/providers/dao.py.bak new file mode 100644 index 0000000..f9374d2 --- /dev/null +++ b/src/springpython/security/providers/dao.py.bak @@ -0,0 +1,199 @@ +""" + Copyright 2006-2008 SpringSource (http://springsource.com), All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +import logging +from springpython.database import DataAccessException +from springpython.security import AuthenticationException +from springpython.security import AuthenticationServiceException +from springpython.security import BadCredentialsException +from springpython.security import DisabledException +from springpython.security import UsernameNotFoundException +from springpython.security.providers import AuthenticationProvider +from springpython.security.providers import UsernamePasswordAuthenticationToken +from springpython.security.providers.encoding import PlaintextPasswordEncoder + +class UserCache(object): + def get_user(self, username): + raise NotImplementedError() + + def put_user(self, user): + raise NotImplementedError() + + def remove_user(self, username): + raise NotImplementedError() + +class NullUserCache(UserCache): + def get_user(self, username): + return None + + def put_user(self, user): + pass + + def remove_user(self, username): + pass + +class AbstractUserDetailsAuthenticationProvider(AuthenticationProvider): + def __init__(self): + super(AbstractUserDetailsAuthenticationProvider, self).__init__() + self.user_cache = NullUserCache() + self.hide_user_not_found_exceptions = True + self.force_principal_as_str = True + self.logger = logging.getLogger("springpython.security.providers.AbstractUserDetailsAuthenticationProvider") + + def authenticate(self, authentication): + # Determine username + username = authentication.username + + cache_was_used = True + user = self.user_cache.get_user(username) + + if user is None: + cache_was_used = False + + try: + user = self.retrieve_user(username, authentication) + except UsernameNotFoundException, notFound: + if self.hide_user_not_found_exceptions: + raise BadCredentialsException("UsernameNotFound: Bad credentials") + else: + raise notFound + + if user is None: + raise Exception("retrieve_user returned null - a violation of the interface contract") + + if not user.accountNonLocked: + raise LockedException("User account is locked") + + if not user.enabled: + raise DisabledException("User is disabled") + + if not user.accountNonExpired: + raise AccountExpiredException("User account has expired") + + # This check must come here, as we don't want to tell users + # about account status unless they presented the correct credentials + try: + self.additional_auth_checks(user, authentication) + except AuthenticationException, exception: + if cache_was_used: + # There was a problem, so try again after checking we're using latest data (ie not from the cache) + cache_was_used = False + user = self.retrieve_user(username, authentication) + self.additional_auth_checks(user, authentication) + else: + raise exception + + if not user.credentialsNonExpired: + raise CredentialsExpiredException("User credentials have expired") + + if not cache_was_used: + self.user_cache.put_user(user) + + principal_to_return = user + + if self.force_principal_as_str: + principal_to_return = user.username + + return self.create_success_auth(principal_to_return, authentication, user) + + def additional_auth_checks(self, user_details, authentication): + raise NotImplementedError() + + def retrieve_user(self, username, authentication): + raise NotImplementedError() + + def create_success_auth(self, principal, authentication, user): + # Ensure we return the original credentials the user supplied, + # so subsequent attempts are successful even with encoded passwords. + # Also ensure we return the original getDetails(), so that future + # authentication events after cache expiry contain the details + result = UsernamePasswordAuthenticationToken(principal, authentication.getCredentials(), user.authorities) + #result.details = authentication.details + return result + +class DaoAuthenticationProvider(AbstractUserDetailsAuthenticationProvider): + def __init__(self, user_details_service = None, password_encoder = PlaintextPasswordEncoder()): + super(DaoAuthenticationProvider, self).__init__() + self.password_encoder = password_encoder + self.salt_source = None + self.user_details_service = user_details_service + self.include_details_obj = True + self.logger = logging.getLogger("springpython.security.providers.DaoAuthenticationProvider") + + def retrieve_user(self, username, authentication): + loaded_user = None + + try: + loaded_user = self.user_details_service.load_user(username) + except DataAccessException, repositoryProblem: + raise AuthenticationServiceException(repositoryProblem) + + if loaded_user is None: + raise AuthenticationServiceException("UserDetailsService returned null, which is an interface contract violation") + + return loaded_user + + def additional_auth_checks(self, user_details, authentication): + salt = None + + if self.salt_source is not None: + salt = self.salt_source.get_salt(user_details) + + if not self.password_encoder.isPasswordValid(user_details.password, authentication.getCredentials(), salt): + raise BadCredentialsException("additional_auth_checks: Bad credentials") + +class SaltSource(object): + """Provides alternative sources of the salt to use for encoding passwords.""" + + def get_salt(self, user): + """Returns the salt to use for the indicated user.""" + raise NotImplementedError() + +class SystemWideSaltSource(SaltSource): + """ + Uses a static system-wide String as the salt. + + Does not supply a different salt for each User. This means users sharing the same + password will still have the same digested password. Of benefit is the digested passwords will at least be more protected than if stored without any salt. + """ + + def __init__(self, system_wide_salt = ""): + super(SystemWideSaltSource, self).__init__() + self.system_wide_salt = system_wide_salt + + def get_salt(self, user): + return self.system_wide_salt + +class ReflectionSaltSource(SaltSource): + """ + Obtains a salt from a specified property of the User object. + + This allows you to subclass User and provide an additional bean getter for a salt. + You should use a synthetic value that does not change, such as a database primary key. + Do not use username if it is likely to change. + """ + + def __init__(self, user_prop_to_use = ""): + super(ReflectionSaltSource, self).__init__() + self.user_prop_to_use = user_prop_to_use + + def get_salt(self, user): + try: + reflectionMethod = getattr(user, self.user_prop_to_use) + return reflectionMethod() + except Exception, e: + raise AuthenticationServiceException(e); + + diff --git a/src/springpython/security/web.py b/src/springpython/security/web.py index e51faf8..160e5cd 100644 --- a/src/springpython/security/web.py +++ b/src/springpython/security/web.py @@ -36,7 +36,7 @@ class Filter(object): def doNextFilter(self, environ, start_response): results = None try: - nextFilter = environ["SPRINGPYTHON_FILTER_CHAIN"].next() + nextFilter = next(environ["SPRINGPYTHON_FILTER_CHAIN"]) if isinstance(nextFilter, tuple): func = nextFilter[0] args = nextFilter[1] @@ -101,7 +101,7 @@ def __call__(self, environ, start_response): for filter in chainOfFilters: try: filterChain.addFilter(self.app_context.get_object(filter)) - except AttributeError, e: + except AttributeError as e: filterChain.addFilter(filter) break @@ -239,7 +239,7 @@ def __call__(self, environ, start_response): self.logger.debug("Trying to authenticate %s using the authentication manager" % token) SecurityContextHolder.getContext().authentication = self.auth_manager.authenticate(token) self.logger.debug("%s was successfully authenticated, access GRANTED." % token.username) - except AuthenticationException, e: + except AuthenticationException as e: self.logger.debug("Authentication failure, access DENIED.") raise @@ -356,10 +356,10 @@ def __init__(self, authenticationEntryPoint=None, accessDeniedHandler=None, redi def __call__(self, environ, start_response): try: return self.doNextFilter(environ, start_response) - except AuthenticationException, e: + except AuthenticationException as e: self.logger.debug("AuthenticationException => %s, redirecting through authenticationEntryPoint" % e) return self.authenticationEntryPoint(environ, start_response) - except AccessDeniedException, e: + except AccessDeniedException as e: self.logger.debug("AccessDeniedException => %s, redirect through accessDeniedHandler" % e) return self.accessDeniedHandler(environ, start_response) @@ -415,5 +415,5 @@ def __setattr__(self, name, value): self.__dict__[name] = value def __call__(self, environ, start_response): - setattr(self.middleware, self.appAttribute, environ["SPRINGPYTHON_FILTER_CHAIN"].next()) + setattr(self.middleware, self.appAttribute, next(environ["SPRINGPYTHON_FILTER_CHAIN"])) return self.middleware(environ, start_response) diff --git a/src/springpython/security/web.py.bak b/src/springpython/security/web.py.bak new file mode 100644 index 0000000..e51faf8 --- /dev/null +++ b/src/springpython/security/web.py.bak @@ -0,0 +1,419 @@ +""" + Copyright 2006-2008 SpringSource (http://springsource.com), All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +import Cookie +import logging +import re +import pickle +import types +from springpython.context import ApplicationContextAware +from springpython.aop import utils +from springpython.security import AccessDeniedException +from springpython.security import AuthenticationException +from springpython.security.context import SecurityContext +from springpython.security.context import SecurityContextHolder +from springpython.security.intercept import AbstractSecurityInterceptor +from springpython.security.intercept import ObjectDefinitionSource +from springpython.security.providers import UsernamePasswordAuthenticationToken + +logger = logging.getLogger("springpython.security.web") + +class Filter(object): + """This is the interface definition of a filter. It must process a request/response.""" + + def doNextFilter(self, environ, start_response): + results = None + try: + nextFilter = environ["SPRINGPYTHON_FILTER_CHAIN"].next() + if isinstance(nextFilter, tuple): + func = nextFilter[0] + args = nextFilter[1] + results = func(args) + else: + results = nextFilter(environ, start_response) + except StopIteration: + pass + + # Apparently, passing back a generator trips up CherryPy and causes it to skip + # the filters. If a generator is detected, convert it to a standard array. + if type(results) == types.GeneratorType: + results = [line for line in results] + + return results + +class FilterChain(object): + """ + Collection of WSGI filters. It allows dynamic re-chaining of filters as the situation is needed. + + In order to link in 3rd party WSGI middleware, see MiddlewareFilter. + """ + + def __init__(self): + self.chain = [] + + def addFilter(self, filter): + self.chain.append(filter) + + def getFilterChain(self): + for filter in self.chain: + yield filter + +class FilterChainProxy(Filter, ApplicationContextAware): + """ + This acts as filter, and delegates to a chain of filters. Each time a web page is called, it dynamically + assembles a FilterChain, and then iterates over it. This is different than the conventional style of + wrapping applications for WSGI, because each URL pattern might have a different chained combination + of the WSGI filters. + + Because most middleware objects define the wrapped application using __init__, Spring provides + the MiddlewareFilter, to help wrap any middleware object so that it can participate in a + FilterChain. + """ + + def __init__(self, filterInvocationDefinitionSource=None): + """This class must be application-context aware in case it is instantiated inside an IoC container.""" + ApplicationContextAware.__init__(self) + if filterInvocationDefinitionSource is None: + self.filterInvocationDefinitionSource = [] + else: + self.filterInvocationDefinitionSource = filterInvocationDefinitionSource + self.logger = logging.getLogger("springpython.security.web.FilterChainProxy") + self.application = None + + def __call__(self, environ, start_response): + """This will route all requests/responses through the chain of filters.""" + filterChain = FilterChain() + for urlPattern, chainOfFilters in self.filterInvocationDefinitionSource: + if re.compile(urlPattern).match(environ["PATH_INFO"].lower()): + self.logger.debug("We had a match of %s against %s" % (environ["PATH_INFO"], urlPattern)) + for filter in chainOfFilters: + try: + filterChain.addFilter(self.app_context.get_object(filter)) + except AttributeError, e: + filterChain.addFilter(filter) + break + + # Put the actual application on the end of the chain. + if self.application: + filterChain.addFilter(self.application) + environ["SPRINGPYTHON_FILTER_CHAIN"] = filterChain.getFilterChain() + return self.doNextFilter(environ, start_response) + +class SessionStrategy(object): + """ + This is an interface definition in defining access to session data. There may be many + ways to implement session data. This makes the mechanism pluggable. + """ + + def getHttpSession(self, environ): + raise NotImplementedError() + + def setHttpSession(self, key, value): + raise NotImplementedError() + +class HttpSessionContextIntegrationFilter(Filter): + """ + This filter is meant to pull security context information from the HttpSession, and store it in the + SecurityContextHolder. Then on the response, copy and SecurityContext information back into the HttpSession. + """ + + # Key to the SecurityContext data stored in an HttpSession dictionary. + SPRINGPYTHON_SECURITY_CONTEXT_KEY = "SPRINGPYTHON_SECURITY_CONTEXT_KEY" + + # Class name used + context = SecurityContext + + def __init__(self, sessionStrategy=None): + self.sessionStrategy = sessionStrategy + self.logger = logging.getLogger("springpython.security.web.HttpSessionContextIntegrationFilter") + + def __call__(self, environ, start_response): + """This filter copies SecurityContext information back and forth between the HttpSession and the SecurityContextHolder.""" + + httpSession = self.sessionStrategy.getHttpSession(environ) + contextWhenChainProceeded = None + + if httpSession is not None: + + contextFromSessionObject = None + if self.SPRINGPYTHON_SECURITY_CONTEXT_KEY in httpSession: + contextFromSessionObject = pickle.loads(httpSession[self.SPRINGPYTHON_SECURITY_CONTEXT_KEY]) + + if contextFromSessionObject is not None: + if isinstance(contextFromSessionObject, SecurityContext): + self.logger.debug("Obtained from SPRINGPYTHON_SECURITY_CONTEXT_KEY a valid SecurityContext and set " + + "to SecurityContextHolder: '%s'" % contextFromSessionObject) + SecurityContextHolder.setContext(contextFromSessionObject) + else: + self.logger.warn("SPRINGPYTHON_SECURITY_CONTEXT_KEY did not contain a SecurityContext but contained: '%s'" % contextFromSessionObject + + "'; are you improperly modifying the HttpSession directly (you should always use " + + "SecurityContextHolder) or using the HttpSession attribute reserved for this class? " + + "- new SecurityContext instance associated with SecurityContextHolder") + SecurityContextHolder.setContext(self.generateNewContext()) + else: + self.logger.debug("HttpSession returned null object for SPRINGPYTHON_SECURITY_CONTEXT_KEY " + + "- new SecurityContext instance associated with SecurityContextHolder") + SecurityContextHolder.setContext(self.generateNewContext()) + + else: + self.logger.debug("No HttpSession currently exists - new SecurityContext instance associated with SecurityContextHolder") + SecurityContextHolder.setContext(self.generateNewContext()) + + self.logger.debug("Setting contextWhenChainProceeded to %s" % SecurityContextHolder.getContext()) + contextWhenChainProceeded = str(SecurityContextHolder.getContext()) + + results = self.doNextFilter(environ, start_response) + + self.sessionStrategy.setHttpSession(self.SPRINGPYTHON_SECURITY_CONTEXT_KEY, + pickle.dumps(SecurityContextHolder.getContext())) + self.logger.debug("SecurityContext stored to HttpSession: '%s'" % SecurityContextHolder.getContext()) + + SecurityContextHolder.clearContext() + self.logger.debug("SecurityContextHolder cleared out, as request processing completed") + + return results + + def setContext(self, clazz): + """This is a factory setter. The context parameter is used to create new security context objects.""" + self.context = clazz + + def generateNewContext(self): + """This is a factory method that instantiates the assigned class, and populates it with an empty token.""" + context = self.context() + context.authentication = UsernamePasswordAuthenticationToken() + return context + + def saveContext(self): + self.sessionStrategy.setHttpSession(self.SPRINGPYTHON_SECURITY_CONTEXT_KEY, + pickle.dumps(SecurityContextHolder.getContext())) + +class RedirectStrategy(object): + """ + This class provides a mechanism to redirect users to another page. Currently, it returns a + standard forwarding message to the browser. This may not be the most efficient, but it guarantees + the entire WSGI stack is processed on both request and response. + """ + + def redirect(self, url): + """This is a 0-second redirect.""" + return """""" % url + +class AuthenticationProcessingFilter(Filter): + """ + This filter utilizes the authentication manager to make sure the requesting person is authenticated. + It expects the SecurityContextHolder to be populated when it runs, so it is always good to preceed it + with the HttpSessionContextIntegrationFilter. + """ + + def __init__(self, auth_manager=None, alwaysReauthenticate=False): + self.auth_manager = auth_manager + self.alwaysReauthenticate = alwaysReauthenticate + self.logger = logging.getLogger("springpython.security.web.AuthenticationProcessingFilter") + + def __call__(self, environ, start_response): + """ + Check if the user is trying to access the login url. Then see if they are already authenticated (and + alwaysReauthenticate is disabled). Finally, try to authenticate the user. If successful, stored credentials + in SecurityContextHolder. Otherwise, redirect to the login page. + """ + # If the user is already authenticated, skip this filter. + if not self.alwaysReauthenticate and SecurityContextHolder.getContext().authentication.isAuthenticated(): + self.logger.debug("You are not required to reauthenticate everytime, and appear to already be authenticted, access GRANTED.") + return self.doNextFilter(environ, start_response) + + try: + # Authenticate existing credentials using the authentication manager. + token = SecurityContextHolder.getContext().authentication + self.logger.debug("Trying to authenticate %s using the authentication manager" % token) + SecurityContextHolder.getContext().authentication = self.auth_manager.authenticate(token) + self.logger.debug("%s was successfully authenticated, access GRANTED." % token.username) + except AuthenticationException, e: + self.logger.debug("Authentication failure, access DENIED.") + raise + + return self.doNextFilter(environ, start_response) + + def logout(self): + SecurityContextHolder.getContext().authentication = UsernamePasswordAuthenticationToken() + +class FilterInvocation: + """Holds objects associated with a WSGI filter, such as environ. This is the web-application equivalent to MethodInvocation.""" + + def __init__(self, environ): + self.environ = environ + + def requestUrl(self): + return self.environ["PATH_INFO"] + +class AbstractFilterInvocationDefinitionSource(ObjectDefinitionSource): + """Abstract implementation of ObjectDefinitionSource.""" + + def get_attributes(self, obj): + try: + return self.lookupAttributes(obj.requestUrl()) + except AttributeError: + raise TypeError("obj must be a FilterInvocation") + + def lookupAttributes(self, url): + raise NotImplementedError() + +class RegExpBasedFilterInvocationDefinitionMap(AbstractFilterInvocationDefinitionSource): + """ + Maintains a list of ObjectDefinitionSource's associated with different HTTP request URL regular expression patterns. + + Regular expressions are used to match a HTTP request URL against a ConfigAttributeDefinition. The order of registering + the regular expressions is very important. The system will identify the first matching regular expression for a given + HTTP URL. It will not proceed to evaluate later regular expressions if a match has already been found. + + Accordingly, the most specific regular expressions should be registered first, with the most general regular expressions registered last. + """ + + def __init__(self, obj_def_source): + self.obj_def_source = obj_def_source + + def lookupAttributes(self, url): + if self.obj_def_source: + for rule, attr in self.obj_def_source: + if re.compile(rule).match(url): + return attr + return None + +class FilterSecurityInterceptor(Filter, AbstractSecurityInterceptor): + """ + Performs security handling of HTTP resources via a filter implementation. + + The ObjectDefinitionSource required by this security interceptor is of type AbstractFilterInvocationDefinitionSource. + + Refer to AbstractSecurityInterceptor for details on the workflow. + """ + + # Key to the FilterSecurityInterceptor's token data stored in an HttpSession dictionary. + SPRINGPYTHON_FILTER_SECURITY_INTERCEPTOR_KEY = "SPRINGPYTHON_FILTER_SECURITY_INTERCEPTOR_KEY" + + def __init__(self, auth_manager = None, access_decision_mgr = None, obj_def_source = None, sessionStrategy=None): + Filter.__init__(self) + AbstractSecurityInterceptor.__init__(self, auth_manager, access_decision_mgr, obj_def_source) + self.sessionStrategy = sessionStrategy + self.obj_def_source = obj_def_source + + def __setattr__(self, name, value): + if name == "obj_def_source" and value is not None: + self.__dict__[name] = RegExpBasedFilterInvocationDefinitionMap(value) + else: + self.__dict__[name] = value + + def obtain_obj_def_source(self): + return self.obj_def_source + + def __call__(self, environ, start_response): + httpSession = self.sessionStrategy.getHttpSession(environ) + self.logger.debug("Trying to check if you are authorized for this.") + fi = FilterInvocation(environ) + token = self.before_invocation(fi) + if httpSession is not None: + httpSession[self.SPRINGPYTHON_FILTER_SECURITY_INTERCEPTOR_KEY] = token + + return self.doNextFilter(environ, start_response) + + if httpSession is not None and self.SPRINGPYTHON_FILTER_SECURITY_INTERCEPTOR_KEY in httpSession: + token = httpSession[self.SPRINGPYTHON_FILTER_SECURITY_INTERCEPTOR_KEY] + self.after_invocation(token, None) + + return results + +class ExceptionTranslationFilter(Filter): + """ + Handles any AccessDeniedException and AuthenticationException thrown within the filter chain. + + This filter is necessary because it provides the bridge between Python exceptions and HTTP responses. + It is solely concerned with maintaining the user interface. This filter does not do any actual security enforcement. + + If an AuthenticationException is detected, the filter will launch the authenticationEntryPoint. This allows common + handling of authentication failures originating from any subclass of AuthenticationProcessingFilter. + + If an AccessDeniedException is detected, the filter will launch the accessDeniedHandler. This allows common + handling of access failures originating from any subclass of AbstractSecurityInterceptor. + """ + + def __init__(self, authenticationEntryPoint=None, accessDeniedHandler=None, redirectStrategy=None): + Filter.__init__(self) + self.authenticationEntryPoint = authenticationEntryPoint + self.accessDeniedHandler = accessDeniedHandler + self.logger = logging.getLogger("springpython.security.web.ExceptionTranslationFilter") + + def __call__(self, environ, start_response): + try: + return self.doNextFilter(environ, start_response) + except AuthenticationException, e: + self.logger.debug("AuthenticationException => %s, redirecting through authenticationEntryPoint" % e) + return self.authenticationEntryPoint(environ, start_response) + except AccessDeniedException, e: + self.logger.debug("AccessDeniedException => %s, redirect through accessDeniedHandler" % e) + return self.accessDeniedHandler(environ, start_response) + +class AuthenticationProcessingFilterEntryPoint(Filter): + """This object holds the location of the login form, and is used to commence a redirect to that form.""" + + def __init__(self, loginFormUrl=None, redirectStrategy=None): + Filter.__init__(self) + self.loginFormUrl = loginFormUrl + self.redirectStrategy = redirectStrategy + self.logger = logging.getLogger("springpython.security.web.AuthenticationProcessingFilterEntryPoint") + + def __call__(self, environ, start_response): + self.logger.debug("Redirecting to login page %s" % self.loginFormUrl) + return self.redirectStrategy.redirect(self.loginFormUrl) + +class AccessDeniedHandler(Filter): + """Used by ExceptionTranslationFilter to handle an AccessDeniedException.""" + + def __init__(self): + Filter.__init__(self) + +class SimpleAccessDeniedHandler(AccessDeniedHandler): + """A simple default implementation of the AccessDeniedHandler interface.""" + + def __init__(self, errorPage=None, redirectStrategy=None): + AccessDeniedHandler.__init__(self) + self.errorPage = errorPage + self.redirectStrategy = redirectStrategy + self.logger = logging.getLogger("springpython.security.web.SimpleAccessDeniedHandler") + + def __call__(self, environ, start_response): + self.logger.debug("Redirecting to error page %s" % self.errorPage) + return self.redirectStrategy.redirect(self.errorPage) + +class MiddlewareFilter(Filter): + """ + This filter allows you to wrap any WSGI-compatible middleware and use it as a Spring Python filter. + This is primary because lots of middleware objects requires the wrapped WSGI app to be included + in the __init__ method. Spring's IoC container currently doesn't support constructor arguments. + """ + + def __init__(self, clazz = None, appAttribute = None): + Filter.__init__(self) + self.clazz = clazz + self.appAttribute = appAttribute + + def __setattr__(self, name, value): + if name == "clazz" and value is not None: + self.__dict__[name] = value + self.middleware = utils.getClass(value)(None) + else: + self.__dict__[name] = value + + def __call__(self, environ, start_response): + setattr(self.middleware, self.appAttribute, environ["SPRINGPYTHON_FILTER_CHAIN"].next()) + return self.middleware(environ, start_response) diff --git a/src/springpython/util.py b/src/springpython/util.py index 5166172..571a222 100644 --- a/src/springpython/util.py +++ b/src/springpython/util.py @@ -19,7 +19,7 @@ try: from cStringIO import StringIO -except ImportError, e: +except ImportError as e: from StringIO import StringIO @@ -43,7 +43,7 @@ def lockedfunc(*args, **kwargs): self.logger.log(TRACE1, "Acquired lock [%s] thread [%s]" % (self.lock, currentThread())) try: return f(*args, **kwargs) - except Exception, e: + except Exception as e: raise finally: self.lock.release() diff --git a/src/springpython/util.py.bak b/src/springpython/util.py.bak new file mode 100644 index 0000000..5166172 --- /dev/null +++ b/src/springpython/util.py.bak @@ -0,0 +1,51 @@ +""" + Copyright 2006-2008 SpringSource (http://springsource.com), All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +import logging +import traceback +from threading import RLock, currentThread + +try: + from cStringIO import StringIO +except ImportError, e: + from StringIO import StringIO + + +TRACE1 = 6 +logging.addLevelName(TRACE1, "TRACE1") + +# Original code by Anand Balachandran Pillai (abpillai at gmail.com) +# http://code.activestate.com/recipes/533135/ +class synchronized(object): + """ Class enapsulating a lock and a function allowing it to be used as + a synchronizing decorator making the wrapped function thread-safe """ + + def __init__(self, *args): + self.lock = RLock() + self.logger = logging.getLogger("springpython.util.synchronized") + + def __call__(self, f): + def lockedfunc(*args, **kwargs): + try: + self.lock.acquire() + self.logger.log(TRACE1, "Acquired lock [%s] thread [%s]" % (self.lock, currentThread())) + try: + return f(*args, **kwargs) + except Exception, e: + raise + finally: + self.lock.release() + self.logger.log(TRACE1, "Released lock [%s] thread [%s]" % (self.lock, currentThread())) + return lockedfunc diff --git a/test/springpythontest/allTests.py b/test/springpythontest/allTests.py index 0f8c0c0..5b2f9cc 100644 --- a/test/springpythontest/allTests.py +++ b/test/springpythontest/allTests.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ +from __future__ import print_function import logging import unittest, os import springpython @@ -20,10 +21,10 @@ if __name__ == "__main__": module_name = sys.argv[1] - print "Trying to import module %s" % module_name + print("Trying to import module %s" % module_name) mod = __import__("%s" % module_name) - print mod + print(mod) logger = logging.getLogger("springpython") loggingLevel = logging.INFO diff --git a/test/springpythontest/allTests.py.bak b/test/springpythontest/allTests.py.bak new file mode 100644 index 0000000..0f8c0c0 --- /dev/null +++ b/test/springpythontest/allTests.py.bak @@ -0,0 +1,39 @@ +""" + Copyright 2006-2008 SpringSource (http://springsource.com), All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +import logging +import unittest, os +import springpython +import sys + +if __name__ == "__main__": + module_name = sys.argv[1] + print "Trying to import module %s" % module_name + mod = __import__("%s" % module_name) + + print mod + + logger = logging.getLogger("springpython") + loggingLevel = logging.INFO + logger.setLevel(loggingLevel) + ch = logging.StreamHandler() + ch.setLevel(loggingLevel) + formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") + ch.setFormatter(formatter) + logger.addHandler(ch) + + suite = unittest.TestSuite() + suite.addTest(unittest.defaultTestLoader.loadTestsFromModule(mod)) + unittest.TextTestRunner(verbosity=3).run(suite) diff --git a/test/springpythontest/contextTestCases.py b/test/springpythontest/contextTestCases.py index 22e87af..c58ab79 100644 --- a/test/springpythontest/contextTestCases.py +++ b/test/springpythontest/contextTestCases.py @@ -1036,7 +1036,7 @@ def test_default_mapping_error_no_type_defined(self): # Will raise KeyError: 'class' try: ApplicationContext(YamlConfig("support/contextYamlBuiltinTypesErrorNoTypeDefined.yaml")) - except KeyError, e: + except KeyError as e: # Meaning there was no 'class' key found. self.assertEqual(e.message, "class") else: @@ -1534,7 +1534,7 @@ def bar(): return Bar() # A reference to the function wrapping the actual 'foo' function. - foo_wrapper = foo.func_globals["_call_"] + foo_wrapper = foo.__globals__["_call_"] # Create an object definition, note that we're telling to return foo_object_def = ObjectDef(id="foo", @@ -1542,7 +1542,7 @@ def bar(): lazy_init=foo_wrapper.lazy_init) # A reference to the function wrapping the actual 'bar' function. - bar_wrapper = foo.func_globals["_call_"] + bar_wrapper = foo.__globals__["_call_"] bar_object_def = ObjectDef(id="foo", factory=PythonObjectFactory(bar, bar_wrapper), scope=SINGLETON, @@ -1999,6 +1999,6 @@ def invalid(self): _globals["Object"] = Object def should_raise_invalid_object_scope(): - exec invalid in _globals, _locals + exec(invalid, _globals, _locals) self.assertRaises(InvalidObjectScope, should_raise_invalid_object_scope) diff --git a/test/springpythontest/contextTestCases.py.bak b/test/springpythontest/contextTestCases.py.bak new file mode 100644 index 0000000..22e87af --- /dev/null +++ b/test/springpythontest/contextTestCases.py.bak @@ -0,0 +1,2004 @@ +# -*- coding: utf-8 -*- + +""" + Copyright 2006-2008 SpringSource (http://springsource.com), All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +# pmock +from pmock import * + +import sys +import atexit +import random +import unittest +from decimal import Decimal +from StringIO import StringIO + +from springpython.context import DisposableObject +from springpython.context import ApplicationContext +from springpython.context import ObjectPostProcessor +from springpython.config import PythonConfig +from springpython.config import PyContainerConfig +from springpython.config import SpringJavaConfig +from springpython.config import Object +from springpython.config import XMLConfig, xml_mappings +from springpython.config import YamlConfig, yaml_mappings +from springpython.config import Object, ObjectDef +from springpython.factory import PythonObjectFactory +from springpython.remoting.pyro import PyroProxyFactory +from springpython.security.userdetails import InMemoryUserDetailsService +from springpythontest.support import testSupportClasses +from springpython.context.scope import SINGLETON, PROTOTYPE +from springpython.container import AbstractObjectException, InvalidObjectScope + +class PyContainerTestCase(unittest.TestCase): + def testCreatingAnApplicationContext(self): + movieAppContainer = ApplicationContext(PyContainerConfig("support/contextTestPrimaryApplicationContext.xml")) + self.assertTrue(isinstance(movieAppContainer, ApplicationContext)) + lister = movieAppContainer.get_object("MovieLister") + movieList = lister.finder.findAll() + self.assertEquals(movieList[0], "The Count of Monte Cristo") + + def testLoadingMultipleApplicationContexts(self): + """When reading multiple sources, later object definitions can override earlier ones.""" + movieAppContainer = ApplicationContext(PyContainerConfig(["support/contextTestPrimaryApplicationContext.xml", "support/contextTestSecondaryApplicationContext.xml"])) + self.assertTrue(isinstance(movieAppContainer, ApplicationContext)) + lister = movieAppContainer.get_object("MovieLister") + movieList = lister.finder.findAll() + self.assertEquals(movieList[0], "Sta") + + def testCreatingXmlBasedIocContainerUsingDirectFunctionCalls(self): + movieAppContainer = ApplicationContext(PyContainerConfig("support/contextSingletonPrototypeObjectContext.xml")) + self.assertTrue(isinstance(movieAppContainer, ApplicationContext)) + lister = movieAppContainer.get_object("MovieLister") + movieList = lister.finder.findAll() + self.assertEquals(movieList[0], "The Count of Monte Cristo") + self.assertEquals(lister.description.str, "There should only be one copy of this string") + + # Create a separate container, which has its own instances of singletons + movieAppContainer2 = ApplicationContext(PyContainerConfig("support/contextSingletonPrototypeObjectContext.xml")) + self.assertTrue(isinstance(movieAppContainer2, ApplicationContext)) + lister2 = movieAppContainer2.get_object("MovieLister") + movieList2 = lister2.finder.findAll() + self.assertEquals(movieList2[0], "The Count of Monte Cristo") + self.assertEquals(lister2.description.str, "There should only be one copy of this string") + + # Create another MovieLister based on the first app context + lister3 = movieAppContainer.get_object("MovieLister") + + # Identity test. Verify objects were created in separate app contexts, and that + # singletons exist only once, while prototypes are different on a per instance + # basis. + + # While the strings hold the same value... + self.assertEquals(lister.description.str, lister2.description.str) + self.assertEquals(lister2.description.str, lister3.description.str) + + # ...they are not necessarily the same object + self.assertEquals(lister.description, lister3.description) + self.assertNotEquals(lister.description, lister2.description) + + # The finder is also a singleton, only varying between containers + self.assertNotEquals(lister.finder, lister2.finder) + self.assertEquals(lister.finder, lister3.finder) + + # The MovieLister's are prototypes, and different within and between containers. + self.assertNotEquals(lister, lister2) + self.assertNotEquals(lister, lister3) + self.assertNotEquals(lister2, lister3) + +class PurePythonContainerTestCase(unittest.TestCase): + def testCreatingDecoratorBasedIocContainerUsingAppContextCalls(self): + movieAppContainer = ApplicationContext(testSupportClasses.MovieBasedApplicationContext()) + self.assertTrue(isinstance(movieAppContainer, ApplicationContext)) + self.assertFalse(movieAppContainer.object_defs[u"MovieLister"].lazy_init) + self.assertTrue(movieAppContainer.object_defs[u"MovieFinder"].lazy_init) + self.assertTrue(movieAppContainer.object_defs[u"SingletonString"].lazy_init) + lister = movieAppContainer.get_object("MovieLister") + movieList = lister.finder.findAll() + self.assertEquals(movieList[0], "The Count of Monte Cristo") + self.assertEquals(lister.description.str, "There should only be one copy of this string") + + def testCreatingMovieListerBeforeSingletonString(self): + movieAppContainer = ApplicationContext(testSupportClasses.MovieBasedApplicationContext()) + self.assertTrue(isinstance(movieAppContainer, ApplicationContext)) + lister = movieAppContainer.get_object("MovieLister") + singleString = movieAppContainer.get_object("SingletonString") + + # Identity test + self.assertEquals(lister.description, singleString) + + def testCreatingSingletonStringBeforeMovieLister(self): + movieAppContainer = ApplicationContext(testSupportClasses.MovieBasedApplicationContext()) + self.assertTrue(isinstance(movieAppContainer, ApplicationContext)) + singleString = movieAppContainer.get_object("SingletonString") + lister = movieAppContainer.get_object("MovieLister") + + # Identity test +# self.assertEquals(lister.description, singleString) + + def testCreatingDecoratorBasedIocContainerUsingDirectFunctionCalls(self): + movieAppContainer = ApplicationContext(testSupportClasses.MovieBasedApplicationContext()) + self.assertTrue(isinstance(movieAppContainer, ApplicationContext)) + lister = movieAppContainer.get_object("MovieLister") + movieList = lister.finder.findAll() + self.assertEquals(movieList[0], "The Count of Monte Cristo") + self.assertEquals(lister.description.str, "There should only be one copy of this string") + + # Create a separate container, which has its own instances of singletons + movieAppContainer2 = ApplicationContext(testSupportClasses.MovieBasedApplicationContext()) + self.assertTrue(isinstance(movieAppContainer2, ApplicationContext)) + lister2 = movieAppContainer2.get_object("MovieLister") + movieList2 = lister2.finder.findAll() + self.assertEquals(movieList2[0], "The Count of Monte Cristo") + self.assertEquals(lister2.description.str, "There should only be one copy of this string") + + # Create another MovieLister based on the first app context + lister3 = movieAppContainer.get_object("MovieLister") + + # Identity test. Verify objects were created in separate app contexts, and that + # singletons exist only once, while prototypes are different on a per instance + # basis. + + # The MovieLister's are prototypes, and different within and between containers. + self.assertNotEquals(lister, lister2) + self.assertNotEquals(lister, lister3) + self.assertNotEquals(lister2, lister3) + + # While the strings hold the same value... + self.assertEquals(lister.description.str, lister2.description.str) + self.assertEquals(lister2.description.str, lister3.description.str) + + # ...they are not necessarily the same object + self.assertEquals(lister.description, lister3.description) + self.assertNotEquals(lister.description, lister2.description) + + # The finder is also a singleton, only varying between containers + self.assertNotEquals(lister.finder, lister2.finder) + self.assertEquals(lister.finder, lister3.finder) + +class MixedConfigurationContainerTestCase(unittest.TestCase): + def testXmlPullingPurePythonObject(self): + movieAppContainer = ApplicationContext([testSupportClasses.MixedApplicationContext(), + PyContainerConfig("support/contextMixedObjectContext.xml")]) + + self.assertTrue(isinstance(movieAppContainer, ApplicationContext)) + lister = movieAppContainer.get_object("MovieLister") + movieList = lister.finder.findAll() + self.assertEquals(movieList[0], "The Count of Monte Cristo") + self.assertEquals(lister.description.str, "There should only be one copy of this string") + + # Create a separate container, which has its own instances of singletons + movieAppContainer2 = ApplicationContext([testSupportClasses.MixedApplicationContext(), + PyContainerConfig("support/contextMixedObjectContext.xml")]) + self.assertTrue(isinstance(movieAppContainer2, ApplicationContext)) + lister2 = movieAppContainer2.get_object("MovieLister") + movieList2 = lister2.finder.findAll() + self.assertEquals(movieList2[0], "The Count of Monte Cristo") + self.assertEquals(lister2.description.str, "There should only be one copy of this string") + + # Create another MovieLister based on the first app context + lister3 = movieAppContainer.get_object("MovieLister") + + # Identity test. Verify objects were created in separate app contexts, and that + # singletons exist only once, while prototypes are different on a per instance + # basis. + + # The MovieLister's are prototypes, and different within and between containers. + self.assertNotEquals(lister, lister2) + self.assertNotEquals(lister, lister3) + self.assertNotEquals(lister2, lister3) + + # While the strings hold the same value... + self.assertEquals(lister.description.str, lister2.description.str) + self.assertEquals(lister2.description.str, lister3.description.str) + + # ...they are not necessarily the same object + self.assertEquals(lister.description, lister3.description) + self.assertNotEquals(lister.description, lister2.description) + + # The finder is also a singleton, only varying between containers + self.assertNotEquals(lister.finder, lister2.finder) + self.assertEquals(lister.finder, lister3.finder) + + def testPurePythonPullingXmlObject(self): + movieAppContainer = ApplicationContext([testSupportClasses.MixedApplicationContext2(), + PyContainerConfig("support/contextMixedObjectContext2.xml")]) + + self.assertTrue(isinstance(movieAppContainer, ApplicationContext)) + lister = movieAppContainer.get_object("MovieLister") + movieList = lister.finder.findAll() + self.assertEquals(movieList[0], "The Count of Monte Cristo") + self.assertEquals(lister.description.str, "There should only be one copy of this string") + + # Create a separate container, which has its own instances of singletons + movieAppContainer2 = ApplicationContext([testSupportClasses.MixedApplicationContext2(), + PyContainerConfig("support/contextMixedObjectContext2.xml")]) + self.assertTrue(isinstance(movieAppContainer2, ApplicationContext)) + lister2 = movieAppContainer2.get_object("MovieLister") + movieList2 = lister2.finder.findAll() + self.assertEquals(movieList2[0], "The Count of Monte Cristo") + self.assertEquals(lister2.description.str, "There should only be one copy of this string") + + # Create another MovieLister based on the first app context + lister3 = movieAppContainer.get_object("MovieLister") + + # Identity test. Verify objects were created in separate app contexts, and that + # singletons exist only once, while prototypes are different on a per instance + # basis. + + # The MovieLister's are prototypes, and different within and between containers. + self.assertNotEquals(lister, lister2) + self.assertNotEquals(lister, lister3) + self.assertNotEquals(lister2, lister3) + + # While the strings hold the same value... + self.assertEquals(lister.description.str, lister2.description.str) + self.assertEquals(lister2.description.str, lister3.description.str) + + # ...they are not necessarily the same object + self.assertEquals(lister.description, lister3.description) + self.assertNotEquals(lister.description, lister2.description) + + # The finder is also a singleton, only varying between containers + self.assertNotEquals(lister.finder, lister2.finder) + self.assertEquals(lister.finder, lister3.finder) + + def testNamedConstructorArguments(self): + ctx = ApplicationContext(testSupportClasses.ConstructorBasedContainer()) + self.assertTrue(isinstance(ctx, ApplicationContext)) + + m = ctx.get_object("MultiValueHolder") + self.assertEquals("alt a", m.a) + self.assertEquals("alt b", m.b) + self.assertEquals("c", m.c) + + m2 = ctx.get_object("MultiValueHolder2") + self.assertEquals("a", m2.a) + self.assertEquals("alt b", m2.b) + self.assertEquals("alt c", m2.c) + +class SpringJavaConfigTestCase(unittest.TestCase): + def testPullingJavaConfig(self): + movieAppContainer = ApplicationContext(SpringJavaConfig("support/contextSpringJavaAppContext.xml")) + + self.assertTrue(isinstance(movieAppContainer, ApplicationContext)) + lister = movieAppContainer.get_object("MovieLister") + movieList = lister.finder.findAll() + self.assertEquals(movieList[0], "The Count of Monte Cristo") + self.assertEquals(lister.description.str, "There should only be one copy of this string") + + # Create a separate container, which has its own instances of singletons + movieAppContainer2 = ApplicationContext(SpringJavaConfig("support/contextSpringJavaAppContext.xml")) + + self.assertTrue(isinstance(movieAppContainer2, ApplicationContext)) + lister2 = movieAppContainer2.get_object("MovieLister") + movieList2 = lister2.finder.findAll() + self.assertEquals(movieList2[0], "The Count of Monte Cristo") + self.assertEquals(lister2.description.str, "There should only be one copy of this string") + + # Create another MovieLister based on the first app context + lister3 = movieAppContainer.get_object("MovieLister") + + # Identity test. Verify objects were created in separate app contexts, and that + # singletons exist only once, while prototypes are different on a per instance + # basis. + + # The MovieLister's are prototypes, and different within and between containers. + self.assertNotEquals(lister, lister2) + self.assertNotEquals(lister, lister3) + self.assertNotEquals(lister2, lister3) + + # While the strings hold the same value... + self.assertEquals(lister.description.str, lister2.description.str) + self.assertEquals(lister2.description.str, lister3.description.str) + + # ...they are not necessarily the same object + self.assertEquals(lister.description, lister3.description) + self.assertNotEquals(lister.description, lister2.description) + + # The finder is also a singleton, only varying between containers + self.assertNotEquals(lister.finder, lister2.finder) + self.assertEquals(lister.finder, lister3.finder) + + def testInnerObjects(self): + movieAppContainer = ApplicationContext(SpringJavaConfig("support/contextSpringJavaAppContext.xml")) + + lister = movieAppContainer.get_object("MovieLister2") + movieList = lister.finder.findAll() + self.assertEquals(movieList[0], "The Count of Monte Cristo") + self.assertEquals(lister.description.str, "There should only be one copy of this string") + + lister2 = movieAppContainer.get_object("MovieLister3") + movieList2 = lister2.finder.findAll() + self.assertEquals(movieList2[0], "The Count of Monte Cristo") + self.assertEquals(lister2.description.str, "There should only be one copy of this string") + + self.assertNotEqual(lister, lister2) + + def testPrefetchingObjects(self): + movieAppContainer = ApplicationContext(SpringJavaConfig("support/contextSpringJavaAppContext.xml")) + + self.assertEqual(len(movieAppContainer.object_defs), 10) + self.assertTrue("MovieLister" in movieAppContainer.object_defs) + self.assertTrue("MovieFinder" in movieAppContainer.object_defs) + self.assertTrue("SingletonString" in movieAppContainer.object_defs) + self.assertTrue("MovieLister2" in movieAppContainer.object_defs) + self.assertTrue("MovieLister3" in movieAppContainer.object_defs) + self.assertTrue("MovieLister2.finder." in movieAppContainer.object_defs) + self.assertTrue("MovieLister3.finder.named" in movieAppContainer.object_defs) + self.assertTrue("ValueHolder" in movieAppContainer.object_defs) + self.assertTrue("AnotherSingletonString" in movieAppContainer.object_defs) + self.assertTrue("AThirdSingletonString" in movieAppContainer.object_defs) + + def testCollections(self): + ctx = ApplicationContext(SpringJavaConfig("support/contextSpringJavaAppContext.xml")) + self.assertTrue(isinstance(ctx, ApplicationContext)) + value_holder = ctx.get_object("ValueHolder") + + self.assertTrue(isinstance(value_holder.some_dict, dict)) + self.assertEquals(4, len(value_holder.some_dict)) + + self.assertEquals("Python", value_holder.some_dict["Spring"]) + self.assertEquals("World", value_holder.some_dict["Hello"]) + self.assertTrue(isinstance(value_holder.some_dict["holder"], testSupportClasses.StringHolder)) + self.assertEquals("There should only be one copy of this string", value_holder.some_dict["holder"].str) + self.assertEquals("There should only be one copy of this string", value_holder.some_dict["another copy"].str) + + # Verify they are both referencing the same StringHolder class + self.assertEquals(value_holder.some_dict["holder"], value_holder.some_dict["another copy"]) + + self.assertTrue(isinstance(value_holder.some_list, list)) + self.assertEquals(3, len(value_holder.some_list)) + self.assertEquals("Hello, world!", value_holder.some_list[0]) + self.assertTrue(isinstance(value_holder.some_list[1], testSupportClasses.StringHolder)) + self.assertEquals("There should only be one copy of this string", value_holder.some_list[1].str) + self.assertEquals("Spring Python", value_holder.some_list[2]) + + # Verify this is also using the same singleton object + self.assertEquals(value_holder.some_dict["holder"], value_holder.some_list[1]) + + self.assertTrue(isinstance(value_holder.some_props, dict)) + self.assertEquals(3, len(value_holder.some_props)) + self.assertEquals("administrator@example.org", value_holder.some_props["administrator"]) + self.assertEquals("support@example.org", value_holder.some_props["support"]) + self.assertEquals("development@example.org", value_holder.some_props["development"]) + + self.assertTrue(isinstance(value_holder.some_set, set)) + self.assertEquals(3, len(value_holder.some_set)) + self.assertTrue("Hello, world!" in value_holder.some_set) + self.assertTrue("Spring Python" in value_holder.some_set) + + foundStringHolder = False + for item in value_holder.some_set: + if isinstance(item, testSupportClasses.StringHolder): + self.assertEquals("There should only be one copy of this string", item.str) + self.assertEquals(item, value_holder.some_list[1]) + foundStringHolder = True + self.assertTrue(foundStringHolder) + + def testConstructors(self): + ctx = ApplicationContext(SpringJavaConfig("support/contextSpringJavaAppContext.xml")) + self.assertTrue(isinstance(ctx, ApplicationContext)) + + another_str = ctx.get_object("AnotherSingletonString") + a_third_str = ctx.get_object("AThirdSingletonString") + + self.assertEquals("attributed value", another_str.str) + self.assertEquals("elemental value", a_third_str.str) + + value_holder = ctx.get_object("ValueHolder") + self.assertTrue(isinstance(value_holder.string_holder, testSupportClasses.StringHolder)) + self.assertEquals("There should only be one copy of this string", value_holder.string_holder.str) + + single_str = ctx.get_object("SingletonString") + + self.assertEquals(single_str.str, value_holder.string_holder.str) + self.assertEquals(single_str, value_holder.string_holder) + +class XMLConfigTestCase(unittest.TestCase): + def testPullingXMLConfig(self): + movieAppContainer = ApplicationContext(XMLConfig("support/contextSpringPythonAppContext.xml")) + + self.assertTrue(isinstance(movieAppContainer, ApplicationContext)) + lister = movieAppContainer.get_object("MovieLister") + movieList = lister.finder.findAll() + self.assertEquals(movieList[0], "The Count of Monte Cristo") + self.assertEquals(lister.description.str, "There should only be one copy of this string") + + # Create a separate container, which has its own instances of singletons + movieAppContainer2 = ApplicationContext(XMLConfig("support/contextSpringPythonAppContext.xml")) + + self.assertTrue(isinstance(movieAppContainer2, ApplicationContext)) + lister2 = movieAppContainer2.get_object("MovieLister") + movieList2 = lister2.finder.findAll() + self.assertEquals(movieList2[0], "The Count of Monte Cristo") + self.assertEquals(lister2.description.str, "There should only be one copy of this string") + + # Create another MovieLister based on the first app context + lister3 = movieAppContainer.get_object("MovieLister") + + # Identity test. Verify objects were created in separate app contexts, and that + # singletons exist only once, while prototypes are different on a per instance + # basis. + + # The MovieLister's are prototypes, and different within and between containers. + self.assertNotEquals(lister, lister2) + self.assertNotEquals(lister, lister3) + self.assertNotEquals(lister2, lister3) + + # While the strings hold the same value... + self.assertEquals(lister.description.str, lister2.description.str) + self.assertEquals(lister2.description.str, lister3.description.str) + + # ...they are not necessarily the same object + self.assertEquals(lister.description, lister3.description) + self.assertNotEquals(lister.description, lister2.description) + + # The finder is also a singleton, only varying between containers + self.assertNotEquals(lister.finder, lister2.finder) + self.assertEquals(lister.finder, lister3.finder) + + def testInnerObjects(self): + movieAppContainer = ApplicationContext(XMLConfig("support/contextSpringPythonAppContext.xml")) + + lister = movieAppContainer.get_object("MovieLister2") + movieList = lister.finder.findAll() + self.assertEquals(movieList[0], "The Count of Monte Cristo") + self.assertEquals(lister.description.str, "There should only be one copy of this string") + + lister2 = movieAppContainer.get_object("MovieLister3") + movieList2 = lister2.finder.findAll() + self.assertEquals(movieList2[0], "The Count of Monte Cristo") + self.assertEquals(lister2.description.str, "There should only be one copy of this string") + + self.assertNotEqual(lister, lister2) + + def testPrefetchingObjects(self): + movieAppContainer = ApplicationContext(XMLConfig("support/contextSpringPythonAppContext.xml")) + + self.assertEqual(len(movieAppContainer.object_defs), 12) + self.assertTrue("MovieLister" in movieAppContainer.object_defs) + self.assertTrue("MovieFinder" in movieAppContainer.object_defs) + self.assertTrue("SingletonString" in movieAppContainer.object_defs) + self.assertTrue("MovieLister2" in movieAppContainer.object_defs) + self.assertTrue("MovieLister3" in movieAppContainer.object_defs) + self.assertTrue("MovieLister2.finder." in movieAppContainer.object_defs) + self.assertTrue("MovieLister3.finder.named" in movieAppContainer.object_defs) + self.assertTrue("ValueHolder" in movieAppContainer.object_defs) + self.assertTrue("AnotherSingletonString" in movieAppContainer.object_defs) + self.assertTrue("AThirdSingletonString" in movieAppContainer.object_defs) + self.assertTrue("MultiValueHolder" in movieAppContainer.object_defs) + self.assertTrue("MultiValueHolder2" in movieAppContainer.object_defs) + + def testCollections(self): + ctx = ApplicationContext(XMLConfig("support/contextSpringPythonAppContext.xml")) + self.assertTrue(isinstance(ctx, ApplicationContext)) + value_holder = ctx.get_object("ValueHolder") + + self.assertTrue(isinstance(value_holder.some_dict, dict)) + self.assertEquals(4, len(value_holder.some_dict)) + + self.assertEquals("Python", value_holder.some_dict["Spring"]) + self.assertEquals("World", value_holder.some_dict["Hello"]) + self.assertTrue(isinstance(value_holder.some_dict["holder"], testSupportClasses.StringHolder)) + self.assertEquals("There should only be one copy of this string", value_holder.some_dict["holder"].str) + self.assertEquals("There should only be one copy of this string", value_holder.some_dict["another copy"].str) + + # Verify they are both referencing the same StringHolder class + self.assertEquals(value_holder.some_dict["holder"], value_holder.some_dict["another copy"]) + + self.assertTrue(isinstance(value_holder.some_list, list)) + self.assertEquals(3, len(value_holder.some_list)) + self.assertEquals("Hello, world!", value_holder.some_list[0]) + self.assertTrue(isinstance(value_holder.some_list[1], testSupportClasses.StringHolder)) + self.assertEquals("There should only be one copy of this string", value_holder.some_list[1].str) + self.assertEquals("Spring Python", value_holder.some_list[2]) + + # Verify this is also using the same singleton object + self.assertEquals(value_holder.some_dict["holder"], value_holder.some_list[1]) + + self.assertTrue(isinstance(value_holder.some_props, dict)) + self.assertEquals(3, len(value_holder.some_props)) + self.assertEquals("administrator@example.org", value_holder.some_props["administrator"]) + self.assertEquals("support@example.org", value_holder.some_props["support"]) + self.assertEquals("development@example.org", value_holder.some_props["development"]) + + self.assertTrue(isinstance(value_holder.some_set, set)) + self.assertEquals(3, len(value_holder.some_set)) + self.assertTrue("Hello, world!" in value_holder.some_set) + self.assertTrue("Spring Python" in value_holder.some_set) + + self.assertTrue(isinstance(value_holder.some_frozen_set, frozenset)) + self.assertEquals(3, len(value_holder.some_frozen_set)) + self.assertTrue("Hello, world!" in value_holder.some_frozen_set) + self.assertTrue("Spring Python" in value_holder.some_frozen_set) + + self.assertTrue(isinstance(value_holder.some_tuple, tuple)) + self.assertEquals(3, len(value_holder.some_tuple)) + self.assertEquals("Hello, world!", value_holder.some_tuple[0]) + self.assertTrue(isinstance(value_holder.some_tuple[1], testSupportClasses.StringHolder)) + self.assertEquals("There should only be one copy of this string", value_holder.some_tuple[1].str) + self.assertEquals("Spring Python", value_holder.some_tuple[2]) + + foundStringHolder = False + for item in value_holder.some_set: + if isinstance(item, testSupportClasses.StringHolder): + self.assertEquals("There should only be one copy of this string", item.str) + self.assertEquals(item, value_holder.some_list[1]) + foundStringHolder = True + self.assertTrue(foundStringHolder) + + def testConstructors(self): + ctx = ApplicationContext(XMLConfig("support/contextSpringPythonAppContext.xml")) + self.assertTrue(isinstance(ctx, ApplicationContext)) + + self.assertTrue(ctx.object_defs[u"SingletonString"].lazy_init) + + another_str = ctx.get_object("AnotherSingletonString") + a_third_str = ctx.get_object("AThirdSingletonString") + + self.assertEquals("attributed value", another_str.str) + self.assertEquals("elemental value", a_third_str.str) + + value_holder = ctx.get_object("ValueHolder") + self.assertTrue(isinstance(value_holder.string_holder, testSupportClasses.StringHolder)) + self.assertEquals("There should only be one copy of this string", value_holder.string_holder.str) + + single_str = ctx.get_object("SingletonString") + + self.assertEquals(single_str.str, value_holder.string_holder.str) + self.assertEquals(single_str, value_holder.string_holder) + + def testNamedConstructorArguments(self): + ctx = ApplicationContext(XMLConfig("support/contextSpringPythonAppContext.xml")) + self.assertTrue(isinstance(ctx, ApplicationContext)) + + m = ctx.get_object("MultiValueHolder") + self.assertEquals("alt a", m.a) + self.assertEquals("alt b", m.b) + self.assertEquals("c", m.c) + + m2 = ctx.get_object("MultiValueHolder2") + self.assertEquals("a", m2.a) + self.assertEquals("alt b", m2.b) + self.assertEquals("alt c", m2.c) + + def testGetComplexValueObject(self): + ctx1 = ApplicationContext(PyContainerConfig("support/contextComplexPyContainer.xml")) + ctx2 = ApplicationContext(XMLConfig("support/contextComplexXMLConfig.xml")) + + # This is what PyContainerConfig could handle + for ctx in [ctx1, ctx2]: + service = ctx.get_object("user_details_service") + self.assertEquals(8, len(service.user_dict)) + self.assertEquals(3, len(service.user_dict["basichiorangeuser"])) + self.assertEquals("ASSIGNED_ORANGE", service.user_dict["basichiorangeuser"][1][1]) + + service = ctx.get_object("user_details_service") + self.assertTrue(isinstance(service.user_dict, dict)) + self.assertEquals(8, len(service.user_dict)) + self.assertEquals(3, len(service.user_dict["basichiorangeuser"])) + self.assertEquals("ASSIGNED_ORANGE", service.user_dict["basichiorangeuser"][1][1]) + + # These are the other things that XMLConfig can handle + service2 = ctx2.get_object("user_details_service2") + self.assertTrue(isinstance(service2.user_dict, list)) + self.assertEquals(5, len(service2.user_dict)) + + self.assertEquals("Hello, world!", service2.user_dict[0]) + + self.assertTrue(isinstance(service2.user_dict[1], dict)) + self.assertEquals("This is working", service2.user_dict[1]["yes"]) + self.assertEquals("Maybe it's not?", service2.user_dict[1]["no"]) + + self.assertTrue(isinstance(service2.user_dict[2], tuple)) + self.assertEquals(4, len(service2.user_dict[2])) + self.assertEquals("Hello, from Spring Python!", service2.user_dict[2][0]) + + self.assertTrue(isinstance(service2.user_dict[2][2], dict)) + self.assertEquals(2, len(service2.user_dict[2][2])) + self.assertEquals("This is working", service2.user_dict[2][2]["yes"]) + self.assertEquals("Maybe it's not?", service2.user_dict[2][2]["no"]) + + self.assertTrue(isinstance(service2.user_dict[2][3], list)) + self.assertEquals(2, len(service2.user_dict[2][3])) + self.assertEquals("This is a list element inside a tuple.", service2.user_dict[2][3][0]) + self.assertEquals("And so is this :)", service2.user_dict[2][3][1]) + + self.assertTrue(isinstance(service2.user_dict[3], set)) + self.assertEquals(2, len(service2.user_dict[3])) + self.assertTrue("1" in service2.user_dict[3]) + self.assertTrue("2" in service2.user_dict[3]) + self.assertTrue("3" not in service2.user_dict[3]) + + self.assertTrue(isinstance(service2.user_dict[4], frozenset)) + self.assertEquals(2, len(service2.user_dict[4])) + self.assertTrue("a" in service2.user_dict[4]) + self.assertTrue("b" in service2.user_dict[4]) + self.assertTrue("c" not in service2.user_dict[4]) + +class YamlConfigTestCase(unittest.TestCase): + def testPullingYamlConfig(self): + movieAppContainer = ApplicationContext(YamlConfig("support/contextSpringPythonAppContext.yaml")) + self.assertTrue(isinstance(movieAppContainer, ApplicationContext)) + self.assertFalse(movieAppContainer.object_defs[u"MovieLister"].lazy_init) + self.assertTrue(movieAppContainer.object_defs[u"MovieFinder"].lazy_init) + self.assertTrue(movieAppContainer.object_defs[u"SingletonString"].lazy_init) + lister = movieAppContainer.get_object("MovieLister") + movieList = lister.finder.findAll() + self.assertEquals(movieList[0], "The Count of Monte Cristo") + self.assertEquals(lister.description.str, "There should only be one copy of this string") + + # Create a separate container, which has its own instances of singletons + movieAppContainer2 = ApplicationContext(YamlConfig("support/contextSpringPythonAppContext.yaml")) + + self.assertTrue(isinstance(movieAppContainer2, ApplicationContext)) + lister2 = movieAppContainer2.get_object("MovieLister") + movieList2 = lister2.finder.findAll() + self.assertEquals(movieList2[0], "The Count of Monte Cristo") + self.assertEquals(lister2.description.str, "There should only be one copy of this string") + + # Create another MovieLister based on the first app context + lister3 = movieAppContainer.get_object("MovieLister") + + # Identity test. Verify objects were created in separate app contexts, and that + # singletons exist only once, while prototypes are different on a per instance + # basis. + + # The MovieLister's are prototypes, and different within and between containers. + self.assertNotEquals(lister, lister2) + self.assertNotEquals(lister, lister3) + self.assertNotEquals(lister2, lister3) + + # While the strings hold the same value... + self.assertEquals(lister.description.str, lister2.description.str) + self.assertEquals(lister2.description.str, lister3.description.str) + + # ...they are not necessarily the same object + self.assertEquals(lister.description, lister3.description) + self.assertNotEquals(lister.description, lister2.description) + + # The finder is also a singleton, only varying between containers + self.assertNotEquals(lister.finder, lister2.finder) + self.assertEquals(lister.finder, lister3.finder) + + def testInnerObjects(self): + movieAppContainer = ApplicationContext(YamlConfig("support/contextSpringPythonAppContext.yaml")) + + lister = movieAppContainer.get_object("MovieLister2") + movieList = lister.finder.findAll() + self.assertEquals(movieList[0], "The Count of Monte Cristo") + self.assertEquals(lister.description.str, "There should only be one copy of this string") + + lister2 = movieAppContainer.get_object("MovieLister3") + movieList2 = lister2.finder.findAll() + self.assertEquals(movieList2[0], "The Count of Monte Cristo") + self.assertEquals(lister2.description.str, "There should only be one copy of this string") + + self.assertNotEqual(lister, lister2) + + def testPrefetchingObjects(self): + movieAppContainer = ApplicationContext(YamlConfig("support/contextSpringPythonAppContext.yaml")) + + self.assertEqual(len(movieAppContainer.object_defs), 12) + self.assertTrue("MovieLister" in movieAppContainer.object_defs) + self.assertTrue("MovieFinder" in movieAppContainer.object_defs) + self.assertTrue("SingletonString" in movieAppContainer.object_defs) + self.assertTrue("MovieLister2" in movieAppContainer.object_defs) + self.assertTrue("MovieLister3" in movieAppContainer.object_defs) + self.assertTrue("MovieLister2.finder." in movieAppContainer.object_defs) + self.assertTrue("MovieLister3.finder.named" in movieAppContainer.object_defs) + self.assertTrue("ValueHolder" in movieAppContainer.object_defs) + self.assertTrue("AnotherSingletonString" in movieAppContainer.object_defs) + self.assertTrue("AThirdSingletonString" in movieAppContainer.object_defs) + self.assertTrue("MultiValueHolder" in movieAppContainer.object_defs) + self.assertTrue("MultiValueHolder2" in movieAppContainer.object_defs) + + def testCollections(self): + ctx = ApplicationContext(YamlConfig("support/contextSpringPythonAppContext.yaml")) + self.assertTrue(isinstance(ctx, ApplicationContext)) + value_holder = ctx.get_object("ValueHolder") + + self.assertTrue(isinstance(value_holder.some_dict, dict)) + self.assertEquals(4, len(value_holder.some_dict)) + + self.assertEquals("Python", value_holder.some_dict["Spring"]) + self.assertEquals("World", value_holder.some_dict["Hello"]) + self.assertTrue(isinstance(value_holder.some_dict["holder"], testSupportClasses.StringHolder)) + self.assertEquals("There should only be one copy of this string", value_holder.some_dict["holder"].str) + self.assertEquals("There should only be one copy of this string", value_holder.some_dict["another copy"].str) + + # Verify they are both referencing the same StringHolder class + self.assertEquals(value_holder.some_dict["holder"], value_holder.some_dict["another copy"]) + + self.assertTrue(isinstance(value_holder.some_list, list)) + self.assertEquals(3, len(value_holder.some_list)) + self.assertEquals("Hello, world!", value_holder.some_list[0]) + self.assertTrue(isinstance(value_holder.some_list[1], testSupportClasses.StringHolder)) + self.assertEquals("There should only be one copy of this string", value_holder.some_list[1].str) + self.assertEquals("Spring Python", value_holder.some_list[2]) + + # Verify this is also using the same singleton object + self.assertEquals(value_holder.some_dict["holder"], value_holder.some_list[1]) + + self.assertTrue(isinstance(value_holder.some_props, dict)) + self.assertEquals(3, len(value_holder.some_props)) + self.assertEquals("administrator@example.org", value_holder.some_props["administrator"]) + self.assertEquals("support@example.org", value_holder.some_props["support"]) + self.assertEquals("development@example.org", value_holder.some_props["development"]) + + self.assertTrue(isinstance(value_holder.some_set, set)) + self.assertEquals(3, len(value_holder.some_set)) + self.assertTrue("Hello, world!" in value_holder.some_set) + self.assertTrue("Spring Python" in value_holder.some_set) + + self.assertTrue(isinstance(value_holder.some_frozen_set, frozenset)) + self.assertEquals(3, len(value_holder.some_frozen_set)) + self.assertTrue("Hello, world!" in value_holder.some_frozen_set) + self.assertTrue("Spring Python" in value_holder.some_frozen_set) + + self.assertTrue(isinstance(value_holder.some_tuple, tuple)) + self.assertEquals(3, len(value_holder.some_tuple)) + self.assertEquals("Hello, world!", value_holder.some_tuple[0]) + self.assertTrue(isinstance(value_holder.some_tuple[1], testSupportClasses.StringHolder)) + self.assertEquals("There should only be one copy of this string", value_holder.some_tuple[1].str) + self.assertEquals("Spring Python", value_holder.some_tuple[2]) + + foundStringHolder = False + for item in value_holder.some_set: + if isinstance(item, testSupportClasses.StringHolder): + self.assertEquals("There should only be one copy of this string", item.str) + self.assertEquals(item, value_holder.some_list[1]) + foundStringHolder = True + self.assertTrue(foundStringHolder) + + def testConstructors(self): + ctx = ApplicationContext(YamlConfig("support/contextSpringPythonAppContext.yaml")) + self.assertTrue(isinstance(ctx, ApplicationContext)) + + another_str = ctx.get_object("AnotherSingletonString") + a_third_str = ctx.get_object("AThirdSingletonString") + + self.assertEquals("attributed value", another_str.str) + self.assertEquals("elemental value", a_third_str.str) + + value_holder = ctx.get_object("ValueHolder") + self.assertTrue(isinstance(value_holder.string_holder, testSupportClasses.StringHolder)) + self.assertEquals("There should only be one copy of this string", value_holder.string_holder.str) + + single_str = ctx.get_object("SingletonString") + + self.assertEquals(single_str.str, value_holder.string_holder.str) + self.assertEquals(single_str, value_holder.string_holder) + + def testNamedConstructorArguments(self): + ctx = ApplicationContext(YamlConfig("support/contextSpringPythonAppContext.yaml")) + self.assertTrue(isinstance(ctx, ApplicationContext)) + + m = ctx.get_object("MultiValueHolder") + self.assertEquals("alt a", m.a) + self.assertEquals("alt b", m.b) + self.assertEquals("c", m.c) + + m2 = ctx.get_object("MultiValueHolder2") + self.assertEquals("a", m2.a) + self.assertEquals("alt b", m2.b) + self.assertEquals("alt c", m2.c) + +class YamlConfigTestCase2(unittest.TestCase): + def testAnotherComplexContainer(self): + ctx = ApplicationContext(YamlConfig("support/contextComplexYamlConfig2.yaml")) + service3 = ctx.get_object("user_details_service3") + self.assertTrue(isinstance(service3.user_dict, list)) + self.assertEquals(7, len(service3.user_dict)) + + self.assertTrue(isinstance(service3.user_dict[0], list)) + self.assertEquals(2, len(service3.user_dict[0])) + + self.assertTrue(isinstance(service3.user_dict[0][0], InMemoryUserDetailsService)) + self.assertEquals("Test1", service3.user_dict[0][0].user_dict) + self.assertEquals("Test2", service3.user_dict[0][1].user_dict) + + self.assertTrue(isinstance(service3.user_dict[1], tuple)) + self.assertEquals(2, len(service3.user_dict[1])) + + self.assertTrue(isinstance(service3.user_dict[1][0], InMemoryUserDetailsService)) + self.assertEquals("Test1", service3.user_dict[1][0].user_dict) + self.assertEquals("Test2", service3.user_dict[1][1].user_dict) + + self.assertTrue(isinstance(service3.user_dict[2], InMemoryUserDetailsService)) + self.assertEquals("Test3", service3.user_dict[2].user_dict) + + self.assertTrue(isinstance(service3.user_dict[3], set)) + self.assertEquals(2, len(service3.user_dict[3])) + self.assertTrue("Test4" in [item.user_dict for item in service3.user_dict[3]]) + self.assertTrue("Test5" in [item.user_dict for item in service3.user_dict[3]]) + + self.assertTrue(isinstance(service3.user_dict[4], frozenset)) + self.assertEquals(2, len(service3.user_dict[4])) + self.assertTrue("Test6" in [item.user_dict for item in service3.user_dict[4]]) + self.assertTrue("Test7" in [item.user_dict for item in service3.user_dict[4]]) + + self.assertTrue(isinstance(service3.user_dict[5], set)) + self.assertEquals(1, len(service3.user_dict[5])) + self.assertTrue("Test8" in [item.user_dict for item in service3.user_dict[5]]) + + self.assertTrue(isinstance(service3.user_dict[6], frozenset)) + self.assertEquals(1, len(service3.user_dict[6])) + self.assertTrue("Test9" in [item.user_dict for item in service3.user_dict[6]]) + + def testNamedConstructorArguments(self): + ctx = ApplicationContext(XMLConfig("support/contextSpringPythonAppContext.xml")) + self.assertTrue(isinstance(ctx, ApplicationContext)) + + m = ctx.get_object("MultiValueHolder") + self.assertEquals("alt a", m.a) + self.assertEquals("alt b", m.b) + self.assertEquals("c", m.c) + + m2 = ctx.get_object("MultiValueHolder2") + self.assertEquals("a", m2.a) + self.assertEquals("alt b", m2.b) + self.assertEquals("alt c", m2.c) + +class XMLConfigTestCase3(unittest.TestCase): + def testAThirdComplexContainer(self): + ctx = ApplicationContext(XMLConfig("support/contextComplexXMLConfig3.xml")) + service4 = ctx.get_object("user_details_service4") + self.assertTrue(isinstance(service4.user_dict, tuple)) + self.assertEquals(7, len(service4.user_dict)) + + self.assertTrue(isinstance(service4.user_dict[0], list)) + self.assertEquals(2, len(service4.user_dict[0])) + + self.assertTrue(isinstance(service4.user_dict[0][0], InMemoryUserDetailsService)) + self.assertEquals("Test1", service4.user_dict[0][0].user_dict) + self.assertEquals("Test2", service4.user_dict[0][1].user_dict) + + self.assertTrue(isinstance(service4.user_dict[1], tuple)) + self.assertEquals(2, len(service4.user_dict[1])) + + self.assertTrue(isinstance(service4.user_dict[1][0], InMemoryUserDetailsService)) + self.assertEquals("Test1", service4.user_dict[1][0].user_dict) + self.assertEquals("Test2", service4.user_dict[1][1].user_dict) + + self.assertTrue(isinstance(service4.user_dict[2], InMemoryUserDetailsService)) + self.assertEquals("Test3", service4.user_dict[2].user_dict) + + self.assertTrue(isinstance(service4.user_dict[3], set)) + self.assertEquals(2, len(service4.user_dict[3])) + self.assertTrue("Test4" in [item.user_dict for item in service4.user_dict[3]]) + self.assertTrue("Test5" in [item.user_dict for item in service4.user_dict[3]]) + + self.assertTrue(isinstance(service4.user_dict[4], frozenset)) + self.assertEquals(2, len(service4.user_dict[4])) + self.assertTrue("Test6" in [item.user_dict for item in service4.user_dict[4]]) + self.assertTrue("Test7" in [item.user_dict for item in service4.user_dict[4]]) + + self.assertTrue(isinstance(service4.user_dict[5], set)) + self.assertEquals(1, len(service4.user_dict[5])) + self.assertTrue("Test8" in [item.user_dict for item in service4.user_dict[5]]) + + self.assertTrue(isinstance(service4.user_dict[6], frozenset)) + self.assertEquals(1, len(service4.user_dict[6])) + self.assertTrue("Test9" in [item.user_dict for item in service4.user_dict[6]]) + +class YamlConfigTestCase3(unittest.TestCase): + def testAThirdComplexContainer(self): + ctx = ApplicationContext(YamlConfig("support/contextComplexYamlConfig3.yaml")) + service4 = ctx.get_object("user_details_service4") + self.assertTrue(isinstance(service4.user_dict, tuple)) + self.assertEquals(7, len(service4.user_dict)) + + self.assertTrue(isinstance(service4.user_dict[0], list)) + self.assertEquals(2, len(service4.user_dict[0])) + + self.assertTrue(isinstance(service4.user_dict[0][0], InMemoryUserDetailsService)) + self.assertEquals("Test1", service4.user_dict[0][0].user_dict) + self.assertEquals("Test2", service4.user_dict[0][1].user_dict) + + self.assertTrue(isinstance(service4.user_dict[1], tuple)) + self.assertEquals(2, len(service4.user_dict[1])) + + self.assertTrue(isinstance(service4.user_dict[1][0], InMemoryUserDetailsService)) + self.assertEquals("Test1", service4.user_dict[1][0].user_dict) + self.assertEquals("Test2", service4.user_dict[1][1].user_dict) + + self.assertTrue(isinstance(service4.user_dict[2], InMemoryUserDetailsService)) + self.assertEquals("Test3", service4.user_dict[2].user_dict) + + self.assertTrue(isinstance(service4.user_dict[3], set)) + self.assertEquals(2, len(service4.user_dict[3])) + self.assertTrue("Test4" in [item.user_dict for item in service4.user_dict[3]]) + self.assertTrue("Test5" in [item.user_dict for item in service4.user_dict[3]]) + + self.assertTrue(isinstance(service4.user_dict[4], frozenset)) + self.assertEquals(2, len(service4.user_dict[4])) + self.assertTrue("Test6" in [item.user_dict for item in service4.user_dict[4]]) + self.assertTrue("Test7" in [item.user_dict for item in service4.user_dict[4]]) + + self.assertTrue(isinstance(service4.user_dict[5], set)) + self.assertEquals(1, len(service4.user_dict[5])) + self.assertTrue("Test8" in [item.user_dict for item in service4.user_dict[5]]) + + self.assertTrue(isinstance(service4.user_dict[6], frozenset)) + self.assertEquals(1, len(service4.user_dict[6])) + self.assertTrue("Test9" in [item.user_dict for item in service4.user_dict[6]]) + +class XMLConfigTestCase4(unittest.TestCase): + def testAThirdComplexContainer(self): + ctx = ApplicationContext(XMLConfig("support/contextComplexXMLConfig4.xml")) + service5 = ctx.get_object("user_details_service5") + self.assertTrue(isinstance(service5.user_dict, set)) + self.assertEquals(4, len(service5.user_dict)) + + for item in service5.user_dict: + if isinstance(item, tuple): + self.assertEquals(2, len(item)) + self.assertEquals("Test1", item[0].user_dict) + self.assertEquals("Test2", item[1].user_dict) + elif isinstance(item, InMemoryUserDetailsService): + self.assertEquals("Test3", item.user_dict) + elif isinstance(item, frozenset): + if len(item) == 1: + self.assertTrue("Test9" in [i.user_dict for i in item]) + elif len(item) == 2: + values = [i.user_dict for i in item] + for test_value in ["Test6", "Test7"]: + self.assertTrue(test_value in values) + else: + self.fail("Did NOT expect a frozenset of length %s" % len(item)) + else: + self.fail("Cannot handle %s" % type(item)) + +class YamlConfigTestCase4(unittest.TestCase): + def testAThirdComplexContainer(self): + import logging + logger = logging.getLogger("springpython.yamltest") + + ctx = ApplicationContext(YamlConfig("support/contextComplexYamlConfig4.yaml")) + service5 = ctx.get_object("user_details_service5") + self.assertTrue(isinstance(service5.user_dict, set)) + self.assertEquals(4, len(service5.user_dict)) + + logger.debug("About to parse dict %s" % service5.user_dict) + + for item in service5.user_dict: + logger.debug("Looking at item %s inside user_dict" % str(item)) + logger.debug("It is a %s type object." % type(item)) + if isinstance(item, tuple): + self.assertEquals(2, len(item)) + self.assertEquals("Test1", item[0].user_dict) + self.assertEquals("Test2", item[1].user_dict) + elif isinstance(item, InMemoryUserDetailsService): + self.assertEquals("Test3", item.user_dict) + elif isinstance(item, frozenset): + if len(item) == 1: + self.assertTrue("Test9" in [i.user_dict for i in item]) + elif len(item) == 2: + values = [i.user_dict for i in item] + for test_value in ["Test6", "Test7"]: + self.assertTrue(test_value in values) + else: + self.fail("Did NOT expect a frozenset of length %s" % len(item)) + else: + self.fail("Cannot handle %s" % type(item)) + +class YamlConfigTypesCustomizing(unittest.TestCase): + """ Exercises the behaviour of customizations of types. + """ + + def test_default_mapping_ok(self): + container = ApplicationContext(YamlConfig("support/contextYamlBuiltinTypes.yaml")) + + self.assertEqual(12, len(container.objects)) + + my_string = container.get_object("MyString") + my_unicode = container.get_object("MyUnicode") + my_int = container.get_object("MyInt") + my_long = container.get_object("MyLong") + my_float = container.get_object("MyFloat") + my_decimal = container.get_object("MyDecimal") + my_boolean = container.get_object("MyBoolean") + my_complex = container.get_object("MyComplex") + my_list = container.get_object("MyList") + my_tuple = container.get_object("MyTuple") + my_dict = container.get_object("MyDict") + my_ref = container.get_object("MyRef") + + self.assertEqual(my_string, "My string") + self.assertEqual(my_unicode, u'Zażółć gęślą jaźń') + self.assertEqual(my_int, 10) + self.assertEqual(my_long, 100000000000000000000000) + self.assertEqual(my_float, 3.14) + self.assertEqual(my_decimal, Decimal("12.34")) + self.assertEqual(my_boolean, False) + self.assertEqual(my_complex, complex(10,0)) + self.assertEqual(my_list, [1, 2, 3, 4]) + self.assertEqual(my_tuple, ("a", "b", "c")) + self.assertEqual(my_dict, {1: "a", 2: "b", 3: "c"}) + self.assertEqual(my_ref, Decimal("12.34")) + + def test_default_mapping_error_no_type_defined(self): + # Will raise KeyError: 'class' + try: + ApplicationContext(YamlConfig("support/contextYamlBuiltinTypesErrorNoTypeDefined.yaml")) + except KeyError, e: + # Meaning there was no 'class' key found. + self.assertEqual(e.message, "class") + else: + self.fail("KeyError should've been raised") + + def test_default_mappings_dictionary_contents(self): + self.assertEqual(yaml_mappings, {'tuple': 'types.TupleType', + 'int': 'types.IntType', 'float': 'types.FloatType', + 'unicode': 'types.UnicodeType', + 'decimal': 'decimal.Decimal', 'list': 'types.ListType', + 'long': 'types.LongType', 'complex': 'types.ComplexType', + 'bool': 'types.BooleanType', 'str': 'types.StringType', + 'dict': 'types.DictType'}) + + def test_custom_mappings(self): + yaml_mappings.update({"interest_rate": "springpythontest.support.interest_rate.InterestRate"}) + container = ApplicationContext(YamlConfig("support/contextYamlCustomMappings.yaml")) + + self.assertEqual(1, len(container.objects)) + base_interest_rate = container.get_object("base_interest_rate") + self.assertEqual("7.35", base_interest_rate.value) + + del yaml_mappings["interest_rate"] + +class XMLConfigTestCase5(unittest.TestCase): + def testAFourthComplexContainer(self): + ctx = ApplicationContext(XMLConfig("support/contextComplexXMLConfig5.xml")) + service6 = ctx.get_object("user_details_service6") + self.assertTrue(isinstance(service6.user_dict, frozenset)) + self.assertEquals(4, len(service6.user_dict)) + + for item in service6.user_dict: + if isinstance(item, tuple): + self.assertEquals(2, len(item)) + self.assertEquals("Test1", item[0].user_dict) + self.assertEquals("Test2", item[1].user_dict) + elif isinstance(item, InMemoryUserDetailsService): + self.assertEquals("Test3", item.user_dict) + elif isinstance(item, frozenset): + if len(item) == 1: + self.assertTrue("Test9" in [i.user_dict for i in item]) + elif len(item) == 2: + values = [i.user_dict for i in item] + for test_value in ["Test6", "Test7"]: + self.assertTrue(test_value in values) + else: + self.fail("Did NOT expect a frozenset of length %s" % len(item)) + else: + self.fail("Cannot handle %s" % type(item)) + +class YamlConfigTestCase5(unittest.TestCase): + def testAFourthComplexContainer(self): + ctx = ApplicationContext(YamlConfig("support/contextComplexYamlConfig5.yaml")) + service6 = ctx.get_object("user_details_service6") + self.assertTrue(isinstance(service6.user_dict, frozenset)) + self.assertEquals(4, len(service6.user_dict)) + + for item in service6.user_dict: + if isinstance(item, tuple): + self.assertEquals(2, len(item)) + self.assertEquals("Test1", item[0].user_dict) + self.assertEquals("Test2", item[1].user_dict) + elif isinstance(item, InMemoryUserDetailsService): + self.assertEquals("Test3", item.user_dict) + elif isinstance(item, frozenset): + if len(item) == 1: + self.assertTrue("Test9" in [i.user_dict for i in item]) + elif len(item) == 2: + values = [i.user_dict for i in item] + for test_value in ["Test6", "Test7"]: + self.assertTrue(test_value in values) + else: + self.fail("Did NOT expect a frozenset of length %s" % len(item)) + else: + self.fail("Cannot handle %s" % type(item)) + +class XMLConfigTestCase6(unittest.TestCase): + def testAThirdComplexContainer(self): + ctx = ApplicationContext(XMLConfig("support/contextComplexXMLConfig6.xml")) + service4 = ctx.get_object("user_details_service4") + self.assertTrue(isinstance(service4.user_dict, dict)) + self.assertEquals(8, len(service4.user_dict)) + + self.assertTrue(isinstance(service4.user_dict["list"], list)) + self.assertEquals(2, len(service4.user_dict["list"])) + + self.assertTrue(isinstance(service4.user_dict["list"][0], InMemoryUserDetailsService)) + self.assertEquals("Test1", service4.user_dict["list"][0].user_dict) + self.assertEquals("Test2", service4.user_dict["list"][1].user_dict) + + self.assertTrue(isinstance(service4.user_dict["tuple"], tuple)) + self.assertEquals(2, len(service4.user_dict["tuple"])) + + self.assertTrue(isinstance(service4.user_dict["tuple"][0], InMemoryUserDetailsService)) + self.assertEquals("Test1", service4.user_dict["tuple"][0].user_dict) + self.assertEquals("Test2", service4.user_dict["tuple"][1].user_dict) + + self.assertTrue(isinstance(service4.user_dict["inner_object"], InMemoryUserDetailsService)) + self.assertEquals("Test3", service4.user_dict["inner_object"].user_dict) + + self.assertTrue(isinstance(service4.user_dict["set1"], set)) + self.assertEquals(2, len(service4.user_dict["set1"])) + self.assertTrue("Test4" in [item.user_dict for item in service4.user_dict["set1"]]) + self.assertTrue("Test5" in [item.user_dict for item in service4.user_dict["set1"]]) + + self.assertTrue(isinstance(service4.user_dict["frozenset1"], frozenset)) + self.assertEquals(2, len(service4.user_dict["frozenset1"])) + self.assertTrue("Test6" in [item.user_dict for item in service4.user_dict["frozenset1"]]) + self.assertTrue("Test7" in [item.user_dict for item in service4.user_dict["frozenset1"]]) + + self.assertTrue(isinstance(service4.user_dict["set2"], set)) + self.assertEquals(1, len(service4.user_dict["set2"])) + self.assertTrue("Test8" in [item.user_dict for item in service4.user_dict["set2"]]) + + self.assertTrue(isinstance(service4.user_dict["frozenset2"], frozenset)) + self.assertEquals(1, len(service4.user_dict["frozenset2"])) + self.assertTrue("Test9" in [item.user_dict for item in service4.user_dict["frozenset2"]]) + + self.assertEquals("Test10", service4.user_dict["value"]) + +class XMLConfigTypesMappingsTestCase(unittest.TestCase): + """This test case exercises the types mappings for XMLConfig""" + + def test_types_mappings(self): + self.assertEqual({'complex': 'types.ComplexType', + 'bool': 'types.BooleanType', 'unicode': 'types.UnicodeType', + 'str': 'types.StringType', 'int': 'types.IntType', + 'decimal': 'decimal.Decimal', 'float': 'types.FloatType', + 'long': 'types.LongType'}, xml_mappings) + + ctx = ApplicationContext(XMLConfig("support/contextXMLConfigTypesMappings.xml")) + self.assertEqual(8, len(ctx.objects)) + + my_string = ctx.get_object("MyString") + my_unicode = ctx.get_object("MyUnicode") + my_int = ctx.get_object("MyInt") + my_long = ctx.get_object("MyLong") + my_float = ctx.get_object("MyFloat") + my_decimal = ctx.get_object("MyDecimal") + my_bool = ctx.get_object("MyBool") + my_complex = ctx.get_object("MyComplex") + + self.assertEqual(my_string, "My string") + self.assertEqual(my_unicode, u"Zażółć gęślą jaźń") + self.assertEqual(my_int, 10) + self.assertEqual(my_long, 100000000000000000000000) + self.assertEqual(my_float, 3.14) + self.assertEqual(my_decimal, Decimal("12.34")) + self.assertEqual(my_bool, False) + self.assertEqual(my_complex, 10+0j) + +class XMLConfigMixedXSDVersionsTestCase(unittest.TestCase): + """ Exercises the XMLConfig behaviour when given XML config files of + different XSD versions. + """ + def test_mixed_xsd_versions(self): + config_files = ["support/contextXMLConfigXSD10.xml", "support/contextXMLConfigXSD11.xml"] + ctx = ApplicationContext(XMLConfig(config_files)) + + self.assertEqual(2, len(ctx.objects)) + + my_string_10 = ctx.get_object("MyString10") + my_string_11 = ctx.get_object("MyString11") + + self.assertEqual(my_string_10, "My string XSD 1.0") + self.assertEqual(my_string_11, "My string XSD 1.1") + +class XMLConfigConstructorBasedTestCase(unittest.TestCase): + """This test case exercises the constructors for XMLConfig""" + + def testUsingConstructorWithObjectReference(self): + ctx = ApplicationContext(XMLConfig("support/contextXMLConfigWithConstructorArgs.xml")) + + controller = ctx.get_object("controller-list") + self.assertTrue(isinstance(controller.executors, list)) + self.assertEquals(2, len(controller.executors)) + for executor in controller.executors: + self.assertTrue(isinstance(executor, testSupportClasses.Executor)) + + controller = ctx.get_object("controller-set") + self.assertTrue(isinstance(controller.executors, set)) + self.assertEquals(2, len(controller.executors)) + for executor in controller.executors: + self.assertTrue(isinstance(executor, testSupportClasses.Executor)) + + controller = ctx.get_object("controller-dict") + self.assertTrue(isinstance(controller.executors, dict)) + self.assertEquals(2, len(controller.executors)) + for key in controller.executors: + self.assertTrue(isinstance(controller.executors[key], testSupportClasses.Executor)) + + controller = ctx.get_object("controller-frozenset") + self.assertTrue(isinstance(controller.executors, frozenset)) + self.assertEquals(2, len(controller.executors)) + for executor in controller.executors: + self.assertTrue(isinstance(executor, testSupportClasses.Executor)) + + controller = ctx.get_object("controller-tuple") + self.assertTrue(isinstance(controller.executors, tuple)) + self.assertEquals(2, len(controller.executors)) + for executor in controller.executors: + self.assertTrue(isinstance(executor, testSupportClasses.Executor)) + +class PyroFactoryTestCase(unittest.TestCase): + def testPyroFactoryDoesntGetApply(self): + class MyPythonConfig(PythonConfig): + @Object + def my_pyrofactory(config_self): # let lambda access parent test self + ppf = PyroProxyFactory() + # small hack to make the thing testable + ppf.__dict__["after_properties_set"] = lambda: self.fail( + "after_properties_set mustn't be called on " + "PyroProxyFactory objects.") + return ppf + + ctx = ApplicationContext(MyPythonConfig()) + + +class YamlConfigConstructorBasedTestCase(unittest.TestCase): + """This test case exercises the constructors for XMLConfig""" + + def testUsingConstructorWithObjectReference(self): + ctx = ApplicationContext(YamlConfig("support/contextYamlConfigWithConstructorArgs.yaml")) + + controller = ctx.get_object("controller-list") + self.assertTrue(isinstance(controller.executors, list)) + self.assertEquals(2, len(controller.executors)) + for executor in controller.executors: + self.assertTrue(isinstance(executor, testSupportClasses.Executor)) + + controller = ctx.get_object("controller-set") + self.assertTrue(isinstance(controller.executors, set)) + self.assertEquals(2, len(controller.executors)) + for executor in controller.executors: + self.assertTrue(isinstance(executor, testSupportClasses.Executor)) + + controller = ctx.get_object("controller-dict") + self.assertTrue(isinstance(controller.executors, dict)) + self.assertEquals(2, len(controller.executors)) + for key in controller.executors: + self.assertTrue(isinstance(controller.executors[key], testSupportClasses.Executor)) + + controller = ctx.get_object("controller-frozenset") + self.assertTrue(isinstance(controller.executors, frozenset)) + self.assertEquals(2, len(controller.executors)) + for executor in controller.executors: + self.assertTrue(isinstance(executor, testSupportClasses.Executor)) + + controller = ctx.get_object("controller-tuple") + self.assertTrue(isinstance(controller.executors, tuple)) + self.assertEquals(2, len(controller.executors)) + for executor in controller.executors: + self.assertTrue(isinstance(executor, testSupportClasses.Executor)) + +class ObjectPostProcessorsTestCase(unittest.TestCase): + """This test case exercises object post processors""" + + def testSimpleObjectPostProcessorXml(self): + ctx = ApplicationContext(XMLConfig("support/contextObjectPostProcessing.xml")) + processor = ctx.get_object("postProcessor") + self.assertTrue(isinstance(processor, ObjectPostProcessor)) + self.assertFalse(hasattr(processor, "processedBefore")) + self.assertFalse(hasattr(processor, "processedAfter")) + obj = ctx.get_object("value") + self.assertTrue(hasattr(obj, "processedBefore")) + self.assertTrue(hasattr(obj, "processedAfter")) + + def testSimpleObjectPostProcessorYaml(self): + ctx = ApplicationContext(YamlConfig("support/contextObjectPostProcessing.yaml")) + processor = ctx.get_object("postProcessor") + self.assertTrue(isinstance(processor, ObjectPostProcessor)) + self.assertFalse(hasattr(processor, "processedBefore")) + self.assertFalse(hasattr(processor, "processedAfter")) + obj = ctx.get_object("value") + self.assertTrue(hasattr(obj, "processedBefore")) + self.assertTrue(hasattr(obj, "processedAfter")) + +class DisposableObjectTestCase(MockTestCase): + """This test case exercises the DisposableObject behaviour.""" + + def _get_sample_config(self, disposable_object): + + class SampleConfig(PythonConfig): + def __init__(self): + super(SampleConfig, self).__init__() + + @Object + def my_disposable_object(self): + return disposable_object + + return SampleConfig() + + + def testDefaultDestroyMethod(self): + + class DisposableObjectWithDefaultDestroyMethod(Mock, DisposableObject): + """ A DisposableObject with a default destroy method. Note the + AttributeError in __getattribute__, it's needed because pmock would + otherwise happily return a mock 'destroy_method' regardless of + whether one had been actually defined. + """ + + def destroy(self): + self.destroy_called = True + + def __getattr__(self, attr_name): + return object.__getattribute__(self, attr_name) + + def __getattribute__(self, attr_name): + + if attr_name == "destroy_method": + raise AttributeError() + + return object.__getattribute__(self, attr_name) + + disposable_object = DisposableObjectWithDefaultDestroyMethod() + + disposable_object.stubs().after_properties_set() + disposable_object.stubs().method("set_app_context") + + ctx = ApplicationContext(self._get_sample_config(disposable_object)) + my_disposable_object = ctx.get_object("my_disposable_object") + + ctx.shutdown_hook() + + # Will raise AttributeError if 'destroy' hasn't been called. + self.assertTrue(my_disposable_object.destroy_called) + + + def testCustomDestroyMethod(self): + + class DisposableObjectWithCustomDestroyMethod(Mock, DisposableObject): + """ A DisposableObject with a custom destroy method, its name is + returned by __getattribute__, again, to prevent pmock from + returning a mock object. + """ + + def custom_destroy(self): + self.custom_destroy_called = True + + def __getattr__(self, attr_name): + return object.__getattribute__(self, attr_name) + + def __getattribute__(self, attr_name): + + if attr_name == "destroy_method": + return "custom_destroy" + + return object.__getattribute__(self, attr_name) + + disposable_object = DisposableObjectWithCustomDestroyMethod() + + disposable_object.stubs().after_properties_set() + disposable_object.stubs().method("set_app_context") + + ctx = ApplicationContext(self._get_sample_config(disposable_object)) + my_disposable_object = ctx.get_object("my_disposable_object") + + ctx.shutdown_hook() + + # Will raise AttributeError if 'custom_destroy' hasn't been called. + self.assertTrue(my_disposable_object.custom_destroy_called) + + def testShutdownHookRegisterdWithAtExit(self): + + class Dummy(DisposableObject): + def destroy(self): + pass + + ctx = ApplicationContext(self._get_sample_config(Dummy())) + + seen_shutdown_hook = False + + # We need to iterate through all registered atexit handlers, our handler + # will be will among the other handlers registered in previous tests. + # Note: we're using a private atexit API here. + for handler_info in atexit._exithandlers: + func = handler_info[0] + if func == ctx.shutdown_hook: + seen_shutdown_hook = True + + self.assertTrue(seen_shutdown_hook) + +class AppContextObjectsObjectsDefsTestCase(MockTestCase): + """This test case exercises the application contexts' .objects and + .object_defs behaviour.""" + + def _get_querying_context(self): + + class MyClass(object): + pass + + class MySubclass(MyClass): + pass + + class SampleContext(PythonConfig): + def __init__(self): + super(SampleContext, self).__init__() + + @Object + def http_port(self): + return 18000 + + @Object + def https_port(self): + return self._get_https_port() + + def _get_https_port(self): + return self.http_port() + 443 + + @Object + def my_class_object1(self): + return MyClass() + + @Object + def my_class_object2(self): + return MyClass() + + @Object + def my_subclass_object1(self): + return MySubclass() + + @Object + def my_subclass_object2(self): + return MySubclass() + + @Object + def my_subclass_object3(self): + return MySubclass() + + return ApplicationContext(SampleContext()), MyClass, MySubclass + + def _get_modifying_context(self): + + class SampleContext2(PythonConfig): + def __init__(self): + super(SampleContext2, self).__init__() + + return ApplicationContext(SampleContext2()) + + def testQuerying(self): + ctx, MyClass, MySubclass, = self._get_querying_context() + + class_instances = ctx.get_objects_by_type(MyClass) + subclass_instances = ctx.get_objects_by_type(MyClass, False) + int_instances = ctx.get_objects_by_type(int) + + self.assertTrue(isinstance(class_instances, dict)) + self.assertTrue(isinstance(subclass_instances, dict)) + self.assertTrue(isinstance(int_instances, dict)) + + self.assertEquals(5, len(class_instances)) + self.assertEquals(3, len(subclass_instances)) + self.assertEquals(2, len(int_instances)) + + for name, instance in class_instances.items(): + self.assertTrue(isinstance(instance, MyClass)) + + for name, instance in subclass_instances.items(): + self.assertTrue(isinstance(instance, MyClass) and type(instance) is not MyClass) + + for name, instance in int_instances.items(): + self.assertTrue(isinstance(instance, int)) + + self.assertTrue("http_port" in ctx.objects) + self.assertTrue("http_port" in ctx.objects) + self.assertFalse("ftp_port" in ctx.object_defs) + self.assertEqual(7, len(ctx.objects)) + + for name in ctx.objects: + self.assertTrue(isinstance(name, basestring)) + + for name in ctx.object_defs: + self.assertTrue(isinstance(name, basestring)) + + def testModifying(self): + ctx = self._get_modifying_context() + + class Foo(object): + pass + + class Bar(object): + pass + + @Object(PROTOTYPE) + def foo(): + """ Returns a new instance of Foo on each call. + """ + return Foo() + + @Object # SINGLETON is the default. + def bar(): + """ Returns a singleton Bar every time accessed. + """ + return Bar() + + # A reference to the function wrapping the actual 'foo' function. + foo_wrapper = foo.func_globals["_call_"] + + # Create an object definition, note that we're telling to return + foo_object_def = ObjectDef(id="foo", + factory=PythonObjectFactory(foo, foo_wrapper), scope=PROTOTYPE, + lazy_init=foo_wrapper.lazy_init) + + # A reference to the function wrapping the actual 'bar' function. + bar_wrapper = foo.func_globals["_call_"] + + bar_object_def = ObjectDef(id="foo", + factory=PythonObjectFactory(bar, bar_wrapper), scope=SINGLETON, + lazy_init=bar_wrapper.lazy_init) + + # No definitions at this point + self.assertEqual({}, ctx.object_defs) + + ctx.object_defs["foo"] = foo_object_def + ctx.object_defs["bar"] = bar_object_def + + # Two object defs have just been added. + self.assertEqual(2, len(ctx.object_defs)) + + for x in range(3): + foo_instance = ctx.get_object("foo") + self.assertTrue(isinstance(foo_instance, Foo)) + + # Will leak the 'bar_instance' for later use. + for x in range(3): + bar_instance = ctx.get_object("bar") + self.assertTrue(isinstance(bar_instance, Bar)) + + # 'foo' object is a PROTOTYPE and 'bar' is a SINGLETON so there must've + # been exactly one object created so far. + self.assertEqual(1, len(ctx.objects)) + + obj = ctx.objects[ctx.objects.keys()[0]] + self.assertTrue(obj is bar_instance) + +class AbstractObjectsTestCase(MockTestCase): + """Test cases related to handling of abstract container managed objects. + """ + + def _get_python_config(self): + + class Request(object): + def __init__(self, nounce=None, user=None, password=None): + self.nounce = nounce + self.user = user + self.password = password + + def __str__(self): + return "" % (hex(id(self)), self.nounce, self.user, self.password) + + class CRMService(object): + def __init__(self, ip=None, port=None, path=None): + self.ip = ip + self.port = port + + def invoke(self, request): + return "CRM OK %s" % request.nounce + + def __str__(self): + return "" % (hex(id(self)), self.ip, self.port) + + class IVRService(object): + def __init__(self, instance=None): + self.instance = instance + + def invoke(self, request): + return "IVR OK %s" % request.nounce + + def __str__(self): + return "" % (hex(id(self)), self.instance) + + class TestAbstractContext(PythonConfig): + + @Object(PROTOTYPE, lazy_init=True, abstract=True) + def request(self): + request = Request() + request.nounce = "".join([random.choice("1234567890") for x in range(16)]) + + return request + + @Object(PROTOTYPE, parent="request") + def crm_request(self, request=None): + request.user = "foo" + request.password = "bar" + + return request + + @Object(PROTOTYPE, parent="request") + def ivr_request(self, request=None): + request.user = "baz" + request.password = "frobble" + + return request + + @Object(abstract=True) + def crm_service(self): + service = CRMService() + service.ip = "192.168.1.145" + service.port = 2627 + + return service + + @Object(parent="crm_service") + def get_customer_id(self, service=None): + request = self.get_object("crm_request") + + return service.invoke(request) + + @Object(PROTOTYPE, parent="crm_service") + def get_customer_profile(self, service=None): + request = self.get_object("crm_request") + + return service.invoke(request) + + @Object(PROTOTYPE, abstract=True) + def ivr_service(self): + service = CRMService() + service.ip = "192.168.1.145" + service.port = 2627 + + return service + + @Object(PROTOTYPE, parent="ivr_service") + def get_customer_location(self, service=None): + request = self.get_object("ivr_request") + + return service.invoke(request) + + @Object(parent="ivr_service") + def get_customer_complaints(self, service=None): + request = self.get_object("ivr_request") + + return service.invoke(request) + + return TestAbstractContext + + def testPythonConfigAbstractObjects(self): + ctx_class = self._get_python_config() + container = ApplicationContext(ctx_class()) + + # Use a variety of scopes to ensure the proper handling of abstract + # objects doesn't depend on their scopes. + + # request => PROTOTYPE + # crm_request => PROTOTYPE + # ivr_request => PROTOTYPE + + # crm_service => SINGLETON + # get_customer_id => SINGLETON + # get_customer_profile => PROTOTYPE + + # ivr_service => PROTOTYPE + # get_customer_location => PROTOTYPE + # get_customer_complaints => SINGLETON + + # get_object's 'ignore_abstract' is False by default. + self.assertRaises(AbstractObjectException, container.get_object, "request") + self.assertRaises(AbstractObjectException, container.get_object, "crm_service") + self.assertRaises(AbstractObjectException, container.get_object, "ivr_service") + + # Won't raise AbstractObjectException because the 'ignore_abstract' flag is True. + request = container.get_object("request", True) + crm_service = container.get_object("crm_service", True) + ivr_service = container.get_object("ivr_service", True) + + self.assertEquals(16, len(request.nounce)) + self.assertTrue(str.isdigit(request.nounce)) + self.assertEquals(None, request.user) + self.assertEquals(None, request.password) + + crm_request = container.get_object("crm_request") + ivr_request = container.get_object("ivr_request") + + self.assertEquals(16, len(crm_request.nounce)) + self.assertTrue(str.isdigit(crm_request.nounce)) + self.assertEquals("foo", crm_request.user) + self.assertEquals("bar", crm_request.password) + + self.assertEquals(16, len(ivr_request.nounce)) + self.assertTrue(str.isdigit(ivr_request.nounce)) + self.assertEquals("baz", ivr_request.user) + self.assertEquals("frobble", ivr_request.password) + + self.assertNotEquals(crm_request.nounce, ivr_request.nounce) + + # Abstract objects may be lazily-initialized or not, and that shouldn't + # get in the way of how they're handled, AbstractObjectException shouldn't + # be raised in either case. + + get_customer_id = container.get_object("get_customer_id") + get_customer_profile = container.get_object("get_customer_profile") + + get_customer_location = container.get_object("get_customer_location") + get_customer_complaints = container.get_object("get_customer_complaints") + + def testXMLAndYamlConfigAbstractObjects(self): + + # + # There are various combinations and corner cases that need be tested + # here, depending on whether an application context uses properties only, + # properties and constructor arguments or constructor args solely. What + # also needs be taken into account is if there are any abstract objects + # and if so, how many levels of inheritance are there. + # + # 1 - uses properties only + # + # 2 - uses properties and named constructor arguments + # + # 3 - uses properties, named and positional arguments + # + # 4 - used for testing of how positional arguments are being handled + # (doesn't use properties nor named arguments) + # + # Note that some assertions are identical for both XML and Yaml config + # modes. From the user's standpoint, the only difference is that XMLConfig + # allows for defining both positional and named arguments whereas with + # Yaml config one needs to choose either positional or named constructor + # parameters. + # + + # + # Properties only + # + + xml_ctx1 = ApplicationContext(XMLConfig("support/contextXMLConfigAbstract1.xml")) + yaml_ctx1 = ApplicationContext(YamlConfig("support/contextYamlAbstract1.yaml")) + + for ctx in(xml_ctx1, yaml_ctx1): + + # There should be only two objects defined, the abstract one which + # is also lazily-initialized shouldn't have been added + # to the container. + self.assertEquals(2, len(ctx.objects)) + + # All object definitions, no matter abstract or concrete ones, + # should have been added though. + self.assertEquals(4, len(ctx.object_defs)) + self.assertEquals(["crm_service", "get_customer_id1", "get_customer_id2", "service"], + sorted(ctx.object_defs.keys())) + + get_customer_id1 = ctx.get_object("get_customer_id1") + self.assertEquals("192.168.1.153", get_customer_id1.ip) + self.assertEquals("3392", get_customer_id1.port) + self.assertEquals("/soap/invoke/get-customer-id1", get_customer_id1.path) + + get_customer_id2 = ctx.get_object("get_customer_id2") + self.assertEquals("192.168.1.153", get_customer_id2.ip) + self.assertEquals("3392", get_customer_id2.port) + self.assertEquals("/soap/invoke/get-customer-id2", get_customer_id2.path) + + get_customer_id1_def = ctx.object_defs["get_customer_id1"] + self.assertEquals(False, get_customer_id1_def.abstract) + self.assertEquals(SINGLETON, get_customer_id1_def.scope) + self.assertEquals("crm_service", get_customer_id1_def.parent) + + get_customer_id2_def = ctx.object_defs["get_customer_id2"] + self.assertEquals(False, get_customer_id2_def.abstract) + self.assertEquals(PROTOTYPE, get_customer_id2_def.scope) + self.assertEquals("crm_service", get_customer_id2_def.parent) + + # Abstract objects must not be added to the container. + self.assertRaises(KeyError, ctx.get_object, "foo_root1") + self.assertRaises(KeyError, ctx.get_object, "foo_root2") + + # + # Properties and named constructor arguments + # + + xml_ctx2 = ApplicationContext(XMLConfig("support/contextXMLConfigAbstract2.xml")) + yaml_ctx2 = ApplicationContext(YamlConfig("support/contextYamlAbstract2.yaml")) + + for ctx in(xml_ctx2, yaml_ctx2): + + foo_child1 = ctx.get_object("foo_child1") + self.assertEquals("aaa", foo_child1.a) + self.assertEquals("bbb", foo_child1.b) + self.assertEquals(None, foo_child1.c) + self.assertEquals(None, foo_child1.d) + self.assertEquals(None, foo_child1.e) + self.assertEquals(None, foo_child1.f) + self.assertEquals(None, foo_child1.g) + + foo_child2 = ctx.get_object("foo_child2") + self.assertEquals("aaa", foo_child2.a) + self.assertEquals("bbb", foo_child2.b) + self.assertEquals("ccc", foo_child2.c) + self.assertEquals(None, foo_child2.d) + self.assertEquals(None, foo_child2.e) + self.assertEquals(None, foo_child2.f) + self.assertEquals(None, foo_child2.g) + + foo_child3 = ctx.get_object("foo_child3") + self.assertEquals("aaa", foo_child3.a) + self.assertEquals("bbbb", foo_child3.b) + self.assertEquals("cccc", foo_child3.c) + self.assertEquals("dddd", foo_child3.d) + self.assertEquals("eeee", foo_child3.e) + self.assertEquals(None, foo_child3.f) + self.assertEquals(None, foo_child3.g) + + foo_child4 = ctx.get_object("foo_child4") + self.assertEquals("aaa", foo_child4.a) + self.assertEquals("bbbb", foo_child4.b) + self.assertEquals("cccc", foo_child4.c) + self.assertEquals("dddd", foo_child4.d) + self.assertEquals(None, foo_child4.e) + self.assertEquals("ffff", foo_child4.f) + self.assertEquals("MyString", foo_child4.g) + + # + # Properties, named and positional arguments + # + + xml_ctx3 = ApplicationContext(XMLConfig("support/contextXMLConfigAbstract3.xml")) + + foo_root3 = xml_ctx3.get_object("foo_root3", True) + self.assertEquals("first_pos_arg_in_foo_root3", foo_root3.a) + self.assertEquals(None, foo_root3.b) + self.assertEquals("cccccc_foo_root3", foo_root3.c) + self.assertEquals("dddddd", foo_root3.d) + self.assertEquals(None, foo_root3.e) + self.assertEquals(None, foo_root3.f) + self.assertEquals(None, foo_root3.g) + + foo_root4 = xml_ctx3.get_object("foo_root4", True) + self.assertEquals("MyString", foo_root4.a) + self.assertEquals(None, foo_root4.b) + self.assertEquals("cccccc_foo_root4", foo_root4.c) + self.assertEquals("dddddd", foo_root4.d) + self.assertEquals(None, foo_root4.e) + self.assertEquals("ffffff", foo_root4.f) + self.assertEquals("MyString", foo_root4.g) + + foo_child5 = xml_ctx3.get_object("foo_child5") + self.assertEquals("MyString", foo_child5.a) + self.assertEquals(None, foo_child5.b) + self.assertEquals("cccccc_foo_child5", foo_child5.c) + self.assertEquals("dddddd_foo_child5", foo_child5.d) + self.assertEquals(None, foo_child5.e) + self.assertEquals("ffffff", foo_child5.f) + self.assertEquals("gggggg_foo_child5", foo_child5.g) + + # + # Properties, named and positional arguments + # + + yaml_ctx3 = ApplicationContext(YamlConfig("support/contextYamlAbstract3.yaml")) + + foo_root_yaml3 = yaml_ctx3.get_object("foo_root_yaml3", True) + self.assertEquals("aaaaaa", foo_root_yaml3.a) + self.assertEquals("MyString", foo_root_yaml3.b) + self.assertEquals(None, foo_root_yaml3.c) + self.assertEquals("dddddd", foo_root_yaml3.d) + self.assertEquals(None, foo_root_yaml3.e) + self.assertEquals(None, foo_root_yaml3.f) + self.assertEquals(None, foo_root_yaml3.g) + + foo_root_yaml4 = yaml_ctx3.get_object("foo_root_yaml4", True) + self.assertEquals("aaaaaa_foo_root_yaml4", foo_root_yaml4.a) + self.assertEquals("bbbbbb", foo_root_yaml4.b) + self.assertEquals("MyString", foo_root_yaml4.c) + self.assertEquals("dddddd", foo_root_yaml4.d) + self.assertEquals("eeeeee_foo_root_yaml4", foo_root_yaml4.e) + self.assertEquals(None, foo_root_yaml4.f) + self.assertEquals(None, foo_root_yaml4.g) + + # + # Positional arguments only + # + + xml_ctx4 = ApplicationContext(XMLConfig("support/contextXMLConfigAbstract4.xml")) + yaml_ctx4 = ApplicationContext(YamlConfig("support/contextYamlAbstract4.yaml")) + + for ctx in(xml_ctx4, yaml_ctx4): + + foo_root_pos1 = ctx.get_object("foo_root_pos1", True) + self.assertEquals("a_foo_root_pos1", foo_root_pos1.a) + self.assertEquals("b_foo_root_pos1", foo_root_pos1.b) + self.assertEquals("MyString", foo_root_pos1.c) + self.assertEquals("d_foo_root_pos1", foo_root_pos1.d) + self.assertEquals("e_foo_root_pos1", foo_root_pos1.e) + self.assertEquals("f_foo_root_pos1", foo_root_pos1.f) + self.assertEquals("g_foo_root_pos1", foo_root_pos1.g) + + foo_parent_pos2 = ctx.get_object("foo_parent_pos2", True) + self.assertEquals("a_foo_parent_pos2", foo_parent_pos2.a) + self.assertEquals("b_foo_parent_pos2", foo_parent_pos2.b) + self.assertEquals("c_foo_parent_pos2", foo_parent_pos2.c) + self.assertEquals("d_foo_parent_pos2", foo_parent_pos2.d) + self.assertEquals("e_foo_parent_pos2", foo_parent_pos2.e) + self.assertEquals("f_foo_root_pos1", foo_parent_pos2.f) + self.assertEquals("g_foo_root_pos1", foo_parent_pos2.g) + + foo_parent_pos3 = ctx.get_object("foo_parent_pos3", True) + self.assertEquals("a_foo_parent_pos3", foo_parent_pos3.a) + self.assertEquals("MyString", foo_parent_pos3.b) + self.assertEquals("c_foo_parent_pos2", foo_parent_pos3.c) + self.assertEquals("d_foo_parent_pos2", foo_parent_pos3.d) + self.assertEquals("e_foo_parent_pos2", foo_parent_pos3.e) + self.assertEquals("f_foo_root_pos1", foo_parent_pos3.f) + self.assertEquals("g_foo_root_pos1", foo_parent_pos3.g) + + foo_parent_pos4 = ctx.get_object("foo_parent_pos4", True) + self.assertEquals("a_foo_parent_pos4", foo_parent_pos4.a) + self.assertEquals("b_foo_parent_pos4", foo_parent_pos4.b) + self.assertEquals("c_foo_parent_pos4", foo_parent_pos4.c) + self.assertEquals("d_foo_parent_pos2", foo_parent_pos4.d) + self.assertEquals("e_foo_parent_pos2", foo_parent_pos4.e) + self.assertEquals("f_foo_root_pos1", foo_parent_pos4.f) + self.assertEquals("g_foo_root_pos1", foo_parent_pos4.g) + + foo_child_pos5 = ctx.get_object("foo_child_pos5", True) + self.assertEquals("a_foo_child_pos5", foo_child_pos5.a) + self.assertEquals("b_foo_child_pos5", foo_child_pos5.b) + self.assertEquals("c_foo_parent_pos4", foo_child_pos5.c) + self.assertEquals("d_foo_parent_pos2", foo_child_pos5.d) + self.assertEquals("e_foo_parent_pos2", foo_child_pos5.e) + self.assertEquals("f_foo_root_pos1", foo_child_pos5.f) + self.assertEquals("g_foo_root_pos1", foo_child_pos5.g) + + +class ScopesTestCase(MockTestCase): + """Test cases related to proper handling of scopes of objects. + """ + + def test_scope(self): + + class TestContext(PythonConfig): + + @Object(PROTOTYPE) + def prototype(self): + pass + + @Object(SINGLETON) + def singleton(self): + pass + + invalid = """ +class InvalidScopeContainingContext(PythonConfig): + @Object("FOOBAR") + def invalid(self): + pass""" + + # If we pass this line then only correct scopes will have been used + # in TestContext. + container = ApplicationContext(TestContext()) + + for object_def in container.object_defs: + if object_def == "singleton": + self.assertEquals(SINGLETON, container.object_defs[object_def].scope) + elif object_def == "prototype": + self.assertEquals(PROTOTYPE, container.object_defs[object_def].scope) + else: + self.fail("Unexpected object_def [%s]" % object_def) + + _globals, _locals = {}, {} + + _globals["PythonConfig"] = PythonConfig + _globals["Object"] = Object + + def should_raise_invalid_object_scope(): + exec invalid in _globals, _locals + + self.assertRaises(InvalidObjectScope, should_raise_invalid_object_scope) diff --git a/test/springpythontest/databaseCoreTestCases.py b/test/springpythontest/databaseCoreTestCases.py index 59ddbf1..20caac6 100644 --- a/test/springpythontest/databaseCoreTestCases.py +++ b/test/springpythontest/databaseCoreTestCases.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ +from __future__ import print_function import logging import os import sys @@ -288,7 +289,7 @@ def testProgrammaticStaticQueryForLong(self): self.mock.expects(once()).method("execute").id("#1") self.mock.expects(once()).method("fetchall").will(return_value([(4,)])).id("#2").after("#1") - count = self.databaseTemplate.query_for_object("select count(*) from animal", required_type=types.IntType) + count = self.databaseTemplate.query_for_object("select count(*) from animal", required_type=int) self.assertEquals(count, 4) def testProgrammaticQueryForLongWithBoundVariables(self): @@ -297,19 +298,19 @@ def testProgrammaticQueryForLongWithBoundVariables(self): self.mock.expects(once()).method("execute").id("#3").after("#2") self.mock.expects(once()).method("fetchall").will(return_value([(1,)])).id("#4").after("#3") - count = self.databaseTemplate.query_for_object("select count(*) from animal where name = %s", ("snake",), types.IntType) + count = self.databaseTemplate.query_for_object("select count(*) from animal where name = %s", ("snake",), int) self.assertEquals(count, 1) - count = self.databaseTemplate.query_for_object("select count(*) from animal where name = ?", ("snake",), types.IntType) + count = self.databaseTemplate.query_for_object("select count(*) from animal where name = ?", ("snake",), int) self.assertEquals(count, 1) def testProgrammaticStaticQueryForObject(self): - self.assertRaises(ArgumentMustBeNamed, self.databaseTemplate.query_for_object, "select name from animal where category = 'reptile'", types.StringType) + self.assertRaises(ArgumentMustBeNamed, self.databaseTemplate.query_for_object, "select name from animal where category = 'reptile'", bytes) self.mock.expects(once()).method("execute").id("#1") self.mock.expects(once()).method("fetchall").will(return_value([("snake",)])).id("#2").after("#1") - name = self.databaseTemplate.query_for_object("select name from animal where category = 'reptile'", required_type=types.StringType) + name = self.databaseTemplate.query_for_object("select name from animal where category = 'reptile'", required_type=bytes) self.assertEquals(name, "snake") def testProgrammaticQueryForObjectWithBoundVariables(self): @@ -318,10 +319,10 @@ def testProgrammaticQueryForObjectWithBoundVariables(self): self.mock.expects(once()).method("execute").id("#3").after("#2") self.mock.expects(once()).method("fetchall").will(return_value([("snake",)])).id("#4").after("#3") - name = self.databaseTemplate.query_for_object("select name from animal where category = %s", ("reptile",), types.StringType) + name = self.databaseTemplate.query_for_object("select name from animal where category = %s", ("reptile",), bytes) self.assertEquals(name, "snake") - name = self.databaseTemplate.query_for_object("select name from animal where category = ?", ("reptile",), types.StringType) + name = self.databaseTemplate.query_for_object("select name from animal where category = ?", ("reptile",), bytes) self.assertEquals(name, "snake") def testProgrammaticStaticUpdate(self): @@ -333,7 +334,7 @@ def testProgrammaticStaticUpdate(self): rows = self.databaseTemplate.update("UPDATE animal SET name = 'python' WHERE name = 'snake'") self.assertEquals(rows, 1) - name = self.databaseTemplate.query_for_object("SELECT name FROM animal WHERE category = 'reptile'", required_type=types.StringType) + name = self.databaseTemplate.query_for_object("SELECT name FROM animal WHERE category = 'reptile'", required_type=bytes) self.assertEquals(name, "python") def testProgrammaticUpdateWithBoundVariables(self): @@ -348,13 +349,13 @@ def testProgrammaticUpdateWithBoundVariables(self): rows = self.databaseTemplate.update("UPDATE animal SET name = ? WHERE category = ?", ("python", "reptile")) self.assertEquals(rows, 1) - name = self.databaseTemplate.query_for_object("SELECT name FROM animal WHERE category = 'reptile'", required_type=types.StringType) + name = self.databaseTemplate.query_for_object("SELECT name FROM animal WHERE category = 'reptile'", required_type=bytes) self.assertEquals(name, "python") rows = self.databaseTemplate.update("UPDATE animal SET name = ? WHERE category = %s", ("coily", "reptile")) self.assertEquals(rows, 1) - name = self.databaseTemplate.query_for_object("SELECT name FROM animal WHERE category = 'reptile'", required_type=types.StringType) + name = self.databaseTemplate.query_for_object("SELECT name FROM animal WHERE category = 'reptile'", required_type=bytes) self.assertEquals(name, "coily") def testProgrammaticStaticInsert(self): @@ -366,7 +367,7 @@ def testProgrammaticStaticInsert(self): rows = self.databaseTemplate.execute ("INSERT INTO animal (name, category, population) VALUES ('black mamba', 'kill_bill_viper', 1)") self.assertEquals(rows, 1) - name = self.databaseTemplate.query_for_object("SELECT name FROM animal WHERE category = 'kill_bill_viper'", required_type=types.StringType) + name = self.databaseTemplate.query_for_object("SELECT name FROM animal WHERE category = 'kill_bill_viper'", required_type=bytes) self.assertEquals(name, "black mamba") def testProgrammaticStaticInsertWithInsertApi(self): @@ -378,7 +379,7 @@ def testProgrammaticStaticInsertWithInsertApi(self): id = self.databaseTemplate.insert_and_return_id("INSERT INTO animal (name, category, population) VALUES ('black mamba', 'kill_bill_viper', 1)") self.assertEquals(id, 42) - name = self.databaseTemplate.query_for_object("SELECT name FROM animal WHERE category = 'kill_bill_viper'", required_type=types.StringType) + name = self.databaseTemplate.query_for_object("SELECT name FROM animal WHERE category = 'kill_bill_viper'", required_type=bytes) self.assertEquals(name, "black mamba") def testProgrammaticInsertWithBoundVariables(self): @@ -393,13 +394,13 @@ def testProgrammaticInsertWithBoundVariables(self): rows = self.databaseTemplate.execute ("INSERT INTO animal (name, category, population) VALUES (?, ?, ?)", ('black mamba', 'kill_bill_viper', 1)) self.assertEquals(rows, 1) - name = self.databaseTemplate.query_for_object("SELECT name FROM animal WHERE category = 'kill_bill_viper'", required_type=types.StringType) + name = self.databaseTemplate.query_for_object("SELECT name FROM animal WHERE category = 'kill_bill_viper'", required_type=bytes) self.assertEquals(name, "black mamba") rows = self.databaseTemplate.execute("INSERT INTO animal (name, category, population) VALUES (%s, %s, %s)", ('cottonmouth', 'kill_bill_viper', 1)) self.assertEquals(rows, 1) - name = self.databaseTemplate.query_for_object("select name from animal where name = 'cottonmouth'", required_type=types.StringType) + name = self.databaseTemplate.query_for_object("select name from animal where name = 'cottonmouth'", required_type=bytes) self.assertEquals(name, "cottonmouth") def testProgrammaticInsertWithBoundVariablesWithInsertApi(self): @@ -414,13 +415,13 @@ def testProgrammaticInsertWithBoundVariablesWithInsertApi(self): id = self.databaseTemplate.insert_and_return_id ("INSERT INTO animal (name, category, population) VALUES (?, ?, ?)", ('black mamba', 'kill_bill_viper', 1)) self.assertEquals(id, 42) - name = self.databaseTemplate.query_for_object("SELECT name FROM animal WHERE category = 'kill_bill_viper'", required_type=types.StringType) + name = self.databaseTemplate.query_for_object("SELECT name FROM animal WHERE category = 'kill_bill_viper'", required_type=bytes) self.assertEquals(name, "black mamba") id = self.databaseTemplate.insert_and_return_id("INSERT INTO animal (name, category, population) VALUES (%s, %s, %s)", ('cottonmouth', 'kill_bill_viper', 1)) self.assertEquals(id, 42) - name = self.databaseTemplate.query_for_object("select name from animal where name = 'cottonmouth'", required_type=types.StringType) + name = self.databaseTemplate.query_for_object("select name from animal where name = 'cottonmouth'", required_type=bytes) self.assertEquals(name, "cottonmouth") class AbstractDatabaseTemplateTestCase(unittest.TestCase): @@ -552,36 +553,36 @@ def testProgrammaticQueryForLongWithBoundVariables(self): self.assertEquals(count, 1) def testProgrammaticStaticQueryForObject(self): - self.assertRaises(ArgumentMustBeNamed, self.databaseTemplate.query_for_object, "select name from animal where category = 'reptile'", types.StringType) + self.assertRaises(ArgumentMustBeNamed, self.databaseTemplate.query_for_object, "select name from animal where category = 'reptile'", bytes) - name = self.databaseTemplate.query_for_object("select name from animal where category = 'reptile'", required_type=types.StringType) + name = self.databaseTemplate.query_for_object("select name from animal where category = 'reptile'", required_type=bytes) self.assertEquals(name, "snake") def testProgrammaticQueryForObjectWithBoundVariables(self): - name = self.databaseTemplate.query_for_object("select name from animal where category = %s", ("reptile",), types.StringType) + name = self.databaseTemplate.query_for_object("select name from animal where category = %s", ("reptile",), bytes) self.assertEquals(name, "snake") - name = self.databaseTemplate.query_for_object("select name from animal where category = ?", ("reptile",), types.StringType) + name = self.databaseTemplate.query_for_object("select name from animal where category = ?", ("reptile",), bytes) self.assertEquals(name, "snake") def testProgrammaticStaticUpdate(self): rows = self.databaseTemplate.update("UPDATE animal SET name = 'python' WHERE name = 'snake'") self.assertEquals(rows, 1) - name = self.databaseTemplate.query_for_object("SELECT name FROM animal WHERE category = 'reptile'", required_type=types.StringType) + name = self.databaseTemplate.query_for_object("SELECT name FROM animal WHERE category = 'reptile'", required_type=bytes) self.assertEquals(name, "python") def testProgrammaticUpdateWithBoundVariables(self): rows = self.databaseTemplate.update("UPDATE animal SET name = ? WHERE category = ?", ("python", "reptile")) self.assertEquals(rows, 1) - name = self.databaseTemplate.query_for_object("SELECT name FROM animal WHERE category = 'reptile'", required_type=types.StringType) + name = self.databaseTemplate.query_for_object("SELECT name FROM animal WHERE category = 'reptile'", required_type=bytes) self.assertEquals(name, "python") rows = self.databaseTemplate.update("UPDATE animal SET name = ? WHERE category = %s", ("coily", "reptile")) self.assertEquals(rows, 1) - name = self.databaseTemplate.query_for_object("SELECT name FROM animal WHERE category = 'reptile'", required_type=types.StringType) + name = self.databaseTemplate.query_for_object("SELECT name FROM animal WHERE category = 'reptile'", required_type=bytes) self.assertEquals(name, "coily") def testProgrammaticStaticInsert(self): @@ -589,7 +590,7 @@ def testProgrammaticStaticInsert(self): rows = self.databaseTemplate.execute("INSERT INTO animal (name, category, population) VALUES ('black mamba', 'kill_bill_viper', 1)") self.assertEquals(rows, 1) - name = self.databaseTemplate.query_for_object("SELECT name FROM animal WHERE category = 'kill_bill_viper'", required_type=types.StringType) + name = self.databaseTemplate.query_for_object("SELECT name FROM animal WHERE category = 'kill_bill_viper'", required_type=bytes) self.assertEquals(name, "black mamba") def testProgrammaticStaticInsertWithInsertApi(self): @@ -597,7 +598,7 @@ def testProgrammaticStaticInsertWithInsertApi(self): id = self.databaseTemplate.insert_and_return_id("INSERT INTO animal (name, category, population) VALUES ('black mamba', 'kill_bill_viper', 1)") self.assertEquals(id, 1) - name = self.databaseTemplate.query_for_object("SELECT name FROM animal WHERE category = 'kill_bill_viper'", required_type=types.StringType) + name = self.databaseTemplate.query_for_object("SELECT name FROM animal WHERE category = 'kill_bill_viper'", required_type=bytes) self.assertEquals(name, "black mamba") def testProgrammaticInsertWithBoundVariables(self): @@ -605,13 +606,13 @@ def testProgrammaticInsertWithBoundVariables(self): rows = self.databaseTemplate.execute("INSERT INTO animal (name, category, population) VALUES (?, ?, ?)", ('black mamba', 'kill_bill_viper', 1)) self.assertEquals(rows, 1) - name = self.databaseTemplate.query_for_object("SELECT name FROM animal WHERE category = 'kill_bill_viper'", required_type=types.StringType) + name = self.databaseTemplate.query_for_object("SELECT name FROM animal WHERE category = 'kill_bill_viper'", required_type=bytes) self.assertEquals(name, "black mamba") rows = self.databaseTemplate.execute("INSERT INTO animal (name, category, population) VALUES (%s, %s, %s)", ('cottonmouth', 'kill_bill_viper', 1)) self.assertEquals(rows, 1) - name = self.databaseTemplate.query_for_object("select name from animal where name = 'cottonmouth'", required_type=types.StringType) + name = self.databaseTemplate.query_for_object("select name from animal where name = 'cottonmouth'", required_type=bytes) self.assertEquals(name, "cottonmouth") def testProgrammaticInsertWithBoundVariablesWithInsertApi(self): @@ -619,13 +620,13 @@ def testProgrammaticInsertWithBoundVariablesWithInsertApi(self): id = self.databaseTemplate.insert_and_return_id("INSERT INTO animal (name, category, population) VALUES (?, ?, ?)", ('black mamba', 'kill_bill_viper', 1)) self.assertEquals(id, 1) - name = self.databaseTemplate.query_for_object("SELECT name FROM animal WHERE category = 'kill_bill_viper'", required_type=types.StringType) + name = self.databaseTemplate.query_for_object("SELECT name FROM animal WHERE category = 'kill_bill_viper'", required_type=bytes) self.assertEquals(name, "black mamba") id = self.databaseTemplate.insert_and_return_id("INSERT INTO animal (name, category, population) VALUES (%s, %s, %s)", ('cottonmouth', 'kill_bill_viper', 1)) self.assertEquals(id, 2) - name = self.databaseTemplate.query_for_object("select name from animal where name = 'cottonmouth'", required_type=types.StringType) + name = self.databaseTemplate.query_for_object("select name from animal where name = 'cottonmouth'", required_type=bytes) self.assertEquals(name, "cottonmouth") class MySQLDatabaseTemplateTestCase(AbstractDatabaseTemplateTestCase): @@ -648,7 +649,7 @@ def createTables(self): """) self.factory.commit() - except Exception, e: + except Exception as e: print(""" !!! Can't run MySQLDatabaseTemplateTestCase !!! @@ -707,7 +708,7 @@ def createTables(self): """) self.factory.commit() - except Exception, e: + except Exception as e: print(""" !!! Can't run PostGreSQLDatabaseTemplateTestCase !!! @@ -764,7 +765,7 @@ def createTables(self): """) self.factory.commit() - except Exception, e: + except Exception as e: print(""" !!! Can't run SqliteDatabaseTemplateTestCase !!! """) @@ -879,7 +880,7 @@ def createTables(self): self.factory.commit() - except Exception, e: + except Exception as e: print(""" !!! Can't run SQLServerDatabaseTemplateTestCase !!! diff --git a/test/springpythontest/databaseCoreTestCases.py.bak b/test/springpythontest/databaseCoreTestCases.py.bak new file mode 100644 index 0000000..59ddbf1 --- /dev/null +++ b/test/springpythontest/databaseCoreTestCases.py.bak @@ -0,0 +1,920 @@ +""" + Copyright 2006-2008 SpringSource (http://springsource.com), All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +import logging +import os +import sys +import types +import unittest +from pmock import * +from springpython.config import XMLConfig +from springpython.context import ApplicationContext +from springpython.database import ArgumentMustBeNamed +from springpython.database import DataAccessException +from springpython.database import InvalidArgumentType +from springpython.database.core import DatabaseTemplate +from springpython.database.core import DictionaryRowMapper +from springpython.database.core import SimpleRowMapper +from springpython.database import factory +from springpythontest.support import testSupportClasses + +logger = logging.getLogger("springpythontest.databaseCoreTestCases") + +class ConnectionFactoryTestCase(MockTestCase): + """Testing the connection factories requires mocking the libraries they are meant to utilize.""" + + def testConnectingToMySql(self): + sys.modules["MySQLdb"] = self.mock() + sys.modules["MySQLdb"].expects(once()).method("connect") + + connection_factory = factory.MySQLConnectionFactory(username="foo", password="bar", hostname="localhost", db="mock") + connection = connection_factory.connect() + + del(sys.modules["MySQLdb"]) + + def testConnectingToPostgresQL(self): + sys.modules["pgdb"] = self.mock() + sys.modules["pgdb"].expects(once()).method("connect") + + connection_factory = factory.PgdbConnectionFactory(user="foo", password="bar", host="localhost", database="mock") + connection = connection_factory.connect() + + del(sys.modules["pgdb"]) + + def testConnectingToSqlite(self): + sys.modules["sqlite3"] = self.mock() + sys.modules["sqlite3"].expects(once()).method("connect") + + connection_factory = factory.Sqlite3ConnectionFactory(db="/tmp/foobar") + connection = connection_factory.connect() + + del(sys.modules["sqlite3"]) + + def testConnectingToSqliteWithSpecialCheck(self): + sys.modules["sqlite3"] = self.mock() + sys.modules["sqlite3"].expects(once()).method("connect") + + connection_factory = factory.Sqlite3ConnectionFactory(db="/tmp/foobar", check_same_thread=False) + connection = connection_factory.connect() + + del(sys.modules["sqlite3"]) + + def testConnectingToOracle(self): + sys.modules["cx_Oracle"] = self.mock() + sys.modules["cx_Oracle"].expects(once()).method("connect") + + connection_factory = factory.cxoraConnectionFactory(username="foo", password="bar", hostname="localhost", db="mock") + connection = connection_factory.connect() + + del(sys.modules["cx_Oracle"]) + + def testQueryingOracleWithInvalidlyFormattedArguments(self): + sys.modules["cx_Oracle"] = self.mock() + + connection_factory = factory.cxoraConnectionFactory(username="foo", password="bar", hostname="localhost", db="mock") + dt = DatabaseTemplate(connection_factory) + + self.assertRaises(InvalidArgumentType, dt.query, """ + SELECT + impcarrcfg.paystat_work_dir, + impcarrcfg.paystat_reload_dir, + impcarrcfg.paystat_archive_dir, + impcarrcfg.oid + FROM impcarrcfg, carr, lklabelsys + WHERE (lklabelsys.oid = impcarrcfg.lklabelsys_oid) + and (carr.oid = impcarrcfg.carr_oid ) + and (carr.oid = ? and lklabelsys.oid = ?) + """, (5, 5), testSupportClasses.ImpFilePropsRowMapper()) + + del(sys.modules["cx_Oracle"]) + + def testQueryingOracleWithValidlyFormattedArguments(self): + cursor = self.mock() + cursor.expects(once()).method("execute") + cursor.expects(once()).method("fetchall").will(return_value([("workDir", "reloadDir", "archiveDir", "oid1")])) + + conn = self.mock() + conn.expects(once()).method("cursor").will(return_value(cursor)) + conn.expects(once()).method("close") + + sys.modules["cx_Oracle"] = self.mock() + sys.modules["cx_Oracle"].expects(once()).method("connect").will(return_value(conn)) + + connection_factory = factory.cxoraConnectionFactory(username="foo", password="bar", hostname="localhost", db="mock") + dt = DatabaseTemplate(connection_factory) + + dt.query(""" + SELECT + impcarrcfg.paystat_work_dir, + impcarrcfg.paystat_reload_dir, + impcarrcfg.paystat_archive_dir, + impcarrcfg.oid + FROM impcarrcfg, carr, lklabelsys + WHERE (lklabelsys.oid = impcarrcfg.lklabelsys_oid) + and (carr.oid = impcarrcfg.carr_oid ) + and (carr.oid = :carr_oid and lklabelsys.oid = :lklabelsys_oid) + """, + {'carr_oid':5, 'lklabelsys_oid':5}, + testSupportClasses.ImpFilePropsRowMapper()) + + del(sys.modules["cx_Oracle"]) + + def testInsertingIntoOracleWithInvalidlyFormattedArguments(self): + sys.modules["cx_Oracle"] = self.mock() + + connection_factory = factory.cxoraConnectionFactory(username="foo", password="bar", hostname="localhost", db="mock") + dt = DatabaseTemplate(connection_factory) + + self.assertRaises(InvalidArgumentType, dt.execute, + "INSERT INTO T_UNIT (F_UNIT_PK, F_UNIT_ID, F_NAME) VALUES (?, ?, ?)", + (1,1,1)) + + del(sys.modules["cx_Oracle"]) + + def testInsertingIntoOracleWithInvalidlyFormattedArgumentsWithUpdateApi(self): + sys.modules["cx_Oracle"] = self.mock() + + connection_factory = factory.cxoraConnectionFactory(username="foo", password="bar", hostname="localhost", db="mock") + dt = DatabaseTemplate(connection_factory) + + self.assertRaises(InvalidArgumentType, dt.update, + "INSERT INTO T_UNIT (F_UNIT_PK, F_UNIT_ID, F_NAME) VALUES (?, ?, ?)", + (1,1,1)) + + del(sys.modules["cx_Oracle"]) + + def testInsertingIntoOracleWithInvalidlyFormattedArgumentsWithInsertApi(self): + sys.modules["cx_Oracle"] = self.mock() + + connection_factory = factory.cxoraConnectionFactory(username="foo", password="bar", hostname="localhost", db="mock") + dt = DatabaseTemplate(connection_factory) + + self.assertRaises(InvalidArgumentType, dt.insert_and_return_id, + "INSERT INTO T_UNIT (F_UNIT_PK, F_UNIT_ID, F_NAME) VALUES (?, ?, ?)", + (1,1,1)) + + del(sys.modules["cx_Oracle"]) + +class DatabaseTemplateMockTestCase(MockTestCase): + """Testing the DatabaseTemplate utilizes stubbing and mocking, in order to isolate from different + vendor implementations. This reduces the overhead in making changes to core functionality.""" + + def setUp(self): + self.mock = self.mock() + connection_factory = testSupportClasses.StubDBFactory() + connection_factory.stubConnection.mockCursor = self.mock + self.databaseTemplate = DatabaseTemplate(connection_factory) + + def testProgrammaticallyInstantiatingAnAbstractDatabaseTemplate(self): + emptyTemplate = DatabaseTemplate() + self.assertRaises(AttributeError, emptyTemplate.query, "sql query shouldn't work", None) + + def testProgrammaticHandlingInvalidRowHandler(self): + self.mock.expects(once()).method("execute") + self.mock.expects(once()).method("fetchall").will(return_value([("me", "myphone")])) + + self.assertRaises(AttributeError, self.databaseTemplate.query, "select * from foobar", rowhandler=testSupportClasses.InvalidCallbackHandler()) + + def testProgrammaticHandlingImproperRowHandler(self): + self.mock.expects(once()).method("execute") + self.mock.expects(once()).method("fetchall").will(return_value([("me", "myphone")])) + + self.assertRaises(TypeError, self.databaseTemplate.query, "select * from foobar", rowhandler=testSupportClasses.ImproperCallbackHandler()) + + def testProgrammaticHandlingValidDuckTypedRowHandler(self): + self.mock.expects(once()).method("execute") + self.mock.expects(once()).method("fetchall").will(return_value([("me", "myphone")])) + + results = self.databaseTemplate.query("select * from foobar", rowhandler=testSupportClasses.ValidHandler()) + + def testIoCGeneralQuery(self): + appContext = ApplicationContext(XMLConfig("support/databaseTestApplicationContext.xml")) + mockConnectionFactory = appContext.get_object("mockConnectionFactory") + mockConnectionFactory.stubConnection.mockCursor = self.mock + + self.mock.expects(once()).method("execute") + self.mock.expects(once()).method("fetchall").will(return_value([("me", "myphone")])) + + + databaseTemplate = DatabaseTemplate(connection_factory = mockConnectionFactory) + results = databaseTemplate.query("select * from foobar", rowhandler=testSupportClasses.SampleRowMapper()) + + def testProgrammaticStaticQuery(self): + self.assertRaises(ArgumentMustBeNamed, self.databaseTemplate.query, "select * from animal", testSupportClasses.AnimalRowMapper()) + + self.mock.expects(once()).method("execute").id("#1") + self.mock.expects(once()).method("fetchall").will(return_value([('snake', 'reptile', 1), ('racoon', 'mammal', 1)])).id("#2").after("#1") + + animals = self.databaseTemplate.query("select * from animal", rowhandler=testSupportClasses.AnimalRowMapper()) + self.assertEquals(animals[0].name, "snake") + self.assertEquals(animals[0].category, "reptile") + self.assertEquals(animals[1].name, "racoon") + self.assertEquals(animals[1].category, "mammal") + + def testProgrammaticQueryWithBoundArguments(self): + self.mock.expects(once()).method("execute").id("#1") + self.mock.expects(once()).method("fetchall").will(return_value([('snake', 'reptile', 1)])).id("#2").after("#1") + self.mock.expects(once()).method("execute").id("#3").after("#2") + self.mock.expects(once()).method("fetchall").will(return_value([('snake', 'reptile', 1)])).id("#4").after("#3") + + animals = self.databaseTemplate.query("select * from animal where name = %s", ("snake",), testSupportClasses.AnimalRowMapper()) + self.assertEquals(animals[0].name, "snake") + self.assertEquals(animals[0].category, "reptile") + + animals = self.databaseTemplate.query("select * from animal where name = ?", ("snake",), testSupportClasses.AnimalRowMapper()) + self.assertEquals(animals[0].name, "snake") + self.assertEquals(animals[0].category, "reptile") + + def testProgrammaticStaticQueryForList(self): + self.mock.expects(once()).method("execute").id("#1") + self.mock.expects(once()).method("fetchall").will(return_value([('snake', 'reptile', 1), ('racoon', 'mammal', 1)])).id("#2").after("#1") + + animals = self.databaseTemplate.query_for_list("select * from animal") + self.assertEquals(animals[0][0], "snake") + self.assertEquals(animals[0][1], "reptile") + self.assertEquals(animals[1][0], "racoon") + self.assertEquals(animals[1][1], "mammal") + + def testProgrammaticQueryForListWithBoundArguments(self): + self.mock.expects(once()).method("execute").id("#1") + self.mock.expects(once()).method("fetchall").will(return_value([('snake', 'reptile', 1)])).id("#2").after("#1") + self.mock.expects(once()).method("execute").id("#3").after("#2") + self.mock.expects(once()).method("fetchall").will(return_value([('snake', 'reptile', 1)])).id("#4").after("#3") + + animals = self.databaseTemplate.query_for_list("select * from animal where name = %s", ("snake",)) + self.assertEquals(animals[0][0], "snake") + self.assertEquals(animals[0][1], "reptile") + + animals = self.databaseTemplate.query_for_list("select * from animal where name = ?", ("snake",)) + self.assertEquals(animals[0][0], "snake") + self.assertEquals(animals[0][1], "reptile") + + def testProgrammaticQueryForListWithBoundArgumentsNotProperlyTuplized(self): + self.assertRaises(InvalidArgumentType, self.databaseTemplate.query_for_list, "select * from animal where name = %s", "snake") + self.assertRaises(InvalidArgumentType, self.databaseTemplate.query_for_list, "select * from animal where name = ?", "snake") + + def testProgrammaticStaticQueryForInt(self): + self.mock.expects(once()).method("execute").id("#1") + self.mock.expects(once()).method("fetchall").will(return_value([(1,)])).id("#2").after("#1") + + count = self.databaseTemplate.query_for_int("select population from animal where name = 'snake'") + self.assertEquals(count, 1) + + def testProgrammaticQueryForIntWithBoundArguments(self): + self.mock.expects(once()).method("execute").id("#1") + self.mock.expects(once()).method("fetchall").will(return_value([(1,)])).id("#2").after("#1") + self.mock.expects(once()).method("execute").id("#3").after("#2") + self.mock.expects(once()).method("fetchall").will(return_value([(1,)])).id("#4").after("#3") + + count = self.databaseTemplate.query_for_int("select population from animal where name = %s", ("snake",)) + self.assertEquals(count, 1) + + count = self.databaseTemplate.query_for_int("select population from animal where name = ?", ("snake",)) + self.assertEquals(count, 1) + + def testProgrammaticStaticQueryForLong(self): + self.mock.expects(once()).method("execute").id("#1") + self.mock.expects(once()).method("fetchall").will(return_value([(4,)])).id("#2").after("#1") + + count = self.databaseTemplate.query_for_object("select count(*) from animal", required_type=types.IntType) + self.assertEquals(count, 4) + + def testProgrammaticQueryForLongWithBoundVariables(self): + self.mock.expects(once()).method("execute").id("#1") + self.mock.expects(once()).method("fetchall").will(return_value([(1,)])).id("#2").after("#1") + self.mock.expects(once()).method("execute").id("#3").after("#2") + self.mock.expects(once()).method("fetchall").will(return_value([(1,)])).id("#4").after("#3") + + count = self.databaseTemplate.query_for_object("select count(*) from animal where name = %s", ("snake",), types.IntType) + self.assertEquals(count, 1) + + count = self.databaseTemplate.query_for_object("select count(*) from animal where name = ?", ("snake",), types.IntType) + self.assertEquals(count, 1) + + def testProgrammaticStaticQueryForObject(self): + self.assertRaises(ArgumentMustBeNamed, self.databaseTemplate.query_for_object, "select name from animal where category = 'reptile'", types.StringType) + + self.mock.expects(once()).method("execute").id("#1") + self.mock.expects(once()).method("fetchall").will(return_value([("snake",)])).id("#2").after("#1") + + name = self.databaseTemplate.query_for_object("select name from animal where category = 'reptile'", required_type=types.StringType) + self.assertEquals(name, "snake") + + def testProgrammaticQueryForObjectWithBoundVariables(self): + self.mock.expects(once()).method("execute").id("#1") + self.mock.expects(once()).method("fetchall").will(return_value([("snake",)])).id("#2").after("#1") + self.mock.expects(once()).method("execute").id("#3").after("#2") + self.mock.expects(once()).method("fetchall").will(return_value([("snake",)])).id("#4").after("#3") + + name = self.databaseTemplate.query_for_object("select name from animal where category = %s", ("reptile",), types.StringType) + self.assertEquals(name, "snake") + + name = self.databaseTemplate.query_for_object("select name from animal where category = ?", ("reptile",), types.StringType) + self.assertEquals(name, "snake") + + def testProgrammaticStaticUpdate(self): + self.mock.expects(once()).method("execute").id("#1") + self.mock.expects(once()).method("execute").id("#2").after("#1") + self.mock.expects(once()).method("fetchall").will(return_value([("python",)])).id("#3").after("#2") + self.mock.rowcount = 1 + + rows = self.databaseTemplate.update("UPDATE animal SET name = 'python' WHERE name = 'snake'") + self.assertEquals(rows, 1) + + name = self.databaseTemplate.query_for_object("SELECT name FROM animal WHERE category = 'reptile'", required_type=types.StringType) + self.assertEquals(name, "python") + + def testProgrammaticUpdateWithBoundVariables(self): + self.mock.expects(once()).method("execute").id("#1") + self.mock.expects(once()).method("execute").id("#2").after("#1") + self.mock.expects(once()).method("fetchall").will(return_value([("python",)])).id("#3").after("#2") + self.mock.expects(once()).method("execute").id("#4").after("#3") + self.mock.expects(once()).method("execute").id("#5").after("#4") + self.mock.expects(once()).method("fetchall").will(return_value([("coily",)])).id("#6").after("#5") + self.mock.rowcount = 1 + + rows = self.databaseTemplate.update("UPDATE animal SET name = ? WHERE category = ?", ("python", "reptile")) + self.assertEquals(rows, 1) + + name = self.databaseTemplate.query_for_object("SELECT name FROM animal WHERE category = 'reptile'", required_type=types.StringType) + self.assertEquals(name, "python") + + rows = self.databaseTemplate.update("UPDATE animal SET name = ? WHERE category = %s", ("coily", "reptile")) + self.assertEquals(rows, 1) + + name = self.databaseTemplate.query_for_object("SELECT name FROM animal WHERE category = 'reptile'", required_type=types.StringType) + self.assertEquals(name, "coily") + + def testProgrammaticStaticInsert(self): + self.mock.expects(once()).method("execute").id("#1") + self.mock.expects(once()).method("execute").id("#2").after("#1") + self.mock.expects(once()).method("fetchall").will(return_value([("black mamba",)])).id("#3").after("#2") + self.mock.rowcount = 1 + + rows = self.databaseTemplate.execute ("INSERT INTO animal (name, category, population) VALUES ('black mamba', 'kill_bill_viper', 1)") + self.assertEquals(rows, 1) + + name = self.databaseTemplate.query_for_object("SELECT name FROM animal WHERE category = 'kill_bill_viper'", required_type=types.StringType) + self.assertEquals(name, "black mamba") + + def testProgrammaticStaticInsertWithInsertApi(self): + self.mock.expects(once()).method("execute").id("#1") + self.mock.expects(once()).method("execute").id("#2").after("#1") + self.mock.expects(once()).method("fetchall").will(return_value([("black mamba",)])).id("#3").after("#2") + self.mock.lastrowid = 42 + + id = self.databaseTemplate.insert_and_return_id("INSERT INTO animal (name, category, population) VALUES ('black mamba', 'kill_bill_viper', 1)") + self.assertEquals(id, 42) + + name = self.databaseTemplate.query_for_object("SELECT name FROM animal WHERE category = 'kill_bill_viper'", required_type=types.StringType) + self.assertEquals(name, "black mamba") + + def testProgrammaticInsertWithBoundVariables(self): + self.mock.expects(once()).method("execute").id("#1") + self.mock.expects(once()).method("execute").id("#2").after("#1") + self.mock.expects(once()).method("fetchall").will(return_value([("black mamba",)])).id("#3").after("#2") + self.mock.expects(once()).method("execute").id("#4").after("#3") + self.mock.expects(once()).method("execute").id("#5").after("#4") + self.mock.expects(once()).method("fetchall").will(return_value([("cottonmouth",)])).id("#6").after("#5") + self.mock.rowcount = 1 + + rows = self.databaseTemplate.execute ("INSERT INTO animal (name, category, population) VALUES (?, ?, ?)", ('black mamba', 'kill_bill_viper', 1)) + self.assertEquals(rows, 1) + + name = self.databaseTemplate.query_for_object("SELECT name FROM animal WHERE category = 'kill_bill_viper'", required_type=types.StringType) + self.assertEquals(name, "black mamba") + + rows = self.databaseTemplate.execute("INSERT INTO animal (name, category, population) VALUES (%s, %s, %s)", ('cottonmouth', 'kill_bill_viper', 1)) + self.assertEquals(rows, 1) + + name = self.databaseTemplate.query_for_object("select name from animal where name = 'cottonmouth'", required_type=types.StringType) + self.assertEquals(name, "cottonmouth") + + def testProgrammaticInsertWithBoundVariablesWithInsertApi(self): + self.mock.expects(once()).method("execute").id("#1") + self.mock.expects(once()).method("execute").id("#2").after("#1") + self.mock.expects(once()).method("fetchall").will(return_value([("black mamba",)])).id("#3").after("#2") + self.mock.expects(once()).method("execute").id("#4").after("#3") + self.mock.expects(once()).method("execute").id("#5").after("#4") + self.mock.expects(once()).method("fetchall").will(return_value([("cottonmouth",)])).id("#6").after("#5") + self.mock.lastrowid = 42 + + id = self.databaseTemplate.insert_and_return_id ("INSERT INTO animal (name, category, population) VALUES (?, ?, ?)", ('black mamba', 'kill_bill_viper', 1)) + self.assertEquals(id, 42) + + name = self.databaseTemplate.query_for_object("SELECT name FROM animal WHERE category = 'kill_bill_viper'", required_type=types.StringType) + self.assertEquals(name, "black mamba") + + id = self.databaseTemplate.insert_and_return_id("INSERT INTO animal (name, category, population) VALUES (%s, %s, %s)", ('cottonmouth', 'kill_bill_viper', 1)) + self.assertEquals(id, 42) + + name = self.databaseTemplate.query_for_object("select name from animal where name = 'cottonmouth'", required_type=types.StringType) + self.assertEquals(name, "cottonmouth") + +class AbstractDatabaseTemplateTestCase(unittest.TestCase): + def __init__(self, methodName='runTest'): + unittest.TestCase.__init__(self, methodName) + self.factory = None + self.createdTables = False + + def setUp(self): + if not self.createdTables: + self.createTables() + self.databaseTemplate = DatabaseTemplate(self.factory) + self.databaseTemplate.execute("DELETE FROM animal") + self.factory.commit() + self.assertEquals(len(self.databaseTemplate.query_for_list("SELECT * FROM animal")), 0) + self.databaseTemplate.execute("INSERT INTO animal (name, category, population) VALUES ('snake', 'reptile', 1)") + self.databaseTemplate.execute("INSERT INTO animal (name, category, population) VALUES ('racoon', 'mammal', 0)") + self.databaseTemplate.execute ("INSERT INTO animal (name, category, population) VALUES ('black mamba', 'kill_bill_viper', 1)") + self.databaseTemplate.execute ("INSERT INTO animal (name, category, population) VALUES ('cottonmouth', 'kill_bill_viper', 1)") + self.factory.commit() + self.assertEquals(len(self.databaseTemplate.query_for_list("SELECT * FROM animal")), 4) + + def tearDown(self): + self.factory.rollback() + + def testProgrammaticallyInstantiatingAnAbstractDatabaseTemplate(self): + emptyTemplate = DatabaseTemplate() + self.assertRaises(AttributeError, emptyTemplate.query, "sql query shouldn't work", None) + + def testProgrammaticHandlingInvalidRowHandler(self): + self.assertRaises(AttributeError, self.databaseTemplate.query, "select * from animal", rowhandler=testSupportClasses.InvalidCallbackHandler()) + + def testProgrammaticHandlingImproperRowHandler(self): + self.assertRaises(TypeError, self.databaseTemplate.query, "select * from animal", rowhandler=testSupportClasses.ImproperCallbackHandler()) + + def testProgrammaticHandlingValidDuckTypedRowHandler(self): + results = self.databaseTemplate.query("select * from animal", rowhandler=testSupportClasses.ValidHandler()) + + def testProgrammaticStaticQuery(self): + self.assertRaises(ArgumentMustBeNamed, self.databaseTemplate.query, "select * from animal", testSupportClasses.AnimalRowMapper()) + + animals = self.databaseTemplate.query("select name, category from animal", rowhandler=testSupportClasses.AnimalRowMapper()) + self.assertEquals(animals[0].name, "snake") + self.assertEquals(animals[0].category, "reptile") + self.assertEquals(animals[1].name, "racoon") + self.assertEquals(animals[1].category, "mammal") + + def testProgrammaticStaticQueryWithSimpleRowMapper(self): + animals = self.databaseTemplate.query("select name, category from animal", rowhandler=SimpleRowMapper(testSupportClasses.Animal)) + self.assertEquals(animals[0].name, "snake") + self.assertEquals(animals[0].category, "reptile") + self.assertEquals(animals[1].name, "racoon") + self.assertEquals(animals[1].category, "mammal") + + def testProgrammaticStaticQueryWithDictionaryRowMapper(self): + animals = self.databaseTemplate.query("select name, category from animal", rowhandler=DictionaryRowMapper()) + self.assertEquals(animals[0]["name"], "snake") + self.assertEquals(animals[0]["category"], "reptile") + self.assertEquals(animals[1]["name"], "racoon") + self.assertEquals(animals[1]["category"], "mammal") + + def testProgrammaticQueryWithBoundArguments(self): + animals = self.databaseTemplate.query("select name, category from animal where name = %s", ("snake",), testSupportClasses.AnimalRowMapper()) + self.assertEquals(animals[0].name, "snake") + self.assertEquals(animals[0].category, "reptile") + + animals = self.databaseTemplate.query("select name, category from animal where name = ?", ("snake",), testSupportClasses.AnimalRowMapper()) + self.assertEquals(animals[0].name, "snake") + self.assertEquals(animals[0].category, "reptile") + + def testProgrammaticQueryWithBoundArgumentsWithSimpleRowMapper(self): + animals = self.databaseTemplate.query("select name, category from animal where name = %s", ("snake",), SimpleRowMapper(testSupportClasses.Animal)) + self.assertEquals(animals[0].name, "snake") + self.assertEquals(animals[0].category, "reptile") + + animals = self.databaseTemplate.query("select name, category from animal where name = ?", ("snake",), SimpleRowMapper(testSupportClasses.Animal)) + self.assertEquals(animals[0].name, "snake") + self.assertEquals(animals[0].category, "reptile") + + def testProgrammaticQueryWithBoundArgumentsWithDictionaryRowMapper(self): + animals = self.databaseTemplate.query("select name, category from animal where name = %s", ("snake",), DictionaryRowMapper()) + self.assertEquals(animals[0]["name"], "snake") + self.assertEquals(animals[0]["category"], "reptile") + + animals = self.databaseTemplate.query("select name, category from animal where name = ?", ("snake",), DictionaryRowMapper()) + self.assertEquals(animals[0]["name"], "snake") + self.assertEquals(animals[0]["category"], "reptile") + + def testProgrammaticStaticQueryForList(self): + animals = self.databaseTemplate.query_for_list("select name, category from animal") + self.assertEquals(animals[0][0], "snake") + self.assertEquals(animals[0][1], "reptile") + self.assertEquals(animals[1][0], "racoon") + self.assertEquals(animals[1][1], "mammal") + + def testProgrammaticQueryForListWithBoundArguments(self): + animals = self.databaseTemplate.query_for_list("select name, category from animal where name = %s", ("snake",)) + self.assertEquals(animals[0][0], "snake") + self.assertEquals(animals[0][1], "reptile") + + animals = self.databaseTemplate.query_for_list("select name, category from animal where name = ?", ("snake",)) + self.assertEquals(animals[0][0], "snake") + self.assertEquals(animals[0][1], "reptile") + + def testProgrammaticQueryForListWithBoundArgumentsNotProperlyTuplized(self): + self.assertRaises(InvalidArgumentType, self.databaseTemplate.query_for_list, "select * from animal where name = %s", "snake") + self.assertRaises(InvalidArgumentType, self.databaseTemplate.query_for_list, "select * from animal where name = ?", "snake") + + def testProgrammaticStaticQueryForInt(self): + count = self.databaseTemplate.query_for_int("select population from animal where name = 'snake'") + self.assertEquals(count, 1) + + def testProgrammaticQueryForIntWithBoundArguments(self): + count = self.databaseTemplate.query_for_int("select population from animal where name = %s", ("snake",)) + self.assertEquals(count, 1) + + count = self.databaseTemplate.query_for_int("select population from animal where name = ?", ("snake",)) + self.assertEquals(count, 1) + + def testProgrammaticStaticQueryForLong(self): + count = self.databaseTemplate.query_for_object("select count(*) from animal", required_type=self.factory.count_type()) + self.assertEquals(count, 4) + + def testProgrammaticQueryForLongWithBoundVariables(self): + count = self.databaseTemplate.query_for_object("select count(*) from animal where name = %s", ("snake",), self.factory.count_type()) + self.assertEquals(count, 1) + + count = self.databaseTemplate.query_for_object("select count(*) from animal where name = ?", ("snake",), self.factory.count_type()) + self.assertEquals(count, 1) + + def testProgrammaticStaticQueryForObject(self): + self.assertRaises(ArgumentMustBeNamed, self.databaseTemplate.query_for_object, "select name from animal where category = 'reptile'", types.StringType) + + name = self.databaseTemplate.query_for_object("select name from animal where category = 'reptile'", required_type=types.StringType) + self.assertEquals(name, "snake") + + def testProgrammaticQueryForObjectWithBoundVariables(self): + name = self.databaseTemplate.query_for_object("select name from animal where category = %s", ("reptile",), types.StringType) + self.assertEquals(name, "snake") + + name = self.databaseTemplate.query_for_object("select name from animal where category = ?", ("reptile",), types.StringType) + self.assertEquals(name, "snake") + + def testProgrammaticStaticUpdate(self): + rows = self.databaseTemplate.update("UPDATE animal SET name = 'python' WHERE name = 'snake'") + self.assertEquals(rows, 1) + + name = self.databaseTemplate.query_for_object("SELECT name FROM animal WHERE category = 'reptile'", required_type=types.StringType) + self.assertEquals(name, "python") + + def testProgrammaticUpdateWithBoundVariables(self): + rows = self.databaseTemplate.update("UPDATE animal SET name = ? WHERE category = ?", ("python", "reptile")) + self.assertEquals(rows, 1) + + name = self.databaseTemplate.query_for_object("SELECT name FROM animal WHERE category = 'reptile'", required_type=types.StringType) + self.assertEquals(name, "python") + + rows = self.databaseTemplate.update("UPDATE animal SET name = ? WHERE category = %s", ("coily", "reptile")) + self.assertEquals(rows, 1) + + name = self.databaseTemplate.query_for_object("SELECT name FROM animal WHERE category = 'reptile'", required_type=types.StringType) + self.assertEquals(name, "coily") + + def testProgrammaticStaticInsert(self): + self.databaseTemplate.execute("DELETE FROM animal") + rows = self.databaseTemplate.execute("INSERT INTO animal (name, category, population) VALUES ('black mamba', 'kill_bill_viper', 1)") + self.assertEquals(rows, 1) + + name = self.databaseTemplate.query_for_object("SELECT name FROM animal WHERE category = 'kill_bill_viper'", required_type=types.StringType) + self.assertEquals(name, "black mamba") + + def testProgrammaticStaticInsertWithInsertApi(self): + self.databaseTemplate.execute("DELETE FROM animal") + id = self.databaseTemplate.insert_and_return_id("INSERT INTO animal (name, category, population) VALUES ('black mamba', 'kill_bill_viper', 1)") + self.assertEquals(id, 1) + + name = self.databaseTemplate.query_for_object("SELECT name FROM animal WHERE category = 'kill_bill_viper'", required_type=types.StringType) + self.assertEquals(name, "black mamba") + + def testProgrammaticInsertWithBoundVariables(self): + self.databaseTemplate.execute("DELETE FROM animal") + rows = self.databaseTemplate.execute("INSERT INTO animal (name, category, population) VALUES (?, ?, ?)", ('black mamba', 'kill_bill_viper', 1)) + self.assertEquals(rows, 1) + + name = self.databaseTemplate.query_for_object("SELECT name FROM animal WHERE category = 'kill_bill_viper'", required_type=types.StringType) + self.assertEquals(name, "black mamba") + + rows = self.databaseTemplate.execute("INSERT INTO animal (name, category, population) VALUES (%s, %s, %s)", ('cottonmouth', 'kill_bill_viper', 1)) + self.assertEquals(rows, 1) + + name = self.databaseTemplate.query_for_object("select name from animal where name = 'cottonmouth'", required_type=types.StringType) + self.assertEquals(name, "cottonmouth") + + def testProgrammaticInsertWithBoundVariablesWithInsertApi(self): + self.databaseTemplate.execute("DELETE FROM animal") + id = self.databaseTemplate.insert_and_return_id("INSERT INTO animal (name, category, population) VALUES (?, ?, ?)", ('black mamba', 'kill_bill_viper', 1)) + self.assertEquals(id, 1) + + name = self.databaseTemplate.query_for_object("SELECT name FROM animal WHERE category = 'kill_bill_viper'", required_type=types.StringType) + self.assertEquals(name, "black mamba") + + id = self.databaseTemplate.insert_and_return_id("INSERT INTO animal (name, category, population) VALUES (%s, %s, %s)", ('cottonmouth', 'kill_bill_viper', 1)) + self.assertEquals(id, 2) + + name = self.databaseTemplate.query_for_object("select name from animal where name = 'cottonmouth'", required_type=types.StringType) + self.assertEquals(name, "cottonmouth") + +class MySQLDatabaseTemplateTestCase(AbstractDatabaseTemplateTestCase): + def __init__(self, methodName='runTest'): + AbstractDatabaseTemplateTestCase.__init__(self, methodName) + + def createTables(self): + self.createdTables = True + try: + self.factory = factory.MySQLConnectionFactory("springpython", "springpython", "localhost", "springpython") + dt = DatabaseTemplate(self.factory) + dt.execute("DROP TABLE IF EXISTS animal") + dt.execute(""" + CREATE TABLE animal ( + id serial PRIMARY KEY, + name VARCHAR(11), + category VARCHAR(20), + population SMALLINT + ) ENGINE=innodb + """) + self.factory.commit() + + except Exception, e: + print(""" + !!! Can't run MySQLDatabaseTemplateTestCase !!! + + This assumes you have executed some step like: + % sudo apt-get install mysql (Ubuntu) + % apt-get install mysql (Debian) + + And then created a database for the spring python user: + % mysql -uroot + mysql> DROP DATABASE IF EXISTS springpython; + mysql> CREATE DATABASE springpython; + mysql> GRANT ALL ON springpython.* TO springpython@localhost IDENTIFIED BY 'springpython'; + + That should setup the springpython user to be able to create tables as needed for these test cases. + """) + raise e + + def testIoCGeneralQuery(self): + appContext = ApplicationContext(XMLConfig("support/databaseTestMySQLApplicationContext.xml")) + factory = appContext.get_object("connection_factory") + + databaseTemplate = DatabaseTemplate(factory) + results = databaseTemplate.query("select * from animal", rowhandler=testSupportClasses.SampleRowMapper()) + + def testIoCGeneralQueryWithSimpleRowMapper(self): + appContext = ApplicationContext(XMLConfig("support/databaseTestMySQLApplicationContext.xml")) + factory = appContext.get_object("connection_factory") + + databaseTemplate = DatabaseTemplate(factory) + results = databaseTemplate.query("select * from animal", rowhandler=SimpleRowMapper(testSupportClasses.Person)) + + def testIoCGeneralQueryWithDictionaryRowMapper(self): + appContext = ApplicationContext(XMLConfig("support/databaseTestMySQLApplicationContext.xml")) + factory = appContext.get_object("connection_factory") + + databaseTemplate = DatabaseTemplate(factory) + results = databaseTemplate.query("select * from animal", rowhandler=DictionaryRowMapper()) + +class PostGreSQLDatabaseTemplateTestCase(AbstractDatabaseTemplateTestCase): + def __init__(self, methodName='runTest'): + AbstractDatabaseTemplateTestCase.__init__(self, methodName) + + def createTables(self): + self.createdTables = True + try: + self.factory = factory.PgdbConnectionFactory("springpython", "springpython", "localhost", "springpython") + dt = DatabaseTemplate(self.factory) + dt.execute("DROP TABLE IF EXISTS animal") + dt.execute(""" + CREATE TABLE animal ( + id serial PRIMARY KEY, + name VARCHAR(11), + category VARCHAR(20), + population integer + ) + """) + self.factory.commit() + + except Exception, e: + print(""" + !!! Can't run PostGreSQLDatabaseTemplateTestCase !!! + + This assumes you have executed some step like: + % sudo apt-get install postgresql (Ubuntu) + % apt-get install postgresql (Debian) + + Next, you need to let PostGreSQL's accounts be decoupled from the system accounts. + Find pg_hba.conf underneath /etc and add something like this: + # TYPE DATABASE USER IP-ADDRESS IP-MASK METHOD + host all all md5 + + Then, restart it. + % sudo /etc/init.d/postgresql restart (Ubuntu) + + Then create a user database to match this account. + % sudo -u postgres psql -f support/setupPostGreSQLSpringPython.sql + + From here on, you should be able to connect into PSQL and run SQL scripts. + """) + raise e + + def testIoCGeneralQuery(self): + appContext = ApplicationContext(XMLConfig("support/databaseTestPGApplicationContext.xml")) + factory = appContext.get_object("connection_factory") + + databaseTemplate = DatabaseTemplate(factory) + results = databaseTemplate.query("select * from animal", rowhandler=testSupportClasses.SampleRowMapper()) + +class SqliteDatabaseTemplateTestCase(AbstractDatabaseTemplateTestCase): + def __init__(self, methodName='runTest'): + AbstractDatabaseTemplateTestCase.__init__(self, methodName) + self.db_filename = "springpython.db" + + def createTables(self): + self.createdTables = True + try: + try: + os.remove(self.db_filename) + except OSError: + pass + self.factory = factory.Sqlite3ConnectionFactory(self.db_filename) + dt = DatabaseTemplate(self.factory) + + dt.execute("DROP TABLE IF EXISTS animal") + + dt.execute(""" + CREATE TABLE animal ( + id serial PRIMARY KEY, + name VARCHAR(11), + category VARCHAR(20), + population integer + ) + """) + self.factory.commit() + + except Exception, e: + print(""" + !!! Can't run SqliteDatabaseTemplateTestCase !!! + """) + raise e + + def testIoCGeneralQuery(self): + appContext = ApplicationContext(XMLConfig("support/databaseTestSqliteApplicationContext.xml")) + factory = appContext.get_object("connection_factory") + + databaseTemplate = DatabaseTemplate(factory) + + databaseTemplate.execute("DROP TABLE IF EXISTS animal") + databaseTemplate.execute(""" + CREATE TABLE animal ( + id serial PRIMARY KEY, + name VARCHAR(11), + category VARCHAR(20), + population integer + ) + """) + factory.commit() + databaseTemplate.execute("DELETE FROM animal") + factory.commit() + self.assertEquals(len(databaseTemplate.query_for_list("SELECT * FROM animal")), 0) + databaseTemplate.execute("INSERT INTO animal (name, category, population) VALUES ('snake', 'reptile', 1)") + databaseTemplate.execute("INSERT INTO animal (name, category, population) VALUES ('racoon', 'mammal', 0)") + databaseTemplate.execute ("INSERT INTO animal (name, category, population) VALUES ('black mamba', 'kill_bill_viper', 1)") + databaseTemplate.execute ("INSERT INTO animal (name, category, population) VALUES ('cottonmouth', 'kill_bill_viper', 1)") + factory.commit() + self.assertEquals(len(databaseTemplate.query_for_list("SELECT * FROM animal")), 4) + + results = databaseTemplate.query("select * from animal", rowhandler=testSupportClasses.SampleRowMapper()) + + def testIoCGeneralQueryWithSimpleRowMapper(self): + appContext = ApplicationContext(XMLConfig("support/databaseTestSqliteApplicationContext.xml")) + factory = appContext.get_object("connection_factory") + + databaseTemplate = DatabaseTemplate(factory) + + databaseTemplate.execute("DROP TABLE IF EXISTS animal") + databaseTemplate.execute(""" + CREATE TABLE animal ( + id serial PRIMARY KEY, + name VARCHAR(11), + category VARCHAR(20), + population integer + ) + """) + factory.commit() + databaseTemplate.execute("DELETE FROM animal") + factory.commit() + self.assertEquals(len(databaseTemplate.query_for_list("SELECT * FROM animal")), 0) + databaseTemplate.execute("INSERT INTO animal (name, category, population) VALUES ('snake', 'reptile', 1)") + databaseTemplate.execute("INSERT INTO animal (name, category, population) VALUES ('racoon', 'mammal', 0)") + databaseTemplate.execute ("INSERT INTO animal (name, category, population) VALUES ('black mamba', 'kill_bill_viper', 1)") + databaseTemplate.execute ("INSERT INTO animal (name, category, population) VALUES ('cottonmouth', 'kill_bill_viper', 1)") + factory.commit() + self.assertEquals(len(databaseTemplate.query_for_list("SELECT * FROM animal")), 4) + + results = databaseTemplate.query("select * from animal", rowhandler=SimpleRowMapper(testSupportClasses.Person)) + + def testIoCGeneralQueryWithDictionaryRowMapper(self): + appContext = ApplicationContext(XMLConfig("support/databaseTestSqliteApplicationContext.xml")) + factory = appContext.get_object("connection_factory") + + databaseTemplate = DatabaseTemplate(factory) + + databaseTemplate.execute("DROP TABLE IF EXISTS animal") + databaseTemplate.execute(""" + CREATE TABLE animal ( + id serial PRIMARY KEY, + name VARCHAR(11), + category VARCHAR(20), + population integer + ) + """) + factory.commit() + databaseTemplate.execute("DELETE FROM animal") + factory.commit() + self.assertEquals(len(databaseTemplate.query_for_list("SELECT * FROM animal")), 0) + databaseTemplate.execute("INSERT INTO animal (name, category, population) VALUES ('snake', 'reptile', 1)") + databaseTemplate.execute("INSERT INTO animal (name, category, population) VALUES ('racoon', 'mammal', 0)") + databaseTemplate.execute ("INSERT INTO animal (name, category, population) VALUES ('black mamba', 'kill_bill_viper', 1)") + databaseTemplate.execute ("INSERT INTO animal (name, category, population) VALUES ('cottonmouth', 'kill_bill_viper', 1)") + factory.commit() + self.assertEquals(len(databaseTemplate.query_for_list("SELECT * FROM animal")), 4) + + results = databaseTemplate.query("select * from animal", rowhandler=DictionaryRowMapper()) + + +class SQLServerDatabaseTemplateTestCase(AbstractDatabaseTemplateTestCase): + def __init__(self, methodName='runTest'): + AbstractDatabaseTemplateTestCase.__init__(self, methodName) + + def createTables(self): + self.createdTables = True + try: + self.factory = factory.SQLServerConnectionFactory(DRIVER="{SQL Server}", + SERVER="localhost", DATABASE="springpython", UID="springpython", PWD="cdZS*RQRBdc9a") + dt = DatabaseTemplate(self.factory) + dt.execute("""IF EXISTS(SELECT 1 FROM sys.tables WHERE name='animal') + DROP TABLE animal""") + + dt.execute(""" + CREATE TABLE animal ( + id INTEGER IDENTITY(1,1) PRIMARY KEY, + name VARCHAR(11), + category VARCHAR(20), + population INTEGER + ) + """) + + self.factory.commit() + + except Exception, e: + print(""" + !!! Can't run SQLServerDatabaseTemplateTestCase !!! + + This assumes you have installed pyodbc (http://code.google.com/p/pyodbc/). + + And then created an SQL Server database for the 'springpython' + login and user. + + USE master; + + IF EXISTS(SELECT 1 FROM sys.databases WHERE name='springpython') + DROP DATABASE springpython; + + IF EXISTS(SELECT 1 FROM sys.syslogins WHERE name='springpython') + DROP LOGIN springpython; + + IF EXISTS(SELECT 1 FROM sys.sysusers WHERE name='springpython') + DROP USER springpython; + + CREATE DATABASE springpython; + CREATE LOGIN springpython WITH PASSWORD='cdZS*RQRBdc9a', DEFAULT_DATABASE=springpython; + + USE springpython; + + CREATE USER springpython FOR LOGIN springpython; + EXEC sp_addrolemember 'db_owner', 'springpython'; + + From here on, you should be able to connect into SQL Server and run SQL scripts. + """) + raise e + + def testIoCGeneralQuery(self): + appContext = ApplicationContext(XMLConfig("support/databaseTestSQLServerApplicationContext.xml")) + factory = appContext.get_object("connection_factory") + + databaseTemplate = DatabaseTemplate(factory) + results = databaseTemplate.query("select * from animal", rowhandler=testSupportClasses.SampleRowMapper()) + diff --git a/test/springpythontest/databaseTransactionTestCases.py b/test/springpythontest/databaseTransactionTestCases.py index 347d4ff..7d8fb78 100644 --- a/test/springpythontest/databaseTransactionTestCases.py +++ b/test/springpythontest/databaseTransactionTestCases.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ +from __future__ import print_function import logging import os import subprocess @@ -61,14 +62,14 @@ def testInsertingRowsIntoTheDatabase(self): rows = self.dt.execute("INSERT INTO animal (name) VALUES (?)", ('black mamba',)) self.assertEquals(rows, 1) - name = self.dt.query_for_object("SELECT name FROM animal WHERE name = 'black mamba'", required_type=types.StringType) + name = self.dt.query_for_object("SELECT name FROM animal WHERE name = 'black mamba'", required_type=bytes) self.assertEquals(name, "black mamba") def testInsertingRowsIntoTheDatabaseWithInsertApi(self): id = self.dt.insert_and_return_id("INSERT INTO animal (name) VALUES (?)", ('black mamba',)) self.assertEquals(id, 1) - name = self.dt.query_for_object("SELECT name FROM animal WHERE name = 'black mamba'", required_type=types.StringType) + name = self.dt.query_for_object("SELECT name FROM animal WHERE name = 'black mamba'", required_type=bytes) self.assertEquals(name, "black mamba") def testInsertingTwoRowsWithoutaTransactionButManuallyCommitted(self): @@ -108,7 +109,7 @@ class txDefinition(TransactionCallback): def do_in_transaction(s, status): self.dt.execute("INSERT INTO animal (name) VALUES (?)", ('black mamba',)) self.dt.execute("INSERT INTO animal (name) VALUES (?)", ('copperhead',)) - results = self.dt.query_for_object("SELECT name FROM animal WHERE name like 'c%'", required_type=types.StringType) + results = self.dt.query_for_object("SELECT name FROM animal WHERE name like 'c%'", required_type=bytes) return results self.assertEquals(self.transactionTemplate.execute(txDefinition()), "copperhead") @@ -417,7 +418,7 @@ def createTables(self): """) self.factory.commit() - except Exception, e: + except Exception as e: print(""" !!! Can't run MySQLDatabaseTemplateTestCase !!! @@ -464,7 +465,7 @@ def createTables(self): """) self.factory.commit() - except Exception, e: + except Exception as e: print(""" !!! Can't run PostGreSQLTransactionTestCase !!! @@ -499,7 +500,7 @@ def createTables(self): try: try: os.remove(self.db_filename) - except OSError, e: + except OSError as e: pass self.factory = factory.Sqlite3ConnectionFactory(self.db_filename) @@ -523,11 +524,11 @@ def createTables(self): """) self.factory.commit() - except Exception, e: - print e.message - print e.args - print type(e) - print dir(e) + except Exception as e: + print(e.message) + print(e.args) + print(type(e)) + print(dir(e)) print(""" !!! Can't run SqliteTransactionTestCase !!! @@ -571,7 +572,7 @@ def createTables(self): self.factory.commit() - except Exception, e: + except Exception as e: print(""" !!! Can't run SQLServerDatabaseTemplateTestCase !!! diff --git a/test/springpythontest/databaseTransactionTestCases.py.bak b/test/springpythontest/databaseTransactionTestCases.py.bak new file mode 100644 index 0000000..347d4ff --- /dev/null +++ b/test/springpythontest/databaseTransactionTestCases.py.bak @@ -0,0 +1,604 @@ +""" + Copyright 2006-2008 SpringSource (http://springsource.com), All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +import logging +import os +import subprocess +import types +import unittest +from springpython.context import ApplicationContext +from springpython.database import factory +from springpython.database import DataAccessException +from springpython.database.core import DatabaseTemplate +from springpython.database.transaction import ConnectionFactoryTransactionManager +from springpython.database.transaction import TransactionTemplate +from springpython.database.transaction import TransactionCallback +from springpython.database.transaction import TransactionCallbackWithoutResult +from springpython.database.transaction import TransactionPropagationException +from springpythontest.support.testSupportClasses import DatabaseTxTestAppContext +from springpythontest.support.testSupportClasses import DatabaseTxTestDecorativeTransactions +from springpythontest.support.testSupportClasses import DatabaseTxTestDecorativeTransactionsWithNoArguments +from springpythontest.support.testSupportClasses import DatabaseTxTestDecorativeTransactionsWithLotsOfArguments +from springpythontest.support.testSupportClasses import DatabaseTxTestAppContextWithNoAutoTransactionalObject +from springpythontest.support.testSupportClasses import BankException +from springpythontest.support.testSupportClasses import TransactionalBank + +class AbstractTransactionTestCase(unittest.TestCase): + + def __init__(self, methodName='runTest'): + unittest.TestCase.__init__(self, methodName) + self.factory = None + self.createdTables = False + + def setUp(self): + if not self.createdTables: + self.createTables() + self.createTables() + self.dt = DatabaseTemplate(self.factory) + self.dt.execute("DELETE FROM animal") + self.dt.execute("DELETE FROM account") + self.factory.commit() + self.assertEquals(len(self.dt.query_for_list("SELECT * FROM animal")), 0) + self.transactionManager = ConnectionFactoryTransactionManager(self.factory) + self.transactionTemplate = TransactionTemplate(self.transactionManager) + + def tearDown(self): + self.factory.getConnection().rollback() + + def testInsertingRowsIntoTheDatabase(self): + rows = self.dt.execute("INSERT INTO animal (name) VALUES (?)", ('black mamba',)) + self.assertEquals(rows, 1) + + name = self.dt.query_for_object("SELECT name FROM animal WHERE name = 'black mamba'", required_type=types.StringType) + self.assertEquals(name, "black mamba") + + def testInsertingRowsIntoTheDatabaseWithInsertApi(self): + id = self.dt.insert_and_return_id("INSERT INTO animal (name) VALUES (?)", ('black mamba',)) + self.assertEquals(id, 1) + + name = self.dt.query_for_object("SELECT name FROM animal WHERE name = 'black mamba'", required_type=types.StringType) + self.assertEquals(name, "black mamba") + + def testInsertingTwoRowsWithoutaTransactionButManuallyCommitted(self): + self.dt.execute("INSERT INTO animal (name) VALUES (?)", ('black mamba',)) + self.dt.execute("INSERT INTO animal (name) VALUES (?)", ('copperhead',)) + self.factory.commit() + self.assertEquals(len(self.dt.query_for_list("SELECT * FROM animal")), 2) + + def testInsertingTwoRowsWithoutaTransactionButManuallyRolledBack(self): + self.dt.execute("INSERT INTO animal (name) VALUES (?)", ('black mamba',)) + self.dt.execute("INSERT INTO animal (name) VALUES (?)", ('copperhead',)) + self.assertEquals(len(self.dt.query_for_list("SELECT * FROM animal")), 2) + self.dt.connection_factory.getConnection().rollback() + self.assertEquals(len(self.dt.query_for_list("SELECT * FROM animal")), 0) + + def testInsertingTwoRowsWithaTransactionAndNoErrorsAndNoResults(self): + class txDefinition(TransactionCallbackWithoutResult): + def do_in_tx_without_result(s, status): + self.dt.execute("INSERT INTO animal (name) VALUES (?)", ('black mamba',)) + self.dt.execute("INSERT INTO animal (name) VALUES (?)", ('copperhead',)) + + self.transactionTemplate.execute(txDefinition()) + self.assertEquals(len(self.dt.query_for_list("SELECT * FROM animal")), 2) + + def testInsertingTwoRowsWithaTransactionAndAnIntermediateErrorAndNoResults(self): + class txDefinition(TransactionCallbackWithoutResult): + def do_in_tx_without_result(s, status): + self.dt.execute("INSERT INTO animal (name) VALUES (?)", ('black mamba',)) + self.assertEquals(len(self.dt.query_for_list("SELECT * FROM animal")), 1) + raise DataAccessException("This should break the transaction, and rollback the insert.") + + self.assertRaises(DataAccessException, self.transactionTemplate.execute, txDefinition()) + self.assertEquals(len(self.dt.query_for_list("SELECT * FROM animal")), 0) + + def testInsertingTwoRowsWithaTransactionAndNoErrorsAndResults(self): + class txDefinition(TransactionCallback): + def do_in_transaction(s, status): + self.dt.execute("INSERT INTO animal (name) VALUES (?)", ('black mamba',)) + self.dt.execute("INSERT INTO animal (name) VALUES (?)", ('copperhead',)) + results = self.dt.query_for_object("SELECT name FROM animal WHERE name like 'c%'", required_type=types.StringType) + return results + + self.assertEquals(self.transactionTemplate.execute(txDefinition()), "copperhead") + self.assertEquals(len(self.dt.query_for_list("SELECT * FROM animal")), 2) + + def testInsertingTwoRowsWithaTransactionAndAnIntermediateErrorAndResults(self): + class txDefinition(TransactionCallback): + def do_in_transaction(s, status): + self.dt.execute("INSERT INTO animal (name) VALUES (?)", ('black mamba')) + self.assertEquals(len(self.dt.query_for_list("SELECT * FROM animal")), 1) + raise DataAccessException("This should break the transaction, and rollback the insert.") + + self.assertRaises(DataAccessException, self.transactionTemplate.execute, txDefinition()) + self.assertEquals(len(self.dt.query_for_list("SELECT * FROM animal")), 0) + + def testDeclarativeTransactions(self): + appContext = ApplicationContext(DatabaseTxTestAppContext(self.factory)) + bank = appContext.get_object("bank") + + bank.open("Checking") + bank.open("Savings") + + bank.deposit(125.00, "Checking") + self.assertEquals(bank.balance("Checking"), 125.00) + + bank.deposit(250.00, "Savings") + self.assertEquals(bank.balance("Savings"), 250.00) + + bank.transfer(25.00, "Savings", "Checking") + self.assertEquals(bank.balance("Savings"), 225.00) + self.assertEquals(bank.balance("Checking"), 150.00) + + bank.withdraw(10.00, "Checking") + self.assertEquals(bank.balance("Checking"), 140.00) + + amount = 0.0 + try: + amount = bank.withdraw(1000, "Nowhere") + self.fail("Expected a BankException!") + except BankException: + pass + self.assertEquals(amount, 0.0) + + self.assertEquals(bank.balance("Savings"), 225.00) + self.assertEquals(bank.balance("Checking"), 140.00) + + try: + bank.transfer(200, "Checking", "Nowhere") + self.fail("Expected a BankException!") + except BankException: + pass + + self.assertEquals(bank.balance("Savings"), 225.00, "Bad transfer did NOT fail atomically!") + self.assertEquals(bank.balance("Checking"), 140.00, "Bad transfer did NOT fail atomically!") + + def testDecoratorBasedTransactions(self): + appContext = ApplicationContext(DatabaseTxTestDecorativeTransactions(self.factory)) + bank = appContext.get_object("bank") + + bank.open("Checking") + bank.open("Savings") + + bank.deposit(125.00, "Checking") + self.assertEquals(bank.balance("Checking"), 125.00) + + bank.deposit(250.00, "Savings") + self.assertEquals(bank.balance("Savings"), 250.00) + + bank.transfer(25.00, "Savings", "Checking") + self.assertEquals(bank.balance("Savings"), 225.00) + self.assertEquals(bank.balance("Checking"), 150.00) + + bank.withdraw(10.00, "Checking") + self.assertEquals(bank.balance("Checking"), 140.00) + + amount = 0.0 + try: + amount = bank.withdraw(1000, "Nowhere") + self.fail("Expected a BankException!") + except BankException: + pass + self.assertEquals(amount, 0.0) + + self.assertEquals(bank.balance("Savings"), 225.00) + self.assertEquals(bank.balance("Checking"), 140.00) + + try: + bank.transfer(200, "Checking", "Nowhere") + self.fail("Expected a BankException!") + except BankException: + pass + self.assertEquals(bank.balance("Savings"), 225.00, "Bad transfer did NOT fail atomically!") + self.assertEquals(bank.balance("Checking"), 140.00, "Bad transfer did NOT fail atomically!") + + def testDecoratorBasedTransactionsWithNoArguments(self): + appContext = ApplicationContext(DatabaseTxTestDecorativeTransactionsWithNoArguments(self.factory)) + bank = appContext.get_object("bank") + + bank.open("Checking") + bank.open("Savings") + + bank.deposit(125.00, "Checking") + self.assertEquals(bank.balance("Checking"), 125.00) + + bank.deposit(250.00, "Savings") + self.assertEquals(bank.balance("Savings"), 250.00) + + bank.transfer(25.00, "Savings", "Checking") + self.assertEquals(bank.balance("Savings"), 225.00) + self.assertEquals(bank.balance("Checking"), 150.00) + + bank.withdraw(10.00, "Checking") + self.assertEquals(bank.balance("Checking"), 140.00) + + amount = 0.0 + try: + amount = bank.withdraw(1000, "Nowhere") + self.fail("Expected a BankException!") + except BankException: + pass + self.assertEquals(amount, 0.0) + + self.assertEquals(bank.balance("Savings"), 225.00) + self.assertEquals(bank.balance("Checking"), 140.00) + + try: + bank.transfer(200, "Checking", "Nowhere") + self.fail("Expected a BankException!") + except BankException: + pass + + self.assertEquals(bank.balance("Savings"), 225.00, "Bad transfer did NOT fail atomically!") + self.assertEquals(bank.balance("Checking"), 140.00, "Bad transfer did NOT fail atomically!") + + def testDecoratorBasedTransactionsWithLotsOfArguments(self): + appContext = ApplicationContext(DatabaseTxTestDecorativeTransactionsWithLotsOfArguments(self.factory)) + bank = appContext.get_object("bank") + + bank.open("Checking") + bank.open("Savings") + + bank.deposit(125.00, "Checking") + self.assertEquals(bank.balance("Checking"), 125.00) + + bank.deposit(250.00, "Savings") + self.assertEquals(bank.balance("Savings"), 250.00) + + bank.transfer(25.00, "Savings", "Checking") + self.assertEquals(bank.balance("Savings"), 225.00) + self.assertEquals(bank.balance("Checking"), 150.00) + + bank.withdraw(10.00, "Checking") + self.assertEquals(bank.balance("Checking"), 140.00) + + amount = 0.0 + try: + amount = bank.withdraw(1000, "Nowhere") + self.fail("Expected a BankException!") + except BankException: + pass + self.assertEquals(amount, 0.0) + + self.assertEquals(bank.balance("Savings"), 225.00) + self.assertEquals(bank.balance("Checking"), 140.00) + + try: + bank.transfer(200, "Checking", "Nowhere") + self.fail("Expected a BankException!") + except BankException: + pass + + self.assertEquals(bank.balance("Savings"), 225.00, "Bad transfer did NOT fail atomically!") + logging.getLogger("springpythontest.databaseTransactionTestCases").debug(bank.balance("Checking")) + self.assertEquals(bank.balance("Checking"), 140.00, "Bad transfer did NOT fail atomically!") + + def testOtherPropagationLevels(self): + appContext = ApplicationContext(DatabaseTxTestDecorativeTransactionsWithLotsOfArguments(self.factory)) + bank = appContext.get_object("bank") + + # Call a mandatory operation outside a transaction, and verify it fails. + try: + bank.mandatoryOperation() + self.fail("Expected a TransactionPropagationException!") + except TransactionPropagationException: + pass + + # Call a mandatory operation from within a transactional routine, and verify it works. + bank.mandatoryOperationTransactionalWrapper() + + # Call a non-transactional operation from outside a transaction, and verify it works. + bank.nonTransactionalOperation() + + # Call a non-tranactional operation from within a transaction, and verify it fails. + try: + bank.nonTransactionalOperationTransactionalWrapper() + self.fail("Expected a TransactionPropagationException!") + except TransactionPropagationException: + pass + + def testTransactionProxyMethodFilters(self): + appContext = ApplicationContext(DatabaseTxTestAppContext(self.factory)) + bank = appContext.get_object("bank") + + bank.open("Checking") + bank.open("Savings") + + bank.deposit(125.00, "Checking") + self.assertEquals(bank.balance("Checking"), 125.00) + + bank.deposit(250.00, "Savings") + self.assertEquals(bank.balance("Savings"), 250.00) + + bank.transfer(25.00, "Savings", "Checking") + self.assertEquals(bank.balance("Savings"), 225.00) + self.assertEquals(bank.balance("Checking"), 150.00) + + bank.withdraw(10.00, "Checking") + self.assertEquals(bank.balance("Checking"), 140.00) + + amount = 0.0 + try: + amount = bank.withdraw(1000, "Nowhere") + self.fail("Expected a BankException!") + except BankException: + pass + self.assertEquals(amount, 0.0) + + self.assertEquals(bank.balance("Savings"), 225.00) + self.assertEquals(bank.balance("Checking"), 140.00) + + try: + bank.transfer(200, "Checking", "Nowhere") + self.fail("Expected a BankException!") + except BankException: + pass + + self.assertEquals(bank.balance("Savings"), 225.00, "Bad transfer did NOT fail atomically!") + self.assertEquals(bank.balance("Checking"), 140.00, "Bad transfer did NOT fail atomically!") + + def testTransactionalBankWithNoAutoTransactionalObject(self): + appContext = ApplicationContext(DatabaseTxTestAppContextWithNoAutoTransactionalObject(self.factory)) + bank = appContext.get_object("bank") + + bank.open("Checking") + bank.open("Savings") + + bank.deposit(125.00, "Checking") + self.assertEquals(bank.balance("Checking"), 125.00) + + bank.deposit(250.00, "Savings") + self.assertEquals(bank.balance("Savings"), 250.00) + + bank.transfer(25.00, "Savings", "Checking") + self.assertEquals(bank.balance("Savings"), 225.00) + self.assertEquals(bank.balance("Checking"), 150.00) + + bank.withdraw(10.00, "Checking") + self.assertEquals(bank.balance("Checking"), 140.00) + + amount = 0.0 + try: + amount = bank.withdraw(1000, "Nowhere") + self.fail("Expected a BankException!") + except BankException: + pass + self.assertEquals(amount, 0.0) + + self.assertEquals(bank.balance("Savings"), 225.00) + self.assertEquals(bank.balance("Checking"), 140.00) + + try: + bank.transfer(200, "Checking", "Nowhere") + self.fail("Expected a BankException!") + except BankException: + pass + + self.assertEquals(bank.balance("Savings"), 225.00, "Bad transfer did NOT fail atomically!") + self.assertEquals(bank.balance("Checking"), -60.00, "Bad transfer did NOT fail as expected (not atomically due to lack of AutoTransactionalObject)") + +class MySQLTransactionTestCase(AbstractTransactionTestCase): + + def __init__(self, methodName='runTest'): + AbstractTransactionTestCase.__init__(self, methodName) + + def createTables(self): + self.createdTables = True + try: + self.factory = factory.MySQLConnectionFactory("springpython", "springpython", "localhost", "springpython") + dt = DatabaseTemplate(self.factory) + dt.execute("DROP TABLE IF EXISTS animal") + dt.execute(""" + CREATE TABLE animal ( + id serial PRIMARY KEY, + name VARCHAR(11), + category VARCHAR(20), + population SMALLINT + ) ENGINE=innodb + """) + dt.execute("DROP TABLE IF EXISTS account") + dt.execute(""" + CREATE TABLE account ( + id INT(4) UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY, + account_num VARCHAR(11), + balance FLOAT(10) + ) ENGINE=innodb + """) + self.factory.commit() + + except Exception, e: + print(""" + !!! Can't run MySQLDatabaseTemplateTestCase !!! + + This assumes you have executed some step like: + % sudo apt-get install mysql (Ubuntu) + % apt-get install mysql (Debian) + + And then created a database for the spring python user: + % mysql -uroot + mysql> DROP DATABASE IF EXISTS springpython; + mysql> CREATE DATABASE springpython; + mysql> GRANT ALL ON springpython.* TO springpython@localhost IDENTIFIED BY 'springpython'; + + That should setup the springpython user to be able to create tables as needed for these test cases. + """) + raise e + + +class PostGreSQLTransactionTestCase(AbstractTransactionTestCase): + + def __init__(self, methodName='runTest'): + AbstractTransactionTestCase.__init__(self, methodName) + + def createTables(self): + self.createdTables = True + try: + self.factory = factory.PgdbConnectionFactory("springpython", "springpython", "localhost", "springpython") + dt = DatabaseTemplate(self.factory) + + dt.execute("DROP TABLE IF EXISTS animal") + dt.execute(""" + CREATE TABLE animal ( + id serial PRIMARY KEY, + name VARCHAR(11) + ) + """) + dt.execute("DROP TABLE IF EXISTS account") + dt.execute(""" + CREATE TABLE account ( + id serial PRIMARY KEY, + account_num VARCHAR(11), + balance FLOAT(10) + ) + """) + self.factory.commit() + + except Exception, e: + print(""" + !!! Can't run PostGreSQLTransactionTestCase !!! + + This assumes you have executed some step like: + % sudo apt-get install postgresql (Ubuntu) + % apt-get install postgresql (Debian) + + Next, you need to let PostGreSQL's accounts be decoupled from the system accounts. + Find pg_hba.conf underneath /etc and add something like this: + # TYPE DATABASE USER IP-ADDRESS IP-MASK METHOD + host all all md5 + + Then, restart it. + % sudo /etc/init.d/postgresql restart (Ubuntu) + + Then create a user database to match this account. + % sudo -u postgres psql -f support/setupPostGreSQLSpringPython.sql + + From here on, you should be able to connect into PSQL and run SQL scripts. + """) + raise e + + +class SqliteTransactionTestCase(AbstractTransactionTestCase): + + def __init__(self, methodName='runTest'): + AbstractTransactionTestCase.__init__(self, methodName) + self.db_filename = "springpython.db" + + def createTables(self): + self.createdTables = True + try: + try: + os.remove(self.db_filename) + except OSError, e: + pass + + self.factory = factory.Sqlite3ConnectionFactory(self.db_filename) + dt = DatabaseTemplate(self.factory) + + dt.execute("DROP TABLE IF EXISTS animal") + dt.execute("DROP TABLE IF EXISTS account") + + dt.execute(""" + CREATE TABLE animal ( + id serial PRIMARY KEY, + name VARCHAR(11) + ) + """) + dt.execute(""" + CREATE TABLE account ( + id serial PRIMARY KEY, + account_num VARCHAR(11), + balance FLOAT(10) + ) + """) + self.factory.commit() + + except Exception, e: + print e.message + print e.args + print type(e) + print dir(e) + print(""" + !!! Can't run SqliteTransactionTestCase !!! + + """) + raise e + +class SQLServerTransactionTestCase(AbstractTransactionTestCase): + + def __init__(self, methodName='runTest'): + AbstractTransactionTestCase.__init__(self, methodName) + + def createTables(self): + self.createdTables = True + try: + self.factory = factory.SQLServerConnectionFactory(DRIVER="{SQL Server}", + SERVER="localhost", DATABASE="springpython", UID="springpython", PWD="cdZS*RQRBdc9a") + dt = DatabaseTemplate(self.factory) + + dt.execute("""IF EXISTS(SELECT 1 FROM sys.tables WHERE name='animal') + DROP TABLE animal""") + + dt.execute(""" + CREATE TABLE animal ( + id INTEGER IDENTITY(1,1) PRIMARY KEY, + name VARCHAR(11), + category VARCHAR(20), + population INTEGER + ) + """) + + dt.execute("""IF EXISTS(SELECT 1 FROM sys.tables WHERE name='account') + DROP TABLE account""") + + dt.execute(""" + CREATE TABLE account ( + id INTEGER IDENTITY(1,1) PRIMARY KEY, + account_num VARCHAR(11), + balance FLOAT(10) + ) + """) + + self.factory.commit() + + except Exception, e: + print(""" + !!! Can't run SQLServerDatabaseTemplateTestCase !!! + + This assumes you have installed pyodbc (http://code.google.com/p/pyodbc/). + + And then created an SQL Server database for the 'springpython' + login and user. + + USE master; + + IF EXISTS(SELECT 1 FROM sys.databases WHERE name='springpython') + DROP DATABASE springpython; + + IF EXISTS(SELECT 1 FROM sys.syslogins WHERE name='springpython') + DROP LOGIN springpython; + + IF EXISTS(SELECT 1 FROM sys.sysusers WHERE name='springpython') + DROP USER springpython; + + CREATE DATABASE springpython; + CREATE LOGIN springpython WITH PASSWORD='cdZS*RQRBdc9a', DEFAULT_DATABASE=springpython; + + USE springpython; + + CREATE USER springpython FOR LOGIN springpython; + EXEC sp_addrolemember 'db_owner', 'springpython'; + + From here on, you should be able to connect into SQL Server and run SQL scripts. + """) + raise e diff --git a/test/springpythontest/jms_websphere_mq_test_cases.py b/test/springpythontest/jms_websphere_mq_test_cases.py index 3a65d12..d1509b9 100644 --- a/test/springpythontest/jms_websphere_mq_test_cases.py +++ b/test/springpythontest/jms_websphere_mq_test_cases.py @@ -1104,14 +1104,14 @@ def testWebSphereMQJMSException(self): try: raise WebSphereMQJMSException() - except WebSphereMQJMSException, e: + except WebSphereMQJMSException as e: self.assertEquals(e.completion_code, None) self.assertEquals(e.reason_code, None) self.assertEquals(e.message, None) try: raise WebSphereMQJMSException(message) - except WebSphereMQJMSException, e: + except WebSphereMQJMSException as e: self.assertEquals(e.completion_code, None) self.assertEquals(e.reason_code, None) self.assertEquals(e.message, message) @@ -1120,7 +1120,7 @@ def testWebSphereMQJMSException(self): try: raise WebSphereMQJMSException(completion_code=mq_exception.comp, reason_code=mq_exception.reason) - except WebSphereMQJMSException, e: + except WebSphereMQJMSException as e: self.assertEquals(e.completion_code, expected_completion_code) self.assertEquals(e.reason_code, expected_reason_code) @@ -1250,7 +1250,7 @@ def testFactoryExportingWebSphereMQConnectionFactoryOnly(self): _globals = {} _locals = {} - exec "from springpython.jms.factory import *" in _globals, _locals + exec("from springpython.jms.factory import *", _globals, _locals) self.assertEquals(1, len(_locals.keys())) self.assertEquals("WebSphereMQConnectionFactory", _locals.keys()[0]) @@ -1459,7 +1459,7 @@ def receive(self, destination, wait_interval): try: listener.run() - except WebSphereMQJMSException, e: + except WebSphereMQJMSException as e: sleep(0.1) # Allows the handler thread to process the message self.assertEquals(e.message, exception_reason) self.assertEquals(3, factory.call_count) @@ -1474,7 +1474,7 @@ def testSimpleMessageListenerContainerMessageHandler(self): try: handler.handle("foo") - except NotImplementedError, e: + except NotImplementedError as e: self.assertEquals(e.message, "Should be overridden by subclasses.") def testSSLCorrectSettings(self): @@ -1517,7 +1517,7 @@ def testSSLIncorrectSettings(self): self.assertRaises(JMSException, factory._connect) try: factory._connect() - except JMSException, e: + except JMSException as e: self.assertEquals(e.args[0], "SSL support requires setting both ssl_cipher_spec and ssl_key_repository") # ssl=True and ssl_cipher_spec only. @@ -1526,7 +1526,7 @@ def testSSLIncorrectSettings(self): self.assertRaises(JMSException, factory._connect) try: factory._connect() - except JMSException, e: + except JMSException as e: self.assertEquals(e.args[0], "SSL support requires setting both ssl_cipher_spec and ssl_key_repository") # ssl=True and ssl_key_repository only. @@ -1535,7 +1535,7 @@ def testSSLIncorrectSettings(self): self.assertRaises(JMSException, factory._connect) try: factory._connect() - except JMSException, e: + except JMSException as e: self.assertEquals(e.args[0], "SSL support requires setting both ssl_cipher_spec and ssl_key_repository") # ssl_cipher_spec only, ssl=False. @@ -1666,7 +1666,7 @@ def receive(self, destination, wait_interval): try: listener.run() - except WebSphereMQJMSException, e: + except WebSphereMQJMSException as e: sleep(0.5) # Allows the handler thread to process the message self.assertEquals(e.message, exception_reason) self.assertEquals(3, factory.call_count) @@ -1681,7 +1681,7 @@ def testSimpleMessageListenerContainerMessageHandler(self): try: handler.handle("foo") - except NotImplementedError, e: + except NotImplementedError as e: self.assertEquals(e.message, "Should be overridden by subclasses.") def testNeedsMCD(self): diff --git a/test/springpythontest/jms_websphere_mq_test_cases.py.bak b/test/springpythontest/jms_websphere_mq_test_cases.py.bak new file mode 100644 index 0000000..3a65d12 --- /dev/null +++ b/test/springpythontest/jms_websphere_mq_test_cases.py.bak @@ -0,0 +1,1696 @@ +# -*- coding: utf-8 -*- + +""" + Copyright 2006-2008 SpringSource (http://springsource.com), All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +# stdlib +import os +import sys +import random +import signal +import logging +import unittest +from struct import pack +from random import choice +from string import letters +from time import time, sleep +from binascii import hexlify, unhexlify +from xml.sax.saxutils import escape, unescape + +try: + import cElementTree as etree +except ImportError: + try: + import xml.etree.ElementTree as etree + except ImportError: + from elementtree import ElementTree as etree + +# Python 2.4 compat +try: + from hashlib import sha1 +except ImportError: + from sha import sha as sha1 + +# pmock +from pmock import * + +# ThreadPool +from threadpool import ThreadPool + +# pymqi +import pymqi as mq +from pymqi import CMQC + +# Spring Python +from springpython.config import XMLConfig +from springpython.context import ApplicationContext + +from springpython.jms.factory import * +from springpython.jms import JMSException, WebSphereMQJMSException, NoMessageAvailableException +from springpython.jms.factory import _WMQ_MAX_EXPIRY_TIME, _WMQ_MQRFH_VERSION_2, \ + _WMQ_MQFMT_RF_HEADER_2, _WMQ_DEFAULT_CCSID, _WMQ_DEFAULT_ENCODING, MQRFH2JMS, \ + _WMQ_DEFAULT_ENCODING_WIRE_FORMAT, _WMQ_DEFAULT_CCSID_WIRE_FORMAT, \ + _WMQ_MQRFH_NO_FLAGS_WIRE_FORMAT, _mcd, unhexlify_wmq_id, _WMQ_ID_PREFIX +from springpython.jms.core import JmsTemplate, TextMessage, MessageConverter +from springpython.jms import DELIVERY_MODE_NON_PERSISTENT, \ + DELIVERY_MODE_PERSISTENT, DEFAULT_DELIVERY_MODE, RECEIVE_TIMEOUT_INDEFINITE_WAIT, \ + RECEIVE_TIMEOUT_NO_WAIT, DEFAULT_TIME_TO_LIVE +from springpython.jms.listener import MessageHandler, SimpleMessageListenerContainer, \ + WebSphereMQListener + +random.seed() + +logger = logging.getLogger("springpythontest.jms_websphere_mq_test_cases") + +QUEUE_MANAGER = "SPRINGPYTHON1" +CHANNEL = "SPR.PY.TO.JAVA.1" +HOST = "localhost" +LISTENER_PORT = "1434" +DESTINATION = "SPRING.PYTHON.TO.JAVA.REQ.1" +PAYLOAD = "Hello from Spring Python and JMS!" + +conn_info = "%s(%s)" % (HOST, LISTENER_PORT) + +# A bit of gimmick, we can't use .eq, because timestamp will be different. +timestamp = "1247950158160" +raw_message = 'RFH \x00\x00\x00\x02\x00\x00\x00\xd8\x00\x00\x01\x11\x00\x00\x04\xb8MQSTR \x00\x00\x00\x00\x00\x00\x04\xb8\x00\x00\x00Ljms_text \x00\x00\x00`queue:///SPRING.PYTHON.TO.JAVA.REQ.1%s2 %s' % (timestamp, PAYLOAD) +raw_message_before_timestamp = raw_message[:raw_message.find(timestamp)] +raw_message_after_timestamp = raw_message[raw_message.find(timestamp) + len(timestamp):len(raw_message)] + +# Used in tests of message consumers +raw_message_for_get = 'RFH \x00\x00\x00\x02\x00\x00\x01d\x00\x00\x01\x11\x00\x00\x04\xb8MQSTR \x00\x00\x00\x00\x00\x00\x04\xb8\x00\x00\x00 jms_text \x00\x00\x00\xa8queue:///TESTqueue:///TEST125209468051912520948039756026ff99-c249-40aa-9f9e-62a7c0a004032 \x00\x00\x00l7b9be165-1151-4dec-be06-b7f876b2703bed493e59-392c-45dd-9862-07847fd202b5 b0f32f11-b531-4bbf-b985-77e795d77024' + +# Queue name may be up to 48 characters (MQCHAR48 in cmqc.h) +queue_name_length = range(1,49) + +class DummyController(object): + def shutdown(self): + pass + +def get_rand_string(length): + return "".join(choice(letters) for idx in range(length)) + +def condition_ignored(ignored): + return True + +def get_default_md(): + md = mq.md() + md.PutDate = "20091023" + md.PutTime = "19261676" + + return md + +def get_simple_message_and_jms_template(mock): + + message = TextMessage() + message.text = "Hi there." + + queue = mock() + mgr = mock() + cd = mock() + sco = mock() + md = get_default_md() + opts = CMQC.MQCNO_HANDLE_SHARE_BLOCK + + sys.modules["pymqi"] = mock() + sys.modules["pymqi"].expects(once()).QueueManager(eq(None)).will(return_value(mgr)) + sys.modules["pymqi"].expects(once()).cd().will(return_value(cd)) + sys.modules["pymqi"].expects(once()).sco().will(return_value(sco)) + sys.modules["pymqi"].expects(at_least_once()).md().will(return_value(md)) + sys.modules["pymqi"].expects(once()).Queue(same(mgr), + eq(DESTINATION), eq(CMQC.MQOO_INPUT_SHARED | CMQC.MQOO_OUTPUT)).will(return_value(queue)) + + sys.modules["pymqi"].MQMIError = mq.MQMIError + + mgr.expects(once()).connectWithOptions(eq(QUEUE_MANAGER), cd=eq(cd), opts=eq(opts), sco=eq(sco)) + queue.expects(at_least_once()).put(functor(condition_ignored), functor(condition_ignored)) + + factory = WebSphereMQConnectionFactory(QUEUE_MANAGER, CHANNEL, HOST, LISTENER_PORT) + jms_template = JmsTemplate(factory) + + return message, jms_template + +class WebSphereMQTestCase(MockTestCase): + + def _get_random_data(self): + + text = get_rand_string(101) + jms_correlation_id = get_rand_string(36) + jms_delivery_mode = choice((DELIVERY_MODE_NON_PERSISTENT, DELIVERY_MODE_PERSISTENT)) + jms_destination = get_rand_string(choice(queue_name_length)) + jms_expiration = random.randrange(int(_WMQ_MAX_EXPIRY_TIME - 2), int(_WMQ_MAX_EXPIRY_TIME + 2)) + jms_priority = choice(range(1,9)) + jms_redelivered = choice((True, False)) + jms_reply_to = get_rand_string(choice(queue_name_length)) + + return(text, jms_correlation_id, jms_delivery_mode, jms_destination, + jms_expiration, jms_priority, jms_redelivered, jms_reply_to) + + def testSendingMessagesToWebSphereMQ(self): + + # For whatever reason, pmock can't handle the following assertions on + # the same mock object though it works fine when the assertions are + # executed in isolation. That's why we need a loop below. + # + # queue.expects(once()).put(string_contains(raw_message_after_timestamp), eq(md)) + # queue.expects(once()).put(string_contains(raw_message_before_timestamp), eq(md)) + + queue1 = Mock("queue_raw_message_before_timestamp") + queue2 = Mock("queue_raw_message_after_timestamp") + + for queue in (queue1, queue2): + + mgr = self.mock() + cd = self.mock() + sco = self.mock() + md = get_default_md() + opts = CMQC.MQCNO_HANDLE_SHARE_BLOCK + + sys.modules["pymqi"] = self.mock() + sys.modules["pymqi"].expects(once()).QueueManager(eq(None)).will(return_value(mgr)) + sys.modules["pymqi"].expects(once()).cd().will(return_value(cd)) + sys.modules["pymqi"].expects(once()).sco().will(return_value(sco)) + sys.modules["pymqi"].expects(once()).md().will(return_value(md)) + sys.modules["pymqi"].expects(once()).Queue(same(mgr), + eq(DESTINATION), eq(CMQC.MQOO_INPUT_SHARED | CMQC.MQOO_OUTPUT)).will(return_value(queue)) + + mgr.expects(once()).connectWithOptions(eq(QUEUE_MANAGER), cd=eq(cd), opts=eq(opts), sco=eq(sco)) + + if queue.get_name() == "queue_raw_message_before_timestamp": + queue.expects(once()).put(string_contains(raw_message_before_timestamp), eq(md)) + else: + queue.expects(once()).put(string_contains(raw_message_after_timestamp), eq(md)) + + queue.expects(once()).close() + + factory = WebSphereMQConnectionFactory(QUEUE_MANAGER, CHANNEL, HOST, LISTENER_PORT) + jms_template = JmsTemplate(factory) + + text_message = TextMessage() + text_message.text = PAYLOAD + + jms_template.send(text_message, DESTINATION) + + del(sys.modules["pymqi"]) + + def testCreatingDefaultMQRFH2JMS(self): + + now = long(time() * 1000) + sleep(1.0) + + # Each folder is prepended by a 4-bytes header + folder_header_length = 4 + + destination_length = choice(queue_name_length) + destination = get_rand_string(destination_length) + message = TextMessage() + + # mcd folder is constant + mcd = """jms_text """ + mcd_len = len(mcd) + mcd_len_wire_format = pack("!l", mcd_len) + + jms = "queue:///%s12479501581602" % destination + current_jms_len = len(jms) + + # Pad to a multiple of 4. + if current_jms_len % 4 == 0: + jms_len = current_jms_len + else: + padding = 4 - (current_jms_len % 4) + jms += " " * padding + jms_len = len(jms) + + jms_len_wire_format = pack("!l", jms_len) + + total_header_length = MQRFH2JMS.FIXED_PART_LENGTH + folder_header_length + mcd_len + folder_header_length + jms_len + total_header_length_wire_format = pack("!l", total_header_length) + + mqrfh2jms = MQRFH2JMS() + header = mqrfh2jms.build_header(message, destination, CMQC, now) + + header_mqrfh_struc_id = header[:4] + header_WMQ_mqrfh_version_2 = header[4:8] + header_total_header_length = header[8:12] + header_WMQ_default_encoding_wire_format = header[12:16] + header_WMQ_default_ccsid_wire_format = header[16:20] + header_mqfmt_string = header[20:28] + header_WMQ_mqrfh_no_flags_wire_format = header[28:32] + header_WMQ_default_ccsid_wire_format = header[32:36] + header_mcd_len = header[36:40] + header_mcd = header[40:40+mcd_len] + header_jms_len = header[40+mcd_len:40+mcd_len+folder_header_length] + header_jms = header[40+mcd_len+folder_header_length:40+mcd_len+folder_header_length+jms_len] + + self.assertEqual(header_mqrfh_struc_id, CMQC.MQRFH_STRUC_ID) + self.assertEqual(header_WMQ_mqrfh_version_2, _WMQ_MQRFH_VERSION_2) + self.assertEqual(header_total_header_length, total_header_length_wire_format) + self.assertEqual(header_WMQ_default_encoding_wire_format, _WMQ_DEFAULT_ENCODING_WIRE_FORMAT) + self.assertEqual(header_WMQ_default_ccsid_wire_format, _WMQ_DEFAULT_CCSID_WIRE_FORMAT) + self.assertEqual(header_mqfmt_string, CMQC.MQFMT_STRING) + self.assertEqual(header_WMQ_mqrfh_no_flags_wire_format, _WMQ_MQRFH_NO_FLAGS_WIRE_FORMAT) + self.assertEqual(header_WMQ_default_ccsid_wire_format, _WMQ_DEFAULT_CCSID_WIRE_FORMAT) + self.assertEqual(header_mcd_len, mcd_len_wire_format) + self.assertEqual(header_mcd, mcd) + self.assertEqual(header_jms_len, jms_len_wire_format) + + # Don't compare the jms folder here - timestamps will differ, will check it below. + # self.assertEqual(header_jms, jms) + + jms = etree.fromstring(header_jms) + + self.assertEqual(jms.find("Dst").text, "queue:///" + destination) + self.assertTrue(bool((long(str(jms.find("Tms").text)) < long(time() * 1000)) is True)) + self.assertEqual(int(str(jms.find("Dlv").text)), DELIVERY_MODE_PERSISTENT) + + def testJMSAndWebSphereMQConstants(self): + self.assertEqual(_WMQ_MQRFH_VERSION_2, "\x00\x00\x00\x02") + self.assertEqual(_WMQ_DEFAULT_ENCODING, 273) + self.assertEqual(_WMQ_DEFAULT_ENCODING_WIRE_FORMAT, pack("!l", 273)) + self.assertEqual(_WMQ_DEFAULT_CCSID, 1208) + self.assertEqual(_WMQ_DEFAULT_CCSID_WIRE_FORMAT, pack("!l", 1208)) + self.assertEqual(_WMQ_MQFMT_RF_HEADER_2, "MQHRF2 ") + self.assertEqual(_WMQ_MQRFH_NO_FLAGS_WIRE_FORMAT, "\x00\x00\x00\x00") + self.assertEqual(MQRFH2JMS.FIXED_PART_LENGTH, 36) + self.assertEqual(MQRFH2JMS.FOLDER_LENGTH_MULTIPLE, 4) + self.assertEqual(_WMQ_MAX_EXPIRY_TIME, 214748364.7) + self.assertEqual(_WMQ_ID_PREFIX, "ID:") + self.assertEqual(etree.tostring(_mcd), """jms_text""") + + def testJMSConstants(self): + self.assertEqual(DELIVERY_MODE_NON_PERSISTENT, 1) + self.assertEqual(DELIVERY_MODE_PERSISTENT, 2) + self.assertEqual(DEFAULT_DELIVERY_MODE, DELIVERY_MODE_PERSISTENT) + self.assertEqual(DEFAULT_TIME_TO_LIVE, 0) + self.assertEqual(RECEIVE_TIMEOUT_INDEFINITE_WAIT, 0) + self.assertEqual(RECEIVE_TIMEOUT_NO_WAIT, -1) + + def testJmsTemplateSettingAndGettingJMSAttributes(self): + + (text, jms_correlation_id, jms_delivery_mode, jms_destination, + jms_expiration, jms_priority, jms_redelivered, + jms_reply_to) = self._get_random_data() + + message = TextMessage() + message.text = text + message.jms_correlation_id = jms_correlation_id + message.jms_delivery_mode = jms_delivery_mode + message.jms_destination = jms_destination + message.jms_expiration = jms_expiration + message.jms_priority = jms_priority + message.jms_redelivered = jms_redelivered + message.jms_reply_to = jms_reply_to + + self.assertEqual(message.text, text) + self.assertEqual(message.jms_correlation_id, jms_correlation_id) + self.assertEqual(message.jms_delivery_mode, jms_delivery_mode) + self.assertEqual(message.jms_destination, jms_destination) + self.assertEqual(message.jms_expiration, jms_expiration) + self.assertEqual(message.jms_priority, jms_priority) + self.assertEqual(message.jms_redelivered, jms_redelivered) + self.assertEqual(message.jms_reply_to, jms_reply_to) + + def testWebSphereMQJMSHeadersMappingsToMQMDAndMQRFH2ForOutgoingMessages(self): + + (text, jms_correlation_id, jms_delivery_mode, jms_destination, + jms_expiration, jms_priority, jms_redelivered, + jms_reply_to) = self._get_random_data() + + message = TextMessage() + + # Message body and standard JMS headers + message.text = text + message.jms_correlation_id = jms_correlation_id + message.jms_delivery_mode = jms_delivery_mode + message.jms_destination = jms_destination + message.jms_expiration = jms_expiration + message.jms_priority = jms_priority + message.jms_redelivered = jms_redelivered + message.jms_reply_to = jms_reply_to + + # WebSphere MQ extended JMS headers + jmsxgroupseq = 90 # Fudged. + jmsxgroupid = get_rand_string(12) + feedback = CMQC.MQFB_EXPIRATION + jms_ibm_report_exception = CMQC.MQRO_EXCEPTION_WITH_DATA + jms_ibm_report_expiration = CMQC.MQRO_EXPIRATION_WITH_FULL_DATA + jms_ibm_report_coa = CMQC.MQRO_COA + jms_ibm_report_cod = CMQC.MQRO_COD_WITH_DATA + jms_ibm_report_pan = CMQC.MQRO_PAN + jms_ibm_report_nan = CMQC.MQRO_NAN + jms_ibm_report_pass_msg_id = CMQC.MQRO_PASS_MSG_ID + jms_ibm_report_pass_correl_id = CMQC.MQRO_PASS_CORREL_ID + jms_ibm_report_discard_msg = CMQC.MQRO_DISCARD_MSG + + message.JMSXGroupSeq = jmsxgroupseq + message.JMSXGroupID = jmsxgroupid + message.JMS_IBM_Report_Exception = jms_ibm_report_exception + message.JMS_IBM_Report_Expiration = jms_ibm_report_expiration + message.JMS_IBM_Report_COA = jms_ibm_report_coa + message.JMS_IBM_Report_COD = jms_ibm_report_cod + message.JMS_IBM_Report_PAN = jms_ibm_report_pan + message.JMS_IBM_Report_NAN = jms_ibm_report_nan + message.JMS_IBM_Report_Pass_Msg_ID = jms_ibm_report_pass_msg_id + message.JMS_IBM_Report_Pass_Correl_ID = jms_ibm_report_pass_correl_id + message.JMS_IBM_Report_Discard_Msg = jms_ibm_report_discard_msg + message.JMS_IBM_Feedback = feedback + message.JMS_IBM_Last_Msg_In_Group = True + + expected_mqmd_jms_correlation_id = jms_correlation_id.ljust(24)[:24] + + def _check_md(md): + """ Verify MQMD attributes on their way to queue.put(body, md). + """ + + # DELIVERY_MODE_NON_PERSISTENT -> MQPER_NOT_PERSISTENT in cmqc.h + # DELIVERY_MODE_PERSISTENT -> MQPER_PERSISTENT in cmqc.h + + if jms_delivery_mode == DELIVERY_MODE_NON_PERSISTENT: + expected_md_persistence = CMQC.MQPER_NOT_PERSISTENT + elif jms_delivery_mode == DELIVERY_MODE_PERSISTENT: + expected_md_persistence = CMQC.MQPER_PERSISTENT + + if jms_expiration / 1000 > _WMQ_MAX_EXPIRY_TIME: + expected_md_expiry = CMQC.MQEI_UNLIMITED + else: + # JMS header is in milliseconds, MQMD one is in centiseconds. + expected_md_expiry = jms_expiration / 10 + + # Truncated or padded to 24 characters. + expected_jmsxgroupid = jmsxgroupid.ljust(24)[:24] + + expected_report = sum((jms_ibm_report_exception,jms_ibm_report_expiration, + jms_ibm_report_coa, jms_ibm_report_cod, jms_ibm_report_pan, + jms_ibm_report_nan, jms_ibm_report_pass_msg_id, + jms_ibm_report_pass_correl_id, jms_ibm_report_discard_msg)) + + # Standard MQMD headers + self.assertEqual(md.Format, _WMQ_MQFMT_RF_HEADER_2) + self.assertEqual(md.CodedCharSetId, _WMQ_DEFAULT_CCSID) + self.assertEqual(md.Encoding, _WMQ_DEFAULT_ENCODING) + + # Mapped from standard JMS headers to MQMD + self.assertEqual(md.CorrelId, expected_mqmd_jms_correlation_id, "md.CorrelId mismatch [%s] [%s]" % (md.CorrelId, expected_mqmd_jms_correlation_id)) + self.assertEqual(md.Persistence, expected_md_persistence, " md.Persistence [%s] [%s]" % (md.Persistence, expected_md_persistence)) + self.assertEqual(md.Expiry, expected_md_expiry, "md.Expiry mismatch [%s] [%s]" % (md.Expiry, expected_md_expiry)) + self.assertEqual(md.Priority, jms_priority, "md.Priority mismatch [%s] [%s]" % (md.Priority, jms_priority)) + self.assertEqual(md.ReplyToQ, jms_reply_to) + + # Extended Webpshere MQ JMS headers + self.assertEqual(md.MsgSeqNumber, jmsxgroupseq, "md.MsgSeqNumber mismatch [%s] [%s]" % (md.MsgSeqNumber, jmsxgroupseq)) + self.assertEqual(md.GroupId, expected_jmsxgroupid, "md.GroupId mismatch [%s] [%s]" % (md.GroupId, expected_jmsxgroupid)) + self.assertEqual(md.Feedback, feedback, "md.Feedback mismatch [%s] [%s]" % (md.Feedback, feedback)) + self.assertEqual(md.Report, expected_report, "md.Report mismatch [%s] [%s]" % (md.Report, expected_report)) + + self.assertTrue(((md.MsgFlags & (CMQC.MQMF_MSG_IN_GROUP) == CMQC.MQMF_MSG_IN_GROUP) is True), + "md.Flags (JMSXGroupSeq) mismatch [%s] [%s]" % (md.MsgFlags, CMQC.MQMF_MSG_IN_GROUP)) + + self.assertTrue(((md.MsgFlags & (CMQC.MQMF_LAST_MSG_IN_GROUP) == CMQC.MQMF_LAST_MSG_IN_GROUP) is True), + "md.Flags (JMS_IBM_Last_Msg_In_Group) mismatch [%s] [%s]" % (md.MsgSeqNumber, CMQC.MQMF_LAST_MSG_IN_GROUP)) + + return True + + def _check_mqrfh2(mqrfh2): + + mqrfh2_jms_start = mqrfh2.find("") + mqrfh2_jmd_end = mqrfh2.find("") + 6 + + mqrfh2_jms = str(mqrfh2[mqrfh2_jms_start:mqrfh2_jmd_end]) + jms = etree.fromstring(mqrfh2_jms) + + now = long(time() * 1000) + + self.assertEqual(str(jms.find("Pri").text), str(jms_priority)) + + # The message has been already put onto queue so its timestamp + # should be equal or earlier than now. + jms_tms = long(str(jms.find("Tms").text)) + self.assertTrue(jms_tms <= now, "jms.Tms error [%s] [%s]" % (jms_tms, now)) + + # Same as Webpshere MQ JMS Java API, jms.Dst cannnot be set manually + # by user, though docs don't mention that. + self.assertEqual(str(jms.find("Dst").text), "queue:///" + DESTINATION) + + # MQMD CorrelId is truncated to 24 characters, but MQRFH2's one isn't. + self.assertEqual(str(jms.find("Cid").text), jms_correlation_id) + + # Message has been sent already sent a couple of milliseconds ago. + jms_exp = long(str(jms.find("Exp").text)) + self.assertTrue(jms_exp - now <= jms_expiration, "jms.Exp error [%s] [%s]" % (jms_exp - now, jms_expiration)) + + return True + + queue = self.mock() + mgr = self.mock() + cd = self.mock() + sco = self.mock() + md = get_default_md() + opts = CMQC.MQCNO_HANDLE_SHARE_BLOCK + + sys.modules["pymqi"] = self.mock() + sys.modules["pymqi"].expects(once()).QueueManager(eq(None)).will(return_value(mgr)) + sys.modules["pymqi"].expects(once()).cd().will(return_value(cd)) + sys.modules["pymqi"].expects(once()).sco().will(return_value(sco)) + sys.modules["pymqi"].expects(once()).md().will(return_value(md)) + sys.modules["pymqi"].expects(once()).Queue(same(mgr), + eq(DESTINATION), eq(CMQC.MQOO_INPUT_SHARED | CMQC.MQOO_OUTPUT)).will(return_value(queue)) + + mgr.expects(at_least_once()).connectWithOptions(eq(QUEUE_MANAGER), cd=eq(cd), opts=eq(opts), sco=eq(sco)) + queue.expects(once()).put(functor(_check_mqrfh2), functor(_check_md)) + + factory = WebSphereMQConnectionFactory(QUEUE_MANAGER, CHANNEL, HOST, LISTENER_PORT) + jms_template = JmsTemplate(factory) + + jms_template.send(message, DESTINATION) + + del(sys.modules["pymqi"]) + + def testMessageConverterForOutgoingMessages(self): + + customer = "123" + customer_account = "456" + number = "789" + date = "20090126" + + expected_message_after_conversion = ";".join((customer, customer_account, + number, date)) + + class Invoice(object): + def __init__(self): + self.customer = customer + self.customer_account = customer_account + self.number = number + self.date = date + + class InvoiceConverter(object): + def to_message(self, invoice): + text = ";".join((invoice.customer, invoice.customer_account, + invoice.number, invoice.date)) + + return TextMessage(text) + + def _check_payload(message): + """ Business payload is the last part of a message, i.e. comes + after the MQ headers. + """ + return message.endswith(expected_message_after_conversion) + + queue = self.mock() + mgr = self.mock() + cd = self.mock() + sco = self.mock() + md = get_default_md() + opts = CMQC.MQCNO_HANDLE_SHARE_BLOCK + + sys.modules["pymqi"] = self.mock() + sys.modules["pymqi"].expects(once()).QueueManager(eq(None)).will(return_value(mgr)) + sys.modules["pymqi"].expects(once()).cd().will(return_value(cd)) + sys.modules["pymqi"].expects(once()).sco().will(return_value(sco)) + sys.modules["pymqi"].expects(once()).md().will(return_value(md)) + sys.modules["pymqi"].expects(once()).Queue(same(mgr), + eq(DESTINATION), eq(CMQC.MQOO_INPUT_SHARED | CMQC.MQOO_OUTPUT)).will(return_value(queue)) + + mgr.expects(once()).connectWithOptions(eq(QUEUE_MANAGER), cd=eq(cd), opts=eq(opts), sco=eq(sco)) + + queue.expects(once()).put(functor(_check_payload), eq(md)) + + factory = WebSphereMQConnectionFactory(QUEUE_MANAGER, CHANNEL, HOST, LISTENER_PORT) + jms_template = JmsTemplate(factory) + invoice = Invoice() + + # No message converter set yet. + self.assertRaises(JMSException, jms_template.convert_and_send, invoice, DESTINATION) + + # No JMSException at this point. + jms_template.message_converter = InvoiceConverter() + jms_template.convert_and_send(invoice, DESTINATION) + + del(sys.modules["pymqi"]) + + def testSettingDefaultDestinationForOutgoingMessages(self): + + default_destination = get_rand_string(24) + + def _check_mqrfh2_destination(message): + return "queue:///" + default_destination in message + + queue = self.mock() + mgr = self.mock() + cd = self.mock() + sco = self.mock() + md = get_default_md() + opts = CMQC.MQCNO_HANDLE_SHARE_BLOCK + + sys.modules["pymqi"] = self.mock() + sys.modules["pymqi"].expects(once()).QueueManager(eq(None)).will(return_value(mgr)) + sys.modules["pymqi"].expects(once()).cd().will(return_value(cd)) + sys.modules["pymqi"].expects(once()).sco().will(return_value(sco)) + sys.modules["pymqi"].expects(once()).md().will(return_value(md)) + + queue.stubs().put(functor(_check_mqrfh2_destination), eq(md)) + + # Queue name must be equal to default destination, pmock will verify it. + sys.modules["pymqi"].expects(once()).Queue(same(mgr), + eq(default_destination), eq(CMQC.MQOO_INPUT_SHARED | CMQC.MQOO_OUTPUT)).will(return_value(queue)) + + mgr.expects(once()).connectWithOptions(eq(QUEUE_MANAGER), cd=eq(cd), opts=eq(opts), sco=eq(sco)) + + factory = WebSphereMQConnectionFactory(QUEUE_MANAGER, CHANNEL, HOST, LISTENER_PORT) + jms_template = JmsTemplate(factory) + + text_message = TextMessage() + text_message.text = PAYLOAD + + # No default destination set yet. + self.assertRaises(JMSException, jms_template.send, text_message) + + # No JMSException here. + jms_template.default_destination = default_destination + jms_template.send(text_message) + + def testUnhexlifyWebSphereMQIdentifiers(self): + + # Basic check. + hex_wmq_id = "ID:414d5120535052494e47505954484f4ecc90674a041f0020" + unhexlified = "AMQ SPRINGPYTHON\xcc\x90gJ\x04\x1f\x00 " + + self.assertEquals(unhexlify_wmq_id(hex_wmq_id), unhexlified) + + # Now the real message, check the unhexlifying for every relevant JMS + # and MQMD header. + + def get_expected_md_header_value(jms_header_value): + if jms_header_value.startswith("ID:"): + expected_mqmd_header = unhexlify(jms_header_value.replace("ID:", "", 1)) + else: + if len(jms_header_value) == 24: + expected_mqmd_header = jms_header_value + elif len(jms_header_value) < 24: + expected_mqmd_header = jms_header_value.ljust(24) + elif len(jms_header_value) > 24: + expected_mqmd_header = jms_header_value[:24] + + return expected_mqmd_header + + jms_to_mqmd_headers = { + "jms_correlation_id":"CorrelId", + "JMSXGroupID":"GroupId"} + + for jms_header, mqmd_header in jms_to_mqmd_headers.items(): + + jms_wmq_id = get_rand_string(24) + jms_wmq_id_header_value = "ID:" + hexlify(jms_wmq_id) + jms_non_wmq_header_short_value = get_rand_string(12) + jms_non_wmq_header_max_mqmd_length_value = get_rand_string(24) + jms_non_wmq_header_long_value = get_rand_string(36) + + for jms_header_value in(jms_wmq_id_header_value, + jms_non_wmq_header_short_value, jms_non_wmq_header_max_mqmd_length_value, + jms_non_wmq_header_long_value): + + expected_mqmd_header_value = get_expected_md_header_value(jms_header_value) + + def _check_md(md): + mqmd_header_value = getattr(md, mqmd_header) + + self.assertEquals(mqmd_header_value, expected_mqmd_header_value, + ("ID mismatch mqmd_header_value='%s' expected_mqmd_header_value='%s' " + + "jms_header='%s' mqmd_header='%s' jms_header_value='%s'") % ( + mqmd_header_value, expected_mqmd_header_value, + jms_header, mqmd_header, jms_header_value)) + return True + + message = TextMessage() + message.text = "Hi there." + + setattr(message, jms_header, jms_header_value) + + queue = self.mock() + mgr = self.mock() + cd = self.mock() + sco = self.mock() + md = get_default_md() + opts = CMQC.MQCNO_HANDLE_SHARE_BLOCK + + sys.modules["pymqi"] = self.mock() + sys.modules["pymqi"].expects(once()).QueueManager(eq(None)).will(return_value(mgr)) + sys.modules["pymqi"].expects(once()).cd().will(return_value(cd)) + sys.modules["pymqi"].expects(once()).sco().will(return_value(sco)) + sys.modules["pymqi"].expects(once()).md().will(return_value(md)) + sys.modules["pymqi"].expects(once()).Queue(same(mgr), + eq(DESTINATION), eq(CMQC.MQOO_INPUT_SHARED | CMQC.MQOO_OUTPUT)).will(return_value(queue)) + + mgr.expects(once()).connectWithOptions(eq(QUEUE_MANAGER), cd=eq(cd), opts=eq(opts), sco=eq(sco)) + + queue.expects(once()).put(functor(condition_ignored), functor(_check_md)) + + factory = WebSphereMQConnectionFactory(QUEUE_MANAGER, CHANNEL, HOST, LISTENER_PORT) + jms_template = JmsTemplate(factory) + + jms_template.send(message, DESTINATION) + + del(sys.modules["pymqi"]) + + def testMappingJMSHeadersOverwrittenByCallingQueuePut(self): + + now = long(time() * 1000) + + jms_expiration = 2619 + expected_jms_expiration = now + jms_expiration + + jms_message_id = get_rand_string(24) + expected_jms_message_id = "ID:" + hexlify(jms_message_id) + + priority = random.choice(range(1,9)) + + expected_jmsxuserid = get_rand_string(6) + expected_jmsxappid = get_rand_string(6) + expected_jms_ibm_putdate = "20090813" + expected_jms_ibm_puttime = "21324547" + expected_jms_priority = priority + expected_jms_timestamp = 1250199165470 + + def update_md(md): + + md.MsgId = jms_message_id + md.UserIdentifier = expected_jmsxuserid + md.PutApplName = expected_jmsxappid + md.PutDate = expected_jms_ibm_putdate + md.PutTime = expected_jms_ibm_puttime + md.Priority = priority + + return True + + message = TextMessage() + message.text = "Hi there." + message.jms_expiration = jms_expiration + + queue = self.mock() + mgr = self.mock() + cd = self.mock() + sco = self.mock() + md = get_default_md() + opts = CMQC.MQCNO_HANDLE_SHARE_BLOCK + + sys.modules["pymqi"] = self.mock() + sys.modules["pymqi"].expects(once()).QueueManager(eq(None)).will(return_value(mgr)) + sys.modules["pymqi"].expects(once()).cd().will(return_value(cd)) + sys.modules["pymqi"].expects(once()).sco().will(return_value(sco)) + sys.modules["pymqi"].expects(once()).md().will(return_value(md)) + sys.modules["pymqi"].expects(once()).Queue(same(mgr), + eq(DESTINATION), eq(CMQC.MQOO_INPUT_SHARED | CMQC.MQOO_OUTPUT)).will(return_value(queue)) + + mgr.expects(once()).connectWithOptions(eq(QUEUE_MANAGER), cd=eq(cd), opts=eq(opts), sco=eq(sco)) + + queue.expects(once()).put(functor(condition_ignored), functor(update_md)) + + factory = WebSphereMQConnectionFactory(QUEUE_MANAGER, CHANNEL, HOST, LISTENER_PORT) + jms_template = JmsTemplate(factory) + + jms_template.send(message, DESTINATION) + + # jms_template.send should've set those attributes, values taken from md. + self.assertEquals(message.jms_message_id, expected_jms_message_id) + self.assertTrue((expected_jms_expiration - message.jms_expiration) <= jms_expiration, + "expected_jms_expiration: '%s', message.jms_expiration: '%s', jms_expiration: '%s'" % ( + expected_jms_expiration, message.jms_expiration, jms_expiration)) + self.assertEquals(message.JMSXUserID, expected_jmsxuserid) + self.assertEquals(message.JMSXAppID, expected_jmsxappid) + self.assertEquals(message.JMS_IBM_PutDate, expected_jms_ibm_putdate) + self.assertEquals(message.JMS_IBM_PutTime, expected_jms_ibm_puttime) + self.assertEquals(message.jms_priority, expected_jms_priority) + self.assertEquals(message.jms_timestamp, expected_jms_timestamp) + self.assertEquals(message.jms_destination, DESTINATION) + + del(sys.modules["pymqi"]) + + def testRaisingJMSExceptionOnInvalidDeliveryMode(self): + + message, jms_template = get_simple_message_and_jms_template(self.mock) + + # jms_delivery_mode should be equal to DELIVERY_MODE_NON_PERSISTENT or DELIVERY_MODE_PERSISTENT + message.jms_delivery_mode = get_rand_string(10) + self.assertRaises(JMSException, jms_template.send, message, DESTINATION) + + # No JMSException here + for mode in(DELIVERY_MODE_NON_PERSISTENT, DELIVERY_MODE_PERSISTENT): + message.jms_delivery_mode = mode + jms_template.send(message, DESTINATION) + + del(sys.modules["pymqi"]) + + def testCachingOpenQueues(self): + + message = TextMessage() + message.text = "Hi there." + + queue = self.mock() + mgr = self.mock() + cd = self.mock() + sco = self.mock() + md = get_default_md() + gmo = self.mock() + opts = CMQC.MQCNO_HANDLE_SHARE_BLOCK + + + sys.modules["pymqi"] = self.mock() + sys.modules["pymqi"].expects(at_least_once()).QueueManager(eq(None)).will(return_value(mgr)) + sys.modules["pymqi"].expects(at_least_once()).cd().will(return_value(cd)) + sys.modules["pymqi"].expects(at_least_once()).sco().will(return_value(sco)) + sys.modules["pymqi"].expects(at_least_once()).md().will(return_value(md)) + sys.modules["pymqi"].expects(at_least_once()).gmo().will(return_value(gmo)) + + sys.modules["pymqi"].expects(at_least_once()).Queue(same(mgr), + eq(DESTINATION), eq(CMQC.MQOO_INPUT_SHARED | CMQC.MQOO_OUTPUT)).will(return_value(queue)) + + sys.modules["pymqi"].expects(at_least_once()).Queue(same(mgr), + eq(DESTINATION)).will(return_value(queue)) + + mgr.expects(at_least_once()).connectWithOptions(eq(QUEUE_MANAGER), cd=eq(cd), opts=eq(opts), sco=eq(sco)) + queue.expects(at_least_once()).put(functor(condition_ignored), functor(condition_ignored)) + queue.expects(at_least_once()).close() + queue.set_default_stub(return_value(raw_message_for_get)) + + factory = WebSphereMQConnectionFactory(QUEUE_MANAGER, CHANNEL, HOST, LISTENER_PORT, + cache_open_send_queues=False) + jms_template = JmsTemplate(factory) + + for x in range(10): + jms_template.send(message, DESTINATION) + + self.assertTrue(DESTINATION not in factory._open_send_queues_cache) + self.assertTrue(len(factory._open_send_queues_cache) == 0) + + factory.cache_open_send_queues = True + + for x in range(10): + jms_template.send(message, DESTINATION) + + self.assertTrue(DESTINATION in factory._open_send_queues_cache) + self.assertTrue(len(factory._open_send_queues_cache), 1) + + factory2 = WebSphereMQConnectionFactory(QUEUE_MANAGER, CHANNEL, HOST, LISTENER_PORT) + jms_template2 = JmsTemplate(factory2) + + for x in range(10): + jms_template2.send(message, DESTINATION) + self.assertTrue(DESTINATION in factory2._open_send_queues_cache) + self.assertEquals(1, len(factory2._open_send_queues_cache)) + + + # Now make sure open queues are not stored in caches. + factory3 = WebSphereMQConnectionFactory(QUEUE_MANAGER, CHANNEL, HOST, LISTENER_PORT) + factory3.cache_open_send_queues = False + factory3.cache_open_receive_queues = False + + jms_template3 = JmsTemplate(factory3) + + for x in range(10): + jms_template3.send(message, DESTINATION) + jms_template3.receive(DESTINATION) + + self.assertTrue(DESTINATION not in factory3._open_send_queues_cache) + self.assertEquals(0, len(factory3._open_send_queues_cache)) + + self.assertTrue(DESTINATION not in factory3._open_receive_queues_cache) + self.assertEquals(0, len(factory3._open_receive_queues_cache)) + + del(sys.modules["pymqi"]) + + def testSettingUserAttributes(self): + + source = "" + preferred_provider = "" + + broker_id = get_rand_string(26) + expected_source = escape(source) + expected_preferred_provider = escape(preferred_provider) + + # 'bile' in Polish features no letters in ASCII range. + foobar = unicode('\xc5\xbc\xc3\xb3\xc5\x82\xc4\x87', "utf-8") + get_rand_string(26) + + def _check_user_attributes(mqrfh2): + + usr_start = mqrfh2.find("") + usr_end = mqrfh2.find("") + 6 + usr_str = str(mqrfh2[usr_start:usr_end]) + + usr = etree.fromstring(usr_str) + + self.assertEqual(str(usr.find("broker_id").text), broker_id) + self.assertEqual(str(usr.find("SOURCE").text), expected_source) + self.assertEqual(str(usr.find("PREFERRED_PROVIDER").text), expected_preferred_provider) + self.assertEqual(unicode(usr.find("foobar").text), foobar) + + return True + + queue = self.mock() + mgr = self.mock() + cd = self.mock() + sco = self.mock() + md = get_default_md() + opts = CMQC.MQCNO_HANDLE_SHARE_BLOCK + + sys.modules["pymqi"] = self.mock() + sys.modules["pymqi"].expects(once()).QueueManager(eq(None)).will(return_value(mgr)) + sys.modules["pymqi"].expects(once()).cd().will(return_value(cd)) + sys.modules["pymqi"].expects(once()).sco().will(return_value(sco)) + sys.modules["pymqi"].expects(once()).md().will(return_value(md)) + + queue.stubs().put(functor(_check_user_attributes), eq(md)) + + sys.modules["pymqi"].expects(once()).Queue(same(mgr), + eq(DESTINATION), eq(CMQC.MQOO_INPUT_SHARED | CMQC.MQOO_OUTPUT)).will(return_value(queue)) + + mgr.expects(once()).connectWithOptions(eq(QUEUE_MANAGER), cd=eq(cd), opts=eq(opts), sco=eq(sco)) + + factory = WebSphereMQConnectionFactory(QUEUE_MANAGER, CHANNEL, HOST, LISTENER_PORT) + jms_template = JmsTemplate(factory) + + text_message = TextMessage() + text_message.text = PAYLOAD + text_message.broker_id = broker_id + text_message.SOURCE = source + text_message.PREFERRED_PROVIDER = preferred_provider + text_message.foobar = foobar + + jms_template.send(text_message, DESTINATION) + + def testUnicodePayload(self): + + message, jms_template = get_simple_message_and_jms_template(self.mock) + + # 'Suzuki' in Japanese + message.payload = unicode("\xe9\x88\xb4\xe6\x9c\xa8", "utf-8") + + # No exception should be raised. + jms_template.send(message, DESTINATION) + + del(sys.modules["pymqi"]) + + def testJmsTemplateSettingAndGettingJMSAttributes(self): + (text, jms_correlation_id, jms_delivery_mode, jms_destination, + jms_expiration, jms_priority, jms_redelivered, + jms_reply_to) = self._get_random_data() + + message = TextMessage() + message.text = text + message.jms_correlation_id = jms_correlation_id + message.jms_delivery_mode = jms_delivery_mode + message.jms_destination = jms_destination + message.jms_expiration = jms_expiration + message.jms_priority = jms_priority + message.jms_redelivered = jms_redelivered + message.jms_reply_to = jms_reply_to + + self.assertEqual(message.text, text) + self.assertEqual(message.jms_correlation_id, jms_correlation_id) + self.assertEqual(message.jms_delivery_mode, jms_delivery_mode) + self.assertEqual(message.jms_destination, jms_destination) + self.assertEqual(message.jms_expiration, jms_expiration) + self.assertEqual(message.jms_priority, jms_priority) + self.assertEqual(message.jms_redelivered, jms_redelivered) + self.assertEqual(message.jms_reply_to, jms_reply_to) + + def testSendingStringMessages(self): + + queue = self.mock() + mgr = self.mock() + cd = self.mock() + sco = self.mock() + md = get_default_md() + opts = CMQC.MQCNO_HANDLE_SHARE_BLOCK + + string = "foo" + uni = u"bar" + + for payload in(string, uni): + + def _check_payload(message): + self.assertEquals(message[-3:], payload) + return True + + sys.modules["pymqi"] = self.mock() + sys.modules["pymqi"].expects(once()).QueueManager(eq(None)).will(return_value(mgr)) + sys.modules["pymqi"].expects(once()).cd().will(return_value(cd)) + sys.modules["pymqi"].expects(once()).sco().will(return_value(sco)) + sys.modules["pymqi"].expects(once()).md().will(return_value(md)) + sys.modules["pymqi"].expects(once()).Queue(same(mgr), + eq(DESTINATION), eq(CMQC.MQOO_INPUT_SHARED | CMQC.MQOO_OUTPUT)).will(return_value(queue)) + + mgr.expects(once()).connectWithOptions(eq(QUEUE_MANAGER), cd=eq(cd), opts=eq(opts), sco=eq(sco)) + queue.expects(at_least_once()).put(functor(_check_payload), functor(condition_ignored)) + + factory = WebSphereMQConnectionFactory(QUEUE_MANAGER, CHANNEL, HOST, LISTENER_PORT) + jms_template = JmsTemplate(factory) + + jms_template.send(payload, DESTINATION) + + del(sys.modules["pymqi"]) + + +################################################################################ + + + def testReceivingMessages(self): + + queue = self.mock() + mgr = self.mock() + cd = self.mock() + sco = self.mock() + gmo = self.mock() + md = get_default_md() + opts = CMQC.MQCNO_HANDLE_SHARE_BLOCK + + md.PutDate = "20090813" + md.PutTime = "21324547" + + class Invoice(object): + def __init__(self, number): + self.number = number + + class InvoiceConverter(object): + def from_message(self, message): + return Invoice(message.text) + + sys.modules["pymqi"] = self.mock() + sys.modules["pymqi"].MQMIError = mq.MQMIError + sys.modules["pymqi"].stubs().QueueManager(eq(None)).will(return_value(mgr)) + sys.modules["pymqi"].stubs().cd().will(return_value(cd)) + sys.modules["pymqi"].stubs().sco().will(return_value(sco)) + sys.modules["pymqi"].stubs().md().will(return_value(md)) + sys.modules["pymqi"].stubs().gmo().will(return_value(gmo)) + sys.modules["pymqi"].expects(once()).Queue(same(mgr), + eq(DESTINATION), eq(CMQC.MQOO_INPUT_SHARED | CMQC.MQOO_OUTPUT)).will(return_value(queue)) + mgr.stubs().connectWithOptions(eq(QUEUE_MANAGER), cd=eq(cd), opts=eq(opts), sco=eq(sco)) + queue.set_default_stub(return_value(raw_message_for_get)) + + factory = WebSphereMQConnectionFactory(QUEUE_MANAGER, CHANNEL, HOST, LISTENER_PORT) + + jms_template1 = JmsTemplate(factory) + + # No message converter set yet. + self.assertRaises(JMSException, jms_template1.receive_and_convert, DESTINATION, 300) + + jms_template1.message_converter = InvoiceConverter() + + # No JMSException at this point. + jms_template1.receive_and_convert(DESTINATION, 300) + + jms_template2 = JmsTemplate(factory) + self.assertEquals("b0f32f11-b531-4bbf-b985-77e795d77024", jms_template2.receive(DESTINATION).text) + + jms_template3 = JmsTemplate(factory) + jms_template3.default_destination = DESTINATION + self.assertEquals("b0f32f11-b531-4bbf-b985-77e795d77024", jms_template3.receive().text) + + # No destination set. + jms_template4 = JmsTemplate(factory) + self.assertRaises(JMSException, jms_template4.receive) + + del(sys.modules["pymqi"]) + +################################################################################ + + def testGetConnectionInfo(self): + + factory = WebSphereMQConnectionFactory(QUEUE_MANAGER, CHANNEL, HOST, LISTENER_PORT) + + conn_info = factory.get_connection_info() + self.assertEquals(conn_info, "queue manager=[%s], channel=[%s], conn_name=[%s(%s)]" % ( + QUEUE_MANAGER, CHANNEL, HOST, LISTENER_PORT)) + + + def testGetJMSTimestampFromMD(self): + date = "20091021" + time = "16321012" + + factory = WebSphereMQConnectionFactory() + + self.assertEquals(factory._get_jms_timestamp_from_md(date, time), 1256142730120) + + def testDynamicQueues(self): + + expected_dyn_queue_name = get_rand_string(12) + expected_dyn_queue = self.mock() + expected_dyn_queue._Queue__qDesc = self.mock() + expected_dyn_queue._Queue__qDesc.ObjectName = expected_dyn_queue_name + + mgr = self.mock() + cd = self.mock() + sco = self.mock() + md = get_default_md() + opts = CMQC.MQCNO_HANDLE_SHARE_BLOCK + + payload = get_rand_string(3) + uni = u"bar" + + def _check_payload(message): + self.assertEquals(message[-3:], payload) + return True + + sys.modules["pymqi"] = self.mock() + sys.modules["pymqi"].expects(once()).QueueManager(eq(None)).will(return_value(mgr)) + sys.modules["pymqi"].expects(once()).cd().will(return_value(cd)) + sys.modules["pymqi"].stubs().sco().will(return_value(sco)) + sys.modules["pymqi"].expects(once()).md().will(return_value(md)) + mgr.expects(once()).connectWithOptions(eq(QUEUE_MANAGER), cd=eq(cd), opts=eq(opts), sco=eq(sco)) + + sys.modules["pymqi"].expects(once()).Queue(same(mgr), eq("SYSTEM.DEFAULT.MODEL.QUEUE"), + eq(CMQC.MQOO_INPUT_SHARED)).will(return_value(expected_dyn_queue)) + + sys.modules["pymqi"].expects(once()).Queue(same(mgr), eq(expected_dyn_queue_name), + eq(CMQC.MQOO_INPUT_SHARED | CMQC.MQOO_OUTPUT)).will(return_value(expected_dyn_queue)) + expected_dyn_queue.set_default_stub(return_value(None)) + + factory = WebSphereMQConnectionFactory(QUEUE_MANAGER, CHANNEL, HOST, LISTENER_PORT) + jms_template = JmsTemplate(factory) + + dyn_queue = jms_template.open_dynamic_queue() + + self.assertEquals(len(factory._open_dynamic_queues_cache), 1) + self.assertTrue(expected_dyn_queue_name in factory._open_dynamic_queues_cache) + self.assertEquals(factory._open_dynamic_queues_cache[expected_dyn_queue_name], expected_dyn_queue) + + jms_template.send(payload, dyn_queue) + jms_template.close_dynamic_queue(dyn_queue) + + self.assertEquals(len(factory._open_dynamic_queues_cache), 0) + self.assertTrue(expected_dyn_queue_name not in factory._open_dynamic_queues_cache) + self.assertRaises(KeyError, factory._open_dynamic_queues_cache.__getitem__, expected_dyn_queue_name) + + del(sys.modules["pymqi"]) + + def testWebSphereMQJMSException(self): + + expected_completion_code = CMQC.MQCC_FAILED + expected_reason_code = CMQC.MQRC_Q_MGR_STOPPING + + message = get_rand_string(20) + mq_exception = mq.MQMIError(expected_completion_code, expected_reason_code) + + try: + raise WebSphereMQJMSException() + except WebSphereMQJMSException, e: + self.assertEquals(e.completion_code, None) + self.assertEquals(e.reason_code, None) + self.assertEquals(e.message, None) + + try: + raise WebSphereMQJMSException(message) + except WebSphereMQJMSException, e: + self.assertEquals(e.completion_code, None) + self.assertEquals(e.reason_code, None) + self.assertEquals(e.message, message) + + #sys.modules["pymqi"].MQMIError = mq.MQMIError + + try: + raise WebSphereMQJMSException(completion_code=mq_exception.comp, reason_code=mq_exception.reason) + except WebSphereMQJMSException, e: + self.assertEquals(e.completion_code, expected_completion_code) + self.assertEquals(e.reason_code, expected_reason_code) + + def testMessageConverterRaisingNotImplementedError(self): + + converter = MessageConverter() + self.assertRaises(NotImplementedError, converter.to_message, None) + self.assertRaises(NotImplementedError, converter.from_message, None) + + def testTextMessageStringRepresentation(self): + + expected_message_sha1_sum_max_100_chars = "aa340eed9dacde39fd355c27b54b2c0f33454f97" + expected_message_sha1_sum_max_4_chars = "662a66f448dce916d0b008edfe995cb879d039bf" + expected_message_sha1_sum_no_text = "6c59aa896a27fb6e7494089bc6a6d6193129b796" + + message1 = TextMessage() + message1.text = "ZFJQ#(RAWFD" * 1000 + message1.jms_correlation_id = "APWRI!@#ffffq3rU" + message1.jms_delivery_mode = DEFAULT_DELIVERY_MODE + message1.jms_destination = "ZVCW#TRW" + message1.jms_expiration = 1252094803975 + message1.jms_priority = 6 + message1.jms_redelivered = True + message1.jms_reply_to = "SETAFJOEF" + message1.jms_message_id = "SFJW)$%@)*%@#%@" + message1.jms_correlation_id = "ARO@#$R@#%$@#RVSTYUO" + message1.jms_timestamp = 1250199165470 + message1.CsKAo9 = "ZDCVKWER@_#%LA" + + self.assertEquals(message1.max_chars_printed, 100) + self.assertEquals(sha1(str(message1)).hexdigest(), expected_message_sha1_sum_max_100_chars) + + message2 = TextMessage(max_chars_printed=4) + message2.text = "SADFK@$#RTIWA" * 1000 + message2.jms_correlation_id = "SZDFKW$#A:jms_text') + + msgbody = mqrfh2jms.folders["mcd"].find("msgbody") + + # msgbody.get will return None if such a namespace will not have been defined. + self.assertEquals("true", msgbody.get("{dummy}nil")) + + def testSimpleMessageListenerContainer(self): + + class TestMessageHandler(object): + def handle(self, message): + return 123 + + handler = TestMessageHandler() + concurrent_listeners = 4 + handlers_per_listener = 2 + wait_interval = 1300 + + queue = self.mock() + mgr = self.mock() + cd = self.mock() + gmo = self.mock() + md = get_default_md() + opts = CMQC.MQCNO_HANDLE_SHARE_BLOCK + + sys.modules["pymqi"] = self.mock() + sys.modules["pymqi"].MQMIError = mq.MQMIError + sys.modules["pymqi"].stubs().QueueManager(eq(None)).will(return_value(mgr)) + sys.modules["pymqi"].stubs().cd().will(return_value(cd)) + sys.modules["pymqi"].stubs().md().will(return_value(md)) + sys.modules["pymqi"].stubs().gmo().will(return_value(gmo)) + sys.modules["pymqi"].expects(once()).Queue(same(mgr), + eq(DESTINATION), eq(CMQC.MQOO_INPUT_SHARED | CMQC.MQOO_OUTPUT)).will(return_value(queue)) + mgr.stubs().connectWithOptions(eq(QUEUE_MANAGER), cd=eq(cd), opts=eq(opts)) + queue.set_default_stub(return_value(raw_message_for_get)) + + factory = WebSphereMQConnectionFactory(QUEUE_MANAGER, CHANNEL, HOST, LISTENER_PORT) + + smlc = SimpleMessageListenerContainer(factory, DESTINATION, handler, + concurrent_listeners, handlers_per_listener, wait_interval) + smlc.after_properties_set() + + self.assertEquals(smlc.factory, factory) + self.assertEquals(smlc.destination, DESTINATION) + self.assertEquals(smlc.handler, handler) + self.assertEquals(smlc.concurrent_listeners, concurrent_listeners) + self.assertEquals(smlc.handlers_per_listener, handlers_per_listener) + self.assertEquals(smlc.wait_interval, wait_interval) + + del(sys.modules["pymqi"]) + + def testWebSphereMQListener(self): + + message = get_rand_string(12) + exception_reason = get_rand_string(12) + "ęóąśłżźćń" + + class TestMessageHandler(object): + def __init__(self): + self.data = [] + + def __str__(self): + return "%s %s" % (hex(id(self)), str(self.data)) + + def handle(self, message): + self.data.append(message) + + class _ConnectionFactory(WebSphereMQConnectionFactory): + def __init__(self, *args): + super(_ConnectionFactory, self).__init__(*args) + self.call_count = 0 + + def receive(self, destination, wait_interval): + self.call_count += 1 + + if self.call_count == 1: + import sys + return message + elif self.call_count == 2: + raise NoMessageAvailableException() + else: + raise WebSphereMQJMSException(exception_reason, CMQC.MQCC_FAILED, CMQC.MQRC_OPTION_NOT_VALID_FOR_TYPE) + + + handler = TestMessageHandler() + handlers_per_listener = 1 + wait_interval = 1300 + + queue = self.mock() + mgr = self.mock() + cd = self.mock() + gmo = self.mock() + md = get_default_md() + opts = CMQC.MQCNO_HANDLE_SHARE_BLOCK + + sys.modules["pymqi"] = self.mock() + sys.modules["pymqi"].MQMIError = mq.MQMIError + sys.modules["pymqi"].stubs().QueueManager(eq(None)).will(return_value(mgr)) + sys.modules["pymqi"].stubs().cd().will(return_value(cd)) + sys.modules["pymqi"].stubs().md().will(return_value(md)) + sys.modules["pymqi"].stubs().gmo().will(return_value(gmo)) + mgr.stubs().connectWithOptions(eq(QUEUE_MANAGER), cd=eq(cd), opts=eq(opts)) + + factory = _ConnectionFactory(QUEUE_MANAGER, CHANNEL, HOST, LISTENER_PORT) + + listener = WebSphereMQListener() + listener.factory = factory + listener.destination = DESTINATION + listener.wait_interval = wait_interval + listener.handler = handler + listener.handlers_pool = ThreadPool(handlers_per_listener) + + try: + listener.run() + except WebSphereMQJMSException, e: + sleep(0.1) # Allows the handler thread to process the message + self.assertEquals(e.message, exception_reason) + self.assertEquals(3, factory.call_count) + self.assertEquals(1, len(handler.data)) + self.assertEquals(message, handler.data[0]) + finally: + del(sys.modules["pymqi"]) + + def testSimpleMessageListenerContainerMessageHandler(self): + handler = MessageHandler() + self.assertRaises(NotImplementedError, handler.handle, "foo") + + try: + handler.handle("foo") + except NotImplementedError, e: + self.assertEquals(e.message, "Should be overridden by subclasses.") + + def testSSLCorrectSettings(self): + opts = CMQC.MQCNO_HANDLE_SHARE_BLOCK + mgr = self.mock() + cd = self.mock() + sco = self.mock() + + cd.SSLCipherSpec = "TLS_RSA_WITH_AES_256_CBC_SHA" + sco.KeyRepository = "/tmp/foobar" + + sys.modules["pymqi"] = self.mock() + sys.modules["pymqi"].expects(once()).QueueManager(eq(None)).will(return_value(mgr)) + sys.modules["pymqi"].expects(once()).cd().will(return_value(cd)) + sys.modules["pymqi"].stubs().sco().will(return_value(sco)) + mgr.expects(once()).connectWithOptions(eq(QUEUE_MANAGER), cd=eq(cd), opts=eq(opts), sco=eq(sco)) + + factory = WebSphereMQConnectionFactory(QUEUE_MANAGER, CHANNEL, HOST, LISTENER_PORT, ssl=True, + ssl_cipher_spec="TLS_RSA_WITH_AES_256_CBC_SHA", + ssl_key_repository="/tmp/foobar") + factory._connect() + + del(sys.modules["pymqi"]) + + def testSSLIncorrectSettings(self): + + opts = CMQC.MQCNO_HANDLE_SHARE_BLOCK + mgr = self.mock() + cd = self.mock() + sco = self.mock() + + sys.modules["pymqi"] = self.mock() + sys.modules["pymqi"].expects(at_least_once()).QueueManager(eq(None)).will(return_value(mgr)) + sys.modules["pymqi"].expects(at_least_once()).cd().will(return_value(cd)) + sys.modules["pymqi"].stubs().sco().will(return_value(sco)) + mgr.expects(at_least_once()).connectWithOptions(eq(QUEUE_MANAGER), cd=eq(cd), opts=eq(opts), sco=eq(sco)) + + # ssl=True and no ssl_cipher_spec nor ssl_key_repository. + factory = WebSphereMQConnectionFactory(QUEUE_MANAGER, CHANNEL, HOST, LISTENER_PORT, ssl=True) + self.assertRaises(JMSException, factory._connect) + try: + factory._connect() + except JMSException, e: + self.assertEquals(e.args[0], "SSL support requires setting both ssl_cipher_spec and ssl_key_repository") + + # ssl=True and ssl_cipher_spec only. + factory = WebSphereMQConnectionFactory(QUEUE_MANAGER, CHANNEL, HOST, LISTENER_PORT, ssl=True, + ssl_cipher_spec="TLS_RSA_WITH_AES_256_CBC_SHA") + self.assertRaises(JMSException, factory._connect) + try: + factory._connect() + except JMSException, e: + self.assertEquals(e.args[0], "SSL support requires setting both ssl_cipher_spec and ssl_key_repository") + + # ssl=True and ssl_key_repository only. + factory = WebSphereMQConnectionFactory(QUEUE_MANAGER, CHANNEL, HOST, LISTENER_PORT, ssl=True, + ssl_key_repository="/tmp/foobar") + self.assertRaises(JMSException, factory._connect) + try: + factory._connect() + except JMSException, e: + self.assertEquals(e.args[0], "SSL support requires setting both ssl_cipher_spec and ssl_key_repository") + + # ssl_cipher_spec only, ssl=False. + factory = WebSphereMQConnectionFactory(QUEUE_MANAGER, CHANNEL, HOST, LISTENER_PORT, + ssl_cipher_spec="TLS_RSA_WITH_AES_256_CBC_SHA") + factory._connect() + + # ssl_key_repository only, ssl=False. + factory = WebSphereMQConnectionFactory(QUEUE_MANAGER, CHANNEL, HOST, LISTENER_PORT, + ssl_key_repository="/tmp/foobar") + factory._connect() + + # ssl_cipher_spec and ssl_key_repository, ssl=False. + factory = WebSphereMQConnectionFactory(QUEUE_MANAGER, CHANNEL, HOST, LISTENER_PORT, + ssl_cipher_spec="TLS_RSA_WITH_AES_256_CBC_SHA", + ssl_key_repository="/tmp/foobar") + factory._connect() + + del(sys.modules["pymqi"]) + + def testSimpleMessageListenerContainer(self): + + class TestMessageHandler(object): + def handle(self, message): + return 123 + + handler = TestMessageHandler() + concurrent_listeners = 4 + handlers_per_listener = 2 + wait_interval = 1300 + + queue = self.mock() + mgr = self.mock() + cd = self.mock() + gmo = self.mock() + md = get_default_md() + sco = self.mock() + opts = CMQC.MQCNO_HANDLE_SHARE_BLOCK + + sys.modules["pymqi"] = self.mock() + sys.modules["pymqi"].MQMIError = mq.MQMIError + sys.modules["pymqi"].stubs().QueueManager(eq(None)).will(return_value(mgr)) + sys.modules["pymqi"].stubs().cd().will(return_value(cd)) + sys.modules["pymqi"].stubs().sco().will(return_value(sco)) + sys.modules["pymqi"].stubs().md().will(return_value(md)) + sys.modules["pymqi"].stubs().gmo().will(return_value(gmo)) + sys.modules["pymqi"].expects(once()).Queue(same(mgr), + eq(DESTINATION), eq(CMQC.MQOO_INPUT_SHARED | CMQC.MQOO_OUTPUT)).will(return_value(queue)) + mgr.stubs().connectWithOptions(eq(QUEUE_MANAGER), cd=eq(cd), opts=eq(opts), sco=eq(sco)) + queue.set_default_stub(return_value(raw_message_for_get)) + + factory = WebSphereMQConnectionFactory(QUEUE_MANAGER, CHANNEL, HOST, LISTENER_PORT) + + smlc = SimpleMessageListenerContainer(factory, DESTINATION, handler, + concurrent_listeners, handlers_per_listener, wait_interval) + smlc.after_properties_set() + + self.assertEquals(smlc.factory, factory) + self.assertEquals(smlc.destination, DESTINATION) + self.assertEquals(smlc.handler, handler) + self.assertEquals(smlc.concurrent_listeners, concurrent_listeners) + self.assertEquals(smlc.handlers_per_listener, handlers_per_listener) + self.assertEquals(smlc.wait_interval, wait_interval) + + del(sys.modules["pymqi"]) + + def testWebSphereMQListener(self): + + message = get_rand_string(12) + exception_reason = get_rand_string(12) + "ęóąśłżźćń" + + class TestMessageHandler(object): + def __init__(self): + self.data = [] + + def __str__(self): + return "%s %s" % (hex(id(self)), str(self.data)) + + def handle(self, message): + self.data.append(message) + + class _ConnectionFactory(WebSphereMQConnectionFactory): + def __init__(self, *args): + super(_ConnectionFactory, self).__init__(*args) + self.call_count = 0 + + def receive(self, destination, wait_interval): + self.call_count += 1 + + if self.call_count == 1: + import sys + return message + elif self.call_count == 2: + raise NoMessageAvailableException() + else: + raise WebSphereMQJMSException(exception_reason, CMQC.MQCC_FAILED, CMQC.MQRC_OPTION_NOT_VALID_FOR_TYPE) + + + handler = TestMessageHandler() + handlers_per_listener = 1 + wait_interval = 1300 + + queue = self.mock() + mgr = self.mock() + cd = self.mock() + sco = self.mock() + gmo = self.mock() + md = get_default_md() + opts = CMQC.MQCNO_HANDLE_SHARE_BLOCK + + sys.modules["pymqi"] = self.mock() + sys.modules["pymqi"].MQMIError = mq.MQMIError + sys.modules["pymqi"].stubs().QueueManager(eq(None)).will(return_value(mgr)) + sys.modules["pymqi"].stubs().cd().will(return_value(cd)) + sys.modules["pymqi"].stubs().sco().will(return_value(sco)) + sys.modules["pymqi"].stubs().md().will(return_value(md)) + sys.modules["pymqi"].stubs().gmo().will(return_value(gmo)) + mgr.stubs().connectWithOptions(eq(QUEUE_MANAGER), cd=eq(cd), opts=eq(opts), sco=eq(sco)) + + factory = _ConnectionFactory(QUEUE_MANAGER, CHANNEL, HOST, LISTENER_PORT) + + listener = WebSphereMQListener() + listener.factory = factory + listener.destination = DESTINATION + listener.wait_interval = wait_interval + listener.handler = handler + listener.handlers_pool = ThreadPool(handlers_per_listener) + + try: + listener.run() + except WebSphereMQJMSException, e: + sleep(0.5) # Allows the handler thread to process the message + self.assertEquals(e.message, exception_reason) + self.assertEquals(3, factory.call_count) + self.assertEquals(1, len(handler.data)) + self.assertEquals(message, handler.data[0]) + finally: + del(sys.modules["pymqi"]) + + def testSimpleMessageListenerContainerMessageHandler(self): + handler = MessageHandler() + self.assertRaises(NotImplementedError, handler.handle, "foo") + + try: + handler.handle("foo") + except NotImplementedError, e: + self.assertEquals(e.message, "Should be overridden by subclasses.") + + def testNeedsMCD(self): + message = TextMessage(get_rand_string(12)) + destination = get_rand_string(12) + now = long(time() * 1000) + + has_mcd = MQRFH2JMS(True).build_header(message, destination, CMQC, now) + has_no_mcd = MQRFH2JMS(False).build_header(message, destination, CMQC, now) + + self.assertTrue('mcd' in has_mcd) + self.assertTrue('mcd' not in has_no_mcd) \ No newline at end of file diff --git a/test/springpythontest/remoting_xmlrpc.py b/test/springpythontest/remoting_xmlrpc.py index ba97bde..a0afcfe 100644 --- a/test/springpythontest/remoting_xmlrpc.py +++ b/test/springpythontest/remoting_xmlrpc.py @@ -174,7 +174,7 @@ def test_import_all(self): _locals = {} _globals = {} - exec "from springpython.remoting.xmlrpc import *" in _locals, _globals + exec("from springpython.remoting.xmlrpc import *", _locals, _globals) self.assertEqual(len(_globals), 3) self.assertEqual(sorted(_globals), ["SSLClient", "SSLServer", "VerificationException"]) diff --git a/test/springpythontest/remoting_xmlrpc.py.bak b/test/springpythontest/remoting_xmlrpc.py.bak new file mode 100644 index 0000000..ba97bde --- /dev/null +++ b/test/springpythontest/remoting_xmlrpc.py.bak @@ -0,0 +1,291 @@ +# -*- coding: utf-8 -*- + +""" + Copyright 2006-2008 SpringSource (http://springsource.com), All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +# stdlib +import socket +import ssl +import threading +import time +import unittest + +from SocketServer import StreamRequestHandler +from xmlrpclib import Transport + +# Spring Python +from springpython.remoting.xmlrpc import SSLServer, SSLClient, RequestHandler, \ + SSLClientTransport, VerificationException + +RESULT_OK = "All good" + +server_key = "./support/pki/server-key.pem" +server_cert = "./support/pki/server-cert.pem" +client_key = "./support/pki/client-key.pem" +client_cert = "./support/pki/client-cert.pem" +ca_certs = "./support/pki/ca-chain.pem" + +class MySSLServer(SSLServer): + + def test_server(self): + return RESULT_OK + + def register_functions(self): + self.register_function(self.shutdown) + self.register_function(self.test_server) + +class _DummyServer(SSLServer): + pass + +class _DummyRequest(): + def recv(self, *ignored_args, **ignored_kwargs): + pass + +class _MyClientTransport(object): + def __init__(self, ca_certs=None, keyfile=None, certfile=None, cert_reqs=None, + ssl_version=None, timeout=None, strict=None): + self.ca_certs = ca_certs + self.keyfile = keyfile + self.certfile = certfile + self.cert_reqs = cert_reqs + self.ssl_version = ssl_version + self.timeout = timeout + self.strict = strict + +class TestInitDefaultArguments(unittest.TestCase): + def test_init_default_arguments(self): + """ Tests various defaults various and those passed to __init__'s. + """ + + self.assertTrue(issubclass(VerificationException, Exception)) + self.assertEqual(RequestHandler.rpc_paths, ("/", "/RPC2")) + self.assertEqual(SSLClientTransport.user_agent, + "SSL XML-RPC Client (by http://springpython.webfactional.com)") + + server1 = MySSLServer("127.0.0.1", 8001) + + self.assertEqual(server1.keyfile, None) + self.assertEqual(server1.certfile, None) + self.assertEqual(server1.ca_certs, None) + self.assertEqual(server1.cert_reqs, ssl.CERT_NONE) + self.assertEqual(server1.ssl_version, ssl.PROTOCOL_TLSv1) + self.assertEqual(server1.do_handshake_on_connect, True) + self.assertEqual(server1.suppress_ragged_eofs, True) + self.assertEqual(server1.ciphers, None) + self.assertEqual(server1.logRequests, True) + self.assertEqual(server1.verify_fields, None) + + server_host = "127.0.0.1" + server_port = 8002 + server_keyfile = "server_keyfile" + server_certfile = "server_certfile" + server_ca_certs = "server_ca_certs" + server_cert_reqs = ssl.CERT_OPTIONAL + server_ssl_version = ssl.PROTOCOL_SSLv3 + server_do_handshake_on_connect = False + server_suppress_ragged_eofs = False + server_ciphers = "ALL" + server_log_requests = False + server_verify_fields = {"commonName": "Foo", "organizationName":"Baz"} + + server2 = MySSLServer(server_host, server_port, server_keyfile, + server_certfile, server_ca_certs, server_cert_reqs, + server_ssl_version, server_do_handshake_on_connect, + server_suppress_ragged_eofs, server_ciphers, server_log_requests, + verify_fields=server_verify_fields) + + # inherited from SocketServer.BaseServer + self.assertEqual(server2.server_address, (server_host, server_port)) + + self.assertEqual(server2.keyfile, server_keyfile) + self.assertEqual(server2.certfile, server_certfile) + self.assertEqual(server2.ca_certs, server_ca_certs) + self.assertEqual(server2.cert_reqs, server_cert_reqs) + self.assertEqual(server2.ssl_version, server_ssl_version) + self.assertEqual(server2.do_handshake_on_connect, server_do_handshake_on_connect) + self.assertEqual(server2.suppress_ragged_eofs, server_suppress_ragged_eofs) + self.assertEqual(server2.ciphers, server_ciphers) + self.assertEqual(server2.logRequests, server_log_requests) + self.assertEqual(sorted(server2.verify_fields), sorted(server_verify_fields)) + + client_uri="https://127.0.0.1:8000/RPC2" + client_ca_certs="client_ca_certs" + client_keyfile="client_keyfile" + client_certfile="client_certfile" + client_cert_reqs=ssl.CERT_OPTIONAL + client_ssl_version=ssl.PROTOCOL_SSLv23 + client_transport=_MyClientTransport + client_encoding="utf-16" + client_verbose=1 + client_allow_none=False + client_use_datetime=False + client_timeout=13 + client_strict=True + + client2 = SSLClient(client_uri, client_ca_certs, client_keyfile, + client_certfile, client_cert_reqs, client_ssl_version, + client_transport, client_encoding, client_verbose, + client_allow_none, client_use_datetime, client_timeout, + client_strict) + + self.assertEqual(client2._ServerProxy__host, "127.0.0.1:8000") + self.assertEqual(client2._ServerProxy__transport.ca_certs, client_ca_certs) + self.assertEqual(client2._ServerProxy__transport.keyfile, client_keyfile) + self.assertEqual(client2._ServerProxy__transport.certfile, client_certfile) + self.assertEqual(client2._ServerProxy__transport.cert_reqs, client_cert_reqs) + self.assertEqual(client2._ServerProxy__transport.ssl_version, client_ssl_version) + self.assertTrue(isinstance(client2._ServerProxy__transport, _MyClientTransport)) + self.assertEqual(client2._ServerProxy__encoding, client_encoding) + self.assertEqual(client2._ServerProxy__verbose, client_verbose) + self.assertEqual(client2._ServerProxy__allow_none, client_allow_none) + self.assertEqual(client2._ServerProxy__transport.timeout, client_timeout) + self.assertEqual(client2._ServerProxy__transport.strict, client_strict) + + self.assertRaises(NotImplementedError, _DummyServer, "127.0.0.1", 8003) + + def test_request_handler(self): + request = _DummyRequest() + rh = RequestHandler(request, None, None) + rh.setup() + self.assertTrue(rh.connection is request) + self.assertTrue(isinstance(rh.rfile, socket._fileobject)) + self.assertTrue(isinstance(rh.wfile, socket._fileobject)) + self.assertTrue(rh.rfile._sock is request) + self.assertEqual(rh.rfile.mode, "rb") + self.assertEqual(rh.rfile.bufsize, socket._fileobject.default_bufsize) + self.assertTrue(rh.wfile._sock is request) + self.assertEqual(rh.wfile.mode, "wb") + self.assertEqual(rh.wfile.bufsize, StreamRequestHandler.wbufsize) + + def test_import_all(self): + _locals = {} + _globals = {} + + exec "from springpython.remoting.xmlrpc import *" in _locals, _globals + + self.assertEqual(len(_globals), 3) + self.assertEqual(sorted(_globals), ["SSLClient", "SSLServer", "VerificationException"]) + +class TestSSL(unittest.TestCase): + + class _ClientServerContextManager(object): + def __init__(self, server_port, cert_reqs=ssl.CERT_NONE, verify_fields={}): + self.server_port = server_port + self.cert_reqs = cert_reqs + self.verify_fields = verify_fields + + def __enter__(self): + server = MySSLServer("127.0.0.1", self.server_port, server_key, + server_cert, ca_certs, cert_reqs=self.cert_reqs, + verify_fields=self.verify_fields) + self.server_thread = self._start_server(server) + time.sleep(0.5) + + def __exit__(self, *ignored_args): + self.server_thread.server.shutdown() + + def _start_server(self, server): + + class _ServerController(threading.Thread): + def __init__(self, server): + self.server = server + self.isDaemon = False + super(_ServerController, self).__init__() + + def run(self): + self.server.serve_forever() + + server_thread = _ServerController(server) + server_thread.start() + + return server_thread + + + def test_simple_ssl(self): + """ Server uses its cert, client uses none. + """ + server_port = 9001 + with TestSSL._ClientServerContextManager(server_port): + client = SSLClient("https://localhost:%d/RPC2" % server_port, ca_certs) + self.assertEqual(client.test_server(), RESULT_OK) + + def test_client_cert(self): + """ Server & client use certs. + """ + server_port = 9002 + with TestSSL._ClientServerContextManager(server_port, ssl.CERT_REQUIRED): + client = SSLClient("https://localhost:%d/RPC2" % server_port, ca_certs, + client_key, client_cert) + self.assertEqual(client.test_server(), RESULT_OK) + + def test_client_cert_verify_ok(self): + """ Server & client use certs. Server succesfully validates client certificate's fields. + """ + server_port = 9003 + verify_fields = {"commonName":"My Client", "countryName":"US", + "organizationalUnitName":"My Unit", "organizationName":"My Company", + "stateOrProvinceName":"My State"} + + with TestSSL._ClientServerContextManager(server_port, ssl.CERT_REQUIRED, verify_fields): + client = SSLClient("https://localhost:%d/RPC2" % server_port, ca_certs, + client_key, client_cert) + self.assertEqual(client.test_server(), RESULT_OK) + + def test_client_cert_verify_failure_missing_field(self): + """ Server & client use certs. Server fails to validate client certificate's fields + (a field is missing). + """ + server_port = 9004 + verify_fields = {"commonName":"My Client", "countryName":"US", + "organizationalUnitName":"My Unit", "organizationName":"My Company", + "stateOrProvinceName":"My State", "FOO": "BAR"} + + with TestSSL._ClientServerContextManager(server_port, ssl.CERT_REQUIRED, verify_fields): + client = SSLClient("https://localhost:%d/RPC2" % server_port, ca_certs, + client_key, client_cert) + self.assertRaises(Exception, client.test_server) + + def test_client_cert_failure_field_incorrect_value(self): + """ Server & client use certs. Server fails to validate client certificate's fields + (all fields are in place, but commonName has an incorrect value). + """ + server_port = 9005 + verify_fields = {"commonName":"Invalid"} + with TestSSL._ClientServerContextManager(server_port, ssl.CERT_REQUIRED, verify_fields): + client = SSLClient("https://localhost:%d/RPC2" % server_port, ca_certs, + client_key, client_cert) + self.assertRaises(Exception, client.test_server) + + def test_client_cert_verify_failure_cert_optional_no_client_cert(self): + """ Server optionally requires a client to send the certificate + and validates its fields but client sends none. + """ + server_port = 9006 + verify_fields = {"commonName":"My Client"} + with TestSSL._ClientServerContextManager(server_port, ssl.CERT_OPTIONAL, verify_fields): + client = SSLClient("https://localhost:%d/RPC2" % server_port, ca_certs) + self.assertRaises(Exception, client.test_server) + + def test_cert_required_no_client_cert(self): + """ Server requires a client to send the certificate but client sends none. + """ + server_port = 9007 + with TestSSL._ClientServerContextManager(server_port, ssl.CERT_REQUIRED): + client = SSLClient("https://localhost:%d/RPC2" % server_port, ca_certs) + self.assertRaises(ssl.SSLError, client.test_server) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/test/springpythontest/securityWebTestCases.py b/test/springpythontest/securityWebTestCases.py index 95edc3f..d7ef9c3 100644 --- a/test/springpythontest/securityWebTestCases.py +++ b/test/springpythontest/securityWebTestCases.py @@ -81,10 +81,10 @@ def testIteratingThroughASimpleFilterChain(self): filterChain.addFilter(filterSecurityInterceptor) chain = filterChain.getFilterChain() - self.assertEquals(httpSessionContextIntegrationFilter, chain.next()) - self.assertEquals(exceptionTranslationFilter, chain.next()) - self.assertEquals(authenticationProcessFilter, chain.next()) - self.assertEquals(filterSecurityInterceptor, chain.next()) + self.assertEquals(httpSessionContextIntegrationFilter, next(chain)) + self.assertEquals(exceptionTranslationFilter, next(chain)) + self.assertEquals(authenticationProcessFilter, next(chain)) + self.assertEquals(filterSecurityInterceptor, next(chain)) self.assertRaises(StopIteration, chain.next) def testHttpSessionContextIntegrationFilter(self): diff --git a/test/springpythontest/securityWebTestCases.py.bak b/test/springpythontest/securityWebTestCases.py.bak new file mode 100644 index 0000000..95edc3f --- /dev/null +++ b/test/springpythontest/securityWebTestCases.py.bak @@ -0,0 +1,233 @@ +""" + Copyright 2006-2008 SpringSource (http://springsource.com), All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +import pickle +import unittest +from pmock import * +from springpython.security import BadCredentialsException +from springpython.security.context import SecurityContext +from springpython.security.context import SecurityContextHolder +from springpython.security.providers import AuthenticationManager +from springpython.security.providers import UsernamePasswordAuthenticationToken +from springpython.security.providers.dao import DaoAuthenticationProvider +from springpython.security.userdetails import InMemoryUserDetailsService +from springpython.security.web import Filter +from springpython.security.web import FilterChain +from springpython.security.web import HttpSessionContextIntegrationFilter +from springpython.security.web import ExceptionTranslationFilter +from springpython.security.web import AuthenticationProcessingFilter +from springpython.security.web import FilterSecurityInterceptor +from springpython.security.web import FilterChainProxy +from springpython.security.web import SessionStrategy + +class StubSessionStrategy(SessionStrategy): + """ + This is a stand-in for any web-based HTTP Session solution. It is a simple in-memory dictionary + used to serve the role of holding session data during any tests. + """ + def __init__(self): + SessionStrategy.__init__(self) + self.sessionData = {} + + def getHttpSession(self, environ): + return self.sessionData + + def setHttpSession(self, key, value): + self.sessionData[key] = value + +class StubAuthenticationFilter(Filter): + """ + This is a pass-through filter, used to help test HttpSessionContextIntegrationFilter. That filter + expects there to be another filter in place that will authenticate credentials, and in turn modify them. + This filter checks if the incoming (default) credentials are authenticated, and if not, sets them + as such. Then it passes on to the next filter. + """ + def __call__(self, environ, start_response): + if not SecurityContextHolder.getContext().authentication.isAuthenticated(): + SecurityContextHolder.getContext().authentication.setAuthenticated(True) + return self.doNextFilter(environ, start_response) + +class WebInterfaceTestCase(unittest.TestCase): + def testSessionStrategy(self): + sessionStrategy = SessionStrategy() + environ = {} + self.assertRaises(NotImplementedError, sessionStrategy.getHttpSession, environ) + +class FilterTestCase(MockTestCase): + def testIteratingThroughASimpleFilterChain(self): + filterChain = FilterChain() + self.assertEquals(0, len(filterChain.chain)) + + httpSessionContextIntegrationFilter = HttpSessionContextIntegrationFilter() + exceptionTranslationFilter = ExceptionTranslationFilter() + authenticationProcessFilter = AuthenticationProcessingFilter() + filterSecurityInterceptor = FilterSecurityInterceptor() + + filterChain.addFilter(httpSessionContextIntegrationFilter) + filterChain.addFilter(exceptionTranslationFilter) + filterChain.addFilter(authenticationProcessFilter) + filterChain.addFilter(filterSecurityInterceptor) + + chain = filterChain.getFilterChain() + self.assertEquals(httpSessionContextIntegrationFilter, chain.next()) + self.assertEquals(exceptionTranslationFilter, chain.next()) + self.assertEquals(authenticationProcessFilter, chain.next()) + self.assertEquals(filterSecurityInterceptor, chain.next()) + self.assertRaises(StopIteration, chain.next) + + def testHttpSessionContextIntegrationFilter(self): + def start_response(): + pass + def application(environ, start_response): + return ["Success"] + + environ = {} + environ["PATH_INFO"] = "/index.html" + + sessionStrategy = StubSessionStrategy() + httpSessionContextIntegrationFilter = HttpSessionContextIntegrationFilter(sessionStrategy) + # HttpSessionContextIntegrationFilter expects another filter after it to alter the credentials. + stubAuthenticationFilter = StubAuthenticationFilter() + + filterChainProxy = FilterChainProxy() + filterChainProxy.filterInvocationDefinitionSource = [("/.*", [httpSessionContextIntegrationFilter, stubAuthenticationFilter])] + filterChainProxy.application = application + + self.assertEquals(["Success"], filterChainProxy(environ, start_response)) + self.assertEquals(["Success"], filterChainProxy(environ, start_response)) + + httpSession = sessionStrategy.getHttpSession(environ) + httpSession[httpSessionContextIntegrationFilter.SPRINGPYTHON_SECURITY_CONTEXT_KEY] = pickle.dumps("Bad credentials") + self.assertEquals(["Success"], filterChainProxy(environ, start_response)) + + def testFilterChainProxyWithMixedURLs(self): + """ + This test goes through the FilterChainProxy, and proves that it takes differing routes through filters + based on URL pattern matching. + """ + class PassthroughFilter1(Filter): + """This filter inserts a simple value to prove it was used.""" + def __call__(self, environ, start_response): + environ["PASSTHROUGH_FILTER1"] = True + return self.doNextFilter(environ, start_response) + + class PassthroughFilter2(Filter): + """This filter inserts a simple value to prove it was used.""" + def __call__(self, environ, start_response): + environ["PASSTHROUGH_FILTER2"] = True + return self.doNextFilter(environ, start_response) + + def start_response(): + pass + def application(environ, start_response): + return ["Success"] + + filterChainProxy = FilterChainProxy() + filterChainProxy.filterInvocationDefinitionSource = [("/.*html", [PassthroughFilter1()]), ("/.*jsp", [PassthroughFilter2()])] + filterChainProxy.application = application + + environ = {} + environ["PATH_INFO"] = "/index.html" + filterChainProxy(environ, start_response) + self.assertTrue("PASSTHROUGH_FILTER1" in environ) + self.assertTrue("PASSTHROUGH_FILTER2" not in environ) + + environ = {} + environ["PATH_INFO"] = "/index.jsp" + filterChainProxy(environ, start_response) + self.assertTrue("PASSTHROUGH_FILTER1" not in environ) + self.assertTrue("PASSTHROUGH_FILTER2" in environ) + + filterChainProxy2 = FilterChainProxy(filterInvocationDefinitionSource=[("/.*html", [PassthroughFilter1()]), ("/.*jsp", [PassthroughFilter2()])]) + filterChainProxy2.application = application + + environ = {} + environ["PATH_INFO"] = "/index.html" + filterChainProxy2(environ, start_response) + self.assertTrue("PASSTHROUGH_FILTER1" in environ) + self.assertTrue("PASSTHROUGH_FILTER2" not in environ) + + environ = {} + environ["PATH_INFO"] = "/index.jsp" + filterChainProxy2(environ, start_response) + self.assertTrue("PASSTHROUGH_FILTER1" not in environ) + self.assertTrue("PASSTHROUGH_FILTER2" in environ) + + def testAuthenticationProcessingFilterWithGoodPassword(self): + def start_response(): + pass + def application(environ, start_response): + return ["Success"] + + environ = {} + environ["PATH_INFO"] = "/index.html" + + inMemoryUserDetailsService = InMemoryUserDetailsService() + inMemoryUserDetailsService.user_dict = {"user1": ("good_password", ["role1", "blue"], True)} + inMemoryDaoAuthenticationProvider = DaoAuthenticationProvider() + inMemoryDaoAuthenticationProvider.user_details_service = inMemoryUserDetailsService + inMemoryDaoAuthenticationManager = AuthenticationManager([inMemoryDaoAuthenticationProvider]) + + authenticationFilter = AuthenticationProcessingFilter() + authenticationFilter.auth_manager = inMemoryDaoAuthenticationManager + authenticationFilter.alwaysReauthenticate = False + + token = UsernamePasswordAuthenticationToken("user1", "good_password", None) + self.assertFalse(token.isAuthenticated()) + + SecurityContextHolder.setContext(SecurityContext()) + SecurityContextHolder.getContext().authentication = token + + filterChainProxy = FilterChainProxy() + filterChainProxy.filterInvocationDefinitionSource = [("/.*", [authenticationFilter])] + filterChainProxy.application = application + + self.assertEquals(["Success"], filterChainProxy(environ, start_response)) + self.assertTrue(SecurityContextHolder.getContext().authentication.isAuthenticated()) + + self.assertEquals(["Success"], filterChainProxy(environ, start_response)) + self.assertTrue(SecurityContextHolder.getContext().authentication.isAuthenticated()) + + def testAuthenticationProcessingFilterWithBadPassword(self): + def start_response(): + pass + def application(environ, start_response): + return ["Success"] + + environ = {} + environ["PATH_INFO"] = "/index.html" + + inMemoryUserDetailsService = InMemoryUserDetailsService() + inMemoryUserDetailsService.user_dict = {"user1": ("good_password", ["role1", "blue"], True)} + inMemoryDaoAuthenticationProvider = DaoAuthenticationProvider() + inMemoryDaoAuthenticationProvider.user_details_service = inMemoryUserDetailsService + inMemoryDaoAuthenticationManager = AuthenticationManager([inMemoryDaoAuthenticationProvider]) + + authenticationFilter = AuthenticationProcessingFilter() + authenticationFilter.auth_manager = inMemoryDaoAuthenticationManager + authenticationFilter.alwaysReauthenticate = False + + token = UsernamePasswordAuthenticationToken("user1", "bad_password", None) + self.assertFalse(token.isAuthenticated()) + + SecurityContextHolder.setContext(SecurityContext()) + SecurityContextHolder.getContext().authentication = token + + filterChainProxy = FilterChainProxy() + filterChainProxy.filterInvocationDefinitionSource = [("/.*", [authenticationFilter])] + filterChainProxy.application = application + self.assertRaises(BadCredentialsException, filterChainProxy, environ, start_response) + self.assertFalse(SecurityContextHolder.getContext().authentication.isAuthenticated()) + diff --git a/test/springpythontest/support/testSupportClasses.py b/test/springpythontest/support/testSupportClasses.py index ee69a5e..2daf60c 100644 --- a/test/springpythontest/support/testSupportClasses.py +++ b/test/springpythontest/support/testSupportClasses.py @@ -197,7 +197,7 @@ def close(self): class StubDBFactory(ConnectionFactory): def __init__(self): - ConnectionFactory.__init__(self, [types.TupleType]) + ConnectionFactory.__init__(self, [tuple]) self.stubConnection = StubConnection() def connect(self): return self.stubConnection @@ -244,7 +244,7 @@ def withdraw(self, amount, account_num): return amount def balance(self, account_num): - return self.dt.query_for_object("SELECT balance FROM account WHERE account_num = ?", (account_num,), types.FloatType) + return self.dt.query_for_object("SELECT balance FROM account WHERE account_num = ?", (account_num,), float) def transfer(self, amount, from_account, to_account): self.logger.debug("Transferring $%s from %s to %s." % (amount, from_account, to_account)) @@ -319,7 +319,7 @@ def withdraw(self, amount, account_num): return amount def balance(self, account_num): - return self.dt.query_for_object("SELECT balance FROM account WHERE account_num = ?", (account_num,), types.FloatType) + return self.dt.query_for_object("SELECT balance FROM account WHERE account_num = ?", (account_num,), float) @transactional() def transfer(self, amount, from_account, to_account): @@ -371,7 +371,7 @@ def withdraw(self, amount, account_num): return amount def balance(self, account_num): - return self.dt.query_for_object("SELECT balance FROM account WHERE account_num = ?", (account_num,), types.FloatType) + return self.dt.query_for_object("SELECT balance FROM account WHERE account_num = ?", (account_num,), float) @transactional def transfer(self, amount, from_account, to_account): @@ -428,7 +428,7 @@ def withdraw(self, amount, account_num): @transactional(["PROPAGATION_SUPPORTS","readOnly"]) def balance(self, account_num): self.logger.debug("Checking balance for %s" % account_num) - return self.dt.query_for_object("SELECT balance FROM account WHERE account_num = ?", (account_num,), types.FloatType) + return self.dt.query_for_object("SELECT balance FROM account WHERE account_num = ?", (account_num,), float) @transactional(["PROPAGATION_REQUIRED"]) def transfer(self, amount, from_account, to_account): diff --git a/test/springpythontest/support/testSupportClasses.py.bak b/test/springpythontest/support/testSupportClasses.py.bak new file mode 100644 index 0000000..ee69a5e --- /dev/null +++ b/test/springpythontest/support/testSupportClasses.py.bak @@ -0,0 +1,537 @@ +""" + Copyright 2006-2008 SpringSource (http://springsource.com), All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +import logging +import types +from pmock import * +from springpython.aop import MethodInterceptor +from springpython.config import PythonConfig +from springpython.config import Object +from springpython.context import scope +from springpython.context import ObjectPostProcessor +from springpython.database.core import DaoSupport +from springpython.database.core import DatabaseTemplate +from springpython.database.core import RowMapper +from springpython.database.factory import ConnectionFactory +from springpython.database.factory import MySQLConnectionFactory +from springpython.database import transaction +from springpython.database.transaction import AutoTransactionalObject +from springpython.database.transaction import ConnectionFactoryTransactionManager +from springpython.database.transaction import TransactionTemplate +from springpython.database.transaction import TransactionCallbackWithoutResult +from springpython.database.transaction import TransactionProxyFactoryObject +from springpython.database.transaction import transactional + +class Person(object): + def __init__(self, name=None, phone=None): + self.name = name + self.phone = phone + +class Animal(object): + def __init__(self, name=None, category=None): + self.name = name + self.category = category + +class SampleRowMapper(RowMapper): + def map_row(self, row, metadata=None): + return Person(name = row[0], phone = row[1]) + +class AnimalRowMapper(RowMapper): + def map_row(self, row, metadata=None): + return Animal(name = row[0], category = row[1]) + +class InvalidCallbackHandler(object): + pass + +class ImproperCallbackHandler(object): + def map_row(self): + raise Exception("You should not have made it this far.") + +class ValidHandler(object): + def map_row(self, row, metadata=None): + return {row[0]:row[1]} + +class MovieLister(object): + def __init__ (self): + self.finder = None + +class ColonMovieFinder(object): + def __init__(self, filename = ""): + self.filename = filename + def findAll (self): + return [line.strip() for line in open(self.filename).readlines()] + +class StringHolder(object): + def __init__(self, str=""): + self.str = str + +class MovieBasedApplicationContext(PythonConfig): + """ + This is a test support class that inherits its functionality from the super class. + """ + def __init__(self): + super(MovieBasedApplicationContext, self).__init__() + + @Object(scope.PROTOTYPE) + def MovieLister(self): + lister = MovieLister() + lister.finder = self.MovieFinder() + lister.description = self.SingletonString() + self.logger.debug("Description = %s" % lister.description) + return lister + + @Object(scope.SINGLETON, True) + def MovieFinder(self): + return ColonMovieFinder(filename="support/movies1.txt") + + @Object(lazy_init=True) # scope.SINGLETON is the default + def SingletonString(self): + return StringHolder("There should only be one copy of this string") + + def NotExposed(self): + pass + +class MixedApplicationContext(PythonConfig): + def __init__(self): + super(MixedApplicationContext, self).__init__() + + @Object(scope.SINGLETON) + def MovieFinder(self): + return ColonMovieFinder(filename="support/movies1.txt") + +class MixedApplicationContext2(PythonConfig): + def __init__(self): + super(MixedApplicationContext2, self).__init__() + + @Object(scope.PROTOTYPE) + def MovieLister(self): + lister = MovieLister() + lister.finder = self.app_context.get_object("MovieFinder") + lister.description = self.SingletonString() + self.logger.debug("Description = %s" % lister.description) + return lister + + @Object # scope.SINGLETON is the default + def SingletonString(self): + return StringHolder("There should only be one copy of this string") + +class TheOtherMovieFinder(object): + def __init__(self, filename = ""): + self.filename = filename + def findAll(self): + return [line.strip()[0:3] for line in open(self.filename).readlines()] + +class SampleBlockOfData: + def __init__(self, data): + self.data = data + def getLabel(self): + return self.data + +class SampleService: + def __init__(self): + self.attribute = "sample" + def method(self, data): + return "You made it! => %s" % data + def doSomething(self): + return "Alright!" + def __str__(self): + return "This is a sample service." + +class NewStyleSampleService(object): + def __init__(self): + self.attribute = "new_sample" + def method(self, data): + return "You made it to a new style class! => %s" % data + def doSomething(self): + return "Even better!" + def __str__(self): + return "This is a new style sample service." + +class RemoteService1(object): + def getData(self, param): + return "You got remote data => %s" % param + def getMoreData(self, param): + return "You got more remote data => %s" % param + +class RemoteService2(object): + def executeOperation(self, routine): + return "Operation %s has been carried out" % routine + def executeOtherOperation(self, routine): + return "Other operation %s has been carried out" % routine + +class BeforeAndAfterInterceptor(MethodInterceptor): + def __init__(self): + self.logger = logging.getLogger("springpythontest.testSupportClasses.BeforeAndAfterInterceptor") + + def invoke(self, invocation): + results = "BEFORE => " + invocation.proceed() + " <= AFTER" + return results + +class WrappingInterceptor(MethodInterceptor): + def __init__(self): + self.logger = logging.getLogger("springpythontest.testSupportClasses.WrappingInterceptor") + + def invoke(self, invocation): + results = "" + invocation.proceed() + "" + return results + +class StubConnection(object): + def __init__(self): + self.mockCursor = None + def cursor(self): + return self.mockCursor + def close(self): + pass + +class StubDBFactory(ConnectionFactory): + def __init__(self): + ConnectionFactory.__init__(self, [types.TupleType]) + self.stubConnection = StubConnection() + def connect(self): + return self.stubConnection + def close(self): + pass + +class ImpFileProps(object): + def __init__(self, paystat_work_dir, paystat_reload_dir, paystat_archive_dir, oid): + self.paystat_work_dir = paystat_work_dir + self.paystat_reload_dir = paystat_reload_dir + self.paystat_archive_dir = paystat_archive_dir + self.oid = oid + +class ImpFilePropsRowMapper(RowMapper): + def map_row(self, row, metadata=None): + return ImpFileProps(row[0], row[1], row[2], row[3]) + +class BankException(Exception): + pass + +class Bank(object): + """This sample application can be used to demonstrate the value of atomic operations. The transfer operation + must be wrapped in a transaction in order to perform correctly. Otherwise, any errors in the deposit will + allow the from-account to leak assets.""" + def __init__(self, factory): + self.logger = logging.getLogger("springpythontest.testSupportClasses.Bank") + self.dt = DatabaseTemplate(factory) + + def open(self, account_num): + self.logger.debug("Opening account %s with $0 balance." % account_num) + self.dt.execute("INSERT INTO account (account_num, balance) VALUES (?,?)", (account_num, 0)) + + def deposit(self, amount, account_num): + self.logger.debug("Depositing $%s into %s" % (amount, account_num)) + rows = self.dt.execute("UPDATE account SET balance = balance + ? WHERE account_num = ?", (amount, account_num)) + if rows == 0: + raise BankException("Account %s does NOT exist" % account_num) + + def withdraw(self, amount, account_num): + self.logger.debug("Withdrawing $%s from %s" % (amount, account_num)) + rows = self.dt.execute("UPDATE account SET balance = balance - ? WHERE account_num = ?", (amount, account_num)) + if rows == 0: + raise BankException("Account %s does NOT exist" % account_num) + return amount + + def balance(self, account_num): + return self.dt.query_for_object("SELECT balance FROM account WHERE account_num = ?", (account_num,), types.FloatType) + + def transfer(self, amount, from_account, to_account): + self.logger.debug("Transferring $%s from %s to %s." % (amount, from_account, to_account)) + self.withdraw(amount, from_account) + self.deposit(amount, to_account) + +class DatabaseTxTestAppContext(PythonConfig): + def __init__(self, factory): + super(DatabaseTxTestAppContext, self).__init__() + self.factory = factory + + @Object + def bank_target(self): + return Bank(self.factory) + + @Object + def tx_object(self): + return AutoTransactionalObject(self.tx_mgr()) + + @Object + def tx_mgr(self): + return ConnectionFactoryTransactionManager(self.factory) + + @Object + def bank(self): + transactionAttributes = [] + transactionAttributes.append((".*transfer", ["PROPAGATION_REQUIRED"])) + transactionAttributes.append((".*", ["PROPAGATION_REQUIRED","readOnly"])) + return TransactionProxyFactoryObject(self.tx_mgr(), self.bank_target(), transactionAttributes) + + +class DatabaseTxTestAppContextWithNoAutoTransactionalObject(PythonConfig): + def __init__(self, factory): + super(DatabaseTxTestAppContextWithNoAutoTransactionalObject, self).__init__() + self.factory = factory + + @Object + def bank_target(self): + return Bank(self.factory) + + @Object + def tx_mgr(self): + return ConnectionFactoryTransactionManager(self.factory) + + @Object + def bank(self): + return TransactionalBank(self.factory) + +class TransactionalBank(object): + """This sample application can be used to demonstrate the value of atomic operations. The transfer operation + must be wrapped in a transaction in order to perform correctly. Otherwise, any errors in the deposit will + allow the from-account to leak assets.""" + def __init__(self, factory): + self.logger = logging.getLogger("springpythontest.testSupportClasses.TransactionalBank") + self.dt = DatabaseTemplate(factory) + + def open(self, account_num): + self.logger.debug("Opening account %s with $0 balance." % account_num) + self.dt.execute("INSERT INTO account (account_num, balance) VALUES (?,?)", (account_num, 0)) + + def deposit(self, amount, account_num): + self.logger.debug("Depositing $%s into %s" % (amount, account_num)) + rows = self.dt.execute("UPDATE account SET balance = balance + ? WHERE account_num = ?", (amount, account_num)) + if rows == 0: + raise BankException("Account %s does NOT exist" % account_num) + + def withdraw(self, amount, account_num): + self.logger.debug("Withdrawing $%s from %s" % (amount, account_num)) + rows = self.dt.execute("UPDATE account SET balance = balance - ? WHERE account_num = ?", (amount, account_num)) + if rows == 0: + raise BankException("Account %s does NOT exist" % account_num) + return amount + + def balance(self, account_num): + return self.dt.query_for_object("SELECT balance FROM account WHERE account_num = ?", (account_num,), types.FloatType) + + @transactional() + def transfer(self, amount, from_account, to_account): + self.logger.debug("Transferring $%s from %s to %s." % (amount, from_account, to_account)) + self.withdraw(amount, from_account) + self.deposit(amount, to_account) + +class DatabaseTxTestDecorativeTransactions(PythonConfig): + def __init__(self, factory): + super(DatabaseTxTestDecorativeTransactions, self).__init__() + self.factory = factory + + @Object + def tx_object(self): + return AutoTransactionalObject(self.tx_mgr()) + + @Object + def tx_mgr(self): + return ConnectionFactoryTransactionManager(self.factory) + + @Object + def bank(self): + results = TransactionalBank(self.factory) + return results + +class TransactionalBankWithNoTransactionalArguments(object): + """This sample application can be used to demonstrate the value of atomic operations. The transfer operation + must be wrapped in a transaction in order to perform correctly. Otherwise, any errors in the deposit will + allow the from-account to leak assets.""" + def __init__(self, factory): + self.logger = logging.getLogger("springpythontest.testSupportClasses.TransactionalBankWithNoTransactionalArguments") + self.dt = DatabaseTemplate(factory) + + def open(self, account_num): + self.logger.debug("Opening account %s with $0 balance." % account_num) + self.dt.execute("INSERT INTO account (account_num, balance) VALUES (?,?)", (account_num, 0)) + + def deposit(self, amount, account_num): + self.logger.debug("Depositing $%s into %s" % (amount, account_num)) + rows = self.dt.execute("UPDATE account SET balance = balance + ? WHERE account_num = ?", (amount, account_num)) + if rows == 0: + raise BankException("Account %s does NOT exist" % account_num) + + def withdraw(self, amount, account_num): + self.logger.debug("Withdrawing $%s from %s" % (amount, account_num)) + rows = self.dt.execute("UPDATE account SET balance = balance - ? WHERE account_num = ?", (amount, account_num)) + if rows == 0: + raise BankException("Account %s does NOT exist" % account_num) + return amount + + def balance(self, account_num): + return self.dt.query_for_object("SELECT balance FROM account WHERE account_num = ?", (account_num,), types.FloatType) + + @transactional + def transfer(self, amount, from_account, to_account): + self.logger.debug("Transferring $%s from %s to %s." % (amount, from_account, to_account)) + self.withdraw(amount, from_account) + self.deposit(amount, to_account) + +class DatabaseTxTestDecorativeTransactionsWithNoArguments(PythonConfig): + def __init__(self, factory): + super(DatabaseTxTestDecorativeTransactionsWithNoArguments, self).__init__() + self.factory = factory + + @Object + def tx_object(self): + return AutoTransactionalObject(self.tx_mgr()) + + @Object + def tx_mgr(self): + return ConnectionFactoryTransactionManager(self.factory) + + @Object + def bank(self): + results = TransactionalBankWithNoTransactionalArguments(self.factory) + return results + +class TransactionalBankWithLotsOfTransactionalArguments(object): + """This sample application can be used to demonstrate the value of atomic operations. The transfer operation + must be wrapped in a transaction in order to perform correctly. Otherwise, any errors in the deposit will + allow the from-account to leak assets.""" + def __init__(self, factory): + self.logger = logging.getLogger("springpythontest.testSupportClasses.TransactionalBankWithLotsOfTransactionalArguments") + self.dt = DatabaseTemplate(factory) + + @transactional(["PROPAGATION_REQUIRED"]) + def open(self, account_num): + self.logger.debug("Opening account %s with $0 balance." % account_num) + self.dt.execute("INSERT INTO account (account_num, balance) VALUES (?,?)", (account_num, 0)) + + @transactional(["PROPAGATION_REQUIRED"]) + def deposit(self, amount, account_num): + self.logger.debug("Depositing $%s into %s" % (amount, account_num)) + rows = self.dt.execute("UPDATE account SET balance = balance + ? WHERE account_num = ?", (amount, account_num)) + if rows == 0: + raise BankException("Account %s does NOT exist" % account_num) + + @transactional(["PROPAGATION_REQUIRED"]) + def withdraw(self, amount, account_num): + self.logger.debug("Withdrawing $%s from %s" % (amount, account_num)) + rows = self.dt.execute("UPDATE account SET balance = balance - ? WHERE account_num = ?", (amount, account_num)) + if rows == 0: + raise BankException("Account %s does NOT exist" % account_num) + return amount + + @transactional(["PROPAGATION_SUPPORTS","readOnly"]) + def balance(self, account_num): + self.logger.debug("Checking balance for %s" % account_num) + return self.dt.query_for_object("SELECT balance FROM account WHERE account_num = ?", (account_num,), types.FloatType) + + @transactional(["PROPAGATION_REQUIRED"]) + def transfer(self, amount, from_account, to_account): + self.logger.debug("Transferring $%s from %s to %s." % (amount, from_account, to_account)) + self.withdraw(amount, from_account) + self.deposit(amount, to_account) + + @transactional(["PROPAGATION_NEVER"]) + def nonTransactionalOperation(self): + self.logger.debug("Executing non-transactional operation.") + + @transactional(["PROPAGATION_MANDATORY"]) + def mandatoryOperation(self): + self.logger.debug("Executing mandatory transactional operation.") + + @transactional(["PROPAGATION_REQUIRED"]) + def mandatoryOperationTransactionalWrapper(self): + self.mandatoryOperation() + self.mandatoryOperation() + + @transactional(["PROPAGATION_REQUIRED"]) + def nonTransactionalOperationTransactionalWrapper(self): + self.nonTransactionalOperation() + +class DatabaseTxTestDecorativeTransactionsWithLotsOfArguments(PythonConfig): + def __init__(self, factory): + super(DatabaseTxTestDecorativeTransactionsWithLotsOfArguments, self).__init__() + self.factory = factory + + @Object + def tx_mgr(self): + return ConnectionFactoryTransactionManager(self.factory) + + @Object + def tx_object(self): + return AutoTransactionalObject(self.tx_mgr()) + + @Object + def bank(self): + results = TransactionalBankWithLotsOfTransactionalArguments(self.factory) + return results + +class ValueHolder(object): + def __init__(self, string_holder = None): + self.some_dict = None + self.some_list = None + self.some_props = None + self.some_set = None + self.some_frozen_set = None + self.some_tuple = None + self.string_holder = string_holder + +class MultiValueHolder(object): + def __init__(self, a = "a", b = "b", c = "c"): + self.a = a + self.b = b + self.c = c + +class ConstructorBasedContainer(PythonConfig): + @Object + def MultiValueHolder(self): + return MultiValueHolder(a="alt a", b="alt b") + + @Object + def MultiValueHolder2(self): + return MultiValueHolder(c="alt c", b="alt b") + +class Controller(object): + def __init__(self, executors=None): + self.executors = executors + +class Executor(object): + pass + +class SamplePostProcessor(ObjectPostProcessor): + def post_process_after_initialization(self, obj, obj_name): + setattr(obj, "processedAfter", obj_name) + return obj + +class SamplePostProcessor2(ObjectPostProcessor): + def post_process_before_initialization(self, obj, obj_name): + setattr(obj, "processedBefore", obj_name) + return obj + +class Service(object): + def __init__(self, ip=None, port=None, path=None): + self.ip = ip + self.port = port + self.path = path + + def __str__(self): + return "" % (hex(id(self)), self.ip, self.port, self.path) + +class Foo(object): + def __init__(self, a=None, b=None, c=None, d=None, e=None, f=None, g=None): + self.a = a + self.b = b + self.c = c + self.d = d + self.e = e + self.f = f + self.g = g + + def __str__(self): + return "" % (hex(id(self)), self.a, self.b, + self.c, self.d, self.e, self.f, self.g) diff --git a/test/standalone/pyro_thread_test.py b/test/standalone/pyro_thread_test.py index fd0d6e2..ecffa9a 100644 --- a/test/standalone/pyro_thread_test.py +++ b/test/standalone/pyro_thread_test.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ +from __future__ import print_function ######################################################################## # This is a stand-alone test, meaning it doesn't run well in automated @@ -27,7 +28,7 @@ class MySampleService(object): def hey(self): - print "You have just called the sample service!" + print("You have just called the sample service!") class MySampleServiceAppContext(PythonConfig): def __init__(self): @@ -55,5 +56,5 @@ def mySampleServiceExporter(self): logger.addHandler(ch) - print "Starting up context that exposese reported issue..." + print("Starting up context that exposese reported issue...") ctx = ApplicationContext(MySampleServiceAppContext()) diff --git a/test/standalone/pyro_thread_test.py.bak b/test/standalone/pyro_thread_test.py.bak new file mode 100644 index 0000000..fd0d6e2 --- /dev/null +++ b/test/standalone/pyro_thread_test.py.bak @@ -0,0 +1,59 @@ +""" + Copyright 2006-2008 SpringSource (http://springsource.com), All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +######################################################################## +# This is a stand-alone test, meaning it doesn't run well in automated +# scenarios. This script exposed bug http://jira.springframework.org/browse/SESPRINGPYTHONPY-99, +# which showed _PyroThread having a name collisions with python2.6's threading.Thread +# class. By renaming _PyroThread's self.daemon as self.pyro_daemon, this code +# now works with python2.6. It was also used to confirm python2.5, and jython2.5.1.FINAL. +######################################################################## + +from springpython.config import Object, PythonConfig +from springpython.remoting.pyro import PyroServiceExporter + +class MySampleService(object): + def hey(self): + print "You have just called the sample service!" + +class MySampleServiceAppContext(PythonConfig): + def __init__(self): + PythonConfig.__init__(self) + + @Object + def mySampleService(self): + return MySampleService() + + @Object + def mySampleServiceExporter(self): + return PyroServiceExporter(self.mySampleService(), "service", "localhost", 7000) + +if __name__ == "__main__": + import logging + from springpython.context import ApplicationContext + + logger = logging.getLogger("springpython") + loggingLevel = logging.DEBUG + logger.setLevel(loggingLevel) + ch = logging.StreamHandler() + ch.setLevel(loggingLevel) + formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") + ch.setFormatter(formatter) + logger.addHandler(ch) + + + print "Starting up context that exposese reported issue..." + ctx = ApplicationContext(MySampleServiceAppContext()) diff --git a/test/standalone/test.py b/test/standalone/test.py index 5a61c6c..40567e4 100644 --- a/test/standalone/test.py +++ b/test/standalone/test.py @@ -1,3 +1,4 @@ +from __future__ import print_function ################################ # This stand-alone script is used to exercise the LDAP APIs. The intent is to get things working, and then # replace this with an effective unit test. @@ -20,16 +21,16 @@ authentication = UsernamePasswordAuthenticationToken(username="bob", password="bobspassword") -print "Input = %s" % authentication +print("Input = %s" % authentication) auth1 = authProvider.authenticate(authentication) -print "Bind output = %s" % auth1 +print("Bind output = %s" % auth1) -print "Input = %s" % authentication +print("Input = %s" % authentication) auth2 = authProvider2.authenticate(authentication) -print "PasswordComparison output = %s" % auth2 +print("PasswordComparison output = %s" % auth2) diff --git a/test/standalone/test.py.bak b/test/standalone/test.py.bak new file mode 100644 index 0000000..5a61c6c --- /dev/null +++ b/test/standalone/test.py.bak @@ -0,0 +1,35 @@ +################################ +# This stand-alone script is used to exercise the LDAP APIs. The intent is to get things working, and then +# replace this with an effective unit test. +################################ + +from springpython.security.providers import UsernamePasswordAuthenticationToken +from springpython.security.providers.Ldap import DefaultSpringSecurityContextSource +from springpython.security.providers.Ldap import BindAuthenticator +from springpython.security.providers.Ldap import PasswordComparisonAuthenticator +from springpython.security.providers.Ldap import DefaultLdapAuthoritiesPopulator +from springpython.security.providers.Ldap import LdapAuthenticationProvider + +context = DefaultSpringSecurityContextSource(url="ldap://localhost:53389/dc=springframework,dc=org") +bindAuthenticator = BindAuthenticator(context_source=context, user_dn_patterns="uid={0},ou=people") +populator = DefaultLdapAuthoritiesPopulator(context_source=context, group_search_base="ou=groups") +authProvider = LdapAuthenticationProvider(ldap_authenticator=bindAuthenticator, ldap_authorities_populator=populator) + +passwordAuthenticator = PasswordComparisonAuthenticator(context_source=context, user_dn_patterns="uid={0},ou=people") +authProvider2 = LdapAuthenticationProvider(ldap_authenticator=passwordAuthenticator, ldap_authorities_populator=populator) + +authentication = UsernamePasswordAuthenticationToken(username="bob", password="bobspassword") + +print "Input = %s" % authentication + +auth1 = authProvider.authenticate(authentication) + +print "Bind output = %s" % auth1 + + +print "Input = %s" % authentication + +auth2 = authProvider2.authenticate(authentication) + +print "PasswordComparison output = %s" % auth2 + diff --git a/test/standalone/xsd_test_cases.py b/test/standalone/xsd_test_cases.py index 21f9847..5cb111e 100644 --- a/test/standalone/xsd_test_cases.py +++ b/test/standalone/xsd_test_cases.py @@ -67,7 +67,7 @@ def test_xsd(self): try: schema.assert_(doc) - except Exception, e: + except Exception as e: logging.error("Exception in assert_, xml=[%s] e=[%s]" % (xml, e)) raise diff --git a/test/standalone/xsd_test_cases.py.bak b/test/standalone/xsd_test_cases.py.bak new file mode 100644 index 0000000..21f9847 --- /dev/null +++ b/test/standalone/xsd_test_cases.py.bak @@ -0,0 +1,85 @@ +""" + Copyright 2006-2008 SpringSource (http://springsource.com), All Rights Reserved + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +# stdlib +import glob +import logging +import unittest + +# lxml +from lxml import etree +from lxml import objectify + +logger = logging.getLogger("springpythontest.xsd_test_cases") + +class XSDTestCase(unittest.TestCase): + """Verifies whether the XMLConfig files used for the tests themselves + are valid according to the appropriate Spring Python's XSD schema.""" + + NS_10 = "http://www.springframework.org/springpython/schema/objects" + NS_11 = "http://www.springframework.org/springpython/schema/objects/1.1" + + def _get_schema(self, version): + schema_file = open("../../xml/schema/context/spring-python-context-%s.xsd" % version) + schema = etree.XMLSchema(etree.parse(schema_file)) + schema_file.close() + + return schema + + def setUp(self): + self.schema10 = self._get_schema("1.0") + self.schema11 = self._get_schema("1.1") + + def test_xsd(self): + xmls = glob.glob("../springpythontest/support/*.xml") + + if not xmls: + self.fail("No XMLs found") + + for xml in xmls: + doc = objectify.fromstring(open(xml).read()) + xmlns = doc.nsmap[None] + + # XSD v. 1.0 + if xmlns == self.NS_10: + schema = self.schema10 + + # XSD v. 1.1 + elif xmlns == self.NS_11: + schema = self.schema11 + + # Ignore any other XML files + else: + continue + + try: + schema.assert_(doc) + except Exception, e: + logging.error("Exception in assert_, xml=[%s] e=[%s]" % (xml, e)) + raise + +if __name__ == "__main__": + import logging + + loggingLevel = logging.DEBUG + logger.setLevel(loggingLevel) + ch = logging.StreamHandler() + ch.setLevel(loggingLevel) + formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") + ch.setFormatter(formatter) + logger.addHandler(ch) + + unittest.main() From 1c57edf7937dfdade66d3006ebbbf121f16c7a2c Mon Sep 17 00:00:00 2001 From: ram Date: Sun, 10 Jul 2022 21:42:48 -0400 Subject: [PATCH 2/3] Create ReadMe.md --- README.md | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 README.md diff --git a/README.md b/README.md new file mode 100644 index 0000000..5d140f6 --- /dev/null +++ b/README.md @@ -0,0 +1,8 @@ +# springpython +Spring framework compatible with **python 3** + +## Install module locally for python 3 +- `git clone https://github.com/ramAdam/springpython.git` +- `pip install src/` + + From c2038c464e4afdabbacb182ec11b385f530616e9 Mon Sep 17 00:00:00 2001 From: ram Date: Thu, 14 Jul 2022 14:40:34 -0400 Subject: [PATCH 3/3] Update README.md added a step on how to build the package --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 5d140f6..e58a4e4 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,7 @@ Spring framework compatible with **python 3** ## Install module locally for python 3 - `git clone https://github.com/ramAdam/springpython.git` +- `python build.py --package` - `pip install src/`