"""
HTTP client that talks to the HTTP storage server.
"""

from __future__ import annotations


from typing import (
    Optional,
    Sequence,
    Mapping,
    BinaryIO,
    cast,
    TypedDict,
    Set,
    Dict,
    Callable,
    ClassVar,
)
from base64 import b64encode
from io import BytesIO
from os import SEEK_END

from attrs import define, asdict, frozen, field
from eliot import start_action, register_exception_extractor
from eliot.twisted import DeferredContext

from pycddl import Schema
from collections_extended import RangeMap
from werkzeug.datastructures import Range, ContentRange
from twisted.web.http_headers import Headers
from twisted.web import http
from twisted.web.iweb import IPolicyForHTTPS, IResponse, IAgent
from twisted.internet.defer import Deferred, succeed
from twisted.internet.interfaces import (
    IOpenSSLClientConnectionCreator,
    IReactorTime,
    IDelayedCall,
)
from twisted.internet.ssl import CertificateOptions
from twisted.protocols.tls import TLSMemoryBIOProtocol
from twisted.web.client import Agent, HTTPConnectionPool
from zope.interface import implementer
from hyperlink import DecodedURL
import treq
from treq.client import HTTPClient
from OpenSSL import SSL
from werkzeug.http import parse_content_range_header

from .http_common import (
    swissnum_auth_header,
    Secrets,
    get_content_type,
    CBOR_MIME_TYPE,
    get_spki_hash,
    response_is_not_html,
)
from ..interfaces import VersionMessage
from .common import si_b2a, si_to_human_readable
from ..util.hashutil import timing_safe_compare
from ..util.deferredutil import async_to_deferred
from ..util.tor_provider import _Provider as TorProvider
from ..util.cputhreadpool import defer_to_thread
from ..util.cbor import dumps

try:
    from txtorcon import Tor  # type: ignore
except ImportError:

    class Tor:  # type: ignore[no-redef]
        pass


def _encode_si(si: bytes) -> str:
    """Encode the storage index into Unicode string."""
    return str(si_b2a(si), "ascii")


class ClientException(Exception):
    """An unexpected response code from the server."""

    def __init__(
        self, code: int, message: Optional[str] = None, body: Optional[bytes] = None
    ):
        Exception.__init__(self, code, message, body)
        self.code = code
        self.message = message
        self.body = body


register_exception_extractor(ClientException, lambda e: {"response_code": e.code})


# Schemas for server responses.
#
# Tags are of the form #6.nnn, where the number is documented at
# https://www.iana.org/assignments/cbor-tags/cbor-tags.xhtml. Notably, #6.258
# indicates a set.
_SCHEMAS: Mapping[str, Schema] = {
    "get_version": Schema(
        # Note that the single-quoted (`'`) string keys in this schema
        # represent *byte* strings - per the CDDL specification.  Text strings
        # are represented using strings with *double* quotes (`"`).
        """
        response = {'http://allmydata.org/tahoe/protocols/storage/v1' => {
                 'maximum-immutable-share-size' => uint
                 'maximum-mutable-share-size' => uint
                 'available-space' => uint
                 }
                 'application-version' => bstr
              }
    """
    ),
    "allocate_buckets": Schema(
        """
    response = {
      already-have: #6.258([0*256 uint])
      allocated: #6.258([0*256 uint])
    }
    """
    ),
    "immutable_write_share_chunk": Schema(
        """
    response = {
      required: [0* {begin: uint, end: uint}]
    }
    """
    ),
    "list_shares": Schema(
        """
    response = #6.258([0*256 uint])
    """
    ),
    "mutable_read_test_write": Schema(
        """
        response = {
          "success": bool,
          "data": {0*256 share_number: [0* bstr]}
        }
        share_number = uint
        """
    ),
    "mutable_list_shares": Schema(
        """
        response = #6.258([0*256 uint])
        """
    ),
}


@define
class _LengthLimitedCollector:
    """
    Collect data using ``treq.collect()``, with limited length.
    """

    remaining_length: int
    timeout_on_silence: IDelayedCall
    f: BytesIO = field(factory=BytesIO)

    def __call__(self, data: bytes) -> None:
        self.timeout_on_silence.reset(60)
        self.remaining_length -= len(data)
        if self.remaining_length < 0:
            raise ValueError("Response length was too long")
        self.f.write(data)


