aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--pypy/doc/whatsnew-head.rst5
-rw-r--r--pypy/objspace/std/bytesobject.py22
-rw-r--r--pypy/objspace/std/test/test_bytesobject.py4
-rw-r--r--pypy/objspace/std/test/test_unicodeobject.py4
-rw-r--r--pypy/objspace/std/unicodeobject.py3
-rw-r--r--rpython/rlib/rstring.py152
-rw-r--r--rpython/rlib/test/test_rstring.py47
-rw-r--r--rpython/rtyper/lltypesystem/rstr.py117
8 files changed, 226 insertions, 128 deletions
diff --git a/pypy/doc/whatsnew-head.rst b/pypy/doc/whatsnew-head.rst
index 26094691b6..4b6c1ff4c9 100644
--- a/pypy/doc/whatsnew-head.rst
+++ b/pypy/doc/whatsnew-head.rst
@@ -89,3 +89,8 @@ Refactor the intbound analysis in the JIT
.. branch: issue-3404
Fix ``PyObject_Format`` for type objects
+
+
+.. branch: string-algorithmic-optimizations
+
+Faster str.replace and bytes.replace implementations.
diff --git a/pypy/objspace/std/bytesobject.py b/pypy/objspace/std/bytesobject.py
index 6315c5d6cf..2316f6e513 100644
--- a/pypy/objspace/std/bytesobject.py
+++ b/pypy/objspace/std/bytesobject.py
@@ -690,15 +690,33 @@ class W_BytesObject(W_AbstractBytesObject):
self_as_unicode._utf8.find(w_sub._utf8) >= 0)
return self._StringMethods_descr_contains(space, w_sub)
- _StringMethods_descr_replace = descr_replace
@unwrap_spec(count=int)
def descr_replace(self, space, w_old, w_new, count=-1):
+ from rpython.rlib.rstring import replace
old_is_unicode = space.isinstance_w(w_old, space.w_unicode)
new_is_unicode = space.isinstance_w(w_new, space.w_unicode)
if old_is_unicode or new_is_unicode:
self_as_uni = unicode_from_encoded_object(space, self, None, None)
return self_as_uni.descr_replace(space, w_old, w_new, count)
- return self._StringMethods_descr_replace(space, w_old, w_new, count)
+
+ # almost copy of StringMethods.descr_replace :-(
+ input = self._value
+
+ sub = self._op_val(space, w_old)
+ by = self._op_val(space, w_new)
+ # the following two lines are for being bug-to-bug compatible
+ # with CPython: see issue #2448
+ if count >= 0 and len(input) == 0:
+ return self._empty()
+ try:
+ res = replace(input, sub, by, count)
+ except OverflowError:
+ raise oefmt(space.w_OverflowError, "replace string is too long")
+ # difference: reuse self if no replacement was done
+ if type(self) is W_BytesObject and res is input:
+ return self
+
+ return self._new(res)
_StringMethods_descr_join = descr_join
def descr_join(self, space, w_list):
diff --git a/pypy/objspace/std/test/test_bytesobject.py b/pypy/objspace/std/test/test_bytesobject.py
index cc15f97d54..2feca7ab5e 100644
--- a/pypy/objspace/std/test/test_bytesobject.py
+++ b/pypy/objspace/std/test/test_bytesobject.py
@@ -342,6 +342,10 @@ class AppTestBytesObject:
assert 'one'.replace(buffer('o'), buffer('n'), 1) == 'nne'
assert 'one'.replace(buffer('o'), buffer('n')) == 'nne'
+ def test_replace_no_occurrence(self):
+ x = b"xyz"
+ assert x.replace(b"a", b"b") is x
+
def test_strip(self):
s = " a b "
assert s.strip() == "a b"
diff --git a/pypy/objspace/std/test/test_unicodeobject.py b/pypy/objspace/std/test/test_unicodeobject.py
index 51faff763d..6b1c7315da 100644
--- a/pypy/objspace/std/test/test_unicodeobject.py
+++ b/pypy/objspace/std/test/test_unicodeobject.py
@@ -1303,3 +1303,7 @@ class AppTestUnicodeString:
def test_newlist_utf8_non_ascii(self):
'ä'.split("\n")[0] # does not crash
+
+ def test_replace_no_occurrence(self):
+ x = u"xyz"
+ assert x.replace(u"a", u"b") is x
diff --git a/pypy/objspace/std/unicodeobject.py b/pypy/objspace/std/unicodeobject.py
index 1dcd415912..4fa1a98437 100644
--- a/pypy/objspace/std/unicodeobject.py
+++ b/pypy/objspace/std/unicodeobject.py
@@ -880,8 +880,11 @@ class W_UnicodeObject(W_Root):
count, isutf8=True)
except OverflowError:
raise oefmt(space.w_OverflowError, "replace string is too long")
+ if type(self) is W_UnicodeObject and replacements == 0:
+ return self
newlength = self._length + replacements * (w_by._length - w_sub._length)
+ assert res is not None
return W_UnicodeObject(res, newlength)
def descr_mul(self, space, w_times):
diff --git a/rpython/rlib/rstring.py b/rpython/rlib/rstring.py
index 29e1495381..b7bf0b2a16 100644
--- a/rpython/rlib/rstring.py
+++ b/rpython/rlib/rstring.py
@@ -101,9 +101,13 @@ def _split_by(value, by, maxsplit):
start = 0
if bylen == 1:
- # fast path: uses str.rfind(character) and str.count(character)
+ # fast path: uses str.find(character) and str.count(character)
by = by[0] # annotator hack: string -> char
cnt = count(value, by, 0, len(value))
+ if cnt == 0:
+ if isinstance(value, str):
+ return [value]
+ return [value[0:len(value)]]
if 0 <= maxsplit < cnt:
cnt = maxsplit
res = newlist_hint(cnt + 1)
@@ -208,12 +212,12 @@ def _rsplit_by(value, by, maxsplit):
@specialize.argtype(0, 1)
@jit.elidable
-def replace(input, sub, by, maxsplit=-1):
- return replace_count(input, sub, by, maxsplit)[0]
+def replace(input, sub, by, maxcount=-1):
+ return replace_count(input, sub, by, maxcount)[0]
@specialize.ll_and_arg(4)
@jit.elidable
-def replace_count(input, sub, by, maxsplit=-1, isutf8=False):
+def replace_count(input, sub, by, maxcount=-1, isutf8=False):
if isinstance(input, str):
Builder = StringBuilder
elif isinstance(input, unicode):
@@ -221,14 +225,14 @@ def replace_count(input, sub, by, maxsplit=-1, isutf8=False):
else:
assert isinstance(input, list)
Builder = ByteListBuilder
- if maxsplit == 0:
+ if maxcount == 0:
return input, 0
if not sub and not isutf8:
upper = len(input)
- if maxsplit > 0 and maxsplit < upper + 2:
- upper = maxsplit - 1
+ if maxcount > 0 and maxcount < upper + 2:
+ upper = maxcount - 1
assert upper >= 0
try:
@@ -246,17 +250,27 @@ def replace_count(input, sub, by, maxsplit=-1, isutf8=False):
builder.append(by)
builder.append_slice(input, upper, len(input))
replacements = upper + 1
+
+ elif isinstance(input, str) and len(sub) == 1:
+ if len(by) == 1:
+ return replace_count_str_chr_chr(input, sub[0], by[0], maxcount)
+ return replace_count_str_chr_str(input, sub[0], by, maxcount)
+
else:
# First compute the exact result size
if sub:
cnt = count(input, sub, 0, len(input))
+ if isinstance(input, str) and cnt == 0:
+ return input, 0
+ if isinstance(input, str):
+ return replace_count_str_str_str(input, sub, by, cnt, maxcount)
else:
assert isutf8
from rpython.rlib import rutf8
cnt = rutf8.codepoints_in_utf8(input) + 1
- if cnt > maxsplit and maxsplit > 0:
- cnt = maxsplit
+ if cnt > maxcount and maxcount > 0:
+ cnt = maxcount
diff_len = len(by) - len(sub)
try:
result_size = ovfcheck(diff_len * cnt)
@@ -274,26 +288,122 @@ def replace_count(input, sub, by, maxsplit=-1, isutf8=False):
from rpython.rlib import rutf8
while True:
builder.append(by)
- maxsplit -= 1
- if start == len(input) or maxsplit == 0:
+ maxcount -= 1
+ if start == len(input) or maxcount == 0:
break
next = rutf8.next_codepoint_pos(input, start)
builder.append_slice(input, start, next)
start = next
else:
- while maxsplit != 0:
+ while maxcount != 0:
next = find(input, sub, start, len(input))
if next < 0:
break
builder.append_slice(input, start, next)
builder.append(by)
start = next + sublen
- maxsplit -= 1 # NB. if it's already < 0, it stays < 0
+ maxcount -= 1 # NB. if it's already < 0, it stays < 0
builder.append_slice(input, start, len(input))
return builder.build(), replacements
+def replace_count_str_chr_chr(input, c1, c2, maxcount):
+ from rpython.rtyper.annlowlevel import llstr, hlstr
+ s = llstr(input)
+ length = len(s.chars)
+ start = find(input, c1, 0, len(input))
+ if start < 0:
+ return input, 0
+ newstr = s.malloc(length)
+ src = s.chars
+ dst = newstr.chars
+ s.copy_contents(s, newstr, 0, 0, len(input))
+ dst[start] = c2
+ count = 1
+ start += 1
+ maxcount -= 1
+ while maxcount != 0:
+ next = find(input, c1, start, len(input))
+ if next < 0:
+ break
+ dst[next] = c2
+ start = next + 1
+ maxcount -= 1
+ count += 1
+
+ return hlstr(newstr), count
+
+def replace_count_str_chr_str(input, sub, by, maxcount):
+ from rpython.rtyper.annlowlevel import llstr, hlstr
+ cnt = count(input, sub, 0, len(input))
+ if cnt == 0:
+ return input, 0
+ if maxcount > 0 and cnt > maxcount:
+ cnt = maxcount
+ diff_len = len(by) - 1
+ try:
+ result_size = ovfcheck(diff_len * cnt)
+ result_size = ovfcheck(result_size + len(input))
+ except OverflowError:
+ raise
+
+ s = llstr(input)
+ by_ll = llstr(by)
+
+ newstr = s.malloc(result_size)
+ dst = 0
+ start = 0
+ while maxcount != 0:
+ next = find(input, sub, start, len(input))
+ if next < 0:
+ break
+ s.copy_contents(s, newstr, start, dst, next - start)
+ dst += next - start
+ s.copy_contents(by_ll, newstr, 0, dst, len(by))
+ dst += len(by)
+
+ start = next + 1
+ maxcount -= 1 # NB. if it's already < 0, it stays < 0
+
+ s.copy_contents(s, newstr, start, dst, len(input) - start)
+ assert dst - start + len(input) == result_size
+ return hlstr(newstr), cnt
+
+def replace_count_str_str_str(input, sub, by, cnt, maxcount):
+ from rpython.rtyper.annlowlevel import llstr, hlstr
+ if cnt > maxcount and maxcount > 0:
+ cnt = maxcount
+ diff_len = len(by) - len(sub)
+ try:
+ result_size = ovfcheck(diff_len * cnt)
+ result_size = ovfcheck(result_size + len(input))
+ except OverflowError:
+ raise
+
+ s = llstr(input)
+ by_ll = llstr(by)
+ newstr = s.malloc(result_size)
+ sublen = len(sub)
+ bylen = len(by)
+ inputlen = len(input)
+ dst = 0
+ start = 0
+ while maxcount != 0:
+ next = find(input, sub, start, inputlen)
+ if next < 0:
+ break
+ s.copy_contents(s, newstr, start, dst, next - start)
+ dst += next - start
+ s.copy_contents(by_ll, newstr, 0, dst, bylen)
+ dst += bylen
+ start = next + sublen
+ maxcount -= 1 # NB. if it's already < 0, it stays < 0
+ s.copy_contents(s, newstr, start, dst, len(input) - start)
+ assert dst - start + len(input) == result_size
+ return hlstr(newstr), cnt
+
+
def _normalize_start_end(length, start, end):
if start < 0:
start += length
@@ -355,20 +465,26 @@ def count(value, other, start, end):
return _search(value, other, start, end, SEARCH_COUNT)
# -------------- substring searching helper ----------------
-# XXX a lot of code duplication with lltypesystem.rstr :-(
SEARCH_COUNT = 0
SEARCH_FIND = 1
SEARCH_RFIND = 2
+@specialize.ll()
def bloom_add(mask, c):
return mask | (1 << (ord(c) & (BLOOM_WIDTH - 1)))
+@specialize.ll()
def bloom(mask, c):
return mask & (1 << (ord(c) & (BLOOM_WIDTH - 1)))
@specialize.argtype(0, 1)
def _search(value, other, start, end, mode):
+ assert value is not None
+ if isinstance(value, unicode):
+ NUL = u'\0'
+ else:
+ NUL = '\0'
if start < 0:
start = 0
if end > len(value):
@@ -398,7 +514,7 @@ def _search(value, other, start, end, mode):
return -1
mlast = m - 1
- skip = mlast - 1
+ skip = mlast
mask = 0
if mode != SEARCH_RFIND:
@@ -411,7 +527,7 @@ def _search(value, other, start, end, mode):
i = start - 1
while i + 1 <= start + w:
i += 1
- if value[i + m - 1] == other[m - 1]:
+ if value[i + mlast] == other[mlast]:
for j in range(mlast):
if value[i + j] != other[j]:
break
@@ -425,7 +541,7 @@ def _search(value, other, start, end, mode):
if i + m < len(value):
c = value[i + m]
else:
- c = '\0'
+ c = NUL
if not bloom(mask, c):
i += m
else:
@@ -434,7 +550,7 @@ def _search(value, other, start, end, mode):
if i + m < len(value):
c = value[i + m]
else:
- c = '\0'
+ c = NUL
if not bloom(mask, c):
i += m
else:
diff --git a/rpython/rlib/test/test_rstring.py b/rpython/rlib/test/test_rstring.py
index 18b5103e54..5fda0275e2 100644
--- a/rpython/rlib/test/test_rstring.py
+++ b/rpython/rlib/test/test_rstring.py
@@ -2,10 +2,12 @@ import sys, py
from rpython.rlib.rstring import StringBuilder, UnicodeBuilder, split, rsplit
from rpython.rlib.rstring import replace, startswith, endswith, replace_count
-from rpython.rlib.rstring import find, rfind, count
+from rpython.rlib.rstring import find, rfind, count, _search, SEARCH_COUNT, SEARCH_FIND
from rpython.rlib.buffer import StringBuffer
from rpython.rtyper.test.tool import BaseRtypingTest
+from hypothesis import given, strategies as st, assume
+
def test_split():
def check_split(value, sub, *args, **kwargs):
result = kwargs['res']
@@ -27,6 +29,11 @@ def test_split():
check_split('endcase test', 'test', res=['endcase ', ''])
py.test.raises(ValueError, split, 'abc', '')
+def test_split_no_occurrence():
+ x = "abc"
+ assert x.split("d")[0] is x
+ assert x.rsplit("d")[0] is x
+
def test_split_None():
assert split("") == []
assert split(' a\ta\na b') == ['a', 'a', 'a', 'b']
@@ -164,6 +171,12 @@ def test_unicode_replace_overflow():
with py.test.raises(OverflowError):
replace(s, u"a", s, len(s) - 10)
+def test_replace_no_occurrence():
+ s = "xyz"
+ assert replace(s, "a", "b") is s
+ s = "xyz"
+ assert replace(s, "abc", "b") is s
+
def test_startswith():
def check_startswith(value, sub, *args, **kwargs):
result = kwargs['res']
@@ -240,6 +253,8 @@ def test_search():
check_search(find, 'one two three', 'ne', 5, 13, res=-1)
check_search(find, 'one two three', '', 0, 13, res=0)
+ check_search(find, '000000p00000000', 'ap', 0, 15, res=-1)
+
check_search(rfind, 'one two three', 'e', 0, 13, res=12)
check_search(rfind, 'one two three', 'e', 0, 1, res=-1)
check_search(rfind, 'one two three', '', 0, 13, res=13)
@@ -293,3 +308,33 @@ class TestTranslates(BaseRtypingTest):
return res
res = self.interpret(fn, [])
assert res
+
+@given(u=st.text(), prefix=st.text(), suffix=st.text())
+def test_hypothesis_search(u, prefix, suffix):
+ prefix = prefix.encode("utf-8")
+ u = u.encode("utf-8")
+ suffix = suffix.encode("utf-8")
+ s = prefix + u + suffix
+
+ index = _search(s, u, 0, len(s), SEARCH_FIND)
+ assert index == s.find(u)
+ assert 0 <= index <= len(prefix)
+
+ index = _search(s, u, len(prefix), len(s) - len(suffix), SEARCH_FIND)
+ assert index == len(prefix)
+
+ count = _search(s, u, 0, len(s), SEARCH_COUNT)
+ assert count == s.count(u)
+ assert 1 <= count
+
+
+@given(st.text(), st.lists(st.text(), min_size=2), st.text(), st.integers(min_value=0, max_value=1000000))
+def test_hypothesis_search(needle, pieces, by, maxcount):
+ needle = needle.encode("utf-8")
+ pieces = [piece.encode("utf-8") for piece in pieces]
+ by = by.encode("utf-8")
+ input = needle.join(pieces)
+ assume(len(input) > 0)
+
+ res = replace(input, needle, by, maxcount)
+ assert res == input.replace(needle, by, maxcount)
diff --git a/rpython/rtyper/lltypesystem/rstr.py b/rpython/rtyper/lltypesystem/rstr.py
index 0ae4c2f459..72c44ba96c 100644
--- a/rpython/rtyper/lltypesystem/rstr.py
+++ b/rpython/rtyper/lltypesystem/rstr.py
@@ -303,21 +303,6 @@ class UniCharRepr(AbstractUniCharRepr, UnicodeRepr):
# get flowed and annotated, mostly with SomePtr.
#
-FAST_COUNT = 0
-FAST_FIND = 1
-FAST_RFIND = 2
-
-
-from rpython.rlib.rarithmetic import LONG_BIT as BLOOM_WIDTH
-
-
-def bloom_add(mask, c):
- return mask | (1 << (ord(c) & (BLOOM_WIDTH - 1)))
-
-
-def bloom(mask, c):
- return mask & (1 << (ord(c) & (BLOOM_WIDTH - 1)))
-
class LLHelpers(AbstractLLHelpers):
from rpython.rtyper.annlowlevel import llstr, llunicode
@@ -720,6 +705,7 @@ class LLHelpers(AbstractLLHelpers):
@staticmethod
@signature(types.any(), types.any(), types.int(), types.int(), returns=types.int())
def ll_find(s1, s2, start, end):
+ from rpython.rlib.rstring import SEARCH_FIND
if start < 0:
start = 0
if end > len(s1.chars):
@@ -731,11 +717,12 @@ class LLHelpers(AbstractLLHelpers):
if m == 1:
return LLHelpers.ll_find_char(s1, s2.chars[0], start, end)
- return LLHelpers.ll_search(s1, s2, start, end, FAST_FIND)
+ return LLHelpers.ll_search(s1, s2, start, end, SEARCH_FIND)
@staticmethod
@signature(types.any(), types.any(), types.int(), types.int(), returns=types.int())
def ll_rfind(s1, s2, start, end):
+ from rpython.rlib.rstring import SEARCH_RFIND
if start < 0:
start = 0
if end > len(s1.chars):
@@ -747,10 +734,11 @@ class LLHelpers(AbstractLLHelpers):
if m == 1:
return LLHelpers.ll_rfind_char(s1, s2.chars[0], start, end)
- return LLHelpers.ll_search(s1, s2, start, end, FAST_RFIND)
+ return LLHelpers.ll_search(s1, s2, start, end, SEARCH_RFIND)
@classmethod
def ll_count(cls, s1, s2, start, end):
+ from rpython.rlib.rstring import SEARCH_COUNT
if start < 0:
start = 0
if end > len(s1.chars):
@@ -762,104 +750,19 @@ class LLHelpers(AbstractLLHelpers):
if m == 1:
return cls.ll_count_char(s1, s2.chars[0], start, end)
- res = cls.ll_search(s1, s2, start, end, FAST_COUNT)
+ res = cls.ll_search(s1, s2, start, end, SEARCH_COUNT)
assert res >= 0
return res
@staticmethod
- @jit.elidable
def ll_search(s1, s2, start, end, mode):
- count = 0
- n = end - start
- m = len(s2.chars)
+ from rpython.rtyper.annlowlevel import hlstr, hlunicode
+ from rpython.rlib import rstring
tp = typeOf(s1)
if tp == string_repr.lowleveltype or tp == Char:
- NUL = '\0'
+ return rstring._search(hlstr(s1), hlstr(s2), start, end, mode)
else:
- NUL = u'\0'
-
- if m == 0:
- if mode == FAST_COUNT:
- return end - start + 1
- elif mode == FAST_RFIND:
- return end
- else:
- return start
-
- w = n - m
-
- if w < 0:
- if mode == FAST_COUNT:
- return 0
- return -1
-
- mlast = m - 1
- skip = mlast - 1
- mask = 0
-
- if mode != FAST_RFIND:
- for i in range(mlast):
- mask = bloom_add(mask, s2.chars[i])
- if s2.chars[i] == s2.chars[mlast]:
- skip = mlast - i - 1
- mask = bloom_add(mask, s2.chars[mlast])
-
- i = start - 1
- while i + 1 <= start + w:
- i += 1
- if s1.chars[i + m - 1] == s2.chars[m - 1]:
- for j in range(mlast):
- if s1.chars[i + j] != s2.chars[j]:
- break
- else:
- if mode != FAST_COUNT:
- return i
- count += 1
- i += mlast
- continue
-
- if i + m < len(s1.chars):
- c = s1.chars[i + m]
- else:
- c = NUL
- if not bloom(mask, c):
- i += m
- else:
- i += skip
- else:
- if i + m < len(s1.chars):
- c = s1.chars[i + m]
- else:
- c = NUL
- if not bloom(mask, c):
- i += m
- else:
- mask = bloom_add(mask, s2.chars[0])
- for i in range(mlast, 0, -1):
- mask = bloom_add(mask, s2.chars[i])
- if s2.chars[i] == s2.chars[0]:
- skip = i - 1
-
- i = start + w + 1
- while i - 1 >= start:
- i -= 1
- if s1.chars[i] == s2.chars[0]:
- for j in xrange(mlast, 0, -1):
- if s1.chars[i + j] != s2.chars[j]:
- break
- else:
- return i
- if i - 1 >= 0 and not bloom(mask, s1.chars[i - 1]):
- i -= m
- else:
- i -= skip
- else:
- if i - 1 >= 0 and not bloom(mask, s1.chars[i - 1]):
- i -= m
-
- if mode != FAST_COUNT:
- return -1
- return count
+ return rstring._search(hlunicode(s1), hlunicode(s2), start, end, mode)
@staticmethod
@signature(types.int(), types.any(), returns=types.any())