Skip to content

Commit 0af8418

Browse files
committed
Fix #4
Switched implementation to mmap. All underlying reads are mmap-based so duplication of file pointers is gone and multiple threads can simultaneously access the compound document. Note that #1 is not fixed by this as we still read the entire FAT into memory upon initial access.
1 parent 5ea8fa0 commit 0af8418

File tree

1 file changed

+98
-85
lines changed

1 file changed

+98
-85
lines changed

compoundfiles/__init__.py

+98-85
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,13 @@
7676
import struct as st
7777
import logging
7878
import warnings
79+
import mmap
80+
import tempfile
81+
import shutil
7982
import datetime as dt
83+
from abc import abstractmethod
8084
from pprint import pformat
8185
from array import array
82-
from mmap import mmap
8386

8487

8588
__all__ = [
@@ -219,16 +222,6 @@ class CompoundFileStream(io.RawIOBase):
219222
directly, but are returned by the :meth:`CompoundFileReader.open` method.
220223
They support all common methods associated with read-only streams
221224
(:meth:`read`, :meth:`seek`, :meth:`tell`, and so forth).
222-
223-
.. note::
224-
225-
The implementation attempts to duplicate the parent object's file
226-
descriptor upon construction which theoretically means multiple threads
227-
can simultaneously read different files in the compound document.
228-
However, if duplication of the file descriptor fails for any reason,
229-
the implementation falls back on sharing the parent object's file
230-
descriptor. In this case, thread safety is not guaranteed. Check the
231-
:attr:`thread_safe` attribute to determine if duplication succeeded.
232225
"""
233226
def __init__(self):
234227
super(CompoundFileStream, self).__init__()
@@ -254,25 +247,9 @@ def _load_sectors(self, start, fat):
254247
raise CompoundFileError(
255248
'cyclic FAT chain found starting at %d' % start)
256249

250+
@abstractmethod
257251
def _set_pos(self, value):
258-
self._sector_index = value // self._sector_size
259-
self._sector_offset = value % self._sector_size
260-
if self._sector_index < len(self._sectors):
261-
self._file.seek(
262-
self._header_size +
263-
(self._sectors[self._sector_index] * self._sector_size) +
264-
self._sector_offset)
265-
266-
def close(self):
267-
"""
268-
Close the file pointer.
269-
"""
270-
if self.thread_safe:
271-
try:
272-
self._file.close()
273-
except AttributeError:
274-
pass
275-
self._file = None
252+
raise NotImplementedError
276253

277254
def readable(self):
278255
"""
@@ -326,6 +303,7 @@ def seek(self, offset, whence=io.SEEK_SET):
326303
self._set_pos(offset)
327304
return offset
328305

306+
@abstractmethod
329307
def read1(self, n=-1):
330308
"""
331309
Read up to *n* bytes from the stream using only a single call to the
@@ -335,28 +313,7 @@ def read1(self, n=-1):
335313
returning the content from the current position up to the end of the
336314
current sector.
337315
"""
338-
if not self.thread_safe:
339-
# If we're sharing a file-pointer with the parent object we can't
340-
# guarantee the file pointer is where we left it, so force a seek
341-
self._set_pos(self.tell())
342-
if n == -1:
343-
n = max(0, self._length - self.tell())
344-
else:
345-
n = max(0, min(n, self._length - self.tell()))
346-
n = min(n, self._sector_size - self._sector_offset)
347-
if n == 0:
348-
return b''
349-
try:
350-
result = self._file.read1(n)
351-
except AttributeError:
352-
result = self._file.read(n)
353-
assert len(result) == n
354-
# Only perform a seek to a different sector if we've crossed into one
355-
if self._sector_offset + n < self._sector_size:
356-
self._sector_offset += n
357-
else:
358-
self._set_pos(self.tell() + n)
359-
return result
316+
raise NotImplementedError
360317

361318
def read(self, n=-1):
362319
"""
@@ -372,14 +329,15 @@ def read(self, n=-1):
372329
n = max(0, self._length - self.tell())
373330
else:
374331
n = max(0, min(n, self._length - self.tell()))
375-
result = b''
376-
while n > 0:
377-
buf = self.read1(n)
332+
result = bytearray(n)
333+
i = 0
334+
while i < n:
335+
buf = self.read1(n - i)
378336
if not buf:
379337
break
380-
n -= len(buf)
381-
result += buf
382-
return result
338+
result[i:i + len(buf)] = buf
339+
i += len(buf)
340+
return bytes(result)
383341

384342

385343
class CompoundFileNormalStream(CompoundFileStream):
@@ -388,15 +346,7 @@ def __init__(self, parent, start, length=None):
388346
self._load_sectors(start, parent._normal_fat)
389347
self._sector_size = parent._normal_sector_size
390348
self._header_size = parent._header_size
391-
try:
392-
fd = os.dup(parent._file.fileno())
393-
except (AttributeError, OSError) as e:
394-
# Share the parent's _file if we fail to duplicate the descriptor
395-
self._file = parent._file
396-
self.thread_safe = False
397-
else:
398-
self._file = io.open(fd, 'rb')
399-
self.thread_safe = True
349+
self._mmap = parent._mmap
400350
min_length = (len(self._sectors) - 1) * self._sector_size
401351
max_length = len(self._sectors) * self._sector_size
402352
if length is None:
@@ -411,6 +361,29 @@ def __init__(self, parent, start, length=None):
411361
self._length = length
412362
self._set_pos(0)
413363

364+
def close(self):
365+
self._mmap = None
366+
367+
def _set_pos(self, value):
368+
self._sector_index = value // self._sector_size
369+
self._sector_offset = value % self._sector_size
370+
371+
def read1(self, n=-1):
372+
if n == -1:
373+
n = max(0, self._length - self.tell())
374+
else:
375+
n = max(0, min(n, self._length - self.tell()))
376+
n = min(n, self._sector_size - self._sector_offset)
377+
if n == 0:
378+
return b''
379+
offset = (
380+
self._header_size + (
381+
self._sectors[self._sector_index] * self._sector_size) +
382+
self._sector_offset)
383+
result = self._mmap[offset:offset + n]
384+
self._set_pos(self.tell() + n)
385+
return result
386+
414387

415388
class CompoundFileMiniStream(CompoundFileStream):
416389
def __init__(self, parent, start, length=None):
@@ -420,7 +393,6 @@ def __init__(self, parent, start, length=None):
420393
self._header_size = 0
421394
self._file = CompoundFileNormalStream(
422395
parent, parent.root._start_sector, parent.root.size)
423-
self.thread_safe = self._file.thread_safe
424396
max_length = len(self._sectors) * self._sector_size
425397
if length is not None and length > max_length:
426398
warnings.warn(
@@ -430,6 +402,37 @@ def __init__(self, parent, start, length=None):
430402
self._length = min(max_length, length or max_length)
431403
self._set_pos(0)
432404

405+
def close(self):
406+
try:
407+
self._file.close()
408+
finally:
409+
self._file = None
410+
411+
def _set_pos(self, value):
412+
self._sector_index = value // self._sector_size
413+
self._sector_offset = value % self._sector_size
414+
if self._sector_index < len(self._sectors):
415+
self._file.seek(
416+
self._header_size +
417+
(self._sectors[self._sector_index] * self._sector_size) +
418+
self._sector_offset)
419+
420+
def read1(self, n=-1):
421+
if n == -1:
422+
n = max(0, self._length - self.tell())
423+
else:
424+
n = max(0, min(n, self._length - self.tell()))
425+
n = min(n, self._sector_size - self._sector_offset)
426+
if n == 0:
427+
return b''
428+
result = self._file.read1(n)
429+
# Only perform a seek to a different sector if we've crossed into one
430+
if self._sector_offset + n < self._sector_size:
431+
self._sector_offset += n
432+
else:
433+
self._set_pos(self.tell() + n)
434+
return result
435+
433436

434437
class CompoundFileReader(object):
435438
"""
@@ -443,8 +446,7 @@ class CompoundFileReader(object):
443446
The class can be constructed with a filename or a file-like object. In the
444447
latter case, the object must support the ``read``, ``seek``, and ``tell``
445448
methods. For optimal usage, it should also provide a valid file descriptor
446-
in response to a call to ``fileno``, and provide a ``read1`` method, but
447-
these are not mandatory.
449+
in response to a call to ``fileno``, but this is not mandatory.
448450
449451
The :attr:`root` attribute represents the root storage entity in the
450452
compound document. An :meth:`open` method is provided which (given a
@@ -481,9 +483,17 @@ def __init__(self, filename_or_obj):
481483
if isinstance(filename_or_obj, (str, bytes)):
482484
self._opened = True
483485
self._file = io.open(filename_or_obj, 'rb')
484-
else:
486+
elif hasattr(filename_or_obj, 'fileno'):
485487
self._opened = False
486488
self._file = filename_or_obj
489+
else:
490+
# It's a file-like object without a valid file descriptor; copy its
491+
# content to a spooled temp file and use that for mmap
492+
filename_or_obj.seek(0)
493+
self._opened = True
494+
self._file = tempfile.SpooledTemporaryFile()
495+
shutil.copyfileobj(filename_or_obj, self._file)
496+
self._mmap = mmap.mmap(self._file.fileno(), 0, access=mmap.ACCESS_READ)
487497

488498
self._master_fat = None
489499
self._normal_fat = None
@@ -507,7 +517,7 @@ def __init__(self, filename_or_obj):
507517
self._mini_sector_count,
508518
self._master_first_sector,
509519
self._master_sector_count,
510-
) = COMPOUND_HEADER.unpack(self._file.read(COMPOUND_HEADER.size))
520+
) = COMPOUND_HEADER.unpack(self._mmap[:COMPOUND_HEADER.size])
511521