def limited_content(
    response: IResponse,
    clock: IReactorTime,
    max_length: int = 30 * 1024 * 1024,
) -> Deferred[BinaryIO]:
    """
    Like ``treq.content()``, but limit data read from the response to a set
    length.  If the response is longer than the max allowed length, the result
    fails with a ``ValueError``.

    A potentially useful future improvement would be using a temporary file to
    store the content; since filesystem buffering means that would use memory
    for small responses and disk for large responses.

    This will time out if no data is received for 60 seconds; so long as a
    trickle of data continues to arrive, it will continue to run.
    """
    result_deferred = succeed(None)

    # Sadly, addTimeout() won't work because we need access to the IDelayedCall
    # in order to reset it on each data chunk received.
    timeout = clock.callLater(60, result_deferred.cancel)
    collector = _LengthLimitedCollector(max_length, timeout)

    with start_action(
        action_type="allmydata:storage:http-client:limited-content",
        max_length=max_length,
    ).context():
        d = DeferredContext(result_deferred)

    # Make really sure everything gets called in Deferred context, treq might
    # call collector directly...
    d.addCallback(lambda _: treq.collect(response, collector))

    def done(_: object) -> BytesIO:
        timeout.cancel()
        collector.f.seek(0)
        return collector.f

    def failed(f):
        if timeout.active():
            timeout.cancel()
        return f

    result = d.addCallbacks(done, failed)
    return result.addActionFinish()


@define
class ImmutableCreateResult:
    """Result of creating a storage index for an immutable."""

    already_have: set[int]
    allocated: set[int]


class _TLSContextFactory(CertificateOptions):
    """
    Create a context that validates the way Tahoe-LAFS wants to: based on a
    pinned certificate hash, rather than a certificate authority.

    Originally implemented as part of Foolscap.  To comply with the license,
    here's the original licensing terms:

    Copyright (c) 2006-2008 Brian Warner

    Permission is hereby granted, free of charge, to any person obtaining a
    copy of this software and associated documentation files (the "Software"),
    to deal in the Software without restriction, including without limitation
    the rights to use, copy, modify, merge, publish, distribute, sublicense,
    and/or sell copies of the Software, and to permit persons to whom the
    Software is furnished to do so, subject to the following conditions:

    The above copyright notice and this permission notice shall be included in
    all copies or substantial portions of the Software.

    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
    THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
    FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
    DEALINGS IN THE SOFTWARE.
    """

    def __init__(self, expected_spki_hash: bytes):
        self.expected_spki_hash = expected_spki_hash
        CertificateOptions.__init__(self)

    def getContext(self) -> SSL.Context:
        def always_validate(conn, cert, errno, depth, preverify_ok):
            # This function is called to validate the certificate received by
            # the other end. OpenSSL calls it multiple times, for each errno
            # for each certificate.

            # We do not care about certificate authorities or revocation
            # lists, we just want to know that the certificate has a valid
            # signature and follow the chain back to one which is
            # self-signed. We need to protect against forged signatures, but
            # not the usual TLS concerns about invalid CAs or revoked
            # certificates.
            things_are_ok = (
                SSL.X509VerificationCodes.OK,
                SSL.X509VerificationCodes.ERR_CERT_NOT_YET_VALID,
                SSL.X509VerificationCodes.ERR_CERT_HAS_EXPIRED,
                SSL.X509VerificationCodes.ERR_DEPTH_ZERO_SELF_SIGNED_CERT,
                SSL.X509VerificationCodes.ERR_SELF_SIGNED_CERT_IN_CHAIN,
            )
            # TODO can we do this once instead of multiple times?
            if errno in things_are_ok and timing_safe_compare(
                get_spki_hash(cert.to_cryptography()), self.expected_spki_hash
            ):
                return 1
            # TODO: log the details of the error, because otherwise they get
            # lost in the PyOpenSSL exception that will eventually be raised
            # (possibly OpenSSL.SSL.Error: certificate verify failed)
            return 0

        ctx = CertificateOptions.getContext(self)

        # VERIFY_PEER means we ask the the other end for their certificate.
        ctx.set_verify(SSL.VERIFY_PEER, always_validate)
        return ctx


@implementer(IPolicyForHTTPS)
@implementer(IOpenSSLClientConnectionCreator)
@define
class _StorageClientHTTPSPolicy:
    """
    A HTTPS policy that ensures the SPKI hash of the public key matches a known
    hash, i.e. pinning-based validation.
    """

    expected_spki_hash: bytes

    # IPolicyForHTTPS
    def creatorForNetloc(self, hostname: str, port: int) -> _StorageClientHTTPSPolicy:
        return self

    # IOpenSSLClientConnectionCreator
    def clientConnectionForTLS(
        self, tlsProtocol: TLSMemoryBIOProtocol
    ) -> SSL.Connection:
        return SSL.Connection(
            _TLSContextFactory(self.expected_spki_hash).getContext(), None
        )


