Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,14 @@ Download or [clone](https://www.mathworks.com/help/matlab/matlab_prog/use-source
## Example: Classify Text Data Using BERT
The simplest use of a pretrained BERT model is to use it as a feature extractor. In particular, you can use the BERT model to convert documents to feature vectors which you can then use as inputs to train a deep learning classification network.

The example [`ClassifyTextDataUsingBERT.m`](./ClassifyTextDataUsingBERT.m) shows how to use a pretrained BERT model to classify failure events given a data set of factory reports.
The example [`ClassifyTextDataUsingBERT.m`](./ClassifyTextDataUsingBERT.m) shows how to use a pretrained BERT model to classify failure events given a data set of factory reports. This example requires the `factoryReports.csv` data set from the Text Analytics Toolbox example [Prepare Text Data for Analysis](https://www.mathworks.com/help/textanalytics/ug/prepare-text-data-for-analysis.html).

## Example: Fine-Tune Pretrained BERT Model
To get the most out of a pretrained BERT model, you can retrain and fine tune the BERT parameters weights for your task.

The example [`FineTuneBERT.m`](./FineTuneBERT.m) shows how to fine-tune a pretrained BERT model to classify failure events given a data set of factory reports.
The example [`FineTuneBERT.m`](./FineTuneBERT.m) shows how to fine-tune a pretrained BERT model to classify failure events given a data set of factory reports. This example requires the `factoryReports.csv` data set from the Text Analytics Toolbox example [Prepare Text Data for Analysis](https://www.mathworks.com/help/textanalytics/ug/prepare-text-data-for-analysis.html).

The example [`FineTuneBERTJapanese.m`](./FineTuneBERTJapanese.m) shows the same workflow using a pretrained Japanese-BERT model. This example requires the `factoryReportsJP.csv` data set from the Text Analytics Toolbox example [Analyze Japanese Text Data](https://www.mathworks.com/help/textanalytics/ug/analyze-japanese-text.html), available in R2023a or later.

## Example: Analyze Sentiment with FinBERT
FinBERT is a sentiment analysis model trained on financial text data and fine-tuned for sentiment analysis.
Expand Down
4 changes: 2 additions & 2 deletions predictMaskedToken.m
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
% replaces instances of mdl.Tokenizer.MaskToken in the string text with
% the most likely token according to the BERT model mdl.

% Copyright 2021 The MathWorks, Inc.
% Copyright 2021-2023 The MathWorks, Inc.
arguments
mdl {mustBeA(mdl,'struct')}
str {mustBeText}
Expand Down Expand Up @@ -44,7 +44,7 @@
tokens = fulltok.tokenize(pieces(i));
if ~isempty(tokens)
% "" tokenizes to empty - awkward
x = cat(2,x,fulltok.encode(tokens));
x = cat(2,x,fulltok.encode(tokens{1}));
end
if i<numel(pieces)
x = cat(2,x,maskCode);
Expand Down
47 changes: 47 additions & 0 deletions test/tpredictMaskedToken.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
classdef(SharedTestFixtures={
DownloadBERTFixture, DownloadJPBERTFixture}) tpredictMaskedToken < matlab.unittest.TestCase
% tpredictMaskedToken Unit test for predictMaskedToken

% Copyright 2023 The MathWorks, Inc.

properties(TestParameter)
Models = {"tiny","japanese-base-wwm"}
ValidText = iGetValidText;
end

methods(Test)
function verifyOutputDimSizes(test, Models, ValidText)
inSize = size(ValidText);
mdl = bert("Model", Models);
outputText = predictMaskedToken(mdl,ValidText);
test.verifyEqual(size(outputText), inSize);
end

function maskTokenIsRemoved(test, Models)
text = "This has a [MASK] token.";
mdl = bert("Model", Models);
outputText = predictMaskedToken(mdl,text);
test.verifyFalse(contains(outputText, "[MASK]"));
end

function inputWithoutMASKRemainsTheSame(test, Models)
text = "This has a no mask token.";
mdl = bert("Model", Models);
outputText = predictMaskedToken(mdl,text);
test.verifyEqual(text, outputText);
end
end
end

function validText = iGetValidText
manyStrs = ["Accelerating the pace of [MASK] and science";
"The cat [MASK] soundly.";
"The [MASK] set beautifully."];
singleStr = "Artificial intelligence continues to shape the future of industries," + ...
" as innovative applications emerge in fields such as healthcare, transportation," + ...
" entertainment, and finance, driving productivity and enhancing human capabilities.";
validText = struct('StringsAsColumns',manyStrs,...
'StringsAsRows',manyStrs',...
'ManyStrings',repmat(singleStr,3),...
'SingleString',singleStr);
end