diff --git a/lib/html/pipeline/sanitization_filter.rb b/lib/html/pipeline/sanitization_filter.rb
index e804b533..b5d65cf7 100644
--- a/lib/html/pipeline/sanitization_filter.rb
+++ b/lib/html/pipeline/sanitization_filter.rb
@@ -25,15 +25,16 @@ class SanitizationFilter < Filter
# of places we're using tables to contain formatted user content (like pull
# request review comments).
TABLE_ITEMS = Set.new(%w(tr td th).freeze)
- TABLE = 'table'.freeze
+ TABLE = 'table'.freeze
+ TABLE_SECTIONS = Set.new(%w(thead tbody tfoot).freeze)
# The main sanitization whitelist. Only these elements and attributes are
# allowed through by default.
WHITELIST = {
:elements => %w(
h1 h2 h3 h4 h5 h6 h7 h8 br b i strong em a pre code img tt
- div ins del sup sub p ol ul table blockquote dl dt dd
- kbd q samp var hr ruby rt rp li tr td th
+ div ins del sup sub p ol ul table thead tbody tfoot blockquote
+ dl dt dd kbd q samp var hr ruby rt rp li tr td th
),
:remove_contents => ['script'],
:attributes => {
@@ -75,7 +76,7 @@ class SanitizationFilter < Filter
# Table child elements that are not contained by a
are removed.
lambda { |env|
name, node = env[:node_name], env[:node]
- if TABLE_ITEMS.include?(name) && !node.ancestors.any? { |n| n.name == TABLE }
+ if (TABLE_SECTIONS.include?(name) || TABLE_ITEMS.include?(name)) && !node.ancestors.any? { |n| n.name == TABLE }
node.replace(node.children)
end
}
@@ -103,4 +104,4 @@ def whitelist
end
end
end
-end
\ No newline at end of file
+end
diff --git a/test/html/pipeline/sanitization_filter_test.rb b/test/html/pipeline/sanitization_filter_test.rb
index 45d34d96..db9c98df 100644
--- a/test/html/pipeline/sanitization_filter_test.rb
+++ b/test/html/pipeline/sanitization_filter_test.rb
@@ -49,4 +49,23 @@ def test_script_contents_are_removed
orig = ''
assert_equal "", SanitizationFilter.call(orig).to_s
end
+
+ def test_table_rows_and_cells_removed_if_not_in_table
+ orig = %(Foo |
Bar | )
+ assert_equal 'FooBar', SanitizationFilter.call(orig).to_s
+ end
+
+ def test_table_sections_removed_if_not_in_table
+ orig = %(Foo |
)
+ assert_equal 'Foo', SanitizationFilter.call(orig).to_s
+ end
+
+ def test_table_sections_are_not_removed
+ orig = %()
+ assert_equal orig, SanitizationFilter.call(orig).to_s
+ end
end