Skip to content

Commit 3001b07

Browse files
dmccrystals0h3yl
authored andcommitted
feat: add sentiment_analysis functionality
GitOrigin-RevId: 121a1e26b0802319f3e88382e7df8b6e9fbc7947
1 parent 0453022 commit 3001b07

File tree

5 files changed

+190
-19
lines changed

5 files changed

+190
-19
lines changed

README.md

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,41 @@ config=aai.TranscriptionConfig(
356356
)
357357
```
358358

359+
</details>
360+
<details>
361+
<summary>Analyze the Sentiment of Sentences in a Transcript</summary>
362+
363+
```python
364+
import assemblyai as aai
365+
366+
transcriber = aai.Transcriber()
367+
transcript = transcriber.transcribe(
368+
"https://example.org/audio.mp3",
369+
config=aai.TranscriptionConfig(sentiment_analysis=True)
370+
)
371+
372+
for sentiment_result in transcript.sentiment_analysis_results:
373+
print(sentiment_result.text)
374+
print(sentiment_result.sentiment) # POSITIVE, NEUTRAL, or NEGATIVE
375+
print(sentiment_result.confidence)
376+
print(f"Timestamp: {sentiment_result.timestamp.start} - {sentiment_result.timestamp.end}")
377+
```
378+
379+
If `speaker_labels` is also enabled, then each sentiment analysis result will also include a `speaker` field.
380+
381+
```python
382+
# ...
383+
384+
config = aai.TranscriptionConfig(sentiment_analysis=True, speaker_labels=True)
385+
386+
# ...
387+
388+
for sentiment_result in transcript.sentiment_analysis_results:
389+
print(sentiment_result.speaker)
390+
```
391+
392+
[Read more about sentiment analysis here.](https://www.assemblyai.com/docs/Models/sentiment_analysis)
393+
359394
</details>
360395

361396
---

assemblyai/transcriber.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,10 @@ def chapters(self) -> Optional[List[types.Chapter]]:
218218
def content_safety_labels(self) -> Optional[types.ContentSafetyResponse]:
219219
return self._impl.transcript.content_safety_labels
220220

221+
@property
222+
def sentiment_analysis_results(self) -> Optional[List[types.Sentiment]]:
223+
return self._impl.transcript.sentiment_analysis_results
224+
221225
@property
222226
def status(self) -> types.TranscriptStatus:
223227
"The current status of the transcript"

assemblyai/types.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -354,8 +354,8 @@ class RawTranscriptionConfig(BaseModel):
354354
disfluencies: Optional[bool]
355355
"Transcribe Filler Words, like 'umm', in your media file."
356356

357-
# sentiment_analysis: bool = False
358-
# "Enable Sentiment Analysis."
357+
sentiment_analysis: Optional[bool]
358+
"Enable Sentiment Analysis."
359359

360360
auto_chapters: Optional[bool]
361361
"Enable Auto Chapters."
@@ -418,7 +418,7 @@ def __init__(
418418
# iab_categories: bool = False,
419419
custom_spelling: Optional[Dict[str, Union[str, Sequence[str]]]] = None,
420420
disfluencies: Optional[bool] = None,
421-
# sentiment_analysis: bool = False,
421+
sentiment_analysis: Optional[bool] = None,
422422
auto_chapters: Optional[bool] = None,
423423
# entity_detection: bool = False,
424424
summarization: Optional[bool] = None,
@@ -494,7 +494,7 @@ def __init__(
494494
# self.iab_categories = iab_categories
495495
self.set_custom_spelling(custom_spelling, override=True)
496496
self.disfluencies = disfluencies
497-
# self.sentiment_analysis = sentiment_analysis
497+
self.sentiment_analysis = sentiment_analysis
498498
self.auto_chapters = auto_chapters
499499
# self.entity_detection = entity_detection
500500
self.set_summarize(
@@ -733,17 +733,17 @@ def disfluencies(self, enable: Optional[bool]) -> None:
733733

734734
return self
735735

736-
# @property
737-
# def sentiment_analysis(self) -> bool:
738-
# "Returns the status of the Sentiment Analysis feature."
736+
@property
737+
def sentiment_analysis(self) -> Optional[bool]:
738+
"Returns the status of the Sentiment Analysis feature."
739739

740-
# return self._raw_transcription_config.sentiment_analysis
740+
return self._raw_transcription_config.sentiment_analysis
741741

742-
# @sentiment_analysis.setter
743-
# def sentiment_analysis(self, enable: bool) -> None:
744-
# "Enable Sentiment Analysis."
742+
@sentiment_analysis.setter
743+
def sentiment_analysis(self, enable: Optional[bool]) -> None:
744+
"Enable Sentiment Analysis."
745745

746-
# self._raw_transcription_config.sentiment_analysis = enable
746+
self._raw_transcription_config.sentiment_analysis = enable
747747

748748
@property
749749
def auto_chapters(self) -> bool:
@@ -752,7 +752,7 @@ def auto_chapters(self) -> bool:
752752
return self._raw_transcription_config.auto_chapters
753753

754754
@auto_chapters.setter
755-
def auto_chapters(self, enable: bool) -> None:
755+
def auto_chapters(self, enable: Optional[bool]) -> None:
756756
"Enable Auto Chapters."
757757

758758
# Validate required params are also set
@@ -1243,6 +1243,7 @@ class IABResponse(BaseModel):
12431243

12441244
class Sentiment(Word):
12451245
sentiment: SentimentType
1246+
speaker: Optional[str]
12461247

12471248

12481249
class Entity(BaseModel):
@@ -1363,8 +1364,8 @@ class BaseTranscript(BaseModel):
13631364
disfluencies: Optional[bool]
13641365
"Transcribe Filler Words, like 'umm', in your media file."
13651366

1366-
# sentiment_analysis: bool = False
1367-
# "Enable Sentiment Analysis."
1367+
sentiment_analysis: Optional[bool]
1368+
"Enable Sentiment Analysis."
13681369

13691370
auto_chapters: Optional[bool]
13701371
"Enable Auto Chapters."
@@ -1451,10 +1452,10 @@ class TranscriptResponse(BaseTranscript):
14511452
# "The list of results when Topic Detection is enabled"
14521453

14531454
chapters: Optional[List[Chapter]]
1454-
# "When Auto Chapters is enabled, the list of Auto Chapters results"
1455+
"When Auto Chapters is enabled, the list of Auto Chapters results"
14551456

1456-
# sentiment_analysis_results: Optional[List[Sentiment]] = None
1457-
# "When Sentiment Analysis is enabled, the list of Sentiment Analysis results"
1457+
sentiment_analysis_results: Optional[List[Sentiment]]
1458+
"When Sentiment Analysis is enabled, the list of Sentiment Analysis results"
14581459

14591460
# entities: Optional[List[Entity]] = None
14601461
# "When Entity Detection is enabled, the list of detected Entities"

tests/unit/test_content_safety.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@ def __submit_mock_request(
7777
Helper function to abstract mock transcriber calls with given `TranscriptionConfig`,
7878
and perform some common assertions.
7979
"""
80-
print(mock_response)
8180

8281
mock_transcript_id = mock_response.get("id", "mock_id")
8382

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
import json
2+
from typing import Any, Dict, Tuple
3+
4+
import factory
5+
import httpx
6+
from pytest_httpx import HTTPXMock
7+
8+
import assemblyai as aai
9+
from tests.unit import factories
10+
11+
aai.settings.api_key = "test"
12+
13+
14+
class SentimentFactory(factories.WordFactory):
15+
sentiment = factory.Faker("enum", enum_cls=aai.types.SentimentType)
16+
speaker = factory.Faker("name")
17+
18+
19+
class SentimentAnalysisResponseFactory(factories.TranscriptCompletedResponseFactory):
20+
sentiment_analysis_results = factory.List([factory.SubFactory(SentimentFactory)])
21+
22+
23+
def __submit_mock_request(
24+
httpx_mock: HTTPXMock,
25+
mock_response: Dict[str, Any],
26+
config: aai.TranscriptionConfig,
27+
) -> Tuple[Dict[str, Any], aai.Transcript]:
28+
"""
29+
Helper function to abstract mock transcriber calls with given `TranscriptionConfig`,
30+
and perform some common assertions.
31+
"""
32+
33+
mock_transcript_id = mock_response.get("id", "mock_id")
34+
35+
# Mock initial submission response (transcript is processing)
36+
mock_processing_response = factories.generate_dict_factory(
37+
factories.TranscriptProcessingResponseFactory
38+
)()
39+
40+
httpx_mock.add_response(
41+
url=f"{aai.settings.base_url}/transcript",
42+
status_code=httpx.codes.OK,
43+
method="POST",
44+
json={
45+
**mock_processing_response,
46+
"id": mock_transcript_id, # inject ID from main mock response
47+
},
48+
)
49+
50+
# Mock polling-for-completeness response, with completed transcript
51+
httpx_mock.add_response(
52+
url=f"{aai.settings.base_url}/transcript/{mock_transcript_id}",
53+
status_code=httpx.codes.OK,
54+
method="GET",
55+
json=mock_response,
56+
)
57+
58+
# == Make API request via SDK ==
59+
transcript = aai.Transcriber().transcribe(
60+
data="https://example.org/audio.wav",
61+
config=config,
62+
)
63+
64+
# Check that submission and polling requests were made
65+
assert len(httpx_mock.get_requests()) == 2
66+
67+
# Extract body of initial submission request
68+
request = httpx_mock.get_requests()[0]
69+
request_body = json.loads(request.content.decode())
70+
71+
return request_body, transcript
72+
73+
74+
def test_sentiment_analysis_disabled_by_default(httpx_mock: HTTPXMock):
75+
"""
76+
Tests that excluding `sentiment_analysis` from the `TranscriptionConfig` will
77+
result in the default behavior of it being excluded from the request body
78+
"""
79+
request_body, transcript = __submit_mock_request(
80+
httpx_mock,
81+
mock_response=factories.generate_dict_factory(
82+
factories.TranscriptCompletedResponseFactory
83+
)(),
84+
config=aai.TranscriptionConfig(),
85+
)
86+
assert request_body.get("sentiment_analysis") is None
87+
assert transcript.sentiment_analysis_results is None
88+
89+
90+
def test_sentiment_analysis_enabled(httpx_mock: HTTPXMock):
91+
"""
92+
Tests that including `sentiment_analysis=True` in the `TranscriptionConfig`
93+
will result in `sentiment_analysis=True` in the request body, and that the
94+
response is properly parsed into a `Transcript` object
95+
"""
96+
mock_response = factories.generate_dict_factory(SentimentAnalysisResponseFactory)()
97+
request_body, transcript = __submit_mock_request(
98+
httpx_mock,
99+
mock_response=mock_response,
100+
config=aai.TranscriptionConfig(sentiment_analysis=True),
101+
)
102+
103+
# Check that request body was properly defined
104+
assert request_body.get("sentiment_analysis") == True
105+
106+
# Check that transcript was properly parsed from JSON response
107+
assert transcript.error is None
108+
109+
assert transcript.sentiment_analysis_results is not None
110+
assert len(transcript.sentiment_analysis_results) > 0
111+
assert len(transcript.sentiment_analysis_results) == len(
112+
mock_response["sentiment_analysis_results"]
113+
)
114+
115+
for response_sentiment_result, transcript_sentiment_result in zip(
116+
mock_response["sentiment_analysis_results"],
117+
transcript.sentiment_analysis_results,
118+
):
119+
assert transcript_sentiment_result.text == response_sentiment_result["text"]
120+
assert transcript_sentiment_result.start == response_sentiment_result["start"]
121+
assert transcript_sentiment_result.end == response_sentiment_result["end"]
122+
assert (
123+
transcript_sentiment_result.confidence
124+
== response_sentiment_result["confidence"]
125+
)
126+
assert (
127+
transcript_sentiment_result.sentiment.value
128+
== response_sentiment_result["sentiment"]
129+
)
130+
assert (
131+
transcript_sentiment_result.speaker == response_sentiment_result["speaker"]
132+
)

0 commit comments

Comments
 (0)