@define
class StorageClientFactory:
    """
    Create ``StorageClient`` instances, using appropriate
    ``twisted.web.iweb.IAgent`` for different connection methods: normal TCP,
    Tor, and eventually I2P.

    There is some caching involved since there might be shared setup work, e.g.
    connecting to the local Tor service only needs to happen once.
    """

    _default_connection_handlers: dict[str, str]
    _tor_provider: Optional[TorProvider]
    # Cache the Tor instance created by the provider, if relevant.
    _tor_instance: Optional[Tor] = None

    # If set, we're doing unit testing and we should call this with any
    # HTTPConnectionPool that gets passed/created to ``create_agent()``.
    TEST_MODE_REGISTER_HTTP_POOL: ClassVar[
        Optional[Callable[[HTTPConnectionPool], None]]
    ] = None

    @classmethod
    def start_test_mode(cls, callback: Callable[[HTTPConnectionPool], None]) -> None:
        """Switch to testing mode.

        In testing mode we register the pool with test system using the given
        callback so it can Do Things, most notably killing off idle HTTP
        connections at test shutdown and, in some tests, in the midddle of the
        test.
        """
        cls.TEST_MODE_REGISTER_HTTP_POOL = callback

    @classmethod
    def stop_test_mode(cls) -> None:
        """Stop testing mode."""
        cls.TEST_MODE_REGISTER_HTTP_POOL = None

    async def _create_agent(
        self,
        nurl: DecodedURL,
        reactor: object,
        tls_context_factory: IPolicyForHTTPS,
        pool: HTTPConnectionPool,
    ) -> IAgent:
        """Create a new ``IAgent``, possibly using Tor."""
        if self.TEST_MODE_REGISTER_HTTP_POOL is not None:
            self.TEST_MODE_REGISTER_HTTP_POOL(pool)

        # TODO default_connection_handlers should really be an object, not a
        # dict, so we can ask "is this using Tor" without poking at a
        # dictionary with arbitrary strings... See
        # https://tahoe-lafs.org/trac/tahoe-lafs/ticket/4032
        handler = self._default_connection_handlers["tcp"]

        if handler == "tcp":
            return Agent(reactor, tls_context_factory, pool=pool)
        if handler == "tor" or nurl.scheme == "pb+tor":
            assert self._tor_provider is not None
            if self._tor_instance is None:
                self._tor_instance = await self._tor_provider.get_tor_instance(reactor)
            return self._tor_instance.web_agent(
                pool=pool, tls_context_factory=tls_context_factory
            )
        else:
            # I2P support will be added here. See
            # https://tahoe-lafs.org/trac/tahoe-lafs/ticket/4037
            raise RuntimeError(f"Unsupported tcp connection handler: {handler}")

    async def create_storage_client(
        self,
        nurl: DecodedURL,
        reactor: IReactorTime,
        pool: Optional[HTTPConnectionPool] = None,
    ) -> StorageClient:
        """Create a new ``StorageClient`` for the given NURL."""
        assert nurl.fragment == "v=1"
        assert nurl.scheme in ("pb", "pb+tor")
        if pool is None:
            pool = HTTPConnectionPool(reactor)
            pool.maxPersistentPerHost = 10

        certificate_hash = nurl.user.encode("ascii")
        agent = await self._create_agent(
            nurl,
            reactor,
            _StorageClientHTTPSPolicy(expected_spki_hash=certificate_hash),
            pool,
        )
        treq_client = HTTPClient(agent)
        https_url = DecodedURL().replace(scheme="https", host=nurl.host, port=nurl.port)
        swissnum = nurl.path[0].encode("ascii")
        response_check = lambda _: None
        if self.TEST_MODE_REGISTER_HTTP_POOL is not None:
            response_check = response_is_not_html

        return StorageClient(
            https_url,
            swissnum,
            treq_client,
            pool,
            reactor,
            response_check,
        )


