Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions resources_servers/workbench/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
6. Legal Approval Status: TBD

Rollouts -
Link: https://huggingface.co/datasets/Nexusflow/abhibha-traj-coll-workbench
Link: https://huggingface.co/datasets/Nexusflow/abhibha-gpt-rollouts-completions-fixed-tools

Commands -
Spin up server:
Expand All @@ -28,11 +28,12 @@ ng_collect_rollouts +agent_name=workbench_simple_agent \
+limit=1
```

Data links: https://gitlab-master.nvidia.com/bxyu/nemo-gym/-/ml/models/55/versions/69#/
Data links:
Nemogym prompt datasets: https://gitlab-master.nvidia.com/bxyu/nemo-gym/-/ml/models/55/versions/98#/

# Licensing information
Code: Apache 2.0
Data: Apache 2.0

Dependencies
- nemo_gym: Apache 2.0
- nemo_gym: Apache 2.0
175 changes: 40 additions & 135 deletions resources_servers/workbench/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,36 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Dict
from typing import Any, Dict

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, ConfigDict
from fastapi import FastAPI, HTTPException, Request
from pydantic import BaseModel, ConfigDict, Field

from nemo_gym.base_resources_server import (
BaseResourcesServerConfig,
BaseSeedSessionRequest,
BaseSeedSessionResponse,
BaseVerifyRequest,
BaseVerifyResponse,
SimpleResourcesServer,
)
from resources_servers.workbench.utils import is_correct
from resources_servers.workbench.workbench_tools.analytics import (
AnalyticsTool,
)
from resources_servers.workbench.workbench_tools.calendar import (
CalendarTool,
)
from resources_servers.workbench.workbench_tools.company_directory import (
CompanyDirectoryTool,
)
from resources_servers.workbench.workbench_tools.customer_relationship_manager import (
CustomerRelationshipManagerTool,
)
from resources_servers.workbench.workbench_tools.email import (
EmailTool,
)
from resources_servers.workbench.workbench_tools.project_management import (
ProjectManagementTool,
)
from nemo_gym.server_utils import SESSION_ID_KEY
from resources_servers.workbench.utils import get_tools, is_correct


REASONING_TAG = os.getenv("REASONING_TAG", "think")
Expand Down Expand Up @@ -72,147 +57,67 @@ class WorkbenchVerifyResponse(BaseVerifyResponse):

class WorkbenchResourcesServer(SimpleResourcesServer):
config: WorkbenchResourcesServerConfig
session_id_to_tool_env: Dict[str, Any] = Field(default_factory=dict)

def setup_webserver(self) -> FastAPI:
app = super().setup_webserver()
app.post("/{path}")(self.route_to_python_function)

return app

async def route_to_python_function(self, path: str, body: WorkbenchRequest) -> WorkbenchResponse:
tool_name_to_class_to_function_mapping = {
"company_directory_find_email_address": {
"class": CompanyDirectoryTool,
"function": "find_email_address",
},
"email_get_email_information_by_id": {
"class": EmailTool,
"function": "get_email_information_by_id",
},
"email_search_emails": {"class": EmailTool, "function": "search_emails"},
"email_send_email": {"class": EmailTool, "function": "send_email"},
"email_delete_email": {"class": EmailTool, "function": "delete_email"},
"email_forward_email": {"class": EmailTool, "function": "forward_email"},
"email_reply_email": {"class": EmailTool, "function": "reply_email"},
"calendar_get_event_information_by_id": {
"class": CalendarTool,
"function": "get_event_information_by_id",
},
"calendar_search_events": {
"class": CalendarTool,
"function": "search_events",
},
"calendar_create_event": {
"class": CalendarTool,
"function": "create_event",
},
"calendar_delete_event": {
"class": CalendarTool,
"function": "delete_event",
},
"calendar_update_event": {
"class": CalendarTool,
"function": "update_event",
},
"analytics_engaged_users_count": {
"class": AnalyticsTool,
"function": "engaged_users_count",
},
"analytics_get_visitor_information_by_id": {
"class": AnalyticsTool,
"function": "get_visitor_information_by_id",
},
"analytics_create_plot": {
"class": AnalyticsTool,
"function": "create_plot",
},
"analytics_traffic_source_count": {
"class": AnalyticsTool,
"function": "traffic_source_count",
},
"analytics_total_visits_count": {
"class": AnalyticsTool,
"function": "total_visits_count",
},
"analytics_get_average_session_duration": {
"class": AnalyticsTool,
"function": "get_average_session_duration",
},
"project_management_get_task_information_by_id": {
"class": ProjectManagementTool,
"function": "get_task_information_by_id",
},
"project_management_search_tasks": {
"class": ProjectManagementTool,
"function": "search_tasks",
},
"project_management_create_task": {
"class": ProjectManagementTool,
"function": "create_task",
},
"project_management_delete_task": {
"class": ProjectManagementTool,
"function": "delete_task",
},
"project_management_update_task": {
"class": ProjectManagementTool,
"function": "update_task",
},
"customer_relationship_manager_search_customers": {
"class": CustomerRelationshipManagerTool,
"function": "search_customers",
},
"customer_relationship_manager_update_customer": {
"class": CustomerRelationshipManagerTool,
"function": "update_customer",
},
"customer_relationship_manager_add_customer": {
"class": CustomerRelationshipManagerTool,
"function": "add_customer",
},
"customer_relationship_manager_delete_customer": {
"class": CustomerRelationshipManagerTool,
"function": "delete_customer",
},
}

class_function_mapping = tool_name_to_class_to_function_mapping.get(path)
if not class_function_mapping:
raise HTTPException(status_code=404, detail="Class not found")

class_object = class_function_mapping["class"]()

method_name = class_function_mapping["function"] # string, e.g. "search_emails"

fn = getattr(class_object, method_name, None) # bound method on the instance
if fn is None or not callable(fn):
raise HTTPException(status_code=404, detail=f"Method {method_name} not found")

async def seed_session(self, request: Request, body: BaseSeedSessionRequest) -> BaseSeedSessionResponse:
# init session once for each sample.
session_id = request.session[SESSION_ID_KEY]
toolkits = [
"email",
"calendar",
"analytics",
"project_management",
"customer_relationship_manager",
]
self.session_id_to_tool_env[session_id] = get_tools(toolkits)
return BaseSeedSessionResponse()

async def route_to_python_function(self, path: str, body: WorkbenchRequest, request: Request) -> WorkbenchResponse:
session_id = request.session[SESSION_ID_KEY]

# Check if session exists
if session_id not in self.session_id_to_tool_env:
raise HTTPException(
status_code=400,
detail="Session not initialized. Please call seed_session first.",
)

tool_env = self.session_id_to_tool_env[session_id]
args = {key: value for key, value in body.model_dump(exclude_unset=True).items() if value is not None}

try:
result = fn(**args) # sync tool method
function = tool_env["functions"][path]
result = function(**args)
return WorkbenchResponse(output=result)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
return WorkbenchResponse(
output=f"Error executing tool '{function.__name__}': {str(e)}"
) # return error to model so that it can correct itself

async def verify(self, body: WorkbenchVerifyRequest) -> WorkbenchVerifyResponse:
ground_truth = body.ground_truth
response = body.response.output

total_score = 0.0

# Convert list of ResponseFunctionToolCall objects into list of dictionaries
predicted_function_calls = []

for message in response:
if message.type == "function_call":
predicted_function_calls.append(message.model_dump())

predicted_chat_content = []

for message in response:
if message.type == "output_text":
predicted_chat_content.append(message.model_dump())

# Use a single reward for correctness
total_score += is_correct(predicted_function_calls, ground_truth, None) * 1.0
return WorkbenchVerifyResponse(**body.model_dump(), reward=total_score)

Expand Down
6 changes: 3 additions & 3 deletions resources_servers/workbench/configs/workbench.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,18 @@ workbench_simple_agent:
jsonl_fpath: resources_servers/workbench/data/train.jsonl
gitlab_identifier:
dataset_name: workbench
version: 0.0.1
version: 0.0.4
artifact_fpath: train.jsonl
license: Apache 2.0
- name: validation
type: validation
jsonl_fpath: resources_servers/workbench/data/validation.jsonl
gitlab_identifier:
dataset_name: workbench
version: 0.0.1
version: 0.0.4
artifact_fpath: validation.jsonl
license: Apache 2.0
- name: example
type: example
jsonl_fpath: resources_servers/workbench/data/example.jsonl
max_steps: 6
max_steps: 6
Loading