aboutsummaryrefslogtreecommitdiff
path: root/py
diff options
context:
space:
mode:
authorRonan Lamy <ronan.lamy@gmail.com>2016-11-15 01:46:49 +0000
committerRonan Lamy <ronan.lamy@gmail.com>2016-11-15 01:46:49 +0000
commitc4d582fc64bcec47f47825bafc85abb9ed175c42 (patch)
tree68b430e6f82cfbcae037ed63d2ad1d5d2c9ac506 /py
parentreinstate pytest_cov.py (diff)
downloadpypy-c4d582fc64bcec47f47825bafc85abb9ed175c42.tar.gz
pypy-c4d582fc64bcec47f47825bafc85abb9ed175c42.tar.bz2
pypy-c4d582fc64bcec47f47825bafc85abb9ed175c42.zip
copy upstream pytest-2.9.2 and py-1.4.29
Diffstat (limited to 'py')
-rw-r--r--py/__init__.py4
-rw-r--r--py/_apipkg.py18
-rw-r--r--py/_builtin.py4
-rw-r--r--py/_code/code.py73
-rw-r--r--py/_code/source.py36
-rw-r--r--py/_io/terminalwriter.py17
-rw-r--r--py/_path/common.py38
-rw-r--r--py/_path/local.py138
-rw-r--r--py/_process/forkedfunc.py56
9 files changed, 238 insertions, 146 deletions
diff --git a/py/__init__.py b/py/__init__.py
index c94f0699c5..fec8803828 100644
--- a/py/__init__.py
+++ b/py/__init__.py
@@ -6,9 +6,9 @@ and classes. The initpkg-dictionary below specifies
name->value mappings where value can be another namespace
dictionary or an import path.
-(c) Holger Krekel and others, 2004-2013
+(c) Holger Krekel and others, 2004-2014
"""
-__version__ = '1.4.20'
+__version__ = '1.4.29'
from py import _apipkg
diff --git a/py/_apipkg.py b/py/_apipkg.py
index 4907f6bfca..a73b8f6d0b 100644
--- a/py/_apipkg.py
+++ b/py/_apipkg.py
@@ -17,6 +17,7 @@ def _py_abspath(path):
that will leave paths from jython jars alone
"""
if path.startswith('__pyclasspath__'):
+
return path
else:
return os.path.abspath(path)
@@ -41,7 +42,7 @@ def initpkg(pkgname, exportdefs, attr=dict()):
if hasattr(oldmod, "__dict__"):
oldmod.__dict__.update(d)
mod = ApiModule(pkgname, exportdefs, implprefix=pkgname, attr=d)
- sys.modules[pkgname] = mod
+ sys.modules[pkgname] = mod
def importobj(modpath, attrname):
module = __import__(modpath, None, None, ['__doc__'])
@@ -72,11 +73,11 @@ class ApiModule(ModuleType):
self.__implprefix__ = implprefix or name
if attr:
for name, val in attr.items():
- #print "setting", self.__name__, name, val
+ # print "setting", self.__name__, name, val
setattr(self, name, val)
for name, importspec in importspec.items():
if isinstance(importspec, dict):
- subname = '%s.%s'%(self.__name__, name)
+ subname = '%s.%s' % (self.__name__, name)
apimod = ApiModule(subname, importspec, implprefix)
sys.modules[subname] = apimod
setattr(self, name, apimod)
@@ -88,7 +89,7 @@ class ApiModule(ModuleType):
modpath = implprefix + modpath
if not attrname:
- subname = '%s.%s'%(self.__name__, name)
+ subname = '%s.%s' % (self.__name__, name)
apimod = AliasModule(subname, modpath)
sys.modules[subname] = apimod
if '.' not in name:
@@ -108,7 +109,7 @@ class ApiModule(ModuleType):
def __makeattr(self, name):
"""lazily compute value for name or raise AttributeError if unknown."""
- #print "makeattr", self.__name__, name
+ # print "makeattr", self.__name__, name
target = None
if '__onfirstaccess__' in self.__map__:
target = self.__map__.pop('__onfirstaccess__')
@@ -126,7 +127,7 @@ class ApiModule(ModuleType):
try:
del self.__map__[name]
except KeyError:
- pass # in a recursive-import situation a double-del can happen
+ pass # in a recursive-import situation a double-del can happen
return result
__getattr__ = __makeattr
@@ -166,7 +167,10 @@ def AliasModule(modname, modpath, attrname=None):
return '<AliasModule %r for %r>' % (modname, x)
def __getattribute__(self, name):
- return getattr(getmod(), name)
+ try:
+ return getattr(getmod(), name)
+ except ImportError:
+ return None
def __setattr__(self, name, value):
setattr(getmod(), name, value)
diff --git a/py/_builtin.py b/py/_builtin.py
index fa3b797c44..52ee9d79ca 100644
--- a/py/_builtin.py
+++ b/py/_builtin.py
@@ -220,11 +220,11 @@ else:
locals = globals
exec2(obj, globals, locals)
-if sys.version_info >= (3,0):
+if sys.version_info >= (3, 0):
def _reraise(cls, val, tb):
__tracebackhide__ = True
assert hasattr(val, '__traceback__')
- raise val
+ raise cls.with_traceback(val, tb)
else:
exec ("""
def _reraise(cls, val, tb):
diff --git a/py/_code/code.py b/py/_code/code.py
index aa60da8017..f14c562a29 100644
--- a/py/_code/code.py
+++ b/py/_code/code.py
@@ -133,12 +133,17 @@ class Frame(object):
class TracebackEntry(object):
""" a single entry in a traceback """
+ _repr_style = None
exprinfo = None
def __init__(self, rawentry):
self._rawentry = rawentry
self.lineno = rawentry.tb_lineno - 1
+ def set_repr_style(self, mode):
+ assert mode in ("short", "long")
+ self._repr_style = mode
+
@property
def frame(self):
return py.code.Frame(self._rawentry.tb_frame)
@@ -465,22 +470,22 @@ class FormattedExcinfo(object):
def get_source(self, source, line_index=-1, excinfo=None, short=False):
""" return formatted and marked up source lines. """
lines = []
- if source is None:
+ if source is None or line_index >= len(source.lines):
source = py.code.Source("???")
line_index = 0
if line_index < 0:
line_index += len(source)
- for i in range(len(source)):
- if i == line_index:
- prefix = self.flow_marker + " "
- else:
- if short:
- continue
- prefix = " "
- line = prefix + source[i]
- lines.append(line)
+ space_prefix = " "
+ if short:
+ lines.append(space_prefix + source.lines[line_index].strip())
+ else:
+ for line in source.lines[:line_index]:
+ lines.append(space_prefix + line)
+ lines.append(self.flow_marker + " " + source.lines[line_index])
+ for line in source.lines[line_index+1:]:
+ lines.append(space_prefix + line)
if excinfo is not None:
- indent = self._getindent(source)
+ indent = 4 if short else self._getindent(source)
lines.extend(self.get_exconly(excinfo, indent=indent, markall=True))
return lines
@@ -520,7 +525,6 @@ class FormattedExcinfo(object):
return ReprLocals(lines)
def repr_traceback_entry(self, entry, excinfo=None):
- # excinfo is not None if this is the last tb entry
source = self._getentrysource(entry)
if source is None:
source = py.code.Source("???")
@@ -530,11 +534,12 @@ class FormattedExcinfo(object):
line_index = entry.lineno - max(entry.getfirstlinesource(), 0)
lines = []
- if self.style in ("short", "long"):
- short = self.style == "short"
- reprargs = None
- if not short:
- reprargs = self.repr_args(entry)
+ style = entry._repr_style
+ if style is None:
+ style = self.style
+ if style in ("short", "long"):
+ short = style == "short"
+ reprargs = self.repr_args(entry) if not short else None
s = self.get_source(source, line_index, excinfo, short=short)
lines.extend(s)
if short:
@@ -546,10 +551,10 @@ class FormattedExcinfo(object):
localsrepr = None
if not short:
localsrepr = self.repr_locals(entry.locals)
- return ReprEntry(lines, reprargs, localsrepr, filelocrepr, short)
+ return ReprEntry(lines, reprargs, localsrepr, filelocrepr, style)
if excinfo:
lines.extend(self.get_exconly(excinfo, indent=4))
- return ReprEntry(lines, None, None, None, False)
+ return ReprEntry(lines, None, None, None, style)
def _makepath(self, path):
if not self.abspath:
@@ -567,7 +572,8 @@ class FormattedExcinfo(object):
traceback = traceback.filter()
recursionindex = None
if excinfo.errisinstance(RuntimeError):
- recursionindex = traceback.recursionindex()
+ if "maximum recursion depth exceeded" in str(excinfo.value):
+ recursionindex = traceback.recursionindex()
last = traceback[-1]
entries = []
extraline = None
@@ -628,14 +634,18 @@ class ReprTraceback(TerminalRepr):
self.style = style
def toterminal(self, tw):
- sepok = False
- for entry in self.reprentries:
- if self.style == "long":
- if sepok:
- tw.sep(self.entrysep)
+ # the entries might have different styles
+ last_style = None
+ for i, entry in enumerate(self.reprentries):
+ if entry.style == "long":
tw.line("")
- sepok = True
entry.toterminal(tw)
+ if i < len(self.reprentries) - 1:
+ next_entry = self.reprentries[i+1]
+ if entry.style == "long" or \
+ entry.style == "short" and next_entry.style == "long":
+ tw.sep(self.entrysep)
+
if self.extraline:
tw.line(self.extraline)
@@ -646,6 +656,8 @@ class ReprTracebackNative(ReprTraceback):
self.extraline = None
class ReprEntryNative(TerminalRepr):
+ style = "native"
+
def __init__(self, tblines):
self.lines = tblines
@@ -655,15 +667,15 @@ class ReprEntryNative(TerminalRepr):
class ReprEntry(TerminalRepr):
localssep = "_ "
- def __init__(self, lines, reprfuncargs, reprlocals, filelocrepr, short):
+ def __init__(self, lines, reprfuncargs, reprlocals, filelocrepr, style):
self.lines = lines
self.reprfuncargs = reprfuncargs
self.reprlocals = reprlocals
self.reprfileloc = filelocrepr
- self.short = short
+ self.style = style
def toterminal(self, tw):
- if self.short:
+ if self.style == "short":
self.reprfileloc.toterminal(tw)
for line in self.lines:
red = line.startswith("E ")
@@ -680,7 +692,8 @@ class ReprEntry(TerminalRepr):
tw.line("")
self.reprlocals.toterminal(tw)
if self.reprfileloc:
- tw.line("")
+ if self.lines:
+ tw.line("")
self.reprfileloc.toterminal(tw)
def __str__(self):
diff --git a/py/_code/source.py b/py/_code/source.py
index e17bc1cd35..18709af0bf 100644
--- a/py/_code/source.py
+++ b/py/_code/source.py
@@ -291,14 +291,10 @@ def deindent(lines, offset=None):
while True:
yield ''
- r = readline_generator(lines)
- try:
- readline = r.next
- except AttributeError:
- readline = r.__next__
+ it = readline_generator(lines)
try:
- for _, _, (sline, _), (eline, _), _ in tokenize.generate_tokens(readline):
+ for _, _, (sline, _), (eline, _), _ in tokenize.generate_tokens(lambda: next(it)):
if sline > len(lines):
break # End of input reached
if sline > len(newlines):
@@ -317,12 +313,14 @@ def deindent(lines, offset=None):
newlines.extend(lines[len(newlines):])
return newlines
+
def get_statement_startend(lineno, nodelist):
from bisect import bisect_right
# lineno starts at 0
nextlineno = None
while 1:
lineno_list = [x.lineno-1 for x in nodelist] # ast indexes start at 1
+ #print lineno_list, [vars(x) for x in nodelist]
insert_index = bisect_right(lineno_list, lineno)
if insert_index >= len(nodelist):
insert_index -= 1
@@ -341,7 +339,6 @@ def get_statement_startend(lineno, nodelist):
start, end = nextnode.lineno-1, nextlineno
start = min(lineno, start)
assert start <= lineno and (end is None or lineno < end)
- #print "returning", start, end
return start, end
def getnodelist(node):
@@ -355,7 +352,6 @@ def getnodelist(node):
l.extend(attr)
elif hasattr(attr, "lineno"):
l.append(attr)
- #print "returning nodelist", l
return l
def getstatementrange_ast(lineno, source, assertion=False, astnode=None):
@@ -373,17 +369,35 @@ def getstatementrange_ast(lineno, source, assertion=False, astnode=None):
# - ast-parsing strips comments
# - else statements do not have a separate lineno
# - there might be empty lines
+ # - we might have lesser indented code blocks at the end
if end is None:
end = len(source.lines)
+
+ if end > start + 1:
+ # make sure we don't span differently indented code blocks
+ # by using the BlockFinder helper used which inspect.getsource() uses itself
+ block_finder = inspect.BlockFinder()
+ # if we start with an indented line, put blockfinder to "started" mode
+ block_finder.started = source.lines[start][0].isspace()
+ it = ((x + "\n") for x in source.lines[start:end])
+ try:
+ for tok in tokenize.generate_tokens(lambda: next(it)):
+ block_finder.tokeneater(*tok)
+ except (inspect.EndOfBlock, IndentationError) as e:
+ end = block_finder.last + start
+ #except Exception:
+ # pass
+
+ # the end might still point to a comment, correct it
while end:
- line = source.lines[end-1].lstrip()
- if (not line or line.startswith("#") or line.startswith("else:") or
- line.startswith("finally:")):
+ line = source.lines[end - 1].lstrip()
+ if line.startswith("#"):
end -= 1
else:
break
return astnode, start, end
+
def getstatementrange_old(lineno, source, assertion=False):
""" return (start, end) tuple which spans the minimal
statement region which containing the given lineno.
diff --git a/py/_io/terminalwriter.py b/py/_io/terminalwriter.py
index 6aca55d570..cef1ff5809 100644
--- a/py/_io/terminalwriter.py
+++ b/py/_io/terminalwriter.py
@@ -31,17 +31,24 @@ def _getdimensions():
def get_terminal_width():
+ height = width = 0
try:
height, width = _getdimensions()
except py.builtin._sysex:
raise
except:
- # FALLBACK
+ # pass to fallback below
+ pass
+
+ if width == 0:
+ # FALLBACK:
+ # * some exception happened
+ # * or this is emacs terminal which reports (0,0)
width = int(os.environ.get('COLUMNS', 80))
- else:
- # XXX the windows getdimensions may be bogus, let's sanify a bit
- if width < 40:
- width = 80
+
+ # XXX the windows getdimensions may be bogus, let's sanify a bit
+ if width < 40:
+ width = 80
return width
terminal_width = get_terminal_width()
diff --git a/py/_path/common.py b/py/_path/common.py
index aeb374c087..80907306a6 100644
--- a/py/_path/common.py
+++ b/py/_path/common.py
@@ -1,8 +1,11 @@
"""
"""
-import os, sys
+import os, sys, posixpath
import py
+# Moved from local.py.
+iswin32 = sys.platform == "win32" or (getattr(os, '_name', False) == 'nt')
+
class Checkers:
_depend_on_existence = 'exists', 'link', 'dir', 'file'
@@ -110,22 +113,24 @@ class PathBase(object):
ext = property(ext, None, None, ext.__doc__)
def dirpath(self, *args, **kwargs):
- """ return the directory Path of the current Path joined
- with any given path arguments.
- """
+ """ return the directory path joined with any given path arguments. """
return self.new(basename='').join(*args, **kwargs)
+ def read_binary(self):
+ """ read and return a bytestring from reading the path. """
+ with self.open('rb') as f:
+ return f.read()
+
+ def read_text(self, encoding):
+ """ read and return a Unicode string from reading the path. """
+ with self.open("r", encoding=encoding) as f:
+ return f.read()
+
+
def read(self, mode='r'):
""" read and return a bytestring from reading the path. """
- if sys.version_info < (2,3):
- for x in 'u', 'U':
- if x in mode:
- mode = mode.replace(x, '')
- f = self.open(mode)
- try:
+ with self.open(mode) as f:
return f.read()
- finally:
- f.close()
def readlines(self, cr=1):
""" read and return a list of lines from the path. if cr is False, the
@@ -379,6 +384,15 @@ class FNMatcher:
def __call__(self, path):
pattern = self.pattern
+
+ if (pattern.find(path.sep) == -1 and
+ iswin32 and
+ pattern.find(posixpath.sep) != -1):
+ # Running on Windows, the pattern has no Windows path separators,
+ # and the pattern has one or more Posix path separators. Replace
+ # the Posix path separators with the Windows path separator.
+ pattern = pattern.replace(posixpath.sep, path.sep)
+
if pattern.find(path.sep) == -1:
name = path.basename
else:
diff --git a/py/_path/local.py b/py/_path/local.py
index af09f43014..c1f7248add 100644
--- a/py/_path/local.py
+++ b/py/_path/local.py
@@ -1,15 +1,16 @@
"""
local path implementation.
"""
+from __future__ import with_statement
+
from contextlib import contextmanager
-import sys, os, re, atexit
+import sys, os, re, atexit, io
import py
from py._path import common
+from py._path.common import iswin32
from stat import S_ISLNK, S_ISDIR, S_ISREG
-from os.path import abspath, normpath, isabs, exists, isdir, isfile, islink
-
-iswin32 = sys.platform == "win32" or (getattr(os, '_name', False) == 'nt')
+from os.path import abspath, normpath, isabs, exists, isdir, isfile, islink, dirname
if sys.version_info > (3,0):
def map_as_list(func, iter):
@@ -151,7 +152,7 @@ class LocalPath(FSBase):
elif isinstance(path, py.builtin._basestring):
if expanduser:
path = os.path.expanduser(path)
- self.strpath = abspath(normpath(path))
+ self.strpath = abspath(path)
else:
raise ValueError("can only pass None, Path instances "
"or non-empty strings to LocalPath")
@@ -303,6 +304,16 @@ class LocalPath(FSBase):
raise ValueError("invalid part specification %r" % name)
return res
+ def dirpath(self, *args, **kwargs):
+ """ return the directory path joined with any given path arguments. """
+ if not kwargs:
+ path = object.__new__(self.__class__)
+ path.strpath = dirname(self.strpath)
+ if args:
+ path = path.join(*args)
+ return path
+ return super(LocalPath, self).dirpath(*args, **kwargs)
+
def join(self, *args, **kwargs):
""" return a new path by appending all 'args' as path
components. if abs=1 is used restart from root if any
@@ -330,13 +341,15 @@ class LocalPath(FSBase):
obj.strpath = normpath(strpath)
return obj
- def open(self, mode='r', ensure=False):
+ def open(self, mode='r', ensure=False, encoding=None):
""" return an opened file with the given mode.
If ensure is True, create parent directories if needed.
"""
if ensure:
self.dirpath().ensure(dir=1)
+ if encoding:
+ return py.error.checked_call(io.open, self.strpath, mode, encoding=encoding)
return py.error.checked_call(open, self.strpath, mode)
def _fastjoin(self, name):
@@ -434,6 +447,24 @@ class LocalPath(FSBase):
py.error.checked_call(os.mkdir, getattr(p, "strpath", p))
return p
+ def write_binary(self, data, ensure=False):
+ """ write binary data into path. If ensure is True create
+ missing parent directories.
+ """
+ if ensure:
+ self.dirpath().ensure(dir=1)
+ with self.open('wb') as f:
+ f.write(data)
+
+ def write_text(self, data, encoding, ensure=False):
+ """ write text data into path using the specified encoding.
+ If ensure is True create missing parent directories.
+ """
+ if ensure:
+ self.dirpath().ensure(dir=1)
+ with self.open('w', encoding=encoding) as f:
+ f.write(data)
+
def write(self, data, mode='w', ensure=False):
""" write data into path. If ensure is True create
missing parent directories.
@@ -549,35 +580,6 @@ class LocalPath(FSBase):
""" return string representation of the Path. """
return self.strpath
- def pypkgpath(self, pkgname=None):
- """ return the Python package path by looking for a
- pkgname. If pkgname is None look for the last
- directory upwards which still contains an __init__.py
- and whose basename is python-importable.
- Return None if a pkgpath can not be determined.
- """
- pkgpath = None
- for parent in self.parts(reverse=True):
- if pkgname is None:
- if parent.check(file=1):
- continue
- if not isimportable(parent.basename):
- break
- if parent.join('__init__.py').check():
- pkgpath = parent
- continue
- return pkgpath
- else:
- if parent.basename == pkgname:
- return parent
- return pkgpath
-
- def _prependsyspath(self, path):
- s = str(path)
- if s != sys.path[0]:
- #print "prepending to sys.path", s
- sys.path.insert(0, s)
-
def chmod(self, mode, rec=0):
""" change permissions to the given mode. If mode is an
integer it directly encodes the os-specific modes.
@@ -590,33 +592,61 @@ class LocalPath(FSBase):
py.error.checked_call(os.chmod, str(x), mode)
py.error.checked_call(os.chmod, str(self), mode)
+ def pypkgpath(self):
+ """ return the Python package path by looking for the last
+ directory upwards which still contains an __init__.py.
+ Return None if a pkgpath can not be determined.
+ """
+ pkgpath = None
+ for parent in self.parts(reverse=True):
+ if parent.isdir():
+ if not parent.join('__init__.py').exists():
+ break
+ if not isimportable(parent.basename):
+ break
+ pkgpath = parent
+ return pkgpath
+
+ def _ensuresyspath(self, ensuremode, path):
+ if ensuremode:
+ s = str(path)
+ if ensuremode == "append":
+ if s not in sys.path:
+ sys.path.append(s)
+ else:
+ if s != sys.path[0]:
+ sys.path.insert(0, s)
+
def pyimport(self, modname=None, ensuresyspath=True):
""" return path as an imported python module.
- if modname is None, look for the containing package
- and construct an according module name.
- The module will be put/looked up in sys.modules.
+
+ If modname is None, look for the containing package
+ and construct an according module name.
+ The module will be put/looked up in sys.modules.
+ if ensuresyspath is True then the root dir for importing
+ the file (taking __init__.py files into account) will
+ be prepended to sys.path if it isn't there already.
+ If ensuresyspath=="append" the root dir will be appended
+ if it isn't already contained in sys.path.
+ if ensuresyspath is False no modification of syspath happens.
"""
if not self.check():
raise py.error.ENOENT(self)
- #print "trying to import", self
+
pkgpath = None
if modname is None:
pkgpath = self.pypkgpath()
if pkgpath is not None:
- if ensuresyspath:
- self._prependsyspath(pkgpath.dirpath())
- __import__(pkgpath.basename)
- pkg = sys.modules[pkgpath.basename]
- names = self.new(ext='').relto(pkgpath.dirpath())
- names = names.split(self.sep)
- if names and names[-1] == "__init__":
+ pkgroot = pkgpath.dirpath()
+ names = self.new(ext="").relto(pkgroot).split(self.sep)
+ if names[-1] == "__init__":
names.pop()
modname = ".".join(names)
else:
- # no package scope, still make it possible
- if ensuresyspath:
- self._prependsyspath(self.dirpath())
+ pkgroot = self.dirpath()
modname = self.purebasename
+
+ self._ensuresyspath(ensuresyspath, pkgroot)
__import__(modname)
mod = sys.modules[modname]
if self.basename == "__init__.py":
@@ -630,7 +660,6 @@ class LocalPath(FSBase):
if modfile.endswith(os.path.sep + "__init__.py"):
if self.basename != "__init__.py":
modfile = modfile[:-12]
-
try:
issame = self.samefile(modfile)
except py.error.ENOENT:
@@ -700,9 +729,10 @@ class LocalPath(FSBase):
for path in paths]
else:
paths = py.std.os.environ['PATH'].split(':')
- tryadd = ['']
+ tryadd = []
if iswin32:
tryadd += os.environ['PATHEXT'].split(os.pathsep)
+ tryadd.append("")
for x in paths:
for addext in tryadd:
@@ -876,8 +906,6 @@ def copychunked(src, dest):
fsrc.close()
def isimportable(name):
- if name:
- if not (name[0].isalpha() or name[0] == '_'):
- return False
- name= name.replace("_", '')
+ if name and (name[0].isalpha() or name[0] == '_'):
+ name = name.replace("_", '')
return not name or name.isalnum()
diff --git a/py/_process/forkedfunc.py b/py/_process/forkedfunc.py
index 604412c2e5..1c28530688 100644
--- a/py/_process/forkedfunc.py
+++ b/py/_process/forkedfunc.py
@@ -3,8 +3,6 @@
ForkedFunc provides a way to run a function in a forked process
and get at its return value, stdout and stderr output as well
as signals and exitstatusus.
-
- XXX see if tempdir handling is sane
"""
import py
@@ -12,9 +10,26 @@ import os
import sys
import marshal
-class ForkedFunc(object):
+
+def get_unbuffered_io(fd, filename):
+ f = open(str(filename), "w")
+ if fd != f.fileno():
+ os.dup2(f.fileno(), fd)
+ class AutoFlush:
+ def write(self, data):
+ f.write(data)
+ f.flush()
+ def __getattr__(self, name):
+ return getattr(f, name)
+ return AutoFlush()
+
+
+class ForkedFunc:
EXITSTATUS_EXCEPTION = 3
- def __init__(self, fun, args=None, kwargs=None, nice_level=0):
+
+
+ def __init__(self, fun, args=None, kwargs=None, nice_level=0,
+ child_on_start=None, child_on_exit=None):
if args is None:
args = []
if kwargs is None:
@@ -28,35 +43,32 @@ class ForkedFunc(object):
self.STDERR = tempdir.ensure('stderr')
pid = os.fork()
- if pid: # in parent process
+ if pid: # in parent process
self.pid = pid
- else: # in child process
- self._child(nice_level)
+ else: # in child process
+ self.pid = None
+ self._child(nice_level, child_on_start, child_on_exit)
- def _child(self, nice_level):
+ def _child(self, nice_level, child_on_start, child_on_exit):
# right now we need to call a function, but first we need to
# map all IO that might happen
- # make sure sys.stdout points to file descriptor one
- sys.stdout = stdout = self.STDOUT.open('w')
- sys.stdout.flush()
- fdstdout = stdout.fileno()
- if fdstdout != 1:
- os.dup2(fdstdout, 1)
- sys.stderr = stderr = self.STDERR.open('w')
- fdstderr = stderr.fileno()
- if fdstderr != 2:
- os.dup2(fdstderr, 2)
+ sys.stdout = stdout = get_unbuffered_io(1, self.STDOUT)
+ sys.stderr = stderr = get_unbuffered_io(2, self.STDERR)
retvalf = self.RETVAL.open("wb")
EXITSTATUS = 0
try:
if nice_level:
os.nice(nice_level)
try:
+ if child_on_start is not None:
+ child_on_start()
retval = self.fun(*self.args, **self.kwargs)
retvalf.write(marshal.dumps(retval))
+ if child_on_exit is not None:
+ child_on_exit()
except:
excinfo = py.code.ExceptionInfo()
- stderr.write(excinfo.exconly())
+ stderr.write(str(excinfo._getreprcrash()))
EXITSTATUS = self.EXITSTATUS_EXCEPTION
finally:
stdout.close()
@@ -73,8 +85,6 @@ class ForkedFunc(object):
exitstatus = os.WTERMSIG(systemstatus) + 128
else:
exitstatus = os.WEXITSTATUS(systemstatus)
- #raise ExecutionFailed(status, systemstatus, cmd,
- # ''.join(out), ''.join(err))
else:
exitstatus = 0
signal = systemstatus & 0x7f
@@ -97,7 +107,9 @@ class ForkedFunc(object):
self.tempdir.remove()
def __del__(self):
- self._removetemp()
+ if self.pid is not None: # only clean up in main process
+ self._removetemp()
+
class Result(object):
def __init__(self, exitstatus, signal, retval, stdout, stderr):