@define(hash=True)
class StorageClient:
    """
    Low-level HTTP client that talks to the HTTP storage server.

    Create using a ``StorageClientFactory`` instance.
    """

    # The URL should be a HTTPS URL ("https://...")
    _base_url: DecodedURL
    _swissnum: bytes
    _treq: HTTPClient
    _pool: HTTPConnectionPool
    _clock: IReactorTime
    # Are we running unit tests?
    _analyze_response: Callable[[IResponse], None] = lambda _: None

    def relative_url(self, path: str) -> DecodedURL:
        """Get a URL relative to the base URL."""
        return self._base_url.click(path)

    def _get_headers(self, headers: Optional[Headers]) -> Headers:
        """Return the basic headers to be used by default."""
        if headers is None:
            headers = Headers()
        headers.addRawHeader(
            "Authorization",
            swissnum_auth_header(self._swissnum),
        )
        return headers

    @async_to_deferred
    async def request(
        self,
        method: str,
        url: DecodedURL,
        lease_renew_secret: Optional[bytes] = None,
        lease_cancel_secret: Optional[bytes] = None,
        upload_secret: Optional[bytes] = None,
        write_enabler_secret: Optional[bytes] = None,
        headers: Optional[Headers] = None,
        message_to_serialize: object = None,
        timeout: float = 60,
        **kwargs,
    ) -> IResponse:
        """
        Like ``treq.request()``, but with optional secrets that get translated
        into corresponding HTTP headers.

        If ``message_to_serialize`` is set, it will be serialized (by default
        with CBOR) and set as the request body.  It should not be mutated
        during execution of this function!

        Default timeout is 60 seconds.
        """
        with start_action(
            action_type="allmydata:storage:http-client:request",
            method=method,
            url=url.to_text(),
            timeout=timeout,
        ) as ctx:
            response = await self._request(
                method,
                url,
                lease_renew_secret,
                lease_cancel_secret,
                upload_secret,
                write_enabler_secret,
                headers,
                message_to_serialize,
                timeout,
                **kwargs,
            )
            ctx.add_success_fields(response_code=response.code)
            return response

    async def _request(
        self,
        method: str,
        url: DecodedURL,
        lease_renew_secret: Optional[bytes] = None,
        lease_cancel_secret: Optional[bytes] = None,
        upload_secret: Optional[bytes] = None,
        write_enabler_secret: Optional[bytes] = None,
        headers: Optional[Headers] = None,
        message_to_serialize: object = None,
        timeout: float = 60,
        **kwargs,
    ) -> IResponse:
        """The implementation of request()."""
        headers = self._get_headers(headers)

        # Add secrets:
        for secret, value in [
            (Secrets.LEASE_RENEW, lease_renew_secret),
            (Secrets.LEASE_CANCEL, lease_cancel_secret),
            (Secrets.UPLOAD, upload_secret),
            (Secrets.WRITE_ENABLER, write_enabler_secret),
        ]:
            if value is None:
                continue
            headers.addRawHeader(
                "X-Tahoe-Authorization",
                b"%s %s" % (secret.value.encode("ascii"), b64encode(value).strip()),
            )

        # Note we can accept CBOR:
        headers.addRawHeader("Accept", CBOR_MIME_TYPE)

        # If there's a request message, serialize it and set the Content-Type
        # header:
        if message_to_serialize is not None:
            if "data" in kwargs:
                raise TypeError(
                    "Can't use both `message_to_serialize` and `data` "
                    "as keyword arguments at the same time"
                )
            kwargs["data"] = await defer_to_thread(dumps, message_to_serialize)
            headers.addRawHeader("Content-Type", CBOR_MIME_TYPE)

        response = await self._treq.request(
            method, url, headers=headers, timeout=timeout, **kwargs
        )
        self._analyze_response(response)

        return response

    async def decode_cbor(self, response: IResponse, schema: Schema) -> object:
        """Given HTTP response, return decoded CBOR body."""
        with start_action(action_type="allmydata:storage:http-client:decode-cbor"):
            if response.code > 199 and response.code < 300:
                content_type = get_content_type(response.headers)
                if content_type == CBOR_MIME_TYPE:
                    f = await limited_content(response, self._clock)
                    data = f.read()

                    def validate_and_decode():
                        return schema.validate_cbor(data, True)

                    return await defer_to_thread(validate_and_decode)
                else:
                    raise ClientException(
                        -1,
                        "Server didn't send CBOR, content type is {}".format(
                            content_type
                        ),
                    )
            else:
                data = (
                    await limited_content(response, self._clock, max_length=10_000)
                ).read()
                raise ClientException(response.code, response.phrase, data)

    def shutdown(self) -> Deferred[object]:
        """Shutdown any connections."""
        return self._pool.closeCachedConnections()


