aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'rpython/rlib')
-rw-r--r--rpython/rlib/rstring.py152
-rw-r--r--rpython/rlib/test/test_rstring.py47
2 files changed, 180 insertions, 19 deletions
diff --git a/rpython/rlib/rstring.py b/rpython/rlib/rstring.py
index 29e1495381..b7bf0b2a16 100644
--- a/rpython/rlib/rstring.py
+++ b/rpython/rlib/rstring.py
@@ -101,9 +101,13 @@ def _split_by(value, by, maxsplit):
start = 0
if bylen == 1:
- # fast path: uses str.rfind(character) and str.count(character)
+ # fast path: uses str.find(character) and str.count(character)
by = by[0] # annotator hack: string -> char
cnt = count(value, by, 0, len(value))
+ if cnt == 0:
+ if isinstance(value, str):
+ return [value]
+ return [value[0:len(value)]]
if 0 <= maxsplit < cnt:
cnt = maxsplit
res = newlist_hint(cnt + 1)
@@ -208,12 +212,12 @@ def _rsplit_by(value, by, maxsplit):
@specialize.argtype(0, 1)
@jit.elidable
-def replace(input, sub, by, maxsplit=-1):
- return replace_count(input, sub, by, maxsplit)[0]
+def replace(input, sub, by, maxcount=-1):
+ return replace_count(input, sub, by, maxcount)[0]
@specialize.ll_and_arg(4)
@jit.elidable
-def replace_count(input, sub, by, maxsplit=-1, isutf8=False):
+def replace_count(input, sub, by, maxcount=-1, isutf8=False):
if isinstance(input, str):
Builder = StringBuilder
elif isinstance(input, unicode):
@@ -221,14 +225,14 @@ def replace_count(input, sub, by, maxsplit=-1, isutf8=False):
else:
assert isinstance(input, list)
Builder = ByteListBuilder
- if maxsplit == 0:
+ if maxcount == 0:
return input, 0
if not sub and not isutf8:
upper = len(input)
- if maxsplit > 0 and maxsplit < upper + 2:
- upper = maxsplit - 1
+ if maxcount > 0 and maxcount < upper + 2:
+ upper = maxcount - 1
assert upper >= 0
try:
@@ -246,17 +250,27 @@ def replace_count(input, sub, by, maxsplit=-1, isutf8=False):
builder.append(by)
builder.append_slice(input, upper, len(input))
replacements = upper + 1
+
+ elif isinstance(input, str) and len(sub) == 1:
+ if len(by) == 1:
+ return replace_count_str_chr_chr(input, sub[0], by[0], maxcount)
+ return replace_count_str_chr_str(input, sub[0], by, maxcount)
+
else:
# First compute the exact result size
if sub:
cnt = count(input, sub, 0, len(input))
+ if isinstance(input, str) and cnt == 0:
+ return input, 0
+ if isinstance(input, str):
+ return replace_count_str_str_str(input, sub, by, cnt, maxcount)
else:
assert isutf8
from rpython.rlib import rutf8
cnt = rutf8.codepoints_in_utf8(input) + 1
- if cnt > maxsplit and maxsplit > 0:
- cnt = maxsplit
+ if cnt > maxcount and maxcount > 0:
+ cnt = maxcount
diff_len = len(by) - len(sub)
try:
result_size = ovfcheck(diff_len * cnt)
@@ -274,26 +288,122 @@ def replace_count(input, sub, by, maxsplit=-1, isutf8=False):
from rpython.rlib import rutf8
while True:
builder.append(by)
- maxsplit -= 1
- if start == len(input) or maxsplit == 0:
+ maxcount -= 1
+ if start == len(input) or maxcount == 0:
break
next = rutf8.next_codepoint_pos(input, start)
builder.append_slice(input, start, next)
start = next
else:
- while maxsplit != 0:
+ while maxcount != 0:
next = find(input, sub, start, len(input))
if next < 0:
break
builder.append_slice(input, start, next)
builder.append(by)
start = next + sublen
- maxsplit -= 1 # NB. if it's already < 0, it stays < 0
+ maxcount -= 1 # NB. if it's already < 0, it stays < 0
builder.append_slice(input, start, len(input))
return builder.build(), replacements
+def replace_count_str_chr_chr(input, c1, c2, maxcount):
+ from rpython.rtyper.annlowlevel import llstr, hlstr
+ s = llstr(input)
+ length = len(s.chars)
+ start = find(input, c1, 0, len(input))
+ if start < 0:
+ return input, 0
+ newstr = s.malloc(length)
+ src = s.chars
+ dst = newstr.chars
+ s.copy_contents(s, newstr, 0, 0, len(input))
+ dst[start] = c2
+ count = 1
+ start += 1
+ maxcount -= 1
+ while maxcount != 0:
+ next = find(input, c1, start, len(input))
+ if next < 0:
+ break
+ dst[next] = c2
+ start = next + 1
+ maxcount -= 1
+ count += 1
+
+ return hlstr(newstr), count
+
+def replace_count_str_chr_str(input, sub, by, maxcount):
+ from rpython.rtyper.annlowlevel import llstr, hlstr
+ cnt = count(input, sub, 0, len(input))
+ if cnt == 0:
+ return input, 0
+ if maxcount > 0 and cnt > maxcount:
+ cnt = maxcount
+ diff_len = len(by) - 1
+ try:
+ result_size = ovfcheck(diff_len * cnt)
+ result_size = ovfcheck(result_size + len(input))
+ except OverflowError:
+ raise
+
+ s = llstr(input)
+ by_ll = llstr(by)
+
+ newstr = s.malloc(result_size)
+ dst = 0
+ start = 0
+ while maxcount != 0:
+ next = find(input, sub, start, len(input))
+ if next < 0:
+ break
+ s.copy_contents(s, newstr, start, dst, next - start)
+ dst += next - start
+ s.copy_contents(by_ll, newstr, 0, dst, len(by))
+ dst += len(by)
+
+ start = next + 1
+ maxcount -= 1 # NB. if it's already < 0, it stays < 0
+
+ s.copy_contents(s, newstr, start, dst, len(input) - start)
+ assert dst - start + len(input) == result_size
+ return hlstr(newstr), cnt
+
+def replace_count_str_str_str(input, sub, by, cnt, maxcount):
+ from rpython.rtyper.annlowlevel import llstr, hlstr
+ if cnt > maxcount and maxcount > 0:
+ cnt = maxcount
+ diff_len = len(by) - len(sub)
+ try:
+ result_size = ovfcheck(diff_len * cnt)
+ result_size = ovfcheck(result_size + len(input))
+ except OverflowError:
+ raise
+
+ s = llstr(input)
+ by_ll = llstr(by)
+ newstr = s.malloc(result_size)
+ sublen = len(sub)
+ bylen = len(by)
+ inputlen = len(input)
+ dst = 0
+ start = 0
+ while maxcount != 0:
+ next = find(input, sub, start, inputlen)
+ if next < 0:
+ break
+ s.copy_contents(s, newstr, start, dst, next - start)
+ dst += next - start
+ s.copy_contents(by_ll, newstr, 0, dst, bylen)
+ dst += bylen
+ start = next + sublen
+ maxcount -= 1 # NB. if it's already < 0, it stays < 0
+ s.copy_contents(s, newstr, start, dst, len(input) - start)
+ assert dst - start + len(input) == result_size
+ return hlstr(newstr), cnt
+
+
def _normalize_start_end(length, start, end):
if start < 0:
start += length
@@ -355,20 +465,26 @@ def count(value, other, start, end):
return _search(value, other, start, end, SEARCH_COUNT)
# -------------- substring searching helper ----------------
-# XXX a lot of code duplication with lltypesystem.rstr :-(
SEARCH_COUNT = 0
SEARCH_FIND = 1
SEARCH_RFIND = 2
+@specialize.ll()
def bloom_add(mask, c):
return mask | (1 << (ord(c) & (BLOOM_WIDTH - 1)))
+@specialize.ll()
def bloom(mask, c):
return mask & (1 << (ord(c) & (BLOOM_WIDTH - 1)))
@specialize.argtype(0, 1)
def _search(value, other, start, end, mode):
+ assert value is not None
+ if isinstance(value, unicode):
+ NUL = u'\0'
+ else:
+ NUL = '\0'
if start < 0:
start = 0
if end > len(value):
@@ -398,7 +514,7 @@ def _search(value, other, start, end, mode):
return -1
mlast = m - 1
- skip = mlast - 1
+ skip = mlast
mask = 0
if mode != SEARCH_RFIND:
@@ -411,7 +527,7 @@ def _search(value, other, start, end, mode):
i = start - 1
while i + 1 <= start + w:
i += 1
- if value[i + m - 1] == other[m - 1]:
+ if value[i + mlast] == other[mlast]:
for j in range(mlast):
if value[i + j] != other[j]:
break
@@ -425,7 +541,7 @@ def _search(value, other, start, end, mode):
if i + m < len(value):
c = value[i + m]
else:
- c = '\0'
+ c = NUL
if not bloom(mask, c):
i += m
else:
@@ -434,7 +550,7 @@ def _search(value, other, start, end, mode):
if i + m < len(value):
c = value[i + m]
else:
- c = '\0'
+ c = NUL
if not bloom(mask, c):
i += m
else:
diff --git a/rpython/rlib/test/test_rstring.py b/rpython/rlib/test/test_rstring.py
index 18b5103e54..5fda0275e2 100644
--- a/rpython/rlib/test/test_rstring.py
+++ b/rpython/rlib/test/test_rstring.py
@@ -2,10 +2,12 @@ import sys, py
from rpython.rlib.rstring import StringBuilder, UnicodeBuilder, split, rsplit
from rpython.rlib.rstring import replace, startswith, endswith, replace_count
-from rpython.rlib.rstring import find, rfind, count
+from rpython.rlib.rstring import find, rfind, count, _search, SEARCH_COUNT, SEARCH_FIND
from rpython.rlib.buffer import StringBuffer
from rpython.rtyper.test.tool import BaseRtypingTest
+from hypothesis import given, strategies as st, assume
+
def test_split():
def check_split(value, sub, *args, **kwargs):
result = kwargs['res']
@@ -27,6 +29,11 @@ def test_split():
check_split('endcase test', 'test', res=['endcase ', ''])
py.test.raises(ValueError, split, 'abc', '')
+def test_split_no_occurrence():
+ x = "abc"
+ assert x.split("d")[0] is x
+ assert x.rsplit("d")[0] is x
+
def test_split_None():
assert split("") == []
assert split(' a\ta\na b') == ['a', 'a', 'a', 'b']
@@ -164,6 +171,12 @@ def test_unicode_replace_overflow():
with py.test.raises(OverflowError):
replace(s, u"a", s, len(s) - 10)
+def test_replace_no_occurrence():
+ s = "xyz"
+ assert replace(s, "a", "b") is s
+ s = "xyz"
+ assert replace(s, "abc", "b") is s
+
def test_startswith():
def check_startswith(value, sub, *args, **kwargs):
result = kwargs['res']
@@ -240,6 +253,8 @@ def test_search():
check_search(find, 'one two three', 'ne', 5, 13, res=-1)
check_search(find, 'one two three', '', 0, 13, res=0)
+ check_search(find, '000000p00000000', 'ap', 0, 15, res=-1)
+
check_search(rfind, 'one two three', 'e', 0, 13, res=12)
check_search(rfind, 'one two three', 'e', 0, 1, res=-1)
check_search(rfind, 'one two three', '', 0, 13, res=13)
@@ -293,3 +308,33 @@ class TestTranslates(BaseRtypingTest):
return res
res = self.interpret(fn, [])
assert res
+
+@given(u=st.text(), prefix=st.text(), suffix=st.text())
+def test_hypothesis_search(u, prefix, suffix):
+ prefix = prefix.encode("utf-8")
+ u = u.encode("utf-8")
+ suffix = suffix.encode("utf-8")
+ s = prefix + u + suffix
+
+ index = _search(s, u, 0, len(s), SEARCH_FIND)
+ assert index == s.find(u)
+ assert 0 <= index <= len(prefix)
+
+ index = _search(s, u, len(prefix), len(s) - len(suffix), SEARCH_FIND)
+ assert index == len(prefix)
+
+ count = _search(s, u, 0, len(s), SEARCH_COUNT)
+ assert count == s.count(u)
+ assert 1 <= count
+
+
+@given(st.text(), st.lists(st.text(), min_size=2), st.text(), st.integers(min_value=0, max_value=1000000))
+def test_hypothesis_search(needle, pieces, by, maxcount):
+ needle = needle.encode("utf-8")
+ pieces = [piece.encode("utf-8") for piece in pieces]
+ by = by.encode("utf-8")
+ input = needle.join(pieces)
+ assume(len(input) > 0)
+
+ res = replace(input, needle, by, maxcount)
+ assert res == input.replace(needle, by, maxcount)