Skip to content

Commit

Permalink
feat: LLM - Added the seed parameter to the TextGenerationModel's…
Browse files Browse the repository at this point in the history
… `predict` methods

Copybara import of the project:

--
6e23d68 by Murat Eken <[email protected]>:

fix: missing request parameters

--
cf2bff5 by Murat Eken <[email protected]>:

removing the echo parameter, this fix only includes seed

COPYBARA_INTEGRATE_REVIEW=#3186 from meken:main 32877b4
PiperOrigin-RevId: 638066958
  • Loading branch information
meken authored and Copybara-Service committed May 28, 2024
1 parent 3e4fc18 commit cb2f4aa
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 0 deletions.
2 changes: 2 additions & 0 deletions tests/unit/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1920,6 +1920,7 @@ def test_text_generation_ga(self):
presence_penalty=1.0,
frequency_penalty=1.0,
logit_bias={1: 100.0, 2: -100.0},
seed=42,
)

expected_errors = (100,)
Expand All @@ -1933,6 +1934,7 @@ def test_text_generation_ga(self):
assert prediction_parameters["presencePenalty"] == 1.0
assert prediction_parameters["frequencyPenalty"] == 1.0
assert prediction_parameters["logitBias"] == {1: 100.0, 2: -100.0}
assert prediction_parameters["seed"] == 42
assert response.text == _TEST_TEXT_GENERATION_PREDICTION["content"]
assert response.errors == expected_errors

Expand Down
42 changes: 42 additions & 0 deletions vertexai/language_models/_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1355,6 +1355,7 @@ def predict(
presence_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[Dict[int, float]] = None,
seed: Optional[int] = None,
) -> "MultiCandidateTextGenerationResponse":
"""Gets model response for a single prompt.
Expand Down Expand Up @@ -1387,6 +1388,12 @@ def predict(
Larger positive bias increases the probability of choosing the token.
Smaller negative bias decreases the probability of choosing the token.
Range: [-100.0, 100.0]
seed:
Decoder generates random noise with a pseudo random number generator, temperature * noise is added to
logits before sampling. The pseudo random number generator (prng) takes a seed as input, it generates
the same output with the same seed. If seed is not set, the seed used in decoder will not be
deterministic, thus the generated random noise will not be deterministic. If seed is set, the
generated random noise will be deterministic.
Returns:
A `MultiCandidateTextGenerationResponse` object that contains the text produced by the model.
Expand All @@ -1404,6 +1411,7 @@ def predict(
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
seed=seed,
)

prediction_response = self._endpoint.predict(
Expand Down Expand Up @@ -1436,6 +1444,7 @@ async def predict_async(
presence_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[Dict[int, float]] = None,
seed: Optional[int] = None,
) -> "MultiCandidateTextGenerationResponse":
"""Asynchronously gets model response for a single prompt.
Expand Down Expand Up @@ -1468,6 +1477,12 @@ async def predict_async(
Larger positive bias increases the probability of choosing the token.
Smaller negative bias decreases the probability of choosing the token.
Range: [-100.0, 100.0]
seed:
Decoder generates random noise with a pseudo random number generator, temperature * noise is added to
logits before sampling. The pseudo random number generator (prng) takes a seed as input, it generates
the same output with the same seed. If seed is not set, the seed used in decoder will not be
deterministic, thus the generated random noise will not be deterministic. If seed is set, the
generated random noise will be deterministic.
Returns:
A `MultiCandidateTextGenerationResponse` object that contains the text produced by the model.
Expand All @@ -1485,6 +1500,7 @@ async def predict_async(
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
seed=seed,
)

prediction_response = await self._endpoint.predict_async(
Expand All @@ -1509,6 +1525,7 @@ def predict_streaming(
presence_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[Dict[int, float]] = None,
seed: Optional[int] = None,
) -> Iterator[TextGenerationResponse]:
"""Gets a streaming model response for a single prompt.
Expand Down Expand Up @@ -1541,6 +1558,12 @@ def predict_streaming(
Larger positive bias increases the probability of choosing the token.
Smaller negative bias decreases the probability of choosing the token.
Range: [-100.0, 100.0]
seed:
Decoder generates random noise with a pseudo random number generator, temperature * noise is added to
logits before sampling. The pseudo random number generator (prng) takes a seed as input, it generates
the same output with the same seed. If seed is not set, the seed used in decoder will not be
deterministic, thus the generated random noise will not be deterministic. If seed is set, the
generated random noise will be deterministic.
Yields:
A stream of `TextGenerationResponse` objects that contain partial
Expand All @@ -1557,6 +1580,7 @@ def predict_streaming(
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
seed=seed,
)

prediction_service_client = self._endpoint._prediction_client
Expand Down Expand Up @@ -1587,6 +1611,7 @@ async def predict_streaming_async(
presence_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[Dict[int, float]] = None,
seed: Optional[int] = None,
) -> AsyncIterator[TextGenerationResponse]:
"""Asynchronously gets a streaming model response for a single prompt.
Expand Down Expand Up @@ -1619,6 +1644,12 @@ async def predict_streaming_async(
Larger positive bias increases the probability of choosing the token.
Smaller negative bias decreases the probability of choosing the token.
Range: [-100.0, 100.0]
seed:
Decoder generates random noise with a pseudo random number generator, temperature * noise is added to
logits before sampling. The pseudo random number generator (prng) takes a seed as input, it generates
the same output with the same seed. If seed is not set, the seed used in decoder will not be
deterministic, thus the generated random noise will not be deterministic. If seed is set, the
generated random noise will be deterministic.
Yields:
A stream of `TextGenerationResponse` objects that contain partial
Expand All @@ -1635,6 +1666,7 @@ async def predict_streaming_async(
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
seed=seed,
)

prediction_service_async_client = self._endpoint._prediction_async_client
Expand Down Expand Up @@ -1671,6 +1703,7 @@ def _create_text_generation_prediction_request(
presence_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[Dict[int, int]] = None,
seed: Optional[int] = None,
) -> "_PredictionRequest":
"""Prepares the text generation request for a single prompt.
Expand Down Expand Up @@ -1703,6 +1736,12 @@ def _create_text_generation_prediction_request(
Larger positive bias increases the probability of choosing the token.
Smaller negative bias decreases the probability of choosing the token.
Range: [-100.0, 100.0]
seed:
Decoder generates random noise with a pseudo random number generator, temperature * noise is added to
logits before sampling. The pseudo random number generator (prng) takes a seed as input, it generates
the same output with the same seed. If seed is not set, the seed used in decoder will not be
deterministic, thus the generated random noise will not be deterministic. If seed is set, the
generated random noise will be deterministic.
Returns:
A `_PredictionRequest` object that contains prediction instance and parameters.
Expand Down Expand Up @@ -1749,6 +1788,9 @@ def _create_text_generation_prediction_request(
if logit_bias is not None:
prediction_parameters["logitBias"] = logit_bias

if seed is not None:
prediction_parameters["seed"] = seed

return _PredictionRequest(
instance=instance,
parameters=prediction_parameters,
Expand Down

0 comments on commit cb2f4aa

Please sign in to comment.