@define(hash=True)
class StorageClientGeneral:
    """
    High-level HTTP APIs that aren't immutable- or mutable-specific.
    """

    _client: StorageClient

    @async_to_deferred
    async def get_version(self) -> VersionMessage:
        """
        Return the version metadata for the server.
        """
        with start_action(
            action_type="allmydata:storage:http-client:get-version",
        ):
            return await self._get_version()

    async def _get_version(self) -> VersionMessage:
        """Implementation of get_version()."""
        url = self._client.relative_url("/storage/v1/version")
        response = await self._client.request("GET", url)
        decoded_response = cast(
            Dict[bytes, object],
            await self._client.decode_cbor(response, _SCHEMAS["get_version"]),
        )
        # Add some features we know are true because the HTTP API
        # specification requires them and because other parts of the storage
        # client implementation assumes they will be present.
        cast(
            Dict[bytes, object],
            decoded_response[b"http://allmydata.org/tahoe/protocols/storage/v1"],
        ).update(
            {
                b"tolerates-immutable-read-overrun": True,
                b"delete-mutable-shares-with-zero-length-writev": True,
                b"fills-holes-with-zero-bytes": True,
                b"prevents-read-past-end-of-share-data": True,
            }
        )
        return decoded_response

    @async_to_deferred
    async def add_or_renew_lease(
        self, storage_index: bytes, renew_secret: bytes, cancel_secret: bytes
    ) -> None:
        """
        Add or renew a lease.

        If the renewal secret matches an existing lease, it is renewed.
        Otherwise a new lease is added.
        """
        with start_action(
            action_type="allmydata:storage:http-client:add-or-renew-lease",
            storage_index=si_to_human_readable(storage_index),
        ):
            return await self._add_or_renew_lease(
                storage_index, renew_secret, cancel_secret
            )

    async def _add_or_renew_lease(
        self, storage_index: bytes, renew_secret: bytes, cancel_secret: bytes
    ) -> None:
        url = self._client.relative_url(
            "/storage/v1/lease/{}".format(_encode_si(storage_index))
        )
        response = await self._client.request(
            "PUT",
            url,
            lease_renew_secret=renew_secret,
            lease_cancel_secret=cancel_secret,
        )

        if response.code == http.NO_CONTENT:
            return
        else:
            raise ClientException(response.code)


@define
class UploadProgress:
    """
    Progress of immutable upload, per the server.
    """

    # True when upload has finished.
    finished: bool
    # Remaining ranges to upload.
    required: RangeMap


@async_to_deferred
async def read_share_chunk(
    client: StorageClient,
    share_type: str,
    storage_index: bytes,
    share_number: int,
    offset: int,
    length: int,
) -> bytes:
    """
    Download a chunk of data from a share.

    TODO https://tahoe-lafs.org/trac/tahoe-lafs/ticket/3857 Failed downloads
    should be transparently retried and redownloaded by the implementation a
    few times so that if a failure percolates up, the caller can assume the
    failure isn't a short-term blip.

    NOTE: the underlying HTTP protocol is somewhat more flexible than this API,
    insofar as it doesn't always require a range.  In practice a range is
    always provided by the current callers.
    """
    url = client.relative_url(
        "/storage/v1/{}/{}/{}".format(
            share_type, _encode_si(storage_index), share_number
        )
    )
    # The default 60 second timeout is for getting the response, so it doesn't
    # include the time it takes to download the body... so we will will deal
    # with that later, via limited_content().
    response = await client.request(
        "GET",
        url,
        headers=Headers(
            # Ranges in HTTP are _inclusive_, Python's convention is exclusive,
            # but Range constructor does that the conversion for us.
            {"range": [Range("bytes", [(offset, offset + length)]).to_header()]}
        ),
        unbuffered=True,  # Don't buffer the response in memory.
    )

    if response.code == http.NO_CONTENT:
        return b""

    content_type = get_content_type(response.headers)
    if content_type != "application/octet-stream":
        raise ValueError(
            f"Content-type was wrong: {content_type}, should be application/octet-stream"
        )

    if response.code == http.PARTIAL_CONTENT:
        content_range = parse_content_range_header(
            response.headers.getRawHeaders("content-range")[0] or ""
        )
        if (
            content_range is None
            or content_range.stop is None
            or content_range.start is None
        ):
            raise ValueError(
                "Content-Range was missing, invalid, or in format we don't support"
            )
        supposed_length = content_range.stop - content_range.start
        if supposed_length > length:
            raise ValueError("Server sent more than we asked for?!")
        # It might also send less than we asked for. That's (probably) OK, e.g.
        # if we went past the end of the file.
        body = await limited_content(response, client._clock, supposed_length)
        body.seek(0, SEEK_END)
        actual_length = body.tell()
        if actual_length != supposed_length:
            # Most likely a mutable that got changed out from under us, but
            # conceivably could be a bug...
            raise ValueError(
                f"Length of response sent from server ({actual_length}) "
                + f"didn't match Content-Range header ({supposed_length})"
            )
        body.seek(0)
        return body.read()
    else:
        # Technically HTTP allows sending an OK with full body under these
        # circumstances, but the server is not designed to do that so we ignore
        # that possibility for now...
        raise ClientException(response.code)


