2
\$\begingroup\$

character_range aims to be an equivalent of the native range for str and bytes (currently there's no support for step and decreasing range however): lazy and intuitive. The code are fully type hinted with mypy as the primary type checker.

There are about 750 LOC. Please ignore the code style. Algorithm suggestions, optimization techniques, naming, bug hunting, general code improvements, etc. are welcome. Other files can be found at the GitHub repository.

Project structure:

.
├── CHANGELOG
├── LICENSE
├── README.md
├── docs
│   ├── Makefile
│   ├── make.bat
│   └── source
│       ├── conf.py
│       ├── index.rst
│       ├── map.rst
│       └── range.rst
├── pyproject.toml
├── src
│   └── character_range
│       ├── __init__.py
│       ├── character_and_byte_map.py
│       └── string_and_bytes_range.py
├── tests
│   ├── __init__.py
│   ├── test_character_and_byte_map.py
│   └── test_string_and_bytes_range.py
└── tox.ini

__init__.py:

'''
Does exactly what it says on the tin:

    >>> list(character_range('aaa', 'aba', CharacterMap.ASCII_LOWERCASE))
    ['aaa', 'aab', ..., 'aay', 'aaz', 'aba']
    >>> character_range(b'0', b'10', ByteMap.ASCII_LOWERCASE)
    [b'0', b'1', ..., b'9', b'00', b'01', ..., b'09', b'10']

'''

from .character_and_byte_map import (
    ByteInterval, ByteMap,
    CharacterInterval, CharacterMap
)
from .string_and_bytes_range import (
    BytesRange,
    character_range,
    StringRange
)


__all__ = [
    'ByteInterval', 'ByteMap', 'BytesRange',
    'CharacterInterval', 'CharacterMap', 'StringRange',
    'character_range'
]

character_and_byte_map.py:

'''
Implementation of:

* :class:`Interval`: :class:`CharacterInterval`, :class:`ByteInterval`
* :class:`IndexMap`: :class:`CharacterMap`, :class:`ByteMap`
'''

from __future__ import annotations

from abc import ABC, ABCMeta, abstractmethod
from collections.abc import Callable, Iterable, Iterator
from dataclasses import astuple, dataclass
from functools import cache, partial
from itertools import chain
from typing import (
    Any,
    cast,
    ClassVar,
    Generic,
    overload,
    TYPE_CHECKING,
    TypeGuard,
    TypeVar
)

from typing_extensions import Self


_T = TypeVar('_T')
_Char = TypeVar('_Char', str, bytes)
_Index = int

_LookupChar = Callable[[_Char], _Index]
_LookupIndex = Callable[[_Index], _Char]


def _ascii_repr(char: str | bytes) -> str:
    if isinstance(char, str):
        char_is_ascii_printable = ' ' <= char <= '~'
    elif isinstance(char, bytes):
        char_is_ascii_printable = b' ' <= char <= b'~'
    else:
        raise RuntimeError
    
    if char in ('\\', b'\\'):
        return r'\\'
    
    if char_is_ascii_printable:
        return char.decode() if isinstance(char, bytes) else char
    
    codepoint = ord(char)
    
    if codepoint <= 0xFF:
        return fr'\x{codepoint:02X}'
    
    if codepoint <= 0xFFFF:
        return fr'\u{codepoint:04X}'
    
    return fr'\U{codepoint:08X}'


def _is_char_of_type(
    value: str | bytes, expected_type: type[_Char], /
) -> TypeGuard[_Char]:
    return isinstance(value, expected_type) and len(value) == 1


class NoIntervals(ValueError):
    '''
    Raised when no intervals are passed
    to the map constructor.
    '''
    
    pass


class OverlappingIntervals(ValueError):
    '''
    Raised when there are at least two overlapping intervals
    in the list of intervals passed to the map constructor.
    '''
    
    def __init__(self) -> None:
        super().__init__('Intervals must not overlap')


class ConfigurationConflict(ValueError):
    '''
    Raised when the map constructor is passed:
    
    * A list of intervals whose elements don't have the same type.
    * Only one lookup function but not the other.
    '''
    
    pass


class NotACharacter(ValueError):
    '''
    Raised when an object is expected to be a character
    (a :class:`str` of length 1) but it is not one.
    '''
    
    def __init__(self, actual: object) -> None:
        if isinstance(actual, str):
            value_repr = f'string of length {len(actual)}'
        else:
            value_repr = repr(actual)
        
        super().__init__(f'Expected a character, got {value_repr}')


class NotAByte(ValueError):
    '''
    Raised when an object is expected to be a byte
    (a :class:`bytes` object of length 1) but it is not one.
    '''
    
    def __init__(self, actual: object) -> None:
        if isinstance(actual, bytes):
            value_repr = f'a bytes object of length {len(actual)}'
        else:
            value_repr = repr(actual)
        
        super().__init__(f'Expected a single byte, got {value_repr!r}')


class InvalidIntervalDirection(ValueError):
    '''
    Raised when an interval constructor is passed
    a ``start`` whose value is greater than that of ``end``.
    '''
    
    def __init__(self, start: _Char, stop: _Char) -> None:
        super().__init__(
            f'Expected stop to be greater than or equals to start, '
            f'got {start!r} > {stop!r}'
        )


class InvalidIndex(LookupError):
    '''
    Raised when the index returned
    by a ``lookup_char`` function
    is not a valid index.
    '''
    
    def __init__(self, length: int, actual_index: object) -> None:
        super().__init__(
            f'Expected lookup_char to return an integer '
            f'in the interval [0, {length}], got {actual_index!r}'
        )


class InvalidChar(LookupError):
    '''
    Raised when the character returned
    by a ``lookup_index`` function
    is not of the type expected by the map.
    '''
    
    def __init__(
        self,
        actual_char: object,
        expected_type: type[str] | type[bytes]
    ) -> None:
        if issubclass(expected_type, str):
            expected_type_name = 'string'
        else:
            expected_type_name = 'bytes object'
        
        super().__init__(
            f'Expected lookup_index to return '
            f'a {expected_type_name}, got {actual_char!r}'
        )


class Interval(Generic[_Char], ABC):
    '''
    An interval (both ends inclusive) of characters,
    represented using either :class:`str` or :class:`bytes`.
    
    For a :class:`CharacterInterval`, the codepoint of
    an endpoint must not be negative or greater than
    ``0x10FFFF``. Similarly, for a :class:`ByteInterval`,
    the integral value of an endpoint must be in
    the interval ``[0, 255]``.
    '''
    
    start: _Char
    end: _Char
    
    # PyCharm can't infer that an immutable dataclass
    # already has __hash__ defined and will therefore
    # raise a warning if this is an @abstractmethod.
    # However, it outright rejects unsafe_hash = True
    # if __hash__ is also defined, regardless of the
    # TYPE_CHECKING guard.
    
    def __hash__(self) -> int:
        raise RuntimeError('Subclasses must implement __hash__')
    
    @abstractmethod
    def __iter__(self) -> Iterator[_Char]:
        '''
        Lazily yield each character or byte.
        '''
        
        raise NotImplementedError
    
    @abstractmethod
    def __getitem__(self, item: int) -> _Char:
        '''
        ``O(1)`` indexing of character or byte.
        '''
        
        raise NotImplementedError
    
    @abstractmethod
    def __add__(self, other: Self) -> IndexMap[_Char]:
        '''
        Create a new :class:`IndexMap` with both
        ``self`` and ``other`` as ``intervals``.
        '''
        
        raise NotImplementedError
    
    def __len__(self) -> int:
        '''
        The length of the interval, equivalent to
        ``codepoint(start) - codepoint(end) + 1``.
        '''
        
        return len(self.to_codepoint_range())
    
    def __contains__(self, item: Any) -> bool:
        '''
        Assert that ``item``'s codepoint is
        greater than or equals to that of ``start``
        and less than or equals to that of ``end``.
        '''
        
        if not isinstance(item, self.start.__class__) or len(item) != 1:
            return False
        
        return self.start <= item <= self.end
    
    def __repr__(self) -> str:
        return f'{self.__class__.__name__}({self})'
    
    def __str__(self) -> str:
        if len(self) == 1:
            return _ascii_repr(self.start)
        
        return f'{_ascii_repr(self.start)}-{_ascii_repr(self.end)}'
    
    def __eq__(self, other: object) -> bool:
        '''
        Two intervals are equal if one is an instance of
        the other's class and their endpoints have the
        same integral values.
        '''
        
        if not isinstance(other, self.__class__):
            return NotImplemented
        
        return self.to_codepoint_range() == other.to_codepoint_range()
    
    def __and__(self, other: Self) -> bool:
        '''
        See :meth:`Interval.intersects`.
        '''
        
        if not isinstance(other, self.__class__):
            return NotImplemented
        
        earlier_end = min(self.end, other.end)
        later_start = max(self.start, other.start)
        
        return later_start <= earlier_end
    
    @property
    @abstractmethod
    def element_type(self) -> type[_Char]:
        '''
        A class-based property that returns
        the type of the interval's elements.
        '''
        
        raise NotImplementedError
    
    def _validate(self, *, exception_type: type[ValueError]) -> None:
        if not _is_char_of_type(self.start, self.element_type):
            raise exception_type(self.start)
        
        if not _is_char_of_type(self.end, self.element_type):
            raise exception_type(self.end)
        
        if self.start > self.end:
            raise InvalidIntervalDirection(self.start, self.end)
    
    def to_codepoint_range(self) -> range:
        '''
        Convert the interval to a native :class:`range` that
        would yield the codepoints of the elements of the interval.
        '''
        
        return range(ord(self.start), ord(self.end) + 1)
    
    def intersects(self, other: Self) -> bool:
        '''
        Whether two intervals intersect each other.
        '''
        
        return self & other


@dataclass(
    eq = False, frozen = True, repr = False,
    slots = True, unsafe_hash = True
)
class CharacterInterval(Interval[str]):
    start: str
    end: str
    
    def __post_init__(self) -> None:
        self._validate(exception_type = NotACharacter)
    
    def __iter__(self) -> Iterator[str]:
        for codepoint in self.to_codepoint_range():
            yield chr(codepoint)
    
    def __getitem__(self, item: int) -> str:
        if not 0 <= item < len(self):
            raise IndexError('Index out of range')
        
        return chr(ord(self.start) + item)
    
    def __add__(self, other: Self) -> CharacterMap:
        if not isinstance(other, self.__class__):
            return NotImplemented
        
        return CharacterMap([self, other])
    
    @property
    def element_type(self) -> type[str]:
        return str


@dataclass(
    eq = False, frozen = True, repr = False,
    slots = True, unsafe_hash = True
)
class ByteInterval(Interval[bytes]):
    start: bytes
    end: bytes
    
    def __post_init__(self) -> None:
        self._validate(exception_type = NotAByte)
    
    def __iter__(self) -> Iterator[bytes]:
        for bytes_value in self.to_codepoint_range():
            yield bytes_value.to_bytes(1, 'big')
    
    def __getitem__(self, item: int) -> bytes:
        if not isinstance(item, int):
            raise TypeError(f'Expected a non-negative integer, got {item}')
        
        if not 0 <= item < len(self):
            raise IndexError('Index out of range')
        
        return (ord(self.start) + item).to_bytes(1, 'big')
    
    def __add__(self, other: Self) -> ByteMap:
        if not isinstance(other, self.__class__):
            return NotImplemented
        
        return ByteMap([self, other])
    
    @property
    def element_type(self) -> type[bytes]:
        return bytes


@dataclass(frozen = True, slots = True)
class _Searchers(Generic[_Char]):
    lookup_char: _LookupChar[_Char] | None
    lookup_index: _LookupIndex[_Char] | None
    
    @property
    def both_given(self) -> bool:
        '''
        Whether both functions are not ``None``.
        '''
        
        return self.lookup_char is not None and self.lookup_index is not None
    
    @property
    def both_omitted(self) -> bool:
        '''
        Whether both functions are ``None``.
        '''
        
        return self.lookup_char is None and self.lookup_index is None
    
    @property
    def only_one_given(self) -> bool:
        '''
        Whether only one of the functions is ``None``.
        '''
        
        return not self.both_given and not self.both_omitted


class _RunCallbackAfterInitialization(type):
    '''
    :class:`_HasPrebuiltMembers`'s metaclass (a.k.a. metametaclass).
    Runs a callback defined at the instance's level.
    '''
    
    _callback_method_name: str
    
    def __call__(cls, *args: object, **kwargs: object) -> Any:
        class_with_prebuilt_members = super().__call__(*args, **kwargs)
        
        callback = getattr(cls, cls._callback_method_name)
        callback(class_with_prebuilt_members)
        
        return class_with_prebuilt_members


# This cannot be generic due to
# https://github.com/python/mypy/issues/11672
class _HasPrebuiltMembers(
    ABCMeta,
    metaclass = _RunCallbackAfterInitialization
):
    '''
    :class:`CharacterMap` and :class:`ByteMap`'s metaclass.
    '''
    
    _callback_method_name: ClassVar[str] = '_instantiate_members'
    
    # When the `cls` (`self`) argument is typed as 'type[_T]',
    # mypy refuses to understand that it also has the '_member_names'
    # attribute, regardless of assertions and castings.
    # The way to circumvent this is to use 'getattr()',
    # as demonstrated below in '__getitem__' and 'members'.
    _member_names: list[str]
    
    def __new__(
        mcs,
        name: str,
        bases: tuple[type, ...],
        namespace: dict[str, Any],
        **kwargs: Any
    ) -> _HasPrebuiltMembers:
        new_class = super().__new__(mcs, name, bases, namespace, **kwargs)
        
        if ABC in bases:
            return new_class
        
        new_class._member_names = [
            name for name, value in new_class.__dict__.items()
            if not name.startswith('_') and not callable(value)
        ]
        
        return new_class
    
    def __getitem__(cls: type[_T], item: str) -> _T:
        member_names: list[str] = getattr(cls, '_member_names')
        
        if item not in member_names:
            raise LookupError(f'No such member: {item!r}')
        
        return cast(_T, getattr(cls, item))
    
    @property
    def members(cls: type[_T]) -> tuple[_T, ...]:
        '''
        Returns a tuple of pre-built members of the class.
        '''
        
        member_names: list[str] = getattr(cls, '_member_names')
        
        return tuple(getattr(cls, name) for name in member_names)
    
    def _instantiate_members(cls) -> None:
        for member_name in cls._member_names:
            value = getattr(cls, member_name)
            
            if isinstance(value, tuple):
                setattr(cls, member_name, cls(*value))
            else:
                setattr(cls, member_name, cls(value))


class IndexMap(Generic[_Char], ABC):
    '''
    A two-way mapping between character or byte
    to its corresponding index.
    '''
    
    __slots__ = (
        '_intervals', '_char_to_index',
        '_searchers', '_index_to_char',
        '_maps_populated', '_element_type',
        '_not_a_char_exception'
    )
    
    _intervals: tuple[Interval[_Char], ...]
    _char_to_index: dict[_Char, _Index]
    _index_to_char: dict[_Index, _Char]
    _searchers: _Searchers[_Char]
    _element_type: type[_Char]
    _maps_populated: bool
    _not_a_char_exception: type[NotACharacter] | type[NotAByte]
    
    def __init__(
        self,
        intervals: Iterable[Interval[_Char]],
        lookup_char: _LookupChar[_Char] | None = None,
        lookup_index: _LookupIndex[_Char] | None = None
    ) -> None:
        r'''
        Construct a new map from a number of intervals.
        The underlying character-to-index and
        index-to-character maps will not be populated if
        lookup functions are given.
        
        Lookup functions are expected to be the optimized
        versions of the naive, brute-force lookup algorithm.
        This relationship is similar to that of
        ``__method__``\ s and built-ins; for example,
        while ``__contains__`` is automatically
        implemented when a class defines both ``__len__``
        and ``__getitem__``, a ``__contains__`` may still
        be needed if manual iterations are too
        unperformant, unnecessary or unwanted.
        
        Lookup functions must raise either
        :class:`LookupError` or :class:`ValueError`
        if the index or character cannot be found.
        If the index returned by ``lookup_char`` is
        not in the interval ``[0, len(self) - 1]``,
        a :class:`ValueError` is raised.
        
        :raise ConfigurationConflict: \
            If only one lookup function is given.
        :raise NoIntervals: \
            If no intervals are given.
        '''
        
        self._intervals = tuple(intervals)
        self._searchers = _Searchers(lookup_char, lookup_index)
        
        self._char_to_index = {}
        self._index_to_char = {}
        
        if self._searchers.only_one_given:
            raise ConfigurationConflict(
                'The two lookup functions must be either '
                'both given or both omitted'
            )
        
        if not self._intervals:
            raise NoIntervals('At least one interval expected')
        
        self._intervals_must_have_same_type()
        self._element_type = self._intervals[0].element_type
        
        if issubclass(self._element_type, str):
            self._not_a_char_exception = NotACharacter
        elif issubclass(self._element_type, bytes):
            self._not_a_char_exception = NotAByte
        else:
            raise RuntimeError
        
        if self._searchers.both_given:
            self._intervals_must_not_overlap()
            self._maps_populated = False
            return
        
        self._populate_maps()
        self._maps_populated = True
    
    def __hash__(self) -> int:
        return hash(self._intervals)
    
    @cache
    def __len__(self) -> int:
        if self._maps_populated:
            return len(self._char_to_index)
        
        return sum(len(interval) for interval in self._intervals)
    
    @cache
    def __repr__(self) -> str:
        joined_ranges = ''.join(str(interval) for interval in self._intervals)
        
        return f'{self.__class__.__name__}({joined_ranges})'
    
    @overload
    def __getitem__(self, item: _Char) -> _Index:
        ...
    
    @overload
    def __getitem__(self, item: _Index) -> _Char:
        ...
    
    def __getitem__(self, item: _Char | _Index) -> _Index | _Char:
        '''
        Either look for the character/index in the underlying maps,
        or delegate that task to the look-up functions given.
        
        Results are cached.
        
        :raise ValueError: \
            If ``item`` is neither a character/byte nor an index.
        :raise IndexError: \
            If ``lookup_char``
        '''
        
        if isinstance(item, int):
            return self._get_char_given_index(item)
        
        if isinstance(item, self._element_type):
            return self._get_index_given_char(item)
        
        raise TypeError(f'Expected a character or an index, got {item!r}')
    
    def __contains__(self, item: object) -> bool:
        if not isinstance(item, self._element_type | int):
            return False
        
        if isinstance(item, int):
            return 0 <= item < len(self)
        
        try:
            # This is necessary for PyCharm,
            # deemed redundant by mypy,
            # and makes pyright think that 'item' is of type 'object'.
            item = cast(_Char, item)  # type: ignore
            
            _ = self._get_index_given_char(item)
        except (LookupError, ValueError):
            return False
        else:
            return True
    
    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return NotImplemented
        
        return self._intervals == other._intervals
    
    # Needed for testing and type hint convenience
    def __add__(self, other: Self | IndexMap[_Char] | Interval[_Char]) -> Self:
        if not isinstance(other, IndexMap | Interval):
            return NotImplemented
        
        if other.element_type is not self.element_type:
            raise ConfigurationConflict('Different element types')
        
        lookup_char, lookup_index = astuple(self._searchers)
        
        if isinstance(other, Interval):
            return self.__class__(
                self._intervals + tuple([other]),
                lookup_char = lookup_char,
                lookup_index = lookup_index
            )
        
        if self._searchers != other._searchers:
            raise ConfigurationConflict(
                'Maps having different lookup functions '
                'cannot be combined'
            )
        
        return self.__class__(
            self._intervals + other._intervals,
            lookup_char = lookup_char,
            lookup_index = lookup_index
        )
    
    @property
    def intervals(self) -> tuple[Interval[_Char], ...]:
        '''
        The intervals that make up the map.
        '''
        
        return self._intervals
    
    @property
    def element_type(self) -> type[_Char]:
        '''
        The type of the map's elements.
        '''
        
        return self._element_type
    
    def _intervals_must_have_same_type(self) -> None:
        interval_types = {type(interval) for interval in self._intervals}
        
        if len(interval_types) > 1:
            raise ConfigurationConflict('Intervals must be of same types')
    
    def _intervals_must_not_overlap(self) -> None:
        seen: list[Interval[_Char]] = []
        
        for current_interval in self._intervals:
            overlapped = any(
                current_interval.intersects(seen_interval)
                for seen_interval in seen
            )
            
            if overlapped:
                raise OverlappingIntervals
            
            seen.append(current_interval)
    
    def _populate_maps(self) -> None:
        chained_intervals = chain.from_iterable(self._intervals)
        
        for index, char in enumerate(chained_intervals):
            if char in self._char_to_index:
                raise OverlappingIntervals
            
            self._char_to_index[char] = index
            self._index_to_char[index] = char
    
    def _get_char_given_index(self, index: _Index, /) -> _Char:
        if index not in self:
            raise IndexError(f'Index {index} is out of range')
        
        if index in self._index_to_char:
            return self._index_to_char[index]
        
        lookup_index = self._searchers.lookup_index
        assert lookup_index is not None
        
        result = lookup_index(index)
        
        if not _is_char_of_type(result, self._element_type):
            raise InvalidChar(result, self._element_type)
        
        self._index_to_char[index] = result
        return self._index_to_char[index]
    
    def _get_index_given_char(self, char: _Char, /) -> _Index:
        if not _is_char_of_type(char, self._element_type):
            raise self._not_a_char_exception(char)
        
        if char in self._char_to_index:
            return self._char_to_index[char]
        
        lookup_char = self._searchers.lookup_char
        
        if lookup_char is None:
            raise LookupError(f'Char {char!r} is not in the map')
        
        result = lookup_char(char)
        
        if not isinstance(result, int) or result not in self:
            raise InvalidIndex(len(self), result)
        
        self._char_to_index[char] = result
        return self._char_to_index[char]


def _ascii_index_from_char_or_byte(char_or_byte: str | bytes) -> int:
    codepoint = ord(char_or_byte)
    
    if not 0 <= codepoint <= 0xFF:
        raise ValueError('Not an ASCII character or byte')
    
    return codepoint


@overload
def _ascii_char_or_byte_from_index(constructor: type[str], index: int) -> str:
    ...


@overload
def _ascii_char_or_byte_from_index(
    constructor: type[bytes],
    index: int
) -> bytes:
    ...


def _ascii_char_or_byte_from_index(
    constructor: type[str] | type[bytes],
    index: int
) -> str | bytes:
    if issubclass(constructor, str):
        return constructor(chr(index))
    
    if issubclass(constructor, bytes):
        # \x80 and higher would be converted
        # to two bytes with .encode() alone.
        return constructor(index.to_bytes(1, 'big'))
    
    raise RuntimeError


_ascii_char_from_index = cast(
    Callable[[int], str],
    partial(_ascii_char_or_byte_from_index, str)
)
_ascii_byte_from_index = cast(
    Callable[[int], bytes],
    partial(_ascii_char_or_byte_from_index, bytes)
)

if TYPE_CHECKING:
    class CharacterMap(IndexMap[str], metaclass = _HasPrebuiltMembers):
        # At runtime, this is a read-only class-level property.
        members: ClassVar[tuple[CharacterMap, ...]]
        
        ASCII_LOWERCASE: ClassVar[CharacterMap]
        ASCII_UPPERCASE: ClassVar[CharacterMap]
        ASCII_LETTERS: ClassVar[CharacterMap]
        
        ASCII_DIGITS: ClassVar[CharacterMap]
        
        LOWERCASE_HEX_DIGITS: ClassVar[CharacterMap]
        UPPERCASE_HEX_DIGITS: ClassVar[CharacterMap]
        
        LOWERCASE_BASE_36: ClassVar[CharacterMap]
        UPPERCASE_BASE_36: ClassVar[CharacterMap]
        
        ASCII: ClassVar[CharacterMap]
        NON_ASCII: ClassVar[CharacterMap]
        UNICODE: ClassVar[CharacterMap]
        
        # At runtime, this functionality is provided
        # using the metaclass's __getitem__.
        def __class_getitem__(cls, item: str) -> CharacterMap:
            ...
    
    
    class ByteMap(IndexMap[bytes], metaclass = _HasPrebuiltMembers):
        # At runtime, this is a read-only class-level property.
        members: ClassVar[tuple[ByteMap, ...]]
        
        ASCII_LOWERCASE: ClassVar[ByteMap]
        ASCII_UPPERCASE: ClassVar[ByteMap]
        ASCII_LETTERS: ClassVar[ByteMap]
        
        ASCII_DIGITS: ClassVar[ByteMap]
        
        LOWERCASE_HEX_DIGITS: ClassVar[ByteMap]
        UPPERCASE_HEX_DIGITS: ClassVar[ByteMap]
        
        LOWERCASE_BASE_36: ClassVar[ByteMap]
        UPPERCASE_BASE_36: ClassVar[ByteMap]
        
        ASCII: ClassVar[ByteMap]
        
        # At runtime, this functionality is provided
        # using the metaclass's __getitem__.
        def __class_getitem__(cls, item: str) -> ByteMap:
            ...

else:
    class CharacterMap(IndexMap[str], metaclass = _HasPrebuiltMembers):
        ASCII_LOWERCASE = [CharacterInterval('a', 'z')]
        ASCII_UPPERCASE = [CharacterInterval('A', 'Z')]
        ASCII_LETTERS = ASCII_LOWERCASE + ASCII_UPPERCASE
        
        ASCII_DIGITS = [CharacterInterval('0', '9')]
        
        LOWERCASE_HEX_DIGITS = ASCII_DIGITS + [CharacterInterval('a', 'f')]
        UPPERCASE_HEX_DIGITS = ASCII_DIGITS + [CharacterInterval('A', 'F')]
        
        LOWERCASE_BASE_36 = ASCII_DIGITS + ASCII_LOWERCASE
        UPPERCASE_BASE_36 = ASCII_DIGITS + ASCII_UPPERCASE
        
        ASCII = (
            [CharacterInterval('\x00', '\xFF')],
            _ascii_index_from_char_or_byte,
            _ascii_char_from_index
        )
        NON_ASCII = (
            [CharacterInterval('\u0100', '\U0010FFFF')],
            lambda char: ord(char) - 0x100,
            lambda index: chr(index + 0x100)
        )
        UNICODE = ([CharacterInterval('\x00', '\U0010FFFF')], ord, chr)
    
    
    class ByteMap(IndexMap[bytes], metaclass = _HasPrebuiltMembers):
        ASCII_LOWERCASE = [ByteInterval(b'a', b'z')]
        ASCII_UPPERCASE = [ByteInterval(b'A', b'Z')]
        ASCII_LETTERS = ASCII_LOWERCASE + ASCII_UPPERCASE
        
        ASCII_DIGITS = [ByteInterval(b'0', b'9')]
        
        LOWERCASE_HEX_DIGITS = ASCII_DIGITS + [ByteInterval(b'a', b'f')]
        UPPERCASE_HEX_DIGITS = ASCII_DIGITS + [ByteInterval(b'A', b'F')]
        
        LOWERCASE_BASE_36 = ASCII_DIGITS + ASCII_LOWERCASE
        UPPERCASE_BASE_36 = ASCII_DIGITS + ASCII_UPPERCASE
        
        ASCII = (
            [ByteInterval(b'\x00', b'\xFF')],
            _ascii_index_from_char_or_byte,
            _ascii_byte_from_index
        )

string_and_bytes_range.py:

'''
The highest-level features of the package, implemented as
:class:`_Range` and :func:`character_range`.
'''

from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Iterable, Iterator
from enum import Enum
from functools import total_ordering
from typing import cast, Generic, overload, TypeVar

from typing_extensions import Literal, Self

from .character_and_byte_map import ByteMap, CharacterMap, IndexMap


_StrOrBytes = TypeVar('_StrOrBytes', str, bytes)

# Keep these in sync with CharacterMap and ByteMap
# TODO: Allow passing the name of a prebuilt map as an argument.
_CharacterMapName = Literal[
    'ascii_lowercase',
    'ascii_uppercase',
    'ascii_letters',
    'ascii_digits',
    'lowercase_hex_digits',
    'uppercase_hex_digits',
    'lowercase_base_36',
    'uppercase_base_36',
    'ascii',
    'non_ascii',
    'unicode'
]
_ByteMapName = Literal[
    'ascii_lowercase',
    'ascii_uppercase',
    'ascii_letters',
    'ascii_digits',
    'lowercase_hex_digits',
    'uppercase_hex_digits',
    'lowercase_base_36',
    'uppercase_base_36',
    'ascii'
]


@overload
def _get_prebuilt_map(
    map_class: type[CharacterMap],
    name: str
) -> CharacterMap:
    ...


@overload
def _get_prebuilt_map(
    map_class: type[ByteMap],
    name: str
) -> ByteMap:
    ...


def _get_prebuilt_map(
    map_class: type[CharacterMap] | type[ByteMap],
    name: str
) -> CharacterMap | ByteMap:
    try:
        member = map_class[name.upper()]
    except KeyError:
        raise _NoSuchPrebuiltMap(name)
    
    return cast(CharacterMap | ByteMap, member)


def _split(value: _StrOrBytes) -> list[_StrOrBytes]:
    if isinstance(value, str):
        return list(value)
    
    return [
        byte_as_int.to_bytes(1, 'big')
        for byte_as_int in value
    ]


# TODO: Support different types of ranges
class _RangeType(str, Enum):
    '''
    Given a range from ``aa`` to ``zz``:

    +------------+----------+----------+
    | Range type | Contains | Contains |
    | / Property | ``aa``   | ``zz``   |
    +============+==========+==========+
    | Open       |    No    |    No    |
    +------------+----------+----------+
    | Closed     |    Yes   |    Yes   |
    +------------+----------+----------+
    | Left-open  |    No    |    Yes   |
    +------------+----------+----------+
    | Right-open |    Yes   |    No    |
    +------------+----------+----------+

    These terms are taken from
    `the Wikipedia article about mathematical intervals \
    <https://en.wikipedia.org/wiki/Interval_(mathematics)>`_.

    A :class:`_Range` always represent a closed interval.
    However, for convenience, :func:`character_range`
    accepts an optional ``range_type`` argument that
    deals with these.
    '''
    
    OPEN = 'open'
    CLOSED = 'closed'
    LEFT_OPEN = 'left_open'
    RIGHT_OPEN = 'right_open'


class InvalidEndpoints(ValueError):
    '''
    Raised when the endpoints given to :class:`_Range` is either:
    
    * Empty, or
    * At least one character is not in the corresponding map.
    '''
    
    def __init__(self, *endpoints: str | bytes):
        super().__init__(', '.join(repr(endpoint) for endpoint in endpoints))


class InvalidRangeDirection(ValueError):
    '''
    Raised when ``start`` is longer than ``end`` or
    they have the same length but ``start`` is
    lexicographically "less than" end.
    '''
    
    def __init__(self, start: object, end: object) -> None:
        super().__init__(f'Start is greater than end ({start!r} > {end!r})')


class _NoSuchPrebuiltMap(ValueError):
    
    def __init__(self, name: str) -> None:
        super().__init__(f'No such prebuilt map with given name: {name!r}')


class _EmptyListOfIndices(ValueError):
    
    def __init__(self) -> None:
        super().__init__('List of indices must not be empty')


class _InvalidBase(ValueError):
    
    def __init__(self, actual: object) -> None:
        super().__init__(f'Expected a positive integer, got {actual!r}')


@total_ordering
class _IncrementableIndexCollection:
    '''
    A collection of indices of a :class:`IndexMap`
    that can be incremented one by one.
    
    :meth:`_MonotonicIndexCollection.increment`
    works in an index-wise manner::
    
        >>> c = _IncrementableIndexCollection([1], 2)
        >>> c.increment()
        _MonotonicIndexCollection([0, 0], base = 2)
    '''
    
    __slots__ = ('_inverted_indices', '_base')
    
    _inverted_indices: list[int]
    _base: int
    
    def __init__(self, indices: Iterable[int], /, base: int) -> None:
        self._inverted_indices = list(indices)[::-1]
        self._base = base
        
        if not self._inverted_indices:
            raise _EmptyListOfIndices
        
        if not isinstance(base, int) or base < 1:
            raise _InvalidBase(base)
    
    def __repr__(self) -> str:
        indices, base = self._indices, self._base
        
        return f'{self.__class__.__name__}({indices!r}, {base = !r})'
    
    def __index__(self) -> int:
        '''
        The integeral value computed by interpreting
        the indices as the digits of a base-*n* integer.
        '''
        
        total = 0
        
        for order_of_magnitude, index in enumerate(self._inverted_indices):
            total += index * self._base ** order_of_magnitude
        
        return total
    
    def __len__(self) -> int:
        '''
        The number of indices the collection currently holds.
        '''
        
        return len(self._inverted_indices)
    
    def __iter__(self) -> Iterator[int]:
        '''
        Lazily yield the elements this collection currently holds.
        '''
        
        yield from reversed(self._inverted_indices)
    
    def __lt__(self, other: Self) -> bool:
        '''
        Whether ``other``'s length is greater than ``self``'s
        or the lengths are equals but the integral value of
        ``other`` is greater than that of ``self``.
        '''
        
        if not isinstance(other, self.__class__):
            return NotImplemented
        
        if len(self) < len(other):
            return True
        
        return len(self) == len(other) and self._indices < other._indices
    
    def __eq__(self, other: object) -> bool:
        '''
        Whether two collections have the same base and elements.
        '''
        
        if not isinstance(other, self.__class__):
            return NotImplemented
        
        return self._base == other._base and self._indices == other._indices
    
    @property
    def _indices(self) -> tuple[int, ...]:
        '''
        The current indices of the collection.
        '''
        
        return tuple(self)
    
    @property
    def base(self) -> int:
        '''
        The maximum value of an index, plus 1.
        '''
        
        return self._base
    
    def increment(self) -> Self:
        '''
        Add 1 to the last index. If the new value is
        equals to ``base``, that index will become 0
        and the process continues at the next index.
        If the last index is reached, a new index (0)
        is added to the list.
        
        This is equivalent to C/C++'s pre-increment
        operator, in that it returns the original value
        after modification.
        
        Examples::
        
            [0, 0] -> [0, 1]
            [0, 1] -> [1, 0]
            [1, 1] -> [0, 0, 0]
        '''
        
        for order_of_magnitude in range(len(self._inverted_indices)):
            self._inverted_indices[order_of_magnitude] += 1
            
            if self._inverted_indices[order_of_magnitude] < self._base:
                return self
            
            self._inverted_indices[order_of_magnitude] = 0
        
        self._inverted_indices.append(0)
        
        return self


class _Range(Generic[_StrOrBytes], ABC):
    '''
    Represents a range between two
    string or bytes object endpoints.
    
    A range of this type is always a closed interval:
    both endpoints are inclusive. This goes in line
    with how regex character ranges work, even though
    those only ever support single characters::
    
        >>> list(StringRange('a', 'c', CharacterMap.ASCII_LOWERCASE))
        ['a', 'b', 'c']
        >>> list(StringRange('aa', 'ac', CharacterMap.ASCII_LOWERCASE))
        ['aa', 'ab', 'ac']
    
    For :class:`BytesRange`, each byte of the yielded
    :class:`bytes` objects will have the corresponding
    integral values ranging from 0 through 0xFF::
    
        >>> list(BytesRange(b'0xFE', b'0x81', ByteMap.ASCII))
        [b'0xFE', b'0xFF', b'0x80', b'0x81']
    
    Also note that the next value after
    ``[base - 1]`` is ``[0, 0]``, not ``[1, 0]``::
    
        >>> list(StringRange('0', '19', CharacterMap.ASCII_DIGITS))
        [
          '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
          '00', '01', '02', '03', '04', '05', '06', '07', '08', '09',
          '10', '11', '12', '13', '14', '15', '16', '17', '18', '19'
        ]
    
    See also :class:`_IncrementableIndexCollection`.
    '''
    
    __slots__ = ('_start', '_end', '_map')
    
    _start: _StrOrBytes
    _end: _StrOrBytes
    _map: IndexMap[_StrOrBytes]
    
    def __init__(
        self, start: _StrOrBytes, end: _StrOrBytes, /,
        index_map: IndexMap[_StrOrBytes]
    ) -> None:
        self._start = start
        self._end = end
        self._map = index_map
        
        start_is_valid = self._is_valid_endpoint(start)
        end_is_valid = self._is_valid_endpoint(end)
        
        if not start_is_valid or not end_is_valid:
            raise InvalidEndpoints(start, end)
        
        if len(start) > len(end) or len(start) == len(end) and start > end:
            raise InvalidRangeDirection(start, end)
    
    def __repr__(self) -> str:
        return f'{self.__class__.__name__}({self._start!r}, {self._end!r})'
    
    def __iter__(self) -> Iterator[_StrOrBytes]:
        '''
        Lazily yield the elements.
        '''
        
        current = self._make_collection(self._start)
        end = self._make_collection(self._end)
        
        # https://github.com/python/mypy/issues/16711
        while current <= end:  # type: ignore
            yield self._make_element(current)
            current.increment()
    
    def __len__(self) -> int:
        '''
        The number of elements the range would yield,
        calculated mathematically.
        '''
        
        # Realistic example:
        # start = 'y'; end = 'aaac'
        # base = len('a'-'z') = 26
        #
        # len = (
        #     (len('a'-'z') + len('aa'-'zz') + len('aaa'-'zzz')) +
        #     len('aaaa'-'aaac') -
        #     (len('a'-'y') - len('y'-'y')
        # )
        # len = (base ** 1 + base ** 2 + base ** 3) + 3 - (25 - 1)
        # len = (26 ** 1 + 26 ** 2 + 26 ** 3) + 3 - 24
        
        start, end = self._start, self._end
        base = len(self._map)
        
        from_len_start_up_to_len_end: int = sum(
            base ** width
            for width in range(len(start), len(end))
        )
        
        from_len_start_through_start = int(self._make_collection(start))
        from_len_end_through_end = int(self._make_collection(end))
        
        result = from_len_start_up_to_len_end
        result += from_len_end_through_end
        result -= from_len_start_through_start
        result += 1
        
        return result
    
    @property
    def _base(self) -> int:
        return len(self._map)
    
    @property
    def start(self) -> _StrOrBytes:
        '''
        The starting endpoint of the range.
        '''
        
        return self._start
    
    @property
    def end(self) -> _StrOrBytes:
        '''
        The ending endpoint of the range.
        '''
        
        return self._end
    
    @property
    def map(self) -> IndexMap[_StrOrBytes]:
        '''
        The map to look up the available characters or bytes.
        '''
        
        return self._map
    
    @property
    def element_type(self) -> type[_StrOrBytes]:
        '''
        The element type of :meth:`map`.
        
        See :meth:`IndexMap.element_type`.
        '''
        
        return self._map.element_type
    
    @abstractmethod
    def _make_element(
        self, indices: _IncrementableIndexCollection, /
    ) -> _StrOrBytes:
        raise NotImplementedError
    
    def _is_valid_endpoint(self, value: _StrOrBytes) -> bool:
        return (
            len(value) > 0 and
            all(char in self._map for char in _split(value))
        )
    
    def _make_collection(
        self, value: _StrOrBytes, /
    ) -> _IncrementableIndexCollection:
        indices = (self._map[char] for char in _split(value))
        
        return _IncrementableIndexCollection(indices, len(self._map))


class StringRange(_Range[str]):
    
    def _make_element(self, indices: _IncrementableIndexCollection, /) -> str:
        return ''.join(self._map[index] for index in indices)


class BytesRange(_Range[bytes]):
    
    def _make_element(self, indices: _IncrementableIndexCollection, /) -> bytes:
        return b''.join(self._map[index] for index in indices)


@overload
def character_range(
    start: str, end: str, /,
    index_map: IndexMap[str]
) -> StringRange:
    ...


@overload
def character_range(
    start: bytes, end: bytes, /,
    index_map: IndexMap[bytes]
) -> BytesRange:
    ...


# TODO: Design a more intuitive signature for this function
# Example: A map parser than takes in a string of form r'a-zB-J&$\-\x00-\x12'
def character_range(
    start: str | bytes,
    end: str | bytes, /,
    index_map: IndexMap[str] | IndexMap[bytes]
) -> StringRange | BytesRange:
    '''
    ``range``-lookalike alias for
    :class:`StringRange` and :class:`BytesRange`.
    
    ``start`` and ``end`` must be of the same type,
    either :class:`str` or :class:`bytes`.
    ``index_map`` must contain all elements of
    both of them.
    '''
    
    map_class: type[CharacterMap] | type[ByteMap] | None = None
    range_class: type[StringRange] | type[BytesRange] | None = None
    
    if isinstance(start, str) and isinstance(end, str):
        map_class = CharacterMap
        range_class = StringRange
    
    if isinstance(start, bytes) and isinstance(end, bytes):
        map_class = ByteMap
        range_class = BytesRange
    
    if map_class is not None and range_class is not None:
        # Either mypy isn't yet smart enough to figure out
        # that this will not cause errors, or I'm not smart
        # enough to figure out all cases.
        return range_class(start, end, index_map)  # type: ignore
    
    raise TypeError(
        f'Expected two strings or two bytes objects, got '
        f'{type(start).__name__} and {type(end).__name__}'
    )

\$\endgroup\$
3
  • 1
    \$\begingroup\$ Please ignore the code style is not how Code Review works. Any insightful observation is on-topic. \$\endgroup\$ Commented Dec 27, 2023 at 13:41
  • \$\begingroup\$ @Reinderien That was a copy & paste error. Edited. \$\endgroup\$ Commented Dec 27, 2023 at 14:02
  • \$\begingroup\$ Regarding the code style, if you want to include that in your answer, sure, you can, but I won't need it. \$\endgroup\$ Commented Dec 27, 2023 at 14:19

1 Answer 1

2
\$\begingroup\$

There are many marks of a vaguely good library here - multiple modules, docs, tests, types, __all__, custom exceptions.

Does exactly what it says on the tin

is debatable. The moment a Python developer sees range, it carries implications: we expect a half-open interval. That applies to (for example) the built-in range, randrange, and Numpy's arange. Also note that all three have the parameters start, stop, step. Strongly consider following this signature and making the range half-open.

Good job with the types. _Index = int is good, but could be strengthened with NewType.

In _ascii_repr, raise RuntimeError is not appropriate and should be replaced with raise TypeError().

Don't triple-apostrophe docstrings; use triple quotes. You've said that you're using PyCharm, which means it will already have warned you about this, assuming that you have a sane configuration.

Otherwise, it just... seems like a tonne of code for what is a very simple operation. At the very least, I'd drop support for bytes entirely, since it's trivial for the user to encode and decode themselves.

About your tests:

Most test cases are dynamically generated. (Perfect Heisenbug environment, I know.)

So... you know the problem. Either do plain-old hard-coded tests, or use a testing library like hypothesis that takes stochastic testing seriously.

Is

    >>> character_range(b'0', b'10', ByteMap.ASCII_LOWERCASE)
    [b'0', b'1', ..., b'9', b'00', b'01', ..., b'09', b'10']

accurate? Sure seems like it isn't, and should show ASCII_DIGITS instead.

\$\endgroup\$
1
  • \$\begingroup\$ Range types (open, closed, left-open, right-open) and step have yet to be implemented (see TODO comments). _Index cannot be a NewType since index_map[0] would not pass type checking. RuntimeError raises mark intended dead branches (I want to use elif instead of just else). Regardless, I upvoted this answer based solely on the suggestion of Hypothesis (and the typo); exactly what I would have used had I known of it. \$\endgroup\$ Commented Dec 27, 2023 at 14:47

You must log in to answer this question.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.