Skip to content

Commit

Permalink
Use a new KAGGLE_GRPC_DATA_PROXY_URL env variable for gRPC proxying (#…
Browse files Browse the repository at this point in the history
…1337)

http://b/308644984

---------

Co-authored-by: Prathamesh Bang <[email protected]>
  • Loading branch information
Philmod and psbang authored Dec 6, 2023
1 parent 2da7966 commit 2ac180a
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 9 deletions.
14 changes: 9 additions & 5 deletions patches/sitecustomize.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,9 @@ def post_import_logic(module):
if os.getenv('KAGGLE_DISABLE_GOOGLE_GENERATIVE_AI_INTEGRATION') != None:
return
if (os.getenv('KAGGLE_DATA_PROXY_TOKEN') == None or
os.getenv('KAGGLE_USER_SECRETS_TOKEN') == None or
os.getenv('KAGGLE_DATA_PROXY_URL') == None):
os.getenv('KAGGLE_USER_SECRETS_TOKEN') == None or
(os.getenv('KAGGLE_DATA_PROXY_URL') == None and
os.getenv('KAGGLE_GRPC_DATA_PROXY_URL') == None)):
return

old_configure = module.configure
Expand All @@ -101,12 +102,15 @@ def new_configure(*args, **kwargs):
client_options = kwargs['client_options']
else:
client_options = {}
client_options['api_endpoint'] = os.environ['KAGGLE_DATA_PROXY_URL']

if os.getenv('KAGGLE_GOOGLE_GENERATIVE_AI_USE_REST_ONLY') != None:
client_options['api_endpoint'] += '/palmapi'
kwargs['transport'] = 'rest'
elif 'transport' in kwargs and kwargs['transport'] == 'rest':

if 'transport' in kwargs and kwargs['transport'] == 'rest':
client_options['api_endpoint'] = os.environ['KAGGLE_DATA_PROXY_URL']
client_options['api_endpoint'] += '/palmapi'
else:
client_options['api_endpoint'] = os.environ['KAGGLE_GRPC_DATA_PROXY_URL']
kwargs['client_options'] = client_options

old_configure(*args, **kwargs)
Expand Down
1 change: 1 addition & 0 deletions tests/test_google_generativeai_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def test_proxy_enabled(self):
env.set("KAGGLE_USER_SECRETS_TOKEN", secrets_token)
env.set("KAGGLE_DATA_PROXY_TOKEN", proxy_token)
env.set("KAGGLE_DATA_PROXY_URL", self.endpoint)
env.set("KAGGLE_GRPC_DATA_PROXY_URL", "http://127.0.0.1:50001")
env.set("KAGGLE_GOOGLE_GENERATIVE_AI_USE_REST_ONLY", "True")
server_address = urlparse(self.endpoint)
with env:
Expand Down
11 changes: 7 additions & 4 deletions tests/test_google_generativeai_patch_disabled.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,24 @@ def do_HEAD(self):
self.send_response(200)

def do_GET(self):
print('YO MOD', self.path)
HTTPHandler.called = True
self.send_response(200)
self.send_header("Content-type", "application/json")
self.end_headers()

class TestGoogleGenerativeAiPatchDisabled(unittest.TestCase):
endpoint = "http://127.0.0.1:80"
http_endpoint = "http://127.0.0.1:80"
grpc_endpoint = "http://127.0.0.1:50001"

def test_disabled(self):
env = EnvironmentVarGuard()
env.set("KAGGLE_USER_SECRETS_TOKEN", "foobar")
env.set("KAGGLE_DATA_PROXY_TOKEN", "foobar")
env.set("KAGGLE_DATA_PROXY_URL", self.endpoint)
env.set("KAGGLE_DATA_PROXY_URL", self.http_endpoint)
env.set("KAGGLE_GRPC_DATA_PROXY_URL", self.grpc_endpoint)
env.set("KAGGLE_DISABLE_GOOGLE_GENERATIVE_AI_INTEGRATION", "True")
server_address = urlparse(self.endpoint)
server_address = urlparse(self.http_endpoint)
with env:
with HTTPServer((server_address.hostname, server_address.port), HTTPHandler) as httpd:
threading.Thread(target=httpd.serve_forever).start()
Expand All @@ -40,4 +43,4 @@ def test_disabled(self):
except:
pass
httpd.shutdown()
self.assertFalse(HTTPHandler.called)
self.assertFalse(HTTPHandler.called)

0 comments on commit 2ac180a

Please sign in to comment.