@async_to_deferred
async def advise_corrupt_share(
    client: StorageClient,
    share_type: str,
    storage_index: bytes,
    share_number: int,
    reason: str,
) -> None:
    assert isinstance(reason, str)
    url = client.relative_url(
        "/storage/v1/{}/{}/{}/corrupt".format(
            share_type, _encode_si(storage_index), share_number
        )
    )
    message = {"reason": reason}
    response = await client.request("POST", url, message_to_serialize=message)
    if response.code == http.OK:
        return
    else:
        raise ClientException(
            response.code,
        )


@define(hash=True)
class StorageClientImmutables:
    """
    APIs for interacting with immutables.
    """

    _client: StorageClient

    @async_to_deferred
    async def create(
        self,
        storage_index: bytes,
        share_numbers: set[int],
        allocated_size: int,
        upload_secret: bytes,
        lease_renew_secret: bytes,
        lease_cancel_secret: bytes,
    ) -> ImmutableCreateResult:
        """
        Create a new storage index for an immutable.

        TODO https://tahoe-lafs.org/trac/tahoe-lafs/ticket/3857 retry
        internally on failure, to ensure the operation fully succeeded.  If
        sufficient number of failures occurred, the result may fire with an
        error, but there's no expectation that user code needs to have a
        recovery codepath; it will most likely just report an error to the
        user.

        Result fires when creating the storage index succeeded, if creating the
        storage index failed the result will fire with an exception.
        """
        with start_action(
            action_type="allmydata:storage:http-client:immutable:create",
            storage_index=si_to_human_readable(storage_index),
            share_numbers=share_numbers,
            allocated_size=allocated_size,
        ) as ctx:
            result = await self._create(
                storage_index,
                share_numbers,
                allocated_size,
                upload_secret,
                lease_renew_secret,
                lease_cancel_secret,
            )
            ctx.add_success_fields(
                already_have=result.already_have, allocated=result.allocated
            )
            return result

    async def _create(
        self,
        storage_index: bytes,
        share_numbers: set[int],
        allocated_size: int,
        upload_secret: bytes,
        lease_renew_secret: bytes,
        lease_cancel_secret: bytes,
    ) -> ImmutableCreateResult:
        """Implementation of create()."""
        url = self._client.relative_url(
            "/storage/v1/immutable/" + _encode_si(storage_index)
        )
        message = {"share-numbers": share_numbers, "allocated-size": allocated_size}

        response = await self._client.request(
            "POST",
            url,
            lease_renew_secret=lease_renew_secret,
            lease_cancel_secret=lease_cancel_secret,
            upload_secret=upload_secret,
            message_to_serialize=message,
        )
        decoded_response = cast(
            Mapping[str, Set[int]],
            await self._client.decode_cbor(response, _SCHEMAS["allocate_buckets"]),
        )
        return ImmutableCreateResult(
            already_have=decoded_response["already-have"],
            allocated=decoded_response["allocated"],
        )

    @async_to_deferred
    async def abort_upload(
        self, storage_index: bytes, share_number: int, upload_secret: bytes
    ) -> None:
        """Abort the upload."""
        with start_action(
            action_type="allmydata:storage:http-client:immutable:abort-upload",
            storage_index=si_to_human_readable(storage_index),
            share_number=share_number,
        ):
            return await self._abort_upload(storage_index, share_number, upload_secret)

    async def _abort_upload(
        self, storage_index: bytes, share_number: int, upload_secret: bytes
    ) -> None:
        """Implementation of ``abort_upload()``."""
        url = self._client.relative_url(
            "/storage/v1/immutable/{}/{}/abort".format(
                _encode_si(storage_index), share_number
            )
        )
        response = await self._client.request(
            "PUT",
            url,
            upload_secret=upload_secret,
        )

        if response.code == http.OK:
            return
        else:
            raise ClientException(
                response.code,
            )

    @async_to_deferred
    async def write_share_chunk(
        self,
        storage_index: bytes,
        share_number: int,
        upload_secret: bytes,
        offset: int,
        data: bytes,
    ) -> UploadProgress:
        """
        Upload a chunk of data for a specific share.

        TODO https://tahoe-lafs.org/trac/tahoe-lafs/ticket/3857 The
        implementation should retry failed uploads transparently a number of
        times, so that if a failure percolates up, the caller can assume the
        failure isn't a short-term blip.

        Result fires when the upload succeeded, with a boolean indicating
        whether the _complete_ share (i.e. all chunks, not just this one) has
        been uploaded.
        """
        with start_action(
            action_type="allmydata:storage:http-client:immutable:write-share-chunk",
            storage_index=si_to_human_readable(storage_index),
            share_number=share_number,
            offset=offset,
            data_len=len(data),
        ) as ctx:
            result = await self._write_share_chunk(
                storage_index, share_number, upload_secret, offset, data
            )
            ctx.add_success_fields(finished=result.finished)
            return result

    async def _write_share_chunk(
        self,
        storage_index: bytes,
        share_number: int,
        upload_secret: bytes,
        offset: int,
        data: bytes,
    ) -> UploadProgress:
        """Implementation of ``write_share_chunk()``."""
        url = self._client.relative_url(
            "/storage/v1/immutable/{}/{}".format(
                _encode_si(storage_index), share_number
            )
        )
        response = await self._client.request(
            "PATCH",
            url,
            upload_secret=upload_secret,
            data=data,
            headers=Headers(
                {
                    "content-range": [
                        ContentRange("bytes", offset, offset + len(data)).to_header()
                    ]
                }
            ),
        )

        if response.code == http.OK:
            # Upload is still unfinished.
            finished = False
        elif response.code == http.CREATED:
            # Upload is done!
            finished = True
        else:
            raise ClientException(
                response.code,
            )
        body = cast(
            Mapping[str, Sequence[Mapping[str, int]]],
            await self._client.decode_cbor(
                response, _SCHEMAS["immutable_write_share_chunk"]
            ),
        )
        remaining = RangeMap()
        for chunk in body["required"]:
            remaining.set(True, chunk["begin"], chunk["end"])
        return UploadProgress(finished=finished, required=remaining)

    @async_to_deferred
    async def read_share_chunk(
        self, storage_index: bytes, share_number: int, offset: int, length: int
    ) -> bytes:
        """
        Download a chunk of data from a share.
        """
        with start_action(
            action_type="allmydata:storage:http-client:immutable:read-share-chunk",
            storage_index=si_to_human_readable(storage_index),
            share_number=share_number,
            offset=offset,
            length=length,
        ) as ctx:
            result = await read_share_chunk(
                self._client, "immutable", storage_index, share_number, offset, length
            )
            ctx.add_success_fields(data_len=len(result))
            return result

    @async_to_deferred
    async def list_shares(self, storage_index: bytes) -> Set[int]:
        """
        Return the set of shares for a given storage index.
        """
        with start_action(
            action_type="allmydata:storage:http-client:immutable:list-shares",
            storage_index=si_to_human_readable(storage_index),
        ) as ctx:
            result = await self._list_shares(storage_index)
            ctx.add_success_fields(shares=result)
            return result

    async def _list_shares(self, storage_index: bytes) -> Set[int]:
        """Implementation of ``list_shares()``."""
        url = self._client.relative_url(
            "/storage/v1/immutable/{}/shares".format(_encode_si(storage_index))
        )
        response = await self._client.request(
            "GET",
            url,
        )
        if response.code == http.OK:
            return cast(
                Set[int],
                await self._client.decode_cbor(response, _SCHEMAS["list_shares"]),
            )
        else:
            raise ClientException(response.code)

    @async_to_deferred
    async def advise_corrupt_share(
        self,
        storage_index: bytes,
        share_number: int,
        reason: str,
    ) -> None:
        """Indicate a share has been corrupted, with a human-readable message."""
        with start_action(
            action_type="allmydata:storage:http-client:immutable:advise-corrupt-share",
            storage_index=si_to_human_readable(storage_index),
            share_number=share_number,
            reason=reason,
        ):
            await advise_corrupt_share(
                self._client, "immutable", storage_index, share_number, reason
            )