512522
# Check the header for basic correctness
513523
if magic != COMPOUND_MAGIC:
@@ -571,8 +581,7 @@ def __init__(self, filename_or_obj):
571581
warnings.warn(
572582
'unused header bytes are non-zero '
573583
'(%r)' % unused, CompoundFileWarning)
574-
self._file.seek(0, io.SEEK_END)
575-
self._file_size = self._file.tell()
584+
self._file_size = self._mmap.size()
576585
self._header_size = max(self._normal_sector_size, 512)
577586
self._max_sector = (self._file_size - self._header_size) // self._normal_sector_size
578587
self._load_normal_fat(self._load_master_fat())
@@ -613,20 +622,25 @@ def open(self, filename_or_entity):
613622
filename_or_entity.size)
614623

615624
def close(self):
616-
if self._opened:
617-
self._file.close()
625+
try:
626+
self._mmap.close()
627+
if self._opened:
628+
self._file.close()
629+
finally:
630+
self._mmap = None
631+
self._file = None
618632

619633
def __enter__(self):
620634
return self
621635

622636
def __exit__(self, exc_type, exc_value, traceback):
623637
self.close()
624638

625-
def _seek_sector(self, sector):
639+
def _read_sector(self, sector):
626640
if sector > self._max_sector:
627-
raise CompoundFileError('seek to invalid sector (%d)' % sector)
628-
self._file.seek(
629-
self._header_size + (sector * self._normal_sector_size))
641+
raise CompoundFileError('read from invalid sector (%d)' % sector)
642+
offset = self._header_size + (sector * self._normal_sector_size)
643+
return self._mmap[offset:offset + self._normal_sector_size]
630644

