source: trunk/src/allmydata/test/common.py

Last change on this file was fec97256, checked in by Alexandre Detiste <alexandre.detiste@…>, at 2025-01-06T21:51:37Z

trim Python2 syntax

  • Property mode set to 100644
File size: 52.4 KB
Line 
1"""
2Functionality related to a lot of the test suite.
3"""
4from __future__ import annotations
5
6__all__ = [
7    "SyncTestCase",
8    "AsyncTestCase",
9    "AsyncBrokenTestCase",
10    "TrialTestCase",
11
12    "flush_logged_errors",
13    "skip",
14    "skipIf",
15
16    # Selected based on platform and re-exported for convenience.
17    "Popen",
18    "PIPE",
19]
20
21import sys
22import os, random, struct
23from contextlib import contextmanager
24import six
25import tempfile
26from tempfile import mktemp
27from functools import partial
28from unittest import case as _case
29from socket import (
30    AF_INET,
31    SOCK_STREAM,
32    SOMAXCONN,
33    socket,
34    error as socket_error,
35)
36from errno import (
37    EADDRINUSE,
38)
39
40import attr
41
42import treq
43
44from zope.interface import implementer
45
46from testtools import (
47    TestCase,
48    skip,
49    skipIf,
50)
51from testtools.twistedsupport import (
52    SynchronousDeferredRunTest,
53    AsynchronousDeferredRunTest,
54    AsynchronousDeferredRunTestForBrokenTwisted,
55    flush_logged_errors,
56)
57
58from twisted.application import service
59from twisted.plugin import IPlugin
60from twisted.internet import defer
61from twisted.internet.defer import inlineCallbacks, returnValue
62from twisted.internet.interfaces import IPullProducer
63from twisted.python import failure
64from twisted.python.filepath import FilePath
65from twisted.web.error import Error as WebError
66from twisted.internet.interfaces import (
67    IStreamServerEndpointStringParser,
68    IReactorSocket,
69)
70from twisted.internet.endpoints import AdoptedStreamServerEndpoint
71from twisted.trial.unittest import TestCase as _TrialTestCase
72
73from allmydata import uri
74from allmydata.interfaces import (
75    IMutableFileNode,
76    IImmutableFileNode,
77    NotEnoughSharesError,
78    ICheckable,
79    IMutableUploadable,
80    SDMF_VERSION,
81    MDMF_VERSION,
82    IAddressFamily,
83    NoSpace,
84)
85from allmydata.check_results import CheckResults, CheckAndRepairResults, \
86     DeepCheckResults, DeepCheckAndRepairResults
87from allmydata.storage_client import StubServer
88from allmydata.mutable.layout import unpack_header
89from allmydata.mutable.publish import MutableData
90from allmydata.storage.mutable import MutableShareFile
91from allmydata.util import hashutil, log, iputil
92from allmydata.util.assertutil import precondition
93from allmydata.util.consumer import download_to_data
94import allmydata.test.common_util as testutil
95from allmydata.immutable.upload import Uploader
96from allmydata.client import (
97    config_from_string,
98    create_client_from_config,
99)
100from allmydata.scripts.common import (
101    write_introducer,
102    )
103
104from ..crypto import (
105    ed25519,
106    rsa,
107)
108from .eliotutil import (
109    EliotLoggedRunTest,
110)
111from .common_util import ShouldFailMixin  # noqa: F401
112
113from subprocess import (
114    Popen,
115    PIPE,
116)
117
118# Is the process running as an OS user with elevated privileges (ie, root)?
119# We only know how to determine this for POSIX systems.
120superuser = getattr(os, "getuid", lambda: -1)() == 0
121
122EMPTY_CLIENT_CONFIG = config_from_string(
123    "/dev/null",
124    "tub.port",
125    ""
126)
127
128def byteschr(x):
129    return bytes([x])
130
131@attr.s
132class FakeDisk:
133    """
134    Just enough of a disk to be able to report free / used information.
135    """
136    total = attr.ib()
137    used = attr.ib()
138
139    def use(self, num_bytes):
140        """
141        Mark some amount of available bytes as used (and no longer available).
142
143        :param int num_bytes: The number of bytes to use.
144
145        :raise NoSpace: If there are fewer bytes available than ``num_bytes``.
146
147        :return: ``None``
148        """
149        if num_bytes > self.total - self.used:
150            raise NoSpace()
151        self.used += num_bytes
152
153    @property
154    def available(self):
155        return self.total - self.used
156
157    def get_disk_stats(self, whichdir, reserved_space):
158        avail = self.available
159        return {
160            'total': self.total,
161            'free_for_root': avail,
162            'free_for_nonroot': avail,
163            'used': self.used,
164            'avail': avail - reserved_space,
165        }
166
167
168@attr.s
169class MemoryIntroducerClient:
170    """
171    A model-only (no behavior) stand-in for ``IntroducerClient``.
172    """
173    tub = attr.ib()
174    introducer_furl = attr.ib()
175    nickname = attr.ib()
176    my_version = attr.ib()
177    oldest_supported = attr.ib()
178    sequencer = attr.ib()
179    cache_filepath = attr.ib()
180
181    subscribed_to : list[Subscription] = attr.ib(default=attr.Factory(list))
182    published_announcements : list[Announcement] = attr.ib(default=attr.Factory(list))
183
184
185    def setServiceParent(self, parent):
186        pass
187
188
189    def subscribe_to(self, service_name, cb, *args, **kwargs):
190        self.subscribed_to.append(Subscription(service_name, cb, args, kwargs))
191
192
193    def publish(self, service_name, ann, signing_key):
194        self.published_announcements.append(Announcement(
195            service_name,
196            ann,
197            ed25519.string_from_signing_key(signing_key),
198        ))
199
200
201@attr.s
202class Subscription:
203    """
204    A model of an introducer subscription.
205    """
206    service_name = attr.ib()
207    cb = attr.ib()
208    args = attr.ib()
209    kwargs = attr.ib()
210
211
212@attr.s
213class Announcement:
214    """
215    A model of an introducer announcement.
216    """
217    service_name = attr.ib()
218    ann = attr.ib()
219    signing_key_bytes = attr.ib(type=bytes)
220
221    @property
222    def signing_key(self):
223        return ed25519.signing_keypair_from_string(self.signing_key_bytes)[0]
224
225
226def get_published_announcements(client):
227    """
228    Get a flattened list of all announcements sent using all introducer
229    clients.
230    """
231    return list(
232        announcement
233        for introducer_client
234        in client.introducer_clients
235        for announcement
236        in introducer_client.published_announcements
237    )
238
239
240class UseTestPlugins:
241    """
242    A fixture which enables loading Twisted plugins from the Tahoe-LAFS test
243    suite.
244    """
245    def setUp(self):
246        """
247        Add the testing package ``plugins`` directory to the ``twisted.plugins``
248        aggregate package.
249        """
250        import twisted.plugins
251        testplugins = FilePath(__file__).sibling("plugins")
252        twisted.plugins.__path__.insert(0, testplugins.path)
253
254    def cleanUp(self):
255        """
256        Remove the testing package ``plugins`` directory from the
257        ``twisted.plugins`` aggregate package.
258        """
259        import twisted.plugins
260        testplugins = FilePath(__file__).sibling("plugins")
261        twisted.plugins.__path__.remove(testplugins.path)
262
263    def getDetails(self):
264        return {}
265
266
267@attr.s
268class UseNode:
269    """
270    A fixture which creates a client node.
271
272    :ivar dict[bytes, bytes] plugin_config: Configuration items to put in the
273        node's configuration.
274
275    :ivar bytes storage_plugin: The name of a storage plugin to enable.
276
277    :ivar FilePath basedir: The base directory of the node.
278
279    :ivar str introducer_furl: The introducer furl with which to
280        configure the client.
281
282    :ivar dict[bytes, bytes] node_config: Configuration items for the *node*
283        section of the configuration.
284
285    :ivar _Config config: The complete resulting configuration.
286    """
287    plugin_config = attr.ib()
288    storage_plugin = attr.ib()
289    basedir = attr.ib(validator=attr.validators.instance_of(FilePath))
290    introducer_furl = attr.ib(validator=attr.validators.instance_of(str),
291                              converter=six.ensure_str)
292    node_config : dict[bytes,bytes] = attr.ib(default=attr.Factory(dict))
293
294    config = attr.ib(default=None)
295    reactor = attr.ib(default=None)
296
297    def setUp(self):
298        self.assigner = SameProcessStreamEndpointAssigner()
299        self.assigner.setUp()
300
301        def format_config_items(config):
302            return "\n".join(
303                " = ".join((key, value))
304                for (key, value)
305                in list(config.items())
306            )
307
308        if self.plugin_config is None:
309            plugin_config_section = ""
310        else:
311            plugin_config_section = (
312                "[storageclient.plugins.{storage_plugin}]\n"
313                "{config}\n").format(
314                    storage_plugin=self.storage_plugin,
315                    config=format_config_items(self.plugin_config),
316                )
317
318        if self.storage_plugin is None:
319            plugins = ""
320        else:
321            plugins = "storage.plugins = {}".format(self.storage_plugin)
322
323        write_introducer(
324            self.basedir,
325            "default",
326            self.introducer_furl,
327        )
328
329        node_config = self.node_config.copy()
330        if "tub.port" not in node_config:
331            if "tub.location" in node_config:
332                raise ValueError(
333                    "UseNode fixture does not support specifying tub.location "
334                    "without tub.port"
335                )
336
337            # Don't use the normal port auto-assignment logic.  It produces
338            # collisions and makes tests fail spuriously.
339            tub_location, tub_endpoint = self.assigner.assign(self.reactor)
340            node_config.update({
341                "tub.port": tub_endpoint,
342                "tub.location": tub_location,
343            })
344
345        self.config = config_from_string(
346            self.basedir.asTextMode().path,
347            "tub.port",
348            "[node]\n"
349            "{node_config}\n"
350            "\n"
351            "[client]\n"
352            "{plugins}\n"
353            "{plugin_config_section}\n"
354            .format(
355                plugins=plugins,
356                node_config=format_config_items(node_config),
357                plugin_config_section=plugin_config_section,
358            )
359        )
360
361    def create_node(self):
362        return create_client_from_config(
363            self.config,
364            _introducer_factory=MemoryIntroducerClient,
365        )
366
367    def cleanUp(self):
368        self.assigner.tearDown()
369
370
371    def getDetails(self):
372        return {}
373
374
375
376@implementer(IPlugin, IStreamServerEndpointStringParser)
377class AdoptedServerPort:
378    """
379    Parse an ``adopt-socket:<fd>`` endpoint description by adopting ``fd`` as
380    a listening TCP port.
381    """
382    prefix = "adopt-socket"
383
384    def parseStreamServer(self, reactor, fd): # type: ignore # https://twistedmatrix.com/trac/ticket/10134
385        log.msg("Adopting {}".format(fd))
386        # AdoptedStreamServerEndpoint wants to own the file descriptor.  It
387        # will duplicate it and then close the one we pass in.  This means it
388        # is really only possible to adopt a particular file descriptor once.
389        #
390        # This wouldn't matter except one of the tests wants to stop one of
391        # the nodes and start it up again.  This results in exactly an attempt
392        # to adopt a particular file descriptor twice.
393        #
394        # So we'll dup it ourselves.  AdoptedStreamServerEndpoint can do
395        # whatever it wants to the result - the original will still be valid
396        # and reusable.
397        return AdoptedStreamServerEndpoint(reactor, os.dup(int(fd)), AF_INET)
398
399
400def really_bind(s, addr):
401    # Arbitrarily decide we'll try 100 times.  We don't want to try forever in
402    # case this is a persistent problem.  Trying is cheap, though, so we may
403    # as well try a lot.  Hopefully the OS isn't so bad at allocating a port
404    # for us that it takes more than 2 iterations.
405    for i in range(100):
406        try:
407            s.bind(addr)
408        except socket_error as e:
409            if e.errno == EADDRINUSE:
410                continue
411            raise
412        else:
413            return
414    raise Exception("Many bind attempts failed with EADDRINUSE")
415
416
417class SameProcessStreamEndpointAssigner:
418    """
419    A fixture which can assign streaming server endpoints for use *in this
420    process only*.
421
422    An effort is made to avoid address collisions for this port but the logic
423    for doing so is platform-dependent (sorry, Windows).
424
425    This is more reliable than trying to listen on a hard-coded non-zero port
426    number.  It is at least as reliable as trying to listen on port number
427    zero on Windows and more reliable than doing that on other platforms.
428    """
429    def setUp(self):
430        self._cleanups = []
431        # Make sure the `adopt-socket` endpoint is recognized.  We do this
432        # instead of providing a dropin because we don't want to make this
433        # endpoint available to random other applications.
434        f = UseTestPlugins()
435        f.setUp()
436        self._cleanups.append(f.cleanUp)
437
438    def tearDown(self):
439        for c in self._cleanups:
440            c()
441
442    def assign(self, reactor):
443        """
444        Make a new streaming server endpoint and return its string description.
445
446        This is intended to help write config files that will then be read and
447        used in this process.
448
449        :param reactor: The reactor which will be used to listen with the
450            resulting endpoint.  If it provides ``IReactorSocket`` then
451            resulting reliability will be extremely high.  If it doesn't,
452            resulting reliability will be pretty alright.
453
454        :return: A two-tuple of (location hint, port endpoint description) as
455            strings.
456        """
457        if sys.platform != "win32" and IReactorSocket.providedBy(reactor):
458            # On this platform, we can reliable pre-allocate a listening port.
459            # Once it is bound we know it will not fail later with EADDRINUSE.
460            s = socket(AF_INET, SOCK_STREAM)
461            # We need to keep ``s`` alive as long as the file descriptor we put in
462            # this string might still be used.  We could dup() the descriptor
463            # instead but then we've only inverted the cleanup problem: gone from
464            # don't-close-too-soon to close-just-late-enough.  So we'll leave
465            # ``s`` alive and use it as the cleanup mechanism.
466            self._cleanups.append(s.close)
467            s.setblocking(False)
468            really_bind(s, ("127.0.0.1", 0))
469            s.listen(SOMAXCONN)
470            host, port = s.getsockname()
471            location_hint = "tcp:%s:%d" % (host, port)
472            port_endpoint = "adopt-socket:fd=%d" % (s.fileno(),)
473        else:
474            # On other platforms, we blindly guess and hope we get lucky.
475            portnum = iputil.allocate_tcp_port()
476            location_hint = "tcp:127.0.0.1:%d" % (portnum,)
477            port_endpoint = "tcp:%d:interface=127.0.0.1" % (portnum,)
478
479        return location_hint, port_endpoint
480
481@implementer(IPullProducer)
482class DummyProducer:
483    def resumeProducing(self):
484        pass
485
486    def stopProducing(self):
487        pass
488
489@implementer(IImmutableFileNode)
490class FakeCHKFileNode(object):  # type: ignore # incomplete implementation
491    """I provide IImmutableFileNode, but all of my data is stored in a
492    class-level dictionary."""
493
494    def __init__(self, filecap, all_contents):
495        precondition(isinstance(filecap, (uri.CHKFileURI, uri.LiteralFileURI)), filecap)
496        self.all_contents = all_contents
497        self.my_uri = filecap
498        self.storage_index = self.my_uri.get_storage_index()
499
500    def get_uri(self):
501        return self.my_uri.to_string()
502    def get_write_uri(self):
503        return None
504    def get_readonly_uri(self):
505        return self.my_uri.to_string()
506    def get_cap(self):
507        return self.my_uri
508    def get_verify_cap(self):
509        return self.my_uri.get_verify_cap()
510    def get_repair_cap(self):
511        return self.my_uri.get_verify_cap()
512    def get_storage_index(self):
513        return self.storage_index
514
515    def check(self, monitor, verify=False, add_lease=False):
516        s = StubServer(b"\x00"*20)
517        r = CheckResults(self.my_uri, self.storage_index,
518                         healthy=True, recoverable=True,
519                         count_happiness=10,
520                         count_shares_needed=3,
521                         count_shares_expected=10,
522                         count_shares_good=10,
523                         count_good_share_hosts=10,
524                         count_recoverable_versions=1,
525                         count_unrecoverable_versions=0,
526                         servers_responding=[s],
527                         sharemap={1: [s]},
528                         count_wrong_shares=0,
529                         list_corrupt_shares=[],
530                         count_corrupt_shares=0,
531                         list_incompatible_shares=[],
532                         count_incompatible_shares=0,
533                         summary="",
534                         report=[],
535                         share_problems=[],
536                         servermap=None)
537        return defer.succeed(r)
538    def check_and_repair(self, monitor, verify=False, add_lease=False):
539        d = self.check(verify)
540        def _got(cr):
541            r = CheckAndRepairResults(self.storage_index)
542            r.pre_repair_results = r.post_repair_results = cr
543            return r
544        d.addCallback(_got)
545        return d
546
547    def is_mutable(self):
548        return False
549    def is_readonly(self):
550        return True
551    def is_unknown(self):
552        return False
553    def is_allowed_in_immutable_directory(self):
554        return True
555    def raise_error(self):
556        pass
557
558    def get_size(self):
559        if isinstance(self.my_uri, uri.LiteralFileURI):
560            return self.my_uri.get_size()
561        try:
562            data = self.all_contents[self.my_uri.to_string()]
563        except KeyError as le:
564            raise NotEnoughSharesError(le, 0, 3)
565        return len(data)
566    def get_current_size(self):
567        return defer.succeed(self.get_size())
568
569    def read(self, consumer, offset=0, size=None):
570        # we don't bother to call registerProducer/unregisterProducer,
571        # because it's a hassle to write a dummy Producer that does the right
572        # thing (we have to make sure that DummyProducer.resumeProducing
573        # writes the data into the consumer immediately, otherwise it will
574        # loop forever).
575
576        d = defer.succeed(None)
577        d.addCallback(self._read, consumer, offset, size)
578        return d
579
580    def _read(self, ignored, consumer, offset, size):
581        if isinstance(self.my_uri, uri.LiteralFileURI):
582            data = self.my_uri.data
583        else:
584            if self.my_uri.to_string() not in self.all_contents:
585                raise NotEnoughSharesError(None, 0, 3)
586            data = self.all_contents[self.my_uri.to_string()]
587        start = offset
588        if size is not None:
589            end = offset + size
590        else:
591            end = len(data)
592        consumer.write(data[start:end])
593        return consumer
594
595
596    def get_best_readable_version(self):
597        return defer.succeed(self)
598
599
600    def download_to_data(self):
601        return download_to_data(self)
602
603
604    download_best_version = download_to_data
605
606
607    def get_size_of_best_version(self):
608        return defer.succeed(self.get_size)
609
610
611def make_chk_file_cap(size):
612    return uri.CHKFileURI(key=os.urandom(16),
613                          uri_extension_hash=os.urandom(32),
614                          needed_shares=3,
615                          total_shares=10,
616                          size=size)
617def make_chk_file_uri(size):
618    return make_chk_file_cap(size).to_string()
619
620def create_chk_filenode(contents, all_contents):
621    filecap = make_chk_file_cap(len(contents))
622    n = FakeCHKFileNode(filecap, all_contents)
623    all_contents[filecap.to_string()] = contents
624    return n
625
626
627@implementer(IMutableFileNode, ICheckable)
628class FakeMutableFileNode(object):  # type: ignore # incomplete implementation
629    """I provide IMutableFileNode, but all of my data is stored in a
630    class-level dictionary."""
631
632    MUTABLE_SIZELIMIT = 10000
633
634    _public_key: rsa.PublicKey | None
635    _private_key: rsa.PrivateKey | None
636
637    def __init__(self,
638                 storage_broker,
639                 secret_holder,
640                 default_encoding_parameters,
641                 history,
642                 all_contents,
643                 keypair: tuple[rsa.PublicKey, rsa.PrivateKey] | None
644                ):
645        self.all_contents = all_contents
646        self.file_types: dict[bytes, int] = {} # storage index => MDMF_VERSION or SDMF_VERSION
647        self.init_from_cap(make_mutable_file_cap(keypair))
648        self._k = default_encoding_parameters['k']
649        self._segsize = default_encoding_parameters['max_segment_size']
650        if keypair is None:
651            self._public_key = self._private_key = None
652        else:
653            self._public_key, self._private_key = keypair
654
655    def create(self, contents, version=SDMF_VERSION):
656        if version == MDMF_VERSION and \
657            isinstance(self.my_uri, (uri.ReadonlySSKFileURI,
658                                 uri.WriteableSSKFileURI)):
659            self.init_from_cap(make_mdmf_mutable_file_cap())
660        self.file_types[self.storage_index] = version
661        initial_contents = self._get_initial_contents(contents)
662        data = initial_contents.read(initial_contents.get_size())
663        data = b"".join(data)
664        self.all_contents[self.storage_index] = data
665        return defer.succeed(self)
666    def _get_initial_contents(self, contents):
667        if contents is None:
668            return MutableData(b"")
669
670        if IMutableUploadable.providedBy(contents):
671            return contents
672
673        assert callable(contents), "%s should be callable, not %s" % \
674               (contents, type(contents))
675        return contents(self)
676    def init_from_cap(self, filecap):
677        assert isinstance(filecap, (uri.WriteableSSKFileURI,
678                                    uri.ReadonlySSKFileURI,
679                                    uri.WriteableMDMFFileURI,
680                                    uri.ReadonlyMDMFFileURI))
681        self.my_uri = filecap
682        self.storage_index = self.my_uri.get_storage_index()
683        if isinstance(filecap, (uri.WriteableMDMFFileURI,
684                                uri.ReadonlyMDMFFileURI)):
685            self.file_types[self.storage_index] = MDMF_VERSION
686
687        else:
688            self.file_types[self.storage_index] = SDMF_VERSION
689
690        return self
691    def get_cap(self):
692        return self.my_uri
693    def get_readcap(self):
694        return self.my_uri.get_readonly()
695    def get_uri(self):
696        return self.my_uri.to_string()
697    def get_write_uri(self):
698        if self.is_readonly():
699            return None
700        return self.my_uri.to_string()
701    def get_readonly(self):
702        return self.my_uri.get_readonly()
703    def get_readonly_uri(self):
704        return self.my_uri.get_readonly().to_string()
705    def get_verify_cap(self):
706        return self.my_uri.get_verify_cap()
707    def get_repair_cap(self):
708        if self.my_uri.is_readonly():
709            return None
710        return self.my_uri
711    def is_readonly(self):
712        return self.my_uri.is_readonly()
713    def is_mutable(self):
714        return self.my_uri.is_mutable()
715    def is_unknown(self):
716        return False
717    def is_allowed_in_immutable_directory(self):
718        return not self.my_uri.is_mutable()
719    def raise_error(self):
720        pass
721    def get_writekey(self):
722        return b"\x00"*16
723    def get_size(self):
724        return len(self.all_contents[self.storage_index])
725    def get_current_size(self):
726        return self.get_size_of_best_version()
727    def get_size_of_best_version(self):
728        return defer.succeed(len(self.all_contents[self.storage_index]))
729
730    def get_storage_index(self):
731        return self.storage_index
732
733    def get_servermap(self, mode):
734        return defer.succeed(None)
735
736    def get_version(self):
737        assert self.storage_index in self.file_types
738        return self.file_types[self.storage_index]
739
740    def check(self, monitor, verify=False, add_lease=False):
741        s = StubServer(b"\x00"*20)
742        r = CheckResults(self.my_uri, self.storage_index,
743                         healthy=True, recoverable=True,
744                         count_happiness=10,
745                         count_shares_needed=3,
746                         count_shares_expected=10,
747                         count_shares_good=10,
748                         count_good_share_hosts=10,
749                         count_recoverable_versions=1,
750                         count_unrecoverable_versions=0,
751                         servers_responding=[s],
752                         sharemap={b"seq1-abcd-sh0": [s]},
753                         count_wrong_shares=0,
754                         list_corrupt_shares=[],
755                         count_corrupt_shares=0,
756                         list_incompatible_shares=[],
757                         count_incompatible_shares=0,
758                         summary="",
759                         report=[],
760                         share_problems=[],
761                         servermap=None)
762        return defer.succeed(r)
763
764    def check_and_repair(self, monitor, verify=False, add_lease=False):
765        d = self.check(verify)
766        def _got(cr):
767            r = CheckAndRepairResults(self.storage_index)
768            r.pre_repair_results = r.post_repair_results = cr
769            return r
770        d.addCallback(_got)
771        return d
772
773    def deep_check(self, verify=False, add_lease=False):
774        d = self.check(verify)
775        def _done(r):
776            dr = DeepCheckResults(self.storage_index)
777            dr.add_check(r, [])
778            return dr
779        d.addCallback(_done)
780        return d
781
782    def deep_check_and_repair(self, verify=False, add_lease=False):
783        d = self.check_and_repair(verify)
784        def _done(r):
785            dr = DeepCheckAndRepairResults(self.storage_index)
786            dr.add_check(r, [])
787            return dr
788        d.addCallback(_done)
789        return d
790
791    def download_best_version(self):
792        return defer.succeed(self._download_best_version())
793
794
795    def _download_best_version(self, ignored=None):
796        if isinstance(self.my_uri, uri.LiteralFileURI):
797            return self.my_uri.data
798        if self.storage_index not in self.all_contents:
799            raise NotEnoughSharesError(None, 0, 3)
800        return self.all_contents[self.storage_index]
801
802
803    def overwrite(self, new_contents):
804        assert not self.is_readonly()
805        new_data = new_contents.read(new_contents.get_size())
806        new_data = b"".join(new_data)
807        self.all_contents[self.storage_index] = new_data
808        return defer.succeed(None)
809    def modify(self, modifier):
810        # this does not implement FileTooLargeError, but the real one does
811        return defer.maybeDeferred(self._modify, modifier)
812    def _modify(self, modifier):
813        assert not self.is_readonly()
814        old_contents = self.all_contents[self.storage_index]
815        new_data = modifier(old_contents, None, True)
816        self.all_contents[self.storage_index] = new_data
817        return None
818
819    # As actually implemented, MutableFilenode and MutableFileVersion
820    # are distinct. However, nothing in the webapi uses (yet) that
821    # distinction -- it just uses the unified download interface
822    # provided by get_best_readable_version and read. When we start
823    # doing cooler things like LDMF, we will want to revise this code to
824    # be less simplistic.
825    def get_best_readable_version(self):
826        return defer.succeed(self)
827
828
829    def get_best_mutable_version(self):
830        return defer.succeed(self)
831
832    # Ditto for this, which is an implementation of IWriteable.
833    # XXX: Declare that the same is implemented.
834    def update(self, data, offset):
835        assert not self.is_readonly()
836        def modifier(old, servermap, first_time):
837            new = old[:offset] + b"".join(data.read(data.get_size()))
838            new += old[len(new):]
839            return new
840        return self.modify(modifier)
841
842
843    def read(self, consumer, offset=0, size=None):
844        data = self._download_best_version()
845        if size:
846            data = data[offset:offset+size]
847        consumer.write(data)
848        return defer.succeed(consumer)
849
850
851def make_mutable_file_cap(
852        keypair: tuple[rsa.PublicKey, rsa.PrivateKey] | None = None,
853) -> uri.WriteableSSKFileURI:
854    """
855    Create a local representation of a mutable object.
856
857    :param keypair: If None, a random keypair will be generated for the new
858        object.  Otherwise, this is the keypair for that object.
859    """
860    if keypair is None:
861        writekey = os.urandom(16)
862        fingerprint = os.urandom(32)
863    else:
864        pubkey, privkey = keypair
865        pubkey_s = rsa.der_string_from_verifying_key(pubkey)
866        privkey_s = rsa.der_string_from_signing_key(privkey)
867        writekey = hashutil.ssk_writekey_hash(privkey_s)
868        fingerprint = hashutil.ssk_pubkey_fingerprint_hash(pubkey_s)
869
870    return uri.WriteableSSKFileURI(
871        writekey=writekey, fingerprint=fingerprint,
872    )
873
874def make_mdmf_mutable_file_cap():
875    return uri.WriteableMDMFFileURI(writekey=os.urandom(16),
876                                   fingerprint=os.urandom(32))
877
878def make_mutable_file_uri(mdmf=False):
879    if mdmf:
880        uri = make_mdmf_mutable_file_cap()
881    else:
882        uri = make_mutable_file_cap()
883
884    return uri.to_string()
885
886def make_verifier_uri():
887    return uri.SSKVerifierURI(storage_index=os.urandom(16),
888                              fingerprint=os.urandom(32)).to_string()
889
890def create_mutable_filenode(contents, mdmf=False, all_contents=None):
891    # XXX: All of these arguments are kind of stupid.
892    if mdmf:
893        cap = make_mdmf_mutable_file_cap()
894    else:
895        cap = make_mutable_file_cap()
896
897    encoding_params = {}
898    encoding_params['k'] = 3
899    encoding_params['max_segment_size'] = 128*1024
900
901    filenode = FakeMutableFileNode(None, None, encoding_params, None,
902                                   all_contents, None)
903    filenode.init_from_cap(cap)
904    if mdmf:
905        filenode.create(MutableData(contents), version=MDMF_VERSION)
906    else:
907        filenode.create(MutableData(contents), version=SDMF_VERSION)
908    return filenode
909
910
911class LoggingServiceParent(service.MultiService):
912    def log(self, *args, **kwargs):
913        return log.msg(*args, **kwargs)
914
915
916TEST_DATA=b"\x02"*(Uploader.URI_LIT_SIZE_THRESHOLD+1)
917
918
919class WebErrorMixin:
920    def explain_web_error(self, f):
921        # an error on the server side causes the client-side getPage() to
922        # return a failure(t.web.error.Error), and its str() doesn't show the
923        # response body, which is where the useful information lives. Attach
924        # this method as an errback handler, and it will reveal the hidden
925        # message.
926        f.trap(WebError)
927        print("Web Error:", f.value, ":", f.value.response)
928        return f
929
930    def _shouldHTTPError(self, res, which, validator):
931        if isinstance(res, failure.Failure):
932            res.trap(WebError)
933            return validator(res)
934        else:
935            self.fail("%s was supposed to Error, not get '%s'" % (which, res))
936
937    def shouldHTTPError(self, which,
938                        code=None, substring=None, response_substring=None,
939                        callable=None, *args, **kwargs):
940        # returns a Deferred with the response body
941        if isinstance(substring, bytes):
942            substring = str(substring, "ascii")
943        if isinstance(response_substring, str):
944            response_substring = response_substring.encode("ascii")
945        assert substring is None or isinstance(substring, str)
946        assert response_substring is None or isinstance(response_substring, bytes)
947        assert callable
948        def _validate(f):
949            if code is not None:
950                self.failUnlessEqual(f.value.status, b"%d" % code, which)
951            if substring:
952                code_string = str(f)
953                self.failUnless(substring in code_string,
954                                "%s: substring '%s' not in '%s'"
955                                % (which, substring, code_string))
956            response_body = f.value.response
957            if response_substring:
958                self.failUnless(response_substring in response_body,
959                                "%r: response substring %r not in %r"
960                                % (which, response_substring, response_body))
961            return response_body
962        d = defer.maybeDeferred(callable, *args, **kwargs)
963        d.addBoth(self._shouldHTTPError, which, _validate)
964        return d
965
966    @inlineCallbacks
967    def assertHTTPError(self, url, code, response_substring,
968                        method="get", persistent=False,
969                        **args):
970        response = yield treq.request(method, url, persistent=persistent,
971                                      **args)
972        body = yield response.content()
973        self.assertEquals(response.code, code)
974        if response_substring is not None:
975            if isinstance(response_substring, str):
976                response_substring = response_substring.encode("utf-8")
977            self.assertIn(response_substring, body)
978        returnValue(body)
979
980class ErrorMixin(WebErrorMixin):
981    def explain_error(self, f):
982        if f.check(defer.FirstError):
983            print("First Error:", f.value.subFailure)
984        return f
985
986def corrupt_field(data, offset, size, debug=False):
987    if random.random() < 0.5:
988        newdata = testutil.flip_one_bit(data, offset, size)
989        if debug:
990            log.msg("testing: corrupting offset %d, size %d flipping one bit orig: %r, newdata: %r" % (offset, size, data[offset:offset+size], newdata[offset:offset+size]))
991        return newdata
992    else:
993        newval = testutil.insecurerandstr(size)
994        if debug:
995            log.msg("testing: corrupting offset %d, size %d randomizing field, orig: %r, newval: %r" % (offset, size, data[offset:offset+size], newval))
996        return data[:offset]+newval+data[offset+size:]
997
998def _corrupt_nothing(data, debug=False):
999    """Leave the data pristine. """
1000    return data
1001
1002def _corrupt_file_version_number(data, debug=False):
1003    """Scramble the file data -- the share file version number have one bit
1004    flipped or else will be changed to a random value."""
1005    return corrupt_field(data, 0x00, 4)
1006
1007def _corrupt_size_of_file_data(data, debug=False):
1008    """Scramble the file data -- the field showing the size of the share data
1009    within the file will be set to one smaller."""
1010    return corrupt_field(data, 0x04, 4)
1011
1012def _corrupt_sharedata_version_number(data, debug=False):
1013    """Scramble the file data -- the share data version number will have one
1014    bit flipped or else will be changed to a random value, but not 1 or 2."""
1015    return corrupt_field(data, 0x0c, 4)
1016    sharevernum = struct.unpack(">L", data[0x0c:0x0c+4])[0]
1017    assert sharevernum in (1, 2), "This test is designed to corrupt immutable shares of v1 or v2 in specific ways."
1018    newsharevernum = sharevernum
1019    while newsharevernum in (1, 2):
1020        newsharevernum = random.randrange(0, 2**32)
1021    newsharevernumbytes = struct.pack(">L", newsharevernum)
1022    return data[:0x0c] + newsharevernumbytes + data[0x0c+4:]
1023
1024def _corrupt_sharedata_version_number_to_plausible_version(data, debug=False):
1025    """Scramble the file data -- the share data version number will be
1026    changed to 2 if it is 1 or else to 1 if it is 2."""
1027    sharevernum = struct.unpack(">L", data[0x0c:0x0c+4])[0]
1028    assert sharevernum in (1, 2), "This test is designed to corrupt immutable shares of v1 or v2 in specific ways."
1029    if sharevernum == 1:
1030        newsharevernum = 2
1031    else:
1032        newsharevernum = 1
1033    newsharevernumbytes = struct.pack(">L", newsharevernum)
1034    return data[:0x0c] + newsharevernumbytes + data[0x0c+4:]
1035
1036def _corrupt_segment_size(data, debug=False):
1037    """Scramble the file data -- the field showing the size of the segment
1038    will have one bit flipped or else be changed to a random value."""
1039    sharevernum = struct.unpack(">L", data[0x0c:0x0c+4])[0]
1040    assert sharevernum in (1, 2), "This test is designed to corrupt immutable shares of v1 or v2 in specific ways."
1041    if sharevernum == 1:
1042        return corrupt_field(data, 0x0c+0x04, 4, debug=False)
1043    else:
1044        return corrupt_field(data, 0x0c+0x04, 8, debug=False)
1045
1046def _corrupt_size_of_sharedata(data, debug=False):
1047    """Scramble the file data -- the field showing the size of the data
1048    within the share data will have one bit flipped or else will be changed
1049    to a random value."""
1050    sharevernum = struct.unpack(">L", data[0x0c:0x0c+4])[0]
1051    assert sharevernum in (1, 2), "This test is designed to corrupt immutable shares of v1 or v2 in specific ways."
1052    if sharevernum == 1:
1053        return corrupt_field(data, 0x0c+0x08, 4)
1054    else:
1055        return corrupt_field(data, 0x0c+0x0c, 8)
1056
1057def _corrupt_offset_of_sharedata(data, debug=False):
1058    """Scramble the file data -- the field showing the offset of the data
1059    within the share data will have one bit flipped or else be changed to a
1060    random value."""
1061    sharevernum = struct.unpack(">L", data[0x0c:0x0c+4])[0]
1062    assert sharevernum in (1, 2), "This test is designed to corrupt immutable shares of v1 or v2 in specific ways."
1063    if sharevernum == 1:
1064        return corrupt_field(data, 0x0c+0x0c, 4)
1065    else:
1066        return corrupt_field(data, 0x0c+0x14, 8)
1067
1068def _corrupt_offset_of_ciphertext_hash_tree(data, debug=False):
1069    """Scramble the file data -- the field showing the offset of the
1070    ciphertext hash tree within the share data will have one bit flipped or
1071    else be changed to a random value.
1072    """
1073    sharevernum = struct.unpack(">L", data[0x0c:0x0c+4])[0]
1074    assert sharevernum in (1, 2), "This test is designed to corrupt immutable shares of v1 or v2 in specific ways."
1075    if sharevernum == 1:
1076        return corrupt_field(data, 0x0c+0x14, 4, debug=False)
1077    else:
1078        return corrupt_field(data, 0x0c+0x24, 8, debug=False)
1079
1080def _corrupt_offset_of_block_hashes(data, debug=False):
1081    """Scramble the file data -- the field showing the offset of the block
1082    hash tree within the share data will have one bit flipped or else will be
1083    changed to a random value."""
1084    sharevernum = struct.unpack(">L", data[0x0c:0x0c+4])[0]
1085    assert sharevernum in (1, 2), "This test is designed to corrupt immutable shares of v1 or v2 in specific ways."
1086    if sharevernum == 1:
1087        return corrupt_field(data, 0x0c+0x18, 4)
1088    else:
1089        return corrupt_field(data, 0x0c+0x2c, 8)
1090
1091def _corrupt_offset_of_block_hashes_to_truncate_crypttext_hashes(data, debug=False):
1092    """Scramble the file data -- the field showing the offset of the block
1093    hash tree within the share data will have a multiple of hash size
1094    subtracted from it, thus causing the downloader to download an incomplete
1095    crypttext hash tree."""
1096    sharevernum = struct.unpack(">L", data[0x0c:0x0c+4])[0]
1097    assert sharevernum in (1, 2), "This test is designed to corrupt immutable shares of v1 or v2 in specific ways."
1098    if sharevernum == 1:
1099        curval = struct.unpack(">L", data[0x0c+0x18:0x0c+0x18+4])[0]
1100        newval = random.randrange(0, max(1, (curval//hashutil.CRYPTO_VAL_SIZE)//2))*hashutil.CRYPTO_VAL_SIZE
1101        newvalstr = struct.pack(">L", newval)
1102        return data[:0x0c+0x18]+newvalstr+data[0x0c+0x18+4:]
1103    else:
1104        curval = struct.unpack(">Q", data[0x0c+0x2c:0x0c+0x2c+8])[0]
1105        newval = random.randrange(0, max(1, (curval//hashutil.CRYPTO_VAL_SIZE)//2))*hashutil.CRYPTO_VAL_SIZE
1106        newvalstr = struct.pack(">Q", newval)
1107        return data[:0x0c+0x2c]+newvalstr+data[0x0c+0x2c+8:]
1108
1109def _corrupt_offset_of_share_hashes(data, debug=False):
1110    """Scramble the file data -- the field showing the offset of the share
1111    hash tree within the share data will have one bit flipped or else will be
1112    changed to a random value."""
1113    sharevernum = struct.unpack(">L", data[0x0c:0x0c+4])[0]
1114    assert sharevernum in (1, 2), "This test is designed to corrupt immutable shares of v1 or v2 in specific ways."
1115    if sharevernum == 1:
1116        return corrupt_field(data, 0x0c+0x1c, 4)
1117    else:
1118        return corrupt_field(data, 0x0c+0x34, 8)
1119
1120def _corrupt_offset_of_uri_extension(data, debug=False):
1121    """Scramble the file data -- the field showing the offset of the uri
1122    extension will have one bit flipped or else will be changed to a random
1123    value."""
1124    sharevernum = struct.unpack(">L", data[0x0c:0x0c+4])[0]
1125    assert sharevernum in (1, 2), "This test is designed to corrupt immutable shares of v1 or v2 in specific ways."
1126    if sharevernum == 1:
1127        return corrupt_field(data, 0x0c+0x20, 4)
1128    else:
1129        return corrupt_field(data, 0x0c+0x3c, 8)
1130
1131def _corrupt_offset_of_uri_extension_to_force_short_read(data, debug=False):
1132    """Scramble the file data -- the field showing the offset of the uri
1133    extension will be set to the size of the file minus 3. This means when
1134    the client tries to read the length field from that location it will get
1135    a short read -- the result string will be only 3 bytes long, not the 4 or
1136    8 bytes necessary to do a successful struct.unpack."""
1137    sharevernum = struct.unpack(">L", data[0x0c:0x0c+4])[0]
1138    assert sharevernum in (1, 2), "This test is designed to corrupt immutable shares of v1 or v2 in specific ways."
1139    # The "-0x0c" in here is to skip the server-side header in the share
1140    # file, which the client doesn't see when seeking and reading.
1141    if sharevernum == 1:
1142        if debug:
1143            log.msg("testing: corrupting offset %d, size %d, changing %d to %d (len(data) == %d)" % (0x2c, 4, struct.unpack(">L", data[0x2c:0x2c+4])[0], len(data)-0x0c-3, len(data)))
1144        return data[:0x2c] + struct.pack(">L", len(data)-0x0c-3) + data[0x2c+4:]
1145    else:
1146        if debug:
1147            log.msg("testing: corrupting offset %d, size %d, changing %d to %d (len(data) == %d)" % (0x48, 8, struct.unpack(">Q", data[0x48:0x48+8])[0], len(data)-0x0c-3, len(data)))
1148        return data[:0x48] + struct.pack(">Q", len(data)-0x0c-3) + data[0x48+8:]
1149
1150def _corrupt_mutable_share_data(data, debug=False):
1151    prefix = data[:32]
1152    assert MutableShareFile.is_valid_header(prefix), "This function is designed to corrupt mutable shares of v1, and the magic number doesn't look right: %r vs %r" % (prefix, MutableShareFile.MAGIC)
1153    data_offset = MutableShareFile.DATA_OFFSET
1154    sharetype = data[data_offset:data_offset+1]
1155    assert sharetype == b"\x00", "non-SDMF mutable shares not supported"
1156    (version, ig_seqnum, ig_roothash, ig_IV, ig_k, ig_N, ig_segsize,
1157     ig_datalen, offsets) = unpack_header(data[data_offset:])
1158    assert version == 0, "this function only handles v0 SDMF files"
1159    start = data_offset + offsets["share_data"]
1160    length = data_offset + offsets["enc_privkey"] - start
1161    return corrupt_field(data, start, length)
1162
1163def _corrupt_share_data(data, debug=False):
1164    """Scramble the file data -- the field containing the share data itself
1165    will have one bit flipped or else will be changed to a random value."""
1166    sharevernum = struct.unpack(">L", data[0x0c:0x0c+4])[0]
1167    assert sharevernum in (1, 2), "This test is designed to corrupt immutable shares of v1 or v2 in specific ways, not v%d." % sharevernum
1168    if sharevernum == 1:
1169        sharedatasize = struct.unpack(">L", data[0x0c+0x08:0x0c+0x08+4])[0]
1170
1171        return corrupt_field(data, 0x0c+0x24, sharedatasize)
1172    else:
1173        sharedatasize = struct.unpack(">Q", data[0x0c+0x08:0x0c+0x0c+8])[0]
1174
1175        return corrupt_field(data, 0x0c+0x44, sharedatasize)
1176
1177def _corrupt_share_data_last_byte(data, debug=False):
1178    """Scramble the file data -- flip all bits of the last byte."""
1179    sharevernum = struct.unpack(">L", data[0x0c:0x0c+4])[0]
1180    assert sharevernum in (1, 2), "This test is designed to corrupt immutable shares of v1 or v2 in specific ways, not v%d." % sharevernum
1181    if sharevernum == 1:
1182        sharedatasize = struct.unpack(">L", data[0x0c+0x08:0x0c+0x08+4])[0]
1183        offset = 0x0c+0x24+sharedatasize-1
1184    else:
1185        sharedatasize = struct.unpack(">Q", data[0x0c+0x08:0x0c+0x0c+8])[0]
1186        offset = 0x0c+0x44+sharedatasize-1
1187
1188    newdata = data[:offset] + byteschr(ord(data[offset:offset+1])^0xFF) + data[offset+1:]
1189    if debug:
1190        log.msg("testing: flipping all bits of byte at offset %d: %r, newdata: %r" % (offset, data[offset], newdata[offset]))
1191    return newdata
1192
1193def _corrupt_crypttext_hash_tree(data, debug=False):
1194    """Scramble the file data -- the field containing the crypttext hash tree
1195    will have one bit flipped or else will be changed to a random value.
1196    """
1197    sharevernum = struct.unpack(">L", data[0x0c:0x0c+4])[0]
1198    assert sharevernum in (1, 2), "This test is designed to corrupt immutable shares of v1 or v2 in specific ways."
1199    if sharevernum == 1:
1200        crypttexthashtreeoffset = struct.unpack(">L", data[0x0c+0x14:0x0c+0x14+4])[0]
1201        blockhashesoffset = struct.unpack(">L", data[0x0c+0x18:0x0c+0x18+4])[0]
1202    else:
1203        crypttexthashtreeoffset = struct.unpack(">Q", data[0x0c+0x24:0x0c+0x24+8])[0]
1204        blockhashesoffset = struct.unpack(">Q", data[0x0c+0x2c:0x0c+0x2c+8])[0]
1205
1206    return corrupt_field(data, 0x0c+crypttexthashtreeoffset, blockhashesoffset-crypttexthashtreeoffset, debug=debug)
1207
1208def _corrupt_crypttext_hash_tree_byte_x221(data, debug=False):
1209    """Scramble the file data -- the byte at offset 0x221 will have its 7th
1210    (b1) bit flipped.
1211    """
1212    sharevernum = struct.unpack(">L", data[0x0c:0x0c+4])[0]
1213    assert sharevernum in (1, 2), "This test is designed to corrupt immutable shares of v1 or v2 in specific ways."
1214    if debug:
1215        log.msg("original data: %r" % (data,))
1216    return data[:0x0c+0x221] + byteschr(ord(data[0x0c+0x221:0x0c+0x221+1])^0x02) + data[0x0c+0x2210+1:]
1217
1218def _corrupt_block_hashes(data, debug=False):
1219    """Scramble the file data -- the field containing the block hash tree
1220    will have one bit flipped or else will be changed to a random value.
1221    """
1222    sharevernum = struct.unpack(">L", data[0x0c:0x0c+4])[0]
1223    assert sharevernum in (1, 2), "This test is designed to corrupt immutable shares of v1 or v2 in specific ways."
1224    if sharevernum == 1:
1225        blockhashesoffset = struct.unpack(">L", data[0x0c+0x18:0x0c+0x18+4])[0]
1226        sharehashesoffset = struct.unpack(">L", data[0x0c+0x1c:0x0c+0x1c+4])[0]
1227    else:
1228        blockhashesoffset = struct.unpack(">Q", data[0x0c+0x2c:0x0c+0x2c+8])[0]
1229        sharehashesoffset = struct.unpack(">Q", data[0x0c+0x34:0x0c+0x34+8])[0]
1230
1231    return corrupt_field(data, 0x0c+blockhashesoffset, sharehashesoffset-blockhashesoffset)
1232
1233def _corrupt_share_hashes(data, debug=False):
1234    """Scramble the file data -- the field containing the share hash chain
1235    will have one bit flipped or else will be changed to a random value.
1236    """
1237    sharevernum = struct.unpack(">L", data[0x0c:0x0c+4])[0]
1238    assert sharevernum in (1, 2), "This test is designed to corrupt immutable shares of v1 or v2 in specific ways."
1239    if sharevernum == 1:
1240        sharehashesoffset = struct.unpack(">L", data[0x0c+0x1c:0x0c+0x1c+4])[0]
1241        uriextoffset = struct.unpack(">L", data[0x0c+0x20:0x0c+0x20+4])[0]
1242    else:
1243        sharehashesoffset = struct.unpack(">Q", data[0x0c+0x34:0x0c+0x34+8])[0]
1244        uriextoffset = struct.unpack(">Q", data[0x0c+0x3c:0x0c+0x3c+8])[0]
1245
1246    return corrupt_field(data, 0x0c+sharehashesoffset, uriextoffset-sharehashesoffset)
1247
1248def _corrupt_length_of_uri_extension(data, debug=False):
1249    """Scramble the file data -- the field showing the length of the uri
1250    extension will have one bit flipped or else will be changed to a random
1251    value."""
1252    sharevernum = struct.unpack(">L", data[0x0c:0x0c+4])[0]
1253    assert sharevernum in (1, 2), "This test is designed to corrupt immutable shares of v1 or v2 in specific ways."
1254    if sharevernum == 1:
1255        uriextoffset = struct.unpack(">L", data[0x0c+0x20:0x0c+0x20+4])[0]
1256        return corrupt_field(data, uriextoffset, 4)
1257    else:
1258        uriextoffset = struct.unpack(">Q", data[0x0c+0x3c:0x0c+0x3c+8])[0]
1259        return corrupt_field(data, 0x0c+uriextoffset, 8)
1260
1261def _corrupt_uri_extension(data, debug=False):
1262    """Scramble the file data -- the field containing the uri extension will
1263    have one bit flipped or else will be changed to a random value."""
1264    sharevernum = struct.unpack(">L", data[0x0c:0x0c+4])[0]
1265    assert sharevernum in (1, 2), "This test is designed to corrupt immutable shares of v1 or v2 in specific ways."
1266    if sharevernum == 1:
1267        uriextoffset = struct.unpack(">L", data[0x0c+0x20:0x0c+0x20+4])[0]
1268        uriextlen = struct.unpack(">L", data[0x0c+uriextoffset:0x0c+uriextoffset+4])[0]
1269    else:
1270        uriextoffset = struct.unpack(">Q", data[0x0c+0x3c:0x0c+0x3c+8])[0]
1271        uriextlen = struct.unpack(">Q", data[0x0c+uriextoffset:0x0c+uriextoffset+8])[0]
1272
1273    return corrupt_field(data, 0x0c+uriextoffset, uriextlen)
1274
1275
1276
1277@attr.s
1278@implementer(IAddressFamily)
1279class ConstantAddresses:
1280    """
1281    Pretend to provide support for some address family but just hand out
1282    canned responses.
1283    """
1284    _listener = attr.ib(default=None)
1285    _handler = attr.ib(default=None)
1286
1287    def get_listener(self):
1288        if self._listener is None:
1289            raise Exception("{!r} has no listener.")
1290        return self._listener
1291
1292    def get_client_endpoint(self):
1293        if self._handler is None:
1294            raise Exception("{!r} has no client endpoint.")
1295        return self._handler
1296
1297@contextmanager
1298def disable_modules(*names):
1299    """
1300    A context manager which makes modules appear to be missing while it is
1301    active.
1302
1303    :param *names: The names of the modules to disappear.  Only top-level
1304        modules are supported (that is, "." is not allowed in any names).
1305        This is an implementation shortcoming which could be lifted if
1306        desired.
1307    """
1308    if any("." in name for name in names):
1309        raise ValueError("Names containing '.' are not supported.")
1310    missing = object()
1311    modules = list(sys.modules.get(n, missing) for n in names)
1312    for n in names:
1313        sys.modules[n] = None
1314    yield
1315    for n, original in zip(names, modules):
1316        if original is missing:
1317            del sys.modules[n]
1318        else:
1319            sys.modules[n] = original
1320
1321class _TestCaseMixin:
1322    """
1323    A mixin for ``TestCase`` which collects helpful behaviors for subclasses.
1324
1325    Those behaviors are:
1326
1327    * All of the features of testtools TestCase.
1328    * Each test method will be run in a unique Eliot action context which
1329      identifies the test and collects all Eliot log messages emitted by that
1330      test (including setUp and tearDown messages).
1331    * trial-compatible mktemp method
1332    * unittest2-compatible assertRaises helper
1333    * Automatic cleanup of tempfile.tempdir mutation (once pervasive through
1334      the Tahoe-LAFS test suite, perhaps gone now but someone should verify
1335      this).
1336    """
1337    def setUp(self):
1338        # Restore the original temporary directory.  Node ``init_tempdir``
1339        # mangles it and many tests manage to get that method called.
1340        self.addCleanup(
1341            partial(setattr, tempfile, "tempdir", tempfile.tempdir),
1342        )
1343        return super(_TestCaseMixin, self).setUp()
1344
1345    class _DummyCase(_case.TestCase):
1346        def dummy(self):
1347            pass
1348    _dummyCase = _DummyCase("dummy")
1349
1350    def mktemp(self):
1351        return mktemp()
1352
1353    def assertRaises(self, *a, **kw):
1354        return self._dummyCase.assertRaises(*a, **kw)
1355
1356    def failUnless(self, *args, **kwargs):
1357        """Backwards compatibility method."""
1358        self.assertTrue(*args, **kwargs)
1359
1360    def failIf(self, *args, **kwargs):
1361        """Backwards compatibility method."""
1362        self.assertFalse(*args, **kwargs)
1363
1364    def failIfEqual(self, *args, **kwargs):
1365        """Backwards compatibility method."""
1366        self.assertNotEqual(*args, **kwargs)
1367
1368    def failUnlessEqual(self, *args, **kwargs):
1369        """Backwards compatibility method."""
1370        self.assertEqual(*args, **kwargs)
1371
1372    def failUnlessReallyEqual(self, *args, **kwargs):
1373        """Backwards compatibility method."""
1374        self.assertReallyEqual(*args, **kwargs)
1375
1376
1377class SyncTestCase(_TestCaseMixin, TestCase):
1378    """
1379    A ``TestCase`` which can run tests that may return an already-fired
1380    ``Deferred``.
1381    """
1382    run_tests_with = EliotLoggedRunTest.make_factory(
1383        SynchronousDeferredRunTest,
1384    )
1385
1386
1387class AsyncTestCase(_TestCaseMixin, TestCase):
1388    """
1389    A ``TestCase`` which can run tests that may return a Deferred that will
1390    only fire if the global reactor is running.
1391    """
1392    run_tests_with = EliotLoggedRunTest.make_factory(
1393        AsynchronousDeferredRunTest.make_factory(timeout=60.0),
1394    )
1395
1396
1397class AsyncBrokenTestCase(_TestCaseMixin, TestCase):
1398    """
1399    A ``TestCase`` like ``AsyncTestCase`` but which spins the reactor a little
1400    longer than apparently necessary to clean out lingering unaccounted for
1401    event sources.
1402
1403    Tests which require this behavior are broken and should be fixed so they
1404    pass with ``AsyncTestCase``.
1405    """
1406    run_tests_with = EliotLoggedRunTest.make_factory(
1407        AsynchronousDeferredRunTestForBrokenTwisted.make_factory(timeout=60.0),
1408    )
1409
1410
1411class TrialTestCase(_TrialTestCase):
1412    """
1413    A twisted.trial.unittest.TestCaes with Tahoe required fixes
1414    applied. Currently these are:
1415
1416      - ensure that .fail() passes a bytes msg on Python2
1417    """
1418
1419    def fail(self, msg):
1420        """
1421        Ensure our msg is a native string on Python2. If it was Unicode,
1422        we encode it as utf8 and hope for the best. On Python3 we take
1423        no action.
1424
1425        This is necessary because Twisted passes the 'msg' argument
1426        along to the constructor of an exception; on Python2,
1427        Exception will accept a `unicode` instance but will fail if
1428        you try to turn that Exception instance into a string.
1429        """
1430
1431        return super(TrialTestCase, self).fail(msg)
Note: See TracBrowser for help on using the repository browser.