diff options
author | Carl Friedrich Bolz-Tereick <cfbolz@gmx.de> | 2021-02-28 14:44:08 +0100 |
---|---|---|
committer | Carl Friedrich Bolz-Tereick <cfbolz@gmx.de> | 2021-02-28 14:44:08 +0100 |
commit | efcf3c89c4a6e7d5481f9ef65dbd8e5b0c734d41 (patch) | |
tree | b0ef265d5cbbaf0a2b0c2ce9b5991d50765263cf | |
parent | Copy dummy constants from greenlet 1.0.0 (diff) | |
parent | add whatsnew (diff) | |
download | pypy-efcf3c89c4a6e7d5481f9ef65dbd8e5b0c734d41.tar.gz pypy-efcf3c89c4a6e7d5481f9ef65dbd8e5b0c734d41.tar.bz2 pypy-efcf3c89c4a6e7d5481f9ef65dbd8e5b0c734d41.zip |
merge string-algorithmic-optimizations
-rw-r--r-- | pypy/doc/whatsnew-head.rst | 5 | ||||
-rw-r--r-- | pypy/objspace/std/bytesobject.py | 22 | ||||
-rw-r--r-- | pypy/objspace/std/test/test_bytesobject.py | 4 | ||||
-rw-r--r-- | pypy/objspace/std/test/test_unicodeobject.py | 4 | ||||
-rw-r--r-- | pypy/objspace/std/unicodeobject.py | 3 | ||||
-rw-r--r-- | rpython/rlib/rstring.py | 152 | ||||
-rw-r--r-- | rpython/rlib/test/test_rstring.py | 47 | ||||
-rw-r--r-- | rpython/rtyper/lltypesystem/rstr.py | 117 |
8 files changed, 226 insertions, 128 deletions
diff --git a/pypy/doc/whatsnew-head.rst b/pypy/doc/whatsnew-head.rst index 26094691b6..4b6c1ff4c9 100644 --- a/pypy/doc/whatsnew-head.rst +++ b/pypy/doc/whatsnew-head.rst @@ -89,3 +89,8 @@ Refactor the intbound analysis in the JIT .. branch: issue-3404 Fix ``PyObject_Format`` for type objects + + +.. branch: string-algorithmic-optimizations + +Faster str.replace and bytes.replace implementations. diff --git a/pypy/objspace/std/bytesobject.py b/pypy/objspace/std/bytesobject.py index 6315c5d6cf..2316f6e513 100644 --- a/pypy/objspace/std/bytesobject.py +++ b/pypy/objspace/std/bytesobject.py @@ -690,15 +690,33 @@ class W_BytesObject(W_AbstractBytesObject): self_as_unicode._utf8.find(w_sub._utf8) >= 0) return self._StringMethods_descr_contains(space, w_sub) - _StringMethods_descr_replace = descr_replace @unwrap_spec(count=int) def descr_replace(self, space, w_old, w_new, count=-1): + from rpython.rlib.rstring import replace old_is_unicode = space.isinstance_w(w_old, space.w_unicode) new_is_unicode = space.isinstance_w(w_new, space.w_unicode) if old_is_unicode or new_is_unicode: self_as_uni = unicode_from_encoded_object(space, self, None, None) return self_as_uni.descr_replace(space, w_old, w_new, count) - return self._StringMethods_descr_replace(space, w_old, w_new, count) + + # almost copy of StringMethods.descr_replace :-( + input = self._value + + sub = self._op_val(space, w_old) + by = self._op_val(space, w_new) + # the following two lines are for being bug-to-bug compatible + # with CPython: see issue #2448 + if count >= 0 and len(input) == 0: + return self._empty() + try: + res = replace(input, sub, by, count) + except OverflowError: + raise oefmt(space.w_OverflowError, "replace string is too long") + # difference: reuse self if no replacement was done + if type(self) is W_BytesObject and res is input: + return self + + return self._new(res) _StringMethods_descr_join = descr_join def descr_join(self, space, w_list): diff --git a/pypy/objspace/std/test/test_bytesobject.py b/pypy/objspace/std/test/test_bytesobject.py index cc15f97d54..2feca7ab5e 100644 --- a/pypy/objspace/std/test/test_bytesobject.py +++ b/pypy/objspace/std/test/test_bytesobject.py @@ -342,6 +342,10 @@ class AppTestBytesObject: assert 'one'.replace(buffer('o'), buffer('n'), 1) == 'nne' assert 'one'.replace(buffer('o'), buffer('n')) == 'nne' + def test_replace_no_occurrence(self): + x = b"xyz" + assert x.replace(b"a", b"b") is x + def test_strip(self): s = " a b " assert s.strip() == "a b" diff --git a/pypy/objspace/std/test/test_unicodeobject.py b/pypy/objspace/std/test/test_unicodeobject.py index 51faff763d..6b1c7315da 100644 --- a/pypy/objspace/std/test/test_unicodeobject.py +++ b/pypy/objspace/std/test/test_unicodeobject.py @@ -1303,3 +1303,7 @@ class AppTestUnicodeString: def test_newlist_utf8_non_ascii(self): 'ä'.split("\n")[0] # does not crash + + def test_replace_no_occurrence(self): + x = u"xyz" + assert x.replace(u"a", u"b") is x diff --git a/pypy/objspace/std/unicodeobject.py b/pypy/objspace/std/unicodeobject.py index 1dcd415912..4fa1a98437 100644 --- a/pypy/objspace/std/unicodeobject.py +++ b/pypy/objspace/std/unicodeobject.py @@ -880,8 +880,11 @@ class W_UnicodeObject(W_Root): count, isutf8=True) except OverflowError: raise oefmt(space.w_OverflowError, "replace string is too long") + if type(self) is W_UnicodeObject and replacements == 0: + return self newlength = self._length + replacements * (w_by._length - w_sub._length) + assert res is not None return W_UnicodeObject(res, newlength) def descr_mul(self, space, w_times): diff --git a/rpython/rlib/rstring.py b/rpython/rlib/rstring.py index 29e1495381..b7bf0b2a16 100644 --- a/rpython/rlib/rstring.py +++ b/rpython/rlib/rstring.py @@ -101,9 +101,13 @@ def _split_by(value, by, maxsplit): start = 0 if bylen == 1: - # fast path: uses str.rfind(character) and str.count(character) + # fast path: uses str.find(character) and str.count(character) by = by[0] # annotator hack: string -> char cnt = count(value, by, 0, len(value)) + if cnt == 0: + if isinstance(value, str): + return [value] + return [value[0:len(value)]] if 0 <= maxsplit < cnt: cnt = maxsplit res = newlist_hint(cnt + 1) @@ -208,12 +212,12 @@ def _rsplit_by(value, by, maxsplit): @specialize.argtype(0, 1) @jit.elidable -def replace(input, sub, by, maxsplit=-1): - return replace_count(input, sub, by, maxsplit)[0] +def replace(input, sub, by, maxcount=-1): + return replace_count(input, sub, by, maxcount)[0] @specialize.ll_and_arg(4) @jit.elidable -def replace_count(input, sub, by, maxsplit=-1, isutf8=False): +def replace_count(input, sub, by, maxcount=-1, isutf8=False): if isinstance(input, str): Builder = StringBuilder elif isinstance(input, unicode): @@ -221,14 +225,14 @@ def replace_count(input, sub, by, maxsplit=-1, isutf8=False): else: assert isinstance(input, list) Builder = ByteListBuilder - if maxsplit == 0: + if maxcount == 0: return input, 0 if not sub and not isutf8: upper = len(input) - if maxsplit > 0 and maxsplit < upper + 2: - upper = maxsplit - 1 + if maxcount > 0 and maxcount < upper + 2: + upper = maxcount - 1 assert upper >= 0 try: @@ -246,17 +250,27 @@ def replace_count(input, sub, by, maxsplit=-1, isutf8=False): builder.append(by) builder.append_slice(input, upper, len(input)) replacements = upper + 1 + + elif isinstance(input, str) and len(sub) == 1: + if len(by) == 1: + return replace_count_str_chr_chr(input, sub[0], by[0], maxcount) + return replace_count_str_chr_str(input, sub[0], by, maxcount) + else: # First compute the exact result size if sub: cnt = count(input, sub, 0, len(input)) + if isinstance(input, str) and cnt == 0: + return input, 0 + if isinstance(input, str): + return replace_count_str_str_str(input, sub, by, cnt, maxcount) else: assert isutf8 from rpython.rlib import rutf8 cnt = rutf8.codepoints_in_utf8(input) + 1 - if cnt > maxsplit and maxsplit > 0: - cnt = maxsplit + if cnt > maxcount and maxcount > 0: + cnt = maxcount diff_len = len(by) - len(sub) try: result_size = ovfcheck(diff_len * cnt) @@ -274,26 +288,122 @@ def replace_count(input, sub, by, maxsplit=-1, isutf8=False): from rpython.rlib import rutf8 while True: builder.append(by) - maxsplit -= 1 - if start == len(input) or maxsplit == 0: + maxcount -= 1 + if start == len(input) or maxcount == 0: break next = rutf8.next_codepoint_pos(input, start) builder.append_slice(input, start, next) start = next else: - while maxsplit != 0: + while maxcount != 0: next = find(input, sub, start, len(input)) if next < 0: break builder.append_slice(input, start, next) builder.append(by) start = next + sublen - maxsplit -= 1 # NB. if it's already < 0, it stays < 0 + maxcount -= 1 # NB. if it's already < 0, it stays < 0 builder.append_slice(input, start, len(input)) return builder.build(), replacements +def replace_count_str_chr_chr(input, c1, c2, maxcount): + from rpython.rtyper.annlowlevel import llstr, hlstr + s = llstr(input) + length = len(s.chars) + start = find(input, c1, 0, len(input)) + if start < 0: + return input, 0 + newstr = s.malloc(length) + src = s.chars + dst = newstr.chars + s.copy_contents(s, newstr, 0, 0, len(input)) + dst[start] = c2 + count = 1 + start += 1 + maxcount -= 1 + while maxcount != 0: + next = find(input, c1, start, len(input)) + if next < 0: + break + dst[next] = c2 + start = next + 1 + maxcount -= 1 + count += 1 + + return hlstr(newstr), count + +def replace_count_str_chr_str(input, sub, by, maxcount): + from rpython.rtyper.annlowlevel import llstr, hlstr + cnt = count(input, sub, 0, len(input)) + if cnt == 0: + return input, 0 + if maxcount > 0 and cnt > maxcount: + cnt = maxcount + diff_len = len(by) - 1 + try: + result_size = ovfcheck(diff_len * cnt) + result_size = ovfcheck(result_size + len(input)) + except OverflowError: + raise + + s = llstr(input) + by_ll = llstr(by) + + newstr = s.malloc(result_size) + dst = 0 + start = 0 + while maxcount != 0: + next = find(input, sub, start, len(input)) + if next < 0: + break + s.copy_contents(s, newstr, start, dst, next - start) + dst += next - start + s.copy_contents(by_ll, newstr, 0, dst, len(by)) + dst += len(by) + + start = next + 1 + maxcount -= 1 # NB. if it's already < 0, it stays < 0 + + s.copy_contents(s, newstr, start, dst, len(input) - start) + assert dst - start + len(input) == result_size + return hlstr(newstr), cnt + +def replace_count_str_str_str(input, sub, by, cnt, maxcount): + from rpython.rtyper.annlowlevel import llstr, hlstr + if cnt > maxcount and maxcount > 0: + cnt = maxcount + diff_len = len(by) - len(sub) + try: + result_size = ovfcheck(diff_len * cnt) + result_size = ovfcheck(result_size + len(input)) + except OverflowError: + raise + + s = llstr(input) + by_ll = llstr(by) + newstr = s.malloc(result_size) + sublen = len(sub) + bylen = len(by) + inputlen = len(input) + dst = 0 + start = 0 + while maxcount != 0: + next = find(input, sub, start, inputlen) + if next < 0: + break + s.copy_contents(s, newstr, start, dst, next - start) + dst += next - start + s.copy_contents(by_ll, newstr, 0, dst, bylen) + dst += bylen + start = next + sublen + maxcount -= 1 # NB. if it's already < 0, it stays < 0 + s.copy_contents(s, newstr, start, dst, len(input) - start) + assert dst - start + len(input) == result_size + return hlstr(newstr), cnt + + def _normalize_start_end(length, start, end): if start < 0: start += length @@ -355,20 +465,26 @@ def count(value, other, start, end): return _search(value, other, start, end, SEARCH_COUNT) # -------------- substring searching helper ---------------- -# XXX a lot of code duplication with lltypesystem.rstr :-( SEARCH_COUNT = 0 SEARCH_FIND = 1 SEARCH_RFIND = 2 +@specialize.ll() def bloom_add(mask, c): return mask | (1 << (ord(c) & (BLOOM_WIDTH - 1))) +@specialize.ll() def bloom(mask, c): return mask & (1 << (ord(c) & (BLOOM_WIDTH - 1))) @specialize.argtype(0, 1) def _search(value, other, start, end, mode): + assert value is not None + if isinstance(value, unicode): + NUL = u'\0' + else: + NUL = '\0' if start < 0: start = 0 if end > len(value): @@ -398,7 +514,7 @@ def _search(value, other, start, end, mode): return -1 mlast = m - 1 - skip = mlast - 1 + skip = mlast mask = 0 if mode != SEARCH_RFIND: @@ -411,7 +527,7 @@ def _search(value, other, start, end, mode): i = start - 1 while i + 1 <= start + w: i += 1 - if value[i + m - 1] == other[m - 1]: + if value[i + mlast] == other[mlast]: for j in range(mlast): if value[i + j] != other[j]: break @@ -425,7 +541,7 @@ def _search(value, other, start, end, mode): if i + m < len(value): c = value[i + m] else: - c = '\0' + c = NUL if not bloom(mask, c): i += m else: @@ -434,7 +550,7 @@ def _search(value, other, start, end, mode): if i + m < len(value): c = value[i + m] else: - c = '\0' + c = NUL if not bloom(mask, c): i += m else: diff --git a/rpython/rlib/test/test_rstring.py b/rpython/rlib/test/test_rstring.py index 18b5103e54..5fda0275e2 100644 --- a/rpython/rlib/test/test_rstring.py +++ b/rpython/rlib/test/test_rstring.py @@ -2,10 +2,12 @@ import sys, py from rpython.rlib.rstring import StringBuilder, UnicodeBuilder, split, rsplit from rpython.rlib.rstring import replace, startswith, endswith, replace_count -from rpython.rlib.rstring import find, rfind, count +from rpython.rlib.rstring import find, rfind, count, _search, SEARCH_COUNT, SEARCH_FIND from rpython.rlib.buffer import StringBuffer from rpython.rtyper.test.tool import BaseRtypingTest +from hypothesis import given, strategies as st, assume + def test_split(): def check_split(value, sub, *args, **kwargs): result = kwargs['res'] @@ -27,6 +29,11 @@ def test_split(): check_split('endcase test', 'test', res=['endcase ', '']) py.test.raises(ValueError, split, 'abc', '') +def test_split_no_occurrence(): + x = "abc" + assert x.split("d")[0] is x + assert x.rsplit("d")[0] is x + def test_split_None(): assert split("") == [] assert split(' a\ta\na b') == ['a', 'a', 'a', 'b'] @@ -164,6 +171,12 @@ def test_unicode_replace_overflow(): with py.test.raises(OverflowError): replace(s, u"a", s, len(s) - 10) +def test_replace_no_occurrence(): + s = "xyz" + assert replace(s, "a", "b") is s + s = "xyz" + assert replace(s, "abc", "b") is s + def test_startswith(): def check_startswith(value, sub, *args, **kwargs): result = kwargs['res'] @@ -240,6 +253,8 @@ def test_search(): check_search(find, 'one two three', 'ne', 5, 13, res=-1) check_search(find, 'one two three', '', 0, 13, res=0) + check_search(find, '000000p00000000', 'ap', 0, 15, res=-1) + check_search(rfind, 'one two three', 'e', 0, 13, res=12) check_search(rfind, 'one two three', 'e', 0, 1, res=-1) check_search(rfind, 'one two three', '', 0, 13, res=13) @@ -293,3 +308,33 @@ class TestTranslates(BaseRtypingTest): return res res = self.interpret(fn, []) assert res + +@given(u=st.text(), prefix=st.text(), suffix=st.text()) +def test_hypothesis_search(u, prefix, suffix): + prefix = prefix.encode("utf-8") + u = u.encode("utf-8") + suffix = suffix.encode("utf-8") + s = prefix + u + suffix + + index = _search(s, u, 0, len(s), SEARCH_FIND) + assert index == s.find(u) + assert 0 <= index <= len(prefix) + + index = _search(s, u, len(prefix), len(s) - len(suffix), SEARCH_FIND) + assert index == len(prefix) + + count = _search(s, u, 0, len(s), SEARCH_COUNT) + assert count == s.count(u) + assert 1 <= count + + +@given(st.text(), st.lists(st.text(), min_size=2), st.text(), st.integers(min_value=0, max_value=1000000)) +def test_hypothesis_search(needle, pieces, by, maxcount): + needle = needle.encode("utf-8") + pieces = [piece.encode("utf-8") for piece in pieces] + by = by.encode("utf-8") + input = needle.join(pieces) + assume(len(input) > 0) + + res = replace(input, needle, by, maxcount) + assert res == input.replace(needle, by, maxcount) diff --git a/rpython/rtyper/lltypesystem/rstr.py b/rpython/rtyper/lltypesystem/rstr.py index 0ae4c2f459..72c44ba96c 100644 --- a/rpython/rtyper/lltypesystem/rstr.py +++ b/rpython/rtyper/lltypesystem/rstr.py @@ -303,21 +303,6 @@ class UniCharRepr(AbstractUniCharRepr, UnicodeRepr): # get flowed and annotated, mostly with SomePtr. # -FAST_COUNT = 0 -FAST_FIND = 1 -FAST_RFIND = 2 - - -from rpython.rlib.rarithmetic import LONG_BIT as BLOOM_WIDTH - - -def bloom_add(mask, c): - return mask | (1 << (ord(c) & (BLOOM_WIDTH - 1))) - - -def bloom(mask, c): - return mask & (1 << (ord(c) & (BLOOM_WIDTH - 1))) - class LLHelpers(AbstractLLHelpers): from rpython.rtyper.annlowlevel import llstr, llunicode @@ -720,6 +705,7 @@ class LLHelpers(AbstractLLHelpers): @staticmethod @signature(types.any(), types.any(), types.int(), types.int(), returns=types.int()) def ll_find(s1, s2, start, end): + from rpython.rlib.rstring import SEARCH_FIND if start < 0: start = 0 if end > len(s1.chars): @@ -731,11 +717,12 @@ class LLHelpers(AbstractLLHelpers): if m == 1: return LLHelpers.ll_find_char(s1, s2.chars[0], start, end) - return LLHelpers.ll_search(s1, s2, start, end, FAST_FIND) + return LLHelpers.ll_search(s1, s2, start, end, SEARCH_FIND) @staticmethod @signature(types.any(), types.any(), types.int(), types.int(), returns=types.int()) def ll_rfind(s1, s2, start, end): + from rpython.rlib.rstring import SEARCH_RFIND if start < 0: start = 0 if end > len(s1.chars): @@ -747,10 +734,11 @@ class LLHelpers(AbstractLLHelpers): if m == 1: return LLHelpers.ll_rfind_char(s1, s2.chars[0], start, end) - return LLHelpers.ll_search(s1, s2, start, end, FAST_RFIND) + return LLHelpers.ll_search(s1, s2, start, end, SEARCH_RFIND) @classmethod def ll_count(cls, s1, s2, start, end): + from rpython.rlib.rstring import SEARCH_COUNT if start < 0: start = 0 if end > len(s1.chars): @@ -762,104 +750,19 @@ class LLHelpers(AbstractLLHelpers): if m == 1: return cls.ll_count_char(s1, s2.chars[0], start, end) - res = cls.ll_search(s1, s2, start, end, FAST_COUNT) + res = cls.ll_search(s1, s2, start, end, SEARCH_COUNT) assert res >= 0 return res @staticmethod - @jit.elidable def ll_search(s1, s2, start, end, mode): - count = 0 - n = end - start - m = len(s2.chars) + from rpython.rtyper.annlowlevel import hlstr, hlunicode + from rpython.rlib import rstring tp = typeOf(s1) if tp == string_repr.lowleveltype or tp == Char: - NUL = '\0' + return rstring._search(hlstr(s1), hlstr(s2), start, end, mode) else: - NUL = u'\0' - - if m == 0: - if mode == FAST_COUNT: - return end - start + 1 - elif mode == FAST_RFIND: - return end - else: - return start - - w = n - m - - if w < 0: - if mode == FAST_COUNT: - return 0 - return -1 - - mlast = m - 1 - skip = mlast - 1 - mask = 0 - - if mode != FAST_RFIND: - for i in range(mlast): - mask = bloom_add(mask, s2.chars[i]) - if s2.chars[i] == s2.chars[mlast]: - skip = mlast - i - 1 - mask = bloom_add(mask, s2.chars[mlast]) - - i = start - 1 - while i + 1 <= start + w: - i += 1 - if s1.chars[i + m - 1] == s2.chars[m - 1]: - for j in range(mlast): - if s1.chars[i + j] != s2.chars[j]: - break - else: - if mode != FAST_COUNT: - return i - count += 1 - i += mlast - continue - - if i + m < len(s1.chars): - c = s1.chars[i + m] - else: - c = NUL - if not bloom(mask, c): - i += m - else: - i += skip - else: - if i + m < len(s1.chars): - c = s1.chars[i + m] - else: - c = NUL - if not bloom(mask, c): - i += m - else: - mask = bloom_add(mask, s2.chars[0]) - for i in range(mlast, 0, -1): - mask = bloom_add(mask, s2.chars[i]) - if s2.chars[i] == s2.chars[0]: - skip = i - 1 - - i = start + w + 1 - while i - 1 >= start: - i -= 1 - if s1.chars[i] == s2.chars[0]: - for j in xrange(mlast, 0, -1): - if s1.chars[i + j] != s2.chars[j]: - break - else: - return i - if i - 1 >= 0 and not bloom(mask, s1.chars[i - 1]): - i -= m - else: - i -= skip - else: - if i - 1 >= 0 and not bloom(mask, s1.chars[i - 1]): - i -= m - - if mode != FAST_COUNT: - return -1 - return count + return rstring._search(hlunicode(s1), hlunicode(s2), start, end, mode) @staticmethod @signature(types.int(), types.any(), returns=types.any()) |