diff --git a/x-pack/solutions/search/plugins/search_inference_endpoints/public/components/all_inference_endpoints/constants.ts b/x-pack/solutions/search/plugins/search_inference_endpoints/public/components/all_inference_endpoints/constants.ts index 210361569f927..9922690c3d582 100644 --- a/x-pack/solutions/search/plugins/search_inference_endpoints/public/components/all_inference_endpoints/constants.ts +++ b/x-pack/solutions/search/plugins/search_inference_endpoints/public/components/all_inference_endpoints/constants.ts @@ -5,28 +5,12 @@ * 2.0. */ -import type { QueryParams, AllInferenceEndpointsTableState, FilterOptions } from './types'; -import { SortFieldInferenceEndpoint, SortOrder } from './types'; - -export const DEFAULT_TABLE_ACTIVE_PAGE = 1; -export const DEFAULT_TABLE_LIMIT = 25; - -export const DEFAULT_QUERY_PARAMS: QueryParams = { - page: DEFAULT_TABLE_ACTIVE_PAGE, - perPage: DEFAULT_TABLE_LIMIT, - sortField: SortFieldInferenceEndpoint.inference_id, - sortOrder: SortOrder.asc, -}; +import type { FilterOptions } from './types'; export const DEFAULT_FILTER_OPTIONS: FilterOptions = { provider: [], type: [], }; -export const DEFAULT_INFERENCE_ENDPOINTS_TABLE_STATE: AllInferenceEndpointsTableState = { - filterOptions: DEFAULT_FILTER_OPTIONS, - queryParams: DEFAULT_QUERY_PARAMS, -}; - export const PIPELINE_URL = 'ingest/ingest_pipelines'; export const SERVERLESS_INDEX_MANAGEMENT_URL = 'index_details'; diff --git a/x-pack/solutions/search/plugins/search_inference_endpoints/public/components/all_inference_endpoints/tabular_page.tsx b/x-pack/solutions/search/plugins/search_inference_endpoints/public/components/all_inference_endpoints/tabular_page.tsx index eca1d2e560119..289b4211aa5ed 100644 --- a/x-pack/solutions/search/plugins/search_inference_endpoints/public/components/all_inference_endpoints/tabular_page.tsx +++ b/x-pack/solutions/search/plugins/search_inference_endpoints/public/components/all_inference_endpoints/tabular_page.tsx @@ -10,7 +10,7 @@ import { i18n as kbnI18n } from '@kbn/i18n'; import { css } from '@emotion/react'; import type { EuiBasicTableColumn, UseEuiTheme } from '@elastic/eui'; -import { EuiBasicTable, EuiFlexGroup, EuiFlexItem } from '@elastic/eui'; +import { EuiInMemoryTable, EuiFlexGroup, EuiFlexItem } from '@elastic/eui'; import type { InferenceAPIConfigResponse } from '@kbn/ml-trained-models-utils'; import type { InferenceInferenceEndpointInfo, @@ -23,8 +23,9 @@ import * as i18n from '../../../common/translations'; import { useTableData } from '../../hooks/use_table_data'; import type { FilterOptions } from './types'; +import { INFERENCE_ENDPOINTS_TABLE_PER_PAGE_VALUES } from './types'; -import { useAllInferenceEndpointsState } from '../../hooks/use_all_inference_endpoints_state'; +import { DEFAULT_FILTER_OPTIONS } from './constants'; import { ServiceProviderFilter } from './filter/service_provider_filter'; import { TaskTypeFilter } from './filter/task_type_filter'; import { TableSearch } from './search/table_search'; @@ -34,6 +35,7 @@ import { ServiceProvider } from './render_table_columns/render_service_provider/ import { TaskType } from './render_table_columns/render_task_type/task_type'; import { DeleteAction } from './render_table_columns/render_actions/actions/delete/delete_action'; import { useKibana } from '../../hooks/use_kibana'; +import { getModelId } from '../../utils/get_model_id'; import { isEndpointPreconfigured } from '../../utils/preconfigured_endpoint_helper'; import { EditInferenceFlyout } from '../edit_inference_endpoints/edit_inference_flyout'; import { docLinks } from '../../../common/doc_links'; @@ -58,8 +60,7 @@ export const TabularPage: React.FC = ({ inferenceEndpoints }) InferenceInferenceEndpointInfo | undefined >(undefined); const [searchKey, setSearchKey] = React.useState(''); - const { queryParams, setQueryParams, filterOptions, setFilterOptions } = - useAllInferenceEndpointsState(); + const [filterOptions, setFilterOptions] = useState(DEFAULT_FILTER_OPTIONS); const copyContent = useCallback( (inferenceId: string) => { @@ -111,19 +112,11 @@ export const TabularPage: React.FC = ({ inferenceEndpoints }) setSelectedInferenceEndpoint(undefined); }, []); - const onFilterChangedCallback = useCallback( - (newFilterOptions: Partial) => { - setFilterOptions(newFilterOptions); - }, - [setFilterOptions] - ); + const onFilterChangedCallback = useCallback((newFilterOptions: Partial) => { + setFilterOptions((prev) => ({ ...prev, ...newFilterOptions })); + }, []); - const { tableData, paginatedSortedTableData, pagination, sorting } = useTableData( - inferenceEndpoints, - queryParams, - filterOptions, - searchKey - ); + const tableData = useTableData(inferenceEndpoints, filterOptions, searchKey); const tableColumns = useMemo>>( () => [ @@ -151,6 +144,7 @@ export const TabularPage: React.FC = ({ inferenceEndpoints }) render: (endpointInfo: InferenceInferenceEndpointInfo) => { return ; }, + sortable: (endpointInfo: InferenceInferenceEndpointInfo) => getModelId(endpointInfo) ?? '', width: '200px', }, { @@ -164,7 +158,7 @@ export const TabularPage: React.FC = ({ inferenceEndpoints }) return null; }, - sortable: false, + sortable: true, width: '285px', }, { @@ -178,7 +172,7 @@ export const TabularPage: React.FC = ({ inferenceEndpoints }) return null; }, - sortable: false, + sortable: true, width: '100px', }, { @@ -218,24 +212,6 @@ export const TabularPage: React.FC = ({ inferenceEndpoints }) [copyContent, displayDeleteActionitem, displayInferenceFlyout] ); - const handleTableChange = useCallback( - ({ page, sort }: any) => { - const newQueryParams = { - ...queryParams, - ...(sort && { - sortField: sort.field, - sortOrder: sort.direction, - }), - ...(page && { - page: page.index + 1, - perPage: page.size, - }), - }; - setQueryParams(newQueryParams); - }, - [queryParams, setQueryParams] - ); - return ( <> @@ -278,13 +254,20 @@ export const TabularPage: React.FC = ({ inferenceEndpoints }) - ) => void; - filterOptions: FilterOptions; - setFilterOptions: (filterOptions: Partial) => void; -} - -export function useAllInferenceEndpointsState(): UseAllInferenceEndpointsStateReturn { - const [tableState, setTableState] = useState( - DEFAULT_INFERENCE_ENDPOINTS_TABLE_STATE - ); - const setState = useCallback((state: AllInferenceEndpointsTableState) => { - setTableState(state); - }, []); - - return { - queryParams: { - ...DEFAULT_INFERENCE_ENDPOINTS_TABLE_STATE.queryParams, - ...tableState.queryParams, - }, - setQueryParams: (newQueryParams: Partial) => { - setState({ - filterOptions: tableState.filterOptions, - queryParams: { ...tableState.queryParams, ...newQueryParams }, - }); - }, - filterOptions: { - ...DEFAULT_INFERENCE_ENDPOINTS_TABLE_STATE.filterOptions, - ...tableState.filterOptions, - }, - setFilterOptions: (newFilterOptions: Partial) => { - setState({ - filterOptions: { ...tableState.filterOptions, ...newFilterOptions }, - queryParams: tableState.queryParams, - }); - }, - }; -} diff --git a/x-pack/solutions/search/plugins/search_inference_endpoints/public/hooks/use_table_data.test.tsx b/x-pack/solutions/search/plugins/search_inference_endpoints/public/hooks/use_table_data.test.tsx index 34aad44c4af1e..0adf381ac5f59 100644 --- a/x-pack/solutions/search/plugins/search_inference_endpoints/public/hooks/use_table_data.test.tsx +++ b/x-pack/solutions/search/plugins/search_inference_endpoints/public/hooks/use_table_data.test.tsx @@ -7,13 +7,9 @@ import type { InferenceAPIConfigResponse } from '@kbn/ml-trained-models-utils'; import { renderHook } from '@testing-library/react'; -import type { QueryParams } from '../components/all_inference_endpoints/types'; -import { SortFieldInferenceEndpoint, SortOrder } from '../components/all_inference_endpoints/types'; +import { ServiceProviderKeys } from '@kbn/inference-endpoint-ui-common'; import { useTableData } from './use_table_data'; -import { INFERENCE_ENDPOINTS_TABLE_PER_PAGE_VALUES } from '../components/all_inference_endpoints/types'; -import { QueryClient, QueryClientProvider } from '@kbn/react-query'; -import React from 'react'; -import { TRAINED_MODEL_STATS_QUERY_KEY } from '../../common/constants'; +import type { FilterOptions } from '../components/all_inference_endpoints/types'; const inferenceEndpoints: InferenceAPIConfigResponse[] = [ { @@ -50,148 +46,161 @@ const inferenceEndpoints: InferenceAPIConfigResponse[] = [ }, ]; -const queryParams: QueryParams = { - page: 1, - perPage: 10, - sortField: SortFieldInferenceEndpoint.inference_id, - sortOrder: SortOrder.desc, -}; - -const filterOptions = { - provider: ['elasticsearch', 'openai'], - type: ['sparse_embedding', 'text_embedding'], -} as any; - -const searchKey = 'my'; - describe('useTableData', () => { - const queryClient = new QueryClient(); - const wrapper = ({ children }: { children: React.ReactNode }) => { - return {children}; - }; - - beforeEach(() => { - queryClient.setQueryData([TRAINED_MODEL_STATS_QUERY_KEY], { - trained_model_stats: [ - { - model_id: '.elser_model_2', - deployment_stats: { deployment_id: 'my-elser-model-01', state: 'started' }, - }, - ], - }); - }); - it('should return correct pagination', () => { - const { result } = renderHook( - () => useTableData(inferenceEndpoints, queryParams, filterOptions, searchKey), - { wrapper } - ); + it('should return all data when no filters are applied', () => { + const filterOptions: FilterOptions = { provider: [], type: [] }; + const { result } = renderHook(() => useTableData(inferenceEndpoints, filterOptions, '')); - expect(result.current.pagination).toEqual({ - pageIndex: 0, - pageSize: 10, - pageSizeOptions: INFERENCE_ENDPOINTS_TABLE_PER_PAGE_VALUES, - totalItemCount: 3, - }); + expect(result.current.length).toBe(3); }); - it('should return correct sorting', () => { - const { result } = renderHook( - () => useTableData(inferenceEndpoints, queryParams, filterOptions, searchKey), - { wrapper } - ); + it('should filter data by provider', () => { + const filterOptions: FilterOptions = { + provider: [ServiceProviderKeys.elasticsearch], + type: [], + }; + const { result } = renderHook(() => useTableData(inferenceEndpoints, filterOptions, '')); - expect(result.current.sorting).toEqual({ - sort: { - direction: 'desc', - field: 'inference_id', - }, - }); + expect(result.current.length).toBe(2); + expect(result.current.every((endpoint) => endpoint.service === 'elasticsearch')).toBe(true); }); - it('should return correctly sorted data', () => { - const { result } = renderHook( - () => useTableData(inferenceEndpoints, queryParams, filterOptions, searchKey), - { wrapper } - ); + it('should filter data by task type', () => { + const filterOptions: FilterOptions = { provider: [], type: ['text_embedding'] }; + const { result } = renderHook(() => useTableData(inferenceEndpoints, filterOptions, '')); - const expectedSortedData = [...inferenceEndpoints].sort((a, b) => - b.inference_id.localeCompare(a.inference_id) - ); - - const sortedEndpoints = result.current.sortedTableData.map((item) => item.inference_id); - const expectedModelIds = expectedSortedData.map((item) => item.inference_id); - - expect(sortedEndpoints).toEqual(expectedModelIds); + expect(result.current.length).toBe(1); + expect(result.current[0].task_type).toBe('text_embedding'); }); - it('should filter data based on provider and type from filterOptions', () => { - const filterOptions2 = { - provider: ['elasticsearch'], - type: ['text_embedding'], - } as any; - const { result } = renderHook( - () => useTableData(inferenceEndpoints, queryParams, filterOptions2, searchKey), - { wrapper } - ); + it('should filter data by both provider and type', () => { + const filterOptions: FilterOptions = { + provider: [ServiceProviderKeys.elasticsearch], + type: ['sparse_embedding'], + }; + const { result } = renderHook(() => useTableData(inferenceEndpoints, filterOptions, '')); - const filteredData = result.current.sortedTableData; + expect(result.current.length).toBe(2); expect( - filteredData.every( + result.current.every( (endpoint) => - filterOptions.provider.includes(endpoint.service) && - filterOptions.type.includes(endpoint.task_type) + endpoint.service === 'elasticsearch' && endpoint.task_type === 'sparse_embedding' ) - ).toBeTruthy(); + ).toBe(true); }); it('should filter data based on searchKey matching inference_id', () => { - const searchKey2 = 'model-05'; - const { result } = renderHook( - () => useTableData(inferenceEndpoints, queryParams, filterOptions, searchKey2), - { wrapper } + const filterOptions: FilterOptions = { provider: [], type: [] }; + const { result } = renderHook(() => + useTableData(inferenceEndpoints, filterOptions, 'model-05') ); - const filteredData = result.current.sortedTableData; - expect(filteredData.length).toBe(1); - expect(filteredData[0].inference_id).toBe('my-openai-model-05'); + + expect(result.current.length).toBe(1); + expect(result.current[0].inference_id).toBe('my-openai-model-05'); }); it('should filter data based on searchKey matching model_id', () => { - // Search for 'third-party' which only exists in model_id, not in inference_id - const searchKey2 = 'third-party'; - const { result } = renderHook( - () => useTableData(inferenceEndpoints, queryParams, filterOptions, searchKey2), - { wrapper } + const filterOptions: FilterOptions = { provider: [], type: [] }; + const { result } = renderHook(() => + useTableData(inferenceEndpoints, filterOptions, 'third-party') ); - const filteredData = result.current.sortedTableData; - expect(filteredData.length).toBe(1); - // Verify the correct endpoint was found by checking both inference_id and model_id - expect(filteredData[0].inference_id).toBe('my-openai-model-05'); - expect(filteredData[0].service_settings.model_id).toBe('third-party-model'); + + expect(result.current.length).toBe(1); + expect(result.current[0].inference_id).toBe('my-openai-model-05'); + expect(result.current[0].service_settings.model_id).toBe('third-party-model'); }); it('should filter data case-insensitively', () => { - const searchKey2 = 'ELSER'; - const { result } = renderHook( - () => useTableData(inferenceEndpoints, queryParams, filterOptions, searchKey2), - { wrapper } + const filterOptions: FilterOptions = { provider: [], type: [] }; + const { result } = renderHook(() => useTableData(inferenceEndpoints, filterOptions, 'ELSER')); + + expect(result.current.length).toBe(2); + expect(result.current.every((item) => item.inference_id.includes('elser'))).toBe(true); + }); + + it('should combine provider, type, and search filters', () => { + const filterOptions: FilterOptions = { + provider: [ServiceProviderKeys.elasticsearch], + type: ['sparse_embedding'], + }; + const { result } = renderHook(() => + useTableData(inferenceEndpoints, filterOptions, 'model-01') + ); + + expect(result.current.length).toBe(1); + expect(result.current[0].inference_id).toBe('my-elser-model-01'); + }); + + it('should return empty array when no endpoints match filters', () => { + const filterOptions: FilterOptions = { provider: [ServiceProviderKeys.cohere], type: [] }; + const { result } = renderHook(() => useTableData(inferenceEndpoints, filterOptions, '')); + + expect(result.current.length).toBe(0); + }); + + it('should return empty array when inferenceEndpoints is empty', () => { + const filterOptions: FilterOptions = { provider: [], type: [] }; + const { result } = renderHook(() => useTableData([], filterOptions, '')); + + expect(result.current.length).toBe(0); + }); + + it('should filter by multiple providers', () => { + const filterOptions: FilterOptions = { + provider: [ServiceProviderKeys.elasticsearch, ServiceProviderKeys.openai], + type: [], + }; + const { result } = renderHook(() => useTableData(inferenceEndpoints, filterOptions, '')); + + expect(result.current.length).toBe(3); + }); + + it('should handle endpoints with no model_id in service_settings', () => { + const endpointsWithNoModelId: InferenceAPIConfigResponse[] = [ + { + inference_id: 'endpoint-no-model', + task_type: 'sparse_embedding', + service: 'elasticsearch', + service_settings: { + num_allocations: 1, + num_threads: 1, + }, + task_settings: {}, + }, + ]; + const filterOptions: FilterOptions = { provider: [], type: [] }; + + // Should still find by inference_id + const { result } = renderHook(() => + useTableData(endpointsWithNoModelId, filterOptions, 'endpoint-no-model') ); - const filteredData = result.current.sortedTableData; - expect(filteredData.length).toBe(2); - expect(filteredData.every((item) => item.inference_id.includes('elser'))).toBeTruthy(); + expect(result.current.length).toBe(1); + + // Should not match when searching for non-existent model_id + const { result: result2 } = renderHook(() => + useTableData(endpointsWithNoModelId, filterOptions, 'some-model-id') + ); + expect(result2.current.length).toBe(0); }); - it('should set pagination total to filtered count', () => { - const filteredSearchKey = 'third-party'; - const { result } = renderHook( - () => useTableData(inferenceEndpoints, queryParams, filterOptions, filteredSearchKey), - { wrapper } + it('should search by service_settings.model field (alternative to model_id)', () => { + const endpointsWithModelField: InferenceAPIConfigResponse[] = [ + { + inference_id: 'bedrock-endpoint', + task_type: 'text_embedding', + service: 'amazonbedrock', + service_settings: { + model: 'amazon.titan-embed-text-v1', + }, + task_settings: {}, + }, + ]; + const filterOptions: FilterOptions = { provider: [], type: [] }; + const { result } = renderHook(() => + useTableData(endpointsWithModelField, filterOptions, 'titan-embed') ); - expect(result.current.pagination).toEqual({ - pageIndex: 0, - pageSize: 10, - pageSizeOptions: INFERENCE_ENDPOINTS_TABLE_PER_PAGE_VALUES, - totalItemCount: 1, - }); + expect(result.current.length).toBe(1); + expect(result.current[0].inference_id).toBe('bedrock-endpoint'); }); }); diff --git a/x-pack/solutions/search/plugins/search_inference_endpoints/public/hooks/use_table_data.tsx b/x-pack/solutions/search/plugins/search_inference_endpoints/public/hooks/use_table_data.tsx index ff7a7de23f669..b98a2fa8f314e 100644 --- a/x-pack/solutions/search/plugins/search_inference_endpoints/public/hooks/use_table_data.tsx +++ b/x-pack/solutions/search/plugins/search_inference_endpoints/public/hooks/use_table_data.tsx @@ -5,97 +5,50 @@ * 2.0. */ -import type { EuiTableSortingType } from '@elastic/eui'; -import type { Pagination } from '@elastic/eui'; import type { InferenceAPIConfigResponse } from '@kbn/ml-trained-models-utils'; import { useMemo } from 'react'; import { ServiceProviderKeys } from '@kbn/inference-endpoint-ui-common'; import type { InferenceInferenceEndpointInfo } from '@elastic/elasticsearch/lib/api/types'; -import { DEFAULT_TABLE_LIMIT } from '../components/all_inference_endpoints/constants'; -import type { FilterOptions, QueryParams } from '../components/all_inference_endpoints/types'; -import { - INFERENCE_ENDPOINTS_TABLE_PER_PAGE_VALUES, - SortOrder, -} from '../components/all_inference_endpoints/types'; +import type { FilterOptions } from '../components/all_inference_endpoints/types'; import { getModelId } from '../utils/get_model_id'; -interface UseTableDataReturn { - tableData: InferenceInferenceEndpointInfo[]; - sortedTableData: InferenceInferenceEndpointInfo[]; - paginatedSortedTableData: InferenceInferenceEndpointInfo[]; - pagination: Pagination; - sorting: EuiTableSortingType; -} - +/** + * Hook that filters inference endpoints based on provider, type, and search criteria. + * Sorting and pagination are handled by EuiInMemoryTable. + */ export const useTableData = ( inferenceEndpoints: InferenceAPIConfigResponse[], - queryParams: QueryParams, filterOptions: FilterOptions, searchKey: string -): UseTableDataReturn => { - const tableData: InferenceInferenceEndpointInfo[] = useMemo(() => { +): InferenceInferenceEndpointInfo[] => { + return useMemo(() => { let filteredEndpoints = inferenceEndpoints; + // Filter by provider if (filterOptions.provider.length > 0) { filteredEndpoints = filteredEndpoints.filter((endpoint) => filterOptions.provider.includes(ServiceProviderKeys[endpoint.service]) ); } + // Filter by task type if (filterOptions.type.length > 0) { filteredEndpoints = filteredEndpoints.filter((endpoint) => filterOptions.type.includes(endpoint.task_type) ); } - return filteredEndpoints.filter((endpoint) => { + // Filter by search key (matches inference_id or model_id) + if (searchKey) { const lowerSearchKey = searchKey.toLowerCase(); - const inferenceIdMatch = endpoint.inference_id.toLowerCase().includes(lowerSearchKey); - const modelId = getModelId(endpoint); - const modelIdMatch = modelId ? modelId.toLowerCase().includes(lowerSearchKey) : false; - return inferenceIdMatch || modelIdMatch; - }); - }, [inferenceEndpoints, searchKey, filterOptions]); - - const sortedTableData: InferenceInferenceEndpointInfo[] = useMemo(() => { - return [...tableData].sort((a, b) => { - const aValue = a[queryParams.sortField]; - const bValue = b[queryParams.sortField]; - - if (queryParams.sortOrder === SortOrder.asc) { - return aValue.localeCompare(bValue); - } else { - return bValue.localeCompare(aValue); - } - }); - }, [tableData, queryParams]); - - const pagination: Pagination = useMemo( - () => ({ - pageIndex: queryParams.page - 1, - pageSize: queryParams.perPage, - pageSizeOptions: INFERENCE_ENDPOINTS_TABLE_PER_PAGE_VALUES, - totalItemCount: tableData.length ?? 0, - }), - [tableData, queryParams] - ); - - const paginatedSortedTableData: InferenceInferenceEndpointInfo[] = useMemo(() => { - const pageSize = pagination.pageSize || DEFAULT_TABLE_LIMIT; - const startIndex = pagination.pageIndex * pageSize; - const endIndex = startIndex + pageSize; - return sortedTableData.slice(startIndex, endIndex); - }, [sortedTableData, pagination]); - - const sorting = useMemo( - () => ({ - sort: { - direction: queryParams.sortOrder, - field: queryParams.sortField, - }, - }), - [queryParams.sortField, queryParams.sortOrder] - ); + filteredEndpoints = filteredEndpoints.filter((endpoint) => { + const inferenceIdMatch = endpoint.inference_id.toLowerCase().includes(lowerSearchKey); + const modelId = getModelId(endpoint); + const modelIdMatch = modelId ? modelId.toLowerCase().includes(lowerSearchKey) : false; + return inferenceIdMatch || modelIdMatch; + }); + } - return { tableData, sortedTableData, paginatedSortedTableData, pagination, sorting }; + return filteredEndpoints; + }, [inferenceEndpoints, searchKey, filterOptions]); };