Skip to content

Commit

Permalink
Merge pull request #101 from zeya30/update-docstrings
Browse files Browse the repository at this point in the history
update docstrings
  • Loading branch information
dylanbouchard authored Jan 13, 2025
2 parents 791baca + 55cb1e4 commit cc8cf94
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 2 deletions.
8 changes: 8 additions & 0 deletions langfair/generator/counterfactual.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,10 +334,18 @@ async def generate_responses(
----------
dict
A dictionary with two keys: 'data' and 'metadata'.
'data' : dict
A dictionary containing the prompts and responses.
'prompt' : list
A list of prompts.
'response' : list
A list of responses corresponding to the prompts.
'metadata' : dict
A dictionary containing metadata about the generation process.
'non_completion_rate' : float
The rate at which the generation process did not complete.
'temperature' : float
Expand Down
4 changes: 4 additions & 0 deletions langfair/generator/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,14 +197,18 @@ async def generate_responses(
-------
dict
A dictionary with two keys: 'data' and 'metadata'.
'data' : dict
A dictionary containing the prompts and responses.
'prompt' : list
A list of prompts.
'response' : list
A list of responses corresponding to the prompts.
'metadata' : dict
A dictionary containing metadata about the generation process.
'non_completion_rate' : float
The rate at which the generation process did not complete.
'temperature' : float
Expand Down
21 changes: 19 additions & 2 deletions langfair/metrics/classification/metrics/baseclass/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from abc import ABC, abstractmethod
from typing import Optional
from typing import Optional,List

from numpy.typing import ArrayLike

Expand All @@ -38,7 +38,24 @@ def evaluate(
pass

@staticmethod
def binary_confusion_matrix(y_true, y_pred):
def binary_confusion_matrix(y_true, y_pred) -> List[List[float]]:
"""
Method for computing binary confusion matrix
Parameters
----------
y_true : Array-like
Binary labels (ground truth values)
y_pred : Array-like
Binary model predictions
Returns
-------
List[List[float]]
2x2 confusion matrix
"""
cm = [[0, 0], [0, 0]]
for i in range(len(y_pred)):
if y_pred[i] == y_true[i]:
Expand Down

0 comments on commit cc8cf94

Please sign in to comment.