diff --git a/examples/metrics-monitoring/README.md b/examples/metrics-monitoring/README.md new file mode 100644 index 000000000000..64ef1160c66b --- /dev/null +++ b/examples/metrics-monitoring/README.md @@ -0,0 +1,4 @@ +# Metrics Monitoring + +## Continuous Batching Metrics in Transformers + diff --git a/examples/metrics-monitoring/continuous-batching-dashboard.json b/examples/metrics-monitoring/continuous-batching-dashboard.json new file mode 100644 index 000000000000..e0a293d06295 --- /dev/null +++ b/examples/metrics-monitoring/continuous-batching-dashboard.json @@ -0,0 +1,974 @@ +{ + "annotations": { + "list": [ + { + "builtIn": 1, + "datasource": { + "type": "grafana", + "uid": "-- Grafana --" + }, + "enable": true, + "hide": true, + "iconColor": "rgba(0, 211, 255, 1)", + "name": "Annotations & Alerts", + "target": { + "limit": 100, + "matchAny": false, + "tags": [], + "type": "dashboard" + }, + "type": "dashboard" + } + ] + }, + "editable": true, + "fiscalYearStartMonth": 0, + "graphTooltip": 0, + "id": 2, + "links": [], + "panels": [ + { + "datasource": { + "type": "prometheus", + "uid": "PBFA97CFB590B2093" + }, + "description": "Memory usage of the PagedAttentionCache", + "fieldConfig": { + "defaults": { + "color": { + "mode": "thresholds" + }, + "mappings": [], + "max": 10737418240, + "min": 0, + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green" + }, + { + "color": "yellow", + "value": 5368709120 + }, + { + "color": "red", + "value": 8589934592 + } + ] + }, + "unit": "bytes" + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 6, + "x": 0, + "y": 0 + }, + "id": 2, + "options": { + "minVizHeight": 75, + "minVizWidth": 75, + "orientation": "auto", + "reduceOptions": { + "calcs": [ + "lastNotNull" + ], + "fields": "", + "values": false + }, + "showThresholdLabels": false, + "showThresholdMarkers": true, + "sizing": "auto" + }, + "pluginVersion": "12.0.0", + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "PBFA97CFB590B2093" + }, + "disableTextWrap": false, + "editorMode": "builder", + "expr": "kv_cache_memory_bytes", + "fullMetaSearch": false, + "includeNullMetadata": true, + "legendFormat": "__auto", + "range": true, + "refId": "A", + "useBackend": false + } + ], + "title": "KV Cache Memory Usage", + "transparent": true, + "type": "gauge" + }, + { + "datasource": { + "type": "prometheus", + "uid": "PBFA97CFB590B2093" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "thresholds" + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "dark-blue" + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 6, + "x": 6, + "y": 0 + }, + "id": 13, + "options": { + "colorMode": "value", + "graphMode": "area", + "justifyMode": "auto", + "orientation": "auto", + "percentChangeColorMode": "standard", + "reduceOptions": { + "calcs": [ + "lastNotNull" + ], + "fields": "", + "values": false + }, + "showPercentChange": false, + "textMode": "auto", + "wideLayout": true + }, + "pluginVersion": "12.0.0", + "targets": [ + { + "disableTextWrap": false, + "editorMode": "builder", + "expr": "active_requests_count", + "fullMetaSearch": false, + "includeNullMetadata": true, + "legendFormat": "__auto", + "range": true, + "refId": "A", + "useBackend": false + } + ], + "title": "Active Requests", + "transparent": true, + "type": "stat" + }, + { + "datasource": { + "type": "prometheus", + "uid": "PBFA97CFB590B2093" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "thresholds" + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "dark-orange" + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 6, + "x": 12, + "y": 0 + }, + "id": 14, + "options": { + "colorMode": "value", + "graphMode": "area", + "justifyMode": "auto", + "orientation": "auto", + "percentChangeColorMode": "standard", + "reduceOptions": { + "calcs": [ + "lastNotNull" + ], + "fields": "", + "values": false + }, + "showPercentChange": false, + "textMode": "auto", + "wideLayout": true + }, + "pluginVersion": "12.0.0", + "targets": [ + { + "disableTextWrap": false, + "editorMode": "builder", + "expr": "waiting_requests_count", + "fullMetaSearch": false, + "includeNullMetadata": true, + "legendFormat": "__auto", + "range": true, + "refId": "A", + "useBackend": false + } + ], + "title": "Waiting Requests", + "transparent": true, + "type": "stat" + }, + { + "datasource": { + "type": "prometheus", + "uid": "PBFA97CFB590B2093" + }, + "description": "Ratio of decode tokens to prefill tokens in a batch", + "fieldConfig": { + "defaults": { + "color": { + "mode": "thresholds" + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "blue" + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 6, + "x": 18, + "y": 0 + }, + "id": 6, + "options": { + "colorMode": "value", + "graphMode": "none", + "justifyMode": "auto", + "orientation": "auto", + "percentChangeColorMode": "standard", + "reduceOptions": { + "calcs": [ + "lastNotNull" + ], + "fields": "", + "values": false + }, + "showPercentChange": false, + "textMode": "auto", + "wideLayout": true + }, + "pluginVersion": "12.0.0", + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "PBFA97CFB590B2093" + }, + "disableTextWrap": false, + "editorMode": "builder", + "expr": "decode_prefill_ratio", + "fullMetaSearch": false, + "includeNullMetadata": true, + "legendFormat": "__auto", + "range": true, + "refId": "A", + "useBackend": false + } + ], + "title": "Decode/Prefill Ratio", + "transparent": true, + "type": "stat" + }, + { + "datasource": { + "type": "prometheus", + "uid": "PBFA97CFB590B2093" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "barWidthFactor": 0.6, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green" + }, + { + "color": "red", + "value": 80 + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 8 + }, + "id": 10, + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "hideZeros": false, + "mode": "single", + "sort": "none" + } + }, + "pluginVersion": "12.0.0", + "targets": [ + { + "editorMode": "code", + "expr": "rate(decode_tokens_processed_total[$__rate_interval])", + "legendFormat": "__auto", + "range": true, + "refId": "A" + } + ], + "title": "Decode tokens throupught tok/s", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "PBFA97CFB590B2093" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "barWidthFactor": 0.6, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green" + }, + { + "color": "red", + "value": 80 + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 8 + }, + "id": 11, + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "hideZeros": false, + "mode": "single", + "sort": "none" + } + }, + "pluginVersion": "12.0.0", + "targets": [ + { + "editorMode": "code", + "expr": "rate(prefill_tokens_processed_total[$__rate_interval])", + "legendFormat": "__auto", + "range": true, + "refId": "A" + } + ], + "title": "Prefill rate tok/s", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "PBFA97CFB590B2093" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "barWidthFactor": 0.6, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green" + }, + { + "color": "red", + "value": 80 + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 16 + }, + "id": 9, + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "hideZeros": false, + "mode": "single", + "sort": "none" + } + }, + "pluginVersion": "12.0.0", + "targets": [ + { + "editorMode": "code", + "expr": "histogram_quantile(0.95, sum by(le) (rate(batch_fill_percentage_percent_bucket[$__rate_interval])))", + "legendFormat": "p95", + "range": true, + "refId": "A" + }, + { + "datasource": { + "type": "prometheus", + "uid": "PBFA97CFB590B2093" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.99, sum by(le) (rate(batch_fill_percentage_percent_bucket[$__rate_interval])))", + "hide": false, + "instant": false, + "legendFormat": "p99", + "range": true, + "refId": "B" + }, + { + "datasource": { + "type": "prometheus", + "uid": "PBFA97CFB590B2093" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.5, sum by(le) (rate(batch_fill_percentage_percent_bucket[$__rate_interval])))", + "hide": false, + "instant": false, + "legendFormat": "p50", + "range": true, + "refId": "C" + } + ], + "title": "Batch fill percentage percentiles", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "PBFA97CFB590B2093" + }, + "description": "KV Cache Memory Usage Over Time", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "barWidthFactor": 0.6, + "drawStyle": "line", + "fillOpacity": 20, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 2, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green" + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "bytes" + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 16 + }, + "id": 4, + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "hideZeros": false, + "mode": "single", + "sort": "none" + } + }, + "pluginVersion": "12.0.0", + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "PBFA97CFB590B2093" + }, + "disableTextWrap": false, + "editorMode": "builder", + "expr": "kv_cache_memory_bytes", + "fullMetaSearch": false, + "includeNullMetadata": true, + "legendFormat": "Used memory", + "range": true, + "refId": "A", + "useBackend": false + }, + { + "datasource": { + "type": "prometheus", + "uid": "PBFA97CFB590B2093" + }, + "disableTextWrap": false, + "editorMode": "builder", + "expr": "kv_cache_free_memory_bytes", + "fullMetaSearch": false, + "hide": false, + "includeNullMetadata": true, + "instant": false, + "legendFormat": "free memory", + "range": true, + "refId": "B", + "useBackend": false + } + ], + "title": "KV Cache Memory Usage Over Time", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "PBFA97CFB590B2093" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "thresholds" + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green" + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "ms" + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 24 + }, + "id": 8, + "options": { + "displayMode": "gradient", + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": false + }, + "maxVizHeight": 300, + "minVizHeight": 10, + "minVizWidth": 0, + "namePlacement": "auto", + "orientation": "auto", + "reduceOptions": { + "calcs": [ + "lastNotNull" + ], + "fields": "", + "values": false + }, + "showUnfilled": true, + "sizing": "auto", + "valueMode": "color" + }, + "pluginVersion": "12.0.0", + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "PBFA97CFB590B2093" + }, + "disableTextWrap": false, + "editorMode": "builder", + "expr": "histogram_quantile(0.95, sum by(le) (rate(ttft_milliseconds_bucket[$__rate_interval])))", + "fullMetaSearch": false, + "includeNullMetadata": true, + "legendFormat": "p95", + "range": true, + "refId": "A", + "useBackend": false + }, + { + "datasource": { + "type": "prometheus", + "uid": "PBFA97CFB590B2093" + }, + "disableTextWrap": false, + "editorMode": "builder", + "expr": "histogram_quantile(0.5, sum by(le) (rate(ttft_milliseconds_bucket[$__rate_interval])))", + "fullMetaSearch": false, + "hide": false, + "includeNullMetadata": true, + "legendFormat": "p50", + "range": true, + "refId": "B", + "useBackend": false + }, + { + "datasource": { + "type": "prometheus", + "uid": "PBFA97CFB590B2093" + }, + "disableTextWrap": false, + "editorMode": "builder", + "expr": "histogram_quantile(0.99, sum by(le) (rate(ttft_milliseconds_bucket[$__rate_interval])))", + "fullMetaSearch": false, + "hide": false, + "includeNullMetadata": false, + "instant": false, + "legendFormat": "p99", + "range": true, + "refId": "C", + "useBackend": false + } + ], + "title": "Time to First Token (TTFT)", + "type": "bargauge" + }, + { + "datasource": { + "type": "prometheus", + "uid": "PBFA97CFB590B2093" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "barWidthFactor": 0.6, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green" + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "ms" + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 24 + }, + "id": 12, + "options": { + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "hideZeros": false, + "mode": "single", + "sort": "none" + } + }, + "pluginVersion": "12.0.0", + "targets": [ + { + "editorMode": "code", + "expr": "histogram_quantile(0.5, sum by(le) (rate(request_latency_milliseconds_bucket[$__rate_interval])))", + "legendFormat": "p50", + "range": true, + "refId": "A" + }, + { + "datasource": { + "type": "prometheus", + "uid": "PBFA97CFB590B2093" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.95, sum by(le) (rate(request_latency_milliseconds_bucket[$__rate_interval])))", + "hide": false, + "instant": false, + "legendFormat": "p95", + "range": true, + "refId": "B" + }, + { + "datasource": { + "type": "prometheus", + "uid": "PBFA97CFB590B2093" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.99, sum by(le) (rate(request_latency_milliseconds_bucket[$__rate_interval])))", + "hide": false, + "instant": false, + "legendFormat": "p99", + "range": true, + "refId": "C" + } + ], + "title": "Request latency percentiles", + "type": "timeseries" + } + ], + "preload": false, + "refresh": "5s", + "schemaVersion": 41, + "tags": [], + "templating": { + "list": [] + }, + "time": { + "from": "now-15m", + "to": "now" + }, + "timepicker": {}, + "timezone": "", + "title": "Transformers Continuous Batching Metrics", + "uid": "Lw6CTvVSz", + "version": 5 +} \ No newline at end of file diff --git a/examples/metrics-monitoring/docker-compose.yml b/examples/metrics-monitoring/docker-compose.yml new file mode 100644 index 000000000000..936f4a894ced --- /dev/null +++ b/examples/metrics-monitoring/docker-compose.yml @@ -0,0 +1,55 @@ +services: + memcached: + image: memcached:1.6.29 + container_name: memcached + ports: + - "11211:11211" + environment: + - MEMCACHED_MAX_MEMORY=64m # Set the maximum memory usage + - MEMCACHED_THREADS=4 # Number of threads to use + + prometheus: + image: prom/prometheus:latest + command: + - "--config.file=/etc/prometheus/prometheus.yml" + - --web.enable-otlp-receiver # Enable OTLP receiver + - --web.enable-remote-write-receiver + - --enable-feature=exemplar-storage + - --enable-feature=native-histograms + volumes: + - ./prometheus.yml:/etc/prometheus/prometheus.yml + ports: + - "9090:9090" + + tempo: + image: grafana/tempo:latest + command: [ "-config.file=/etc/tempo.yaml" ] + volumes: + - ./tempo.yaml:/etc/tempo.yaml + ports: + - "14268:14268" # jaeger ingest + - "3200:3200" # tempo + - "9095:9095" # tempo grpc + - "4317:4317" # otlp grpc + - "4318:4318" # otlp http + - "9411:9411" # zipkin + depends_on: + - memcached + + grafana: + image: grafana/grafana:latest + volumes: + - ./continuous-batching-dashboard.json:/etc/grafana/provisioning/dashboards/continuous-batching-dashboard.json + - ./grafana-dashboard.yaml:/etc/grafana/provisioning/dashboards/grafana-dashboard.yaml + - ./grafana-datasources.yaml:/etc/grafana/provisioning/datasources/datasources.yaml + environment: + - GF_AUTH_ANONYMOUS_ENABLED=true + - GF_AUTH_ANONYMOUS_ORG_ROLE=Admin + - GF_AUTH_DISABLE_LOGIN_FORM=true + - GF_FEATURE_TOGGLES_ENABLE=traceqlEditor metricsSummary + - GF_INSTALL_PLUGINS=https://storage.googleapis.com/integration-artifacts/grafana-exploretraces-app/grafana-exploretraces-app-latest.zip;grafana-traces-app + ports: + - "3000:3000" + depends_on: + - prometheus + - tempo diff --git a/examples/metrics-monitoring/grafana-dashboard.yaml b/examples/metrics-monitoring/grafana-dashboard.yaml new file mode 100644 index 000000000000..6dd396d00e16 --- /dev/null +++ b/examples/metrics-monitoring/grafana-dashboard.yaml @@ -0,0 +1,11 @@ +apiVersion: 1 + +providers: + - name: 'Transformers Dashboards' + orgId: 1 + folder: 'Transformers' + type: file + disableDeletion: false + editable: true + options: + path: /etc/grafana/provisioning/dashboards diff --git a/examples/metrics-monitoring/grafana-datasources.yaml b/examples/metrics-monitoring/grafana-datasources.yaml new file mode 100644 index 000000000000..e3f2e78becea --- /dev/null +++ b/examples/metrics-monitoring/grafana-datasources.yaml @@ -0,0 +1,14 @@ +apiVersion: 1 + +datasources: + - name: Prometheus + type: prometheus + access: proxy + url: http://prometheus:9090 + isDefault: true + + - name: Tempo + type: tempo + access: proxy + url: http://tempo:3200 + uid: tempo diff --git a/examples/metrics-monitoring/metrics_example.py b/examples/metrics-monitoring/metrics_example.py new file mode 100644 index 000000000000..df3551b68d48 --- /dev/null +++ b/examples/metrics-monitoring/metrics_example.py @@ -0,0 +1,48 @@ +# Example usage of the trace and attach_tracer decorators + +from transformers.utils.metrics import attach_tracer, traced + + +@attach_tracer() +class ExampleClass: + def __init__(self, name): + # The attach_tracer decorator has already created self.tracer for us + self.name = name + + @traced # This method will use the tracer from the class instance + def process_data(self, data): + # This method is traced and can use self.tracer + return f"Processed {data} with {self.name}" + + @traced(span_name="custom_operation") # With custom span name + def special_operation(self, value): + # Also traced, with a custom span name + return value * 2 + + @traced( + additional_attributes=[ + ("name", "object.name", lambda x: x.upper()), # Using a transform function + ("name", "object.fixed_value", "static_value"), # Using a fixed value + ] + ) + def operation_with_attributes(self): + # This will add the specified attributes to the span + return "Operation completed" + + +# For functions without a class, the traced decorator still works +@traced +def standalone_function(arg1, arg2): + # For functions, a tracer is created based on the module name + return arg1 + arg2 + + +# Usage: +if __name__ == "__main__": + # With OpenTelemetry configured, these will produce traces + example = ExampleClass("test_object") + example.process_data("sample") + example.special_operation(42) + example.operation_with_attributes() + + result = standalone_function(1, 2) diff --git a/examples/metrics-monitoring/prometheus.yml b/examples/metrics-monitoring/prometheus.yml new file mode 100644 index 000000000000..6c578ad89f51 --- /dev/null +++ b/examples/metrics-monitoring/prometheus.yml @@ -0,0 +1,3 @@ +global: + scrape_interval: 15s + diff --git a/examples/metrics-monitoring/tempo.yaml b/examples/metrics-monitoring/tempo.yaml new file mode 100644 index 000000000000..353b83e1cccf --- /dev/null +++ b/examples/metrics-monitoring/tempo.yaml @@ -0,0 +1,90 @@ +stream_over_http_enabled: true +server: + http_listen_port: 3200 + log_level: info + + +cache: + background: + writeback_goroutines: 5 + caches: + - roles: + - frontend-search + memcached: + addresses: dns+memcached:11211 + +query_frontend: + search: + duration_slo: 5s + throughput_bytes_slo: 1.073741824e+09 + metadata_slo: + duration_slo: 5s + throughput_bytes_slo: 1.073741824e+09 + trace_by_id: + duration_slo: 100ms + metrics: + max_duration: 200h # maximum duration of a metrics query, increase for local setups + query_backend_after: 5m + duration_slo: 5s + throughput_bytes_slo: 1.073741824e+09 + +distributor: + receivers: # this configuration will listen on all ports and protocols that tempo is capable of. + jaeger: # the receives all come from the OpenTelemetry collector. more configuration information can + protocols: # be found there: https://github.com/open-telemetry/opentelemetry-collector/tree/main/receiver + thrift_http: # + endpoint: "tempo:14268" # for a production deployment you should only enable the receivers you need! + grpc: + endpoint: "tempo:14250" + thrift_binary: + endpoint: "tempo:6832" + thrift_compact: + endpoint: "tempo:6831" + zipkin: + endpoint: "tempo:9411" + otlp: + protocols: + grpc: + endpoint: "tempo:4317" + http: + endpoint: "tempo:4318" + opencensus: + endpoint: "tempo:55678" + +ingester: + max_block_duration: 5m # cut the headblock when this much time passes. this is being set for demo purposes and should probably be left alone normally + +compactor: + compaction: + block_retention: 720h # overall Tempo trace retention. set for demo purposes + +metrics_generator: + registry: + external_labels: + source: tempo + cluster: docker-compose + storage: + path: /var/tempo/generator/wal + remote_write: + - url: http://prometheus:9090/api/v1/write + send_exemplars: true + traces_storage: + path: /var/tempo/generator/traces + processor: + local_blocks: + filter_server_spans: false + flush_to_storage: true + +storage: + trace: + backend: local # backend configuration to use + wal: + path: /var/tempo/wal # where to store the wal locally + local: + path: /var/tempo/blocks + +overrides: + defaults: + metrics_generator: + processors: [service-graphs, span-metrics, local-blocks] # enables metrics generator + generate_native_histograms: both diff --git a/examples/pytorch/continuous_batching.py b/examples/pytorch/continuous_batching.py new file mode 100644 index 000000000000..9aaa836f7bae --- /dev/null +++ b/examples/pytorch/continuous_batching.py @@ -0,0 +1,109 @@ +import time + +import datasets +import torch + +from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers.generation import GenerationConfig + + +torch.set_float32_matmul_precision("high") + +model_id = "meta-llama/Llama-3.2-3b-Instruct" +model = AutoModelForCausalLM.from_pretrained( + model_id, attn_implementation="sdpa_paged", torch_dtype=torch.bfloat16, device_map="auto" +).eval() +tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left") + +generation_config = GenerationConfig( + max_new_tokens=512, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id, + use_cache=False, + num_blocks=2048, + block_size=128, + do_sample=True, + max_batch_tokens=1024, # Maximum number of tokens to process in a single batch + scheduler="prefill_first", +) + +train_dataset = datasets.load_dataset("openai/gsm8k", "socratic", split="test") + +# --- Example 1: Simple Version using generate_batch --- +print("--- Running CB Generation Example ---") + + +def tokenize_function(examples): + return tokenizer(examples["question"]) + + +tokenized_datasets = train_dataset.map(tokenize_function, batched=True) +simple_batch_inputs = [item["input_ids"] for item in tokenized_datasets] + +start_time_simple = time.time() +# model.forward = torch.compile(model.forward, mode="max-autotune-no-cudagraphs", fullgraph=True) +batch_outputs = model.generate_batch( + inputs=simple_batch_inputs, + generation_config=generation_config, +) +end_time_simple = time.time() + +for request in batch_outputs: + input_text = tokenizer.decode(batch_outputs[request].prompt_ids, skip_special_tokens=False) + try: + output_text = tokenizer.decode(batch_outputs[request].generated_tokens, skip_special_tokens=False) + except Exception as e: + print(f"Decoding failed for request {request}: {e}") + output_text = tokenizer.decode(batch_outputs[request].generated_tokens[1:], skip_special_tokens=False) + if len(output_text) > 0: + print("-" * 20) + print(f"{request} Input: {input_text}") + print(f"{request} Output: {output_text}") + else: + print("", end="\r\r\r\r") +print("-" * 20) +print("--- Finished CB Generation Example ---\n\n") + + +print(f"CB generation took: {end_time_simple - start_time_simple:.2f} seconds") + + +# train_dataset = train_dataset.select(range(5)) # Use only 5 examples for the simple version + +# tokenized_test_prompts = tokenizer(_TEST_PROMPTS, padding=True, padding_side="left", truncation=True, max_length=512) +# simple_batch_inputs = list(tokenized_test_prompts["input_ids"]) + +# def tokenize_function(examples): +# # Truncate to avoid overly long prompts exceeding max context length +# return tokenizer(examples["question"], padding=True, truncation=True, max_length=512) + + +# tokenized_datasets = train_dataset.map(tokenize_function, batched=True) +# simple_batch_inputs = [item["input_ids"] for item in tokenized_datasets] + + +# model.config.attn_implementation = "sdpa" +# start_time_simple = time.time() +# batch_size = 64 +# full_outputs = [] +# from tqdm import tqdm + +# for i in tqdm(range(0, len(simple_batch_inputs)-batch_size, batch_size)): +# outputs = model.generate( +# torch.tensor(simple_batch_inputs[i:i+batch_size], device=model.device), +# generation_config=GenerationConfig( +# max_new_tokens=16, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id +# ), +# ) +# full_outputs.extend(outputs.tolist()) + +# end_time_simple = time.time() +# print(f"\nSimple batch generation took: {end_time_simple - start_time_simple:.2f} seconds") + +# print("\nResults from simple generate_batch:") +# for i, request in enumerate(full_outputs): +# output_text = tokenizer.decode(request, skip_special_tokens=False) +# print("-" * 20) +# print(f" Output: {output_text}") +# print("-" * 20) +# print("--- Finished Simple Batch Generation Example ---\n\n") diff --git a/setup.py b/setup.py index c3888e537358..a2401bf12dca 100644 --- a/setup.py +++ b/setup.py @@ -201,6 +201,9 @@ "pytest-rich", "libcst", "rich", + "opentelemetry-api", + "opentelemetry-exporter-otlp", + "opentelemetry-sdk", ] @@ -435,6 +438,9 @@ def run(self): extras["benchmark"] = deps_list("optimum-benchmark") +# OpenTelemetry dependencies for metrics collection in continuous batching +extras["open-telemetry"] = deps_list("opentelemetry-api", "opentelemetry-exporter-otlp", "opentelemetry-sdk") + # when modifying the following list, make sure to update src/transformers/dependency_versions_check.py install_requires = [ deps["filelock"], # filesystem locks, e.g., to prevent parallel downloads diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index dc2b37a19287..6a8c4d156d80 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -103,4 +103,7 @@ "pytest-rich": "pytest-rich", "libcst": "libcst", "rich": "rich", + "opentelemetry-api": "opentelemetry-api", + "opentelemetry-exporter-otlp": "opentelemetry-exporter-otlp", + "opentelemetry-sdk": "opentelemetry-sdk", } diff --git a/src/transformers/generation/__init__.py b/src/transformers/generation/__init__.py index cf1fa3661e0c..64ebfe6fc7c3 100644 --- a/src/transformers/generation/__init__.py +++ b/src/transformers/generation/__init__.py @@ -97,6 +97,9 @@ "validate_stopping_criteria", "StopStringCriteria", ] + _import_structure["continuous_batching"] = [ + "ContinuousMixin", + ] _import_structure["utils"] = [ "GenerationMixin", "GreedySearchEncoderDecoderOutput", @@ -213,6 +216,7 @@ EarlyExitCandidateGenerator, PromptLookupCandidateGenerator, ) + from .continuous_batching import ContinuousMixin from .logits_process import ( AlternatingCodebooksLogitsProcessor, ClassifierFreeGuidanceLogitsProcessor, diff --git a/src/transformers/generation/continuous_batching.py b/src/transformers/generation/continuous_batching.py new file mode 100644 index 000000000000..a1aa6fa2ec9b --- /dev/null +++ b/src/transformers/generation/continuous_batching.py @@ -0,0 +1,1446 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import queue +import statistics +import threading +import time +from abc import ABC, abstractmethod +from collections import deque +from dataclasses import dataclass, field +from enum import Enum +from functools import partial +from typing import Deque, Dict, List, Optional, Set, Tuple, Union + +import torch +import torch.nn as nn +from torch.profiler import profile, schedule, tensorboard_trace_handler +from tqdm import tqdm + +from ..cache_utils import Cache +from ..configuration_utils import PretrainedConfig +from ..generation.configuration_utils import GenerationConfig +from ..utils.metrics import ContinuousBatchProcessorMetrics, attach_tracer, traced + + +class RequestStatus(Enum): + """Status of a generation request through its lifecycle.""" + + PENDING = "pending" + PREFILLING = "prefilling" + PREFILLING_SPLIT = "prefilling_split" + SPLIT_PENDING_REMAINDER = "split_pending_remainder" + DECODING = "decoding" + FINISHED = "finished" + FAILED = "failed" + + +# Setup your logger +logger = logging.getLogger(__name__) +logger.setLevel(logging.WARNING) + + +@dataclass +class GenerationOutput: + """Tracks the output of a generation request. + + Attributes: + request_id (str): The ID of the generation request. + prompt_ids (List[int]): The IDs of the prompt tokens. + generated_tokens (List[int]): The generated tokens. + logprobs (List[float]): The log probabilities of the generated tokens. + error (Optional[str]): Any error message associated with the request. When None, the request was successful. + """ + + request_id: str + prompt_ids: List[int] = field(default_factory=list) + generated_tokens: List[int] = field(default_factory=list) + logprobs: List[float] = field(default_factory=list) + error: Optional[str] = None + status: RequestStatus = RequestStatus.PENDING + created_time: float = field(default_factory=time.time) + + +@dataclass +class RequestState: + """Tracks the state of a generation request through its lifecycle. + + Attributes: + status (RequestStatus): can be one of PENDING, PREFILLING, PREFILLING_SPLIT, + SPLIT_PENDING_REMAINDER, DECODING, FINISHED, FAILED + """ + + # Required fields + request_id: str + prompt_ids: Optional[List[int]] = None # the one being processed + full_prompt_ids: Optional[List[int]] = None # the full prompt + remaining_prompt_ids: List[int] = field(default_factory=list) # For split requests + static_outputs: List[int] = field(default_factory=list) + allocated_blocks: List[int] = field(default_factory=list) + position_offset: int = 0 # Current position in the sequence for position_ids + status: RequestStatus = RequestStatus.PENDING + max_new_tokens: int = 20 + eos_token_id: int = -1 + created_time: float = field(default_factory=time.time) + error: Optional[str] = None + + def current_len(self) -> int: + """Get the current length of the sequence (prompt + generated tokens).""" + return self.position_offset + + def generated_len(self) -> int: + """Get the number of tokens generated so far.""" + return len(self.static_outputs) + + @traced + def update_with_token(self, token_id: int) -> bool: + """Update the request with a newly generated token and check for completion. + + Args: + token_id: The token ID to add to the output sequence + + Returns: + bool: True if the request is now complete, False otherwise + """ + # Only update if we're in decoding state + if self.status != RequestStatus.DECODING: + return False + + is_eos = token_id == self.eos_token_id and self.eos_token_id != -1 + is_max_len = self.generated_len() >= self.max_new_tokens + + if is_eos or is_max_len: + self.status = RequestStatus.FINISHED + return True + return False + + def __repr__(self): + return f"RequestState(\n\trequest_id={self.request_id},\n\tstatus={self.status},\n\tout_tokens={self.generated_len()},\n\tquery_length={len(self.prompt_ids)}, \n\tremaining_tokens={len(self.remaining_prompt_ids)}, \n\tkv_length={self.position_offset}\n\tfull_prompt_lenght={len(self.full_prompt_ids)},\n\tallocated_blocks={self.allocated_blocks},\n\tgenerated_tokens={self.static_outputs}\n)" + + def to_generation_output(self): + """Convert the request state to a GenerationOutput object.""" + return GenerationOutput( + request_id=self.request_id, + prompt_ids=self.full_prompt_ids, + status=self.status, + generated_tokens=self.static_outputs, + logprobs=[], + error=self.error, + ) + + +@attach_tracer() +class PagedAttentionCache(Cache): + def __init__( + self, + config: PretrainedConfig, + generation_config: GenerationConfig, + device: torch.device, + dtype: torch.dtype = torch.float16, + layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None, + initial_prompt_shapes: Optional[List[List[int]]] = None, + ) -> None: + """Initialize a paged attention cache for efficient memory usage. + + Args: + config: Model configuration + generation_config: Generation configuration containing cache parameters + device: Device for the cache tensors + dtype: Data type for the cache tensors + layer_device_map: Optional mapping of layer indices to devices + initial_prompt_shapes: Optional sample prompts to help calculate optimal cache size + """ + # Extract model dimensions + self.num_key_value_heads = ( + config.num_attention_heads + if getattr(config, "num_key_value_heads", None) is None + else config.num_key_value_heads + ) + self.head_dim = ( + config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads + ) + self.num_hidden_layers = config.num_hidden_layers + + # Calculate optimal block size and number if not provided + num_blocks = getattr(generation_config, "num_blocks", None) + block_size = getattr(generation_config, "block_size", None) + if num_blocks is None or block_size is None: + logger.info("Calculating optimal block size and number...") + num_blocks, block_size = compute_optimal_blocks( + device, config, generation_config, initial_prompt_shapes or [], dtype, median_prefill_length=200 + ) + logger.info(f"Using calculated num_blocks={num_blocks}, block_size={block_size}") + + self.block_size = block_size + self.num_blocks = num_blocks + self.cache_shape = (self.num_key_value_heads, num_blocks, self.block_size, self.head_dim) + + self.dtype = dtype + self.device = device + + self.key_cache: List[torch.Tensor] = [] + self.value_cache: List[torch.Tensor] = [] + for idx in range(config.num_hidden_layers): + layer_device = layer_device_map[idx] if layer_device_map is not None else device + new_layer_key_cache = torch.zeros(self.cache_shape, dtype=self.dtype, device=layer_device) + new_layer_value_cache = torch.zeros(self.cache_shape, dtype=self.dtype, device=layer_device) + # Note: `mark_static_address` is used to tag the cache as a fixed data pointer, + # preventing compiled graph breaks when updating the cache. + torch._dynamo.mark_static_address(new_layer_key_cache) + torch._dynamo.mark_static_address(new_layer_value_cache) + self.key_cache.append(new_layer_key_cache) + self.value_cache.append(new_layer_value_cache) + + # Block management data structures + self._free_blocks = deque(range(num_blocks)) + self._block_tables: Dict[str, List[int]] = {} + + @traced + def allocate_blocks(self, n_blocks: int, request_id: str) -> List[int]: + """Allocates n_blocks for a given request_id.""" + if len(self._free_blocks) < n_blocks: + return False + + allocated = [] + for _ in range(n_blocks): + allocated.append(self._free_blocks.popleft()) + + if request_id not in self._block_tables: + self._block_tables[request_id] = [] + self._block_tables[request_id].extend(allocated) + return allocated + + @traced + def free_blocks(self, request_id: str) -> None: + """Frees all blocks associated with a request_id.""" + if request_id in self._block_tables: + blocks_to_free = self._block_tables.pop(request_id) + self._free_blocks.extend(blocks_to_free) + else: + logger.warning(f"Attempted to free blocks for non-existent request_id: {request_id}") + + def get_num_free_blocks(self) -> int: + """Returns the number of free blocks available.""" + return len(self._free_blocks) + + def get_block_table(self, request_id: str) -> List[int]: + """Returns the block table for a request.""" + return self._block_tables.get(request_id, []) + + @traced + def _get_physical_indices(self, state: RequestState, logical_indices: List[int]) -> List[int]: + """ + Maps logical sequence indices to physical cache indices using the block table, using PyTorch. + + Args: + request_id: The request ID. + logical_indices: A list of logical indices. + + Returns: + A list of physical indices. + + Raises: + ValueError: If no block table is found for the request ID. + IndexError: If a logical index maps to a block index that is out of bounds. + """ + request_id = state.request_id + block_table = self._block_tables.get(request_id) + if not block_table: + raise ValueError(f"No block table found for request {request_id}") + + block_size = self.block_size + physical_indices = [] + + for idx in logical_indices: + block_idx = idx // block_size + block_offset = idx % block_size + + if block_idx >= len(block_table): + raise IndexError( + f"Logical index {idx} maps to block index {block_idx} which is out of bounds " + f"for request {request_id}" + ) + + physical_block_num = block_table[block_idx] + physical_index = physical_block_num * block_size + block_offset + physical_indices.append(physical_index) + + return physical_indices + + @traced + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + read_index, + write_index, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Reshape cache for easier indexing + total_slots = self.num_blocks * self.block_size + k_cache_flat = self.key_cache[layer_idx].view(self.num_key_value_heads, total_slots, self.head_dim) + v_cache_flat = self.value_cache[layer_idx].view(self.num_key_value_heads, total_slots, self.head_dim) + k_cache_flat[:, write_index, :] = key_states[0] + v_cache_flat[:, write_index, :] = value_states[0] + return k_cache_flat[None, :, read_index, :], v_cache_flat[None, :, read_index, :] + + +class Scheduler(ABC): + """ + Abstract base class for scheduling requests in the continuous batch processor. + It is expected that cache allocation and scheduling logic will be implemented in subclasses. + """ + + def __init__(self, cache: PagedAttentionCache): + self.active_requests: Dict[str, RequestState] = {} + self.waiting_requests: Dict[str, RequestState] = {} + self.waiting_requests_order: Deque[str] = deque() + self.cache = cache + + @abstractmethod + def add_waiting_request(self, state: RequestState): + """Add a request to the waiting list.""" + pass + + @abstractmethod + def schedule_batch(self, token_budget: int) -> List[RequestState]: + pass + + @traced + def has_pending_requests(self) -> bool: + """Check if there are requests ready to be processed.""" + return self.active_requests or self.waiting_requests + + @abstractmethod + def finish_request(self, state: RequestState): + """Finish processing a request and free its allocated blocks.""" + pass + + @traced + def get_active_request_static_outputs(self, request_id: str) -> List[int]: + if request_id in self.active_requests: + return self.active_requests[request_id].static_outputs + return [] + + +@attach_tracer() +class FIFOScheduler(Scheduler): + @traced + def _allocate_blocks_if_needed(self, state: RequestState, len_next_tokens: int): + # 1. we check that the occupancy is less than the requested length + # 2. we allocate enough blocks to cover the requested length + current_len = state.current_len() + occupancy = len(state.allocated_blocks) * self.cache.block_size - current_len + if occupancy < len_next_tokens or (len(state.allocated_blocks) == 0): + blocks_needed = ((len_next_tokens - occupancy + 1) // self.cache.block_size) + 1 + allocated = self.cache.allocate_blocks(blocks_needed, state.request_id) + if not allocated: + return False + state.allocated_blocks.extend(allocated) + return True + + @traced(span_name="prepare_request") + def _prepare_request_for_processing( + self, state: RequestState, token_budget: int, request_ids_to_remove_from_waiting: Set[str] + ): + """Prepare a request for processing in the current batch.""" + request_tokens = ( + state.remaining_prompt_ids if state.status == RequestStatus.SPLIT_PENDING_REMAINDER else state.prompt_ids + ) + if len(request_tokens) < token_budget: + # Can process the entire prompt/remainder + if state.status == RequestStatus.PENDING: + self.active_requests[state.request_id] = state + state.status = RequestStatus.PREFILLING + request_ids_to_remove_from_waiting.add(state.request_id) + elif state.status == RequestStatus.SPLIT_PENDING_REMAINDER: + state.status = RequestStatus.PREFILLING + state.prompt_ids = state.remaining_prompt_ids + state.remaining_prompt_ids = [] + else: + # Need to split the request + if state.status == RequestStatus.PENDING: + self.active_requests[state.request_id] = state + state.status = RequestStatus.PREFILLING_SPLIT + request_ids_to_remove_from_waiting.add(state.request_id) + elif state.status == RequestStatus.SPLIT_PENDING_REMAINDER: + state.status = RequestStatus.PREFILLING_SPLIT + state.remaining_prompt_ids = request_tokens[token_budget:] + state.prompt_ids = request_tokens[:token_budget] + + @traced + def add_waiting_request(self, state: RequestState): + """Add a request to the waiting list.""" + self.waiting_requests[state.request_id] = state + self.waiting_requests_order.append(state.request_id) + + @traced + def schedule_batch(self, token_budget: int) -> List[RequestState]: + priority_states: List[RequestState] = [] + second_priority_states: List[RequestState] = [] + scheduled_requests = [] + + for state in self.active_requests.values(): + if state.status == RequestStatus.DECODING: + priority_states.append(state) + if state.status == RequestStatus.SPLIT_PENDING_REMAINDER: + second_priority_states.append(state) + + # Add waiting requests to second priority + for req_id in self.waiting_requests_order: + second_priority_states.append(self.waiting_requests[req_id]) + + candidates = priority_states + second_priority_states + request_ids_to_remove_from_waiting = set() + + for state in candidates: + self._prepare_request_for_processing(state, token_budget, request_ids_to_remove_from_waiting) + request_len = len(state.prompt_ids) + if not self._allocate_blocks_if_needed( + state, len(state.prompt_ids) + ): # don't schedule if we can't allocate blocks + if len(self.cache._free_blocks) == 0: + break + continue + + @traced + def _add_to_scheduled_requests(state: RequestState): + scheduled_requests.append(state) + + _add_to_scheduled_requests(state) + + token_budget -= request_len + + @traced + def _remove_from_waiting_requests(state: RequestState): + req_id = state.request_id + if req_id in self.waiting_requests: + del self.waiting_requests[req_id] + request_ids_to_remove_from_waiting.add(req_id) + + _remove_from_waiting_requests(state) + + if token_budget == 0: + break + + self.waiting_requests_order = deque( + [req_id for req_id in self.waiting_requests_order if req_id not in request_ids_to_remove_from_waiting] + ) + + return scheduled_requests + + @traced + def finish_request(self, state: RequestState): + request_id = state.request_id + self.cache.free_blocks(request_id) + if request_id in self.active_requests: + del self.active_requests[request_id] + + +@attach_tracer() +class PrefillFirstScheduler(Scheduler): + @traced + def _allocate_blocks_if_needed(self, state: RequestState, len_next_tokens: int): + # 1. we check that the occupancy is less than the requested length + # 2. we allocate enough blocks to cover the requested length + current_len = state.current_len() + occupancy = len(state.allocated_blocks) * self.cache.block_size - current_len + if occupancy < len_next_tokens or (len(state.allocated_blocks) == 0): + blocks_needed = ((len_next_tokens - occupancy + 1) // self.cache.block_size) + 1 + allocated = self.cache.allocate_blocks(blocks_needed, state.request_id) + if not allocated: + return False + state.allocated_blocks.extend(allocated) + return True + + @traced(span_name="prepare_request") + def _prepare_request_for_processing( + self, state: RequestState, token_budget: int, request_ids_to_remove_from_waiting: Set[str] + ): + """Prepare a request for processing in the current batch.""" + request_tokens = ( + state.remaining_prompt_ids if state.status == RequestStatus.SPLIT_PENDING_REMAINDER else state.prompt_ids + ) + if len(request_tokens) < token_budget: + # Can process the entire prompt/remainder + if state.status == RequestStatus.PENDING: + self.active_requests[state.request_id] = state + state.status = RequestStatus.PREFILLING + request_ids_to_remove_from_waiting.add(state.request_id) + elif state.status == RequestStatus.SPLIT_PENDING_REMAINDER: + state.status = RequestStatus.PREFILLING + state.prompt_ids = state.remaining_prompt_ids + state.remaining_prompt_ids = [] + else: + # Need to split the request + if state.status == RequestStatus.PENDING: + self.active_requests[state.request_id] = state + state.status = RequestStatus.PREFILLING_SPLIT + request_ids_to_remove_from_waiting.add(state.request_id) + elif state.status == RequestStatus.SPLIT_PENDING_REMAINDER: + state.status = RequestStatus.PREFILLING_SPLIT + state.remaining_prompt_ids = request_tokens[token_budget:] + state.prompt_ids = request_tokens[:token_budget] + + @traced + def add_waiting_request(self, state: RequestState): + """Add a request to the waiting list.""" + self.waiting_requests[state.request_id] = state + self.waiting_requests_order.append(state.request_id) + + @traced + def schedule_batch(self, token_budget: int) -> List[RequestState]: + priority_states: List[RequestState] = [] + second_priority_states: List[RequestState] = [] + scheduled_requests = [] + + for state in self.active_requests.values(): + if state.status == RequestStatus.SPLIT_PENDING_REMAINDER: + priority_states.append(state) + elif state.status == RequestStatus.DECODING: + second_priority_states.append(state) + + for req_id in self.waiting_requests_order: + second_priority_states.append(self.waiting_requests[req_id]) + + candidates = priority_states + second_priority_states + + request_ids_to_remove_from_waiting = set() + + for state in candidates: + self._prepare_request_for_processing(state, token_budget, request_ids_to_remove_from_waiting) + request_len = len(state.prompt_ids) + if not self._allocate_blocks_if_needed( + state, len(state.prompt_ids) + ): # don't schedule if we can't allocate blocks + if len(self.cache._free_blocks) == 0: + break + continue + + @traced + def _add_to_scheduled_requests(state: RequestState): + scheduled_requests.append(state) + + _add_to_scheduled_requests(state) + + token_budget -= request_len + + @traced + def _remove_from_waiting_requests(state: RequestState): + req_id = state.request_id + if req_id in self.waiting_requests: + del self.waiting_requests[req_id] + request_ids_to_remove_from_waiting.add(req_id) + + _remove_from_waiting_requests(state) + + if token_budget == 0: + break + + self.waiting_requests_order = deque( + [req_id for req_id in self.waiting_requests_order if req_id not in request_ids_to_remove_from_waiting] + ) + + return scheduled_requests + + @traced + def finish_request(self, state: RequestState): + request_id = state.request_id + self.cache.free_blocks(request_id) + if request_id in self.active_requests: + del self.active_requests[request_id] + + +@traced(standalone=True) +def compute_optimal_blocks( + device: torch.device, + config: PretrainedConfig, + generation_config: GenerationConfig, + inputs: List[List[int]], + dtype: torch.dtype = torch.bfloat16, + safety_margin: float = 0.9, + median_prefill_length: Optional[int] = None, +): + """Calculate optimal number and size of blocks for the KV cache. + + Args: + device: The device where the model runs + config: The model configuration + generation_config: The generation configuration + inputs: Sample input sequences to estimate memory requirements + dtype: Data type for cache tensors + safety_margin: Fraction of available memory to use + median_prefill_length: Override for median prefill length calculation + + Returns: + Tuple of (num_blocks, block_size) + """ + # Extract model dimensions + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + num_kv_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) + num_hidden_layers = getattr(config, "num_hidden_layers", 40) + + # Get available device memory + if device.type == "cuda": + device_properties = torch.cuda.get_device_properties(device) + total_memory = device_properties.total_memory + allocated_memory = torch.cuda.memory_allocated(device) + reserved_memory = torch.cuda.memory_reserved(device) + available_memory = total_memory - max(allocated_memory, reserved_memory) + elif device.type == "mps": + logger.warning("MPS memory estimation is approximate. Using conservative defaults.") + return 2048, 256 + else: + logger.warning(f"Unsupported device type {device.type} for optimal block calculation. Using defaults.") + return 32, 128 + + # Apply safety margin + available_memory = int(available_memory * safety_margin) + if available_memory <= 0: + logger.warning("Not enough available memory. Using minimum configuration.") + return 8, 128 # Minimum viable configuration + + # Calculate memory per token + dtype_size = torch.tensor([], dtype=dtype).element_size() + memory_per_token = 2 * num_kv_heads * head_dim * dtype_size * num_hidden_layers # For K and V caches + + # Estimate sequence length requirements + tokens_to_generate = getattr(generation_config, "max_new_tokens", 20) + + if median_prefill_length is None and inputs: + non_empty_inputs = [len(seq) for seq in inputs if seq] + median_prefill_length = int(statistics.median(non_empty_inputs)) if non_empty_inputs else 64 + elif median_prefill_length is None: + median_prefill_length = 64 # Reasonable default if no inputs provided + + # Total sequence length including generated tokens + seq_length = median_prefill_length + tokens_to_generate + + # Calculate block parameters + MIN_BLOCK_SIZE = 16 + + # Estimate number of concurrent sequences + per_sequence_memory = seq_length * memory_per_token + max_concurrent_sequences = max(1, int(available_memory // per_sequence_memory)) + + # Total tokens that can fit in memory + total_tokens = available_memory // memory_per_token + + # Calculate block size (rounded to power of 2) + initial_block_size = max(MIN_BLOCK_SIZE, total_tokens // (max_concurrent_sequences * 2)) + block_size = 1 << (initial_block_size - 1).bit_length() # Round to power of 2 + + # Calculate number of blocks + num_blocks = max(1, total_tokens // block_size) + + logger.info( + f"Optimal cache: {num_blocks} blocks of size {block_size} " + f"(can handle ~{num_blocks * block_size // seq_length} sequences of length {seq_length})" + ) + + return int(num_blocks), int(block_size) + + +@dataclass +class PagedAttentionArgs: + input_ids: torch.Tensor + attention_mask: torch.Tensor + position_ids: torch.Tensor + cumulative_seqlens_q: torch.Tensor + cumulative_seqlens_k: torch.Tensor + max_seqlen_q: int + max_seqlen_k: int + write_index: torch.Tensor + read_index: torch.Tensor + logits_indices: torch.Tensor + block_tables: Dict[str, List[int]] + cache: PagedAttentionCache + use_cache: bool = False + + +@traced +def create_document_mask(cumulative_seqlens_q, cumulative_seqlens_k): + # Number of documents + valid_docs_q = cumulative_seqlens_q[1:] > cumulative_seqlens_q[:-1] + valid_docs_k = cumulative_seqlens_k[1:] > cumulative_seqlens_k[:-1] + num_valid_docs = min(valid_docs_q.sum(), valid_docs_k.sum()) + + # Trim to valid docs + cumulative_seqlens_q = cumulative_seqlens_q[: num_valid_docs + 1] + cumulative_seqlens_k = cumulative_seqlens_k[: num_valid_docs + 1] + + total_q = cumulative_seqlens_q[-1] + total_k = cumulative_seqlens_k[-1] + + q_indices = torch.arange(total_q, device=cumulative_seqlens_q.device) + k_indices = torch.arange(total_k, device=cumulative_seqlens_k.device) + + q_doc_ids = torch.bucketize(q_indices, cumulative_seqlens_q[1:], right=True) + k_doc_ids = torch.bucketize(k_indices, cumulative_seqlens_k[1:], right=False) + doc_mask = q_doc_ids[:, None] == k_doc_ids[None, :] + # apply causal mask where no decoding (same nb of q than k) + + is_causal = ~(cumulative_seqlens_q[1:] - cumulative_seqlens_q[:-1] == 1) * cumulative_seqlens_q[1:] + apply_causal = torch.bucketize(q_indices, is_causal, right=True)[:, None] == k_doc_ids + # TODO don't apply on prefill splitting + causal_mask = torch.triu(torch.ones(total_q, total_k, device=q_doc_ids.device), diagonal=1).bool() + doc_mask.masked_fill_((apply_causal & causal_mask), False) + return doc_mask + + +# Continuous Batch Processor (Internal Logic) +@attach_tracer() +class ContinuousBatchProcessor: + def __init__( + self, + cache: PagedAttentionCache, + config: PretrainedConfig, + generation_config: GenerationConfig, + input_queue: queue.Queue, + output_queue: queue.Queue, + stop_event: threading.Event, + model_device: torch.device, + model_dtype: torch.dtype, + scheduler: Scheduler, + streaming: bool = False, + ): + """Initialize the continuous batch processor. + + Args: + cache: The paged attention cache to use + generation_config: The generation configuration + input_queue: Queue for incoming requests + output_queue: Queue for outgoing results + stop_event: Event to signal processing should stop + model_device: Device for model inputs/outputs + model_dtype: Data type for model inputs/outputs + streaming: Whether to stream tokens as they're generated + """ + self.cache = cache + self.config = config + self.generation_config = generation_config + self.input_queue = input_queue + self.output_queue = output_queue + self.stop_event = stop_event + self.model_device = model_device + self.model_dtype = model_dtype + self.scheduler = scheduler + self.streaming = streaming + + self.requests_in_batch: List[RequestState] = [] + + # Get batch size parameters from generation config + self._configure_batch_parameters() + + # Set up metrics collector + self.metrics = ContinuousBatchProcessorMetrics(self.max_batch_tokens) + + self.setup_static_tensors() + + @traced(standalone=True) + def setup_static_tensors(self): + T = self.max_batch_tokens + max_token_budget = self.cache.num_blocks * self.cache.block_size + tensor_metadata = {"dtype": torch.int32, "device": self.model_device} + self.tensor_metadata = tensor_metadata + self.input_ids = torch.zeros((1, T), **tensor_metadata) + self.position_ids = torch.zeros((1, T), **tensor_metadata) + self.attention_mask = torch.zeros( + (1, 1, T, max_token_budget), dtype=self.model_dtype, device=self.model_device + ) + self.cumulative_seqlens_q = torch.zeros((T + 1,), **tensor_metadata) + self.cumulative_seqlens_k = torch.zeros((T + 1,), **tensor_metadata) + self.write_index = torch.zeros((T,), **tensor_metadata) + self.read_index = torch.zeros((max_token_budget,), **tensor_metadata) + self.logits_indices = torch.full((T,), -1, **tensor_metadata) + self.max_seqlen_q = 0 + self.max_seqlen_k = 0 + self.output_ids = torch.full((1, T), -1, **tensor_metadata) + + @traced + @torch.no_grad() + @torch.compile() + def reset_static_tensors(self): + """Reset static tensors for the next batch.""" + self.input_ids.zero_() + self.position_ids.zero_() + self.attention_mask.fill_(torch.finfo(self.model_dtype).min) + self.cumulative_seqlens_q.zero_() + self.cumulative_seqlens_k.zero_() + self.write_index.fill_(-1) + self.read_index.fill_(-1) + self.logits_indices.fill_(-1) + self.max_seqlen_q = 0 + self.max_seqlen_k = 0 + self.output_ids.zero_() + + def get_model_kwargs(self) -> PagedAttentionArgs: + """Get model keyword arguments for the current batch.""" + # torch.set_printoptions(threshold=100000,linewidth=10000) + return { + "input_ids": self.input_ids, + "position_ids": self.position_ids, + "attention_mask": self.attention_mask, + "cumulative_seqlens_q": self.cumulative_seqlens_q, + "cumulative_seqlens_k": self.cumulative_seqlens_k, + "write_index": self.write_index, + "read_index": self.read_index, + "logits_indices": self.logits_indices, + "max_seqlen_q": self.max_seqlen_q, + "max_seqlen_k": self.max_seqlen_k, + "block_tables": self.cache._block_tables, + "cache": self.cache, + "use_cache": False, + } + + def __repr__(self): + return ( + f"ContinuousBatchProcessor(input_queue={self.input_queue}, output_queue={self.output_queue}, active_requests={self.scheduler.active_requests}, waiting_requests={self.scheduler.waiting_requests})" + + self.get_model_kwargs().__repr__() + ) + + @traced(standalone=True) + def _configure_batch_parameters(self): + """Set up batch processing parameters based on generation config.""" + # Calculate total cache capacity + total_cache_tokens = self.cache.num_blocks * self.cache.block_size + + # Get or calculate max tokens per batch + user_batch_tokens = getattr(self.generation_config, "max_batch_tokens", None) + if user_batch_tokens is not None: + self.max_batch_tokens = user_batch_tokens + else: + # Default to 1/8 of total cache capacity, adjusted for context + self.max_context_len = getattr(self.generation_config, "max_position_embeddings", 2048) + recommended_batch_size = min(total_cache_tokens // 8, self.max_context_len) + self.max_batch_tokens = max(64, recommended_batch_size) + + # Context length and EOS token + self.max_context_len = getattr(self.generation_config, "max_position_embeddings", 2048) + + @traced + def _get_new_requests(self): + """Pull new requests from the input queue and add to waiting list.""" + while not self.input_queue.empty(): + try: + state = self.input_queue.get_nowait() + if state is None: # Sentinel value + continue + self.scheduler.add_waiting_request(state) + + except queue.Empty: + break + except Exception as e: + logger.error(f"Error processing new request: {e}", exc_info=True) + state: RequestState = locals().get("state") + if state is not None: + self._handle_request_error(e, state) + + @traced + def _handle_request_error(self, error, state: RequestState): + """Handle general request processing error.""" + state.status = RequestStatus.FAILED + state.error = str(error) + + # Include any generated tokens if this is an active request + if isinstance(state.request_id, str): + state.static_outputs = self.scheduler.get_active_request_static_outputs(state.request_id) + else: + state.static_outputs = [] + + self.metrics.record_request_completion(state.created_time, state.request_id) + self.output_queue.put(state.to_generation_output()) + + @traced + def prepare_next_batch(self): + """Prepare tensors and metadata for the next model forward pass.""" + # Get new requests from the queue + self._get_new_requests() + if not self.scheduler.has_pending_requests(): + return None + + self.metrics.record_queue_metrics(len(self.scheduler.active_requests), len(self.scheduler.waiting_requests)) + + self.requests_in_batch = self.scheduler.schedule_batch(self.max_batch_tokens) + if not self.requests_in_batch: + return None + + # Get the request objects for this batch + self.reset_static_tensors() + position_ids = [] + input_ids = [] + read_index = [] + write_index = [] + cumulative_seqlens_q = [0] + cumulative_seqlens_k = [0] + logits_indices = [] + self.metrics.record_batch_metrics(self.requests_in_batch) + + for state in self.requests_in_batch: + next_input_ids = state.prompt_ids + input_ids.extend(next_input_ids) + past_length = state.position_offset + query_length = len(next_input_ids) + key_length = query_length + past_length + cache_index = list(range(key_length)) + + positions_to_add = cache_index[past_length:] + read_indices = self.cache._get_physical_indices(state, cache_index) + write_indices = read_indices[-query_length:] + + position_ids.extend(positions_to_add) + read_index.extend(read_indices) + write_index.extend(write_indices) + cumulative_seqlens_q.append(cumulative_seqlens_q[-1] + query_length) + cumulative_seqlens_k.append(cumulative_seqlens_k[-1] + key_length) + if len(state.remaining_prompt_ids) == 0: + logits_indices.append(cumulative_seqlens_q[-1] - 1) + self.max_seqlen_q = max(self.max_seqlen_q, query_length) + self.max_seqlen_k = max(self.max_seqlen_k, key_length) + state.position_offset += query_length + + logger.warning( + f"Scheduled: {len(self.requests_in_batch)}, Waiting: {len(self.scheduler.waiting_requests)}, Active: {len(self.scheduler.active_requests)}. cum Q: {cumulative_seqlens_q[-1]}. cum KV: {cumulative_seqlens_k[-1]}, free blocks: {self.cache.get_num_free_blocks()}" + ) + self._build_tensors( + input_ids, + position_ids, + read_index, + write_index, + cumulative_seqlens_q, + cumulative_seqlens_k, + logits_indices, + ) + + self.metrics.record_kv_cache_memory_metrics(self.cache) + + @traced + def _build_tensors( + self, + input_ids, + position_ids, + read_index, + write_index, + cumulative_seqlens_q, + cumulative_seqlens_k, + logits_indices, + ): + to_tensor = partial(torch.tensor, **self.tensor_metadata) + self.input_ids[:, : len(input_ids)] = to_tensor(input_ids) + self.position_ids[:, : len(position_ids)] = to_tensor(position_ids) + self.write_index[: len(write_index)] = to_tensor(write_index) + self.read_index[: len(read_index)] = to_tensor(read_index) + self.cumulative_seqlens_q[: len(cumulative_seqlens_q)] = to_tensor(cumulative_seqlens_q) + self.cumulative_seqlens_k[: len(cumulative_seqlens_k)] = to_tensor(cumulative_seqlens_k) + self.logits_indices[: len(logits_indices)] = to_tensor(logits_indices) + min_value = torch.finfo(self.model_dtype).min + if self.config._attn_implementation != "paged_attention": # we set `is_causal` to True in paged call` + for i in range(len(cumulative_seqlens_q) - 1): + if ( + cumulative_seqlens_q[i + 1] - cumulative_seqlens_q[i] + < cumulative_seqlens_k[i + 1] - cumulative_seqlens_k[i] + and cumulative_seqlens_q[i + 1] - cumulative_seqlens_q[i] >= 1 + ): + diagonal = ( + cumulative_seqlens_k[i + 1] - (cumulative_seqlens_q[i + 1] - cumulative_seqlens_q[i]) + 1 + ) + diagonal = diagonal - cumulative_seqlens_k[i] + else: + diagonal = 1 + query_range = slice(cumulative_seqlens_q[i], cumulative_seqlens_q[i + 1]) + key_range = slice(cumulative_seqlens_k[i], cumulative_seqlens_k[i + 1]) + + mask = torch.triu( + torch.full( + self.attention_mask[..., query_range, key_range].shape, + min_value, + dtype=self.model_dtype, + device=self.model_device, + ), + diagonal=diagonal, + ) + self.attention_mask[..., query_range, key_range] = mask + + @traced + def _sync(self): + return self.output_ids.tolist()[0] # should be the only synch we do + + @traced + def _maybe_send_output(self, state: RequestState, token: int): + """Send output to the queue based on streaming mode and request state.""" + if self.streaming: + state.next_token = token + self.output_queue.put(state.to_generation_output()) + elif state.status == RequestStatus.FINISHED: + self.output_queue.put(state.to_generation_output()) + + @traced + def update_batch(self): + """Update request states based on generated tokens.""" + out_tokens = self._sync() + finished_request_ids = [] + for i, state in enumerate(self.requests_in_batch): + req_id = state.request_id + if len(state.remaining_prompt_ids) == 0: + self.metrics.record_ttft_metric(state.created_time, state.request_id) + state.status = RequestStatus.DECODING + token = out_tokens[self.logits_indices[i]] + state.static_outputs.extend([token]) + state.prompt_ids = [token] + if state.update_with_token(token): + self.metrics.record_request_completion(state.created_time, state.request_id) + self.scheduler.finish_request(state) + finished_request_ids.append(req_id) + self._maybe_send_output(state, token) + elif state.status == RequestStatus.PREFILLING_SPLIT: + state.status = RequestStatus.SPLIT_PENDING_REMAINDER + + @traced + def has_pending_requests(self) -> bool: + """Check if there are any active or waiting requests.""" + return self.scheduler.has_pending_requests() + + @traced + def handle_batch_error(self, error): + """Handle errors during batch processing.""" + failed_reqs = self.requests_in_batch + for req in failed_reqs: + self._handle_request_error(error, req) + self.scheduler.finish_request(req) + + @traced + def fail_all_requests(self, error): + """Fail all active requests with the given error. + + Args: + error: The error to report in the failure message + """ + for state in self.scheduler.active_requests.values(): + self._handle_request_error(error, state) + self.scheduler.finish_request(state) + + # Also fail any requests in the waiting queue + for req_id in list(self.scheduler.waiting_requests.keys()): + state = self.scheduler.waiting_requests.pop(req_id) + self._handle_request_error(error, state) + + # Clear the ordering queue + self.scheduler.waiting_requests_order.clear() + + +SCHEDULER_MAPPING = { + "fifo": FIFOScheduler, + "prefill_first": PrefillFirstScheduler, +} + + +# Manager Class (User Interface) +@attach_tracer() +class ContinuousBatchingManager: + """Manager for handling continuous batching of generation requests. + + This class provides the user interface for submitting generation requests, + retrieving results, and managing the background generation thread. + """ + + def __init__(self, model, generation_config: GenerationConfig, max_queue_size=0, streaming: bool = True): + """Initialize the continuous batching manager. + + Args: + model: The language model for generation + generation_config: Configuration for generation parameters + max_queue_size: Maximum size of the request queue (0 = unlimited) + streaming: Whether to stream tokens as they are generated + """ + self.model = model + self.generation_config = generation_config + self.input_queue = queue.Queue(maxsize=max_queue_size) + self.output_queue = queue.Queue() + self.stop_event = threading.Event() + self.streaming = streaming + self.log_prob_generation = getattr(generation_config, "log_prob_generation", False) + self._generation_thread = None + self._request_counter = 0 + self._request_lock = threading.Lock() + self.model.generation_config.top_p = None + self.do_sample = getattr(generation_config, "do_sample", True) + self.logit_processor = self.model._get_logits_processor(self.model.generation_config) + self.use_cuda_graph = getattr(generation_config, "use_cuda_graph", True) + self.profile = getattr(generation_config, "profile", False) + + @traced + def start(self): + """Start the background generation thread.""" + if self._generation_thread is not None and self._generation_thread.is_alive(): + logger.warning("Manager thread is already running.") + return + + self._result_queue = queue.Queue() + self._generation_thread = threading.Thread(target=self._run_generation_loop) + self._generation_thread.start() + logger.info("Continuous batching manager started.") + + def is_running(self): + """Check if the background generation thread is running.""" + return self._generation_thread is not None and self._generation_thread.is_alive() + + def stop(self, block: bool = False, timeout: Optional[float] = None): + """Signal the background thread to stop. + + Args: + block: Whether to wait for the thread to stop + timeout: Maximum time to wait for the thread to stop + """ + if self._generation_thread is None: + logger.warning("Manager not started.") + return + + if not self.stop_event.is_set(): + self.stop_event.set() + logger.info("Stopping continuous batching manager...") + + if block: + self.join(timeout) + + def join(self, timeout: Optional[float] = None): + """Wait for the background thread to finish. + + Args: + timeout: Maximum time to wait for the thread to stop + """ + if self._generation_thread is not None: + self._generation_thread.join(timeout=timeout) + if self._generation_thread.is_alive(): + logger.warning("Generation thread did not exit after join timeout.") + else: + logger.info("Continuous Batching Manager stopped.") + self._generation_thread = None + + def add_request( + self, input_ids: List[int], request_id: Optional[str] = None, max_new_tokens: Optional[int] = None + ) -> str: + """Add a new generation request to the queue. + + Args: + input_ids: Input token IDs to use as prompt + request_id: Optional custom request ID (auto-generated if None) + **kwargs: Additional generation parameters + + Returns: + str: The request ID + """ + if request_id is None: + with self._request_lock: + request_id = f"req_{self._request_counter}" + self._request_counter += 1 + + max_new_tokens = self.generation_config.max_new_tokens if max_new_tokens is None else max_new_tokens + + state = RequestState( + request_id=request_id, + prompt_ids=list(input_ids), + full_prompt_ids=list(input_ids), + max_new_tokens=max_new_tokens, + eos_token_id=self.generation_config.eos_token_id, + ) + + # Use block=True with timeout to handle backpressure if queue is full + self.input_queue.put(state, block=True, timeout=10) # XXX: pass timeout as fn arg? + logger.debug(f"Added request {request_id} to queue.") + return request_id + + def add_requests(self, inputs: List[List[int]], **kwargs): + for i, input_ids in enumerate(inputs): + # Assign a predictable request ID for ordering results later + req_id = f"batch_req_{i}" + self.add_request(input_ids, request_id=req_id, **kwargs) + + def get_result(self, timeout=None) -> Optional[GenerationOutput]: + """Retrieve one result from the output queue. + + Args: + timeout: Maximum time to wait for a result + + Returns: + Optional[Dict]: The result data or None if timeout + """ + if self._generation_thread is None and self.output_queue.empty(): + return None + try: + result = self.output_queue.get(block=True, timeout=timeout) + logger.debug(f"Retrieved result for request {result.request_id}") + return result + except queue.Empty: + return None + + def __iter__(self): + """Iterate over results as they become available.""" + while ( + self._generation_thread is not None and self._generation_thread.is_alive() or not self.output_queue.empty() + ): + result = self.get_result(timeout=0.1) # allow the model to run for 10 seconds + if result is not None: + yield result + + @traced + def warmup(self, batch_processor): + stream = torch.cuda.Stream() + stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(stream): + # Warmup the model with a dummy forward pass + self._generation_step(batch_processor) + torch.cuda.current_stream().wait_stream(stream) + + self.graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(self.graph): + self._generation_step(batch_processor) + + @traced + # @torch.compile + def _generation_step(self, batch_processor: ContinuousBatchProcessor): + """Perform a single generation step. This is cuda graphed""" + batch_data = batch_processor.get_model_kwargs() + with torch.no_grad(): + logits = self._model_forward(batch_data) + if self.log_prob_generation: + batch_processor.output_probs.copy_(logits) # TODO + probs = self._process_logit(batch_data, logits) + self._sample(batch_processor, probs) + + @traced(span_name="model_forward") + def _model_forward(self, batch_data): + return self.model(**batch_data).logits + + @traced(span_name="logit_processing") + @torch.compile() + def _process_logit(self, batch_data, logits): + return self.logit_processor(batch_data["input_ids"], logits) + + @traced(span_name="sampling") + def _sample(self, batch_processor: ContinuousBatchProcessor, probs): + if self.do_sample: # sample + probs = nn.functional.softmax(probs, dim=-1) + next_tokens = torch.multinomial(probs[0], num_samples=1).squeeze(1) + else: + next_tokens = torch.argmax(probs, dim=-1) + batch_processor.output_ids.copy_(next_tokens) + + def _run_generation_loop(self): + """Main processing loop running in the background thread.""" + batch_processor = None + try: + paged_attention_cache = PagedAttentionCache( + self.model.config, + self.generation_config, + self.model.device, + self.model.dtype, + ) + + scheduler = SCHEDULER_MAPPING.get(self.generation_config.scheduler) + if scheduler is None: + logger.warning(f"Scheduler '{scheduler}' not found. Defaulting to FIFO.") + scheduler = FIFOScheduler + + batch_processor = ContinuousBatchProcessor( + paged_attention_cache, + self.model.config, + self.generation_config, + self.input_queue, + self.output_queue, + self.stop_event, + self.model.device, + self.model.dtype, + scheduler(paged_attention_cache), + self.streaming, + ) + is_first = True + + if self.profile: + tracing_schedule = schedule(skip_first=2, warmup=3, active=200, repeat=100, wait=1) + trace_handler = tensorboard_trace_handler( + dir_name="/fsx/arthur/transformers", use_gzip=True, worker_name="paged_compile" + ) + activities = [ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ] + with profile( + activities=activities, + schedule=tracing_schedule, + on_trace_ready=trace_handler, + record_shapes=False, + with_stack=True, + ) as prof: + while not self.stop_event.is_set() or batch_processor.has_pending_requests(): + self._inner_generation_loop(batch_processor, is_first) + if is_first: + is_first = False + prof.step() + else: + while not self.stop_event.is_set() or batch_processor.has_pending_requests(): + self._inner_generation_loop(batch_processor, is_first) + if is_first: + is_first = False + + except Exception as e: + logger.error(f"Error in generation loop: {e}", exc_info=True) + self._handle_critical_error(e, batch_processor) + finally: + logger.info("Generation loop finished.") + + @traced(span_name="generation_loop") + def _inner_generation_loop(self, batch_processor: ContinuousBatchProcessor, is_first: bool = False): + if torch.cuda.is_available(): + torch.cuda.synchronize() + batch_processor.prepare_next_batch() + if torch.cuda.is_available() and self.use_cuda_graph: + if is_first: + self.warmup(batch_processor) + elif hasattr(self, "graph"): + try: + self._graph_replay() + except Exception as e: + logger.error(f"Model forward pass failed: {e}", exc_info=True) + batch_processor.handle_batch_error(e) + return + else: + self._generation_step(batch_processor) + else: + self._generation_step(batch_processor) + if torch.cuda.is_available(): + torch.cuda.synchronize() + batch_processor.update_batch() + + @traced(span_name="graph_replay") + def _graph_replay(self): + self.graph.replay() + + @traced + def _handle_critical_error(self, error, batch_processor: Optional[ContinuousBatchProcessor]): + """Handle critical errors that terminate the generation loop.""" + # Signal stop + self.stop_event.set() + + # Fail pending requests in input queue + try: + while True: + req_data = self.input_queue.get_nowait() + if batch_processor is not None: + batch_processor._handle_request_error(error, req_data) + except queue.Empty: + pass + + # Fail active requests + if batch_processor is not None: + batch_processor.fail_all_requests(error) + + +class ContinuousMixin: + """Mixin class for models to add continuous batching capabilities.""" + + def init_continuous_batching( + self, + generation_config: Optional[GenerationConfig] = None, + max_queue_size: int = 0, + scheduler: str = "fifo", + streaming: bool = False, + ) -> ContinuousBatchingManager: + """Initialize a manager for continuous batching inference. + + Args: + generation_config: Custom generation configuration + max_queue_size: Maximum size of the input request queue + streaming: Whether to stream tokens as they are generated + + Returns: + `ContinuousBatchingManager`: The manager instance to add requests and retrieve results. + """ + if not hasattr(self, "config") or not hasattr(self, "device") or not hasattr(self, "dtype"): + raise AttributeError("Model must have 'config', 'device', and 'dtype' attributes.") + + gen_config = generation_config if generation_config is not None else self.generation_config + if gen_config is None: + raise ValueError("A GenerationConfig must be provided or set in the model.") + + if gen_config.eos_token_id is None: + logger.warning("`eos_token_id` not set in GenerationConfig. Setting to -1 (disabled).") + gen_config.eos_token_id = -1 + + # Create and return the manager + return ContinuousBatchingManager( + model=self, generation_config=gen_config, max_queue_size=max_queue_size, streaming=streaming + ) + + @traced + @torch.inference_mode() + def generate_batch( + self, + inputs: List[List[int]], + generation_config: Optional[GenerationConfig] = None, + progress_bar: bool = True, + **kwargs, + ) -> List[List[int]]: + """Generate sequences for a batch of prompts using continuous batching. + + Args: + inputs: List of input token sequences (prompts) + generation_config: Optional generation configuration + **kwargs: Additional generation parameters + + Returns: + `List[List[int]]`: A list containing the generated sequences (including prompt tokens + if not handled otherwise) for each input prompt, in the same order. + Returns an empty list `[]` for requests that failed. + """ + if not inputs: + return [] + + # Initialize manager with the batch inputs + manager = self.init_continuous_batching(generation_config=generation_config) + manager.start() + results = {} + num_requests = len(inputs) + try: + from tqdm.contrib.logging import logging_redirect_tqdm + + with logging_redirect_tqdm([logger]): + with tqdm( + total=num_requests, + disable=(not progress_bar), + desc=f"Solving {num_requests} requests", + unit="request", + ) as pbar: + manager.add_requests(inputs, **kwargs) + finished_count = 0 + while finished_count < num_requests: + result = manager.get_result(timeout=1) + if result: + req_id = result.request_id + if result.status == RequestStatus.FINISHED: + results[req_id] = result + finished_count += 1 + pbar.update(1) + else: + if not manager.is_running(): + logger.error("Generation thread terminated unexpectedly.") + break + + except Exception as e: + logger.error(f"Error during batch generation: {e}", exc_info=True) + finally: + manager.stop(block=True, timeout=5.0) + return results diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 64227812152e..7eef82d116fd 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -71,6 +71,7 @@ GenerationConfig, GenerationMode, ) +from .continuous_batching import ContinuousMixin from .logits_process import ( EncoderNoRepeatNGramLogitsProcessor, EncoderRepetitionPenaltyLogitsProcessor, @@ -344,7 +345,7 @@ class GenerateBeamEncoderDecoderOutput(ModelOutput): GenerateOutput = Union[GenerateNonBeamOutput, GenerateBeamOutput] -class GenerationMixin: +class GenerationMixin(ContinuousMixin): """ A class containing all functions for auto-regressive text generation, to be used as a mixin in model classes. Inheriting from this class causes the model to have special generation-related behavior, such as loading a @@ -1014,10 +1015,10 @@ def _get_candidate_generator( def _get_logits_processor( self, generation_config: GenerationConfig, - input_ids_seq_length: int, - encoder_input_ids: torch.LongTensor, - prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]], - logits_processor: Optional[LogitsProcessorList], + input_ids_seq_length: Optional[int] = None, + encoder_input_ids: torch.LongTensor = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, + logits_processor: Optional[LogitsProcessorList] = None, device: Optional[str] = None, model_kwargs: Optional[Dict[str, Any]] = None, negative_prompt_ids: Optional[torch.Tensor] = None, @@ -1029,6 +1030,8 @@ def _get_logits_processor( """ # instantiate processors list processors = LogitsProcessorList() + if logits_processor is None: + logits_processor = [] if generation_config.guidance_scale is not None and generation_config.guidance_scale != 1: processors.append( @@ -1098,7 +1101,7 @@ def _get_logits_processor( ) if ( generation_config.min_length is not None - and generation_config._eos_token_tensor is not None + and getattr(generation_config, "_eos_token_tensor", None) is not None and generation_config.min_length > 0 ): processors.append( @@ -1110,7 +1113,7 @@ def _get_logits_processor( ) if ( generation_config.min_new_tokens is not None - and generation_config._eos_token_tensor is not None + and getattr(generation_config, "_eos_token_tensor", None) is not None and generation_config.min_new_tokens > 0 ): processors.append( diff --git a/src/transformers/integrations/eager_paged.py b/src/transformers/integrations/eager_paged.py new file mode 100644 index 000000000000..9893e10c89ae --- /dev/null +++ b/src/transformers/integrations/eager_paged.py @@ -0,0 +1,45 @@ +from typing import Optional + +import torch +from torch import nn + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_paged_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + cache = kwargs.pop("cache", None) + if cache is not None: + key, value = cache.update(key, value, module.layer_idx, **kwargs) + + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights diff --git a/src/transformers/integrations/flash_paged.py b/src/transformers/integrations/flash_paged.py new file mode 100644 index 000000000000..b0463d952487 --- /dev/null +++ b/src/transformers/integrations/flash_paged.py @@ -0,0 +1,64 @@ +import torch + +from ..generation.continuous_batching import PagedAttentionCache +from ..utils import is_flash_attn_2_available + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_varlen_func + + +def paged_attention_forward( + module: torch.nn.Module, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attention_mask: torch.Tensor = None, + cache: PagedAttentionCache = None, + cumulative_seqlens_q=None, + cumulative_seqlens_k=None, + max_seqlen_q=None, + max_seqlen_k=None, + block_tables=None, + **kwargs, +) -> torch.Tensor: + r"""Perform the forward pass of attention with paged key-value cache. + + This function handles the cache updates and performs the attention computation + using the flash_attn_varlen_func for efficient processing. + + Args: + q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. + k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. but if there is a block table it can be the full k + v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. but if there is a block table it can be the full v + cumulative_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into q. + cumulative_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into kv. + max_seqlen_q: int. Maximum query sequence length in the batch. + max_seqlen_k: int. Maximum key sequence length in the batch. + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. + softcap: float. Anything > 0 activates softcapping attention. + """ + k, v = cache.update(k, v, module.layer_idx, cumulative_seqlens_k=cumulative_seqlens_k, **kwargs) + + attn_output = flash_attn_varlen_func( + q.transpose(1, 2).squeeze(0), + k.transpose(1, 2).squeeze(0), + v.transpose(1, 2).squeeze(0), + cumulative_seqlens_q.to(torch.int32), + cumulative_seqlens_k.to(torch.int32), + max_seqlen_q, + max_seqlen_k, + softmax_scale=module.scaling, + causal=True, # kind of a must, it automatically aligns the mask for q < k + window_size=(-1, -1), # -1 means infinite context window + # block_table=block_tables, -> torch.Tensor + # **kwargs, + ) + + return attn_output, None diff --git a/src/transformers/integrations/sdpa_paged.py b/src/transformers/integrations/sdpa_paged.py new file mode 100644 index 000000000000..640db16d0dec --- /dev/null +++ b/src/transformers/integrations/sdpa_paged.py @@ -0,0 +1,51 @@ +from typing import Optional, Tuple + +import torch + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def sdpa_attention_paged_forward( + module: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + dropout: float = 0.0, + scaling: Optional[float] = None, + is_causal: Optional[bool] = None, + **kwargs, +) -> Tuple[torch.Tensor, None]: + cache = kwargs.pop("cache", None) + if cache is not None: + key, value = cache.update(key, value, module.layer_idx, **kwargs) + if hasattr(module, "num_key_value_groups"): + key = repeat_kv(key, module.num_key_value_groups) + value = repeat_kv(value, module.num_key_value_groups) + + causal_mask = attention_mask + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + attn_output = torch.nn.functional.scaled_dot_product_attention( + query, + key, + value, + attn_mask=causal_mask, + dropout_p=dropout, + scale=scaling, + is_causal=False, + ) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, None diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index c7d54dc41514..57e974e9f912 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -421,9 +421,9 @@ class FlashAttentionKwargs(TypedDict, total=False): Keyword arguments for Flash Attention with Compile. Attributes: - cu_seq_lens_q (`torch.LongTensor`, *optional*) + cumulative_seqlens_q (`torch.LongTensor`, *optional*) Gets cumulative sequence length for query state. - cu_seq_lens_k (`torch.LongTensor`, *optional*) + cumulative_seqlens_k (`torch.LongTensor`, *optional*) Gets cumulative sequence length for key state. max_length_q (`int`, *optional*): Maximum sequence length for query state. @@ -431,7 +431,7 @@ class FlashAttentionKwargs(TypedDict, total=False): Maximum sequence length for key state. """ - cu_seq_lens_q: Optional[torch.LongTensor] - cu_seq_lens_k: Optional[torch.LongTensor] + cumulative_seqlens_q: Optional[torch.LongTensor] + cumulative_seqlens_k: Optional[torch.LongTensor] max_length_q: Optional[int] max_length_k: Optional[int] diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index f994d9b08769..9b316383a280 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -58,9 +58,12 @@ from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled from .integrations.accelerate import find_tied_parameters, init_empty_weights from .integrations.deepspeed import _load_state_dict_into_zero3_model +from .integrations.eager_paged import eager_paged_attention_forward from .integrations.flash_attention import flash_attention_forward +from .integrations.flash_paged import paged_attention_forward from .integrations.flex_attention import flex_attention_forward from .integrations.sdpa_attention import sdpa_attention_forward +from .integrations.sdpa_paged import sdpa_attention_paged_forward from .integrations.tensor_parallel import ( SUPPORTED_TP_STYLES, shard_and_distribute_module, @@ -6064,7 +6067,10 @@ class AttentionInterface(MutableMapping): _global_mapping = { "flash_attention_2": flash_attention_forward, "flex_attention": flex_attention_forward, + "paged_attention": paged_attention_forward, "sdpa": sdpa_attention_forward, + "sdpa_paged": sdpa_attention_paged_forward, + "eager_paged": eager_paged_attention_forward, } def __init__(self): diff --git a/src/transformers/utils/metrics.py b/src/transformers/utils/metrics.py new file mode 100644 index 000000000000..16379831a97b --- /dev/null +++ b/src/transformers/utils/metrics.py @@ -0,0 +1,434 @@ +import functools +import logging +import time +from enum import Enum +from typing import Any, Callable, List, Optional, Tuple, Union + +import torch + + +class RequestStatus(Enum): + """Status of a generation request through its lifecycle.""" + + PENDING = "pending" + PREFILLING = "prefilling" + PREFILLING_SPLIT = "prefilling_split" + SPLIT_PENDING_REMAINDER = "split_pending_remainder" + DECODING = "decoding" + FINISHED = "finished" + FAILED = "failed" + + +try: + from opentelemetry import metrics, trace + from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter + from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter + from opentelemetry.sdk.metrics import MeterProvider + from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader + from opentelemetry.sdk.resources import Resource + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.sdk.trace.export import BatchSpanProcessor + from opentelemetry.trace import Status, StatusCode, get_tracer + + resource = Resource.create({"service.name": "transformers"}) + + metrics_exporter = PeriodicExportingMetricReader(OTLPMetricExporter(), export_interval_millis=1000) + meter_provider = MeterProvider(resource=resource, metric_readers=[metrics_exporter]) + metrics.set_meter_provider(meter_provider) + + trace_exporter = OTLPSpanExporter() + tracer_provider = TracerProvider(resource=resource) + tracer_provider.add_span_processor(BatchSpanProcessor(trace_exporter)) + trace.set_tracer_provider(tracer_provider) + + _has_opentelemetry = True +except ImportError: + _has_opentelemetry = False + + +def attach_tracer(tracer_name_template=None): + """ + Decorator that attaches a tracer to a class. + + This decorator should be applied to classes that need OpenTelemetry tracing. + It adds a tracer attribute to the class instance that can be used by the traced decorator. + + Args: + tracer_name_template: Optional template string for the tracer name. + If provided, it should contain {module} which will be replaced with the class's full module path + and {class_name} for the class name. + If None, a default naming scheme will be used where: + - If the module already starts with "transformers.", it will use that directly + - Otherwise, it will prepend "transformers." to the module name + + Returns: + Class decorator function + """ + if not _has_opentelemetry: + return lambda cls: cls + + def decorator(cls): + original_init = cls.__init__ + + @functools.wraps(original_init) + def init_with_tracer(self, *args, **kwargs): + original_init(self, *args, **kwargs) + + module_name = cls.__module__ + class_name = cls.__qualname__ + + if tracer_name_template is None: + if module_name.startswith("transformers."): + tracer_name = f"{module_name}.{class_name}" + else: + tracer_name = f"transformers.{module_name}.{class_name}" + else: + tracer_name = tracer_name_template.format(module=module_name, class_name=class_name) + + self.tracer = get_tracer(tracer_name) + + cls.__init__ = init_with_tracer + return cls + + return decorator + + +def traced( + func=None, + *, + span_name=None, + standalone=False, + additional_attributes: Optional[List[Tuple[str, str, Union[Any, Callable[[Any], Any]]]]] = None, +): + """ + Decorator to trace function calls with OpenTelemetry. + + Can be used as @traced or @traced(span_name="custom_name") + + Args: + func: The function to trace + span_name: Optional custom name for the span (defaults to function name) + standalone: If True, creates a parentless span + additional_attributes: Optional list of additional attributes to set on the span. + Each item is a tuple of (instance_attribute_name, span_attribute_key, value_or_transform_function) + where: + - instance_attribute_name: Name of the attribute to get from the class instance + - span_attribute_key: Key to use when setting the attribute on the span + - value_or_transform_function: Either a raw value to use directly, or a function to transform + the attribute value before setting it on the span + + Returns: + Decorated function with tracing + """ + + def decorator(func): + if not _has_opentelemetry: + return func + + import functools + + @functools.wraps(func) + def wrapper(*args, **kwargs): + instance = args[0] if args and (hasattr(func, "__self__") and func.__self__ is not None) else None + is_method = instance is not None + + if is_method and hasattr(instance, "tracer"): + tracer = instance.tracer + else: + tracer = get_tracer(f"transformers.{func.__module__}.{func.__name__}") + + name = span_name or func.__name__ + span_fn = tracer.start_span if standalone else tracer.start_as_current_span + with span_fn(name) as span: + span.set_attribute("function.name", func.__name__) + span.set_attribute("function.module", func.__module__) + span.set_attribute("function.is_method", is_method) + + if args: + for i, arg in enumerate(args): + if isinstance(arg, (str, int, float, bool)) or arg is None: + span.set_attribute(f"args.{i}", str(arg)) + else: + span.set_attribute(f"args.{i}", str(type(arg))) + if kwargs: + for key, value in kwargs.items(): + if isinstance(value, (str, int, float, bool)) or value is None: + span.set_attribute(f"kwargs.{key}", str(value)) + else: + span.set_attribute(f"kwargs.{key}", str(type(value))) + + if additional_attributes and is_method: + for attr_config in additional_attributes: + instance_attribute_name, span_attribute_key, value_or_transform_function = attr_config + if hasattr(instance, instance_attribute_name): + attribute_value = getattr(instance, instance_attribute_name) + if callable(value_or_transform_function): + transformed_value = value_or_transform_function(attribute_value) + else: + transformed_value = value_or_transform_function + span.set_attribute(span_attribute_key, transformed_value) + + try: + result = func(*args, **kwargs) + return result + except Exception as e: + span.set_status(Status(StatusCode.ERROR)) + span.record_exception(e) + raise + + return wrapper + + if func is None: + return decorator + return decorator(func) + + +logger = logging.getLogger(__name__) + + +@attach_tracer() +class ContinuousBatchProcessorMetrics: + """Metrics collection for ContinuousBatchProcessor.""" + + def __init__(self, max_batch_tokens: int): + """Initialize metrics for continuous batch processor. + + Args: + max_batch_tokens: Maximum number of tokens in a batch + """ + self.max_batch_tokens = max_batch_tokens + + self._setup_metrics() + + def _setup_metrics(self): + """Initialize OpenTelemetry metrics and tracing if the library is available.""" + + if not _has_opentelemetry: + logger.info("OpenTelemetry is not installed. Metrics and tracing will not be recorded.") + return + + self.meter = metrics.get_meter("transformers.generation.continuous_batch_processor") + + # Define appropriate buckets for TTFT (typically ranges from ~50ms to several seconds) + ttft_buckets = [10, 25, 50, 75, 100, 150, 200, 300, 500, 750, 1000, 2000, 5000, 10000] + + self.ttft_histogram = self.meter.create_histogram( + name="ttft_milliseconds", + description="Time to first token in milliseconds", + unit="ms", + explicit_bucket_boundaries_advisory=ttft_buckets, + ) + + self.active_requests_gauge = self.meter.create_gauge( + name="active_requests_count", + description="Number of active requests currently being processed", + unit="requests", + ) + + self.waiting_requests_gauge = self.meter.create_gauge( + name="waiting_requests_count", + description="Number of requests waiting to be processed", + unit="requests", + ) + + # Define appropriate buckets for request latency (similar to TTFT but with higher upper bounds) + latency_buckets = [50, 100, 250, 500, 1000, 2000, 5000, 10000, 20000, 30000, 60000] + + self.request_latency_histogram = self.meter.create_histogram( + name="request_latency_milliseconds", + description="End-to-end latency for completed requests in milliseconds", + unit="ms", + explicit_bucket_boundaries_advisory=latency_buckets, + ) + + self.decode_prefill_ratio_gauge = self.meter.create_gauge( + name="decode_prefill_ratio", + description="Ratio of decode tokens to prefill tokens in a batch", + unit="ratio", + ) + + self.prefill_tokens_counter = self.meter.create_counter( + name="prefill_tokens_processed", + description="Number of prefill tokens processed", + unit="tokens", + ) + + self.decode_tokens_counter = self.meter.create_counter( + name="decode_tokens_processed", + description="Number of decode tokens processed", + unit="tokens", + ) + + # Define appropriate buckets for batch fill percentage (0-100%) + batch_fill_buckets = [5, 10, 20, 30, 40, 50, 60, 70, 80, 90, 95, 98, 100] + + self.batch_fill_percentage_histogram = self.meter.create_histogram( + name="batch_fill_percentage", + description="Percentage of max_batch_tokens utilized in each batch", + unit="percent", + explicit_bucket_boundaries_advisory=batch_fill_buckets, + ) + + self.kv_cache_free_memory_gauge = self.meter.create_gauge( + name="kv_cache_free_memory_bytes", + description="Free memory of the PagedAttentionCache in bytes", + unit="bytes", + ) + + self.kv_cache_memory_gauge = self.meter.create_gauge( + name="kv_cache_memory_bytes", + description="Memory usage of the PagedAttentionCache in bytes", + unit="bytes", + ) + + @traced + def record_ttft_metric(self, created_time: float, request_id: str) -> None: + """Record Time to First Token (TTFT). + + Args: + created_time: The time the request was created + request_id: The ID of the request + """ + if not _has_opentelemetry: + return + + ttft_ms = (time.time() - created_time) * 1000.0 + + try: + self.ttft_histogram.record(ttft_ms) + logger.debug(f"Recorded TTFT for request {request_id}: {ttft_ms:.2f}ms") + except Exception as e: + logger.warning(f"Failed to record TTFT metric: {e}") + + @traced + def record_batch_metrics(self, requests_in_batch: List) -> None: + """Record metrics about the batch composition including decode/prefill ratio and batch fill percentage. + + Args: + requests_in_batch: List of request states in the current batch + """ + if not _has_opentelemetry or not requests_in_batch: + return + + decode_tokens = 0 + prefill_tokens = 0 + + for state in requests_in_batch: + if state.status == RequestStatus.DECODING: + decode_tokens += 1 + elif state.status in [RequestStatus.PREFILLING, RequestStatus.PREFILLING_SPLIT]: + prefill_tokens += len(state.prompt_ids) + + total_batch_tokens = decode_tokens + prefill_tokens + + try: + if prefill_tokens > 0: + self.prefill_tokens_counter.add(prefill_tokens) + + if decode_tokens > 0: + self.decode_tokens_counter.add(decode_tokens) + + if prefill_tokens > 0: + ratio = decode_tokens / prefill_tokens + self.decode_prefill_ratio_gauge.set(ratio) + + fill_percentage = (total_batch_tokens / self.max_batch_tokens) * 100.0 + self.batch_fill_percentage_histogram.record(fill_percentage) + logger.debug( + f"Batch metrics: {decode_tokens} decode tokens, {prefill_tokens} prefill tokens, " + f"batch fill: {fill_percentage:.2f}% ({total_batch_tokens}/{self.max_batch_tokens})" + ) + except Exception as e: + logger.warning(f"Failed to record batch metrics: {e}") + + @traced + def record_kv_cache_memory_metrics(self, cache) -> None: + """Record memory usage of the PagedAttentionCache without GPU synchronization. + + This calculates the theoretical memory usage based on cache configuration + and the number of blocks currently in use. + + Args: + cache: The PagedAttentionCache object to measure + """ + if not _has_opentelemetry: + return + + try: + # Calculate memory usage based on cache configuration + num_used_blocks = cache.num_blocks - len(cache._free_blocks) + num_layers = len(cache.key_cache) + + # Each used block stores key and value states + # Each with shape: (num_kv_heads, block_size, head_dim) + bytes_per_parameter = 2 if cache.dtype in [torch.float16, torch.bfloat16] else 4 # Size in bytes + + # Total bytes = num_layers * num_used_blocks * block_size * + # num_kv_heads * head_dim * 2 (both K and V) * bytes_per_parameter + memory_bytes = ( + num_layers + * num_used_blocks + * cache.block_size + * cache.num_key_value_heads + * cache.head_dim + * 2 # For both key and value caches + * bytes_per_parameter + ) + + free_memory_bytes = ( + num_layers + * len(cache._free_blocks) + * cache.block_size + * cache.num_key_value_heads + * cache.head_dim + * 2 # For both key and value caches + * bytes_per_parameter + ) + + self.kv_cache_memory_gauge.set(memory_bytes) + self.kv_cache_free_memory_gauge.set(free_memory_bytes) + logger.debug( + f"KV Cache memory: {memory_bytes / (1024 * 1024):.2f}MB, " + f"Used blocks: {num_used_blocks}/{cache.num_blocks} " + f"({num_used_blocks / cache.num_blocks * 100:.1f}%)" + ) + except Exception as e: + logger.warning(f"Failed to record KV cache memory metrics: {e}") + + @traced + def record_queue_metrics(self, active_requests: int, waiting_requests: int) -> None: + """Record metrics about active and waiting requests. + + Args: + active_requests: Number of active requests + waiting_requests: Number of waiting requests + """ + if not _has_opentelemetry: + return + + try: + self.active_requests_gauge.set(active_requests) + self.waiting_requests_gauge.set(waiting_requests) + logger.debug(f"Queue metrics: {active_requests} active requests, {waiting_requests} waiting requests") + except Exception as e: + logger.warning(f"Failed to record queue metrics: {e}") + + @traced + def record_request_completion(self, created_time: float, request_id: str) -> None: + """Record metrics about a completed request. + + Args: + created_time: The time the request was created + request_id: The ID of the request + """ + if not _has_opentelemetry: + return + + latency_ms = (time.time() - created_time) * 1000.0 + + try: + self.request_latency_histogram.record(latency_ms) + + logger.debug(f"Recorded request completion for {request_id}: {latency_ms:.2f}ms") + except Exception as e: + logger.warning(f"Failed to record request completion metric: {e}") diff --git a/tests/generation/test_paged_attention.py b/tests/generation/test_paged_attention.py new file mode 100644 index 000000000000..c06975844bfc --- /dev/null +++ b/tests/generation/test_paged_attention.py @@ -0,0 +1,86 @@ +import time +import unittest + +from parameterized import parameterized + +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig +from transformers.testing_utils import require_flash_attn, require_torch_gpu, run_slow + + +_TEST_PROMPTS = [ + "A man is a walking his dog down the street, and a the turn he sees", + "Describe a fruit that is of orange color and round. It is a sweet fruit and a great source of Vitamine C. The fruit I'm thinking of is an", + "A plane is flying high in the sky, out of the window are clouds and mountains. Where could the plane be located?", + "Please fill in the form to", + "For safety reasons, the train is stopped in the middle of the", +] + +_EXPECTED_OUTPUTS = [ + "a woman standing on the sidewalk, looking at him. He is immediately drawn to her and feels a strong attraction. He walks up to her and strikes up a conversation, and they quickly discover that they have a lot in common. They exchange numbers and", + "orange.\n\n## Step 1: Identify the key characteristics of the fruit\nThe fruit is described as being orange in color and round in shape.\n\n## Step 2: Determine the taste and nutritional value of the fruit\nThe fruit is described as sweet", + "This riddle is a classic example of a lateral thinking puzzle, which requires the test-taker to think creatively and consider multiple possibilities. The answer is not a straightforward one, and it requires some lateral thinking to arrive at the correct solution.", + "get in touch with us. We will respond to your message as soon as possible.\n\n[Your Name]\n[Your Email]\n[Your Phone Number]\n[Your Message]\n\nWe are looking forward to hearing from you!\n\n[Insert Contact Information]\n\nNote:", + "track. The train is stopped for 30 minutes. The train is moving at a speed of 60 km/h. How many kilometers does the train travel in 30 minutes?\n## Step 1: Convert the speed from km/h to km/min", +] + + +@run_slow +@require_torch_gpu +@require_flash_attn +class TestBatchGeneration(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = AutoModelForCausalLM.from_pretrained( + "meta-llama/Llama-3.2-3b-Instruct", torch_dtype="bfloat16", device_map="auto" + ).eval() + + cls.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3b-Instruct", padding_side="left") + + if cls.tokenizer.pad_token is None: + cls.tokenizer.pad_token = cls.tokenizer.eos_token + cls.model.config.pad_token_id = cls.model.config.eos_token_id + + cls.model.use_cache = False + + @parameterized.expand( + [ + ("eager_paged", 64, 128, 64), + ("sdpa_paged", 32, 256, 128), + ("paged_attention", 16, 512, 256), + ("flex_paged", 64, 128, 64), + ] + ) + def test_generate_batch_consistency(self, attn_impl, num_blocks, block_size, max_batch_tokens): + self.model.config.attn_implementation = attn_impl + + generation_config = GenerationConfig( + max_new_tokens=50, + top_k=0, + eos_token_id=self.tokenizer.eos_token_id, + pad_token_id=self.tokenizer.pad_token_id, + use_cache=False, + num_blocks=num_blocks, + block_size=block_size, + max_batch_tokens=max_batch_tokens, + ) + + tokenized = self.tokenizer(_TEST_PROMPTS, truncation=True, max_length=512) + batch_inputs = list(tokenized["input_ids"]) + + start = time.time() + batch_outputs = self.model.generate_batch( + inputs=batch_inputs, + generation_config=generation_config, + ) + end = time.time() + print( + f"\n[{attn_impl}] Batch took {end - start:.2f}s with config: blocks={num_blocks}, block_size={block_size}, max_batch_tokens={max_batch_tokens}" + ) + + for i, req_id in enumerate(batch_outputs): + generated = self.tokenizer.decode(batch_outputs[req_id].static_outputs, skip_special_tokens=False).strip() + expected = _EXPECTED_OUTPUTS[i].strip() + self.assertTrue( + generated.startswith(expected), + msg=f"[{attn_impl}] Mismatch in request {i}:\nExpected start: {expected}\nGot: {generated}", + )