@frozen
class WriteVector:
    """Data to write to a chunk."""

    offset: int
    data: bytes


@frozen
class TestVector:
    """Checks to make on a chunk before writing to it."""

    offset: int
    size: int
    specimen: bytes


@frozen
class ReadVector:
    """
    Reads to do on chunks, as part of a read/test/write operation.
    """

    offset: int
    size: int


@frozen
class TestWriteVectors:
    """Test and write vectors for a specific share."""

    test_vectors: Sequence[TestVector] = field(factory=list)
    write_vectors: Sequence[WriteVector] = field(factory=list)
    new_length: Optional[int] = None

    def asdict(self) -> dict:
        """Return dictionary suitable for sending over CBOR."""
        d = asdict(self)
        d["test"] = d.pop("test_vectors")
        d["write"] = d.pop("write_vectors")
        d["new-length"] = d.pop("new_length")
        return d


@frozen
class ReadTestWriteResult:
    """Result of sending read-test-write vectors."""

    success: bool
    # Map share numbers to reads corresponding to the request's list of
    # ReadVectors:
    reads: Mapping[int, Sequence[bytes]]


# Result type for mutable read/test/write HTTP response. Can't just use
# dict[int,list[bytes]] because on Python 3.8 that will error out.
MUTABLE_RTW = TypedDict(
    "MUTABLE_RTW", {"success": bool, "data": Mapping[int, Sequence[bytes]]}
)