631645
def _load_master_fat(self):
632646
# Note: when reading the master-FAT we deliberately disregard the
@@ -643,9 +657,9 @@ def _load_master_fat(self):
643657

644658
# Special case: the first 109 entries are stored at the end of the file
645659
# header and the next sector of the master-FAT is stored in the header
646-
self._file.seek(COMPOUND_HEADER.size)
660+
offset = COMPOUND_HEADER.size
647661
self._master_fat.extend(
648-
st.unpack(b'<109L', self._file.read(109 * 4)))
662+
st.unpack(b'<109L', self._mmap[offset:offset + (109 * 4)]))
649663
sector = self._master_first_sector
650664
if count == 0 and sector == FREE_SECTOR:
651665
warnings.warn(
@@ -689,10 +703,9 @@ def _load_master_fat(self):
689703
# last value
690704
count -= 1
691705
sectors.add(sector)
692-
self._seek_sector(sector)
693706
self._master_fat.extend(
694707
self._normal_sector_format.unpack(
695-
self._file.read(self._normal_sector_format.size)))
708+
self._read_sector(sector)))
696709
# Guard against malicious files which could cause excessive memory
697710
# allocation when reading the normal-FAT. If the normal-FAT alone
698711
# would exceed 100Mb of RAM, raise an error
@@ -732,10 +745,9 @@ def _load_normal_fat(self, master_sectors):
732745
# of contiguous sectors? Or make the array lazy-read whenever a block
733746
# needs filling?
734747
for sector in self._master_fat:
735-
self._seek_sector(sector)
736748
self._normal_fat.extend(
737749
self._normal_sector_format.unpack(
738-
self._file.read(self._normal_sector_format.size)))
750+
self._read_sector(sector)))
739751

740752
# The following simply verifies that all normal-FAT and master-FAT
741753
# sectors are marked appropriately in the normal-FAT
@@ -1031,3 +1043,4 @@ def __repr__(self):
10311043
if self.isdir else
10321044
"<CompoundFileEntry ???>"
10331045
)
1046+

0 commit comments

Comments
 (0)