Source code for oumi.inference.together_inference_engine

# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import tempfile
from pathlib import Path
from typing import Any

import aiofiles
import aiofiles.os
import aiohttp
import jsonlines
from typing_extensions import override

from oumi.core.configs import GenerationParams, ModelParams
from oumi.core.types.conversation import Conversation
from oumi.inference.remote_inference_engine import (
    _BATCH_ENDPOINT,
    BatchInfo,
    RemoteInferenceEngine,
)


[docs] class TogetherInferenceEngine(RemoteInferenceEngine): """Engine for running inference against the Together AI API. Together AI supports batch inference via their batch API. Note that Together uses a redirect-based file upload flow that differs from OpenAI's direct multipart upload. See: https://docs.together.ai/docs/batch-inference """ @property @override def base_url(self) -> str | None: """Return the default base URL for the Together API.""" return "https://api.together.xyz/v1/chat/completions" @property @override def api_key_env_varname(self) -> str | None: """Return the default environment variable name for the Together API key.""" return "TOGETHER_API_KEY" @property def _batch_purpose(self) -> str: """Return the purpose value for batch file uploads. Together AI uses "batch-api" instead of OpenAI's "batch". """ return "batch-api" @override async def _upload_batch_file( self, batch_requests: list[dict], ) -> str: """Uploads a JSONL file for batch processing using Together's redirect flow. Together AI uses a different upload mechanism than OpenAI: 1. POST to /files with JSON body to get a signed upload URL (302 redirect) 2. PUT the file content to the signed URL 3. POST to /files/{file_id}/preprocess to finalize Args: batch_requests: List of request objects to include in the batch Returns: str: The uploaded file ID """ # Create temporary JSONL file with tempfile.NamedTemporaryFile( mode="w", suffix=".jsonl", delete=False ) as tmp: with jsonlines.Writer(tmp) as writer: for request in batch_requests: writer.write(request) tmp_path = Path(tmp.name) try: async with self._create_session() as (session, headers): # Step 1: Request signed upload URL # Together expects form-encoded data (not JSON or multipart) request_data = { "purpose": self._batch_purpose, "file_name": tmp_path.name, "file_type": "jsonl", } async with session.post( self.get_file_api_url(), data=request_data, # Form-encoded, not json= headers=headers, allow_redirects=False, # We need to handle the redirect manually ) as response: if response.status == 302: # Get the signed URL and file ID from headers redirect_url = response.headers.get("Location") file_id = response.headers.get("X-Together-File-Id") if not redirect_url or not file_id: raise RuntimeError( "Together API did not return redirect URL or file ID. " f"Headers: {dict(response.headers)}" ) elif response.status == 200: # Some endpoints might return the file directly data = await response.json() return data["id"] else: error_text = await response.text() raise RuntimeError( f"Failed to get upload URL from Together: {error_text}" ) # Step 2: Upload file content to signed URL async with aiofiles.open(tmp_path, "rb") as f: file_content = await f.read() async with session.put( redirect_url, data=file_content, ) as upload_response: if upload_response.status not in (200, 201): error_text = await upload_response.text() raise RuntimeError( f"Failed to upload file to Together: {error_text}" ) # Step 3: Finalize upload by calling preprocess endpoint preprocess_url = f"{self.get_file_api_url()}/{file_id}/preprocess" async with session.post( preprocess_url, headers=self._get_request_headers(self._remote_params), ) as preprocess_response: if preprocess_response.status != 200: error_text = await preprocess_response.text() raise RuntimeError( f"Failed to preprocess file on Together: {error_text}" ) return file_id finally: # Clean up temporary file try: await aiofiles.os.remove(tmp_path) except OSError: pass # Ignore cleanup errors to avoid masking original exceptions def _normalize_together_response(self, data: dict[str, Any]) -> dict[str, Any]: """Normalize Together's response format to match OpenAI's format. Together uses: - Uppercase status values (e.g., "COMPLETED" instead of "completed") - ISO 8601 timestamps instead of Unix timestamps - ``progress`` (float 0-100) for batch completion instead of OpenAI's ``request_counts`` dict with total/completed/failed keys. """ if "status" in data: data["status"] = data["status"].lower() # Convert ISO 8601 timestamps to Unix timestamps for base class compatibility timestamp_fields = [ "created_at", "completed_at", "in_progress_at", "expires_at", "finalizing_at", "failed_at", "expired_at", "cancelling_at", "cancelled_at", ] for field in timestamp_fields: if field in data and isinstance(data[field], str): dt = self._parse_iso_timestamp(data[field]) if dt is not None: data[field] = int(dt.timestamp()) progress = data.get("progress", 0.0) if progress > 0: data["request_counts"] = { "total": 100, "completed": int(progress), } return data def _parse_batch_create_response( self, response: aiohttp.ClientResponse, data: dict[str, Any] ) -> str: """Parse batch creation response. Together returns 201 and {"job": {"id": "..."}} instead of OpenAI's 200 and {"id": "..."}. """ if response.status not in (200, 201): raise RuntimeError(f"Unexpected status code: {response.status}") if "job" in data: return data["job"]["id"] return data["id"] @override async def _create_batch( self, conversations: list[Conversation], generation_params: GenerationParams, model_params: ModelParams, ) -> str: """Creates a batch job, handling Together's response format.""" # Prepare batch requests batch_requests = [] for i, conv in enumerate(conversations): api_input = self._convert_conversation_to_api_input( conv, generation_params, model_params ) batch_requests.append( { "custom_id": f"request-{i}", "method": "POST", "url": _BATCH_ENDPOINT, "body": api_input, } ) # Upload batch file (uses Together's redirect-based flow) file_id = await self._upload_batch_file(batch_requests) # Create batch async with self._create_session() as (session, headers): async with session.post( self.get_batch_api_url(), json={ "input_file_id": file_id, "endpoint": _BATCH_ENDPOINT, "completion_window": self._remote_params.batch_completion_window, }, headers=headers, ) as response: if response.status not in (200, 201): raise RuntimeError( f"Failed to create batch: {await response.text()}" ) data = await response.json() return self._parse_batch_create_response(response, data) @override async def _get_batch_status(self, batch_id: str) -> BatchInfo: """Gets the status of a batch job, normalizing Together's response format.""" async with self._create_session() as (session, headers): async with session.get( f"{self.get_batch_api_url()}/{batch_id}", headers=headers, ) as response: if response.status != 200: raise RuntimeError( f"Failed to get batch status: {await response.text()}" ) data = await response.json() return BatchInfo.from_api_response( self._normalize_together_response(data) )