Skip to content

Commit 8bddeec

Browse files
authored
Add custom formatter for Fire result (#345)
Fixes #344 (see issue for more details) This lets you define a function that will take the result from the Fire component and allows the user to alter it before fire looks at it to render it.
1 parent 8469e48 commit 8bddeec

File tree

2 files changed

+34
-3
lines changed

2 files changed

+34
-3
lines changed

fire/core.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def main(argv):
7878
import asyncio # pylint: disable=import-error,g-import-not-at-top # pytype: disable=import-error
7979

8080

81-
def Fire(component=None, command=None, name=None):
81+
def Fire(component=None, command=None, name=None, serialize=None):
8282
"""This function, Fire, is the main entrypoint for Python Fire.
8383
8484
Executes a command either from the `command` argument or from sys.argv by
@@ -164,7 +164,7 @@ def Fire(component=None, command=None, name=None):
164164
raise FireExit(0, component_trace)
165165

166166
# The command succeeded normally; print the result.
167-
_PrintResult(component_trace, verbose=component_trace.verbose)
167+
_PrintResult(component_trace, verbose=component_trace.verbose, serialize=serialize)
168168
result = component_trace.GetResult()
169169
return result
170170

@@ -241,12 +241,19 @@ def _IsHelpShortcut(component_trace, remaining_args):
241241
return show_help
242242

243243

244-
def _PrintResult(component_trace, verbose=False):
244+
def _PrintResult(component_trace, verbose=False, serialize=None):
245245
"""Prints the result of the Fire call to stdout in a human readable way."""
246246
# TODO(dbieber): Design human readable deserializable serialization method
247247
# and move serialization to its own module.
248248
result = component_trace.GetResult()
249249

250+
# Allow users to modify the return value of the component and provide
251+
# custom formatting.
252+
if serialize:
253+
if not callable(serialize):
254+
raise FireError("serialize argument {} must be empty or callable.".format(serialize))
255+
result = serialize(result)
256+
250257
if value_types.HasCustomStr(result):
251258
# If the object has a custom __str__ method, rather than one inherited from
252259
# object, then we use that to serialize the object.

fire/core_test.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,30 @@ def testClassMethod(self):
194194
7,
195195
)
196196

197+
def testCustomSerialize(self):
198+
def serialize(x):
199+
if isinstance(x, list):
200+
return ', '.join(str(xi) for xi in x)
201+
if isinstance(x, dict):
202+
return ', '.join('{}={!r}'.format(k, v) for k, v in x.items())
203+
if x == 'special':
204+
return ['SURPRISE!!', "I'm a list!"]
205+
return x
206+
207+
ident = lambda x: x
208+
209+
with self.assertOutputMatches(stdout='a, b', stderr=None):
210+
result = core.Fire(ident, command=['[a,b]'], serialize=serialize)
211+
with self.assertOutputMatches(stdout='a=5, b=6', stderr=None):
212+
result = core.Fire(ident, command=['{a:5,b:6}'], serialize=serialize)
213+
with self.assertOutputMatches(stdout='asdf', stderr=None):
214+
result = core.Fire(ident, command=['asdf'], serialize=serialize)
215+
with self.assertOutputMatches(stdout="SURPRISE!!\nI'm a list!\n", stderr=None):
216+
result = core.Fire(ident, command=['special'], serialize=serialize)
217+
with self.assertRaises(core.FireError):
218+
core.Fire(ident, command=['asdf'], serialize=55)
219+
220+
197221
@testutils.skipIf(six.PY2, 'lru_cache is Python 3 only.')
198222
def testLruCacheDecoratorBoundArg(self):
199223
self.assertEqual(

0 commit comments

Comments
 (0)