Skip to content

Commit

Permalink
Removed register functions from formatters manager (#3376)
Browse files Browse the repository at this point in the history
  • Loading branch information
joachimmetz authored Jan 1, 2021
1 parent 73228a7 commit 1b32b47
Show file tree
Hide file tree
Showing 16 changed files with 147 additions and 260 deletions.
2 changes: 1 addition & 1 deletion plaso/formatters/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def _FormatMessage(self, format_string, event_values):
str: formatted message.
"""
if not isinstance(format_string, str):
logger.warning('Format string: {0:s} is non-Unicode.'.format(
logger.warning('Format string: {0!s} is non-Unicode.'.format(
format_string))

# Plaso code files should be in UTF-8 any thus binary strings are
Expand Down
97 changes: 12 additions & 85 deletions plaso/formatters/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,11 @@
class FormattersManager(object):
"""Class that implements the formatters manager."""

_custom_formatter_helpers = {}
_DEFAULT_FORMATTER = default.DefaultFormatter()

_formatter_classes = {}
_formatter_objects = {}
_custom_formatter_helpers = {}

# Keep track of the data types of the formatters that were read from
# file to prevent re-reading the formatter files during unit tests and
# so that the formatters manager can be reset to hardcoded formatters.
_formatters_from_file = []
_formatters = {}

@classmethod
def _ReadFormattersFile(cls, path):
Expand All @@ -44,34 +40,7 @@ def _ReadFormattersFile(cls, path):
if custom_formatter_helper:
formatter.AddHelper(custom_formatter_helper)

# TODO: refactor RegisterFormatter to only use formatter objects.
cls.RegisterFormatter(formatter)

cls._formatter_objects[data_type] = formatter

cls._formatters_from_file.append(data_type)

@classmethod
def DeregisterFormatter(cls, formatter_class):
"""Deregisters a formatter class.
The formatter classes are identified based on their lower case data type.
Args:
formatter_class (type): class of the formatter.
Raises:
KeyError: if formatter class is not set for the corresponding data type.
"""
data_type = formatter_class.DATA_TYPE.lower()
if data_type not in cls._formatter_classes:
raise KeyError('Formatter class not set for data type: {0:s}.'.format(
formatter_class.DATA_TYPE))

del cls._formatter_classes[data_type]

if data_type in cls._formatter_objects:
del cls._formatter_objects[data_type]
cls._formatters[data_type] = formatter

@classmethod
def GetFormatterObject(cls, data_type):
Expand All @@ -85,23 +54,13 @@ def GetFormatterObject(cls, data_type):
not available.
"""
data_type = data_type.lower()
if data_type not in cls._formatter_objects:
formatter_object = None

if data_type in cls._formatter_classes:
formatter_class = cls._formatter_classes[data_type]
# TODO: remove the need to instantiate the Formatter classes
# and use class methods only.
formatter_object = formatter_class()

if not formatter_object:
logger.warning('Using default formatter for data type: {0:s}'.format(
data_type))
formatter_object = default.DefaultFormatter()

cls._formatter_objects[data_type] = formatter_object
formatter_object = cls._formatters.get(data_type, None)
if not formatter_object:
logger.warning('Using default formatter for data type: {0:s}'.format(
data_type))
formatter_object = cls._DEFAULT_FORMATTER

return cls._formatter_objects[data_type]
return formatter_object

@classmethod
def ReadFormattersFromDirectory(cls, path):
Expand All @@ -115,7 +74,7 @@ def ReadFormattersFromDirectory(cls, path):
KeyError: if formatter class is already set for the corresponding
data type.
"""
if not cls._formatters_from_file:
if not cls._formatters:
for formatters_file_path in glob.glob(os.path.join(path, '*.yaml')):
cls._ReadFormattersFile(formatters_file_path)

Expand All @@ -130,7 +89,7 @@ def ReadFormattersFromFile(cls, path):
KeyError: if formatter class is already set for the corresponding
data type.
"""
if not cls._formatters_from_file:
if not cls._formatters:
cls._ReadFormattersFile(path)

@classmethod
Expand Down Expand Up @@ -171,35 +130,3 @@ def RegisterEventFormatterHelpers(cls, formatter_helper_classes):
"""
for formatter_helper_class in formatter_helper_classes:
cls.RegisterEventFormatterHelper(formatter_helper_class)

@classmethod
def RegisterFormatter(cls, formatter_class):
"""Registers a formatter class.
The formatter classes are identified based on their lower case data type.
Args:
formatter_class (type): class of the formatter.
Raises:
KeyError: if formatter class is already set for the corresponding
data type.
"""
data_type = formatter_class.DATA_TYPE.lower()
if data_type in cls._formatter_classes:
raise KeyError('Formatter class already set for data type: {0:s}.'.format(
formatter_class.DATA_TYPE))

cls._formatter_classes[data_type] = formatter_class

@classmethod
def Reset(cls):
"""Resets the manager to the hardcoded formatter classes.
This method is used during unit testing.
"""
for data_type in cls._formatters_from_file:
formatter_class = cls._formatter_objects[data_type]
cls.DeregisterFormatter(formatter_class)

cls._formatters_from_file = []
5 changes: 5 additions & 0 deletions test_data/formatters/format_test.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# YAML-based formatters file for testing format features.

type: 'basic'
data_type: 'test:event'
message: '{text}'
short_message: '{text}'
---
type: 'conditional'
data_type: 'test:fs:stat'
message:
Expand Down
22 changes: 10 additions & 12 deletions tests/formatters/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@ class BrokenConditionalEventFormatter(interface.ConditionalEventFormatter):
DATA_TYPE = 'test:broken_conditional'
FORMAT_STRING_PIECES = ['{too} {many} formatting placeholders']

SOURCE_SHORT = 'LOG'
SOURCE_LONG = 'Some Text File.'


class ConditionalTestEventFormatter(interface.ConditionalEventFormatter):
"""A test conditional event formatter."""
Expand All @@ -34,8 +31,12 @@ class ConditionalTestEventFormatter(interface.ConditionalEventFormatter):
'Optional: {optional}',
'Text: {text}']

SOURCE_SHORT = 'LOG'
SOURCE_LONG = 'Some Text File.'

class TestEventFormatter(interface.EventFormatter):
"""Test event formatter."""

DATA_TYPE = 'test:event'
FORMAT_STRING = '{text}'


class WrongEventFormatter(interface.EventFormatter):
Expand All @@ -44,9 +45,6 @@ class WrongEventFormatter(interface.EventFormatter):

FORMAT_STRING = 'This format string does not match {body}.'

SOURCE_SHORT = 'FILE'
SOURCE_LONG = 'Weird Log File'


class BooleanEventFormatterHelperTest(test_lib.EventFormatterTestCase):
"""Tests for the boolean event formatter helper."""
Expand Down Expand Up @@ -109,15 +107,15 @@ class EventFormatterTest(test_lib.EventFormatterTestCase):

def testInitialization(self):
"""Tests the initialization."""
event_formatter = test_lib.TestEventFormatter()
event_formatter = TestEventFormatter()
self.assertIsNotNone(event_formatter)

# TODO: add tests for _FormatMessage
# TODO: add tests for _FormatMessages

def testGetFormatStringAttributeNames(self):
"""Tests the GetFormatStringAttributeNames function."""
event_formatter = test_lib.TestEventFormatter()
event_formatter = TestEventFormatter()

expected_attribute_names = ['text']

Expand All @@ -126,7 +124,7 @@ def testGetFormatStringAttributeNames(self):

def testGetMessage(self):
"""Tests the GetMessage function."""
event_formatter = test_lib.TestEventFormatter()
event_formatter = TestEventFormatter()

_, event_data, _ = containers_test_lib.CreateEventFromValues(
self._TEST_EVENTS[0])
Expand All @@ -139,7 +137,7 @@ def testGetMessage(self):

def testGetMessageShort(self):
"""Tests the GetMessageShort function."""
event_formatter = test_lib.TestEventFormatter()
event_formatter = TestEventFormatter()

_, event_data, _ = containers_test_lib.CreateEventFromValues(
self._TEST_EVENTS[0])
Expand Down
59 changes: 10 additions & 49 deletions tests/formatters/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from plaso.lib import definitions

from tests import test_lib as shared_test_lib
from tests.formatters import test_lib


class FormattersManagerTest(shared_test_lib.BaseTestCase):
Expand Down Expand Up @@ -80,71 +79,33 @@ def testReadFormattersFile(self):
test_file_path = self._GetTestFilePath(['formatters', 'format_test.yaml'])
self._SkipIfPathNotExists(test_file_path)

manager.FormattersManager.Reset()
number_of_formatters = len(manager.FormattersManager._formatter_classes)

manager.FormattersManager._formatters = {}
manager.FormattersManager._ReadFormattersFile(test_file_path)
self.assertEqual(
len(manager.FormattersManager._formatter_classes),
number_of_formatters + 1)
self.assertEqual(len(manager.FormattersManager._formatters), 2)

manager.FormattersManager.Reset()
self.assertEqual(
len(manager.FormattersManager._formatter_classes),
number_of_formatters)
manager.FormattersManager._formatters = {}

def testReadFormattersFromDirectory(self):
"""Tests the ReadFormattersFromDirectory function."""
test_directory_path = self._GetTestFilePath(['formatters'])
self._SkipIfPathNotExists(test_directory_path)

manager.FormattersManager.Reset()
number_of_formatters = len(manager.FormattersManager._formatter_classes)

manager.FormattersManager._formatters = {}
manager.FormattersManager.ReadFormattersFromDirectory(test_directory_path)
self.assertEqual(
len(manager.FormattersManager._formatter_classes),
number_of_formatters + 1)
self.assertEqual(len(manager.FormattersManager._formatters), 2)

manager.FormattersManager.Reset()
self.assertEqual(
len(manager.FormattersManager._formatter_classes),
number_of_formatters)
manager.FormattersManager._formatters = {}

def testReadFormattersFromFile(self):
"""Tests the ReadFormattersFromFile function."""
test_file_path = self._GetTestFilePath(['formatters', 'format_test.yaml'])
self._SkipIfPathNotExists(test_file_path)

manager.FormattersManager.Reset()
number_of_formatters = len(manager.FormattersManager._formatter_classes)

manager.FormattersManager._formatters = {}
manager.FormattersManager.ReadFormattersFromFile(test_file_path)
self.assertEqual(
len(manager.FormattersManager._formatter_classes),
number_of_formatters + 1)

manager.FormattersManager.Reset()
self.assertEqual(
len(manager.FormattersManager._formatter_classes),
number_of_formatters)

def testFormatterRegistration(self):
"""Tests the RegisterFormatter and DeregisterFormatter functions."""
number_of_formatters = len(manager.FormattersManager._formatter_classes)

manager.FormattersManager.RegisterFormatter(test_lib.TestEventFormatter)
self.assertEqual(
len(manager.FormattersManager._formatter_classes),
number_of_formatters + 1)

with self.assertRaises(KeyError):
manager.FormattersManager.RegisterFormatter(test_lib.TestEventFormatter)

manager.FormattersManager.DeregisterFormatter(test_lib.TestEventFormatter)
self.assertEqual(
len(manager.FormattersManager._formatter_classes),
number_of_formatters)
self.assertEqual(len(manager.FormattersManager._formatters), 2)

manager.FormattersManager._formatters = {}


if __name__ == '__main__':
Expand Down
12 changes: 0 additions & 12 deletions tests/formatters/test_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,9 @@

from __future__ import unicode_literals

from plaso.formatters import interface

from tests import test_lib as shared_test_lib


class TestEventFormatter(interface.EventFormatter):
"""Test event formatter."""

DATA_TYPE = 'test:event'
FORMAT_STRING = '{text}'

SOURCE_SHORT = 'FILE'
SOURCE_LONG = 'Test log file'


class EventFormatterTestCase(shared_test_lib.BaseTestCase):
"""The unit test case for an event formatter."""

Expand Down
7 changes: 4 additions & 3 deletions tests/formatters/yaml_formatters_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def testReadFromFileObject(self):
with io.open(test_file_path, 'r', encoding='utf-8') as file_object:
formatters = list(test_formatters_file._ReadFromFileObject(file_object))

self.assertEqual(len(formatters), 1)
self.assertEqual(len(formatters), 2)

def testReadFromFile(self):
"""Tests the ReadFromFile function."""
Expand All @@ -94,9 +94,10 @@ def testReadFromFile(self):

formatters = test_formatters_file.ReadFromFile(test_file_path)

self.assertEqual(len(formatters), 1)
self.assertEqual(len(formatters), 2)

self.assertEqual(formatters[0].DATA_TYPE, 'test:fs:stat')
self.assertEqual(formatters[0].DATA_TYPE, 'test:event')
self.assertEqual(formatters[1].DATA_TYPE, 'test:fs:stat')


if __name__ == '__main__':
Expand Down
Loading

0 comments on commit 1b32b47

Please sign in to comment.