@frozen
class StorageClientMutables:
    """
    APIs for interacting with mutables.
    """

    _client: StorageClient

    @async_to_deferred
    async def read_test_write_chunks(
        self,
        storage_index: bytes,
        write_enabler_secret: bytes,
        lease_renew_secret: bytes,
        lease_cancel_secret: bytes,
        testwrite_vectors: dict[int, TestWriteVectors],
        read_vector: list[ReadVector],
    ) -> ReadTestWriteResult:
        """
        Read, test, and possibly write chunks to a particular mutable storage
        index.

        Reads are done before writes.

        Given a mapping between share numbers and test/write vectors, the tests
        are done and if they are valid the writes are done.
        """
        with start_action(
            action_type="allmydata:storage:http-client:mutable:read-test-write",
            storage_index=si_to_human_readable(storage_index),
        ):
            return await self._read_test_write_chunks(
                storage_index,
                write_enabler_secret,
                lease_renew_secret,
                lease_cancel_secret,
                testwrite_vectors,
                read_vector,
            )

    async def _read_test_write_chunks(
        self,
        storage_index: bytes,
        write_enabler_secret: bytes,
        lease_renew_secret: bytes,
        lease_cancel_secret: bytes,
        testwrite_vectors: dict[int, TestWriteVectors],
        read_vector: list[ReadVector],
    ) -> ReadTestWriteResult:
        """Implementation of ``read_test_write_chunks()``."""
        url = self._client.relative_url(
            "/storage/v1/mutable/{}/read-test-write".format(_encode_si(storage_index))
        )
        message = {
            "test-write-vectors": {
                share_number: twv.asdict()
                for (share_number, twv) in testwrite_vectors.items()
            },
            "read-vector": [asdict(r) for r in read_vector],
        }
        response = await self._client.request(
            "POST",
            url,
            write_enabler_secret=write_enabler_secret,
            lease_renew_secret=lease_renew_secret,
            lease_cancel_secret=lease_cancel_secret,
            message_to_serialize=message,
        )
        if response.code == http.OK:
            result = cast(
                MUTABLE_RTW,
                await self._client.decode_cbor(
                    response, _SCHEMAS["mutable_read_test_write"]
                ),
            )
            return ReadTestWriteResult(success=result["success"], reads=result["data"])
        else:
            raise ClientException(response.code, (await response.content()))

    @async_to_deferred
    async def read_share_chunk(
        self,
        storage_index: bytes,
        share_number: int,
        offset: int,
        length: int,
    ) -> bytes:
        """
        Download a chunk of data from a share.
        """
        with start_action(
            action_type="allmydata:storage:http-client:mutable:read-share-chunk",
            storage_index=si_to_human_readable(storage_index),
            share_number=share_number,
            offset=offset,
            length=length,
        ) as ctx:
            result = await read_share_chunk(
                self._client, "mutable", storage_index, share_number, offset, length
            )
            ctx.add_success_fields(data_len=len(result))
            return result

    @async_to_deferred
    async def list_shares(self, storage_index: bytes) -> Set[int]:
        """
        List the share numbers for a given storage index.
        """
        with start_action(
            action_type="allmydata:storage:http-client:mutable:list-shares",
            storage_index=si_to_human_readable(storage_index),
        ) as ctx:
            result = await self._list_shares(storage_index)
            ctx.add_success_fields(shares=result)
            return result

    async def _list_shares(self, storage_index: bytes) -> Set[int]:
        """Implementation of ``list_shares()``."""
        url = self._client.relative_url(
            "/storage/v1/mutable/{}/shares".format(_encode_si(storage_index))
        )
        response = await self._client.request("GET", url)
        if response.code == http.OK:
            return cast(
                Set[int],
                await self._client.decode_cbor(
                    response,
                    _SCHEMAS["mutable_list_shares"],
                ),
            )
        else:
            raise ClientException(response.code)

    @async_to_deferred
    async def advise_corrupt_share(
        self,
        storage_index: bytes,
        share_number: int,
        reason: str,
    ) -> None:
        """Indicate a share has been corrupted, with a human-readable message."""
        with start_action(
            action_type="allmydata:storage:http-client:mutable:advise-corrupt-share",
            storage_index=si_to_human_readable(storage_index),
            share_number=share_number,
            reason=reason,
        ):
            await advise_corrupt_share(
                self._client, "mutable", storage_index, share_number, reason
            )
