# Build apps
> This bundle contains all pages in the Build apps section.
> Source: https://www.union.ai/docs/v2/union/user-guide/build-apps/

=== PAGE: https://www.union.ai/docs/v2/union/user-guide/build-apps ===

# Build apps

> **📝 Note**
>
> An LLM-optimized bundle of this entire section is available at [`section.md`](section.md).
> This single file contains all pages in this section, optimized for AI coding agent context.

This section covers how to build different types of apps with Flyte, from single-script apps to multi-file projects, common usage patterns, and authentication.

> [!TIP]
> Go to [Introducing apps](https://www.union.ai/docs/v2/union/user-guide/core-concepts/introducing-apps/page.md) for an overview of apps and a quick example. For pre-built environments for popular frameworks like Streamlit, FastAPI, vLLM, and SGLang, see [Native app integrations](https://www.union.ai/docs/v2/union/user-guide/native-app-integrations/_index).

## App types

Flyte supports various types of apps:

- **UI dashboard apps**: Interactive web dashboards and data visualization tools like Streamlit and Gradio
- **Web API apps**: REST APIs, webhooks, and backend services like FastAPI and Flask
- **Model serving apps**: High-performance LLM serving with vLLM and SGLang

- **Connector apps**: Long-running services that delegate task execution to external systems

For ready-to-use environments for these frameworks, see [Native app integrations](https://www.union.ai/docs/v2/union/user-guide/native-app-integrations/_index).

## Usage patterns

Apps and tasks can interact in various ways: calling each other via HTTP, webhooks, WebSockets, or direct browser usage.

| Pattern | Use Case | Implementation |
|---------|----------|----------------|
| App | Stand-alone serving app | HTTP requests from arbitrary clients |
| App → App | Microservices, proxies, agent routers, LLM routers | HTTP requests between apps |
| App → Task | Webhooks, APIs triggering workflows | Flyte SDK in app |
| Task → App | Batch processing using inference services | HTTP requests from task |
| Browser app | User-facing dashboards (e.g. Streamlit, Gradio) | Direct browser access |

## Next steps

- **Build apps > Single-script apps**: The simplest way to build and deploy apps in a single Python script
- **Build apps > Multi-script apps**: Build FastAPI and Streamlit apps with multiple files
- **Build apps > Serving graphs**: Apps calling other apps for microservice architectures
- **Build apps > Hybrid app-task graphs**: Tasks calling apps and apps calling tasks (webhooks, APIs)
- **Build apps > WebSocket apps**: Real-time, bidirectional communication with WebSockets
- **Build apps > Browser apps**: User-facing dashboards and UIs
- **Build apps > Secret-based authentication**: Authenticate FastAPI apps using Flyte secrets

- **Build apps > Connector app**: Deploy a connector as a long-running service

=== PAGE: https://www.union.ai/docs/v2/union/user-guide/build-apps/single-script-apps ===

# Single-script apps

The simplest way to build and deploy an app with Flyte is to write everything in a single Python script. This approach is perfect for:

- **Quick prototypes**: Rapidly test ideas and concepts
- **Simple services**: Basic HTTP servers, APIs, or dashboards
- **Learning**: Understanding how Flyte apps work without complexity
- **Minimal examples**: Demonstrating core functionality

All the code for your app—the application logic, the app environment configuration, and the deployment code—lives in one file. This makes it easy to understand, share, and deploy.

## Plain Python HTTP server

The simplest possible app is a plain Python HTTP server using Python's built-in `http.server` module. This requires no external dependencies beyond the Flyte SDK.

```
# /// script
# requires-python = ">=3.12"
# dependencies = [
#    "flyte>=2.0.0b52",
# ]
# ///

"""A plain Python HTTP server example - the simplest possible app."""

import flyte
import flyte.app
from pathlib import Path

# {{docs-fragment server-code}}
# Create a simple HTTP server handler
from http.server import HTTPServer, BaseHTTPRequestHandler

class SimpleHandler(BaseHTTPRequestHandler):
    """A simple HTTP server handler."""

    def do_GET(self):

        if self.path == "/":
            self.send_response(200)
            self.send_header("Content-type", "text/html")
            self.end_headers()
            self.wfile.write(b"<h1>Hello from Plain Python Server!</h1>")

        elif self.path == "/health":
            self.send_response(200)
            self.send_header("Content-type", "application/json")
            self.end_headers()
            self.wfile.write(b'{"status": "healthy"}')

        else:
            self.send_response(404)
            self.end_headers()
# {{/docs-fragment server-code}}

# {{docs-fragment app-env}}
file_name = Path(__file__).name
app_env = flyte.app.AppEnvironment(
    name="plain-python-server",
    image=flyte.Image.from_debian_base(python_version=(3, 12)),
    args=["python", file_name, "--server"],
    port=8080,
    resources=flyte.Resources(cpu="1", memory="512Mi"),
    requires_auth=False,
)
# {{/docs-fragment app-env}}

# {{docs-fragment deploy}}
if __name__ == "__main__":
    import sys

    if "--server" in sys.argv:
        server = HTTPServer(("0.0.0.0", 8080), SimpleHandler)
        print("Server running on port 8080")
        server.serve_forever()
    else:
        flyte.init_from_config(root_dir=Path(__file__).parent)
        app = flyte.serve(app_env)
        print(f"App URL: {app.url}")
# {{/docs-fragment deploy}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/plain_python_server.py*

**Key points**

- **No external dependencies**: Uses only Python's standard library
- **Simple handler**: Define request handlers as Python classes
- **Basic command**: Run the server with a simple Python command
- **Minimal resources**: Requires only basic CPU and memory

## Streamlit app

Streamlit makes it easy to build interactive web dashboards. Here's a complete single-script Streamlit app:

```
# /// script
# requires-python = ">=3.12"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "streamlit",
# ]
# ///

"""A single-script Streamlit app example."""

import pathlib
import streamlit as st
import flyte
import flyte.app

# {{docs-fragment streamlit-app}}
def main():
    st.set_page_config(page_title="Simple Streamlit App", page_icon="🚀")

    st.title("Hello from Streamlit!")
    st.write("This is a simple single-script Streamlit app.")

    name = st.text_input("What's your name?", "World")
    st.write(f"Hello, {name}!")

    if st.button("Click me!"):
        st.balloons()
        st.success("Button clicked!")
# {{/docs-fragment streamlit-app}}

# {{docs-fragment app-env}}
file_name = pathlib.Path(__file__).name
app_env = flyte.app.AppEnvironment(
    name="streamlit-single-script",
    image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
        "streamlit==1.41.1"
    ),
    args=["streamlit", "run", file_name, "--server.port", "8080", "--", "--server"],
    port=8080,
    resources=flyte.Resources(cpu="1", memory="1Gi"),
    requires_auth=False,
)
# {{/docs-fragment app-env}}

# {{docs-fragment deploy}}
if __name__ == "__main__":
    import sys

    if "--server" in sys.argv:
        main()
    else:
        flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
        app = flyte.serve(app_env)
        print(f"App URL: {app.url}")
# {{/docs-fragment deploy}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/streamlit_single_script.py*

**Key points**

- **Interactive UI**: Streamlit provides widgets and visualizations out of the box
- **Single file**: All UI logic and deployment code in one script
- **Simple deployment**: Just specify the Streamlit command and port
- **Rich ecosystem**: Access to Streamlit's extensive component library

## FastAPI app

FastAPI is a modern, fast web framework for building APIs. Here's a minimal single-script FastAPI app:

```
# /// script
# requires-python = ">=3.12"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "fastapi",
# ]
# ///

"""A single-script FastAPI app example - the simplest FastAPI app."""

from fastapi import FastAPI
import pathlib
import flyte
from flyte.app.extras import FastAPIAppEnvironment

# {{docs-fragment fastapi-app}}
app = FastAPI(
    title="Simple FastAPI App",
    description="A minimal single-script FastAPI application",
    version="1.0.0",
)

@app.get("/")
async def root():
    return {"message": "Hello, World!"}

@app.get("/health")
async def health():
    return {"status": "healthy"}
# {{/docs-fragment fastapi-app}}

# {{docs-fragment app-env}}
app_env = FastAPIAppEnvironment(
    name="fastapi-single-script",
    app=app,
    image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
        "fastapi",
        "uvicorn",
    ),
    resources=flyte.Resources(cpu=1, memory="512Mi"),
    requires_auth=False,
)
# {{/docs-fragment app-env}}

# {{docs-fragment deploy}}
if __name__ == "__main__":
    flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
    app_deployment = flyte.serve(app_env)
    print(f"Deployed: {app_deployment.url}")
    print(f"API docs: {app_deployment.url}/docs")
# {{/docs-fragment deploy}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/fastapi_single_script.py*

**Key points**

- **FastAPIAppEnvironment**: Automatically configures uvicorn and FastAPI
- **Type hints**: FastAPI uses Python type hints for automatic validation
- **Auto docs**: Interactive API documentation at `/docs` endpoint
- **Async support**: Built-in support for async/await patterns

## Running single-script apps

To run any of these examples:

1. **Save the script** to a file (e.g., `my_app.py`)
2. **Ensure you have a config file** (`./.flyte/config.yaml` or `./config.yaml`)
3. **Run the script**:

```bash
python my_app.py
```

Or using `uv`:

```bash
uv run my_app.py
```

The script will:
- Initialize Flyte from your config
- Deploy the app to your Union/Flyte instance
- Print the app URL

## When to use single-script apps

**Use single-script apps when:**
- Building prototypes or proof-of-concepts
- Creating simple services with minimal logic
- Learning how Flyte apps work
- Sharing complete, runnable examples
- Building demos or tutorials

**Consider multi-script apps when:**
- Your app grows beyond a few hundred lines
- You need to organize code into modules
- You want to reuse components across apps
- You're building production applications

See [**Multi-script apps**](./multi-script-apps) for examples of organizing apps across multiple files.

=== PAGE: https://www.union.ai/docs/v2/union/user-guide/build-apps/multi-script-apps ===

# Multi-script apps

Real-world applications often span multiple files. This page shows how to build FastAPI and Streamlit apps with multiple Python files.

## FastAPI multi-script app

### Project structure

```
project/
├── app.py          # Main FastAPI app file
└── module.py       # Helper module
```

### Example: Multi-file FastAPI app

```
# /// script
# requires-python = ">=3.12"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "fastapi",
# ]
# ///

"""Multi-file FastAPI app example."""

from fastapi import FastAPI
from module import function  # Import from another file
import pathlib

import flyte
from flyte.app.extras import FastAPIAppEnvironment

# {{docs-fragment app-definition}}
app = FastAPI(title="Multi-file FastAPI Demo")

app_env = FastAPIAppEnvironment(
    name="fastapi-multi-file",
    app=app,
    image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
        "fastapi",
        "uvicorn",
    ),
    resources=flyte.Resources(cpu=1, memory="512Mi"),
    requires_auth=False,
    # FastAPIAppEnvironment automatically includes necessary files
    # But you can also specify explicitly:
    # include=["app.py", "module.py"],
)
# {{/docs-fragment app-definition}}

# {{docs-fragment endpoint}}
@app.get("/")
async def root():
    return function()  # Uses function from module.py
# {{/docs-fragment endpoint}}

# {{docs-fragment deploy}}
if __name__ == "__main__":
    flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
    app_deployment = flyte.deploy(app_env)
    print(f"Deployed: {app_deployment[0].summary_repr()}")
# {{/docs-fragment deploy}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/fastapi/multi_file/app.py*

```
# {{docs-fragment helper-function}}
def function():
    """Helper function used by the FastAPI app."""
    return {"message": "Hello from module.py!"}
# {{/docs-fragment helper-function}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/fastapi/multi_file/module.py*

### Automatic file discovery

`FastAPIAppEnvironment` automatically discovers and includes the necessary files by analyzing your imports. However, if you have files that aren't automatically detected (like configuration files or data files), you can explicitly include them:

```python
app_env = FastAPIAppEnvironment(
    name="fastapi-with-config",
    app=app,
    include=["app.py", "module.py", "config.yaml"],  # Explicit includes
    # ...
)
```

## Streamlit multi-script app

### Project structure

```
project/
├── main.py         # Main Streamlit app
├── utils.py        # Utility functions
└── components.py   # Reusable components
```

### Example: Multi-file Streamlit app

```
import os

import streamlit as st
from utils import generate_data

# {{docs-fragment streamlit-app}}
all_columns = ["Apples", "Orange", "Pineapple"]
with st.container(border=True):
    columns = st.multiselect("Columns", all_columns, default=all_columns)

all_data = st.cache_data(generate_data)(columns=all_columns, seed=101)

data = all_data[columns]

tab1, tab2 = st.tabs(["Chart", "Dataframe"])
tab1.line_chart(data, height=250)
tab2.dataframe(data, height=250, use_container_width=True)
st.write(f"Environment: {os.environ}")
# {{/docs-fragment streamlit-app}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/streamlit/main.py*

```
import numpy as np
import pandas as pd

# {{docs-fragment utils-function}}
def generate_data(columns: list[str], seed: int = 42):
    rng = np.random.default_rng(seed)
    data = pd.DataFrame(rng.random(size=(20, len(columns))), columns=columns)
    return data
# {{/docs-fragment utils-function}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/streamlit/utils.py*

### Deploying multi-file Streamlit app

```
# /// script
# requires-python = ">=3.12"
# dependencies = [
#    "flyte>=2.0.0b52",
# ]
# ///

"""A custom Streamlit app with multiple files."""

import pathlib
import flyte
import flyte.app

# {{docs-fragment app-env}}
image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
    "streamlit==1.41.1",
    "pandas==2.2.3",
    "numpy==2.2.3",
)

app_env = flyte.app.AppEnvironment(
    name="streamlit-multi-file-app",
    image=image,
    args="streamlit run main.py --server.port 8080",
    port=8080,
    include=["main.py", "utils.py"],  # Include your app files
    resources=flyte.Resources(cpu="1", memory="1Gi"),
    requires_auth=False,
)
# {{/docs-fragment app-env}}

# {{docs-fragment deploy}}
if __name__ == "__main__":
    flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
    app = flyte.deploy(app_env)
    print(f"Deployed app: {app[0].summary_repr()}")
# {{/docs-fragment deploy}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/streamlit/multi_file_streamlit.py*

## Complex multi-file example

Here's a more complex example with multiple modules:

### Project structure

```
project/
├── app.py
├── models/
│   ├── __init__.py
│   └── user.py
├── services/
│   ├── __init__.py
│   └── auth.py
└── utils/
    ├── __init__.py
    └── helpers.py
```

### Example code

```
# /// script
# requires-python = ">=3.12"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "fastapi",
# ]
# ///

"""Complex multi-file FastAPI app example."""

from pathlib import Path
from fastapi import FastAPI
from models.user import User
from services.auth import authenticate
from utils.helpers import format_response

import flyte
from flyte.app.extras import FastAPIAppEnvironment

# {{docs-fragment complex-app}}
app = FastAPI(title="Complex Multi-file App")

@app.get("/users/{user_id}")
async def get_user(user_id: int):
    user = User(id=user_id, name="John Doe")
    return format_response(user)
# {{/docs-fragment complex-app}}

# {{docs-fragment complex-env}}
app_env = FastAPIAppEnvironment(
    name="complex-app",
    app=app,
    image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
        "fastapi",
        "uvicorn",
        "pydantic",
    ),
    # Include all necessary files
    include=[
        "app.py",
        "models/",
        "services/",
        "utils/",
    ],
    resources=flyte.Resources(cpu=1, memory="512Mi"),
)
# {{/docs-fragment complex-env}}

if __name__ == "__main__":
    flyte.init_from_config(root_dir=Path(__file__).parent)
    app_deployment = flyte.deploy(app_env)
    print(f"Deployed: {app_deployment[0].summary_repr()}")
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/fastapi/complex_multi_file/app.py*

```
# {{docs-fragment user-model}}
from pydantic import BaseModel

class User(BaseModel):
    id: int
    name: str
# {{/docs-fragment user-model}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/fastapi/complex_multi_file/models/user.py*

```
# {{docs-fragment auth-service}}
def authenticate(token: str) -> bool:
    """Authenticate a user by token."""
    # ... authentication logic ...
    return True
# {{/docs-fragment auth-service}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/fastapi/complex_multi_file/services/auth.py*

```
# {{docs-fragment helpers}}
def format_response(data):
    """Format a response with standard structure."""
    return {"data": data, "status": "success"}
# {{/docs-fragment helpers}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/fastapi/complex_multi_file/utils/helpers.py*

### Deploying complex app

```
# /// script
# requires-python = ">=3.12"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "fastapi",
# ]
# ///

"""Complex multi-file FastAPI app example."""

from pathlib import Path
from fastapi import FastAPI
from models.user import User
from services.auth import authenticate
from utils.helpers import format_response

import flyte
from flyte.app.extras import FastAPIAppEnvironment

# {{docs-fragment complex-app}}
app = FastAPI(title="Complex Multi-file App")

@app.get("/users/{user_id}")
async def get_user(user_id: int):
    user = User(id=user_id, name="John Doe")
    return format_response(user)
# {{/docs-fragment complex-app}}

# {{docs-fragment complex-env}}
app_env = FastAPIAppEnvironment(
    name="complex-app",
    app=app,
    image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
        "fastapi",
        "uvicorn",
        "pydantic",
    ),
    # Include all necessary files
    include=[
        "app.py",
        "models/",
        "services/",
        "utils/",
    ],
    resources=flyte.Resources(cpu=1, memory="512Mi"),
)
# {{/docs-fragment complex-env}}

if __name__ == "__main__":
    flyte.init_from_config(root_dir=Path(__file__).parent)
    app_deployment = flyte.deploy(app_env)
    print(f"Deployed: {app_deployment[0].summary_repr()}")
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/fastapi/complex_multi_file/app.py*

## Best practices

1. **Use explicit includes**: For Streamlit apps, explicitly list all files in `include`
2. **Automatic discovery**: For FastAPI apps, `FastAPIAppEnvironment` handles most cases automatically
3. **Organize modules**: Use proper Python package structure with `__init__.py` files
4. **Test locally**: Test your multi-file app locally before deploying
5. **Include all dependencies**: Include all files that your app imports

## Troubleshooting

**Import errors:**
- Verify all files are included in the `include` parameter
- Check that file paths are correct (relative to app definition file)
- Ensure `__init__.py` files are included for packages

**Module not found:**
- Add missing files to the `include` list
- Check that import paths match the file structure
- Verify that the image includes all necessary packages

**File not found at runtime:**
- Ensure all referenced files are included
- Check mount paths for file/directory inputs
- Verify file paths are relative to the app root directory

=== PAGE: https://www.union.ai/docs/v2/union/user-guide/build-apps/serving-graphs ===

# Serving graphs

A *serving graph* is a set of Flyte apps that talk to each other inside the
cluster. Instead of putting every stage of a request into one process, you
split the work across multiple `AppEnvironment`s that you deploy together —
each one sized for its own bottleneck, with its own image and scaling policy.

This pattern is useful for:

- **Heterogeneous resource requirements**: CPU pre/postprocessing in front of a GPU forward pass
- **Microservice architectures**: Independent components with distinct lifecycles
- **A/B testing and canary rollouts**: A root app routes traffic across variant apps
- **Proxy / gateway patterns**: One app fronts several backends

## Core concepts: a minimal two-app chain

The simplest serving graph — `app2` proxies HTTP calls to `app1` — is enough
to introduce every core concept: deploying multiple apps together, discovering
an upstream app's endpoint, and sizing each app independently.

Both apps share an image and live in the same Python file:

```
import logging
import os
import pathlib
import typing

import httpx
from fastapi import FastAPI

import flyte
import flyte.app
from flyte.app.extras import FastAPIAppEnvironment

# {{docs-fragment image}}
image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages("fastapi", "uvicorn", "httpx")
# {{/docs-fragment image}}

# {{docs-fragment apps}}
app1 = FastAPI(
    title="App 1",
    description="A FastAPI app that runs some computations",
)

app2 = FastAPI(
    title="App 2",
    description="A FastAPI app that proxies requests to another FastAPI app",
)
# {{/docs-fragment apps}}

# {{docs-fragment env-direct}}
env1 = FastAPIAppEnvironment(
    name="app1-is-called-by-app2",
    app=app1,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
    requires_auth=True,
)
# {{/docs-fragment env-direct}}

# {{docs-fragment env-with-parameter}}
env2 = FastAPIAppEnvironment(
    name="app2-calls-app1",
    app=app2,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
    requires_auth=True,
    parameters=[
        flyte.app.Parameter(
            name="app1_url",
            value=flyte.app.AppEndpoint(app_name="app1-is-called-by-app2"),
            env_var="APP1_URL",
        ),
    ],
    depends_on=[env1],
    env_vars={"LOG_LEVEL": "10"},
)
# {{/docs-fragment env-with-parameter}}

@app1.get("/greeting/{name}")
async def greeting(name: str) -> str:
    return f"Hello, {name}!"

# {{docs-fragment endpoint-property-pattern}}
@app2.get("/app1-endpoint")
async def get_app1_endpoint() -> str:
    return env1.endpoint

@app2.get("/greeting/{name}")
async def greeting_proxy(name: str) -> typing.Any:
    async with httpx.AsyncClient() as client:
        response = await client.get(f"{env1.endpoint}/greeting/{name}")
        return response.json()
# {{/docs-fragment endpoint-property-pattern}}

# {{docs-fragment endpoint-env-var-pattern}}
@app2.get("/app1-url")
async def get_app1_url() -> str:
    return os.getenv("APP1_URL")
# {{/docs-fragment endpoint-env-var-pattern}}

# {{docs-fragment deploy}}
if __name__ == "__main__":
    flyte.init_from_config(
        root_dir=pathlib.Path(__file__).parent,
        log_level=logging.DEBUG,
    )
    app = flyte.serve(env2)
    print(f"Deployed FastAPI app: {app.url}")
# {{/docs-fragment deploy}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/serving_graphs/two_app_chain.py*

```
import logging
import os
import pathlib
import typing

import httpx
from fastapi import FastAPI

import flyte
import flyte.app
from flyte.app.extras import FastAPIAppEnvironment

# {{docs-fragment image}}
image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages("fastapi", "uvicorn", "httpx")
# {{/docs-fragment image}}

# {{docs-fragment apps}}
app1 = FastAPI(
    title="App 1",
    description="A FastAPI app that runs some computations",
)

app2 = FastAPI(
    title="App 2",
    description="A FastAPI app that proxies requests to another FastAPI app",
)
# {{/docs-fragment apps}}

# {{docs-fragment env-direct}}
env1 = FastAPIAppEnvironment(
    name="app1-is-called-by-app2",
    app=app1,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
    requires_auth=True,
)
# {{/docs-fragment env-direct}}

# {{docs-fragment env-with-parameter}}
env2 = FastAPIAppEnvironment(
    name="app2-calls-app1",
    app=app2,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
    requires_auth=True,
    parameters=[
        flyte.app.Parameter(
            name="app1_url",
            value=flyte.app.AppEndpoint(app_name="app1-is-called-by-app2"),
            env_var="APP1_URL",
        ),
    ],
    depends_on=[env1],
    env_vars={"LOG_LEVEL": "10"},
)
# {{/docs-fragment env-with-parameter}}

@app1.get("/greeting/{name}")
async def greeting(name: str) -> str:
    return f"Hello, {name}!"

# {{docs-fragment endpoint-property-pattern}}
@app2.get("/app1-endpoint")
async def get_app1_endpoint() -> str:
    return env1.endpoint

@app2.get("/greeting/{name}")
async def greeting_proxy(name: str) -> typing.Any:
    async with httpx.AsyncClient() as client:
        response = await client.get(f"{env1.endpoint}/greeting/{name}")
        return response.json()
# {{/docs-fragment endpoint-property-pattern}}

# {{docs-fragment endpoint-env-var-pattern}}
@app2.get("/app1-url")
async def get_app1_url() -> str:
    return os.getenv("APP1_URL")
# {{/docs-fragment endpoint-env-var-pattern}}

# {{docs-fragment deploy}}
if __name__ == "__main__":
    flyte.init_from_config(
        root_dir=pathlib.Path(__file__).parent,
        log_level=logging.DEBUG,
    )
    app = flyte.serve(env2)
    print(f"Deployed FastAPI app: {app.url}")
# {{/docs-fragment deploy}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/serving_graphs/two_app_chain.py*

### Deploying multiple apps together with `depends_on`

The callee env is straightforward — it has no upstream dependencies of its
own:

```
import logging
import os
import pathlib
import typing

import httpx
from fastapi import FastAPI

import flyte
import flyte.app
from flyte.app.extras import FastAPIAppEnvironment

# {{docs-fragment image}}
image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages("fastapi", "uvicorn", "httpx")
# {{/docs-fragment image}}

# {{docs-fragment apps}}
app1 = FastAPI(
    title="App 1",
    description="A FastAPI app that runs some computations",
)

app2 = FastAPI(
    title="App 2",
    description="A FastAPI app that proxies requests to another FastAPI app",
)
# {{/docs-fragment apps}}

# {{docs-fragment env-direct}}
env1 = FastAPIAppEnvironment(
    name="app1-is-called-by-app2",
    app=app1,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
    requires_auth=True,
)
# {{/docs-fragment env-direct}}

# {{docs-fragment env-with-parameter}}
env2 = FastAPIAppEnvironment(
    name="app2-calls-app1",
    app=app2,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
    requires_auth=True,
    parameters=[
        flyte.app.Parameter(
            name="app1_url",
            value=flyte.app.AppEndpoint(app_name="app1-is-called-by-app2"),
            env_var="APP1_URL",
        ),
    ],
    depends_on=[env1],
    env_vars={"LOG_LEVEL": "10"},
)
# {{/docs-fragment env-with-parameter}}

@app1.get("/greeting/{name}")
async def greeting(name: str) -> str:
    return f"Hello, {name}!"

# {{docs-fragment endpoint-property-pattern}}
@app2.get("/app1-endpoint")
async def get_app1_endpoint() -> str:
    return env1.endpoint

@app2.get("/greeting/{name}")
async def greeting_proxy(name: str) -> typing.Any:
    async with httpx.AsyncClient() as client:
        response = await client.get(f"{env1.endpoint}/greeting/{name}")
        return response.json()
# {{/docs-fragment endpoint-property-pattern}}

# {{docs-fragment endpoint-env-var-pattern}}
@app2.get("/app1-url")
async def get_app1_url() -> str:
    return os.getenv("APP1_URL")
# {{/docs-fragment endpoint-env-var-pattern}}

# {{docs-fragment deploy}}
if __name__ == "__main__":
    flyte.init_from_config(
        root_dir=pathlib.Path(__file__).parent,
        log_level=logging.DEBUG,
    )
    app = flyte.serve(env2)
    print(f"Deployed FastAPI app: {app.url}")
# {{/docs-fragment deploy}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/serving_graphs/two_app_chain.py*

The caller declares `depends_on=[env1]`, which tells Flyte that `env1` must
be deployed alongside this one:

```
import logging
import os
import pathlib
import typing

import httpx
from fastapi import FastAPI

import flyte
import flyte.app
from flyte.app.extras import FastAPIAppEnvironment

# {{docs-fragment image}}
image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages("fastapi", "uvicorn", "httpx")
# {{/docs-fragment image}}

# {{docs-fragment apps}}
app1 = FastAPI(
    title="App 1",
    description="A FastAPI app that runs some computations",
)

app2 = FastAPI(
    title="App 2",
    description="A FastAPI app that proxies requests to another FastAPI app",
)
# {{/docs-fragment apps}}

# {{docs-fragment env-direct}}
env1 = FastAPIAppEnvironment(
    name="app1-is-called-by-app2",
    app=app1,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
    requires_auth=True,
)
# {{/docs-fragment env-direct}}

# {{docs-fragment env-with-parameter}}
env2 = FastAPIAppEnvironment(
    name="app2-calls-app1",
    app=app2,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
    requires_auth=True,
    parameters=[
        flyte.app.Parameter(
            name="app1_url",
            value=flyte.app.AppEndpoint(app_name="app1-is-called-by-app2"),
            env_var="APP1_URL",
        ),
    ],
    depends_on=[env1],
    env_vars={"LOG_LEVEL": "10"},
)
# {{/docs-fragment env-with-parameter}}

@app1.get("/greeting/{name}")
async def greeting(name: str) -> str:
    return f"Hello, {name}!"

# {{docs-fragment endpoint-property-pattern}}
@app2.get("/app1-endpoint")
async def get_app1_endpoint() -> str:
    return env1.endpoint

@app2.get("/greeting/{name}")
async def greeting_proxy(name: str) -> typing.Any:
    async with httpx.AsyncClient() as client:
        response = await client.get(f"{env1.endpoint}/greeting/{name}")
        return response.json()
# {{/docs-fragment endpoint-property-pattern}}

# {{docs-fragment endpoint-env-var-pattern}}
@app2.get("/app1-url")
async def get_app1_url() -> str:
    return os.getenv("APP1_URL")
# {{/docs-fragment endpoint-env-var-pattern}}

# {{docs-fragment deploy}}
if __name__ == "__main__":
    flyte.init_from_config(
        root_dir=pathlib.Path(__file__).parent,
        log_level=logging.DEBUG,
    )
    app = flyte.serve(env2)
    print(f"Deployed FastAPI app: {app.url}")
# {{/docs-fragment deploy}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/serving_graphs/two_app_chain.py*

Calling `flyte.serve(env2)` then deploys the whole dependency closure
transitively, so you only ever name the entry-point app:

```
import logging
import os
import pathlib
import typing

import httpx
from fastapi import FastAPI

import flyte
import flyte.app
from flyte.app.extras import FastAPIAppEnvironment

# {{docs-fragment image}}
image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages("fastapi", "uvicorn", "httpx")
# {{/docs-fragment image}}

# {{docs-fragment apps}}
app1 = FastAPI(
    title="App 1",
    description="A FastAPI app that runs some computations",
)

app2 = FastAPI(
    title="App 2",
    description="A FastAPI app that proxies requests to another FastAPI app",
)
# {{/docs-fragment apps}}

# {{docs-fragment env-direct}}
env1 = FastAPIAppEnvironment(
    name="app1-is-called-by-app2",
    app=app1,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
    requires_auth=True,
)
# {{/docs-fragment env-direct}}

# {{docs-fragment env-with-parameter}}
env2 = FastAPIAppEnvironment(
    name="app2-calls-app1",
    app=app2,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
    requires_auth=True,
    parameters=[
        flyte.app.Parameter(
            name="app1_url",
            value=flyte.app.AppEndpoint(app_name="app1-is-called-by-app2"),
            env_var="APP1_URL",
        ),
    ],
    depends_on=[env1],
    env_vars={"LOG_LEVEL": "10"},
)
# {{/docs-fragment env-with-parameter}}

@app1.get("/greeting/{name}")
async def greeting(name: str) -> str:
    return f"Hello, {name}!"

# {{docs-fragment endpoint-property-pattern}}
@app2.get("/app1-endpoint")
async def get_app1_endpoint() -> str:
    return env1.endpoint

@app2.get("/greeting/{name}")
async def greeting_proxy(name: str) -> typing.Any:
    async with httpx.AsyncClient() as client:
        response = await client.get(f"{env1.endpoint}/greeting/{name}")
        return response.json()
# {{/docs-fragment endpoint-property-pattern}}

# {{docs-fragment endpoint-env-var-pattern}}
@app2.get("/app1-url")
async def get_app1_url() -> str:
    return os.getenv("APP1_URL")
# {{/docs-fragment endpoint-env-var-pattern}}

# {{docs-fragment deploy}}
if __name__ == "__main__":
    flyte.init_from_config(
        root_dir=pathlib.Path(__file__).parent,
        log_level=logging.DEBUG,
    )
    app = flyte.serve(env2)
    print(f"Deployed FastAPI app: {app.url}")
# {{/docs-fragment deploy}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/serving_graphs/two_app_chain.py*

`depends_on` is about deployment co-scheduling, not request-time ordering —
at runtime each app is independent.

### Getting an upstream app's endpoint

There are two ways for one app to discover another app's URL. Both resolve
correctly across local, in-cluster, and external contexts.

**Pattern A — `env.endpoint` (Python property).** When both apps live in the
same Python module, the upstream env object is in scope and you can read
`env.endpoint` directly. The example above uses this pattern in `app2`'s
proxy endpoint:

```
import logging
import os
import pathlib
import typing

import httpx
from fastapi import FastAPI

import flyte
import flyte.app
from flyte.app.extras import FastAPIAppEnvironment

# {{docs-fragment image}}
image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages("fastapi", "uvicorn", "httpx")
# {{/docs-fragment image}}

# {{docs-fragment apps}}
app1 = FastAPI(
    title="App 1",
    description="A FastAPI app that runs some computations",
)

app2 = FastAPI(
    title="App 2",
    description="A FastAPI app that proxies requests to another FastAPI app",
)
# {{/docs-fragment apps}}

# {{docs-fragment env-direct}}
env1 = FastAPIAppEnvironment(
    name="app1-is-called-by-app2",
    app=app1,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
    requires_auth=True,
)
# {{/docs-fragment env-direct}}

# {{docs-fragment env-with-parameter}}
env2 = FastAPIAppEnvironment(
    name="app2-calls-app1",
    app=app2,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
    requires_auth=True,
    parameters=[
        flyte.app.Parameter(
            name="app1_url",
            value=flyte.app.AppEndpoint(app_name="app1-is-called-by-app2"),
            env_var="APP1_URL",
        ),
    ],
    depends_on=[env1],
    env_vars={"LOG_LEVEL": "10"},
)
# {{/docs-fragment env-with-parameter}}

@app1.get("/greeting/{name}")
async def greeting(name: str) -> str:
    return f"Hello, {name}!"

# {{docs-fragment endpoint-property-pattern}}
@app2.get("/app1-endpoint")
async def get_app1_endpoint() -> str:
    return env1.endpoint

@app2.get("/greeting/{name}")
async def greeting_proxy(name: str) -> typing.Any:
    async with httpx.AsyncClient() as client:
        response = await client.get(f"{env1.endpoint}/greeting/{name}")
        return response.json()
# {{/docs-fragment endpoint-property-pattern}}

# {{docs-fragment endpoint-env-var-pattern}}
@app2.get("/app1-url")
async def get_app1_url() -> str:
    return os.getenv("APP1_URL")
# {{/docs-fragment endpoint-env-var-pattern}}

# {{docs-fragment deploy}}
if __name__ == "__main__":
    flyte.init_from_config(
        root_dir=pathlib.Path(__file__).parent,
        log_level=logging.DEBUG,
    )
    app = flyte.serve(env2)
    print(f"Deployed FastAPI app: {app.url}")
# {{/docs-fragment deploy}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/serving_graphs/two_app_chain.py*

**Pattern B — `flyte.app.AppEndpoint` as a parameter.** When the upstream env
object isn't importable (different file, different process, looking it up by
name), declare it as a `flyte.app.Parameter` and have Flyte inject the
resolved URL via an environment variable. The `env2` definition above shows
this — `app1_url` becomes available as `os.getenv("APP1_URL")` at runtime:

```
import logging
import os
import pathlib
import typing

import httpx
from fastapi import FastAPI

import flyte
import flyte.app
from flyte.app.extras import FastAPIAppEnvironment

# {{docs-fragment image}}
image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages("fastapi", "uvicorn", "httpx")
# {{/docs-fragment image}}

# {{docs-fragment apps}}
app1 = FastAPI(
    title="App 1",
    description="A FastAPI app that runs some computations",
)

app2 = FastAPI(
    title="App 2",
    description="A FastAPI app that proxies requests to another FastAPI app",
)
# {{/docs-fragment apps}}

# {{docs-fragment env-direct}}
env1 = FastAPIAppEnvironment(
    name="app1-is-called-by-app2",
    app=app1,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
    requires_auth=True,
)
# {{/docs-fragment env-direct}}

# {{docs-fragment env-with-parameter}}
env2 = FastAPIAppEnvironment(
    name="app2-calls-app1",
    app=app2,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
    requires_auth=True,
    parameters=[
        flyte.app.Parameter(
            name="app1_url",
            value=flyte.app.AppEndpoint(app_name="app1-is-called-by-app2"),
            env_var="APP1_URL",
        ),
    ],
    depends_on=[env1],
    env_vars={"LOG_LEVEL": "10"},
)
# {{/docs-fragment env-with-parameter}}

@app1.get("/greeting/{name}")
async def greeting(name: str) -> str:
    return f"Hello, {name}!"

# {{docs-fragment endpoint-property-pattern}}
@app2.get("/app1-endpoint")
async def get_app1_endpoint() -> str:
    return env1.endpoint

@app2.get("/greeting/{name}")
async def greeting_proxy(name: str) -> typing.Any:
    async with httpx.AsyncClient() as client:
        response = await client.get(f"{env1.endpoint}/greeting/{name}")
        return response.json()
# {{/docs-fragment endpoint-property-pattern}}

# {{docs-fragment endpoint-env-var-pattern}}
@app2.get("/app1-url")
async def get_app1_url() -> str:
    return os.getenv("APP1_URL")
# {{/docs-fragment endpoint-env-var-pattern}}

# {{docs-fragment deploy}}
if __name__ == "__main__":
    flyte.init_from_config(
        root_dir=pathlib.Path(__file__).parent,
        log_level=logging.DEBUG,
    )
    app = flyte.serve(env2)
    print(f"Deployed FastAPI app: {app.url}")
# {{/docs-fragment deploy}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/serving_graphs/two_app_chain.py*

### Sizing each node independently

Each `AppEnvironment` carries its own image, resources, and scaling. That's
the entire point of splitting — for example, the GPU side of an inference
graph can stay narrow with `scaling=Scaling(replicas=(1, 2))` while the CPU
side scales wide with `scaling=Scaling(replicas=(1, 8))`, with no shared
autoscaling policy between them. The next example shows this in practice.

## Example: CPU / GPU inference split

The canonical heterogeneous-resource pipeline: heavy CPU preprocessing in
front of a fast GPU forward pass, talking to each other over HTTP inside the
cluster.

```mermaid
flowchart LR
    client["client"] --> cpu["cpu_app (×N replicas)<br/>decode + resize<br/>+ softmax"]
    cpu --> gpu["gpu_app (×M replicas)<br/>ResNet18 forward only"]
    gpu --> cpu
    cpu --> client
```

In a typical vision/audio pipeline, the GPU forward pass takes milliseconds
but is sandwiched between slow CPU work (image decode, resize, normalization,
softmax, label lookup). If both stages share one process you pay for an idle
GPU during preprocessing. Splitting them lets each side scale independently:
cheap CPU wide, expensive GPU narrow.

### Disjoint images per node

The two apps share a small base image and add their own disjoint stacks. The
CPU app never imports `torch`; the GPU app never imports `PIL`:

```
"""
Serving graph — CPU pre/post split from a GPU forward pass.

This example shows the canonical "two-app" inference graph: heavy CPU work
on one app, the GPU forward pass on another, talking to each other over HTTP
inside the cluster.

Why split? In a typical vision/audio/feature-engineering pipeline the GPU
forward pass is fast (millis) but is sandwiched between slow CPU work
(image decode, resize, denoise, NMS, label lookup, etc.). If you put both
stages in one process you pay for an idle GPU during preprocessing. Splitting
them lets each side scale independently:

    client ──► [cpu_app  x N replicas]  ──► [gpu_app x M replicas] ──► back
                preprocess + postprocess        model.forward only
                cheap CPU, scale wide           expensive GPU, scale narrow

Wire format between the two apps is raw float32 bytes (not JSON) — for
anything tensor-shaped this is the single biggest perf knob.
"""

import io
import ipaddress
import logging
import pathlib
import socket
from contextlib import asynccontextmanager

import httpx
import numpy as np
from fastapi import FastAPI, HTTPException, Request, Response
from PIL import Image, ImageFilter
from pydantic import BaseModel

import flyte
import flyte.app
from flyte.app.extras import FastAPIAppEnvironment

# ---------------------------------------------------------------------------
# Images
# ---------------------------------------------------------------------------
# Shared base with the deps both apps need (HTTP server + numpy). The CPU and
# GPU images extend it with their own disjoint stacks — the CPU app never
# imports torch and the GPU app never imports PIL. Sharing the base layer
# means the registry only stores one copy of fastapi/uvicorn/numpy.

# {{docs-fragment images}}
base_image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
    "fastapi",
    "uvicorn",
    "numpy",
)

cpu_image = base_image.with_pip_packages(
    "httpx",
    "pillow",
)

gpu_image = base_image.with_pip_packages(
    "torch==2.7.1",
    "torchvision==0.22.1",
)
# {{/docs-fragment images}}

# ---------------------------------------------------------------------------
# Shared tensor layout
# ---------------------------------------------------------------------------

INPUT_C, INPUT_H, INPUT_W = 3, 224, 224
NUM_CLASSES = 1000
TENSOR_DTYPE = np.float32

# ===========================================================================
# GPU app — model.forward only
# ===========================================================================

# {{docs-fragment gpu-lifespan}}
@asynccontextmanager
async def _gpu_lifespan(app: FastAPI):
    # Imported lazily so the CPU app never has to import torch.
    import torch
    from torchvision.models import ResNet18_Weights, resnet18

    weights = ResNet18_Weights.IMAGENET1K_V1
    model = resnet18(weights=weights).eval()
    device = "cuda" if torch.cuda.is_available() else "cpu"
    if device == "cuda":
        model = model.to("cuda")
    app.state.model = model
    app.state.device = device
    app.state.categories = list(weights.meta["categories"])
    logging.getLogger(__name__).info("model loaded on %s", device)
    yield

gpu_app = FastAPI(
    title="inference-gpu",
    description="ResNet18 forward pass.",
    lifespan=_gpu_lifespan,
)
# {{/docs-fragment gpu-lifespan}}

@gpu_app.get("/health")
async def gpu_health() -> dict:
    return {"status": "ok", "device": gpu_app.state.device}

@gpu_app.get("/labels")
async def labels() -> list[str]:
    # Exposed so the CPU side can fetch labels once at startup instead of
    # hard-coding the ImageNet class list.
    return gpu_app.state.categories

# {{docs-fragment gpu-infer}}
@gpu_app.post("/infer")
async def infer(request: Request) -> Response:
    """Run a batched forward pass.

    Request body:  raw float32 bytes, shape (B, 3, 224, 224), C-contiguous.
    Response body: raw float32 bytes, shape (B, 1000) — raw logits.

    We deliberately do NOT use JSON here. For a batch of 32 images the tensor
    is ~19MB; JSON-serializing that is the dominant cost end-to-end.
    """
    import torch

    raw = await request.body()
    arr = np.frombuffer(raw, dtype=TENSOR_DTYPE)
    if arr.size % (INPUT_C * INPUT_H * INPUT_W) != 0:
        raise HTTPException(400, "payload size is not a multiple of one image tensor")
    batch = arr.reshape(-1, INPUT_C, INPUT_H, INPUT_W)

    x = torch.from_numpy(batch).to(gpu_app.state.device)
    with torch.inference_mode():
        logits = gpu_app.state.model(x)
    out = logits.detach().to("cpu").numpy().astype(TENSOR_DTYPE, copy=False)
    return Response(content=out.tobytes(), media_type="application/octet-stream")
# {{/docs-fragment gpu-infer}}

# {{docs-fragment gpu-env}}
gpu_env = FastAPIAppEnvironment(
    name="serving-graph-gpu",
    app=gpu_app,
    image=gpu_image,
    resources=flyte.Resources(cpu=2, memory="8Gi", gpu="A10G:1"),
    # GPU replicas are expensive; keep at least one warm so model weights stay
    # resident, and cap the max. Bump if a single replica saturates.
    scaling=flyte.app.Scaling(replicas=(1, 2)),
    requires_auth=True,
)
# {{/docs-fragment gpu-env}}

# ===========================================================================
# CPU app — pre/postprocess, calls the GPU app
# ===========================================================================

IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=TENSOR_DTYPE).reshape(3, 1, 1)
IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=TENSOR_DTYPE).reshape(3, 1, 1)

class ClassifyRequest(BaseModel):
    image_url: str
    top_k: int = 5

class Prediction(BaseModel):
    label: str
    score: float

# {{docs-fragment cpu-preprocess}}
def _preprocess(img_bytes: bytes) -> np.ndarray:
    """Decode → denoise → resize → normalize. CPU-bound, deliberately so.

    Real preprocessing stacks (detection, OCR, audio) do substantially more
    than this — sliding window crops, color-space conversion, etc. The point
    is that none of it benefits from a GPU sitting next to it.
    """
    img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
    img = img.filter(ImageFilter.GaussianBlur(radius=1.0))
    img = img.resize((INPUT_W, INPUT_H), Image.BILINEAR)
    arr = np.asarray(img, dtype=TENSOR_DTYPE) / 255.0
    arr = arr.transpose(2, 0, 1)  # HWC → CHW
    arr = (arr - IMAGENET_MEAN) / IMAGENET_STD
    return np.ascontiguousarray(arr, dtype=TENSOR_DTYPE)
# {{/docs-fragment cpu-preprocess}}

def _softmax(x: np.ndarray, axis: int = -1) -> np.ndarray:
    x = x - x.max(axis=axis, keepdims=True)
    e = np.exp(x)
    return e / e.sum(axis=axis, keepdims=True)

# {{docs-fragment cpu-lifespan}}
@asynccontextmanager
async def _cpu_lifespan(app: FastAPI):
    # Resolved at serving time via the cluster-internal endpoint pattern,
    # so this stays correct across local/remote deploys without an env var.
    gpu_url = gpu_env.endpoint
    log = logging.getLogger(__name__)
    log.info("resolved GPU endpoint: %s", gpu_url)
    async with httpx.AsyncClient(timeout=30.0) as bootstrap:
        try:
            r = await bootstrap.get(f"{gpu_url}/labels")
            r.raise_for_status()
        except (httpx.HTTPError, OSError) as e:
            # Most common reason on a fresh deploy: GPU replica hasn't finished
            # pulling its image / loading weights yet. Crash-looping is fine —
            # the next attempt will likely succeed — but make the cause obvious.
            log.error("downstream GPU app at %s not ready: %s", gpu_url, e)
            raise
        app.state.labels = r.json()
    # One persistent client per replica — avoids TCP/TLS handshake per request,
    # which matters once you're doing 100s of req/s.
    async with httpx.AsyncClient(
        base_url=gpu_url,
        timeout=httpx.Timeout(30.0, connect=5.0),
        limits=httpx.Limits(max_connections=64, max_keepalive_connections=32),
    ) as client:
        app.state.client = client
        yield

cpu_app = FastAPI(
    title="inference-cpu",
    description="Pre/post around the GPU forward pass.",
    lifespan=_cpu_lifespan,
)
# {{/docs-fragment cpu-lifespan}}

@cpu_app.get("/health")
async def cpu_health() -> dict:
    return {"status": "ok", "labels_loaded": len(cpu_app.state.labels)}

# {{docs-fragment cpu-classify}}
async def validate_public_image_url(image_url: str) -> str:
     try:
         parsed = httpx.URL(image_url)
     except Exception as exc:
         raise HTTPException(status_code=400, detail="Invalid image_url.") from exc
     if parsed.scheme not in {"http", "https"}:
         raise HTTPException(status_code=400, detail="image_url must use http or https.")
     host = parsed.host
     if not host:
         raise HTTPException(status_code=400, detail="image_url must include a hostname.")
     try:
         addr_info = socket.getaddrinfo(host, parsed.port or (443 if parsed.scheme == "https" else 80))
     except socket.gaierror as exc:
         raise HTTPException(status_code=400, detail="image_url host could not be resolved.") from exc
     for info in addr_info:
         ip_text = info[4][0]
         ip_obj = ipaddress.ip_address(ip_text)
         if not ip_obj.is_global:
             raise HTTPException(status_code=400, detail="image_url host resolves to a non-public address.")
     return str(parsed)

@cpu_app.post("/classify", response_model=list[Prediction])
async def classify(req: ClassifyRequest) -> list[Prediction]:
    async with httpx.AsyncClient(timeout=30.0) as client:
        img_resp = await client.get(await validate_public_image_url(req.image_url))
        img_resp.raise_for_status()

    tensor = _preprocess(img_resp.content)  # heavy CPU
    batch = tensor[np.newaxis, ...]  # add batch dim

    gpu_resp = await cpu_app.state.client.post(
        "/infer",
        content=batch.tobytes(),
        headers={"content-type": "application/octet-stream"},
    )
    gpu_resp.raise_for_status()
    logits = np.frombuffer(gpu_resp.content, dtype=TENSOR_DTYPE).reshape(1, NUM_CLASSES)

    probs = _softmax(logits, axis=-1)[0]  # back to CPU work
    top_idx = np.argsort(-probs)[: req.top_k]
    return [Prediction(label=cpu_app.state.labels[i], score=float(probs[i])) for i in top_idx]
# {{/docs-fragment cpu-classify}}

# {{docs-fragment cpu-env}}
cpu_env = FastAPIAppEnvironment(
    name="serving-graph-cpu",
    app=cpu_app,
    image=cpu_image,
    resources=flyte.Resources(cpu=4, memory="4Gi"),
    # Cheap, so scale wide. Use scale-to-zero (replicas=(0, 8)) for bursty
    # traffic; keep replicas=(1, 8) here to avoid cold starts in the demo.
    scaling=flyte.app.Scaling(replicas=(1, 8)),
    requires_auth=True,
    depends_on=[gpu_env],
)
# {{/docs-fragment cpu-env}}

# ===========================================================================
# Deploy
# ===========================================================================

# {{docs-fragment deploy}}
if __name__ == "__main__":
    flyte.init_from_config(
        root_dir=pathlib.Path(__file__).parent,
        log_level=logging.INFO,
    )
    app = flyte.serve(cpu_env)
    print(f"Deployed serving graph; public CPU endpoint: {app.url}")
    print("Try: curl -X POST $URL/classify -H 'content-type: application/json' \\")
    print(
        '       -d \'{"image_url": "https://upload.wikimedia.org/wikipedia/commons/4/41/Sunflower_from_Silesia2.jpg"}\''
    )
# {{/docs-fragment deploy}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/serving_graphs/image_classification.py*

### GPU app: model.forward only

The GPU app loads the model once at startup using FastAPI's lifespan, so model
weights stay resident across requests:

```
"""
Serving graph — CPU pre/post split from a GPU forward pass.

This example shows the canonical "two-app" inference graph: heavy CPU work
on one app, the GPU forward pass on another, talking to each other over HTTP
inside the cluster.

Why split? In a typical vision/audio/feature-engineering pipeline the GPU
forward pass is fast (millis) but is sandwiched between slow CPU work
(image decode, resize, denoise, NMS, label lookup, etc.). If you put both
stages in one process you pay for an idle GPU during preprocessing. Splitting
them lets each side scale independently:

    client ──► [cpu_app  x N replicas]  ──► [gpu_app x M replicas] ──► back
                preprocess + postprocess        model.forward only
                cheap CPU, scale wide           expensive GPU, scale narrow

Wire format between the two apps is raw float32 bytes (not JSON) — for
anything tensor-shaped this is the single biggest perf knob.
"""

import io
import ipaddress
import logging
import pathlib
import socket
from contextlib import asynccontextmanager

import httpx
import numpy as np
from fastapi import FastAPI, HTTPException, Request, Response
from PIL import Image, ImageFilter
from pydantic import BaseModel

import flyte
import flyte.app
from flyte.app.extras import FastAPIAppEnvironment

# ---------------------------------------------------------------------------
# Images
# ---------------------------------------------------------------------------
# Shared base with the deps both apps need (HTTP server + numpy). The CPU and
# GPU images extend it with their own disjoint stacks — the CPU app never
# imports torch and the GPU app never imports PIL. Sharing the base layer
# means the registry only stores one copy of fastapi/uvicorn/numpy.

# {{docs-fragment images}}
base_image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
    "fastapi",
    "uvicorn",
    "numpy",
)

cpu_image = base_image.with_pip_packages(
    "httpx",
    "pillow",
)

gpu_image = base_image.with_pip_packages(
    "torch==2.7.1",
    "torchvision==0.22.1",
)
# {{/docs-fragment images}}

# ---------------------------------------------------------------------------
# Shared tensor layout
# ---------------------------------------------------------------------------

INPUT_C, INPUT_H, INPUT_W = 3, 224, 224
NUM_CLASSES = 1000
TENSOR_DTYPE = np.float32

# ===========================================================================
# GPU app — model.forward only
# ===========================================================================

# {{docs-fragment gpu-lifespan}}
@asynccontextmanager
async def _gpu_lifespan(app: FastAPI):
    # Imported lazily so the CPU app never has to import torch.
    import torch
    from torchvision.models import ResNet18_Weights, resnet18

    weights = ResNet18_Weights.IMAGENET1K_V1
    model = resnet18(weights=weights).eval()
    device = "cuda" if torch.cuda.is_available() else "cpu"
    if device == "cuda":
        model = model.to("cuda")
    app.state.model = model
    app.state.device = device
    app.state.categories = list(weights.meta["categories"])
    logging.getLogger(__name__).info("model loaded on %s", device)
    yield

gpu_app = FastAPI(
    title="inference-gpu",
    description="ResNet18 forward pass.",
    lifespan=_gpu_lifespan,
)
# {{/docs-fragment gpu-lifespan}}

@gpu_app.get("/health")
async def gpu_health() -> dict:
    return {"status": "ok", "device": gpu_app.state.device}

@gpu_app.get("/labels")
async def labels() -> list[str]:
    # Exposed so the CPU side can fetch labels once at startup instead of
    # hard-coding the ImageNet class list.
    return gpu_app.state.categories

# {{docs-fragment gpu-infer}}
@gpu_app.post("/infer")
async def infer(request: Request) -> Response:
    """Run a batched forward pass.

    Request body:  raw float32 bytes, shape (B, 3, 224, 224), C-contiguous.
    Response body: raw float32 bytes, shape (B, 1000) — raw logits.

    We deliberately do NOT use JSON here. For a batch of 32 images the tensor
    is ~19MB; JSON-serializing that is the dominant cost end-to-end.
    """
    import torch

    raw = await request.body()
    arr = np.frombuffer(raw, dtype=TENSOR_DTYPE)
    if arr.size % (INPUT_C * INPUT_H * INPUT_W) != 0:
        raise HTTPException(400, "payload size is not a multiple of one image tensor")
    batch = arr.reshape(-1, INPUT_C, INPUT_H, INPUT_W)

    x = torch.from_numpy(batch).to(gpu_app.state.device)
    with torch.inference_mode():
        logits = gpu_app.state.model(x)
    out = logits.detach().to("cpu").numpy().astype(TENSOR_DTYPE, copy=False)
    return Response(content=out.tobytes(), media_type="application/octet-stream")
# {{/docs-fragment gpu-infer}}

# {{docs-fragment gpu-env}}
gpu_env = FastAPIAppEnvironment(
    name="serving-graph-gpu",
    app=gpu_app,
    image=gpu_image,
    resources=flyte.Resources(cpu=2, memory="8Gi", gpu="A10G:1"),
    # GPU replicas are expensive; keep at least one warm so model weights stay
    # resident, and cap the max. Bump if a single replica saturates.
    scaling=flyte.app.Scaling(replicas=(1, 2)),
    requires_auth=True,
)
# {{/docs-fragment gpu-env}}

# ===========================================================================
# CPU app — pre/postprocess, calls the GPU app
# ===========================================================================

IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=TENSOR_DTYPE).reshape(3, 1, 1)
IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=TENSOR_DTYPE).reshape(3, 1, 1)

class ClassifyRequest(BaseModel):
    image_url: str
    top_k: int = 5

class Prediction(BaseModel):
    label: str
    score: float

# {{docs-fragment cpu-preprocess}}
def _preprocess(img_bytes: bytes) -> np.ndarray:
    """Decode → denoise → resize → normalize. CPU-bound, deliberately so.

    Real preprocessing stacks (detection, OCR, audio) do substantially more
    than this — sliding window crops, color-space conversion, etc. The point
    is that none of it benefits from a GPU sitting next to it.
    """
    img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
    img = img.filter(ImageFilter.GaussianBlur(radius=1.0))
    img = img.resize((INPUT_W, INPUT_H), Image.BILINEAR)
    arr = np.asarray(img, dtype=TENSOR_DTYPE) / 255.0
    arr = arr.transpose(2, 0, 1)  # HWC → CHW
    arr = (arr - IMAGENET_MEAN) / IMAGENET_STD
    return np.ascontiguousarray(arr, dtype=TENSOR_DTYPE)
# {{/docs-fragment cpu-preprocess}}

def _softmax(x: np.ndarray, axis: int = -1) -> np.ndarray:
    x = x - x.max(axis=axis, keepdims=True)
    e = np.exp(x)
    return e / e.sum(axis=axis, keepdims=True)

# {{docs-fragment cpu-lifespan}}
@asynccontextmanager
async def _cpu_lifespan(app: FastAPI):
    # Resolved at serving time via the cluster-internal endpoint pattern,
    # so this stays correct across local/remote deploys without an env var.
    gpu_url = gpu_env.endpoint
    log = logging.getLogger(__name__)
    log.info("resolved GPU endpoint: %s", gpu_url)
    async with httpx.AsyncClient(timeout=30.0) as bootstrap:
        try:
            r = await bootstrap.get(f"{gpu_url}/labels")
            r.raise_for_status()
        except (httpx.HTTPError, OSError) as e:
            # Most common reason on a fresh deploy: GPU replica hasn't finished
            # pulling its image / loading weights yet. Crash-looping is fine —
            # the next attempt will likely succeed — but make the cause obvious.
            log.error("downstream GPU app at %s not ready: %s", gpu_url, e)
            raise
        app.state.labels = r.json()
    # One persistent client per replica — avoids TCP/TLS handshake per request,
    # which matters once you're doing 100s of req/s.
    async with httpx.AsyncClient(
        base_url=gpu_url,
        timeout=httpx.Timeout(30.0, connect=5.0),
        limits=httpx.Limits(max_connections=64, max_keepalive_connections=32),
    ) as client:
        app.state.client = client
        yield

cpu_app = FastAPI(
    title="inference-cpu",
    description="Pre/post around the GPU forward pass.",
    lifespan=_cpu_lifespan,
)
# {{/docs-fragment cpu-lifespan}}

@cpu_app.get("/health")
async def cpu_health() -> dict:
    return {"status": "ok", "labels_loaded": len(cpu_app.state.labels)}

# {{docs-fragment cpu-classify}}
async def validate_public_image_url(image_url: str) -> str:
     try:
         parsed = httpx.URL(image_url)
     except Exception as exc:
         raise HTTPException(status_code=400, detail="Invalid image_url.") from exc
     if parsed.scheme not in {"http", "https"}:
         raise HTTPException(status_code=400, detail="image_url must use http or https.")
     host = parsed.host
     if not host:
         raise HTTPException(status_code=400, detail="image_url must include a hostname.")
     try:
         addr_info = socket.getaddrinfo(host, parsed.port or (443 if parsed.scheme == "https" else 80))
     except socket.gaierror as exc:
         raise HTTPException(status_code=400, detail="image_url host could not be resolved.") from exc
     for info in addr_info:
         ip_text = info[4][0]
         ip_obj = ipaddress.ip_address(ip_text)
         if not ip_obj.is_global:
             raise HTTPException(status_code=400, detail="image_url host resolves to a non-public address.")
     return str(parsed)

@cpu_app.post("/classify", response_model=list[Prediction])
async def classify(req: ClassifyRequest) -> list[Prediction]:
    async with httpx.AsyncClient(timeout=30.0) as client:
        img_resp = await client.get(await validate_public_image_url(req.image_url))
        img_resp.raise_for_status()

    tensor = _preprocess(img_resp.content)  # heavy CPU
    batch = tensor[np.newaxis, ...]  # add batch dim

    gpu_resp = await cpu_app.state.client.post(
        "/infer",
        content=batch.tobytes(),
        headers={"content-type": "application/octet-stream"},
    )
    gpu_resp.raise_for_status()
    logits = np.frombuffer(gpu_resp.content, dtype=TENSOR_DTYPE).reshape(1, NUM_CLASSES)

    probs = _softmax(logits, axis=-1)[0]  # back to CPU work
    top_idx = np.argsort(-probs)[: req.top_k]
    return [Prediction(label=cpu_app.state.labels[i], score=float(probs[i])) for i in top_idx]
# {{/docs-fragment cpu-classify}}

# {{docs-fragment cpu-env}}
cpu_env = FastAPIAppEnvironment(
    name="serving-graph-cpu",
    app=cpu_app,
    image=cpu_image,
    resources=flyte.Resources(cpu=4, memory="4Gi"),
    # Cheap, so scale wide. Use scale-to-zero (replicas=(0, 8)) for bursty
    # traffic; keep replicas=(1, 8) here to avoid cold starts in the demo.
    scaling=flyte.app.Scaling(replicas=(1, 8)),
    requires_auth=True,
    depends_on=[gpu_env],
)
# {{/docs-fragment cpu-env}}

# ===========================================================================
# Deploy
# ===========================================================================

# {{docs-fragment deploy}}
if __name__ == "__main__":
    flyte.init_from_config(
        root_dir=pathlib.Path(__file__).parent,
        log_level=logging.INFO,
    )
    app = flyte.serve(cpu_env)
    print(f"Deployed serving graph; public CPU endpoint: {app.url}")
    print("Try: curl -X POST $URL/classify -H 'content-type: application/json' \\")
    print(
        '       -d \'{"image_url": "https://upload.wikimedia.org/wikipedia/commons/4/41/Sunflower_from_Silesia2.jpg"}\''
    )
# {{/docs-fragment deploy}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/serving_graphs/image_classification.py*

The inference endpoint speaks raw `float32` bytes over
`application/octet-stream`. For anything tensor-shaped this is the single
biggest perf knob — JSON-serializing a 19MB batch dominates end-to-end
latency:

```
"""
Serving graph — CPU pre/post split from a GPU forward pass.

This example shows the canonical "two-app" inference graph: heavy CPU work
on one app, the GPU forward pass on another, talking to each other over HTTP
inside the cluster.

Why split? In a typical vision/audio/feature-engineering pipeline the GPU
forward pass is fast (millis) but is sandwiched between slow CPU work
(image decode, resize, denoise, NMS, label lookup, etc.). If you put both
stages in one process you pay for an idle GPU during preprocessing. Splitting
them lets each side scale independently:

    client ──► [cpu_app  x N replicas]  ──► [gpu_app x M replicas] ──► back
                preprocess + postprocess        model.forward only
                cheap CPU, scale wide           expensive GPU, scale narrow

Wire format between the two apps is raw float32 bytes (not JSON) — for
anything tensor-shaped this is the single biggest perf knob.
"""

import io
import ipaddress
import logging
import pathlib
import socket
from contextlib import asynccontextmanager

import httpx
import numpy as np
from fastapi import FastAPI, HTTPException, Request, Response
from PIL import Image, ImageFilter
from pydantic import BaseModel

import flyte
import flyte.app
from flyte.app.extras import FastAPIAppEnvironment

# ---------------------------------------------------------------------------
# Images
# ---------------------------------------------------------------------------
# Shared base with the deps both apps need (HTTP server + numpy). The CPU and
# GPU images extend it with their own disjoint stacks — the CPU app never
# imports torch and the GPU app never imports PIL. Sharing the base layer
# means the registry only stores one copy of fastapi/uvicorn/numpy.

# {{docs-fragment images}}
base_image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
    "fastapi",
    "uvicorn",
    "numpy",
)

cpu_image = base_image.with_pip_packages(
    "httpx",
    "pillow",
)

gpu_image = base_image.with_pip_packages(
    "torch==2.7.1",
    "torchvision==0.22.1",
)
# {{/docs-fragment images}}

# ---------------------------------------------------------------------------
# Shared tensor layout
# ---------------------------------------------------------------------------

INPUT_C, INPUT_H, INPUT_W = 3, 224, 224
NUM_CLASSES = 1000
TENSOR_DTYPE = np.float32

# ===========================================================================
# GPU app — model.forward only
# ===========================================================================

# {{docs-fragment gpu-lifespan}}
@asynccontextmanager
async def _gpu_lifespan(app: FastAPI):
    # Imported lazily so the CPU app never has to import torch.
    import torch
    from torchvision.models import ResNet18_Weights, resnet18

    weights = ResNet18_Weights.IMAGENET1K_V1
    model = resnet18(weights=weights).eval()
    device = "cuda" if torch.cuda.is_available() else "cpu"
    if device == "cuda":
        model = model.to("cuda")
    app.state.model = model
    app.state.device = device
    app.state.categories = list(weights.meta["categories"])
    logging.getLogger(__name__).info("model loaded on %s", device)
    yield

gpu_app = FastAPI(
    title="inference-gpu",
    description="ResNet18 forward pass.",
    lifespan=_gpu_lifespan,
)
# {{/docs-fragment gpu-lifespan}}

@gpu_app.get("/health")
async def gpu_health() -> dict:
    return {"status": "ok", "device": gpu_app.state.device}

@gpu_app.get("/labels")
async def labels() -> list[str]:
    # Exposed so the CPU side can fetch labels once at startup instead of
    # hard-coding the ImageNet class list.
    return gpu_app.state.categories

# {{docs-fragment gpu-infer}}
@gpu_app.post("/infer")
async def infer(request: Request) -> Response:
    """Run a batched forward pass.

    Request body:  raw float32 bytes, shape (B, 3, 224, 224), C-contiguous.
    Response body: raw float32 bytes, shape (B, 1000) — raw logits.

    We deliberately do NOT use JSON here. For a batch of 32 images the tensor
    is ~19MB; JSON-serializing that is the dominant cost end-to-end.
    """
    import torch

    raw = await request.body()
    arr = np.frombuffer(raw, dtype=TENSOR_DTYPE)
    if arr.size % (INPUT_C * INPUT_H * INPUT_W) != 0:
        raise HTTPException(400, "payload size is not a multiple of one image tensor")
    batch = arr.reshape(-1, INPUT_C, INPUT_H, INPUT_W)

    x = torch.from_numpy(batch).to(gpu_app.state.device)
    with torch.inference_mode():
        logits = gpu_app.state.model(x)
    out = logits.detach().to("cpu").numpy().astype(TENSOR_DTYPE, copy=False)
    return Response(content=out.tobytes(), media_type="application/octet-stream")
# {{/docs-fragment gpu-infer}}

# {{docs-fragment gpu-env}}
gpu_env = FastAPIAppEnvironment(
    name="serving-graph-gpu",
    app=gpu_app,
    image=gpu_image,
    resources=flyte.Resources(cpu=2, memory="8Gi", gpu="A10G:1"),
    # GPU replicas are expensive; keep at least one warm so model weights stay
    # resident, and cap the max. Bump if a single replica saturates.
    scaling=flyte.app.Scaling(replicas=(1, 2)),
    requires_auth=True,
)
# {{/docs-fragment gpu-env}}

# ===========================================================================
# CPU app — pre/postprocess, calls the GPU app
# ===========================================================================

IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=TENSOR_DTYPE).reshape(3, 1, 1)
IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=TENSOR_DTYPE).reshape(3, 1, 1)

class ClassifyRequest(BaseModel):
    image_url: str
    top_k: int = 5

class Prediction(BaseModel):
    label: str
    score: float

# {{docs-fragment cpu-preprocess}}
def _preprocess(img_bytes: bytes) -> np.ndarray:
    """Decode → denoise → resize → normalize. CPU-bound, deliberately so.

    Real preprocessing stacks (detection, OCR, audio) do substantially more
    than this — sliding window crops, color-space conversion, etc. The point
    is that none of it benefits from a GPU sitting next to it.
    """
    img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
    img = img.filter(ImageFilter.GaussianBlur(radius=1.0))
    img = img.resize((INPUT_W, INPUT_H), Image.BILINEAR)
    arr = np.asarray(img, dtype=TENSOR_DTYPE) / 255.0
    arr = arr.transpose(2, 0, 1)  # HWC → CHW
    arr = (arr - IMAGENET_MEAN) / IMAGENET_STD
    return np.ascontiguousarray(arr, dtype=TENSOR_DTYPE)
# {{/docs-fragment cpu-preprocess}}

def _softmax(x: np.ndarray, axis: int = -1) -> np.ndarray:
    x = x - x.max(axis=axis, keepdims=True)
    e = np.exp(x)
    return e / e.sum(axis=axis, keepdims=True)

# {{docs-fragment cpu-lifespan}}
@asynccontextmanager
async def _cpu_lifespan(app: FastAPI):
    # Resolved at serving time via the cluster-internal endpoint pattern,
    # so this stays correct across local/remote deploys without an env var.
    gpu_url = gpu_env.endpoint
    log = logging.getLogger(__name__)
    log.info("resolved GPU endpoint: %s", gpu_url)
    async with httpx.AsyncClient(timeout=30.0) as bootstrap:
        try:
            r = await bootstrap.get(f"{gpu_url}/labels")
            r.raise_for_status()
        except (httpx.HTTPError, OSError) as e:
            # Most common reason on a fresh deploy: GPU replica hasn't finished
            # pulling its image / loading weights yet. Crash-looping is fine —
            # the next attempt will likely succeed — but make the cause obvious.
            log.error("downstream GPU app at %s not ready: %s", gpu_url, e)
            raise
        app.state.labels = r.json()
    # One persistent client per replica — avoids TCP/TLS handshake per request,
    # which matters once you're doing 100s of req/s.
    async with httpx.AsyncClient(
        base_url=gpu_url,
        timeout=httpx.Timeout(30.0, connect=5.0),
        limits=httpx.Limits(max_connections=64, max_keepalive_connections=32),
    ) as client:
        app.state.client = client
        yield

cpu_app = FastAPI(
    title="inference-cpu",
    description="Pre/post around the GPU forward pass.",
    lifespan=_cpu_lifespan,
)
# {{/docs-fragment cpu-lifespan}}

@cpu_app.get("/health")
async def cpu_health() -> dict:
    return {"status": "ok", "labels_loaded": len(cpu_app.state.labels)}

# {{docs-fragment cpu-classify}}
async def validate_public_image_url(image_url: str) -> str:
     try:
         parsed = httpx.URL(image_url)
     except Exception as exc:
         raise HTTPException(status_code=400, detail="Invalid image_url.") from exc
     if parsed.scheme not in {"http", "https"}:
         raise HTTPException(status_code=400, detail="image_url must use http or https.")
     host = parsed.host
     if not host:
         raise HTTPException(status_code=400, detail="image_url must include a hostname.")
     try:
         addr_info = socket.getaddrinfo(host, parsed.port or (443 if parsed.scheme == "https" else 80))
     except socket.gaierror as exc:
         raise HTTPException(status_code=400, detail="image_url host could not be resolved.") from exc
     for info in addr_info:
         ip_text = info[4][0]
         ip_obj = ipaddress.ip_address(ip_text)
         if not ip_obj.is_global:
             raise HTTPException(status_code=400, detail="image_url host resolves to a non-public address.")
     return str(parsed)

@cpu_app.post("/classify", response_model=list[Prediction])
async def classify(req: ClassifyRequest) -> list[Prediction]:
    async with httpx.AsyncClient(timeout=30.0) as client:
        img_resp = await client.get(await validate_public_image_url(req.image_url))
        img_resp.raise_for_status()

    tensor = _preprocess(img_resp.content)  # heavy CPU
    batch = tensor[np.newaxis, ...]  # add batch dim

    gpu_resp = await cpu_app.state.client.post(
        "/infer",
        content=batch.tobytes(),
        headers={"content-type": "application/octet-stream"},
    )
    gpu_resp.raise_for_status()
    logits = np.frombuffer(gpu_resp.content, dtype=TENSOR_DTYPE).reshape(1, NUM_CLASSES)

    probs = _softmax(logits, axis=-1)[0]  # back to CPU work
    top_idx = np.argsort(-probs)[: req.top_k]
    return [Prediction(label=cpu_app.state.labels[i], score=float(probs[i])) for i in top_idx]
# {{/docs-fragment cpu-classify}}

# {{docs-fragment cpu-env}}
cpu_env = FastAPIAppEnvironment(
    name="serving-graph-cpu",
    app=cpu_app,
    image=cpu_image,
    resources=flyte.Resources(cpu=4, memory="4Gi"),
    # Cheap, so scale wide. Use scale-to-zero (replicas=(0, 8)) for bursty
    # traffic; keep replicas=(1, 8) here to avoid cold starts in the demo.
    scaling=flyte.app.Scaling(replicas=(1, 8)),
    requires_auth=True,
    depends_on=[gpu_env],
)
# {{/docs-fragment cpu-env}}

# ===========================================================================
# Deploy
# ===========================================================================

# {{docs-fragment deploy}}
if __name__ == "__main__":
    flyte.init_from_config(
        root_dir=pathlib.Path(__file__).parent,
        log_level=logging.INFO,
    )
    app = flyte.serve(cpu_env)
    print(f"Deployed serving graph; public CPU endpoint: {app.url}")
    print("Try: curl -X POST $URL/classify -H 'content-type: application/json' \\")
    print(
        '       -d \'{"image_url": "https://upload.wikimedia.org/wikipedia/commons/4/41/Sunflower_from_Silesia2.jpg"}\''
    )
# {{/docs-fragment deploy}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/serving_graphs/image_classification.py*

The GPU environment requests a GPU and keeps replicas narrow:

```
"""
Serving graph — CPU pre/post split from a GPU forward pass.

This example shows the canonical "two-app" inference graph: heavy CPU work
on one app, the GPU forward pass on another, talking to each other over HTTP
inside the cluster.

Why split? In a typical vision/audio/feature-engineering pipeline the GPU
forward pass is fast (millis) but is sandwiched between slow CPU work
(image decode, resize, denoise, NMS, label lookup, etc.). If you put both
stages in one process you pay for an idle GPU during preprocessing. Splitting
them lets each side scale independently:

    client ──► [cpu_app  x N replicas]  ──► [gpu_app x M replicas] ──► back
                preprocess + postprocess        model.forward only
                cheap CPU, scale wide           expensive GPU, scale narrow

Wire format between the two apps is raw float32 bytes (not JSON) — for
anything tensor-shaped this is the single biggest perf knob.
"""

import io
import ipaddress
import logging
import pathlib
import socket
from contextlib import asynccontextmanager

import httpx
import numpy as np
from fastapi import FastAPI, HTTPException, Request, Response
from PIL import Image, ImageFilter
from pydantic import BaseModel

import flyte
import flyte.app
from flyte.app.extras import FastAPIAppEnvironment

# ---------------------------------------------------------------------------
# Images
# ---------------------------------------------------------------------------
# Shared base with the deps both apps need (HTTP server + numpy). The CPU and
# GPU images extend it with their own disjoint stacks — the CPU app never
# imports torch and the GPU app never imports PIL. Sharing the base layer
# means the registry only stores one copy of fastapi/uvicorn/numpy.

# {{docs-fragment images}}
base_image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
    "fastapi",
    "uvicorn",
    "numpy",
)

cpu_image = base_image.with_pip_packages(
    "httpx",
    "pillow",
)

gpu_image = base_image.with_pip_packages(
    "torch==2.7.1",
    "torchvision==0.22.1",
)
# {{/docs-fragment images}}

# ---------------------------------------------------------------------------
# Shared tensor layout
# ---------------------------------------------------------------------------

INPUT_C, INPUT_H, INPUT_W = 3, 224, 224
NUM_CLASSES = 1000
TENSOR_DTYPE = np.float32

# ===========================================================================
# GPU app — model.forward only
# ===========================================================================

# {{docs-fragment gpu-lifespan}}
@asynccontextmanager
async def _gpu_lifespan(app: FastAPI):
    # Imported lazily so the CPU app never has to import torch.
    import torch
    from torchvision.models import ResNet18_Weights, resnet18

    weights = ResNet18_Weights.IMAGENET1K_V1
    model = resnet18(weights=weights).eval()
    device = "cuda" if torch.cuda.is_available() else "cpu"
    if device == "cuda":
        model = model.to("cuda")
    app.state.model = model
    app.state.device = device
    app.state.categories = list(weights.meta["categories"])
    logging.getLogger(__name__).info("model loaded on %s", device)
    yield

gpu_app = FastAPI(
    title="inference-gpu",
    description="ResNet18 forward pass.",
    lifespan=_gpu_lifespan,
)
# {{/docs-fragment gpu-lifespan}}

@gpu_app.get("/health")
async def gpu_health() -> dict:
    return {"status": "ok", "device": gpu_app.state.device}

@gpu_app.get("/labels")
async def labels() -> list[str]:
    # Exposed so the CPU side can fetch labels once at startup instead of
    # hard-coding the ImageNet class list.
    return gpu_app.state.categories

# {{docs-fragment gpu-infer}}
@gpu_app.post("/infer")
async def infer(request: Request) -> Response:
    """Run a batched forward pass.

    Request body:  raw float32 bytes, shape (B, 3, 224, 224), C-contiguous.
    Response body: raw float32 bytes, shape (B, 1000) — raw logits.

    We deliberately do NOT use JSON here. For a batch of 32 images the tensor
    is ~19MB; JSON-serializing that is the dominant cost end-to-end.
    """
    import torch

    raw = await request.body()
    arr = np.frombuffer(raw, dtype=TENSOR_DTYPE)
    if arr.size % (INPUT_C * INPUT_H * INPUT_W) != 0:
        raise HTTPException(400, "payload size is not a multiple of one image tensor")
    batch = arr.reshape(-1, INPUT_C, INPUT_H, INPUT_W)

    x = torch.from_numpy(batch).to(gpu_app.state.device)
    with torch.inference_mode():
        logits = gpu_app.state.model(x)
    out = logits.detach().to("cpu").numpy().astype(TENSOR_DTYPE, copy=False)
    return Response(content=out.tobytes(), media_type="application/octet-stream")
# {{/docs-fragment gpu-infer}}

# {{docs-fragment gpu-env}}
gpu_env = FastAPIAppEnvironment(
    name="serving-graph-gpu",
    app=gpu_app,
    image=gpu_image,
    resources=flyte.Resources(cpu=2, memory="8Gi", gpu="A10G:1"),
    # GPU replicas are expensive; keep at least one warm so model weights stay
    # resident, and cap the max. Bump if a single replica saturates.
    scaling=flyte.app.Scaling(replicas=(1, 2)),
    requires_auth=True,
)
# {{/docs-fragment gpu-env}}

# ===========================================================================
# CPU app — pre/postprocess, calls the GPU app
# ===========================================================================

IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=TENSOR_DTYPE).reshape(3, 1, 1)
IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=TENSOR_DTYPE).reshape(3, 1, 1)

class ClassifyRequest(BaseModel):
    image_url: str
    top_k: int = 5

class Prediction(BaseModel):
    label: str
    score: float

# {{docs-fragment cpu-preprocess}}
def _preprocess(img_bytes: bytes) -> np.ndarray:
    """Decode → denoise → resize → normalize. CPU-bound, deliberately so.

    Real preprocessing stacks (detection, OCR, audio) do substantially more
    than this — sliding window crops, color-space conversion, etc. The point
    is that none of it benefits from a GPU sitting next to it.
    """
    img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
    img = img.filter(ImageFilter.GaussianBlur(radius=1.0))
    img = img.resize((INPUT_W, INPUT_H), Image.BILINEAR)
    arr = np.asarray(img, dtype=TENSOR_DTYPE) / 255.0
    arr = arr.transpose(2, 0, 1)  # HWC → CHW
    arr = (arr - IMAGENET_MEAN) / IMAGENET_STD
    return np.ascontiguousarray(arr, dtype=TENSOR_DTYPE)
# {{/docs-fragment cpu-preprocess}}

def _softmax(x: np.ndarray, axis: int = -1) -> np.ndarray:
    x = x - x.max(axis=axis, keepdims=True)
    e = np.exp(x)
    return e / e.sum(axis=axis, keepdims=True)

# {{docs-fragment cpu-lifespan}}
@asynccontextmanager
async def _cpu_lifespan(app: FastAPI):
    # Resolved at serving time via the cluster-internal endpoint pattern,
    # so this stays correct across local/remote deploys without an env var.
    gpu_url = gpu_env.endpoint
    log = logging.getLogger(__name__)
    log.info("resolved GPU endpoint: %s", gpu_url)
    async with httpx.AsyncClient(timeout=30.0) as bootstrap:
        try:
            r = await bootstrap.get(f"{gpu_url}/labels")
            r.raise_for_status()
        except (httpx.HTTPError, OSError) as e:
            # Most common reason on a fresh deploy: GPU replica hasn't finished
            # pulling its image / loading weights yet. Crash-looping is fine —
            # the next attempt will likely succeed — but make the cause obvious.
            log.error("downstream GPU app at %s not ready: %s", gpu_url, e)
            raise
        app.state.labels = r.json()
    # One persistent client per replica — avoids TCP/TLS handshake per request,
    # which matters once you're doing 100s of req/s.
    async with httpx.AsyncClient(
        base_url=gpu_url,
        timeout=httpx.Timeout(30.0, connect=5.0),
        limits=httpx.Limits(max_connections=64, max_keepalive_connections=32),
    ) as client:
        app.state.client = client
        yield

cpu_app = FastAPI(
    title="inference-cpu",
    description="Pre/post around the GPU forward pass.",
    lifespan=_cpu_lifespan,
)
# {{/docs-fragment cpu-lifespan}}

@cpu_app.get("/health")
async def cpu_health() -> dict:
    return {"status": "ok", "labels_loaded": len(cpu_app.state.labels)}

# {{docs-fragment cpu-classify}}
async def validate_public_image_url(image_url: str) -> str:
     try:
         parsed = httpx.URL(image_url)
     except Exception as exc:
         raise HTTPException(status_code=400, detail="Invalid image_url.") from exc
     if parsed.scheme not in {"http", "https"}:
         raise HTTPException(status_code=400, detail="image_url must use http or https.")
     host = parsed.host
     if not host:
         raise HTTPException(status_code=400, detail="image_url must include a hostname.")
     try:
         addr_info = socket.getaddrinfo(host, parsed.port or (443 if parsed.scheme == "https" else 80))
     except socket.gaierror as exc:
         raise HTTPException(status_code=400, detail="image_url host could not be resolved.") from exc
     for info in addr_info:
         ip_text = info[4][0]
         ip_obj = ipaddress.ip_address(ip_text)
         if not ip_obj.is_global:
             raise HTTPException(status_code=400, detail="image_url host resolves to a non-public address.")
     return str(parsed)

@cpu_app.post("/classify", response_model=list[Prediction])
async def classify(req: ClassifyRequest) -> list[Prediction]:
    async with httpx.AsyncClient(timeout=30.0) as client:
        img_resp = await client.get(await validate_public_image_url(req.image_url))
        img_resp.raise_for_status()

    tensor = _preprocess(img_resp.content)  # heavy CPU
    batch = tensor[np.newaxis, ...]  # add batch dim

    gpu_resp = await cpu_app.state.client.post(
        "/infer",
        content=batch.tobytes(),
        headers={"content-type": "application/octet-stream"},
    )
    gpu_resp.raise_for_status()
    logits = np.frombuffer(gpu_resp.content, dtype=TENSOR_DTYPE).reshape(1, NUM_CLASSES)

    probs = _softmax(logits, axis=-1)[0]  # back to CPU work
    top_idx = np.argsort(-probs)[: req.top_k]
    return [Prediction(label=cpu_app.state.labels[i], score=float(probs[i])) for i in top_idx]
# {{/docs-fragment cpu-classify}}

# {{docs-fragment cpu-env}}
cpu_env = FastAPIAppEnvironment(
    name="serving-graph-cpu",
    app=cpu_app,
    image=cpu_image,
    resources=flyte.Resources(cpu=4, memory="4Gi"),
    # Cheap, so scale wide. Use scale-to-zero (replicas=(0, 8)) for bursty
    # traffic; keep replicas=(1, 8) here to avoid cold starts in the demo.
    scaling=flyte.app.Scaling(replicas=(1, 8)),
    requires_auth=True,
    depends_on=[gpu_env],
)
# {{/docs-fragment cpu-env}}

# ===========================================================================
# Deploy
# ===========================================================================

# {{docs-fragment deploy}}
if __name__ == "__main__":
    flyte.init_from_config(
        root_dir=pathlib.Path(__file__).parent,
        log_level=logging.INFO,
    )
    app = flyte.serve(cpu_env)
    print(f"Deployed serving graph; public CPU endpoint: {app.url}")
    print("Try: curl -X POST $URL/classify -H 'content-type: application/json' \\")
    print(
        '       -d \'{"image_url": "https://upload.wikimedia.org/wikipedia/commons/4/41/Sunflower_from_Silesia2.jpg"}\''
    )
# {{/docs-fragment deploy}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/serving_graphs/image_classification.py*

### CPU app: pre/postprocess + call GPU

Preprocessing is deliberately CPU-bound — decode, denoise, resize, normalize:

```
"""
Serving graph — CPU pre/post split from a GPU forward pass.

This example shows the canonical "two-app" inference graph: heavy CPU work
on one app, the GPU forward pass on another, talking to each other over HTTP
inside the cluster.

Why split? In a typical vision/audio/feature-engineering pipeline the GPU
forward pass is fast (millis) but is sandwiched between slow CPU work
(image decode, resize, denoise, NMS, label lookup, etc.). If you put both
stages in one process you pay for an idle GPU during preprocessing. Splitting
them lets each side scale independently:

    client ──► [cpu_app  x N replicas]  ──► [gpu_app x M replicas] ──► back
                preprocess + postprocess        model.forward only
                cheap CPU, scale wide           expensive GPU, scale narrow

Wire format between the two apps is raw float32 bytes (not JSON) — for
anything tensor-shaped this is the single biggest perf knob.
"""

import io
import ipaddress
import logging
import pathlib
import socket
from contextlib import asynccontextmanager

import httpx
import numpy as np
from fastapi import FastAPI, HTTPException, Request, Response
from PIL import Image, ImageFilter
from pydantic import BaseModel

import flyte
import flyte.app
from flyte.app.extras import FastAPIAppEnvironment

# ---------------------------------------------------------------------------
# Images
# ---------------------------------------------------------------------------
# Shared base with the deps both apps need (HTTP server + numpy). The CPU and
# GPU images extend it with their own disjoint stacks — the CPU app never
# imports torch and the GPU app never imports PIL. Sharing the base layer
# means the registry only stores one copy of fastapi/uvicorn/numpy.

# {{docs-fragment images}}
base_image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
    "fastapi",
    "uvicorn",
    "numpy",
)

cpu_image = base_image.with_pip_packages(
    "httpx",
    "pillow",
)

gpu_image = base_image.with_pip_packages(
    "torch==2.7.1",
    "torchvision==0.22.1",
)
# {{/docs-fragment images}}

# ---------------------------------------------------------------------------
# Shared tensor layout
# ---------------------------------------------------------------------------

INPUT_C, INPUT_H, INPUT_W = 3, 224, 224
NUM_CLASSES = 1000
TENSOR_DTYPE = np.float32

# ===========================================================================
# GPU app — model.forward only
# ===========================================================================

# {{docs-fragment gpu-lifespan}}
@asynccontextmanager
async def _gpu_lifespan(app: FastAPI):
    # Imported lazily so the CPU app never has to import torch.
    import torch
    from torchvision.models import ResNet18_Weights, resnet18

    weights = ResNet18_Weights.IMAGENET1K_V1
    model = resnet18(weights=weights).eval()
    device = "cuda" if torch.cuda.is_available() else "cpu"
    if device == "cuda":
        model = model.to("cuda")
    app.state.model = model
    app.state.device = device
    app.state.categories = list(weights.meta["categories"])
    logging.getLogger(__name__).info("model loaded on %s", device)
    yield

gpu_app = FastAPI(
    title="inference-gpu",
    description="ResNet18 forward pass.",
    lifespan=_gpu_lifespan,
)
# {{/docs-fragment gpu-lifespan}}

@gpu_app.get("/health")
async def gpu_health() -> dict:
    return {"status": "ok", "device": gpu_app.state.device}

@gpu_app.get("/labels")
async def labels() -> list[str]:
    # Exposed so the CPU side can fetch labels once at startup instead of
    # hard-coding the ImageNet class list.
    return gpu_app.state.categories

# {{docs-fragment gpu-infer}}
@gpu_app.post("/infer")
async def infer(request: Request) -> Response:
    """Run a batched forward pass.

    Request body:  raw float32 bytes, shape (B, 3, 224, 224), C-contiguous.
    Response body: raw float32 bytes, shape (B, 1000) — raw logits.

    We deliberately do NOT use JSON here. For a batch of 32 images the tensor
    is ~19MB; JSON-serializing that is the dominant cost end-to-end.
    """
    import torch

    raw = await request.body()
    arr = np.frombuffer(raw, dtype=TENSOR_DTYPE)
    if arr.size % (INPUT_C * INPUT_H * INPUT_W) != 0:
        raise HTTPException(400, "payload size is not a multiple of one image tensor")
    batch = arr.reshape(-1, INPUT_C, INPUT_H, INPUT_W)

    x = torch.from_numpy(batch).to(gpu_app.state.device)
    with torch.inference_mode():
        logits = gpu_app.state.model(x)
    out = logits.detach().to("cpu").numpy().astype(TENSOR_DTYPE, copy=False)
    return Response(content=out.tobytes(), media_type="application/octet-stream")
# {{/docs-fragment gpu-infer}}

# {{docs-fragment gpu-env}}
gpu_env = FastAPIAppEnvironment(
    name="serving-graph-gpu",
    app=gpu_app,
    image=gpu_image,
    resources=flyte.Resources(cpu=2, memory="8Gi", gpu="A10G:1"),
    # GPU replicas are expensive; keep at least one warm so model weights stay
    # resident, and cap the max. Bump if a single replica saturates.
    scaling=flyte.app.Scaling(replicas=(1, 2)),
    requires_auth=True,
)
# {{/docs-fragment gpu-env}}

# ===========================================================================
# CPU app — pre/postprocess, calls the GPU app
# ===========================================================================

IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=TENSOR_DTYPE).reshape(3, 1, 1)
IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=TENSOR_DTYPE).reshape(3, 1, 1)

class ClassifyRequest(BaseModel):
    image_url: str
    top_k: int = 5

class Prediction(BaseModel):
    label: str
    score: float

# {{docs-fragment cpu-preprocess}}
def _preprocess(img_bytes: bytes) -> np.ndarray:
    """Decode → denoise → resize → normalize. CPU-bound, deliberately so.

    Real preprocessing stacks (detection, OCR, audio) do substantially more
    than this — sliding window crops, color-space conversion, etc. The point
    is that none of it benefits from a GPU sitting next to it.
    """
    img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
    img = img.filter(ImageFilter.GaussianBlur(radius=1.0))
    img = img.resize((INPUT_W, INPUT_H), Image.BILINEAR)
    arr = np.asarray(img, dtype=TENSOR_DTYPE) / 255.0
    arr = arr.transpose(2, 0, 1)  # HWC → CHW
    arr = (arr - IMAGENET_MEAN) / IMAGENET_STD
    return np.ascontiguousarray(arr, dtype=TENSOR_DTYPE)
# {{/docs-fragment cpu-preprocess}}

def _softmax(x: np.ndarray, axis: int = -1) -> np.ndarray:
    x = x - x.max(axis=axis, keepdims=True)
    e = np.exp(x)
    return e / e.sum(axis=axis, keepdims=True)

# {{docs-fragment cpu-lifespan}}
@asynccontextmanager
async def _cpu_lifespan(app: FastAPI):
    # Resolved at serving time via the cluster-internal endpoint pattern,
    # so this stays correct across local/remote deploys without an env var.
    gpu_url = gpu_env.endpoint
    log = logging.getLogger(__name__)
    log.info("resolved GPU endpoint: %s", gpu_url)
    async with httpx.AsyncClient(timeout=30.0) as bootstrap:
        try:
            r = await bootstrap.get(f"{gpu_url}/labels")
            r.raise_for_status()
        except (httpx.HTTPError, OSError) as e:
            # Most common reason on a fresh deploy: GPU replica hasn't finished
            # pulling its image / loading weights yet. Crash-looping is fine —
            # the next attempt will likely succeed — but make the cause obvious.
            log.error("downstream GPU app at %s not ready: %s", gpu_url, e)
            raise
        app.state.labels = r.json()
    # One persistent client per replica — avoids TCP/TLS handshake per request,
    # which matters once you're doing 100s of req/s.
    async with httpx.AsyncClient(
        base_url=gpu_url,
        timeout=httpx.Timeout(30.0, connect=5.0),
        limits=httpx.Limits(max_connections=64, max_keepalive_connections=32),
    ) as client:
        app.state.client = client
        yield

cpu_app = FastAPI(
    title="inference-cpu",
    description="Pre/post around the GPU forward pass.",
    lifespan=_cpu_lifespan,
)
# {{/docs-fragment cpu-lifespan}}

@cpu_app.get("/health")
async def cpu_health() -> dict:
    return {"status": "ok", "labels_loaded": len(cpu_app.state.labels)}

# {{docs-fragment cpu-classify}}
async def validate_public_image_url(image_url: str) -> str:
     try:
         parsed = httpx.URL(image_url)
     except Exception as exc:
         raise HTTPException(status_code=400, detail="Invalid image_url.") from exc
     if parsed.scheme not in {"http", "https"}:
         raise HTTPException(status_code=400, detail="image_url must use http or https.")
     host = parsed.host
     if not host:
         raise HTTPException(status_code=400, detail="image_url must include a hostname.")
     try:
         addr_info = socket.getaddrinfo(host, parsed.port or (443 if parsed.scheme == "https" else 80))
     except socket.gaierror as exc:
         raise HTTPException(status_code=400, detail="image_url host could not be resolved.") from exc
     for info in addr_info:
         ip_text = info[4][0]
         ip_obj = ipaddress.ip_address(ip_text)
         if not ip_obj.is_global:
             raise HTTPException(status_code=400, detail="image_url host resolves to a non-public address.")
     return str(parsed)

@cpu_app.post("/classify", response_model=list[Prediction])
async def classify(req: ClassifyRequest) -> list[Prediction]:
    async with httpx.AsyncClient(timeout=30.0) as client:
        img_resp = await client.get(await validate_public_image_url(req.image_url))
        img_resp.raise_for_status()

    tensor = _preprocess(img_resp.content)  # heavy CPU
    batch = tensor[np.newaxis, ...]  # add batch dim

    gpu_resp = await cpu_app.state.client.post(
        "/infer",
        content=batch.tobytes(),
        headers={"content-type": "application/octet-stream"},
    )
    gpu_resp.raise_for_status()
    logits = np.frombuffer(gpu_resp.content, dtype=TENSOR_DTYPE).reshape(1, NUM_CLASSES)

    probs = _softmax(logits, axis=-1)[0]  # back to CPU work
    top_idx = np.argsort(-probs)[: req.top_k]
    return [Prediction(label=cpu_app.state.labels[i], score=float(probs[i])) for i in top_idx]
# {{/docs-fragment cpu-classify}}

# {{docs-fragment cpu-env}}
cpu_env = FastAPIAppEnvironment(
    name="serving-graph-cpu",
    app=cpu_app,
    image=cpu_image,
    resources=flyte.Resources(cpu=4, memory="4Gi"),
    # Cheap, so scale wide. Use scale-to-zero (replicas=(0, 8)) for bursty
    # traffic; keep replicas=(1, 8) here to avoid cold starts in the demo.
    scaling=flyte.app.Scaling(replicas=(1, 8)),
    requires_auth=True,
    depends_on=[gpu_env],
)
# {{/docs-fragment cpu-env}}

# ===========================================================================
# Deploy
# ===========================================================================

# {{docs-fragment deploy}}
if __name__ == "__main__":
    flyte.init_from_config(
        root_dir=pathlib.Path(__file__).parent,
        log_level=logging.INFO,
    )
    app = flyte.serve(cpu_env)
    print(f"Deployed serving graph; public CPU endpoint: {app.url}")
    print("Try: curl -X POST $URL/classify -H 'content-type: application/json' \\")
    print(
        '       -d \'{"image_url": "https://upload.wikimedia.org/wikipedia/commons/4/41/Sunflower_from_Silesia2.jpg"}\''
    )
# {{/docs-fragment deploy}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/serving_graphs/image_classification.py*

The CPU app uses its lifespan to resolve the GPU endpoint via `gpu_env.endpoint`,
fetch labels once at startup, and build one persistent `httpx.AsyncClient` per
replica. Persistent clients avoid a TCP/TLS handshake per request, which
matters at high request rates:

```
"""
Serving graph — CPU pre/post split from a GPU forward pass.

This example shows the canonical "two-app" inference graph: heavy CPU work
on one app, the GPU forward pass on another, talking to each other over HTTP
inside the cluster.

Why split? In a typical vision/audio/feature-engineering pipeline the GPU
forward pass is fast (millis) but is sandwiched between slow CPU work
(image decode, resize, denoise, NMS, label lookup, etc.). If you put both
stages in one process you pay for an idle GPU during preprocessing. Splitting
them lets each side scale independently:

    client ──► [cpu_app  x N replicas]  ──► [gpu_app x M replicas] ──► back
                preprocess + postprocess        model.forward only
                cheap CPU, scale wide           expensive GPU, scale narrow

Wire format between the two apps is raw float32 bytes (not JSON) — for
anything tensor-shaped this is the single biggest perf knob.
"""

import io
import ipaddress
import logging
import pathlib
import socket
from contextlib import asynccontextmanager

import httpx
import numpy as np
from fastapi import FastAPI, HTTPException, Request, Response
from PIL import Image, ImageFilter
from pydantic import BaseModel

import flyte
import flyte.app
from flyte.app.extras import FastAPIAppEnvironment

# ---------------------------------------------------------------------------
# Images
# ---------------------------------------------------------------------------
# Shared base with the deps both apps need (HTTP server + numpy). The CPU and
# GPU images extend it with their own disjoint stacks — the CPU app never
# imports torch and the GPU app never imports PIL. Sharing the base layer
# means the registry only stores one copy of fastapi/uvicorn/numpy.

# {{docs-fragment images}}
base_image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
    "fastapi",
    "uvicorn",
    "numpy",
)

cpu_image = base_image.with_pip_packages(
    "httpx",
    "pillow",
)

gpu_image = base_image.with_pip_packages(
    "torch==2.7.1",
    "torchvision==0.22.1",
)
# {{/docs-fragment images}}

# ---------------------------------------------------------------------------
# Shared tensor layout
# ---------------------------------------------------------------------------

INPUT_C, INPUT_H, INPUT_W = 3, 224, 224
NUM_CLASSES = 1000
TENSOR_DTYPE = np.float32

# ===========================================================================
# GPU app — model.forward only
# ===========================================================================

# {{docs-fragment gpu-lifespan}}
@asynccontextmanager
async def _gpu_lifespan(app: FastAPI):
    # Imported lazily so the CPU app never has to import torch.
    import torch
    from torchvision.models import ResNet18_Weights, resnet18

    weights = ResNet18_Weights.IMAGENET1K_V1
    model = resnet18(weights=weights).eval()
    device = "cuda" if torch.cuda.is_available() else "cpu"
    if device == "cuda":
        model = model.to("cuda")
    app.state.model = model
    app.state.device = device
    app.state.categories = list(weights.meta["categories"])
    logging.getLogger(__name__).info("model loaded on %s", device)
    yield

gpu_app = FastAPI(
    title="inference-gpu",
    description="ResNet18 forward pass.",
    lifespan=_gpu_lifespan,
)
# {{/docs-fragment gpu-lifespan}}

@gpu_app.get("/health")
async def gpu_health() -> dict:
    return {"status": "ok", "device": gpu_app.state.device}

@gpu_app.get("/labels")
async def labels() -> list[str]:
    # Exposed so the CPU side can fetch labels once at startup instead of
    # hard-coding the ImageNet class list.
    return gpu_app.state.categories

# {{docs-fragment gpu-infer}}
@gpu_app.post("/infer")
async def infer(request: Request) -> Response:
    """Run a batched forward pass.

    Request body:  raw float32 bytes, shape (B, 3, 224, 224), C-contiguous.
    Response body: raw float32 bytes, shape (B, 1000) — raw logits.

    We deliberately do NOT use JSON here. For a batch of 32 images the tensor
    is ~19MB; JSON-serializing that is the dominant cost end-to-end.
    """
    import torch

    raw = await request.body()
    arr = np.frombuffer(raw, dtype=TENSOR_DTYPE)
    if arr.size % (INPUT_C * INPUT_H * INPUT_W) != 0:
        raise HTTPException(400, "payload size is not a multiple of one image tensor")
    batch = arr.reshape(-1, INPUT_C, INPUT_H, INPUT_W)

    x = torch.from_numpy(batch).to(gpu_app.state.device)
    with torch.inference_mode():
        logits = gpu_app.state.model(x)
    out = logits.detach().to("cpu").numpy().astype(TENSOR_DTYPE, copy=False)
    return Response(content=out.tobytes(), media_type="application/octet-stream")
# {{/docs-fragment gpu-infer}}

# {{docs-fragment gpu-env}}
gpu_env = FastAPIAppEnvironment(
    name="serving-graph-gpu",
    app=gpu_app,
    image=gpu_image,
    resources=flyte.Resources(cpu=2, memory="8Gi", gpu="A10G:1"),
    # GPU replicas are expensive; keep at least one warm so model weights stay
    # resident, and cap the max. Bump if a single replica saturates.
    scaling=flyte.app.Scaling(replicas=(1, 2)),
    requires_auth=True,
)
# {{/docs-fragment gpu-env}}

# ===========================================================================
# CPU app — pre/postprocess, calls the GPU app
# ===========================================================================

IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=TENSOR_DTYPE).reshape(3, 1, 1)
IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=TENSOR_DTYPE).reshape(3, 1, 1)

class ClassifyRequest(BaseModel):
    image_url: str
    top_k: int = 5

class Prediction(BaseModel):
    label: str
    score: float

# {{docs-fragment cpu-preprocess}}
def _preprocess(img_bytes: bytes) -> np.ndarray:
    """Decode → denoise → resize → normalize. CPU-bound, deliberately so.

    Real preprocessing stacks (detection, OCR, audio) do substantially more
    than this — sliding window crops, color-space conversion, etc. The point
    is that none of it benefits from a GPU sitting next to it.
    """
    img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
    img = img.filter(ImageFilter.GaussianBlur(radius=1.0))
    img = img.resize((INPUT_W, INPUT_H), Image.BILINEAR)
    arr = np.asarray(img, dtype=TENSOR_DTYPE) / 255.0
    arr = arr.transpose(2, 0, 1)  # HWC → CHW
    arr = (arr - IMAGENET_MEAN) / IMAGENET_STD
    return np.ascontiguousarray(arr, dtype=TENSOR_DTYPE)
# {{/docs-fragment cpu-preprocess}}

def _softmax(x: np.ndarray, axis: int = -1) -> np.ndarray:
    x = x - x.max(axis=axis, keepdims=True)
    e = np.exp(x)
    return e / e.sum(axis=axis, keepdims=True)

# {{docs-fragment cpu-lifespan}}
@asynccontextmanager
async def _cpu_lifespan(app: FastAPI):
    # Resolved at serving time via the cluster-internal endpoint pattern,
    # so this stays correct across local/remote deploys without an env var.
    gpu_url = gpu_env.endpoint
    log = logging.getLogger(__name__)
    log.info("resolved GPU endpoint: %s", gpu_url)
    async with httpx.AsyncClient(timeout=30.0) as bootstrap:
        try:
            r = await bootstrap.get(f"{gpu_url}/labels")
            r.raise_for_status()
        except (httpx.HTTPError, OSError) as e:
            # Most common reason on a fresh deploy: GPU replica hasn't finished
            # pulling its image / loading weights yet. Crash-looping is fine —
            # the next attempt will likely succeed — but make the cause obvious.
            log.error("downstream GPU app at %s not ready: %s", gpu_url, e)
            raise
        app.state.labels = r.json()
    # One persistent client per replica — avoids TCP/TLS handshake per request,
    # which matters once you're doing 100s of req/s.
    async with httpx.AsyncClient(
        base_url=gpu_url,
        timeout=httpx.Timeout(30.0, connect=5.0),
        limits=httpx.Limits(max_connections=64, max_keepalive_connections=32),
    ) as client:
        app.state.client = client
        yield

cpu_app = FastAPI(
    title="inference-cpu",
    description="Pre/post around the GPU forward pass.",
    lifespan=_cpu_lifespan,
)
# {{/docs-fragment cpu-lifespan}}

@cpu_app.get("/health")
async def cpu_health() -> dict:
    return {"status": "ok", "labels_loaded": len(cpu_app.state.labels)}

# {{docs-fragment cpu-classify}}
async def validate_public_image_url(image_url: str) -> str:
     try:
         parsed = httpx.URL(image_url)
     except Exception as exc:
         raise HTTPException(status_code=400, detail="Invalid image_url.") from exc
     if parsed.scheme not in {"http", "https"}:
         raise HTTPException(status_code=400, detail="image_url must use http or https.")
     host = parsed.host
     if not host:
         raise HTTPException(status_code=400, detail="image_url must include a hostname.")
     try:
         addr_info = socket.getaddrinfo(host, parsed.port or (443 if parsed.scheme == "https" else 80))
     except socket.gaierror as exc:
         raise HTTPException(status_code=400, detail="image_url host could not be resolved.") from exc
     for info in addr_info:
         ip_text = info[4][0]
         ip_obj = ipaddress.ip_address(ip_text)
         if not ip_obj.is_global:
             raise HTTPException(status_code=400, detail="image_url host resolves to a non-public address.")
     return str(parsed)

@cpu_app.post("/classify", response_model=list[Prediction])
async def classify(req: ClassifyRequest) -> list[Prediction]:
    async with httpx.AsyncClient(timeout=30.0) as client:
        img_resp = await client.get(await validate_public_image_url(req.image_url))
        img_resp.raise_for_status()

    tensor = _preprocess(img_resp.content)  # heavy CPU
    batch = tensor[np.newaxis, ...]  # add batch dim

    gpu_resp = await cpu_app.state.client.post(
        "/infer",
        content=batch.tobytes(),
        headers={"content-type": "application/octet-stream"},
    )
    gpu_resp.raise_for_status()
    logits = np.frombuffer(gpu_resp.content, dtype=TENSOR_DTYPE).reshape(1, NUM_CLASSES)

    probs = _softmax(logits, axis=-1)[0]  # back to CPU work
    top_idx = np.argsort(-probs)[: req.top_k]
    return [Prediction(label=cpu_app.state.labels[i], score=float(probs[i])) for i in top_idx]
# {{/docs-fragment cpu-classify}}

# {{docs-fragment cpu-env}}
cpu_env = FastAPIAppEnvironment(
    name="serving-graph-cpu",
    app=cpu_app,
    image=cpu_image,
    resources=flyte.Resources(cpu=4, memory="4Gi"),
    # Cheap, so scale wide. Use scale-to-zero (replicas=(0, 8)) for bursty
    # traffic; keep replicas=(1, 8) here to avoid cold starts in the demo.
    scaling=flyte.app.Scaling(replicas=(1, 8)),
    requires_auth=True,
    depends_on=[gpu_env],
)
# {{/docs-fragment cpu-env}}

# ===========================================================================
# Deploy
# ===========================================================================

# {{docs-fragment deploy}}
if __name__ == "__main__":
    flyte.init_from_config(
        root_dir=pathlib.Path(__file__).parent,
        log_level=logging.INFO,
    )
    app = flyte.serve(cpu_env)
    print(f"Deployed serving graph; public CPU endpoint: {app.url}")
    print("Try: curl -X POST $URL/classify -H 'content-type: application/json' \\")
    print(
        '       -d \'{"image_url": "https://upload.wikimedia.org/wikipedia/commons/4/41/Sunflower_from_Silesia2.jpg"}\''
    )
# {{/docs-fragment deploy}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/serving_graphs/image_classification.py*

The `/classify` endpoint glues it all together. Heavy CPU work runs in this
process; the GPU forward pass is delegated over HTTP using the raw-bytes wire
format:

```
"""
Serving graph — CPU pre/post split from a GPU forward pass.

This example shows the canonical "two-app" inference graph: heavy CPU work
on one app, the GPU forward pass on another, talking to each other over HTTP
inside the cluster.

Why split? In a typical vision/audio/feature-engineering pipeline the GPU
forward pass is fast (millis) but is sandwiched between slow CPU work
(image decode, resize, denoise, NMS, label lookup, etc.). If you put both
stages in one process you pay for an idle GPU during preprocessing. Splitting
them lets each side scale independently:

    client ──► [cpu_app  x N replicas]  ──► [gpu_app x M replicas] ──► back
                preprocess + postprocess        model.forward only
                cheap CPU, scale wide           expensive GPU, scale narrow

Wire format between the two apps is raw float32 bytes (not JSON) — for
anything tensor-shaped this is the single biggest perf knob.
"""

import io
import ipaddress
import logging
import pathlib
import socket
from contextlib import asynccontextmanager

import httpx
import numpy as np
from fastapi import FastAPI, HTTPException, Request, Response
from PIL import Image, ImageFilter
from pydantic import BaseModel

import flyte
import flyte.app
from flyte.app.extras import FastAPIAppEnvironment

# ---------------------------------------------------------------------------
# Images
# ---------------------------------------------------------------------------
# Shared base with the deps both apps need (HTTP server + numpy). The CPU and
# GPU images extend it with their own disjoint stacks — the CPU app never
# imports torch and the GPU app never imports PIL. Sharing the base layer
# means the registry only stores one copy of fastapi/uvicorn/numpy.

# {{docs-fragment images}}
base_image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
    "fastapi",
    "uvicorn",
    "numpy",
)

cpu_image = base_image.with_pip_packages(
    "httpx",
    "pillow",
)

gpu_image = base_image.with_pip_packages(
    "torch==2.7.1",
    "torchvision==0.22.1",
)
# {{/docs-fragment images}}

# ---------------------------------------------------------------------------
# Shared tensor layout
# ---------------------------------------------------------------------------

INPUT_C, INPUT_H, INPUT_W = 3, 224, 224
NUM_CLASSES = 1000
TENSOR_DTYPE = np.float32

# ===========================================================================
# GPU app — model.forward only
# ===========================================================================

# {{docs-fragment gpu-lifespan}}
@asynccontextmanager
async def _gpu_lifespan(app: FastAPI):
    # Imported lazily so the CPU app never has to import torch.
    import torch
    from torchvision.models import ResNet18_Weights, resnet18

    weights = ResNet18_Weights.IMAGENET1K_V1
    model = resnet18(weights=weights).eval()
    device = "cuda" if torch.cuda.is_available() else "cpu"
    if device == "cuda":
        model = model.to("cuda")
    app.state.model = model
    app.state.device = device
    app.state.categories = list(weights.meta["categories"])
    logging.getLogger(__name__).info("model loaded on %s", device)
    yield

gpu_app = FastAPI(
    title="inference-gpu",
    description="ResNet18 forward pass.",
    lifespan=_gpu_lifespan,
)
# {{/docs-fragment gpu-lifespan}}

@gpu_app.get("/health")
async def gpu_health() -> dict:
    return {"status": "ok", "device": gpu_app.state.device}

@gpu_app.get("/labels")
async def labels() -> list[str]:
    # Exposed so the CPU side can fetch labels once at startup instead of
    # hard-coding the ImageNet class list.
    return gpu_app.state.categories

# {{docs-fragment gpu-infer}}
@gpu_app.post("/infer")
async def infer(request: Request) -> Response:
    """Run a batched forward pass.

    Request body:  raw float32 bytes, shape (B, 3, 224, 224), C-contiguous.
    Response body: raw float32 bytes, shape (B, 1000) — raw logits.

    We deliberately do NOT use JSON here. For a batch of 32 images the tensor
    is ~19MB; JSON-serializing that is the dominant cost end-to-end.
    """
    import torch

    raw = await request.body()
    arr = np.frombuffer(raw, dtype=TENSOR_DTYPE)
    if arr.size % (INPUT_C * INPUT_H * INPUT_W) != 0:
        raise HTTPException(400, "payload size is not a multiple of one image tensor")
    batch = arr.reshape(-1, INPUT_C, INPUT_H, INPUT_W)

    x = torch.from_numpy(batch).to(gpu_app.state.device)
    with torch.inference_mode():
        logits = gpu_app.state.model(x)
    out = logits.detach().to("cpu").numpy().astype(TENSOR_DTYPE, copy=False)
    return Response(content=out.tobytes(), media_type="application/octet-stream")
# {{/docs-fragment gpu-infer}}

# {{docs-fragment gpu-env}}
gpu_env = FastAPIAppEnvironment(
    name="serving-graph-gpu",
    app=gpu_app,
    image=gpu_image,
    resources=flyte.Resources(cpu=2, memory="8Gi", gpu="A10G:1"),
    # GPU replicas are expensive; keep at least one warm so model weights stay
    # resident, and cap the max. Bump if a single replica saturates.
    scaling=flyte.app.Scaling(replicas=(1, 2)),
    requires_auth=True,
)
# {{/docs-fragment gpu-env}}

# ===========================================================================
# CPU app — pre/postprocess, calls the GPU app
# ===========================================================================

IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=TENSOR_DTYPE).reshape(3, 1, 1)
IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=TENSOR_DTYPE).reshape(3, 1, 1)

class ClassifyRequest(BaseModel):
    image_url: str
    top_k: int = 5

class Prediction(BaseModel):
    label: str
    score: float

# {{docs-fragment cpu-preprocess}}
def _preprocess(img_bytes: bytes) -> np.ndarray:
    """Decode → denoise → resize → normalize. CPU-bound, deliberately so.

    Real preprocessing stacks (detection, OCR, audio) do substantially more
    than this — sliding window crops, color-space conversion, etc. The point
    is that none of it benefits from a GPU sitting next to it.
    """
    img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
    img = img.filter(ImageFilter.GaussianBlur(radius=1.0))
    img = img.resize((INPUT_W, INPUT_H), Image.BILINEAR)
    arr = np.asarray(img, dtype=TENSOR_DTYPE) / 255.0
    arr = arr.transpose(2, 0, 1)  # HWC → CHW
    arr = (arr - IMAGENET_MEAN) / IMAGENET_STD
    return np.ascontiguousarray(arr, dtype=TENSOR_DTYPE)
# {{/docs-fragment cpu-preprocess}}

def _softmax(x: np.ndarray, axis: int = -1) -> np.ndarray:
    x = x - x.max(axis=axis, keepdims=True)
    e = np.exp(x)
    return e / e.sum(axis=axis, keepdims=True)

# {{docs-fragment cpu-lifespan}}
@asynccontextmanager
async def _cpu_lifespan(app: FastAPI):
    # Resolved at serving time via the cluster-internal endpoint pattern,
    # so this stays correct across local/remote deploys without an env var.
    gpu_url = gpu_env.endpoint
    log = logging.getLogger(__name__)
    log.info("resolved GPU endpoint: %s", gpu_url)
    async with httpx.AsyncClient(timeout=30.0) as bootstrap:
        try:
            r = await bootstrap.get(f"{gpu_url}/labels")
            r.raise_for_status()
        except (httpx.HTTPError, OSError) as e:
            # Most common reason on a fresh deploy: GPU replica hasn't finished
            # pulling its image / loading weights yet. Crash-looping is fine —
            # the next attempt will likely succeed — but make the cause obvious.
            log.error("downstream GPU app at %s not ready: %s", gpu_url, e)
            raise
        app.state.labels = r.json()
    # One persistent client per replica — avoids TCP/TLS handshake per request,
    # which matters once you're doing 100s of req/s.
    async with httpx.AsyncClient(
        base_url=gpu_url,
        timeout=httpx.Timeout(30.0, connect=5.0),
        limits=httpx.Limits(max_connections=64, max_keepalive_connections=32),
    ) as client:
        app.state.client = client
        yield

cpu_app = FastAPI(
    title="inference-cpu",
    description="Pre/post around the GPU forward pass.",
    lifespan=_cpu_lifespan,
)
# {{/docs-fragment cpu-lifespan}}

@cpu_app.get("/health")
async def cpu_health() -> dict:
    return {"status": "ok", "labels_loaded": len(cpu_app.state.labels)}

# {{docs-fragment cpu-classify}}
async def validate_public_image_url(image_url: str) -> str:
     try:
         parsed = httpx.URL(image_url)
     except Exception as exc:
         raise HTTPException(status_code=400, detail="Invalid image_url.") from exc
     if parsed.scheme not in {"http", "https"}:
         raise HTTPException(status_code=400, detail="image_url must use http or https.")
     host = parsed.host
     if not host:
         raise HTTPException(status_code=400, detail="image_url must include a hostname.")
     try:
         addr_info = socket.getaddrinfo(host, parsed.port or (443 if parsed.scheme == "https" else 80))
     except socket.gaierror as exc:
         raise HTTPException(status_code=400, detail="image_url host could not be resolved.") from exc
     for info in addr_info:
         ip_text = info[4][0]
         ip_obj = ipaddress.ip_address(ip_text)
         if not ip_obj.is_global:
             raise HTTPException(status_code=400, detail="image_url host resolves to a non-public address.")
     return str(parsed)

@cpu_app.post("/classify", response_model=list[Prediction])
async def classify(req: ClassifyRequest) -> list[Prediction]:
    async with httpx.AsyncClient(timeout=30.0) as client:
        img_resp = await client.get(await validate_public_image_url(req.image_url))
        img_resp.raise_for_status()

    tensor = _preprocess(img_resp.content)  # heavy CPU
    batch = tensor[np.newaxis, ...]  # add batch dim

    gpu_resp = await cpu_app.state.client.post(
        "/infer",
        content=batch.tobytes(),
        headers={"content-type": "application/octet-stream"},
    )
    gpu_resp.raise_for_status()
    logits = np.frombuffer(gpu_resp.content, dtype=TENSOR_DTYPE).reshape(1, NUM_CLASSES)

    probs = _softmax(logits, axis=-1)[0]  # back to CPU work
    top_idx = np.argsort(-probs)[: req.top_k]
    return [Prediction(label=cpu_app.state.labels[i], score=float(probs[i])) for i in top_idx]
# {{/docs-fragment cpu-classify}}

# {{docs-fragment cpu-env}}
cpu_env = FastAPIAppEnvironment(
    name="serving-graph-cpu",
    app=cpu_app,
    image=cpu_image,
    resources=flyte.Resources(cpu=4, memory="4Gi"),
    # Cheap, so scale wide. Use scale-to-zero (replicas=(0, 8)) for bursty
    # traffic; keep replicas=(1, 8) here to avoid cold starts in the demo.
    scaling=flyte.app.Scaling(replicas=(1, 8)),
    requires_auth=True,
    depends_on=[gpu_env],
)
# {{/docs-fragment cpu-env}}

# ===========================================================================
# Deploy
# ===========================================================================

# {{docs-fragment deploy}}
if __name__ == "__main__":
    flyte.init_from_config(
        root_dir=pathlib.Path(__file__).parent,
        log_level=logging.INFO,
    )
    app = flyte.serve(cpu_env)
    print(f"Deployed serving graph; public CPU endpoint: {app.url}")
    print("Try: curl -X POST $URL/classify -H 'content-type: application/json' \\")
    print(
        '       -d \'{"image_url": "https://upload.wikimedia.org/wikipedia/commons/4/41/Sunflower_from_Silesia2.jpg"}\''
    )
# {{/docs-fragment deploy}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/serving_graphs/image_classification.py*

The CPU environment scales wide and declares `depends_on=[gpu_env]` so both
sides deploy together:

```
"""
Serving graph — CPU pre/post split from a GPU forward pass.

This example shows the canonical "two-app" inference graph: heavy CPU work
on one app, the GPU forward pass on another, talking to each other over HTTP
inside the cluster.

Why split? In a typical vision/audio/feature-engineering pipeline the GPU
forward pass is fast (millis) but is sandwiched between slow CPU work
(image decode, resize, denoise, NMS, label lookup, etc.). If you put both
stages in one process you pay for an idle GPU during preprocessing. Splitting
them lets each side scale independently:

    client ──► [cpu_app  x N replicas]  ──► [gpu_app x M replicas] ──► back
                preprocess + postprocess        model.forward only
                cheap CPU, scale wide           expensive GPU, scale narrow

Wire format between the two apps is raw float32 bytes (not JSON) — for
anything tensor-shaped this is the single biggest perf knob.
"""

import io
import ipaddress
import logging
import pathlib
import socket
from contextlib import asynccontextmanager

import httpx
import numpy as np
from fastapi import FastAPI, HTTPException, Request, Response
from PIL import Image, ImageFilter
from pydantic import BaseModel

import flyte
import flyte.app
from flyte.app.extras import FastAPIAppEnvironment

# ---------------------------------------------------------------------------
# Images
# ---------------------------------------------------------------------------
# Shared base with the deps both apps need (HTTP server + numpy). The CPU and
# GPU images extend it with their own disjoint stacks — the CPU app never
# imports torch and the GPU app never imports PIL. Sharing the base layer
# means the registry only stores one copy of fastapi/uvicorn/numpy.

# {{docs-fragment images}}
base_image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
    "fastapi",
    "uvicorn",
    "numpy",
)

cpu_image = base_image.with_pip_packages(
    "httpx",
    "pillow",
)

gpu_image = base_image.with_pip_packages(
    "torch==2.7.1",
    "torchvision==0.22.1",
)
# {{/docs-fragment images}}

# ---------------------------------------------------------------------------
# Shared tensor layout
# ---------------------------------------------------------------------------

INPUT_C, INPUT_H, INPUT_W = 3, 224, 224
NUM_CLASSES = 1000
TENSOR_DTYPE = np.float32

# ===========================================================================
# GPU app — model.forward only
# ===========================================================================

# {{docs-fragment gpu-lifespan}}
@asynccontextmanager
async def _gpu_lifespan(app: FastAPI):
    # Imported lazily so the CPU app never has to import torch.
    import torch
    from torchvision.models import ResNet18_Weights, resnet18

    weights = ResNet18_Weights.IMAGENET1K_V1
    model = resnet18(weights=weights).eval()
    device = "cuda" if torch.cuda.is_available() else "cpu"
    if device == "cuda":
        model = model.to("cuda")
    app.state.model = model
    app.state.device = device
    app.state.categories = list(weights.meta["categories"])
    logging.getLogger(__name__).info("model loaded on %s", device)
    yield

gpu_app = FastAPI(
    title="inference-gpu",
    description="ResNet18 forward pass.",
    lifespan=_gpu_lifespan,
)
# {{/docs-fragment gpu-lifespan}}

@gpu_app.get("/health")
async def gpu_health() -> dict:
    return {"status": "ok", "device": gpu_app.state.device}

@gpu_app.get("/labels")
async def labels() -> list[str]:
    # Exposed so the CPU side can fetch labels once at startup instead of
    # hard-coding the ImageNet class list.
    return gpu_app.state.categories

# {{docs-fragment gpu-infer}}
@gpu_app.post("/infer")
async def infer(request: Request) -> Response:
    """Run a batched forward pass.

    Request body:  raw float32 bytes, shape (B, 3, 224, 224), C-contiguous.
    Response body: raw float32 bytes, shape (B, 1000) — raw logits.

    We deliberately do NOT use JSON here. For a batch of 32 images the tensor
    is ~19MB; JSON-serializing that is the dominant cost end-to-end.
    """
    import torch

    raw = await request.body()
    arr = np.frombuffer(raw, dtype=TENSOR_DTYPE)
    if arr.size % (INPUT_C * INPUT_H * INPUT_W) != 0:
        raise HTTPException(400, "payload size is not a multiple of one image tensor")
    batch = arr.reshape(-1, INPUT_C, INPUT_H, INPUT_W)

    x = torch.from_numpy(batch).to(gpu_app.state.device)
    with torch.inference_mode():
        logits = gpu_app.state.model(x)
    out = logits.detach().to("cpu").numpy().astype(TENSOR_DTYPE, copy=False)
    return Response(content=out.tobytes(), media_type="application/octet-stream")
# {{/docs-fragment gpu-infer}}

# {{docs-fragment gpu-env}}
gpu_env = FastAPIAppEnvironment(
    name="serving-graph-gpu",
    app=gpu_app,
    image=gpu_image,
    resources=flyte.Resources(cpu=2, memory="8Gi", gpu="A10G:1"),
    # GPU replicas are expensive; keep at least one warm so model weights stay
    # resident, and cap the max. Bump if a single replica saturates.
    scaling=flyte.app.Scaling(replicas=(1, 2)),
    requires_auth=True,
)
# {{/docs-fragment gpu-env}}

# ===========================================================================
# CPU app — pre/postprocess, calls the GPU app
# ===========================================================================

IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=TENSOR_DTYPE).reshape(3, 1, 1)
IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=TENSOR_DTYPE).reshape(3, 1, 1)

class ClassifyRequest(BaseModel):
    image_url: str
    top_k: int = 5

class Prediction(BaseModel):
    label: str
    score: float

# {{docs-fragment cpu-preprocess}}
def _preprocess(img_bytes: bytes) -> np.ndarray:
    """Decode → denoise → resize → normalize. CPU-bound, deliberately so.

    Real preprocessing stacks (detection, OCR, audio) do substantially more
    than this — sliding window crops, color-space conversion, etc. The point
    is that none of it benefits from a GPU sitting next to it.
    """
    img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
    img = img.filter(ImageFilter.GaussianBlur(radius=1.0))
    img = img.resize((INPUT_W, INPUT_H), Image.BILINEAR)
    arr = np.asarray(img, dtype=TENSOR_DTYPE) / 255.0
    arr = arr.transpose(2, 0, 1)  # HWC → CHW
    arr = (arr - IMAGENET_MEAN) / IMAGENET_STD
    return np.ascontiguousarray(arr, dtype=TENSOR_DTYPE)
# {{/docs-fragment cpu-preprocess}}

def _softmax(x: np.ndarray, axis: int = -1) -> np.ndarray:
    x = x - x.max(axis=axis, keepdims=True)
    e = np.exp(x)
    return e / e.sum(axis=axis, keepdims=True)

# {{docs-fragment cpu-lifespan}}
@asynccontextmanager
async def _cpu_lifespan(app: FastAPI):
    # Resolved at serving time via the cluster-internal endpoint pattern,
    # so this stays correct across local/remote deploys without an env var.
    gpu_url = gpu_env.endpoint
    log = logging.getLogger(__name__)
    log.info("resolved GPU endpoint: %s", gpu_url)
    async with httpx.AsyncClient(timeout=30.0) as bootstrap:
        try:
            r = await bootstrap.get(f"{gpu_url}/labels")
            r.raise_for_status()
        except (httpx.HTTPError, OSError) as e:
            # Most common reason on a fresh deploy: GPU replica hasn't finished
            # pulling its image / loading weights yet. Crash-looping is fine —
            # the next attempt will likely succeed — but make the cause obvious.
            log.error("downstream GPU app at %s not ready: %s", gpu_url, e)
            raise
        app.state.labels = r.json()
    # One persistent client per replica — avoids TCP/TLS handshake per request,
    # which matters once you're doing 100s of req/s.
    async with httpx.AsyncClient(
        base_url=gpu_url,
        timeout=httpx.Timeout(30.0, connect=5.0),
        limits=httpx.Limits(max_connections=64, max_keepalive_connections=32),
    ) as client:
        app.state.client = client
        yield

cpu_app = FastAPI(
    title="inference-cpu",
    description="Pre/post around the GPU forward pass.",
    lifespan=_cpu_lifespan,
)
# {{/docs-fragment cpu-lifespan}}

@cpu_app.get("/health")
async def cpu_health() -> dict:
    return {"status": "ok", "labels_loaded": len(cpu_app.state.labels)}

# {{docs-fragment cpu-classify}}
async def validate_public_image_url(image_url: str) -> str:
     try:
         parsed = httpx.URL(image_url)
     except Exception as exc:
         raise HTTPException(status_code=400, detail="Invalid image_url.") from exc
     if parsed.scheme not in {"http", "https"}:
         raise HTTPException(status_code=400, detail="image_url must use http or https.")
     host = parsed.host
     if not host:
         raise HTTPException(status_code=400, detail="image_url must include a hostname.")
     try:
         addr_info = socket.getaddrinfo(host, parsed.port or (443 if parsed.scheme == "https" else 80))
     except socket.gaierror as exc:
         raise HTTPException(status_code=400, detail="image_url host could not be resolved.") from exc
     for info in addr_info:
         ip_text = info[4][0]
         ip_obj = ipaddress.ip_address(ip_text)
         if not ip_obj.is_global:
             raise HTTPException(status_code=400, detail="image_url host resolves to a non-public address.")
     return str(parsed)

@cpu_app.post("/classify", response_model=list[Prediction])
async def classify(req: ClassifyRequest) -> list[Prediction]:
    async with httpx.AsyncClient(timeout=30.0) as client:
        img_resp = await client.get(await validate_public_image_url(req.image_url))
        img_resp.raise_for_status()

    tensor = _preprocess(img_resp.content)  # heavy CPU
    batch = tensor[np.newaxis, ...]  # add batch dim

    gpu_resp = await cpu_app.state.client.post(
        "/infer",
        content=batch.tobytes(),
        headers={"content-type": "application/octet-stream"},
    )
    gpu_resp.raise_for_status()
    logits = np.frombuffer(gpu_resp.content, dtype=TENSOR_DTYPE).reshape(1, NUM_CLASSES)

    probs = _softmax(logits, axis=-1)[0]  # back to CPU work
    top_idx = np.argsort(-probs)[: req.top_k]
    return [Prediction(label=cpu_app.state.labels[i], score=float(probs[i])) for i in top_idx]
# {{/docs-fragment cpu-classify}}

# {{docs-fragment cpu-env}}
cpu_env = FastAPIAppEnvironment(
    name="serving-graph-cpu",
    app=cpu_app,
    image=cpu_image,
    resources=flyte.Resources(cpu=4, memory="4Gi"),
    # Cheap, so scale wide. Use scale-to-zero (replicas=(0, 8)) for bursty
    # traffic; keep replicas=(1, 8) here to avoid cold starts in the demo.
    scaling=flyte.app.Scaling(replicas=(1, 8)),
    requires_auth=True,
    depends_on=[gpu_env],
)
# {{/docs-fragment cpu-env}}

# ===========================================================================
# Deploy
# ===========================================================================

# {{docs-fragment deploy}}
if __name__ == "__main__":
    flyte.init_from_config(
        root_dir=pathlib.Path(__file__).parent,
        log_level=logging.INFO,
    )
    app = flyte.serve(cpu_env)
    print(f"Deployed serving graph; public CPU endpoint: {app.url}")
    print("Try: curl -X POST $URL/classify -H 'content-type: application/json' \\")
    print(
        '       -d \'{"image_url": "https://upload.wikimedia.org/wikipedia/commons/4/41/Sunflower_from_Silesia2.jpg"}\''
    )
# {{/docs-fragment deploy}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/serving_graphs/image_classification.py*

### Deploy

`flyte.serve(cpu_env)` deploys both apps. The CPU app is the public entry
point; the GPU app is reached only via the cluster-internal endpoint:

```
"""
Serving graph — CPU pre/post split from a GPU forward pass.

This example shows the canonical "two-app" inference graph: heavy CPU work
on one app, the GPU forward pass on another, talking to each other over HTTP
inside the cluster.

Why split? In a typical vision/audio/feature-engineering pipeline the GPU
forward pass is fast (millis) but is sandwiched between slow CPU work
(image decode, resize, denoise, NMS, label lookup, etc.). If you put both
stages in one process you pay for an idle GPU during preprocessing. Splitting
them lets each side scale independently:

    client ──► [cpu_app  x N replicas]  ──► [gpu_app x M replicas] ──► back
                preprocess + postprocess        model.forward only
                cheap CPU, scale wide           expensive GPU, scale narrow

Wire format between the two apps is raw float32 bytes (not JSON) — for
anything tensor-shaped this is the single biggest perf knob.
"""

import io
import ipaddress
import logging
import pathlib
import socket
from contextlib import asynccontextmanager

import httpx
import numpy as np
from fastapi import FastAPI, HTTPException, Request, Response
from PIL import Image, ImageFilter
from pydantic import BaseModel

import flyte
import flyte.app
from flyte.app.extras import FastAPIAppEnvironment

# ---------------------------------------------------------------------------
# Images
# ---------------------------------------------------------------------------
# Shared base with the deps both apps need (HTTP server + numpy). The CPU and
# GPU images extend it with their own disjoint stacks — the CPU app never
# imports torch and the GPU app never imports PIL. Sharing the base layer
# means the registry only stores one copy of fastapi/uvicorn/numpy.

# {{docs-fragment images}}
base_image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
    "fastapi",
    "uvicorn",
    "numpy",
)

cpu_image = base_image.with_pip_packages(
    "httpx",
    "pillow",
)

gpu_image = base_image.with_pip_packages(
    "torch==2.7.1",
    "torchvision==0.22.1",
)
# {{/docs-fragment images}}

# ---------------------------------------------------------------------------
# Shared tensor layout
# ---------------------------------------------------------------------------

INPUT_C, INPUT_H, INPUT_W = 3, 224, 224
NUM_CLASSES = 1000
TENSOR_DTYPE = np.float32

# ===========================================================================
# GPU app — model.forward only
# ===========================================================================

# {{docs-fragment gpu-lifespan}}
@asynccontextmanager
async def _gpu_lifespan(app: FastAPI):
    # Imported lazily so the CPU app never has to import torch.
    import torch
    from torchvision.models import ResNet18_Weights, resnet18

    weights = ResNet18_Weights.IMAGENET1K_V1
    model = resnet18(weights=weights).eval()
    device = "cuda" if torch.cuda.is_available() else "cpu"
    if device == "cuda":
        model = model.to("cuda")
    app.state.model = model
    app.state.device = device
    app.state.categories = list(weights.meta["categories"])
    logging.getLogger(__name__).info("model loaded on %s", device)
    yield

gpu_app = FastAPI(
    title="inference-gpu",
    description="ResNet18 forward pass.",
    lifespan=_gpu_lifespan,
)
# {{/docs-fragment gpu-lifespan}}

@gpu_app.get("/health")
async def gpu_health() -> dict:
    return {"status": "ok", "device": gpu_app.state.device}

@gpu_app.get("/labels")
async def labels() -> list[str]:
    # Exposed so the CPU side can fetch labels once at startup instead of
    # hard-coding the ImageNet class list.
    return gpu_app.state.categories

# {{docs-fragment gpu-infer}}
@gpu_app.post("/infer")
async def infer(request: Request) -> Response:
    """Run a batched forward pass.

    Request body:  raw float32 bytes, shape (B, 3, 224, 224), C-contiguous.
    Response body: raw float32 bytes, shape (B, 1000) — raw logits.

    We deliberately do NOT use JSON here. For a batch of 32 images the tensor
    is ~19MB; JSON-serializing that is the dominant cost end-to-end.
    """
    import torch

    raw = await request.body()
    arr = np.frombuffer(raw, dtype=TENSOR_DTYPE)
    if arr.size % (INPUT_C * INPUT_H * INPUT_W) != 0:
        raise HTTPException(400, "payload size is not a multiple of one image tensor")
    batch = arr.reshape(-1, INPUT_C, INPUT_H, INPUT_W)

    x = torch.from_numpy(batch).to(gpu_app.state.device)
    with torch.inference_mode():
        logits = gpu_app.state.model(x)
    out = logits.detach().to("cpu").numpy().astype(TENSOR_DTYPE, copy=False)
    return Response(content=out.tobytes(), media_type="application/octet-stream")
# {{/docs-fragment gpu-infer}}

# {{docs-fragment gpu-env}}
gpu_env = FastAPIAppEnvironment(
    name="serving-graph-gpu",
    app=gpu_app,
    image=gpu_image,
    resources=flyte.Resources(cpu=2, memory="8Gi", gpu="A10G:1"),
    # GPU replicas are expensive; keep at least one warm so model weights stay
    # resident, and cap the max. Bump if a single replica saturates.
    scaling=flyte.app.Scaling(replicas=(1, 2)),
    requires_auth=True,
)
# {{/docs-fragment gpu-env}}

# ===========================================================================
# CPU app — pre/postprocess, calls the GPU app
# ===========================================================================

IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=TENSOR_DTYPE).reshape(3, 1, 1)
IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=TENSOR_DTYPE).reshape(3, 1, 1)

class ClassifyRequest(BaseModel):
    image_url: str
    top_k: int = 5

class Prediction(BaseModel):
    label: str
    score: float

# {{docs-fragment cpu-preprocess}}
def _preprocess(img_bytes: bytes) -> np.ndarray:
    """Decode → denoise → resize → normalize. CPU-bound, deliberately so.

    Real preprocessing stacks (detection, OCR, audio) do substantially more
    than this — sliding window crops, color-space conversion, etc. The point
    is that none of it benefits from a GPU sitting next to it.
    """
    img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
    img = img.filter(ImageFilter.GaussianBlur(radius=1.0))
    img = img.resize((INPUT_W, INPUT_H), Image.BILINEAR)
    arr = np.asarray(img, dtype=TENSOR_DTYPE) / 255.0
    arr = arr.transpose(2, 0, 1)  # HWC → CHW
    arr = (arr - IMAGENET_MEAN) / IMAGENET_STD
    return np.ascontiguousarray(arr, dtype=TENSOR_DTYPE)
# {{/docs-fragment cpu-preprocess}}

def _softmax(x: np.ndarray, axis: int = -1) -> np.ndarray:
    x = x - x.max(axis=axis, keepdims=True)
    e = np.exp(x)
    return e / e.sum(axis=axis, keepdims=True)

# {{docs-fragment cpu-lifespan}}
@asynccontextmanager
async def _cpu_lifespan(app: FastAPI):
    # Resolved at serving time via the cluster-internal endpoint pattern,
    # so this stays correct across local/remote deploys without an env var.
    gpu_url = gpu_env.endpoint
    log = logging.getLogger(__name__)
    log.info("resolved GPU endpoint: %s", gpu_url)
    async with httpx.AsyncClient(timeout=30.0) as bootstrap:
        try:
            r = await bootstrap.get(f"{gpu_url}/labels")
            r.raise_for_status()
        except (httpx.HTTPError, OSError) as e:
            # Most common reason on a fresh deploy: GPU replica hasn't finished
            # pulling its image / loading weights yet. Crash-looping is fine —
            # the next attempt will likely succeed — but make the cause obvious.
            log.error("downstream GPU app at %s not ready: %s", gpu_url, e)
            raise
        app.state.labels = r.json()
    # One persistent client per replica — avoids TCP/TLS handshake per request,
    # which matters once you're doing 100s of req/s.
    async with httpx.AsyncClient(
        base_url=gpu_url,
        timeout=httpx.Timeout(30.0, connect=5.0),
        limits=httpx.Limits(max_connections=64, max_keepalive_connections=32),
    ) as client:
        app.state.client = client
        yield

cpu_app = FastAPI(
    title="inference-cpu",
    description="Pre/post around the GPU forward pass.",
    lifespan=_cpu_lifespan,
)
# {{/docs-fragment cpu-lifespan}}

@cpu_app.get("/health")
async def cpu_health() -> dict:
    return {"status": "ok", "labels_loaded": len(cpu_app.state.labels)}

# {{docs-fragment cpu-classify}}
async def validate_public_image_url(image_url: str) -> str:
     try:
         parsed = httpx.URL(image_url)
     except Exception as exc:
         raise HTTPException(status_code=400, detail="Invalid image_url.") from exc
     if parsed.scheme not in {"http", "https"}:
         raise HTTPException(status_code=400, detail="image_url must use http or https.")
     host = parsed.host
     if not host:
         raise HTTPException(status_code=400, detail="image_url must include a hostname.")
     try:
         addr_info = socket.getaddrinfo(host, parsed.port or (443 if parsed.scheme == "https" else 80))
     except socket.gaierror as exc:
         raise HTTPException(status_code=400, detail="image_url host could not be resolved.") from exc
     for info in addr_info:
         ip_text = info[4][0]
         ip_obj = ipaddress.ip_address(ip_text)
         if not ip_obj.is_global:
             raise HTTPException(status_code=400, detail="image_url host resolves to a non-public address.")
     return str(parsed)

@cpu_app.post("/classify", response_model=list[Prediction])
async def classify(req: ClassifyRequest) -> list[Prediction]:
    async with httpx.AsyncClient(timeout=30.0) as client:
        img_resp = await client.get(await validate_public_image_url(req.image_url))
        img_resp.raise_for_status()

    tensor = _preprocess(img_resp.content)  # heavy CPU
    batch = tensor[np.newaxis, ...]  # add batch dim

    gpu_resp = await cpu_app.state.client.post(
        "/infer",
        content=batch.tobytes(),
        headers={"content-type": "application/octet-stream"},
    )
    gpu_resp.raise_for_status()
    logits = np.frombuffer(gpu_resp.content, dtype=TENSOR_DTYPE).reshape(1, NUM_CLASSES)

    probs = _softmax(logits, axis=-1)[0]  # back to CPU work
    top_idx = np.argsort(-probs)[: req.top_k]
    return [Prediction(label=cpu_app.state.labels[i], score=float(probs[i])) for i in top_idx]
# {{/docs-fragment cpu-classify}}

# {{docs-fragment cpu-env}}
cpu_env = FastAPIAppEnvironment(
    name="serving-graph-cpu",
    app=cpu_app,
    image=cpu_image,
    resources=flyte.Resources(cpu=4, memory="4Gi"),
    # Cheap, so scale wide. Use scale-to-zero (replicas=(0, 8)) for bursty
    # traffic; keep replicas=(1, 8) here to avoid cold starts in the demo.
    scaling=flyte.app.Scaling(replicas=(1, 8)),
    requires_auth=True,
    depends_on=[gpu_env],
)
# {{/docs-fragment cpu-env}}

# ===========================================================================
# Deploy
# ===========================================================================

# {{docs-fragment deploy}}
if __name__ == "__main__":
    flyte.init_from_config(
        root_dir=pathlib.Path(__file__).parent,
        log_level=logging.INFO,
    )
    app = flyte.serve(cpu_env)
    print(f"Deployed serving graph; public CPU endpoint: {app.url}")
    print("Try: curl -X POST $URL/classify -H 'content-type: application/json' \\")
    print(
        '       -d \'{"image_url": "https://upload.wikimedia.org/wikipedia/commons/4/41/Sunflower_from_Silesia2.jpg"}\''
    )
# {{/docs-fragment deploy}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/serving_graphs/image_classification.py*

## Example: A/B testing with Statsig

A serving graph also lets you shape traffic. A root app routes each incoming
request to one of two variant apps using a [Statsig](https://www.statsig.com/)
feature gate, with consistent per-user bucketing.

```mermaid
flowchart LR
    client["client"] --> root["root_app<br/>(check_gate)"]
    root -->|"gate off"| a["app_a<br/>fast-processing"]
    root -->|"gate on"| b["app_b<br/>enhanced-processing"]
```

### Statsig client singleton

The variant routing logic needs a single Statsig client per process. Wrap it
in a singleton so lifespan startup/shutdown is the only place that touches its
lifecycle:

```
import os
import typing
from contextlib import asynccontextmanager

import httpx
from fastapi import FastAPI

import flyte
from flyte.app.extras import FastAPIAppEnvironment

# {{docs-fragment statsig-client}}
class StatsigClient:
    """Singleton to manage Statsig client lifecycle."""

    _instance: "StatsigClient | None" = None
    _statsig = None

    @classmethod
    def initialize(cls, api_key: str):
        """Initialize Statsig client (call during lifespan startup)."""
        if cls._instance is None:
            cls._instance = cls()

        # Import statsig at runtime (only available in container)
        from statsig_python_core import Statsig

        cls._statsig = Statsig(api_key)
        cls._statsig.initialize().wait()

    @classmethod
    def get_client(cls):
        """Get the initialized Statsig instance."""
        if cls._statsig is None:
            raise RuntimeError("StatsigClient not initialized. Call initialize() first.")
        return cls._statsig

    @classmethod
    def shutdown(cls):
        """Shutdown Statsig client (call during lifespan shutdown)."""
        if cls._statsig is not None:
            cls._statsig.shutdown()
            cls._statsig = None
            cls._instance = None
# {{/docs-fragment statsig-client}}

# {{docs-fragment variant-apps}}
# Image with statsig-python-core for A/B testing
image = flyte.Image.from_debian_base().with_pip_packages("fastapi", "uvicorn", "httpx", "statsig-python-core")

# App A - First variant
app_a = FastAPI(
    title="App A",
    description="Variant A for A/B testing",
)

# App B - Second variant
app_b = FastAPI(
    title="App B",
    description="Variant B for A/B testing",
)
# {{/docs-fragment variant-apps}}

# {{docs-fragment root-lifespan}}
@asynccontextmanager
async def lifespan(_app: FastAPI):
    """Initialize and shutdown Statsig for A/B testing."""
    # Startup: Initialize Statsig using singleton
    api_key = os.getenv("STATSIG_API_KEY", None)
    if api_key is None:
        raise RuntimeError(f"StatsigClient API Key not set. ENV vars {os.environ}")
    StatsigClient.initialize(api_key)

    yield

    # Shutdown: Cleanup Statsig
    StatsigClient.shutdown()

# Root App - Performs A/B testing and routes to A or B
root_app = FastAPI(
    title="Root App - A/B Testing",
    description="Routes requests to App A or App B based on Statsig A/B test",
    lifespan=lifespan,
)
# {{/docs-fragment root-lifespan}}

# {{docs-fragment variant-envs}}
env_a = FastAPIAppEnvironment(
    name="app-a-variant",
    app=app_a,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
)

env_b = FastAPIAppEnvironment(
    name="app-b-variant",
    app=app_b,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
)
# {{/docs-fragment variant-envs}}

# {{docs-fragment root-env}}
env_root = FastAPIAppEnvironment(
    name="root-ab-testing-app",
    app=root_app,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
    depends_on=[env_a, env_b],
    secrets=flyte.Secret("statsig-api-key", as_env_var="STATSIG_API_KEY"),
)
# {{/docs-fragment root-env}}

# {{docs-fragment variant-endpoints}}
# App A endpoints
@app_a.get("/process/{message}")
async def process_a(message: str) -> dict[str, str]:
    return {
        "variant": "A",
        "message": f"App A processed: {message}",
        "algorithm": "fast-processing",
    }

# App B endpoints
@app_b.get("/process/{message}")
async def process_b(message: str) -> dict[str, str]:
    return {
        "variant": "B",
        "message": f"App B processed: {message}",
        "algorithm": "enhanced-processing",
    }
# {{/docs-fragment variant-endpoints}}

# {{docs-fragment routing-endpoint}}
# Root app A/B testing endpoint
@root_app.get("/process/{message}")
async def process_with_ab_test(message: str, user_key: str) -> dict[str, typing.Any]:
    """
    Process a message using A/B testing to determine which app to call.

    Args:
        message: The message to process
        user_key: User identifier for A/B test bucketing (e.g., user_id, session_id)

    Returns:
        Response from either App A or App B, plus metadata about which variant was used
    """
    # Import StatsigUser at runtime (only available in container)
    from statsig_python_core import StatsigUser

    # Get statsig client from singleton
    statsig = StatsigClient.get_client()

    # Create Statsig user with the provided key
    user = StatsigUser(user_id=user_key)

    # Check the feature gate "variant_b" to determine which variant
    # If gate is enabled, use App B; otherwise use App A
    use_variant_b = statsig.check_gate(user, "variant_b")

    # Call the appropriate app based on A/B test result
    async with httpx.AsyncClient() as client:
        if use_variant_b:
            endpoint = f"{env_b.endpoint}/process/{message}"
            response = await client.get(endpoint)
            result = response.json()
        else:
            endpoint = f"{env_a.endpoint}/process/{message}"
            response = await client.get(endpoint)
            result = response.json()

    # Add A/B test metadata to response
    return {
        "ab_test_result": {
            "user_key": user_key,
            "selected_variant": "B" if use_variant_b else "A",
            "gate_name": "variant_b",
        },
        "response": result,
    }
# {{/docs-fragment routing-endpoint}}

@root_app.get("/endpoints")
async def get_endpoints() -> dict[str, str]:
    """Get the endpoints for App A and App B."""
    return {
        "app_a_endpoint": env_a.endpoint,
        "app_b_endpoint": env_b.endpoint,
    }

@root_app.get("/")
async def index():
    """Serve the A/B testing demo HTML page."""
    from fastapi.responses import HTMLResponse

    html_content = """
    <!DOCTYPE html>
    <html lang="en">
    <head>
        <meta charset="UTF-8">
        <meta name="viewport" content="width=device-width, initial-scale=1.0">
        <title>A/B Testing Demo - Statsig</title>
        <style>
            * {
                margin: 0;
                padding: 0;
                box-sizing: border-box;
            }

            body {
                font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell,
                sans-serif;
                background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
                min-height: 100vh;
                display: flex;
                justify-content: center;
                align-items: center;
                padding: 20px;
            }

            .container {
                background: white;
                border-radius: 20px;
                box-shadow: 0 20px 60px rgba(0, 0, 0, 0.3);
                padding: 40px;
                max-width: 600px;
                width: 100%;
            }

            h1 {
                color: #333;
                margin-bottom: 10px;
                font-size: 28px;
            }

            .subtitle {
                color: #666;
                margin-bottom: 30px;
                font-size: 14px;
            }

            .form-group {
                margin-bottom: 20px;
            }

            label {
                display: block;
                margin-bottom: 8px;
                color: #555;
                font-weight: 500;
                font-size: 14px;
            }

            input {
                width: 100%;
                padding: 12px 16px;
                border: 2px solid #e0e0e0;
                border-radius: 8px;
                font-size: 14px;
                transition: border-color 0.3s;
            }

            input:focus {
                outline: none;
                border-color: #667eea;
            }

            button {
                width: 100%;
                padding: 14px;
                background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
                color: white;
                border: none;
                border-radius: 8px;
                font-size: 16px;
                font-weight: 600;
                cursor: pointer;
                transition: transform 0.2s, box-shadow 0.2s;
            }

            button:hover {
                transform: translateY(-2px);
                box-shadow: 0 10px 20px rgba(102, 126, 234, 0.4);
            }

            button:active {
                transform: translateY(0);
            }

            button:disabled {
                opacity: 0.6;
                cursor: not-allowed;
            }

            .result {
                margin-top: 30px;
                padding: 20px;
                border-radius: 12px;
                display: none;
            }

            .result.show {
                display: block;
            }

            .result.variant-a {
                background: #e3f2fd;
                border: 2px solid #2196f3;
            }

            .result.variant-b {
                background: #f3e5f5;
                border: 2px solid #9c27b0;
            }

            .result-header {
                font-size: 18px;
                font-weight: 600;
                margin-bottom: 15px;
                display: flex;
                align-items: center;
                gap: 10px;
            }

            .variant-badge {
                display: inline-block;
                padding: 4px 12px;
                border-radius: 12px;
                font-size: 12px;
                font-weight: 700;
            }

            .variant-a .variant-badge {
                background: #2196f3;
                color: white;
            }

            .variant-b .variant-badge {
                background: #9c27b0;
                color: white;
            }

            .result-content {
                margin-top: 10px;
            }

            .result-item {
                margin-bottom: 10px;
                padding: 10px;
                background: rgba(255, 255, 255, 0.8);
                border-radius: 6px;
            }

            .result-label {
                font-weight: 600;
                color: #555;
                font-size: 13px;
            }

            .result-value {
                color: #333;
                margin-top: 4px;
            }

            .error {
                background: #ffebee;
                border: 2px solid #f44336;
                color: #c62828;
                padding: 16px;
                border-radius: 8px;
                margin-top: 20px;
                display: none;
            }

            .error.show {
                display: block;
            }

            .info {
                background: #fff3e0;
                border-left: 4px solid #ff9800;
                padding: 12px 16px;
                margin-top: 20px;
                border-radius: 4px;
                font-size: 13px;
                color: #e65100;
            }
        </style>
    </head>
    <body>
        <div class="container">
            <h1>🎯 A/B Testing Demo</h1>
            <p class="subtitle">Test Statsig-powered variant selection</p>

            <form id="abTestForm">
                <div class="form-group">
                    <label for="message">Message to Process</label>
                    <input
                        type="text"
                        id="message"
                        name="message"
                        placeholder="e.g., hello, world, test"
                        required
                        value="hello"
                    >
                </div>

                <div class="form-group">
                    <label for="userKey">User Key (for A/B bucketing)</label>
                    <input
                        type="text"
                        id="userKey"
                        name="userKey"
                        placeholder="e.g., user123, session456"
                        required
                        value="user123"
                    >
                </div>

                <button type="submit" id="submitBtn">Run A/B Test</button>
            </form>

            <div id="result" class="result"></div>
            <div id="error" class="error"></div>

            <div class="info">
                💡 <strong>Tip:</strong> Try different user keys to see how Statsig routes to different variants.
                The same user key will always get the same variant (consistent bucketing).
            </div>
        </div>

        <script>
            const form = document.getElementById('abTestForm');
            const resultDiv = document.getElementById('result');
            const errorDiv = document.getElementById('error');
            const submitBtn = document.getElementById('submitBtn');

            form.addEventListener('submit', async (e) => {
                e.preventDefault();

                const message = document.getElementById('message').value;
                const userKey = document.getElementById('userKey').value;

                // Reset previous results
                resultDiv.classList.remove('show', 'variant-a', 'variant-b');
                errorDiv.classList.remove('show');
                submitBtn.disabled = true;
                submitBtn.textContent = 'Processing...';

                try {
                    const response =
                        await fetch(`/process/${encodeURIComponent(message)}?user_key=${encodeURIComponent(userKey)}`);

                    if (!response.ok) {
                        throw new Error(`HTTP error! status: ${response.status}`);
                    }

                    const data = await response.json();

                    // Display result
                    const variant = data.ab_test_result.selected_variant;
                    const variantClass = `variant-${variant.toLowerCase()}`;

                    resultDiv.className = `result show ${variantClass}`;
                    resultDiv.innerHTML = `
                        <div class="result-header">
                            <span>A/B Test Result</span>
                            <span class="variant-badge">Variant ${variant}</span>
                        </div>
                        <div class="result-content">
                            <div class="result-item">
                                <div class="result-label">User Key</div>
                                <div class="result-value">${data.ab_test_result.user_key}</div>
                            </div>
                            <div class="result-item">
                                <div class="result-label">Selected Variant</div>
                                <div class="result-value">Variant ${variant}
                                    (Gate: ${data.ab_test_result.gate_name})</div>
                            </div>
                            <div class="result-item">
                                <div class="result-label">Response from App ${variant}</div>
                                <div class="result-value">${data.response.message}</div>
                            </div>
                            <div class="result-item">
                                <div class="result-label">Algorithm</div>
                                <div class="result-value">${data.response.algorithm}</div>
                            </div>
                        </div>
                    `;

                } catch (error) {
                    errorDiv.textContent = `Error: ${error.message}`;
                    errorDiv.classList.add('show');
                } finally {
                    submitBtn.disabled = false;
                    submitBtn.textContent = 'Run A/B Test';
                }
            });
        </script>
    </body>
    </html>
    """
    return HTMLResponse(content=html_content)

# {{docs-fragment deploy}}
if __name__ == "__main__":
    flyte.init_from_config()
    flyte.deploy(env_root)
    print("Deployed A/B Testing Root App")
    print("\nUsage:")
    print("  Open your browser to '<endpoint>/' to access the interactive demo")
    print("  Or use curl: curl '<endpoint>/process/hello?user_key=user123'")
    print("\nNote: Set STATSIG_API_KEY secret to use real Statsig A/B testing.")
    print("      Create a feature gate named 'variant_b' in your Statsig dashboard.")
# {{/docs-fragment deploy}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/serving_graphs/ab_testing.py*

### Variant apps

The two variants are independent FastAPI apps with their own endpoints. Each
variant returns a payload labeled with its identity, but they're otherwise
deployed and scaled independently:

```
import os
import typing
from contextlib import asynccontextmanager

import httpx
from fastapi import FastAPI

import flyte
from flyte.app.extras import FastAPIAppEnvironment

# {{docs-fragment statsig-client}}
class StatsigClient:
    """Singleton to manage Statsig client lifecycle."""

    _instance: "StatsigClient | None" = None
    _statsig = None

    @classmethod
    def initialize(cls, api_key: str):
        """Initialize Statsig client (call during lifespan startup)."""
        if cls._instance is None:
            cls._instance = cls()

        # Import statsig at runtime (only available in container)
        from statsig_python_core import Statsig

        cls._statsig = Statsig(api_key)
        cls._statsig.initialize().wait()

    @classmethod
    def get_client(cls):
        """Get the initialized Statsig instance."""
        if cls._statsig is None:
            raise RuntimeError("StatsigClient not initialized. Call initialize() first.")
        return cls._statsig

    @classmethod
    def shutdown(cls):
        """Shutdown Statsig client (call during lifespan shutdown)."""
        if cls._statsig is not None:
            cls._statsig.shutdown()
            cls._statsig = None
            cls._instance = None
# {{/docs-fragment statsig-client}}

# {{docs-fragment variant-apps}}
# Image with statsig-python-core for A/B testing
image = flyte.Image.from_debian_base().with_pip_packages("fastapi", "uvicorn", "httpx", "statsig-python-core")

# App A - First variant
app_a = FastAPI(
    title="App A",
    description="Variant A for A/B testing",
)

# App B - Second variant
app_b = FastAPI(
    title="App B",
    description="Variant B for A/B testing",
)
# {{/docs-fragment variant-apps}}

# {{docs-fragment root-lifespan}}
@asynccontextmanager
async def lifespan(_app: FastAPI):
    """Initialize and shutdown Statsig for A/B testing."""
    # Startup: Initialize Statsig using singleton
    api_key = os.getenv("STATSIG_API_KEY", None)
    if api_key is None:
        raise RuntimeError(f"StatsigClient API Key not set. ENV vars {os.environ}")
    StatsigClient.initialize(api_key)

    yield

    # Shutdown: Cleanup Statsig
    StatsigClient.shutdown()

# Root App - Performs A/B testing and routes to A or B
root_app = FastAPI(
    title="Root App - A/B Testing",
    description="Routes requests to App A or App B based on Statsig A/B test",
    lifespan=lifespan,
)
# {{/docs-fragment root-lifespan}}

# {{docs-fragment variant-envs}}
env_a = FastAPIAppEnvironment(
    name="app-a-variant",
    app=app_a,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
)

env_b = FastAPIAppEnvironment(
    name="app-b-variant",
    app=app_b,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
)
# {{/docs-fragment variant-envs}}

# {{docs-fragment root-env}}
env_root = FastAPIAppEnvironment(
    name="root-ab-testing-app",
    app=root_app,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
    depends_on=[env_a, env_b],
    secrets=flyte.Secret("statsig-api-key", as_env_var="STATSIG_API_KEY"),
)
# {{/docs-fragment root-env}}

# {{docs-fragment variant-endpoints}}
# App A endpoints
@app_a.get("/process/{message}")
async def process_a(message: str) -> dict[str, str]:
    return {
        "variant": "A",
        "message": f"App A processed: {message}",
        "algorithm": "fast-processing",
    }

# App B endpoints
@app_b.get("/process/{message}")
async def process_b(message: str) -> dict[str, str]:
    return {
        "variant": "B",
        "message": f"App B processed: {message}",
        "algorithm": "enhanced-processing",
    }
# {{/docs-fragment variant-endpoints}}

# {{docs-fragment routing-endpoint}}
# Root app A/B testing endpoint
@root_app.get("/process/{message}")
async def process_with_ab_test(message: str, user_key: str) -> dict[str, typing.Any]:
    """
    Process a message using A/B testing to determine which app to call.

    Args:
        message: The message to process
        user_key: User identifier for A/B test bucketing (e.g., user_id, session_id)

    Returns:
        Response from either App A or App B, plus metadata about which variant was used
    """
    # Import StatsigUser at runtime (only available in container)
    from statsig_python_core import StatsigUser

    # Get statsig client from singleton
    statsig = StatsigClient.get_client()

    # Create Statsig user with the provided key
    user = StatsigUser(user_id=user_key)

    # Check the feature gate "variant_b" to determine which variant
    # If gate is enabled, use App B; otherwise use App A
    use_variant_b = statsig.check_gate(user, "variant_b")

    # Call the appropriate app based on A/B test result
    async with httpx.AsyncClient() as client:
        if use_variant_b:
            endpoint = f"{env_b.endpoint}/process/{message}"
            response = await client.get(endpoint)
            result = response.json()
        else:
            endpoint = f"{env_a.endpoint}/process/{message}"
            response = await client.get(endpoint)
            result = response.json()

    # Add A/B test metadata to response
    return {
        "ab_test_result": {
            "user_key": user_key,
            "selected_variant": "B" if use_variant_b else "A",
            "gate_name": "variant_b",
        },
        "response": result,
    }
# {{/docs-fragment routing-endpoint}}

@root_app.get("/endpoints")
async def get_endpoints() -> dict[str, str]:
    """Get the endpoints for App A and App B."""
    return {
        "app_a_endpoint": env_a.endpoint,
        "app_b_endpoint": env_b.endpoint,
    }

@root_app.get("/")
async def index():
    """Serve the A/B testing demo HTML page."""
    from fastapi.responses import HTMLResponse

    html_content = """
    <!DOCTYPE html>
    <html lang="en">
    <head>
        <meta charset="UTF-8">
        <meta name="viewport" content="width=device-width, initial-scale=1.0">
        <title>A/B Testing Demo - Statsig</title>
        <style>
            * {
                margin: 0;
                padding: 0;
                box-sizing: border-box;
            }

            body {
                font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell,
                sans-serif;
                background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
                min-height: 100vh;
                display: flex;
                justify-content: center;
                align-items: center;
                padding: 20px;
            }

            .container {
                background: white;
                border-radius: 20px;
                box-shadow: 0 20px 60px rgba(0, 0, 0, 0.3);
                padding: 40px;
                max-width: 600px;
                width: 100%;
            }

            h1 {
                color: #333;
                margin-bottom: 10px;
                font-size: 28px;
            }

            .subtitle {
                color: #666;
                margin-bottom: 30px;
                font-size: 14px;
            }

            .form-group {
                margin-bottom: 20px;
            }

            label {
                display: block;
                margin-bottom: 8px;
                color: #555;
                font-weight: 500;
                font-size: 14px;
            }

            input {
                width: 100%;
                padding: 12px 16px;
                border: 2px solid #e0e0e0;
                border-radius: 8px;
                font-size: 14px;
                transition: border-color 0.3s;
            }

            input:focus {
                outline: none;
                border-color: #667eea;
            }

            button {
                width: 100%;
                padding: 14px;
                background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
                color: white;
                border: none;
                border-radius: 8px;
                font-size: 16px;
                font-weight: 600;
                cursor: pointer;
                transition: transform 0.2s, box-shadow 0.2s;
            }

            button:hover {
                transform: translateY(-2px);
                box-shadow: 0 10px 20px rgba(102, 126, 234, 0.4);
            }

            button:active {
                transform: translateY(0);
            }

            button:disabled {
                opacity: 0.6;
                cursor: not-allowed;
            }

            .result {
                margin-top: 30px;
                padding: 20px;
                border-radius: 12px;
                display: none;
            }

            .result.show {
                display: block;
            }

            .result.variant-a {
                background: #e3f2fd;
                border: 2px solid #2196f3;
            }

            .result.variant-b {
                background: #f3e5f5;
                border: 2px solid #9c27b0;
            }

            .result-header {
                font-size: 18px;
                font-weight: 600;
                margin-bottom: 15px;
                display: flex;
                align-items: center;
                gap: 10px;
            }

            .variant-badge {
                display: inline-block;
                padding: 4px 12px;
                border-radius: 12px;
                font-size: 12px;
                font-weight: 700;
            }

            .variant-a .variant-badge {
                background: #2196f3;
                color: white;
            }

            .variant-b .variant-badge {
                background: #9c27b0;
                color: white;
            }

            .result-content {
                margin-top: 10px;
            }

            .result-item {
                margin-bottom: 10px;
                padding: 10px;
                background: rgba(255, 255, 255, 0.8);
                border-radius: 6px;
            }

            .result-label {
                font-weight: 600;
                color: #555;
                font-size: 13px;
            }

            .result-value {
                color: #333;
                margin-top: 4px;
            }

            .error {
                background: #ffebee;
                border: 2px solid #f44336;
                color: #c62828;
                padding: 16px;
                border-radius: 8px;
                margin-top: 20px;
                display: none;
            }

            .error.show {
                display: block;
            }

            .info {
                background: #fff3e0;
                border-left: 4px solid #ff9800;
                padding: 12px 16px;
                margin-top: 20px;
                border-radius: 4px;
                font-size: 13px;
                color: #e65100;
            }
        </style>
    </head>
    <body>
        <div class="container">
            <h1>🎯 A/B Testing Demo</h1>
            <p class="subtitle">Test Statsig-powered variant selection</p>

            <form id="abTestForm">
                <div class="form-group">
                    <label for="message">Message to Process</label>
                    <input
                        type="text"
                        id="message"
                        name="message"
                        placeholder="e.g., hello, world, test"
                        required
                        value="hello"
                    >
                </div>

                <div class="form-group">
                    <label for="userKey">User Key (for A/B bucketing)</label>
                    <input
                        type="text"
                        id="userKey"
                        name="userKey"
                        placeholder="e.g., user123, session456"
                        required
                        value="user123"
                    >
                </div>

                <button type="submit" id="submitBtn">Run A/B Test</button>
            </form>

            <div id="result" class="result"></div>
            <div id="error" class="error"></div>

            <div class="info">
                💡 <strong>Tip:</strong> Try different user keys to see how Statsig routes to different variants.
                The same user key will always get the same variant (consistent bucketing).
            </div>
        </div>

        <script>
            const form = document.getElementById('abTestForm');
            const resultDiv = document.getElementById('result');
            const errorDiv = document.getElementById('error');
            const submitBtn = document.getElementById('submitBtn');

            form.addEventListener('submit', async (e) => {
                e.preventDefault();

                const message = document.getElementById('message').value;
                const userKey = document.getElementById('userKey').value;

                // Reset previous results
                resultDiv.classList.remove('show', 'variant-a', 'variant-b');
                errorDiv.classList.remove('show');
                submitBtn.disabled = true;
                submitBtn.textContent = 'Processing...';

                try {
                    const response =
                        await fetch(`/process/${encodeURIComponent(message)}?user_key=${encodeURIComponent(userKey)}`);

                    if (!response.ok) {
                        throw new Error(`HTTP error! status: ${response.status}`);
                    }

                    const data = await response.json();

                    // Display result
                    const variant = data.ab_test_result.selected_variant;
                    const variantClass = `variant-${variant.toLowerCase()}`;

                    resultDiv.className = `result show ${variantClass}`;
                    resultDiv.innerHTML = `
                        <div class="result-header">
                            <span>A/B Test Result</span>
                            <span class="variant-badge">Variant ${variant}</span>
                        </div>
                        <div class="result-content">
                            <div class="result-item">
                                <div class="result-label">User Key</div>
                                <div class="result-value">${data.ab_test_result.user_key}</div>
                            </div>
                            <div class="result-item">
                                <div class="result-label">Selected Variant</div>
                                <div class="result-value">Variant ${variant}
                                    (Gate: ${data.ab_test_result.gate_name})</div>
                            </div>
                            <div class="result-item">
                                <div class="result-label">Response from App ${variant}</div>
                                <div class="result-value">${data.response.message}</div>
                            </div>
                            <div class="result-item">
                                <div class="result-label">Algorithm</div>
                                <div class="result-value">${data.response.algorithm}</div>
                            </div>
                        </div>
                    `;

                } catch (error) {
                    errorDiv.textContent = `Error: ${error.message}`;
                    errorDiv.classList.add('show');
                } finally {
                    submitBtn.disabled = false;
                    submitBtn.textContent = 'Run A/B Test';
                }
            });
        </script>
    </body>
    </html>
    """
    return HTMLResponse(content=html_content)

# {{docs-fragment deploy}}
if __name__ == "__main__":
    flyte.init_from_config()
    flyte.deploy(env_root)
    print("Deployed A/B Testing Root App")
    print("\nUsage:")
    print("  Open your browser to '<endpoint>/' to access the interactive demo")
    print("  Or use curl: curl '<endpoint>/process/hello?user_key=user123'")
    print("\nNote: Set STATSIG_API_KEY secret to use real Statsig A/B testing.")
    print("      Create a feature gate named 'variant_b' in your Statsig dashboard.")
# {{/docs-fragment deploy}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/serving_graphs/ab_testing.py*

```
import os
import typing
from contextlib import asynccontextmanager

import httpx
from fastapi import FastAPI

import flyte
from flyte.app.extras import FastAPIAppEnvironment

# {{docs-fragment statsig-client}}
class StatsigClient:
    """Singleton to manage Statsig client lifecycle."""

    _instance: "StatsigClient | None" = None
    _statsig = None

    @classmethod
    def initialize(cls, api_key: str):
        """Initialize Statsig client (call during lifespan startup)."""
        if cls._instance is None:
            cls._instance = cls()

        # Import statsig at runtime (only available in container)
        from statsig_python_core import Statsig

        cls._statsig = Statsig(api_key)
        cls._statsig.initialize().wait()

    @classmethod
    def get_client(cls):
        """Get the initialized Statsig instance."""
        if cls._statsig is None:
            raise RuntimeError("StatsigClient not initialized. Call initialize() first.")
        return cls._statsig

    @classmethod
    def shutdown(cls):
        """Shutdown Statsig client (call during lifespan shutdown)."""
        if cls._statsig is not None:
            cls._statsig.shutdown()
            cls._statsig = None
            cls._instance = None
# {{/docs-fragment statsig-client}}

# {{docs-fragment variant-apps}}
# Image with statsig-python-core for A/B testing
image = flyte.Image.from_debian_base().with_pip_packages("fastapi", "uvicorn", "httpx", "statsig-python-core")

# App A - First variant
app_a = FastAPI(
    title="App A",
    description="Variant A for A/B testing",
)

# App B - Second variant
app_b = FastAPI(
    title="App B",
    description="Variant B for A/B testing",
)
# {{/docs-fragment variant-apps}}

# {{docs-fragment root-lifespan}}
@asynccontextmanager
async def lifespan(_app: FastAPI):
    """Initialize and shutdown Statsig for A/B testing."""
    # Startup: Initialize Statsig using singleton
    api_key = os.getenv("STATSIG_API_KEY", None)
    if api_key is None:
        raise RuntimeError(f"StatsigClient API Key not set. ENV vars {os.environ}")
    StatsigClient.initialize(api_key)

    yield

    # Shutdown: Cleanup Statsig
    StatsigClient.shutdown()

# Root App - Performs A/B testing and routes to A or B
root_app = FastAPI(
    title="Root App - A/B Testing",
    description="Routes requests to App A or App B based on Statsig A/B test",
    lifespan=lifespan,
)
# {{/docs-fragment root-lifespan}}

# {{docs-fragment variant-envs}}
env_a = FastAPIAppEnvironment(
    name="app-a-variant",
    app=app_a,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
)

env_b = FastAPIAppEnvironment(
    name="app-b-variant",
    app=app_b,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
)
# {{/docs-fragment variant-envs}}

# {{docs-fragment root-env}}
env_root = FastAPIAppEnvironment(
    name="root-ab-testing-app",
    app=root_app,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
    depends_on=[env_a, env_b],
    secrets=flyte.Secret("statsig-api-key", as_env_var="STATSIG_API_KEY"),
)
# {{/docs-fragment root-env}}

# {{docs-fragment variant-endpoints}}
# App A endpoints
@app_a.get("/process/{message}")
async def process_a(message: str) -> dict[str, str]:
    return {
        "variant": "A",
        "message": f"App A processed: {message}",
        "algorithm": "fast-processing",
    }

# App B endpoints
@app_b.get("/process/{message}")
async def process_b(message: str) -> dict[str, str]:
    return {
        "variant": "B",
        "message": f"App B processed: {message}",
        "algorithm": "enhanced-processing",
    }
# {{/docs-fragment variant-endpoints}}

# {{docs-fragment routing-endpoint}}
# Root app A/B testing endpoint
@root_app.get("/process/{message}")
async def process_with_ab_test(message: str, user_key: str) -> dict[str, typing.Any]:
    """
    Process a message using A/B testing to determine which app to call.

    Args:
        message: The message to process
        user_key: User identifier for A/B test bucketing (e.g., user_id, session_id)

    Returns:
        Response from either App A or App B, plus metadata about which variant was used
    """
    # Import StatsigUser at runtime (only available in container)
    from statsig_python_core import StatsigUser

    # Get statsig client from singleton
    statsig = StatsigClient.get_client()

    # Create Statsig user with the provided key
    user = StatsigUser(user_id=user_key)

    # Check the feature gate "variant_b" to determine which variant
    # If gate is enabled, use App B; otherwise use App A
    use_variant_b = statsig.check_gate(user, "variant_b")

    # Call the appropriate app based on A/B test result
    async with httpx.AsyncClient() as client:
        if use_variant_b:
            endpoint = f"{env_b.endpoint}/process/{message}"
            response = await client.get(endpoint)
            result = response.json()
        else:
            endpoint = f"{env_a.endpoint}/process/{message}"
            response = await client.get(endpoint)
            result = response.json()

    # Add A/B test metadata to response
    return {
        "ab_test_result": {
            "user_key": user_key,
            "selected_variant": "B" if use_variant_b else "A",
            "gate_name": "variant_b",
        },
        "response": result,
    }
# {{/docs-fragment routing-endpoint}}

@root_app.get("/endpoints")
async def get_endpoints() -> dict[str, str]:
    """Get the endpoints for App A and App B."""
    return {
        "app_a_endpoint": env_a.endpoint,
        "app_b_endpoint": env_b.endpoint,
    }

@root_app.get("/")
async def index():
    """Serve the A/B testing demo HTML page."""
    from fastapi.responses import HTMLResponse

    html_content = """
    <!DOCTYPE html>
    <html lang="en">
    <head>
        <meta charset="UTF-8">
        <meta name="viewport" content="width=device-width, initial-scale=1.0">
        <title>A/B Testing Demo - Statsig</title>
        <style>
            * {
                margin: 0;
                padding: 0;
                box-sizing: border-box;
            }

            body {
                font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell,
                sans-serif;
                background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
                min-height: 100vh;
                display: flex;
                justify-content: center;
                align-items: center;
                padding: 20px;
            }

            .container {
                background: white;
                border-radius: 20px;
                box-shadow: 0 20px 60px rgba(0, 0, 0, 0.3);
                padding: 40px;
                max-width: 600px;
                width: 100%;
            }

            h1 {
                color: #333;
                margin-bottom: 10px;
                font-size: 28px;
            }

            .subtitle {
                color: #666;
                margin-bottom: 30px;
                font-size: 14px;
            }

            .form-group {
                margin-bottom: 20px;
            }

            label {
                display: block;
                margin-bottom: 8px;
                color: #555;
                font-weight: 500;
                font-size: 14px;
            }

            input {
                width: 100%;
                padding: 12px 16px;
                border: 2px solid #e0e0e0;
                border-radius: 8px;
                font-size: 14px;
                transition: border-color 0.3s;
            }

            input:focus {
                outline: none;
                border-color: #667eea;
            }

            button {
                width: 100%;
                padding: 14px;
                background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
                color: white;
                border: none;
                border-radius: 8px;
                font-size: 16px;
                font-weight: 600;
                cursor: pointer;
                transition: transform 0.2s, box-shadow 0.2s;
            }

            button:hover {
                transform: translateY(-2px);
                box-shadow: 0 10px 20px rgba(102, 126, 234, 0.4);
            }

            button:active {
                transform: translateY(0);
            }

            button:disabled {
                opacity: 0.6;
                cursor: not-allowed;
            }

            .result {
                margin-top: 30px;
                padding: 20px;
                border-radius: 12px;
                display: none;
            }

            .result.show {
                display: block;
            }

            .result.variant-a {
                background: #e3f2fd;
                border: 2px solid #2196f3;
            }

            .result.variant-b {
                background: #f3e5f5;
                border: 2px solid #9c27b0;
            }

            .result-header {
                font-size: 18px;
                font-weight: 600;
                margin-bottom: 15px;
                display: flex;
                align-items: center;
                gap: 10px;
            }

            .variant-badge {
                display: inline-block;
                padding: 4px 12px;
                border-radius: 12px;
                font-size: 12px;
                font-weight: 700;
            }

            .variant-a .variant-badge {
                background: #2196f3;
                color: white;
            }

            .variant-b .variant-badge {
                background: #9c27b0;
                color: white;
            }

            .result-content {
                margin-top: 10px;
            }

            .result-item {
                margin-bottom: 10px;
                padding: 10px;
                background: rgba(255, 255, 255, 0.8);
                border-radius: 6px;
            }

            .result-label {
                font-weight: 600;
                color: #555;
                font-size: 13px;
            }

            .result-value {
                color: #333;
                margin-top: 4px;
            }

            .error {
                background: #ffebee;
                border: 2px solid #f44336;
                color: #c62828;
                padding: 16px;
                border-radius: 8px;
                margin-top: 20px;
                display: none;
            }

            .error.show {
                display: block;
            }

            .info {
                background: #fff3e0;
                border-left: 4px solid #ff9800;
                padding: 12px 16px;
                margin-top: 20px;
                border-radius: 4px;
                font-size: 13px;
                color: #e65100;
            }
        </style>
    </head>
    <body>
        <div class="container">
            <h1>🎯 A/B Testing Demo</h1>
            <p class="subtitle">Test Statsig-powered variant selection</p>

            <form id="abTestForm">
                <div class="form-group">
                    <label for="message">Message to Process</label>
                    <input
                        type="text"
                        id="message"
                        name="message"
                        placeholder="e.g., hello, world, test"
                        required
                        value="hello"
                    >
                </div>

                <div class="form-group">
                    <label for="userKey">User Key (for A/B bucketing)</label>
                    <input
                        type="text"
                        id="userKey"
                        name="userKey"
                        placeholder="e.g., user123, session456"
                        required
                        value="user123"
                    >
                </div>

                <button type="submit" id="submitBtn">Run A/B Test</button>
            </form>

            <div id="result" class="result"></div>
            <div id="error" class="error"></div>

            <div class="info">
                💡 <strong>Tip:</strong> Try different user keys to see how Statsig routes to different variants.
                The same user key will always get the same variant (consistent bucketing).
            </div>
        </div>

        <script>
            const form = document.getElementById('abTestForm');
            const resultDiv = document.getElementById('result');
            const errorDiv = document.getElementById('error');
            const submitBtn = document.getElementById('submitBtn');

            form.addEventListener('submit', async (e) => {
                e.preventDefault();

                const message = document.getElementById('message').value;
                const userKey = document.getElementById('userKey').value;

                // Reset previous results
                resultDiv.classList.remove('show', 'variant-a', 'variant-b');
                errorDiv.classList.remove('show');
                submitBtn.disabled = true;
                submitBtn.textContent = 'Processing...';

                try {
                    const response =
                        await fetch(`/process/${encodeURIComponent(message)}?user_key=${encodeURIComponent(userKey)}`);

                    if (!response.ok) {
                        throw new Error(`HTTP error! status: ${response.status}`);
                    }

                    const data = await response.json();

                    // Display result
                    const variant = data.ab_test_result.selected_variant;
                    const variantClass = `variant-${variant.toLowerCase()}`;

                    resultDiv.className = `result show ${variantClass}`;
                    resultDiv.innerHTML = `
                        <div class="result-header">
                            <span>A/B Test Result</span>
                            <span class="variant-badge">Variant ${variant}</span>
                        </div>
                        <div class="result-content">
                            <div class="result-item">
                                <div class="result-label">User Key</div>
                                <div class="result-value">${data.ab_test_result.user_key}</div>
                            </div>
                            <div class="result-item">
                                <div class="result-label">Selected Variant</div>
                                <div class="result-value">Variant ${variant}
                                    (Gate: ${data.ab_test_result.gate_name})</div>
                            </div>
                            <div class="result-item">
                                <div class="result-label">Response from App ${variant}</div>
                                <div class="result-value">${data.response.message}</div>
                            </div>
                            <div class="result-item">
                                <div class="result-label">Algorithm</div>
                                <div class="result-value">${data.response.algorithm}</div>
                            </div>
                        </div>
                    `;

                } catch (error) {
                    errorDiv.textContent = `Error: ${error.message}`;
                    errorDiv.classList.add('show');
                } finally {
                    submitBtn.disabled = false;
                    submitBtn.textContent = 'Run A/B Test';
                }
            });
        </script>
    </body>
    </html>
    """
    return HTMLResponse(content=html_content)

# {{docs-fragment deploy}}
if __name__ == "__main__":
    flyte.init_from_config()
    flyte.deploy(env_root)
    print("Deployed A/B Testing Root App")
    print("\nUsage:")
    print("  Open your browser to '<endpoint>/' to access the interactive demo")
    print("  Or use curl: curl '<endpoint>/process/hello?user_key=user123'")
    print("\nNote: Set STATSIG_API_KEY secret to use real Statsig A/B testing.")
    print("      Create a feature gate named 'variant_b' in your Statsig dashboard.")
# {{/docs-fragment deploy}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/serving_graphs/ab_testing.py*

### Root app with Statsig in its lifespan

The root app's lifespan initializes Statsig at startup and shuts it down
cleanly. The API key arrives as an env var because the env is configured with
a Flyte secret (see below):

```
import os
import typing
from contextlib import asynccontextmanager

import httpx
from fastapi import FastAPI

import flyte
from flyte.app.extras import FastAPIAppEnvironment

# {{docs-fragment statsig-client}}
class StatsigClient:
    """Singleton to manage Statsig client lifecycle."""

    _instance: "StatsigClient | None" = None
    _statsig = None

    @classmethod
    def initialize(cls, api_key: str):
        """Initialize Statsig client (call during lifespan startup)."""
        if cls._instance is None:
            cls._instance = cls()

        # Import statsig at runtime (only available in container)
        from statsig_python_core import Statsig

        cls._statsig = Statsig(api_key)
        cls._statsig.initialize().wait()

    @classmethod
    def get_client(cls):
        """Get the initialized Statsig instance."""
        if cls._statsig is None:
            raise RuntimeError("StatsigClient not initialized. Call initialize() first.")
        return cls._statsig

    @classmethod
    def shutdown(cls):
        """Shutdown Statsig client (call during lifespan shutdown)."""
        if cls._statsig is not None:
            cls._statsig.shutdown()
            cls._statsig = None
            cls._instance = None
# {{/docs-fragment statsig-client}}

# {{docs-fragment variant-apps}}
# Image with statsig-python-core for A/B testing
image = flyte.Image.from_debian_base().with_pip_packages("fastapi", "uvicorn", "httpx", "statsig-python-core")

# App A - First variant
app_a = FastAPI(
    title="App A",
    description="Variant A for A/B testing",
)

# App B - Second variant
app_b = FastAPI(
    title="App B",
    description="Variant B for A/B testing",
)
# {{/docs-fragment variant-apps}}

# {{docs-fragment root-lifespan}}
@asynccontextmanager
async def lifespan(_app: FastAPI):
    """Initialize and shutdown Statsig for A/B testing."""
    # Startup: Initialize Statsig using singleton
    api_key = os.getenv("STATSIG_API_KEY", None)
    if api_key is None:
        raise RuntimeError(f"StatsigClient API Key not set. ENV vars {os.environ}")
    StatsigClient.initialize(api_key)

    yield

    # Shutdown: Cleanup Statsig
    StatsigClient.shutdown()

# Root App - Performs A/B testing and routes to A or B
root_app = FastAPI(
    title="Root App - A/B Testing",
    description="Routes requests to App A or App B based on Statsig A/B test",
    lifespan=lifespan,
)
# {{/docs-fragment root-lifespan}}

# {{docs-fragment variant-envs}}
env_a = FastAPIAppEnvironment(
    name="app-a-variant",
    app=app_a,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
)

env_b = FastAPIAppEnvironment(
    name="app-b-variant",
    app=app_b,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
)
# {{/docs-fragment variant-envs}}

# {{docs-fragment root-env}}
env_root = FastAPIAppEnvironment(
    name="root-ab-testing-app",
    app=root_app,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
    depends_on=[env_a, env_b],
    secrets=flyte.Secret("statsig-api-key", as_env_var="STATSIG_API_KEY"),
)
# {{/docs-fragment root-env}}

# {{docs-fragment variant-endpoints}}
# App A endpoints
@app_a.get("/process/{message}")
async def process_a(message: str) -> dict[str, str]:
    return {
        "variant": "A",
        "message": f"App A processed: {message}",
        "algorithm": "fast-processing",
    }

# App B endpoints
@app_b.get("/process/{message}")
async def process_b(message: str) -> dict[str, str]:
    return {
        "variant": "B",
        "message": f"App B processed: {message}",
        "algorithm": "enhanced-processing",
    }
# {{/docs-fragment variant-endpoints}}

# {{docs-fragment routing-endpoint}}
# Root app A/B testing endpoint
@root_app.get("/process/{message}")
async def process_with_ab_test(message: str, user_key: str) -> dict[str, typing.Any]:
    """
    Process a message using A/B testing to determine which app to call.

    Args:
        message: The message to process
        user_key: User identifier for A/B test bucketing (e.g., user_id, session_id)

    Returns:
        Response from either App A or App B, plus metadata about which variant was used
    """
    # Import StatsigUser at runtime (only available in container)
    from statsig_python_core import StatsigUser

    # Get statsig client from singleton
    statsig = StatsigClient.get_client()

    # Create Statsig user with the provided key
    user = StatsigUser(user_id=user_key)

    # Check the feature gate "variant_b" to determine which variant
    # If gate is enabled, use App B; otherwise use App A
    use_variant_b = statsig.check_gate(user, "variant_b")

    # Call the appropriate app based on A/B test result
    async with httpx.AsyncClient() as client:
        if use_variant_b:
            endpoint = f"{env_b.endpoint}/process/{message}"
            response = await client.get(endpoint)
            result = response.json()
        else:
            endpoint = f"{env_a.endpoint}/process/{message}"
            response = await client.get(endpoint)
            result = response.json()

    # Add A/B test metadata to response
    return {
        "ab_test_result": {
            "user_key": user_key,
            "selected_variant": "B" if use_variant_b else "A",
            "gate_name": "variant_b",
        },
        "response": result,
    }
# {{/docs-fragment routing-endpoint}}

@root_app.get("/endpoints")
async def get_endpoints() -> dict[str, str]:
    """Get the endpoints for App A and App B."""
    return {
        "app_a_endpoint": env_a.endpoint,
        "app_b_endpoint": env_b.endpoint,
    }

@root_app.get("/")
async def index():
    """Serve the A/B testing demo HTML page."""
    from fastapi.responses import HTMLResponse

    html_content = """
    <!DOCTYPE html>
    <html lang="en">
    <head>
        <meta charset="UTF-8">
        <meta name="viewport" content="width=device-width, initial-scale=1.0">
        <title>A/B Testing Demo - Statsig</title>
        <style>
            * {
                margin: 0;
                padding: 0;
                box-sizing: border-box;
            }

            body {
                font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell,
                sans-serif;
                background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
                min-height: 100vh;
                display: flex;
                justify-content: center;
                align-items: center;
                padding: 20px;
            }

            .container {
                background: white;
                border-radius: 20px;
                box-shadow: 0 20px 60px rgba(0, 0, 0, 0.3);
                padding: 40px;
                max-width: 600px;
                width: 100%;
            }

            h1 {
                color: #333;
                margin-bottom: 10px;
                font-size: 28px;
            }

            .subtitle {
                color: #666;
                margin-bottom: 30px;
                font-size: 14px;
            }

            .form-group {
                margin-bottom: 20px;
            }

            label {
                display: block;
                margin-bottom: 8px;
                color: #555;
                font-weight: 500;
                font-size: 14px;
            }

            input {
                width: 100%;
                padding: 12px 16px;
                border: 2px solid #e0e0e0;
                border-radius: 8px;
                font-size: 14px;
                transition: border-color 0.3s;
            }

            input:focus {
                outline: none;
                border-color: #667eea;
            }

            button {
                width: 100%;
                padding: 14px;
                background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
                color: white;
                border: none;
                border-radius: 8px;
                font-size: 16px;
                font-weight: 600;
                cursor: pointer;
                transition: transform 0.2s, box-shadow 0.2s;
            }

            button:hover {
                transform: translateY(-2px);
                box-shadow: 0 10px 20px rgba(102, 126, 234, 0.4);
            }

            button:active {
                transform: translateY(0);
            }

            button:disabled {
                opacity: 0.6;
                cursor: not-allowed;
            }

            .result {
                margin-top: 30px;
                padding: 20px;
                border-radius: 12px;
                display: none;
            }

            .result.show {
                display: block;
            }

            .result.variant-a {
                background: #e3f2fd;
                border: 2px solid #2196f3;
            }

            .result.variant-b {
                background: #f3e5f5;
                border: 2px solid #9c27b0;
            }

            .result-header {
                font-size: 18px;
                font-weight: 600;
                margin-bottom: 15px;
                display: flex;
                align-items: center;
                gap: 10px;
            }

            .variant-badge {
                display: inline-block;
                padding: 4px 12px;
                border-radius: 12px;
                font-size: 12px;
                font-weight: 700;
            }

            .variant-a .variant-badge {
                background: #2196f3;
                color: white;
            }

            .variant-b .variant-badge {
                background: #9c27b0;
                color: white;
            }

            .result-content {
                margin-top: 10px;
            }

            .result-item {
                margin-bottom: 10px;
                padding: 10px;
                background: rgba(255, 255, 255, 0.8);
                border-radius: 6px;
            }

            .result-label {
                font-weight: 600;
                color: #555;
                font-size: 13px;
            }

            .result-value {
                color: #333;
                margin-top: 4px;
            }

            .error {
                background: #ffebee;
                border: 2px solid #f44336;
                color: #c62828;
                padding: 16px;
                border-radius: 8px;
                margin-top: 20px;
                display: none;
            }

            .error.show {
                display: block;
            }

            .info {
                background: #fff3e0;
                border-left: 4px solid #ff9800;
                padding: 12px 16px;
                margin-top: 20px;
                border-radius: 4px;
                font-size: 13px;
                color: #e65100;
            }
        </style>
    </head>
    <body>
        <div class="container">
            <h1>🎯 A/B Testing Demo</h1>
            <p class="subtitle">Test Statsig-powered variant selection</p>

            <form id="abTestForm">
                <div class="form-group">
                    <label for="message">Message to Process</label>
                    <input
                        type="text"
                        id="message"
                        name="message"
                        placeholder="e.g., hello, world, test"
                        required
                        value="hello"
                    >
                </div>

                <div class="form-group">
                    <label for="userKey">User Key (for A/B bucketing)</label>
                    <input
                        type="text"
                        id="userKey"
                        name="userKey"
                        placeholder="e.g., user123, session456"
                        required
                        value="user123"
                    >
                </div>

                <button type="submit" id="submitBtn">Run A/B Test</button>
            </form>

            <div id="result" class="result"></div>
            <div id="error" class="error"></div>

            <div class="info">
                💡 <strong>Tip:</strong> Try different user keys to see how Statsig routes to different variants.
                The same user key will always get the same variant (consistent bucketing).
            </div>
        </div>

        <script>
            const form = document.getElementById('abTestForm');
            const resultDiv = document.getElementById('result');
            const errorDiv = document.getElementById('error');
            const submitBtn = document.getElementById('submitBtn');

            form.addEventListener('submit', async (e) => {
                e.preventDefault();

                const message = document.getElementById('message').value;
                const userKey = document.getElementById('userKey').value;

                // Reset previous results
                resultDiv.classList.remove('show', 'variant-a', 'variant-b');
                errorDiv.classList.remove('show');
                submitBtn.disabled = true;
                submitBtn.textContent = 'Processing...';

                try {
                    const response =
                        await fetch(`/process/${encodeURIComponent(message)}?user_key=${encodeURIComponent(userKey)}`);

                    if (!response.ok) {
                        throw new Error(`HTTP error! status: ${response.status}`);
                    }

                    const data = await response.json();

                    // Display result
                    const variant = data.ab_test_result.selected_variant;
                    const variantClass = `variant-${variant.toLowerCase()}`;

                    resultDiv.className = `result show ${variantClass}`;
                    resultDiv.innerHTML = `
                        <div class="result-header">
                            <span>A/B Test Result</span>
                            <span class="variant-badge">Variant ${variant}</span>
                        </div>
                        <div class="result-content">
                            <div class="result-item">
                                <div class="result-label">User Key</div>
                                <div class="result-value">${data.ab_test_result.user_key}</div>
                            </div>
                            <div class="result-item">
                                <div class="result-label">Selected Variant</div>
                                <div class="result-value">Variant ${variant}
                                    (Gate: ${data.ab_test_result.gate_name})</div>
                            </div>
                            <div class="result-item">
                                <div class="result-label">Response from App ${variant}</div>
                                <div class="result-value">${data.response.message}</div>
                            </div>
                            <div class="result-item">
                                <div class="result-label">Algorithm</div>
                                <div class="result-value">${data.response.algorithm}</div>
                            </div>
                        </div>
                    `;

                } catch (error) {
                    errorDiv.textContent = `Error: ${error.message}`;
                    errorDiv.classList.add('show');
                } finally {
                    submitBtn.disabled = false;
                    submitBtn.textContent = 'Run A/B Test';
                }
            });
        </script>
    </body>
    </html>
    """
    return HTMLResponse(content=html_content)

# {{docs-fragment deploy}}
if __name__ == "__main__":
    flyte.init_from_config()
    flyte.deploy(env_root)
    print("Deployed A/B Testing Root App")
    print("\nUsage:")
    print("  Open your browser to '<endpoint>/' to access the interactive demo")
    print("  Or use curl: curl '<endpoint>/process/hello?user_key=user123'")
    print("\nNote: Set STATSIG_API_KEY secret to use real Statsig A/B testing.")
    print("      Create a feature gate named 'variant_b' in your Statsig dashboard.")
# {{/docs-fragment deploy}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/serving_graphs/ab_testing.py*

### App environments

Variant envs are minimal:

```
import os
import typing
from contextlib import asynccontextmanager

import httpx
from fastapi import FastAPI

import flyte
from flyte.app.extras import FastAPIAppEnvironment

# {{docs-fragment statsig-client}}
class StatsigClient:
    """Singleton to manage Statsig client lifecycle."""

    _instance: "StatsigClient | None" = None
    _statsig = None

    @classmethod
    def initialize(cls, api_key: str):
        """Initialize Statsig client (call during lifespan startup)."""
        if cls._instance is None:
            cls._instance = cls()

        # Import statsig at runtime (only available in container)
        from statsig_python_core import Statsig

        cls._statsig = Statsig(api_key)
        cls._statsig.initialize().wait()

    @classmethod
    def get_client(cls):
        """Get the initialized Statsig instance."""
        if cls._statsig is None:
            raise RuntimeError("StatsigClient not initialized. Call initialize() first.")
        return cls._statsig

    @classmethod
    def shutdown(cls):
        """Shutdown Statsig client (call during lifespan shutdown)."""
        if cls._statsig is not None:
            cls._statsig.shutdown()
            cls._statsig = None
            cls._instance = None
# {{/docs-fragment statsig-client}}

# {{docs-fragment variant-apps}}
# Image with statsig-python-core for A/B testing
image = flyte.Image.from_debian_base().with_pip_packages("fastapi", "uvicorn", "httpx", "statsig-python-core")

# App A - First variant
app_a = FastAPI(
    title="App A",
    description="Variant A for A/B testing",
)

# App B - Second variant
app_b = FastAPI(
    title="App B",
    description="Variant B for A/B testing",
)
# {{/docs-fragment variant-apps}}

# {{docs-fragment root-lifespan}}
@asynccontextmanager
async def lifespan(_app: FastAPI):
    """Initialize and shutdown Statsig for A/B testing."""
    # Startup: Initialize Statsig using singleton
    api_key = os.getenv("STATSIG_API_KEY", None)
    if api_key is None:
        raise RuntimeError(f"StatsigClient API Key not set. ENV vars {os.environ}")
    StatsigClient.initialize(api_key)

    yield

    # Shutdown: Cleanup Statsig
    StatsigClient.shutdown()

# Root App - Performs A/B testing and routes to A or B
root_app = FastAPI(
    title="Root App - A/B Testing",
    description="Routes requests to App A or App B based on Statsig A/B test",
    lifespan=lifespan,
)
# {{/docs-fragment root-lifespan}}

# {{docs-fragment variant-envs}}
env_a = FastAPIAppEnvironment(
    name="app-a-variant",
    app=app_a,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
)

env_b = FastAPIAppEnvironment(
    name="app-b-variant",
    app=app_b,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
)
# {{/docs-fragment variant-envs}}

# {{docs-fragment root-env}}
env_root = FastAPIAppEnvironment(
    name="root-ab-testing-app",
    app=root_app,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
    depends_on=[env_a, env_b],
    secrets=flyte.Secret("statsig-api-key", as_env_var="STATSIG_API_KEY"),
)
# {{/docs-fragment root-env}}

# {{docs-fragment variant-endpoints}}
# App A endpoints
@app_a.get("/process/{message}")
async def process_a(message: str) -> dict[str, str]:
    return {
        "variant": "A",
        "message": f"App A processed: {message}",
        "algorithm": "fast-processing",
    }

# App B endpoints
@app_b.get("/process/{message}")
async def process_b(message: str) -> dict[str, str]:
    return {
        "variant": "B",
        "message": f"App B processed: {message}",
        "algorithm": "enhanced-processing",
    }
# {{/docs-fragment variant-endpoints}}

# {{docs-fragment routing-endpoint}}
# Root app A/B testing endpoint
@root_app.get("/process/{message}")
async def process_with_ab_test(message: str, user_key: str) -> dict[str, typing.Any]:
    """
    Process a message using A/B testing to determine which app to call.

    Args:
        message: The message to process
        user_key: User identifier for A/B test bucketing (e.g., user_id, session_id)

    Returns:
        Response from either App A or App B, plus metadata about which variant was used
    """
    # Import StatsigUser at runtime (only available in container)
    from statsig_python_core import StatsigUser

    # Get statsig client from singleton
    statsig = StatsigClient.get_client()

    # Create Statsig user with the provided key
    user = StatsigUser(user_id=user_key)

    # Check the feature gate "variant_b" to determine which variant
    # If gate is enabled, use App B; otherwise use App A
    use_variant_b = statsig.check_gate(user, "variant_b")

    # Call the appropriate app based on A/B test result
    async with httpx.AsyncClient() as client:
        if use_variant_b:
            endpoint = f"{env_b.endpoint}/process/{message}"
            response = await client.get(endpoint)
            result = response.json()
        else:
            endpoint = f"{env_a.endpoint}/process/{message}"
            response = await client.get(endpoint)
            result = response.json()

    # Add A/B test metadata to response
    return {
        "ab_test_result": {
            "user_key": user_key,
            "selected_variant": "B" if use_variant_b else "A",
            "gate_name": "variant_b",
        },
        "response": result,
    }
# {{/docs-fragment routing-endpoint}}

@root_app.get("/endpoints")
async def get_endpoints() -> dict[str, str]:
    """Get the endpoints for App A and App B."""
    return {
        "app_a_endpoint": env_a.endpoint,
        "app_b_endpoint": env_b.endpoint,
    }

@root_app.get("/")
async def index():
    """Serve the A/B testing demo HTML page."""
    from fastapi.responses import HTMLResponse

    html_content = """
    <!DOCTYPE html>
    <html lang="en">
    <head>
        <meta charset="UTF-8">
        <meta name="viewport" content="width=device-width, initial-scale=1.0">
        <title>A/B Testing Demo - Statsig</title>
        <style>
            * {
                margin: 0;
                padding: 0;
                box-sizing: border-box;
            }

            body {
                font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell,
                sans-serif;
                background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
                min-height: 100vh;
                display: flex;
                justify-content: center;
                align-items: center;
                padding: 20px;
            }

            .container {
                background: white;
                border-radius: 20px;
                box-shadow: 0 20px 60px rgba(0, 0, 0, 0.3);
                padding: 40px;
                max-width: 600px;
                width: 100%;
            }

            h1 {
                color: #333;
                margin-bottom: 10px;
                font-size: 28px;
            }

            .subtitle {
                color: #666;
                margin-bottom: 30px;
                font-size: 14px;
            }

            .form-group {
                margin-bottom: 20px;
            }

            label {
                display: block;
                margin-bottom: 8px;
                color: #555;
                font-weight: 500;
                font-size: 14px;
            }

            input {
                width: 100%;
                padding: 12px 16px;
                border: 2px solid #e0e0e0;
                border-radius: 8px;
                font-size: 14px;
                transition: border-color 0.3s;
            }

            input:focus {
                outline: none;
                border-color: #667eea;
            }

            button {
                width: 100%;
                padding: 14px;
                background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
                color: white;
                border: none;
                border-radius: 8px;
                font-size: 16px;
                font-weight: 600;
                cursor: pointer;
                transition: transform 0.2s, box-shadow 0.2s;
            }

            button:hover {
                transform: translateY(-2px);
                box-shadow: 0 10px 20px rgba(102, 126, 234, 0.4);
            }

            button:active {
                transform: translateY(0);
            }

            button:disabled {
                opacity: 0.6;
                cursor: not-allowed;
            }

            .result {
                margin-top: 30px;
                padding: 20px;
                border-radius: 12px;
                display: none;
            }

            .result.show {
                display: block;
            }

            .result.variant-a {
                background: #e3f2fd;
                border: 2px solid #2196f3;
            }

            .result.variant-b {
                background: #f3e5f5;
                border: 2px solid #9c27b0;
            }

            .result-header {
                font-size: 18px;
                font-weight: 600;
                margin-bottom: 15px;
                display: flex;
                align-items: center;
                gap: 10px;
            }

            .variant-badge {
                display: inline-block;
                padding: 4px 12px;
                border-radius: 12px;
                font-size: 12px;
                font-weight: 700;
            }

            .variant-a .variant-badge {
                background: #2196f3;
                color: white;
            }

            .variant-b .variant-badge {
                background: #9c27b0;
                color: white;
            }

            .result-content {
                margin-top: 10px;
            }

            .result-item {
                margin-bottom: 10px;
                padding: 10px;
                background: rgba(255, 255, 255, 0.8);
                border-radius: 6px;
            }

            .result-label {
                font-weight: 600;
                color: #555;
                font-size: 13px;
            }

            .result-value {
                color: #333;
                margin-top: 4px;
            }

            .error {
                background: #ffebee;
                border: 2px solid #f44336;
                color: #c62828;
                padding: 16px;
                border-radius: 8px;
                margin-top: 20px;
                display: none;
            }

            .error.show {
                display: block;
            }

            .info {
                background: #fff3e0;
                border-left: 4px solid #ff9800;
                padding: 12px 16px;
                margin-top: 20px;
                border-radius: 4px;
                font-size: 13px;
                color: #e65100;
            }
        </style>
    </head>
    <body>
        <div class="container">
            <h1>🎯 A/B Testing Demo</h1>
            <p class="subtitle">Test Statsig-powered variant selection</p>

            <form id="abTestForm">
                <div class="form-group">
                    <label for="message">Message to Process</label>
                    <input
                        type="text"
                        id="message"
                        name="message"
                        placeholder="e.g., hello, world, test"
                        required
                        value="hello"
                    >
                </div>

                <div class="form-group">
                    <label for="userKey">User Key (for A/B bucketing)</label>
                    <input
                        type="text"
                        id="userKey"
                        name="userKey"
                        placeholder="e.g., user123, session456"
                        required
                        value="user123"
                    >
                </div>

                <button type="submit" id="submitBtn">Run A/B Test</button>
            </form>

            <div id="result" class="result"></div>
            <div id="error" class="error"></div>

            <div class="info">
                💡 <strong>Tip:</strong> Try different user keys to see how Statsig routes to different variants.
                The same user key will always get the same variant (consistent bucketing).
            </div>
        </div>

        <script>
            const form = document.getElementById('abTestForm');
            const resultDiv = document.getElementById('result');
            const errorDiv = document.getElementById('error');
            const submitBtn = document.getElementById('submitBtn');

            form.addEventListener('submit', async (e) => {
                e.preventDefault();

                const message = document.getElementById('message').value;
                const userKey = document.getElementById('userKey').value;

                // Reset previous results
                resultDiv.classList.remove('show', 'variant-a', 'variant-b');
                errorDiv.classList.remove('show');
                submitBtn.disabled = true;
                submitBtn.textContent = 'Processing...';

                try {
                    const response =
                        await fetch(`/process/${encodeURIComponent(message)}?user_key=${encodeURIComponent(userKey)}`);

                    if (!response.ok) {
                        throw new Error(`HTTP error! status: ${response.status}`);
                    }

                    const data = await response.json();

                    // Display result
                    const variant = data.ab_test_result.selected_variant;
                    const variantClass = `variant-${variant.toLowerCase()}`;

                    resultDiv.className = `result show ${variantClass}`;
                    resultDiv.innerHTML = `
                        <div class="result-header">
                            <span>A/B Test Result</span>
                            <span class="variant-badge">Variant ${variant}</span>
                        </div>
                        <div class="result-content">
                            <div class="result-item">
                                <div class="result-label">User Key</div>
                                <div class="result-value">${data.ab_test_result.user_key}</div>
                            </div>
                            <div class="result-item">
                                <div class="result-label">Selected Variant</div>
                                <div class="result-value">Variant ${variant}
                                    (Gate: ${data.ab_test_result.gate_name})</div>
                            </div>
                            <div class="result-item">
                                <div class="result-label">Response from App ${variant}</div>
                                <div class="result-value">${data.response.message}</div>
                            </div>
                            <div class="result-item">
                                <div class="result-label">Algorithm</div>
                                <div class="result-value">${data.response.algorithm}</div>
                            </div>
                        </div>
                    `;

                } catch (error) {
                    errorDiv.textContent = `Error: ${error.message}`;
                    errorDiv.classList.add('show');
                } finally {
                    submitBtn.disabled = false;
                    submitBtn.textContent = 'Run A/B Test';
                }
            });
        </script>
    </body>
    </html>
    """
    return HTMLResponse(content=html_content)

# {{docs-fragment deploy}}
if __name__ == "__main__":
    flyte.init_from_config()
    flyte.deploy(env_root)
    print("Deployed A/B Testing Root App")
    print("\nUsage:")
    print("  Open your browser to '<endpoint>/' to access the interactive demo")
    print("  Or use curl: curl '<endpoint>/process/hello?user_key=user123'")
    print("\nNote: Set STATSIG_API_KEY secret to use real Statsig A/B testing.")
    print("      Create a feature gate named 'variant_b' in your Statsig dashboard.")
# {{/docs-fragment deploy}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/serving_graphs/ab_testing.py*

The root env declares `depends_on=[env_a, env_b]` so all three deploy
together, and pulls the Statsig API key from a Flyte secret:

```
import os
import typing
from contextlib import asynccontextmanager

import httpx
from fastapi import FastAPI

import flyte
from flyte.app.extras import FastAPIAppEnvironment

# {{docs-fragment statsig-client}}
class StatsigClient:
    """Singleton to manage Statsig client lifecycle."""

    _instance: "StatsigClient | None" = None
    _statsig = None

    @classmethod
    def initialize(cls, api_key: str):
        """Initialize Statsig client (call during lifespan startup)."""
        if cls._instance is None:
            cls._instance = cls()

        # Import statsig at runtime (only available in container)
        from statsig_python_core import Statsig

        cls._statsig = Statsig(api_key)
        cls._statsig.initialize().wait()

    @classmethod
    def get_client(cls):
        """Get the initialized Statsig instance."""
        if cls._statsig is None:
            raise RuntimeError("StatsigClient not initialized. Call initialize() first.")
        return cls._statsig

    @classmethod
    def shutdown(cls):
        """Shutdown Statsig client (call during lifespan shutdown)."""
        if cls._statsig is not None:
            cls._statsig.shutdown()
            cls._statsig = None
            cls._instance = None
# {{/docs-fragment statsig-client}}

# {{docs-fragment variant-apps}}
# Image with statsig-python-core for A/B testing
image = flyte.Image.from_debian_base().with_pip_packages("fastapi", "uvicorn", "httpx", "statsig-python-core")

# App A - First variant
app_a = FastAPI(
    title="App A",
    description="Variant A for A/B testing",
)

# App B - Second variant
app_b = FastAPI(
    title="App B",
    description="Variant B for A/B testing",
)
# {{/docs-fragment variant-apps}}

# {{docs-fragment root-lifespan}}
@asynccontextmanager
async def lifespan(_app: FastAPI):
    """Initialize and shutdown Statsig for A/B testing."""
    # Startup: Initialize Statsig using singleton
    api_key = os.getenv("STATSIG_API_KEY", None)
    if api_key is None:
        raise RuntimeError(f"StatsigClient API Key not set. ENV vars {os.environ}")
    StatsigClient.initialize(api_key)

    yield

    # Shutdown: Cleanup Statsig
    StatsigClient.shutdown()

# Root App - Performs A/B testing and routes to A or B
root_app = FastAPI(
    title="Root App - A/B Testing",
    description="Routes requests to App A or App B based on Statsig A/B test",
    lifespan=lifespan,
)
# {{/docs-fragment root-lifespan}}

# {{docs-fragment variant-envs}}
env_a = FastAPIAppEnvironment(
    name="app-a-variant",
    app=app_a,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
)

env_b = FastAPIAppEnvironment(
    name="app-b-variant",
    app=app_b,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
)
# {{/docs-fragment variant-envs}}

# {{docs-fragment root-env}}
env_root = FastAPIAppEnvironment(
    name="root-ab-testing-app",
    app=root_app,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
    depends_on=[env_a, env_b],
    secrets=flyte.Secret("statsig-api-key", as_env_var="STATSIG_API_KEY"),
)
# {{/docs-fragment root-env}}

# {{docs-fragment variant-endpoints}}
# App A endpoints
@app_a.get("/process/{message}")
async def process_a(message: str) -> dict[str, str]:
    return {
        "variant": "A",
        "message": f"App A processed: {message}",
        "algorithm": "fast-processing",
    }

# App B endpoints
@app_b.get("/process/{message}")
async def process_b(message: str) -> dict[str, str]:
    return {
        "variant": "B",
        "message": f"App B processed: {message}",
        "algorithm": "enhanced-processing",
    }
# {{/docs-fragment variant-endpoints}}

# {{docs-fragment routing-endpoint}}
# Root app A/B testing endpoint
@root_app.get("/process/{message}")
async def process_with_ab_test(message: str, user_key: str) -> dict[str, typing.Any]:
    """
    Process a message using A/B testing to determine which app to call.

    Args:
        message: The message to process
        user_key: User identifier for A/B test bucketing (e.g., user_id, session_id)

    Returns:
        Response from either App A or App B, plus metadata about which variant was used
    """
    # Import StatsigUser at runtime (only available in container)
    from statsig_python_core import StatsigUser

    # Get statsig client from singleton
    statsig = StatsigClient.get_client()

    # Create Statsig user with the provided key
    user = StatsigUser(user_id=user_key)

    # Check the feature gate "variant_b" to determine which variant
    # If gate is enabled, use App B; otherwise use App A
    use_variant_b = statsig.check_gate(user, "variant_b")

    # Call the appropriate app based on A/B test result
    async with httpx.AsyncClient() as client:
        if use_variant_b:
            endpoint = f"{env_b.endpoint}/process/{message}"
            response = await client.get(endpoint)
            result = response.json()
        else:
            endpoint = f"{env_a.endpoint}/process/{message}"
            response = await client.get(endpoint)
            result = response.json()

    # Add A/B test metadata to response
    return {
        "ab_test_result": {
            "user_key": user_key,
            "selected_variant": "B" if use_variant_b else "A",
            "gate_name": "variant_b",
        },
        "response": result,
    }
# {{/docs-fragment routing-endpoint}}

@root_app.get("/endpoints")
async def get_endpoints() -> dict[str, str]:
    """Get the endpoints for App A and App B."""
    return {
        "app_a_endpoint": env_a.endpoint,
        "app_b_endpoint": env_b.endpoint,
    }

@root_app.get("/")
async def index():
    """Serve the A/B testing demo HTML page."""
    from fastapi.responses import HTMLResponse

    html_content = """
    <!DOCTYPE html>
    <html lang="en">
    <head>
        <meta charset="UTF-8">
        <meta name="viewport" content="width=device-width, initial-scale=1.0">
        <title>A/B Testing Demo - Statsig</title>
        <style>
            * {
                margin: 0;
                padding: 0;
                box-sizing: border-box;
            }

            body {
                font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell,
                sans-serif;
                background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
                min-height: 100vh;
                display: flex;
                justify-content: center;
                align-items: center;
                padding: 20px;
            }

            .container {
                background: white;
                border-radius: 20px;
                box-shadow: 0 20px 60px rgba(0, 0, 0, 0.3);
                padding: 40px;
                max-width: 600px;
                width: 100%;
            }

            h1 {
                color: #333;
                margin-bottom: 10px;
                font-size: 28px;
            }

            .subtitle {
                color: #666;
                margin-bottom: 30px;
                font-size: 14px;
            }

            .form-group {
                margin-bottom: 20px;
            }

            label {
                display: block;
                margin-bottom: 8px;
                color: #555;
                font-weight: 500;
                font-size: 14px;
            }

            input {
                width: 100%;
                padding: 12px 16px;
                border: 2px solid #e0e0e0;
                border-radius: 8px;
                font-size: 14px;
                transition: border-color 0.3s;
            }

            input:focus {
                outline: none;
                border-color: #667eea;
            }

            button {
                width: 100%;
                padding: 14px;
                background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
                color: white;
                border: none;
                border-radius: 8px;
                font-size: 16px;
                font-weight: 600;
                cursor: pointer;
                transition: transform 0.2s, box-shadow 0.2s;
            }

            button:hover {
                transform: translateY(-2px);
                box-shadow: 0 10px 20px rgba(102, 126, 234, 0.4);
            }

            button:active {
                transform: translateY(0);
            }

            button:disabled {
                opacity: 0.6;
                cursor: not-allowed;
            }

            .result {
                margin-top: 30px;
                padding: 20px;
                border-radius: 12px;
                display: none;
            }

            .result.show {
                display: block;
            }

            .result.variant-a {
                background: #e3f2fd;
                border: 2px solid #2196f3;
            }

            .result.variant-b {
                background: #f3e5f5;
                border: 2px solid #9c27b0;
            }

            .result-header {
                font-size: 18px;
                font-weight: 600;
                margin-bottom: 15px;
                display: flex;
                align-items: center;
                gap: 10px;
            }

            .variant-badge {
                display: inline-block;
                padding: 4px 12px;
                border-radius: 12px;
                font-size: 12px;
                font-weight: 700;
            }

            .variant-a .variant-badge {
                background: #2196f3;
                color: white;
            }

            .variant-b .variant-badge {
                background: #9c27b0;
                color: white;
            }

            .result-content {
                margin-top: 10px;
            }

            .result-item {
                margin-bottom: 10px;
                padding: 10px;
                background: rgba(255, 255, 255, 0.8);
                border-radius: 6px;
            }

            .result-label {
                font-weight: 600;
                color: #555;
                font-size: 13px;
            }

            .result-value {
                color: #333;
                margin-top: 4px;
            }

            .error {
                background: #ffebee;
                border: 2px solid #f44336;
                color: #c62828;
                padding: 16px;
                border-radius: 8px;
                margin-top: 20px;
                display: none;
            }

            .error.show {
                display: block;
            }

            .info {
                background: #fff3e0;
                border-left: 4px solid #ff9800;
                padding: 12px 16px;
                margin-top: 20px;
                border-radius: 4px;
                font-size: 13px;
                color: #e65100;
            }
        </style>
    </head>
    <body>
        <div class="container">
            <h1>🎯 A/B Testing Demo</h1>
            <p class="subtitle">Test Statsig-powered variant selection</p>

            <form id="abTestForm">
                <div class="form-group">
                    <label for="message">Message to Process</label>
                    <input
                        type="text"
                        id="message"
                        name="message"
                        placeholder="e.g., hello, world, test"
                        required
                        value="hello"
                    >
                </div>

                <div class="form-group">
                    <label for="userKey">User Key (for A/B bucketing)</label>
                    <input
                        type="text"
                        id="userKey"
                        name="userKey"
                        placeholder="e.g., user123, session456"
                        required
                        value="user123"
                    >
                </div>

                <button type="submit" id="submitBtn">Run A/B Test</button>
            </form>

            <div id="result" class="result"></div>
            <div id="error" class="error"></div>

            <div class="info">
                💡 <strong>Tip:</strong> Try different user keys to see how Statsig routes to different variants.
                The same user key will always get the same variant (consistent bucketing).
            </div>
        </div>

        <script>
            const form = document.getElementById('abTestForm');
            const resultDiv = document.getElementById('result');
            const errorDiv = document.getElementById('error');
            const submitBtn = document.getElementById('submitBtn');

            form.addEventListener('submit', async (e) => {
                e.preventDefault();

                const message = document.getElementById('message').value;
                const userKey = document.getElementById('userKey').value;

                // Reset previous results
                resultDiv.classList.remove('show', 'variant-a', 'variant-b');
                errorDiv.classList.remove('show');
                submitBtn.disabled = true;
                submitBtn.textContent = 'Processing...';

                try {
                    const response =
                        await fetch(`/process/${encodeURIComponent(message)}?user_key=${encodeURIComponent(userKey)}`);

                    if (!response.ok) {
                        throw new Error(`HTTP error! status: ${response.status}`);
                    }

                    const data = await response.json();

                    // Display result
                    const variant = data.ab_test_result.selected_variant;
                    const variantClass = `variant-${variant.toLowerCase()}`;

                    resultDiv.className = `result show ${variantClass}`;
                    resultDiv.innerHTML = `
                        <div class="result-header">
                            <span>A/B Test Result</span>
                            <span class="variant-badge">Variant ${variant}</span>
                        </div>
                        <div class="result-content">
                            <div class="result-item">
                                <div class="result-label">User Key</div>
                                <div class="result-value">${data.ab_test_result.user_key}</div>
                            </div>
                            <div class="result-item">
                                <div class="result-label">Selected Variant</div>
                                <div class="result-value">Variant ${variant}
                                    (Gate: ${data.ab_test_result.gate_name})</div>
                            </div>
                            <div class="result-item">
                                <div class="result-label">Response from App ${variant}</div>
                                <div class="result-value">${data.response.message}</div>
                            </div>
                            <div class="result-item">
                                <div class="result-label">Algorithm</div>
                                <div class="result-value">${data.response.algorithm}</div>
                            </div>
                        </div>
                    `;

                } catch (error) {
                    errorDiv.textContent = `Error: ${error.message}`;
                    errorDiv.classList.add('show');
                } finally {
                    submitBtn.disabled = false;
                    submitBtn.textContent = 'Run A/B Test';
                }
            });
        </script>
    </body>
    </html>
    """
    return HTMLResponse(content=html_content)

# {{docs-fragment deploy}}
if __name__ == "__main__":
    flyte.init_from_config()
    flyte.deploy(env_root)
    print("Deployed A/B Testing Root App")
    print("\nUsage:")
    print("  Open your browser to '<endpoint>/' to access the interactive demo")
    print("  Or use curl: curl '<endpoint>/process/hello?user_key=user123'")
    print("\nNote: Set STATSIG_API_KEY secret to use real Statsig A/B testing.")
    print("      Create a feature gate named 'variant_b' in your Statsig dashboard.")
# {{/docs-fragment deploy}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/serving_graphs/ab_testing.py*

### Routing endpoint

The root app checks the `variant_b` feature gate against a user key and
proxies to the matching variant using its `endpoint` property:

```
import os
import typing
from contextlib import asynccontextmanager

import httpx
from fastapi import FastAPI

import flyte
from flyte.app.extras import FastAPIAppEnvironment

# {{docs-fragment statsig-client}}
class StatsigClient:
    """Singleton to manage Statsig client lifecycle."""

    _instance: "StatsigClient | None" = None
    _statsig = None

    @classmethod
    def initialize(cls, api_key: str):
        """Initialize Statsig client (call during lifespan startup)."""
        if cls._instance is None:
            cls._instance = cls()

        # Import statsig at runtime (only available in container)
        from statsig_python_core import Statsig

        cls._statsig = Statsig(api_key)
        cls._statsig.initialize().wait()

    @classmethod
    def get_client(cls):
        """Get the initialized Statsig instance."""
        if cls._statsig is None:
            raise RuntimeError("StatsigClient not initialized. Call initialize() first.")
        return cls._statsig

    @classmethod
    def shutdown(cls):
        """Shutdown Statsig client (call during lifespan shutdown)."""
        if cls._statsig is not None:
            cls._statsig.shutdown()
            cls._statsig = None
            cls._instance = None
# {{/docs-fragment statsig-client}}

# {{docs-fragment variant-apps}}
# Image with statsig-python-core for A/B testing
image = flyte.Image.from_debian_base().with_pip_packages("fastapi", "uvicorn", "httpx", "statsig-python-core")

# App A - First variant
app_a = FastAPI(
    title="App A",
    description="Variant A for A/B testing",
)

# App B - Second variant
app_b = FastAPI(
    title="App B",
    description="Variant B for A/B testing",
)
# {{/docs-fragment variant-apps}}

# {{docs-fragment root-lifespan}}
@asynccontextmanager
async def lifespan(_app: FastAPI):
    """Initialize and shutdown Statsig for A/B testing."""
    # Startup: Initialize Statsig using singleton
    api_key = os.getenv("STATSIG_API_KEY", None)
    if api_key is None:
        raise RuntimeError(f"StatsigClient API Key not set. ENV vars {os.environ}")
    StatsigClient.initialize(api_key)

    yield

    # Shutdown: Cleanup Statsig
    StatsigClient.shutdown()

# Root App - Performs A/B testing and routes to A or B
root_app = FastAPI(
    title="Root App - A/B Testing",
    description="Routes requests to App A or App B based on Statsig A/B test",
    lifespan=lifespan,
)
# {{/docs-fragment root-lifespan}}

# {{docs-fragment variant-envs}}
env_a = FastAPIAppEnvironment(
    name="app-a-variant",
    app=app_a,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
)

env_b = FastAPIAppEnvironment(
    name="app-b-variant",
    app=app_b,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
)
# {{/docs-fragment variant-envs}}

# {{docs-fragment root-env}}
env_root = FastAPIAppEnvironment(
    name="root-ab-testing-app",
    app=root_app,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
    depends_on=[env_a, env_b],
    secrets=flyte.Secret("statsig-api-key", as_env_var="STATSIG_API_KEY"),
)
# {{/docs-fragment root-env}}

# {{docs-fragment variant-endpoints}}
# App A endpoints
@app_a.get("/process/{message}")
async def process_a(message: str) -> dict[str, str]:
    return {
        "variant": "A",
        "message": f"App A processed: {message}",
        "algorithm": "fast-processing",
    }

# App B endpoints
@app_b.get("/process/{message}")
async def process_b(message: str) -> dict[str, str]:
    return {
        "variant": "B",
        "message": f"App B processed: {message}",
        "algorithm": "enhanced-processing",
    }
# {{/docs-fragment variant-endpoints}}

# {{docs-fragment routing-endpoint}}
# Root app A/B testing endpoint
@root_app.get("/process/{message}")
async def process_with_ab_test(message: str, user_key: str) -> dict[str, typing.Any]:
    """
    Process a message using A/B testing to determine which app to call.

    Args:
        message: The message to process
        user_key: User identifier for A/B test bucketing (e.g., user_id, session_id)

    Returns:
        Response from either App A or App B, plus metadata about which variant was used
    """
    # Import StatsigUser at runtime (only available in container)
    from statsig_python_core import StatsigUser

    # Get statsig client from singleton
    statsig = StatsigClient.get_client()

    # Create Statsig user with the provided key
    user = StatsigUser(user_id=user_key)

    # Check the feature gate "variant_b" to determine which variant
    # If gate is enabled, use App B; otherwise use App A
    use_variant_b = statsig.check_gate(user, "variant_b")

    # Call the appropriate app based on A/B test result
    async with httpx.AsyncClient() as client:
        if use_variant_b:
            endpoint = f"{env_b.endpoint}/process/{message}"
            response = await client.get(endpoint)
            result = response.json()
        else:
            endpoint = f"{env_a.endpoint}/process/{message}"
            response = await client.get(endpoint)
            result = response.json()

    # Add A/B test metadata to response
    return {
        "ab_test_result": {
            "user_key": user_key,
            "selected_variant": "B" if use_variant_b else "A",
            "gate_name": "variant_b",
        },
        "response": result,
    }
# {{/docs-fragment routing-endpoint}}

@root_app.get("/endpoints")
async def get_endpoints() -> dict[str, str]:
    """Get the endpoints for App A and App B."""
    return {
        "app_a_endpoint": env_a.endpoint,
        "app_b_endpoint": env_b.endpoint,
    }

@root_app.get("/")
async def index():
    """Serve the A/B testing demo HTML page."""
    from fastapi.responses import HTMLResponse

    html_content = """
    <!DOCTYPE html>
    <html lang="en">
    <head>
        <meta charset="UTF-8">
        <meta name="viewport" content="width=device-width, initial-scale=1.0">
        <title>A/B Testing Demo - Statsig</title>
        <style>
            * {
                margin: 0;
                padding: 0;
                box-sizing: border-box;
            }

            body {
                font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell,
                sans-serif;
                background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
                min-height: 100vh;
                display: flex;
                justify-content: center;
                align-items: center;
                padding: 20px;
            }

            .container {
                background: white;
                border-radius: 20px;
                box-shadow: 0 20px 60px rgba(0, 0, 0, 0.3);
                padding: 40px;
                max-width: 600px;
                width: 100%;
            }

            h1 {
                color: #333;
                margin-bottom: 10px;
                font-size: 28px;
            }

            .subtitle {
                color: #666;
                margin-bottom: 30px;
                font-size: 14px;
            }

            .form-group {
                margin-bottom: 20px;
            }

            label {
                display: block;
                margin-bottom: 8px;
                color: #555;
                font-weight: 500;
                font-size: 14px;
            }

            input {
                width: 100%;
                padding: 12px 16px;
                border: 2px solid #e0e0e0;
                border-radius: 8px;
                font-size: 14px;
                transition: border-color 0.3s;
            }

            input:focus {
                outline: none;
                border-color: #667eea;
            }

            button {
                width: 100%;
                padding: 14px;
                background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
                color: white;
                border: none;
                border-radius: 8px;
                font-size: 16px;
                font-weight: 600;
                cursor: pointer;
                transition: transform 0.2s, box-shadow 0.2s;
            }

            button:hover {
                transform: translateY(-2px);
                box-shadow: 0 10px 20px rgba(102, 126, 234, 0.4);
            }

            button:active {
                transform: translateY(0);
            }

            button:disabled {
                opacity: 0.6;
                cursor: not-allowed;
            }

            .result {
                margin-top: 30px;
                padding: 20px;
                border-radius: 12px;
                display: none;
            }

            .result.show {
                display: block;
            }

            .result.variant-a {
                background: #e3f2fd;
                border: 2px solid #2196f3;
            }

            .result.variant-b {
                background: #f3e5f5;
                border: 2px solid #9c27b0;
            }

            .result-header {
                font-size: 18px;
                font-weight: 600;
                margin-bottom: 15px;
                display: flex;
                align-items: center;
                gap: 10px;
            }

            .variant-badge {
                display: inline-block;
                padding: 4px 12px;
                border-radius: 12px;
                font-size: 12px;
                font-weight: 700;
            }

            .variant-a .variant-badge {
                background: #2196f3;
                color: white;
            }

            .variant-b .variant-badge {
                background: #9c27b0;
                color: white;
            }

            .result-content {
                margin-top: 10px;
            }

            .result-item {
                margin-bottom: 10px;
                padding: 10px;
                background: rgba(255, 255, 255, 0.8);
                border-radius: 6px;
            }

            .result-label {
                font-weight: 600;
                color: #555;
                font-size: 13px;
            }

            .result-value {
                color: #333;
                margin-top: 4px;
            }

            .error {
                background: #ffebee;
                border: 2px solid #f44336;
                color: #c62828;
                padding: 16px;
                border-radius: 8px;
                margin-top: 20px;
                display: none;
            }

            .error.show {
                display: block;
            }

            .info {
                background: #fff3e0;
                border-left: 4px solid #ff9800;
                padding: 12px 16px;
                margin-top: 20px;
                border-radius: 4px;
                font-size: 13px;
                color: #e65100;
            }
        </style>
    </head>
    <body>
        <div class="container">
            <h1>🎯 A/B Testing Demo</h1>
            <p class="subtitle">Test Statsig-powered variant selection</p>

            <form id="abTestForm">
                <div class="form-group">
                    <label for="message">Message to Process</label>
                    <input
                        type="text"
                        id="message"
                        name="message"
                        placeholder="e.g., hello, world, test"
                        required
                        value="hello"
                    >
                </div>

                <div class="form-group">
                    <label for="userKey">User Key (for A/B bucketing)</label>
                    <input
                        type="text"
                        id="userKey"
                        name="userKey"
                        placeholder="e.g., user123, session456"
                        required
                        value="user123"
                    >
                </div>

                <button type="submit" id="submitBtn">Run A/B Test</button>
            </form>

            <div id="result" class="result"></div>
            <div id="error" class="error"></div>

            <div class="info">
                💡 <strong>Tip:</strong> Try different user keys to see how Statsig routes to different variants.
                The same user key will always get the same variant (consistent bucketing).
            </div>
        </div>

        <script>
            const form = document.getElementById('abTestForm');
            const resultDiv = document.getElementById('result');
            const errorDiv = document.getElementById('error');
            const submitBtn = document.getElementById('submitBtn');

            form.addEventListener('submit', async (e) => {
                e.preventDefault();

                const message = document.getElementById('message').value;
                const userKey = document.getElementById('userKey').value;

                // Reset previous results
                resultDiv.classList.remove('show', 'variant-a', 'variant-b');
                errorDiv.classList.remove('show');
                submitBtn.disabled = true;
                submitBtn.textContent = 'Processing...';

                try {
                    const response =
                        await fetch(`/process/${encodeURIComponent(message)}?user_key=${encodeURIComponent(userKey)}`);

                    if (!response.ok) {
                        throw new Error(`HTTP error! status: ${response.status}`);
                    }

                    const data = await response.json();

                    // Display result
                    const variant = data.ab_test_result.selected_variant;
                    const variantClass = `variant-${variant.toLowerCase()}`;

                    resultDiv.className = `result show ${variantClass}`;
                    resultDiv.innerHTML = `
                        <div class="result-header">
                            <span>A/B Test Result</span>
                            <span class="variant-badge">Variant ${variant}</span>
                        </div>
                        <div class="result-content">
                            <div class="result-item">
                                <div class="result-label">User Key</div>
                                <div class="result-value">${data.ab_test_result.user_key}</div>
                            </div>
                            <div class="result-item">
                                <div class="result-label">Selected Variant</div>
                                <div class="result-value">Variant ${variant}
                                    (Gate: ${data.ab_test_result.gate_name})</div>
                            </div>
                            <div class="result-item">
                                <div class="result-label">Response from App ${variant}</div>
                                <div class="result-value">${data.response.message}</div>
                            </div>
                            <div class="result-item">
                                <div class="result-label">Algorithm</div>
                                <div class="result-value">${data.response.algorithm}</div>
                            </div>
                        </div>
                    `;

                } catch (error) {
                    errorDiv.textContent = `Error: ${error.message}`;
                    errorDiv.classList.add('show');
                } finally {
                    submitBtn.disabled = false;
                    submitBtn.textContent = 'Run A/B Test';
                }
            });
        </script>
    </body>
    </html>
    """
    return HTMLResponse(content=html_content)

# {{docs-fragment deploy}}
if __name__ == "__main__":
    flyte.init_from_config()
    flyte.deploy(env_root)
    print("Deployed A/B Testing Root App")
    print("\nUsage:")
    print("  Open your browser to '<endpoint>/' to access the interactive demo")
    print("  Or use curl: curl '<endpoint>/process/hello?user_key=user123'")
    print("\nNote: Set STATSIG_API_KEY secret to use real Statsig A/B testing.")
    print("      Create a feature gate named 'variant_b' in your Statsig dashboard.")
# {{/docs-fragment deploy}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/serving_graphs/ab_testing.py*

Use stable identifiers (user ID, session ID) for `user_key` so the same user
always lands in the same bucket. To swap `check_gate` for an experiment or
dynamic config:

```python
experiment = statsig.get_experiment(user, "my_experiment")
variant = experiment.get("variant", "A")
```

### Deploy

```
import os
import typing
from contextlib import asynccontextmanager

import httpx
from fastapi import FastAPI

import flyte
from flyte.app.extras import FastAPIAppEnvironment

# {{docs-fragment statsig-client}}
class StatsigClient:
    """Singleton to manage Statsig client lifecycle."""

    _instance: "StatsigClient | None" = None
    _statsig = None

    @classmethod
    def initialize(cls, api_key: str):
        """Initialize Statsig client (call during lifespan startup)."""
        if cls._instance is None:
            cls._instance = cls()

        # Import statsig at runtime (only available in container)
        from statsig_python_core import Statsig

        cls._statsig = Statsig(api_key)
        cls._statsig.initialize().wait()

    @classmethod
    def get_client(cls):
        """Get the initialized Statsig instance."""
        if cls._statsig is None:
            raise RuntimeError("StatsigClient not initialized. Call initialize() first.")
        return cls._statsig

    @classmethod
    def shutdown(cls):
        """Shutdown Statsig client (call during lifespan shutdown)."""
        if cls._statsig is not None:
            cls._statsig.shutdown()
            cls._statsig = None
            cls._instance = None
# {{/docs-fragment statsig-client}}

# {{docs-fragment variant-apps}}
# Image with statsig-python-core for A/B testing
image = flyte.Image.from_debian_base().with_pip_packages("fastapi", "uvicorn", "httpx", "statsig-python-core")

# App A - First variant
app_a = FastAPI(
    title="App A",
    description="Variant A for A/B testing",
)

# App B - Second variant
app_b = FastAPI(
    title="App B",
    description="Variant B for A/B testing",
)
# {{/docs-fragment variant-apps}}

# {{docs-fragment root-lifespan}}
@asynccontextmanager
async def lifespan(_app: FastAPI):
    """Initialize and shutdown Statsig for A/B testing."""
    # Startup: Initialize Statsig using singleton
    api_key = os.getenv("STATSIG_API_KEY", None)
    if api_key is None:
        raise RuntimeError(f"StatsigClient API Key not set. ENV vars {os.environ}")
    StatsigClient.initialize(api_key)

    yield

    # Shutdown: Cleanup Statsig
    StatsigClient.shutdown()

# Root App - Performs A/B testing and routes to A or B
root_app = FastAPI(
    title="Root App - A/B Testing",
    description="Routes requests to App A or App B based on Statsig A/B test",
    lifespan=lifespan,
)
# {{/docs-fragment root-lifespan}}

# {{docs-fragment variant-envs}}
env_a = FastAPIAppEnvironment(
    name="app-a-variant",
    app=app_a,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
)

env_b = FastAPIAppEnvironment(
    name="app-b-variant",
    app=app_b,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
)
# {{/docs-fragment variant-envs}}

# {{docs-fragment root-env}}
env_root = FastAPIAppEnvironment(
    name="root-ab-testing-app",
    app=root_app,
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
    depends_on=[env_a, env_b],
    secrets=flyte.Secret("statsig-api-key", as_env_var="STATSIG_API_KEY"),
)
# {{/docs-fragment root-env}}

# {{docs-fragment variant-endpoints}}
# App A endpoints
@app_a.get("/process/{message}")
async def process_a(message: str) -> dict[str, str]:
    return {
        "variant": "A",
        "message": f"App A processed: {message}",
        "algorithm": "fast-processing",
    }

# App B endpoints
@app_b.get("/process/{message}")
async def process_b(message: str) -> dict[str, str]:
    return {
        "variant": "B",
        "message": f"App B processed: {message}",
        "algorithm": "enhanced-processing",
    }
# {{/docs-fragment variant-endpoints}}

# {{docs-fragment routing-endpoint}}
# Root app A/B testing endpoint
@root_app.get("/process/{message}")
async def process_with_ab_test(message: str, user_key: str) -> dict[str, typing.Any]:
    """
    Process a message using A/B testing to determine which app to call.

    Args:
        message: The message to process
        user_key: User identifier for A/B test bucketing (e.g., user_id, session_id)

    Returns:
        Response from either App A or App B, plus metadata about which variant was used
    """
    # Import StatsigUser at runtime (only available in container)
    from statsig_python_core import StatsigUser

    # Get statsig client from singleton
    statsig = StatsigClient.get_client()

    # Create Statsig user with the provided key
    user = StatsigUser(user_id=user_key)

    # Check the feature gate "variant_b" to determine which variant
    # If gate is enabled, use App B; otherwise use App A
    use_variant_b = statsig.check_gate(user, "variant_b")

    # Call the appropriate app based on A/B test result
    async with httpx.AsyncClient() as client:
        if use_variant_b:
            endpoint = f"{env_b.endpoint}/process/{message}"
            response = await client.get(endpoint)
            result = response.json()
        else:
            endpoint = f"{env_a.endpoint}/process/{message}"
            response = await client.get(endpoint)
            result = response.json()

    # Add A/B test metadata to response
    return {
        "ab_test_result": {
            "user_key": user_key,
            "selected_variant": "B" if use_variant_b else "A",
            "gate_name": "variant_b",
        },
        "response": result,
    }
# {{/docs-fragment routing-endpoint}}

@root_app.get("/endpoints")
async def get_endpoints() -> dict[str, str]:
    """Get the endpoints for App A and App B."""
    return {
        "app_a_endpoint": env_a.endpoint,
        "app_b_endpoint": env_b.endpoint,
    }

@root_app.get("/")
async def index():
    """Serve the A/B testing demo HTML page."""
    from fastapi.responses import HTMLResponse

    html_content = """
    <!DOCTYPE html>
    <html lang="en">
    <head>
        <meta charset="UTF-8">
        <meta name="viewport" content="width=device-width, initial-scale=1.0">
        <title>A/B Testing Demo - Statsig</title>
        <style>
            * {
                margin: 0;
                padding: 0;
                box-sizing: border-box;
            }

            body {
                font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell,
                sans-serif;
                background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
                min-height: 100vh;
                display: flex;
                justify-content: center;
                align-items: center;
                padding: 20px;
            }

            .container {
                background: white;
                border-radius: 20px;
                box-shadow: 0 20px 60px rgba(0, 0, 0, 0.3);
                padding: 40px;
                max-width: 600px;
                width: 100%;
            }

            h1 {
                color: #333;
                margin-bottom: 10px;
                font-size: 28px;
            }

            .subtitle {
                color: #666;
                margin-bottom: 30px;
                font-size: 14px;
            }

            .form-group {
                margin-bottom: 20px;
            }

            label {
                display: block;
                margin-bottom: 8px;
                color: #555;
                font-weight: 500;
                font-size: 14px;
            }

            input {
                width: 100%;
                padding: 12px 16px;
                border: 2px solid #e0e0e0;
                border-radius: 8px;
                font-size: 14px;
                transition: border-color 0.3s;
            }

            input:focus {
                outline: none;
                border-color: #667eea;
            }

            button {
                width: 100%;
                padding: 14px;
                background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
                color: white;
                border: none;
                border-radius: 8px;
                font-size: 16px;
                font-weight: 600;
                cursor: pointer;
                transition: transform 0.2s, box-shadow 0.2s;
            }

            button:hover {
                transform: translateY(-2px);
                box-shadow: 0 10px 20px rgba(102, 126, 234, 0.4);
            }

            button:active {
                transform: translateY(0);
            }

            button:disabled {
                opacity: 0.6;
                cursor: not-allowed;
            }

            .result {
                margin-top: 30px;
                padding: 20px;
                border-radius: 12px;
                display: none;
            }

            .result.show {
                display: block;
            }

            .result.variant-a {
                background: #e3f2fd;
                border: 2px solid #2196f3;
            }

            .result.variant-b {
                background: #f3e5f5;
                border: 2px solid #9c27b0;
            }

            .result-header {
                font-size: 18px;
                font-weight: 600;
                margin-bottom: 15px;
                display: flex;
                align-items: center;
                gap: 10px;
            }

            .variant-badge {
                display: inline-block;
                padding: 4px 12px;
                border-radius: 12px;
                font-size: 12px;
                font-weight: 700;
            }

            .variant-a .variant-badge {
                background: #2196f3;
                color: white;
            }

            .variant-b .variant-badge {
                background: #9c27b0;
                color: white;
            }

            .result-content {
                margin-top: 10px;
            }

            .result-item {
                margin-bottom: 10px;
                padding: 10px;
                background: rgba(255, 255, 255, 0.8);
                border-radius: 6px;
            }

            .result-label {
                font-weight: 600;
                color: #555;
                font-size: 13px;
            }

            .result-value {
                color: #333;
                margin-top: 4px;
            }

            .error {
                background: #ffebee;
                border: 2px solid #f44336;
                color: #c62828;
                padding: 16px;
                border-radius: 8px;
                margin-top: 20px;
                display: none;
            }

            .error.show {
                display: block;
            }

            .info {
                background: #fff3e0;
                border-left: 4px solid #ff9800;
                padding: 12px 16px;
                margin-top: 20px;
                border-radius: 4px;
                font-size: 13px;
                color: #e65100;
            }
        </style>
    </head>
    <body>
        <div class="container">
            <h1>🎯 A/B Testing Demo</h1>
            <p class="subtitle">Test Statsig-powered variant selection</p>

            <form id="abTestForm">
                <div class="form-group">
                    <label for="message">Message to Process</label>
                    <input
                        type="text"
                        id="message"
                        name="message"
                        placeholder="e.g., hello, world, test"
                        required
                        value="hello"
                    >
                </div>

                <div class="form-group">
                    <label for="userKey">User Key (for A/B bucketing)</label>
                    <input
                        type="text"
                        id="userKey"
                        name="userKey"
                        placeholder="e.g., user123, session456"
                        required
                        value="user123"
                    >
                </div>

                <button type="submit" id="submitBtn">Run A/B Test</button>
            </form>

            <div id="result" class="result"></div>
            <div id="error" class="error"></div>

            <div class="info">
                💡 <strong>Tip:</strong> Try different user keys to see how Statsig routes to different variants.
                The same user key will always get the same variant (consistent bucketing).
            </div>
        </div>

        <script>
            const form = document.getElementById('abTestForm');
            const resultDiv = document.getElementById('result');
            const errorDiv = document.getElementById('error');
            const submitBtn = document.getElementById('submitBtn');

            form.addEventListener('submit', async (e) => {
                e.preventDefault();

                const message = document.getElementById('message').value;
                const userKey = document.getElementById('userKey').value;

                // Reset previous results
                resultDiv.classList.remove('show', 'variant-a', 'variant-b');
                errorDiv.classList.remove('show');
                submitBtn.disabled = true;
                submitBtn.textContent = 'Processing...';

                try {
                    const response =
                        await fetch(`/process/${encodeURIComponent(message)}?user_key=${encodeURIComponent(userKey)}`);

                    if (!response.ok) {
                        throw new Error(`HTTP error! status: ${response.status}`);
                    }

                    const data = await response.json();

                    // Display result
                    const variant = data.ab_test_result.selected_variant;
                    const variantClass = `variant-${variant.toLowerCase()}`;

                    resultDiv.className = `result show ${variantClass}`;
                    resultDiv.innerHTML = `
                        <div class="result-header">
                            <span>A/B Test Result</span>
                            <span class="variant-badge">Variant ${variant}</span>
                        </div>
                        <div class="result-content">
                            <div class="result-item">
                                <div class="result-label">User Key</div>
                                <div class="result-value">${data.ab_test_result.user_key}</div>
                            </div>
                            <div class="result-item">
                                <div class="result-label">Selected Variant</div>
                                <div class="result-value">Variant ${variant}
                                    (Gate: ${data.ab_test_result.gate_name})</div>
                            </div>
                            <div class="result-item">
                                <div class="result-label">Response from App ${variant}</div>
                                <div class="result-value">${data.response.message}</div>
                            </div>
                            <div class="result-item">
                                <div class="result-label">Algorithm</div>
                                <div class="result-value">${data.response.algorithm}</div>
                            </div>
                        </div>
                    `;

                } catch (error) {
                    errorDiv.textContent = `Error: ${error.message}`;
                    errorDiv.classList.add('show');
                } finally {
                    submitBtn.disabled = false;
                    submitBtn.textContent = 'Run A/B Test';
                }
            });
        </script>
    </body>
    </html>
    """
    return HTMLResponse(content=html_content)

# {{docs-fragment deploy}}
if __name__ == "__main__":
    flyte.init_from_config()
    flyte.deploy(env_root)
    print("Deployed A/B Testing Root App")
    print("\nUsage:")
    print("  Open your browser to '<endpoint>/' to access the interactive demo")
    print("  Or use curl: curl '<endpoint>/process/hello?user_key=user123'")
    print("\nNote: Set STATSIG_API_KEY secret to use real Statsig A/B testing.")
    print("      Create a feature gate named 'variant_b' in your Statsig dashboard.")
# {{/docs-fragment deploy}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/serving_graphs/ab_testing.py*

**Setup before running:**

1. Get a Server Secret Key at [statsig.com](https://www.statsig.com/) → Settings → API Keys.
2. Create a feature gate named `variant_b` (e.g. 50% rollout).
3. Set the Flyte secret:
   ```bash
   flyte create secret statsig-api-key <your-secret-key-here>
   ```

## When to split into a serving graph

Split when stages have:

- **Different bottlenecks** — CPU vs GPU vs memory
- **Different scaling needs** — bursty vs steady, wide vs narrow
- **Different lifecycles** — model weights you don't want to reload, expensive cold starts
- **Different routing concerns** — A/B, canary, proxy, gateway

Don't split just to separate code — a single app with a few endpoints is
simpler to operate.

## Best practices

1. **Use `depends_on`**: Always specify dependencies to ensure the dependency closure is deployed in one shot.
2. **Persistent HTTP clients**: Open one `httpx.AsyncClient` per replica in the app's lifespan rather than per request, to avoid TCP/TLS setup overhead.
3. **Pick the right wire format**: For tensor-shaped payloads, send raw bytes over `application/octet-stream` instead of JSON.
4. **Size each node independently**: GPU narrow, CPU wide; use scale-to-zero (`replicas=(0, N)`) for bursty downstream services.
5. **Authentication**: Use `requires_auth=True` on internal-only apps so they can't be reached from the public internet, and put public-facing auth on the entry-point app.
6. **Endpoint access**: Prefer `app_env.endpoint` for in-module references; use `flyte.app.AppEndpoint` parameters when the upstream env isn't importable.

=== PAGE: https://www.union.ai/docs/v2/union/user-guide/build-apps/hybrid-graphs ===

# Hybrid app-task graphs

Apps and tasks can interact with each other: tasks can call apps via HTTP, and apps can trigger task execution via the Flyte SDK. This page covers both patterns.

## Call app from task

Tasks can call apps by making HTTP requests to the app's endpoint. This is useful when:

- You need to use a long-running service during task execution
- You want to call a model serving endpoint from a batch processing task
- You need to interact with an API from a workflow

### Example: FastAPI app called from a task 

```
# /// script
# requires-python = ">=3.12"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "fastapi",
#    "httpx",
# ]
# ///

"""Example of a task calling an app."""

import pathlib
import httpx
from fastapi import FastAPI
import flyte
from flyte.app.extras import FastAPIAppEnvironment

app = FastAPI(title="Add One", description="Adds one to the input", version="1.0.0")

image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages("fastapi", "uvicorn", "httpx")

# {{docs-fragment app-definition}}
app_env = FastAPIAppEnvironment(
    name="add-one-app",
    app=app,
    description="Adds one to the input",
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
    requires_auth=False,
)
# {{/docs-fragment app-definition}}

# {{docs-fragment task-env}}
task_env = flyte.TaskEnvironment(
    name="add_one_task_env",
    image=image,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
    depends_on=[app_env],  # Ensure app is deployed before task runs
)
# {{/docs-fragment task-env}}

# {{docs-fragment app-endpoint}}
@app.get("/")
async def add_one(x: int) -> dict[str, int]:
    """Main endpoint for the add-one app."""
    return {"result": x + 1}
# {{/docs-fragment app-endpoint}}

# {{docs-fragment task}}
@task_env.task
async def add_one_task(x: int) -> int:
    print(f"Calling app at {app_env.endpoint}")
    async with httpx.AsyncClient() as client:
        response = await client.get(app_env.endpoint, params={"x": x})
        response.raise_for_status()
        return response.json()["result"]
# {{/docs-fragment task}}

# {{docs-fragment deploy}}
if __name__ == "__main__":
    flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
    deployments = flyte.deploy(task_env)
    print(f"Deployed task environment: {deployments}")
# {{/docs-fragment deploy}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/fastapi/task_calling_app.py*

Key points:

- The task environment uses `depends_on=[app_env]` to ensure the app is deployed first
- Access the app endpoint via `app_env.endpoint`
- Use standard HTTP client libraries (like `httpx`) to make requests

## Call task from app (webhooks / APIs)

Apps can trigger task execution using the Flyte SDK. This is useful for:

- Webhooks that trigger workflows
- APIs that need to run batch jobs
- Services that need to execute tasks asynchronously

Webhooks are HTTP endpoints that trigger actions in response to external events. Flyte apps can serve as webhook endpoints that trigger task runs, workflows, or other operations.

### Example: Basic webhook app

Here's a simple webhook that triggers Flyte tasks:

```
# /// script
# requires-python = ">=3.12"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "fastapi",
# ]
# ///

"""A webhook that triggers Flyte tasks."""

import pathlib
from fastapi import FastAPI, HTTPException, Security
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from starlette import status
import os
from contextlib import asynccontextmanager
import flyte
import flyte.remote as remote
from flyte.app.extras import FastAPIAppEnvironment

# {{docs-fragment auth}}
WEBHOOK_API_KEY = os.getenv("WEBHOOK_API_KEY", "test-api-key")
security = HTTPBearer()

async def verify_token(
    credentials: HTTPAuthorizationCredentials = Security(security),
) -> HTTPAuthorizationCredentials:
    """Verify the API key from the bearer token."""
    if credentials.credentials != WEBHOOK_API_KEY:
        raise HTTPException(
            status_code=status.HTTP_403_FORBIDDEN,
            detail="Could not validate credentials",
        )
    return credentials
# {{/docs-fragment auth}}

# {{docs-fragment lifespan}}
@asynccontextmanager
async def lifespan(app: FastAPI):
    """Initialize Flyte before accepting requests."""
    await flyte.init_in_cluster.aio()
    yield
    # Cleanup if needed
# {{/docs-fragment lifespan}}

# {{docs-fragment app}}
app = FastAPI(
    title="Flyte Webhook Runner",
    description="A webhook service that triggers Flyte task runs",
    version="1.0.0",
    lifespan=lifespan,
)

@app.get("/health")
async def health_check():
    """Health check endpoint."""
    return {"status": "healthy"}
# {{/docs-fragment app}}

# {{docs-fragment webhook-endpoint}}
@app.post("/run-task/{project}/{domain}/{name}/{version}")
async def run_task(
    project: str,
    domain: str,
    name: str,
    version: str,
    inputs: dict,
    credentials: HTTPAuthorizationCredentials = Security(verify_token),
):
    """
    Trigger a Flyte task run via webhook.

    Returns information about the launched run.
    """
    # Fetch the task
    task = remote.Task.get(
        project=project,
        domain=domain,
        name=name,
        version=version,
    )

    # Run the task
    run = await flyte.run.aio(task, **inputs)

    return {
        "url": run.url,
        "id": run.id,
        "status": "started",
    }
# {{/docs-fragment webhook-endpoint}}

# {{docs-fragment env}}
env = FastAPIAppEnvironment(
    name="webhook-runner",
    app=app,
    description="A webhook service that triggers Flyte task runs",
    image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
        "fastapi",
        "uvicorn",
    ),
    resources=flyte.Resources(cpu=1, memory="512Mi"),
    requires_auth=False,  # We handle auth in the app
    env_vars={"WEBHOOK_API_KEY": os.getenv("WEBHOOK_API_KEY", "test-api-key")},
)
# {{/docs-fragment env}}

# {{docs-fragment deploy}}
if __name__ == "__main__":
    flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
    app_deployment = flyte.deploy(env)
    print(f"Deployed webhook: {app_deployment[0].summary_repr()}")
# {{/docs-fragment deploy}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/webhook/basic_webhook.py*

Once deployed, you can trigger tasks via HTTP POST:

```bash
curl -X POST "https://your-webhook-url/run-task/flytesnacks/development/my_task/v1" \
  -H "Authorization: Bearer test-api-key" \
  -H "Content-Type: application/json" \
  -d '{"input_key": "input_value"}'
```

Response:

```json
{
  "url": "https://console.union.ai/...",
  "id": "abc123",
  "status": "started"
}
```

### Advanced webhook patterns

**Webhook with validation**

Use Pydantic for input validation:

```
# /// script
# requires-python = ">=3.12"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "fastapi",
# ]
# ///

"""A webhook with Pydantic validation."""

import pathlib
from fastapi import FastAPI, HTTPException, Security
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from starlette import status
import os
from contextlib import asynccontextmanager
from pydantic import BaseModel
import flyte
import flyte.remote as remote
from flyte.app.extras import FastAPIAppEnvironment

WEBHOOK_API_KEY = os.getenv("WEBHOOK_API_KEY", "test-api-key")
security = HTTPBearer()

async def verify_token(
    credentials: HTTPAuthorizationCredentials = Security(security),
) -> HTTPAuthorizationCredentials:
    """Verify the API key from the bearer token."""
    if credentials.credentials != WEBHOOK_API_KEY:
        raise HTTPException(
            status_code=status.HTTP_403_FORBIDDEN,
            detail="Could not validate credentials",
        )
    return credentials

@asynccontextmanager
async def lifespan(app: FastAPI):
    """Initialize Flyte before accepting requests."""
    await flyte.init_in_cluster.aio()
    yield

app = FastAPI(
    title="Flyte Webhook Runner with Validation",
    description="A webhook service that triggers Flyte task runs with Pydantic validation",
    version="1.0.0",
    lifespan=lifespan,
)

# {{docs-fragment validation-model}}
class TaskInput(BaseModel):
    data: dict
    priority: int = 0
# {{/docs-fragment validation-model}}

# {{docs-fragment validated-webhook}}
@app.post("/run-task/{project}/{domain}/{name}/{version}")
async def run_task(
    project: str,
    domain: str,
    name: str,
    version: str,
    inputs: TaskInput,  # Validated input
    credentials: HTTPAuthorizationCredentials = Security(verify_token),
):
    task = remote.Task.get(
        project=project,
        domain=domain,
        name=name,
        version=version,
    )

    run = await flyte.run.aio(task, **inputs.model_dump())

    return {
        "run_id": run.id,
        "url": run.url,
    }
# {{/docs-fragment validated-webhook}}

env = FastAPIAppEnvironment(
    name="webhook-with-validation",
    app=app,
    image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
        "fastapi",
        "uvicorn",
    ),
    resources=flyte.Resources(cpu=1, memory="512Mi"),
    requires_auth=False,
    env_vars={"WEBHOOK_API_KEY": os.getenv("WEBHOOK_API_KEY", "test-api-key")},
)

if __name__ == "__main__":
    flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
    app_deployment = flyte.deploy(env)
    print(f"Deployed webhook: {app_deployment[0].summary_repr()}")
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/fastapi/webhook_validation.py*

```
# /// script
# requires-python = ">=3.12"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "fastapi",
# ]
# ///

"""A webhook with Pydantic validation."""

import pathlib
from fastapi import FastAPI, HTTPException, Security
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from starlette import status
import os
from contextlib import asynccontextmanager
from pydantic import BaseModel
import flyte
import flyte.remote as remote
from flyte.app.extras import FastAPIAppEnvironment

WEBHOOK_API_KEY = os.getenv("WEBHOOK_API_KEY", "test-api-key")
security = HTTPBearer()

async def verify_token(
    credentials: HTTPAuthorizationCredentials = Security(security),
) -> HTTPAuthorizationCredentials:
    """Verify the API key from the bearer token."""
    if credentials.credentials != WEBHOOK_API_KEY:
        raise HTTPException(
            status_code=status.HTTP_403_FORBIDDEN,
            detail="Could not validate credentials",
        )
    return credentials

@asynccontextmanager
async def lifespan(app: FastAPI):
    """Initialize Flyte before accepting requests."""
    await flyte.init_in_cluster.aio()
    yield

app = FastAPI(
    title="Flyte Webhook Runner with Validation",
    description="A webhook service that triggers Flyte task runs with Pydantic validation",
    version="1.0.0",
    lifespan=lifespan,
)

# {{docs-fragment validation-model}}
class TaskInput(BaseModel):
    data: dict
    priority: int = 0
# {{/docs-fragment validation-model}}

# {{docs-fragment validated-webhook}}
@app.post("/run-task/{project}/{domain}/{name}/{version}")
async def run_task(
    project: str,
    domain: str,
    name: str,
    version: str,
    inputs: TaskInput,  # Validated input
    credentials: HTTPAuthorizationCredentials = Security(verify_token),
):
    task = remote.Task.get(
        project=project,
        domain=domain,
        name=name,
        version=version,
    )

    run = await flyte.run.aio(task, **inputs.model_dump())

    return {
        "run_id": run.id,
        "url": run.url,
    }
# {{/docs-fragment validated-webhook}}

env = FastAPIAppEnvironment(
    name="webhook-with-validation",
    app=app,
    image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
        "fastapi",
        "uvicorn",
    ),
    resources=flyte.Resources(cpu=1, memory="512Mi"),
    requires_auth=False,
    env_vars={"WEBHOOK_API_KEY": os.getenv("WEBHOOK_API_KEY", "test-api-key")},
)

if __name__ == "__main__":
    flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
    app_deployment = flyte.deploy(env)
    print(f"Deployed webhook: {app_deployment[0].summary_repr()}")
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/fastapi/webhook_validation.py*

**Webhook with response waiting**

Wait for task completion:

```
# /// script
# requires-python = ">=3.12"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "fastapi",
# ]
# ///

"""A webhook that waits for task completion."""

import pathlib
from fastapi import FastAPI, HTTPException, Security
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from starlette import status
import os
from contextlib import asynccontextmanager
import flyte
import flyte.remote as remote
from flyte.app.extras import FastAPIAppEnvironment

WEBHOOK_API_KEY = os.getenv("WEBHOOK_API_KEY", "test-api-key")
security = HTTPBearer()

async def verify_token(
    credentials: HTTPAuthorizationCredentials = Security(security),
) -> HTTPAuthorizationCredentials:
    """Verify the API key from the bearer token."""
    if credentials.credentials != WEBHOOK_API_KEY:
        raise HTTPException(
            status_code=status.HTTP_403_FORBIDDEN,
            detail="Could not validate credentials",
        )
    return credentials

@asynccontextmanager
async def lifespan(app: FastAPI):
    """Initialize Flyte before accepting requests."""
    await flyte.init_in_cluster.aio()
    yield

app = FastAPI(
    title="Flyte Webhook Runner (Wait for Completion)",
    description="A webhook service that triggers Flyte task runs and waits for completion",
    version="1.0.0",
    lifespan=lifespan,
)

# {{docs-fragment wait-webhook}}
@app.post("/run-task-and-wait/{project}/{domain}/{name}/{version}")
async def run_task_and_wait(
    project: str,
    domain: str,
    name: str,
    version: str,
    inputs: dict,
    credentials: HTTPAuthorizationCredentials = Security(verify_token),
):
    task = remote.Task.get(
        project=project,
        domain=domain,
        name=name,
        version=version,
    )

    run = await flyte.run.aio(task, **inputs)
    run.wait()  # Wait for completion

    return {
        "run_id": run.id,
        "url": run.url,
        "status": run.status,
        "outputs": run.outputs(),
    }
# {{/docs-fragment wait-webhook}}

env = FastAPIAppEnvironment(
    name="webhook-wait-completion",
    app=app,
    image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
        "fastapi",
        "uvicorn",
    ),
    resources=flyte.Resources(cpu=1, memory="512Mi"),
    requires_auth=False,
    env_vars={"WEBHOOK_API_KEY": os.getenv("WEBHOOK_API_KEY", "test-api-key")},
)

if __name__ == "__main__":
    flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
    app_deployment = flyte.deploy(env)
    print(f"Deployed webhook: {app_deployment[0].summary_repr()}")
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/fastapi/webhook_wait.py*

**Webhook with secret management**

Use Flyte secrets for API keys:

```python
env = FastAPIAppEnvironment(
    name="webhook-runner",
    app=app,
    secrets=flyte.Secret(key="webhook-api-key", as_env_var="WEBHOOK_API_KEY"),
    # ...
)
```

Then access in your app:

```python
WEBHOOK_API_KEY = os.getenv("WEBHOOK_API_KEY")
```

### Webhook security and best practices

- **Authentication**: Always secure webhooks with authentication (API keys, tokens, etc.).
- **Input validation**: Validate webhook inputs using Pydantic models.
- **Error handling**: Handle errors gracefully and return meaningful error messages.
- **Async operations**: Use async/await for I/O operations.
- **Health checks**: Include health check endpoints.
- **Logging**: Log webhook requests for debugging and auditing.
- **Rate limiting**: Consider implementing rate limiting for production.

Security considerations:

- Store API keys in Flyte secrets, not in code.
- Always use HTTPS in production.
- Validate all inputs to prevent injection attacks.
- Implement proper access control mechanisms.
- Log all webhook invocations for security auditing.

### Example: GitHub webhook

Here's an example webhook that triggers tasks based on GitHub events:

```
# /// script
# requires-python = ">=3.12"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "fastapi",
# ]
# ///

"""A GitHub webhook that triggers Flyte tasks based on GitHub events."""

import pathlib
import hmac
import hashlib
import os
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request, Header, HTTPException
import flyte
import flyte.remote as remote
from flyte.app.extras import FastAPIAppEnvironment

@asynccontextmanager
async def lifespan(app: FastAPI):
    """Initialize Flyte before accepting requests."""
    await flyte.init_in_cluster.aio()
    yield

app = FastAPI(
    title="GitHub Webhook Handler",
    description="Triggers Flyte tasks based on GitHub events",
    version="1.0.0",
    lifespan=lifespan,
)

# {{docs-fragment github-webhook}}
@app.post("/github-webhook")
async def github_webhook(
    request: Request,
    x_hub_signature_256: str = Header(None),
):
    """Handle GitHub webhook events."""
    body = await request.body()

    # Verify signature
    secret = os.getenv("GITHUB_WEBHOOK_SECRET")
    signature = hmac.new(
        secret.encode(),
        body,
        hashlib.sha256
    ).hexdigest()

    expected_signature = f"sha256={signature}"
    if not hmac.compare_digest(x_hub_signature_256, expected_signature):
        raise HTTPException(status_code=403, detail="Invalid signature")

    # Process webhook
    event = await request.json()
    event_type = request.headers.get("X-GitHub-Event")

    if event_type == "push":
        # Trigger deployment task
        task = remote.Task.get(
            project="my-project",
            domain="development",
            name="deploy-task",
            version="v1",
        )
        run = await flyte.run.aio(task, commit=event["after"])
        return {"run_id": run.id, "url": run.url}

    return {"status": "ignored"}
# {{/docs-fragment github-webhook}}

# {{docs-fragment env}}
env = FastAPIAppEnvironment(
    name="github-webhook",
    app=app,
    image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
        "fastapi",
        "uvicorn",
    ),
    resources=flyte.Resources(cpu=1, memory="512Mi"),
    requires_auth=False,
    secrets=flyte.Secret(key="GITHUB_WEBHOOK_SECRET", as_env_var="GITHUB_WEBHOOK_SECRET"),
)
# {{/docs-fragment env}}

if __name__ == "__main__":
    flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
    app_deployment = flyte.deploy(env)
    print(f"Deployed GitHub webhook: {app_deployment[0].summary_repr()}")
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/fastapi/github_webhook.py*

### Gradio agent UI

For AI agents, a Gradio app lets you build an interactive UI that kicks off agent runs. The app uses `flyte.with_runcontext()` to run the agent task either locally or on a remote cluster, controlled by an environment variable.

```python
import os
import flyte
import flyte.app
from research_agent import agent

RUN_MODE = os.getenv("RUN_MODE", "remote")

serving_env = flyte.app.AppEnvironment(
    name="research-agent-ui",
    image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
        "gradio", "langchain-core", "langchain-openai", "langgraph",
    ),
    secrets=flyte.Secret(key="OPENAI_API_KEY", as_env_var="OPENAI_API_KEY"),
    port=7860,
)

def run_query(request: str):
    """Kick off the agent as a Flyte task."""
    result = flyte.with_runcontext(mode=RUN_MODE).run(agent, request=request)
    result.wait()
    return result.outputs()[0]

@serving_env.server
def app_server():
    create_demo().launch(server_name="0.0.0.0", server_port=7860)

if __name__ == "__main__":
    create_demo().launch()
```

The `RUN_MODE` variable gives you a smooth development progression:

1. **Fully local**: `RUN_MODE=local python agent_app.py`. Everything runs in your local Python environment, great for rapid iteration.
2. **Local app, remote task**: `python agent_app.py`. The UI runs locally but the agent executes on the cluster with full compute resources.
3. **Full remote**: `flyte deploy agent_app.py serving_env`. Both the UI and agent run on the cluster.

## Best practices

1. **Use `depends_on`**: Always specify dependencies to ensure proper deployment order.
2. **Handle errors**: Implement proper error handling for HTTP requests.
3. **Use async clients**: Use async HTTP clients (`httpx.AsyncClient`) in async contexts.
4. **Initialize Flyte**: For apps calling tasks, initialize Flyte in the app's startup.
5. **Endpoint access**: Use `app_env.endpoint` or `AppEndpoint` parameter for accessing app URLs.
6. **Webhook security**: Secure webhooks with auth, validation, and HTTPS.

=== PAGE: https://www.union.ai/docs/v2/union/user-guide/build-apps/websocket-apps ===

# WebSocket apps

WebSockets enable bidirectional, real-time communication between clients and servers. Flyte apps can serve WebSocket endpoints for real-time applications like chat, live updates, or streaming data.

## Example: Basic WebSocket app

Here's a simple FastAPI app with WebSocket support:

```
# /// script
# requires-python = ">=3.12"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "fastapi",
#    "websockets",
# ]
# ///

"""A FastAPI app with WebSocket support."""

import pathlib
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
import asyncio
import json
from datetime import UTC, datetime
import flyte
from flyte.app.extras import FastAPIAppEnvironment

app = FastAPI(
    title="Flyte WebSocket Demo",
    description="A FastAPI app with WebSocket support",
    version="1.0.0",
)

# {{docs-fragment connection-manager}}
class ConnectionManager:
    """Manages WebSocket connections."""

    def __init__(self):
        self.active_connections: list[WebSocket] = []

    async def connect(self, websocket: WebSocket):
        """Accept and register a new WebSocket connection."""
        await websocket.accept()
        self.active_connections.append(websocket)
        print(f"Client connected. Total: {len(self.active_connections)}")

    def disconnect(self, websocket: WebSocket):
        """Remove a WebSocket connection."""
        self.active_connections.remove(websocket)
        print(f"Client disconnected. Total: {len(self.active_connections)}")

    async def send_personal_message(self, message: str, websocket: WebSocket):
        """Send a message to a specific WebSocket connection."""
        await websocket.send_text(message)

    async def broadcast(self, message: str):
        """Broadcast a message to all active connections."""
        for connection in self.active_connections:
            try:
                await connection.send_text(message)
            except Exception as e:
                print(f"Error broadcasting: {e}")

manager = ConnectionManager()
# {{/docs-fragment connection-manager}}

# {{docs-fragment websocket-endpoint}}
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
    """WebSocket endpoint for real-time communication."""
    await manager.connect(websocket)

    try:
        # Send welcome message
        await manager.send_personal_message(
            json.dumps({
                "type": "system",
                "message": "Welcome! You are connected.",
                "timestamp": datetime.now(UTC).isoformat(),
            }),
            websocket,
        )

        # Listen for messages
        while True:
            data = await websocket.receive_text()

            # Echo back to sender
            await manager.send_personal_message(
                json.dumps({
                    "type": "echo",
                    "message": f"Echo: {data}",
                    "timestamp": datetime.now(UTC).isoformat(),
                }),
                websocket,
            )

            # Broadcast to all clients
            await manager.broadcast(
                json.dumps({
                    "type": "broadcast",
                    "message": f"Broadcast: {data}",
                    "timestamp": datetime.now(UTC).isoformat(),
                    "connections": len(manager.active_connections),
                })
            )

    except WebSocketDisconnect:
        manager.disconnect(websocket)
        await manager.broadcast(
            json.dumps({
                "type": "system",
                "message": "A client disconnected",
                "connections": len(manager.active_connections),
            })
        )
# {{/docs-fragment websocket-endpoint}}

# {{docs-fragment env}}
env = FastAPIAppEnvironment(
    name="websocket-app",
    app=app,
    image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
        "fastapi",
        "uvicorn",
        "websockets",
    ),
    resources=flyte.Resources(cpu=1, memory="1Gi"),
    requires_auth=False,
)
# {{/docs-fragment env}}

# {{docs-fragment deploy}}
if __name__ == "__main__":
    flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
    app_deployment = flyte.deploy(env)
    print(f"Deployed websocket app: {app_deployment[0].summary_repr()}")
# {{/docs-fragment deploy}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/websocket/basic_websocket.py*

## WebSocket patterns

**Echo server**

```
# /// script
# requires-python = ">=3.12"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "fastapi",
#    "websockets",
# ]
# ///

"""WebSocket patterns: echo, broadcast, streaming, and chat."""

import asyncio
import json
import random
from datetime import datetime, UTC
from pathlib import Path
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
import flyte
from flyte.app.extras import FastAPIAppEnvironment

app = FastAPI(
    title="WebSocket Patterns Demo",
    description="Demonstrates various WebSocket patterns",
    version="1.0.0",
)

# {{docs-fragment echo-server}}
@app.websocket("/echo")
async def echo(websocket: WebSocket):
    await websocket.accept()
    try:
        while True:
            data = await websocket.receive_text()
            await websocket.send_text(f"Echo: {data}")
    except WebSocketDisconnect:
        pass
# {{/docs-fragment echo-server}}

# Connection manager for broadcast
class ConnectionManager:
    def __init__(self):
        self.active_connections: list[WebSocket] = []

    async def connect(self, websocket: WebSocket):
        await websocket.accept()
        self.active_connections.append(websocket)

    def disconnect(self, websocket: WebSocket):
        self.active_connections.remove(websocket)

    async def broadcast(self, message: str):
        for connection in self.active_connections:
            try:
                await connection.send_text(message)
            except Exception:
                pass

manager = ConnectionManager()

# {{docs-fragment broadcast-server}}
@app.websocket("/broadcast")
async def broadcast(websocket: WebSocket):
    await manager.connect(websocket)
    try:
        while True:
            data = await websocket.receive_text()
            await manager.broadcast(data)
    except WebSocketDisconnect:
        manager.disconnect(websocket)
# {{/docs-fragment broadcast-server}}

# {{docs-fragment streaming-server}}
@app.websocket("/stream")
async def stream_data(websocket: WebSocket):
    await websocket.accept()
    try:
        while True:
            # Generate or fetch data
            data = {"timestamp": datetime.now(UTC).isoformat(), "value": random.random()}
            await websocket.send_json(data)
            await asyncio.sleep(1)  # Send update every second
    except WebSocketDisconnect:
        pass
# {{/docs-fragment streaming-server}}

# {{docs-fragment chat-room}}
class ChatRoom:
    def __init__(self, name: str):
        self.name = name
        self.connections: list[WebSocket] = []

    async def join(self, websocket: WebSocket):
        self.connections.append(websocket)

    async def leave(self, websocket: WebSocket):
        self.connections.remove(websocket)

    async def broadcast(self, message: str, sender: WebSocket):
        for connection in self.connections:
            if connection != sender:
                await connection.send_text(message)

rooms: dict[str, ChatRoom] = {}

@app.websocket("/chat/{room_name}")
async def chat(websocket: WebSocket, room_name: str):
    await websocket.accept()

    if room_name not in rooms:
        rooms[room_name] = ChatRoom(room_name)

    room = rooms[room_name]
    await room.join(websocket)

    try:
        while True:
            data = await websocket.receive_text()
            await room.broadcast(data, websocket)
    except WebSocketDisconnect:
        await room.leave(websocket)
# {{/docs-fragment chat-room}}

env = FastAPIAppEnvironment(
    name="websocket-patterns",
    app=app,
    image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
        "fastapi",
        "uvicorn",
        "websockets",
    ),
    resources=flyte.Resources(cpu=1, memory="1Gi"),
    requires_auth=False,
)

if __name__ == "__main__":
    flyte.init_from_config(root_dir=Path(__file__).parent)
    app_deployment = flyte.deploy(env)
    print(f"Deployed WebSocket patterns app: {app_deployment[0].summary_repr()}")
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/websocket/websocket_patterns.py*

**Broadcast server**

```
# /// script
# requires-python = ">=3.12"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "fastapi",
#    "websockets",
# ]
# ///

"""WebSocket patterns: echo, broadcast, streaming, and chat."""

import asyncio
import json
import random
from datetime import datetime, UTC
from pathlib import Path
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
import flyte
from flyte.app.extras import FastAPIAppEnvironment

app = FastAPI(
    title="WebSocket Patterns Demo",
    description="Demonstrates various WebSocket patterns",
    version="1.0.0",
)

# {{docs-fragment echo-server}}
@app.websocket("/echo")
async def echo(websocket: WebSocket):
    await websocket.accept()
    try:
        while True:
            data = await websocket.receive_text()
            await websocket.send_text(f"Echo: {data}")
    except WebSocketDisconnect:
        pass
# {{/docs-fragment echo-server}}

# Connection manager for broadcast
class ConnectionManager:
    def __init__(self):
        self.active_connections: list[WebSocket] = []

    async def connect(self, websocket: WebSocket):
        await websocket.accept()
        self.active_connections.append(websocket)

    def disconnect(self, websocket: WebSocket):
        self.active_connections.remove(websocket)

    async def broadcast(self, message: str):
        for connection in self.active_connections:
            try:
                await connection.send_text(message)
            except Exception:
                pass

manager = ConnectionManager()

# {{docs-fragment broadcast-server}}
@app.websocket("/broadcast")
async def broadcast(websocket: WebSocket):
    await manager.connect(websocket)
    try:
        while True:
            data = await websocket.receive_text()
            await manager.broadcast(data)
    except WebSocketDisconnect:
        manager.disconnect(websocket)
# {{/docs-fragment broadcast-server}}

# {{docs-fragment streaming-server}}
@app.websocket("/stream")
async def stream_data(websocket: WebSocket):
    await websocket.accept()
    try:
        while True:
            # Generate or fetch data
            data = {"timestamp": datetime.now(UTC).isoformat(), "value": random.random()}
            await websocket.send_json(data)
            await asyncio.sleep(1)  # Send update every second
    except WebSocketDisconnect:
        pass
# {{/docs-fragment streaming-server}}

# {{docs-fragment chat-room}}
class ChatRoom:
    def __init__(self, name: str):
        self.name = name
        self.connections: list[WebSocket] = []

    async def join(self, websocket: WebSocket):
        self.connections.append(websocket)

    async def leave(self, websocket: WebSocket):
        self.connections.remove(websocket)

    async def broadcast(self, message: str, sender: WebSocket):
        for connection in self.connections:
            if connection != sender:
                await connection.send_text(message)

rooms: dict[str, ChatRoom] = {}

@app.websocket("/chat/{room_name}")
async def chat(websocket: WebSocket, room_name: str):
    await websocket.accept()

    if room_name not in rooms:
        rooms[room_name] = ChatRoom(room_name)

    room = rooms[room_name]
    await room.join(websocket)

    try:
        while True:
            data = await websocket.receive_text()
            await room.broadcast(data, websocket)
    except WebSocketDisconnect:
        await room.leave(websocket)
# {{/docs-fragment chat-room}}

env = FastAPIAppEnvironment(
    name="websocket-patterns",
    app=app,
    image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
        "fastapi",
        "uvicorn",
        "websockets",
    ),
    resources=flyte.Resources(cpu=1, memory="1Gi"),
    requires_auth=False,
)

if __name__ == "__main__":
    flyte.init_from_config(root_dir=Path(__file__).parent)
    app_deployment = flyte.deploy(env)
    print(f"Deployed WebSocket patterns app: {app_deployment[0].summary_repr()}")
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/websocket/websocket_patterns.py*

**Real-time data streaming**

```
# /// script
# requires-python = ">=3.12"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "fastapi",
#    "websockets",
# ]
# ///

"""WebSocket patterns: echo, broadcast, streaming, and chat."""

import asyncio
import json
import random
from datetime import datetime, UTC
from pathlib import Path
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
import flyte
from flyte.app.extras import FastAPIAppEnvironment

app = FastAPI(
    title="WebSocket Patterns Demo",
    description="Demonstrates various WebSocket patterns",
    version="1.0.0",
)

# {{docs-fragment echo-server}}
@app.websocket("/echo")
async def echo(websocket: WebSocket):
    await websocket.accept()
    try:
        while True:
            data = await websocket.receive_text()
            await websocket.send_text(f"Echo: {data}")
    except WebSocketDisconnect:
        pass
# {{/docs-fragment echo-server}}

# Connection manager for broadcast
class ConnectionManager:
    def __init__(self):
        self.active_connections: list[WebSocket] = []

    async def connect(self, websocket: WebSocket):
        await websocket.accept()
        self.active_connections.append(websocket)

    def disconnect(self, websocket: WebSocket):
        self.active_connections.remove(websocket)

    async def broadcast(self, message: str):
        for connection in self.active_connections:
            try:
                await connection.send_text(message)
            except Exception:
                pass

manager = ConnectionManager()

# {{docs-fragment broadcast-server}}
@app.websocket("/broadcast")
async def broadcast(websocket: WebSocket):
    await manager.connect(websocket)
    try:
        while True:
            data = await websocket.receive_text()
            await manager.broadcast(data)
    except WebSocketDisconnect:
        manager.disconnect(websocket)
# {{/docs-fragment broadcast-server}}

# {{docs-fragment streaming-server}}
@app.websocket("/stream")
async def stream_data(websocket: WebSocket):
    await websocket.accept()
    try:
        while True:
            # Generate or fetch data
            data = {"timestamp": datetime.now(UTC).isoformat(), "value": random.random()}
            await websocket.send_json(data)
            await asyncio.sleep(1)  # Send update every second
    except WebSocketDisconnect:
        pass
# {{/docs-fragment streaming-server}}

# {{docs-fragment chat-room}}
class ChatRoom:
    def __init__(self, name: str):
        self.name = name
        self.connections: list[WebSocket] = []

    async def join(self, websocket: WebSocket):
        self.connections.append(websocket)

    async def leave(self, websocket: WebSocket):
        self.connections.remove(websocket)

    async def broadcast(self, message: str, sender: WebSocket):
        for connection in self.connections:
            if connection != sender:
                await connection.send_text(message)

rooms: dict[str, ChatRoom] = {}

@app.websocket("/chat/{room_name}")
async def chat(websocket: WebSocket, room_name: str):
    await websocket.accept()

    if room_name not in rooms:
        rooms[room_name] = ChatRoom(room_name)

    room = rooms[room_name]
    await room.join(websocket)

    try:
        while True:
            data = await websocket.receive_text()
            await room.broadcast(data, websocket)
    except WebSocketDisconnect:
        await room.leave(websocket)
# {{/docs-fragment chat-room}}

env = FastAPIAppEnvironment(
    name="websocket-patterns",
    app=app,
    image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
        "fastapi",
        "uvicorn",
        "websockets",
    ),
    resources=flyte.Resources(cpu=1, memory="1Gi"),
    requires_auth=False,
)

if __name__ == "__main__":
    flyte.init_from_config(root_dir=Path(__file__).parent)
    app_deployment = flyte.deploy(env)
    print(f"Deployed WebSocket patterns app: {app_deployment[0].summary_repr()}")
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/websocket/websocket_patterns.py*

**Chat application**

```
# /// script
# requires-python = ">=3.12"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "fastapi",
#    "websockets",
# ]
# ///

"""WebSocket patterns: echo, broadcast, streaming, and chat."""

import asyncio
import json
import random
from datetime import datetime, UTC
from pathlib import Path
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
import flyte
from flyte.app.extras import FastAPIAppEnvironment

app = FastAPI(
    title="WebSocket Patterns Demo",
    description="Demonstrates various WebSocket patterns",
    version="1.0.0",
)

# {{docs-fragment echo-server}}
@app.websocket("/echo")
async def echo(websocket: WebSocket):
    await websocket.accept()
    try:
        while True:
            data = await websocket.receive_text()
            await websocket.send_text(f"Echo: {data}")
    except WebSocketDisconnect:
        pass
# {{/docs-fragment echo-server}}

# Connection manager for broadcast
class ConnectionManager:
    def __init__(self):
        self.active_connections: list[WebSocket] = []

    async def connect(self, websocket: WebSocket):
        await websocket.accept()
        self.active_connections.append(websocket)

    def disconnect(self, websocket: WebSocket):
        self.active_connections.remove(websocket)

    async def broadcast(self, message: str):
        for connection in self.active_connections:
            try:
                await connection.send_text(message)
            except Exception:
                pass

manager = ConnectionManager()

# {{docs-fragment broadcast-server}}
@app.websocket("/broadcast")
async def broadcast(websocket: WebSocket):
    await manager.connect(websocket)
    try:
        while True:
            data = await websocket.receive_text()
            await manager.broadcast(data)
    except WebSocketDisconnect:
        manager.disconnect(websocket)
# {{/docs-fragment broadcast-server}}

# {{docs-fragment streaming-server}}
@app.websocket("/stream")
async def stream_data(websocket: WebSocket):
    await websocket.accept()
    try:
        while True:
            # Generate or fetch data
            data = {"timestamp": datetime.now(UTC).isoformat(), "value": random.random()}
            await websocket.send_json(data)
            await asyncio.sleep(1)  # Send update every second
    except WebSocketDisconnect:
        pass
# {{/docs-fragment streaming-server}}

# {{docs-fragment chat-room}}
class ChatRoom:
    def __init__(self, name: str):
        self.name = name
        self.connections: list[WebSocket] = []

    async def join(self, websocket: WebSocket):
        self.connections.append(websocket)

    async def leave(self, websocket: WebSocket):
        self.connections.remove(websocket)

    async def broadcast(self, message: str, sender: WebSocket):
        for connection in self.connections:
            if connection != sender:
                await connection.send_text(message)

rooms: dict[str, ChatRoom] = {}

@app.websocket("/chat/{room_name}")
async def chat(websocket: WebSocket, room_name: str):
    await websocket.accept()

    if room_name not in rooms:
        rooms[room_name] = ChatRoom(room_name)

    room = rooms[room_name]
    await room.join(websocket)

    try:
        while True:
            data = await websocket.receive_text()
            await room.broadcast(data, websocket)
    except WebSocketDisconnect:
        await room.leave(websocket)
# {{/docs-fragment chat-room}}

env = FastAPIAppEnvironment(
    name="websocket-patterns",
    app=app,
    image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
        "fastapi",
        "uvicorn",
        "websockets",
    ),
    resources=flyte.Resources(cpu=1, memory="1Gi"),
    requires_auth=False,
)

if __name__ == "__main__":
    flyte.init_from_config(root_dir=Path(__file__).parent)
    app_deployment = flyte.deploy(env)
    print(f"Deployed WebSocket patterns app: {app_deployment[0].summary_repr()}")
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/websocket/websocket_patterns.py*

## Using WebSockets with Flyte tasks

You can trigger Flyte tasks from WebSocket messages:

```
# /// script
# requires-python = ">=3.12"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "fastapi",
#    "websockets",
# ]
# ///

"""A WebSocket app that triggers Flyte tasks and streams updates."""

import json
from pathlib import Path
from contextlib import asynccontextmanager
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
import flyte
import flyte.remote as remote
from flyte.app.extras import FastAPIAppEnvironment

@asynccontextmanager
async def lifespan(app: FastAPI):
    """Initialize Flyte before accepting requests."""
    await flyte.init_in_cluster.aio()
    yield

app = FastAPI(
    title="WebSocket Task Runner",
    description="Triggers Flyte tasks via WebSocket and streams updates",
    version="1.0.0",
    lifespan=lifespan,
)

# {{docs-fragment task-runner-websocket}}
@app.websocket("/task-runner")
async def task_runner(websocket: WebSocket):
    await websocket.accept()

    try:
        while True:
            # Receive task request
            message = await websocket.receive_text()
            request = json.loads(message)

            # Trigger Flyte task
            task = remote.Task.get(
                project=request["project"],
                domain=request["domain"],
                name=request["task"],
                version=request["version"],
            )

            run = await flyte.run.aio(task, **request["inputs"])

            # Send run info back
            await websocket.send_json({
                "run_id": run.id,
                "url": run.url,
                "status": "started",
            })

            # Optionally stream updates
            async for update in run.stream():
                await websocket.send_json({
                    "status": update.status,
                    "message": update.message,
                })

    except WebSocketDisconnect:
        pass
# {{/docs-fragment task-runner-websocket}}

env = FastAPIAppEnvironment(
    name="task-runner-websocket",
    app=app,
    image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
        "fastapi",
        "uvicorn",
        "websockets",
    ),
    resources=flyte.Resources(cpu=1, memory="1Gi"),
    requires_auth=False,
)

if __name__ == "__main__":
    flyte.init_from_config(root_dir=Path(__file__).parent)
    app_deployment = flyte.deploy(env)
    print(f"Deployed WebSocket task runner: {app_deployment[0].summary_repr()}")
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/websocket/task_runner_websocket.py*

## WebSocket client example

Connect from Python:

```python
import asyncio
import websockets
import json

async def client():
    uri = "ws://your-app-url/ws"
    async with websockets.connect(uri) as websocket:
        # Send message
        await websocket.send("Hello, Server!")
        
        # Receive message
        response = await websocket.recv()
        print(f"Received: {response}")

asyncio.run(client())
```

## Best practices

1. **Connection management**: Track active connections and handle disconnections gracefully.
2. **Heartbeats**: Implement ping/pong for connection health monitoring.
3. **Rate limiting**: Consider rate limiting for production deployments.
4. **Error handling**: Handle WebSocket errors and connection drops.
5. **Authentication**: Implement authentication for secure WebSocket connections.

=== PAGE: https://www.union.ai/docs/v2/union/user-guide/build-apps/browser-apps ===

# Browser apps

For browser-based apps (like Streamlit, Gradio, or custom HTML/JS dashboards), users interact directly through the web interface. The app URL is accessible in a browser, and users interact with the UI directly—no API calls needed from other services.

## Accessing browser-based apps

To access a browser-based app:

1. Deploy the app using `flyte deploy` or `flyte serve`
2. Navigate to the app URL in a browser
3. Interact with the UI directly

## Common browser-based app types

### Streamlit apps

Streamlit is ideal for data dashboards and ML prototypes. See [Streamlit app](https://www.union.ai/docs/v2/union/user-guide/native-app-integrations/streamlit-app) for details.

### Gradio apps

Gradio is great for ML model demos and interactive interfaces. You can deploy a Gradio app by building a custom [`AppEnvironment`](./single-script-apps) with the `gradio` package installed in your image.

### Custom HTML/JS apps

You can also serve custom HTML/JS applications using FastAPI's static file serving or any other web framework.

## Best practices

1. **Authentication**: For sensitive apps, enable authentication with `requires_auth=True`.
2. **Responsive design**: Design UIs that work on various screen sizes.
3. **Loading states**: Show loading indicators for long-running operations.
4. **Error handling**: Display user-friendly error messages.
5. **Resource management**: Configure appropriate CPU/memory resources based on expected usage.

=== PAGE: https://www.union.ai/docs/v2/union/user-guide/build-apps/secret-based-authentication ===

# Secret-based authentication

In this guide, we'll deploy a FastAPI app that uses API key authentication with Flyte secrets. This allows you to invoke the endpoint from the public internet securely without exposing API keys in your code.

## Create the secret

Before defining and deploying the app, you need to create the `API_KEY` secret in Flyte. This secret will store your API key securely.

Create the secret using the Flyte CLI:

```bash
flyte create secret API_KEY <your-api-key-value>
```

For example:

```bash
flyte create secret API_KEY my-secret-api-key-12345
```

> [!NOTE]
> The secret name `API_KEY` must match the key specified in the `flyte.Secret()` call in your code. The secret will be available to your app as the environment variable specified in `as_env_var`.

## Define the FastAPI app

Here's a simple FastAPI app that uses `HTTPAuthorizationCredentials` to authenticate requests using a secret stored in Flyte:

```
# /// script
# requires-python = ">=3.12"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "fastapi",
# ]
# ///

"""Basic FastAPI authentication using dependency injection."""

from fastapi import FastAPI, HTTPException, Security
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from starlette import status
import os
import pathlib
import flyte
from flyte.app.extras import FastAPIAppEnvironment

# Get API key from environment variable (loaded from Flyte secret)
# The secret must be created using: flyte create secret API_KEY <your-api-key-value>
API_KEY = os.getenv("API_KEY")
security = HTTPBearer()

async def verify_token(
    credentials: HTTPAuthorizationCredentials = Security(security),
) -> HTTPAuthorizationCredentials:
    """Verify the API key from the bearer token."""
    if not API_KEY:
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail="API_KEY not configured",
        )
    if credentials.credentials != API_KEY:
        raise HTTPException(
            status_code=status.HTTP_403_FORBIDDEN,
            detail="Could not validate credentials",
        )
    return credentials

app = FastAPI(title="Authenticated API")

@app.get("/public")
async def public_endpoint():
    """Public endpoint that doesn't require authentication."""
    return {"message": "This is public"}

@app.get("/protected")
async def protected_endpoint(
    credentials: HTTPAuthorizationCredentials = Security(verify_token),
):
    """Protected endpoint that requires authentication."""
    return {
        "message": "This is protected",
        "user": credentials.credentials,
    }

env = FastAPIAppEnvironment(
    name="authenticated-api",
    app=app,
    image=flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
        "fastapi",
        "uvicorn",
    ),
    resources=flyte.Resources(cpu=1, memory="512Mi"),
    requires_auth=False,  # We handle auth in the app
    secrets=flyte.Secret(key="API_KEY", as_env_var="API_KEY"),
)

if __name__ == "__main__":
    flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
    app_deployment = flyte.deploy(env)
    print(f"Deployed: {app_deployment[0].summary_repr()}")
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/fastapi/basic_auth.py*

As you can see, we:

1. Define a `FastAPI` app
2. Create a `verify_token` function that verifies the API key from the Bearer token
3. Define endpoints that use the `verify_token` function to authenticate requests
4. Configure the `FastAPIAppEnvironment` with:
   - `requires_auth=False` - This allows the endpoint to be reached without going through Flyte's authentication, since we're handling authentication ourselves using the `API_KEY` secret
   - `secrets=flyte.Secret(key="API_KEY", as_env_var="API_KEY")` - This injects the secret value into the `API_KEY` environment variable at runtime

The key difference from using `env_vars` is that secrets are stored securely in Flyte's secret store and injected at runtime, rather than being passed as plain environment variables.

## Deploy the FastAPI app

Once the secret is created, you can deploy the FastAPI app. Make sure your `config.yaml` file is in the same directory as your script, then run:

```bash
python basic_auth.py
```

Or use the Flyte CLI:

```bash
flyte serve basic_auth.py
```

Deploying the application will stream the status to the console and display the app URL:

```
✨ Deploying Application: authenticated-api
🔎 Console URL: https://<union-tenant>/console/projects/my-project/domains/development/apps/fastapi-with-auth
[Status] Pending: App is pending deployment
[Status] Started: Service is ready
🚀 Deployed Endpoint: https://rough-meadow-97cf5.apps.<union-tenant>
```

## Invoke the endpoint

Once deployed, you can invoke the authenticated endpoint using curl:

```bash
curl -X GET "https://rough-meadow-97cf5.apps.<union-tenant>/protected" \
  -H "Authorization: Bearer <your-api-key-value>"
```

Replace `<your-api-key-value>` with the actual API key value you used when creating the secret.

For example, if you created the secret with value `my-secret-api-key-12345`:

```bash
curl -X GET "https://rough-meadow-97cf5.apps.<union-tenant>/protected" \
  -H "Authorization: Bearer my-secret-api-key-12345"
```

You should receive a response:

```json
{
  "message": "This is protected",
  "user": "my-secret-api-key-12345"
}
```

## Authentication for vLLM and SGLang apps

Both vLLM and SGLang apps support API key authentication through their native `--api-key` argument. This allows you to secure your LLM endpoints while keeping them accessible from the public internet.

### Create the authentication secret

Create a secret to store your API key:

```bash
flyte create secret AUTH_SECRET <your-api-key-value>
```

For example:

```bash
flyte create secret AUTH_SECRET my-llm-api-key-12345
```

### Deploy vLLM app with authentication

Here's how to deploy a vLLM app with API key authentication:

```
# /// script
# requires-python = ">=3.12"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "flyteplugins-vllm>=2.0.0b45",
# ]
# ///

"""vLLM app with API key authentication."""

import pathlib
from flyteplugins.vllm import VLLMAppEnvironment
import flyte

# The secret must be created using: flyte create secret AUTH_SECRET <your-api-key-value>
vllm_app = VLLMAppEnvironment(
    name="vllm-app-with-auth",
    model_hf_path="Qwen/Qwen3-0.6B",  # HuggingFace model path
    model_id="qwen3-0.6b",  # Model ID exposed by vLLM
    resources=flyte.Resources(
        cpu="4",
        memory="16Gi",
        gpu="L40s:1",  # GPU required for LLM serving
        disk="10Gi",
    ),
    scaling=flyte.app.Scaling(
        replicas=(0, 1),
        scaledown_after=300,  # Scale down after 5 minutes of inactivity
    ),
    # Disable Union's platform-level authentication so you can access the
    # endpoint from the public internet
    requires_auth=False,
    # Inject the secret as an environment variable
    secrets=flyte.Secret(key="AUTH_SECRET", as_env_var="AUTH_SECRET"),
    # Pass the API key to vLLM's --api-key argument
    # The $AUTH_SECRET will be replaced with the actual secret value at runtime
    extra_args=[
        "--api-key", "$AUTH_SECRET",
    ],
)

if __name__ == "__main__":
    flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
    app = flyte.serve(vllm_app)
    print(f"Deployed vLLM app: {app.url}")
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/vllm/vllm_with_auth.py*

Key points:

1. **`requires_auth=False`** - Disables Union's platform-level authentication so the endpoint can be accessed from the public internet
2. **`secrets=flyte.Secret(key="AUTH_SECRET", as_env_var="AUTH_SECRET")`** - Injects the secret as an environment variable
3. **`extra_args=["--api-key", "$AUTH_SECRET"]`** - Passes the API key to vLLM's `--api-key` argument. The `$AUTH_SECRET` will be replaced with the actual secret value at runtime

Deploy the app:

```bash
python vllm_with_auth.py
```

Or use the Flyte CLI:

```bash
flyte serve vllm_with_auth.py
```

### Deploy SGLang app with authentication

Here's how to deploy a SGLang app with API key authentication:

```
# /// script
# requires-python = ">=3.12"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "flyteplugins-sglang>=2.0.0b45",
# ]
# ///

"""SGLang app with API key authentication."""

import pathlib
from flyteplugins.sglang import SGLangAppEnvironment
import flyte

# The secret must be created using: flyte create secret AUTH_SECRET <your-api-key-value>
sglang_app = SGLangAppEnvironment(
    name="sglang-with-auth",
    model_hf_path="Qwen/Qwen3-0.6B",  # HuggingFace model path
    model_id="qwen3-0.6b",  # Model ID exposed by SGLang
    resources=flyte.Resources(
        cpu="4",
        memory="16Gi",
        gpu="L40s:1",  # GPU required for LLM serving
        disk="10Gi",
    ),
    scaling=flyte.app.Scaling(
        replicas=(0, 1),
        scaledown_after=300,  # Scale down after 5 minutes of inactivity
    ),
    # Disable Union's platform-level authentication so you can access the
    # endpoint from the public internet
    requires_auth=False,
    # Inject the secret as an environment variable
    secrets=flyte.Secret(key="AUTH_SECRET", as_env_var="AUTH_SECRET"),
    # Pass the API key to SGLang's --api-key argument
    # The $AUTH_SECRET will be replaced with the actual secret value at runtime
    extra_args=[
        "--api-key", "$AUTH_SECRET",
    ],
)

if __name__ == "__main__":
    flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
    app = flyte.serve(sglang_app)
    print(f"Deployed SGLang app: {app.url}")
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/user-guide/build-apps/sglang/sglang_with_auth.py*

The configuration is similar to vLLM:

1. **`requires_auth=False`** - Disables Union's platform-level authentication
2. **`secrets=flyte.Secret(key="AUTH_SECRET", as_env_var="AUTH_SECRET")`** - Injects the secret as an environment variable
3. **`extra_args=["--api-key", "$AUTH_SECRET"]`** - Passes the API key to SGLang's `--api-key` argument

Deploy the app:

```bash
python sglang_with_auth.py
```

Or use the Flyte CLI:

```bash
flyte serve sglang_with_auth.py
```

### Invoke authenticated LLM endpoints

Once deployed, you can invoke the authenticated endpoints using the OpenAI-compatible API format. Both vLLM and SGLang expose OpenAI-compatible endpoints.

For example, to make a chat completion request:

```bash
curl -X POST "https://your-app-url/v1/chat/completions" \
  -H "Content-Type: application/json" \
  -H "Authorization: Bearer <your-api-key-value>" \
  -d '{
    "model": "qwen3-0.6b",
    "messages": [
      {"role": "user", "content": "Hello, how are you?"}
    ]
  }'
```

Replace `<your-api-key-value>` with the actual API key value you used when creating the secret.

For example, if you created the secret with value `my-llm-api-key-12345`:

```bash
curl -X POST "https://your-app-url/v1/chat/completions" \
  -H "Content-Type: application/json" \
  -H "Authorization: Bearer my-llm-api-key-12345" \
  -d '{
    "model": "qwen3-0.6b",
    "messages": [
      {"role": "user", "content": "Hello, how are you?"}
    ]
  }'
```

You should receive a response with the model's completion.

> [!NOTE]
> The `$AUTH_SECRET` syntax in `extra_args` is automatically replaced with the actual secret value at runtime. This ensures the API key is never exposed in your code or configuration files.

## Accessing Swagger documentation

The app also includes a public health check endpoint and Swagger UI documentation:

- **Health check**: `https://your-app-url/health`
- **Swagger UI**: `https://your-app-url/docs`
- **ReDoc**: `https://your-app-url/redoc`

The Swagger UI will show an "Authorize" button where you can enter your Bearer token to test authenticated endpoints directly from the browser.

## Security best practices

1. **Use strong API keys**: Generate cryptographically secure random strings for your API keys
2. **Rotate keys regularly**: Periodically rotate your API keys for better security
3. **Scope secrets appropriately**: Use project/domain scoping when creating secrets if you want to limit access:
   ```bash
   flyte create secret --project my-project --domain development API_KEY my-secret-value
   ```
4. **Never commit secrets**: Always use Flyte secrets for API keys, never hardcode them in your code
5. **Use HTTPS**: Always use HTTPS in production (Flyte apps are served over HTTPS by default)

## Troubleshooting

**Authentication failing:**
- Verify the secret exists: `flyte get secret API_KEY`
- Check that the secret key name matches exactly (case-sensitive)
- Ensure you're using the correct Bearer token value
- Verify the `as_env_var` parameter matches the environment variable name in your code

**Secret not found:**
- Make sure you've created the secret before deploying the app
- Check the secret scope (organization vs project/domain) matches your app's project/domain
- Verify the secret name matches exactly (should be `API_KEY`)

**App not starting:**
- Check container logs for errors
- Verify all dependencies are installed in the image
- Ensure the secret is accessible in the app's project/domain

**LLM app authentication not working:**
- Verify the secret exists: `flyte get secret AUTH_SECRET`
- Check that `$AUTH_SECRET` is correctly specified in `extra_args` (note the `$` prefix)
- Ensure the secret name matches exactly (case-sensitive) in both the `flyte.Secret()` call and `extra_args`
- For vLLM, verify the `--api-key` argument is correctly passed
- For SGLang, verify the `--api-key` argument is correctly passed
- Check that `requires_auth=False` is set to allow public access

## Next steps

- Learn more about [managing secrets](https://www.union.ai/docs/v2/union/user-guide/task-configuration/secrets) in Flyte
- See [hybrid graphs](./hybrid-graphs) for webhook examples and authentication patterns
- Learn about [vLLM apps](https://www.union.ai/docs/v2/union/user-guide/native-app-integrations/vllm-app) and [SGLang apps](https://www.union.ai/docs/v2/union/user-guide/native-app-integrations/sglang-app) for serving LLMs

=== PAGE: https://www.union.ai/docs/v2/union/user-guide/build-apps/connector-app ===

# Connector app

A **connector** lets you extend Flyte with custom task execution backends — for example, submitting jobs to an internal batch service, a proprietary ML platform, or any external API. Rather than running your task code directly in a container, Flyte delegates execution to the connector, which polls the external system and reports status back to the orchestrator.

Connectors are deployed as long-running services via `flyte.app.ConnectorEnvironment`, the same app deployment model used for FastAPI endpoints and model servers.

## When to build a custom connector

Build a connector when:

- Tasks need to submit work to an external system (e.g., a job scheduler, a cloud ML service) and poll for completion asynchronously
- You want Flyte's orchestration, observability, and data lineage on top of a non-Kubernetes backend
- Multiple tasks share the same external integration and you want to centralize that logic

## Project structure

A connector app spans two concerns: the connector service (deployed once) and the task plugin (used in each workflow). A typical layout is:

```
my_project/
├── app.py              # Deploy the connector as a flyte app
├── main.py             # Example task that uses the connector
└── my_connector/
    ├── __init__.py
    ├── connector.py    # AsyncConnector implementation
    └── task.py         # Task plugin (AsyncConnectorExecutorMixin)
```

## Step 1: Implement the connector

The connector implements four lifecycle methods: `create` (submit the job), `get` (poll status), `delete` (cancel), and `get_logs` (stream paginated log lines to the UI).

```
import time
import uuid
from dataclasses import dataclass
from typing import Any, Dict, Optional

from flyteidl2.connector.connector_pb2 import (
    GetTaskLogsResponse,
    GetTaskLogsResponseBody,
    GetTaskLogsResponseHeader,
)
from flyteidl2.core.execution_pb2 import TaskExecution
from flyteidl2.logs.dataplane.payload_pb2 import LogLine, LogLineOriginator
from google.protobuf.timestamp_pb2 import Timestamp

from flyte import logger
from flyte.connectors import AsyncConnector, ConnectorRegistry, Resource, ResourceMeta

@dataclass
class BatchJobMetadata(ResourceMeta):
    job_id: str
    created_at: float

class BatchJobConnector(AsyncConnector):
    name = "Batch Job Connector"
    task_type_name = "batch_job"
    metadata_type = BatchJobMetadata

    async def create(self, task_template, inputs: Optional[Dict[str, Any]] = None, **kwargs) -> BatchJobMetadata:
        job_id = str(uuid.uuid4())[:8]
        logger.info(f"Submitted batch job {job_id}")
        return BatchJobMetadata(job_id=job_id, created_at=time.time())

    async def get(self, resource_meta: BatchJobMetadata, **kwargs) -> Resource:
        elapsed = time.time() - resource_meta.created_at
        if elapsed < 5:
            return Resource(phase=TaskExecution.RUNNING, message="Job in progress")
        return Resource(
            phase=TaskExecution.SUCCEEDED,
            message="Job completed",
            outputs={"result": f"output-from-{resource_meta.job_id}"},
        )

    async def delete(self, resource_meta: BatchJobMetadata, **kwargs):
        logger.info(f"Cancelled job {resource_meta.job_id}")

    async def get_logs(self, resource_meta: BatchJobMetadata, token: str = "", **kwargs):
        def line(message: str, ts: float) -> LogLine:
            t = Timestamp()
            t.FromSeconds(int(ts))
            return LogLine(timestamp=t, message=message, originator=LogLineOriginator.USER)

        start = resource_meta.created_at
        job_id = resource_meta.job_id
        pages = {
            "": GetTaskLogsResponseBody(lines=[
                line(f"[INFO] Job {job_id} submitted", start),
                line(f"[INFO] Job {job_id} started", start + 1),
            ]),
            "page-2": GetTaskLogsResponseBody(lines=[
                line(f"[INFO] Job {job_id} finished", start + 5),
            ]),
        }
        next_tokens = {"": "page-2", "page-2": ""}
        yield GetTaskLogsResponse(body=pages.get(token, GetTaskLogsResponseBody(lines=[])))
        next_token = next_tokens.get(token, "")
        if next_token:
            yield GetTaskLogsResponse(header=GetTaskLogsResponseHeader(token=next_token))

ConnectorRegistry.register(BatchJobConnector())
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/connectors/batch_job/connector.py*

Key points:

- **`task_type_name`** must match the `_TASK_TYPE` on your task plugin (see Step 2)
- **`ResourceMeta`** carries whatever state you need between `create` and subsequent `get` / `delete` calls (e.g., a job ID)
- **`Resource.outputs`** maps output names to values; these become the task's return values
- **`get_logs`** is an async generator that yields `GetTaskLogsResponse` messages. Yield a `body` with log lines, then optionally a `header` carrying the next-page token. Omit the final header to signal end of pagination. The UI displays these lines under the **Logs** tab of the task run.
- Register the connector with `ConnectorRegistry.register()` at module level so it is discovered on startup

Connector logs as shown in the UI:

![Connector logs shown in the UI Logs tab](https://raw.githubusercontent.com/unionai/unionai-docs-static/main/images/user-guide/build-apps/connector-app/connector-logs-ui.png)

## Step 2: Create the task plugin

The task plugin is the Python object your workflows use. It declares the task type, inputs, outputs, and any configuration; `AsyncConnectorExecutorMixin` wires it to the connector at execution time.

```
from dataclasses import dataclass
from typing import Any, Dict, Optional, Type

from flyte.connectors import AsyncConnectorExecutorMixin
from flyte.extend import TaskTemplate
from flyte.models import NativeInterface, SerializationContext

@dataclass
class BatchJobConfig:
    timeout_seconds: int = 300

class BatchJobTask(AsyncConnectorExecutorMixin, TaskTemplate):
    _TASK_TYPE = "batch_job"

    def __init__(self, name: str, plugin_config: BatchJobConfig,
                 inputs: Optional[Dict[str, Type]] = None,
                 outputs: Optional[Dict[str, Type]] = None, **kwargs):
        super().__init__(
            name=name,
            interface=NativeInterface(
                {k: (v, None) for k, v in inputs.items()} if inputs else {},
                outputs or {},
            ),
            task_type=self._TASK_TYPE,
            image=None,
            **kwargs,
        )
        self.plugin_config = plugin_config

    def custom_config(self, sctx: SerializationContext) -> Optional[Dict[str, Any]]:
        return {"timeout_seconds": self.plugin_config.timeout_seconds}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/connectors/batch_job/task.py*

Key points:

- **`_TASK_TYPE`** ties this plugin to the connector that declares the same `task_type_name`
- **`custom_config`** serializes plugin-specific settings into the task template; the connector receives these in `task_template.custom` during `create`
- `image=None` is correct — the connector service, not the task container, executes this task

## Step 3: Deploy the connector

`ConnectorEnvironment` builds and deploys the connector as a long-running service. The `include` parameter lists the Python packages or modules to copy into the connector image.

```python
# app.py
from pathlib import Path

import flyte
import flyte.app

image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages("flyte[connector]")

connector = flyte.app.ConnectorEnvironment(
    name="batch-job-connector",
    image=image,
    resources=flyte.Resources(cpu="1", memory="1Gi"),
    include=["my_connector"],
)

if __name__ == "__main__":
    flyte.init_from_config(root_dir=Path(__file__).parent)
    d = flyte.deploy(connector)
    print(d[0])
```

Deploy with:

```bash
python app.py
```

Or using the CLI:

```bash
flyte deploy app.py connector
```

Flyte builds the image, pushes it, and starts the connector service. The service stays running and handles all `create` / `get` / `delete` calls for tasks with `task_type_name = "batch_job"`.

## Step 4: Register and run tasks

Create and register a `TaskEnvironment` that points to your connector, then run the task:

```python
# main.py
from pathlib import Path

from my_connector.task import BatchJobConfig, BatchJobTask

import flyte

batch_job = BatchJobTask(
    name="my_batch_job",
    plugin_config=BatchJobConfig(timeout_seconds=60),
    inputs={"name": str},
    outputs={"result": str},
)

flyte.TaskEnvironment.from_task("batch-job-env", batch_job)

if __name__ == "__main__":
    flyte.init_from_config(root_dir=Path(__file__).parent)
    result = flyte.run(batch_job, name="hello")
    print(result.url)
```

`TaskEnvironment.from_task` registers the task under a named environment so Flyte knows which connector service to route executions to.

Run the task:

```bash
python main.py
# or
flyte run main.py my_batch_job --name hello
```

## How it all fits together

```
flyte run main.py my_batch_job
       │
       ▼
  Flyte executor sees task_type = "batch_job"
       │
       ▼
  Routes to the deployed connector service (batch-job-connector)
       │
       ├─ connector.create()  →  submits job, returns BatchJobMetadata
       │
       ├─ connector.get()     →  polls status (RUNNING / SUCCEEDED / FAILED)
       │
       └─ connector.delete()  →  called on cancellation
```

The connector service is the only component that needs network access to the external system. Your workflow code and Flyte's propeller never communicate with it directly.

## Secrets

There are two ways to pass credentials to a connector.

**Connector-level secrets** are shared across all tasks using the connector. Add them to `ConnectorEnvironment.secrets` and read them from `os.environ` inside the connector. See [Connectors](https://www.union.ai/docs/v2/union/user-guide/integrations/_index) for details.

### Per-task secrets (per-user credentials)

When different users need to run the same connector with their own credentials, pass the secret *name* through the task plugin. Flyte fetches the secret value at execution time and injects it as a keyword argument into the connector's `create` and `get` methods.

**1. Accept a secret name in the task plugin:**

```python
# my_connector/task.py

class BatchJobTask(AsyncConnectorExecutorMixin, TaskTemplate):
    _TASK_TYPE = "batch_job"

    def __init__(
        self,
        name: str,
        plugin_config: BatchJobConfig,
        inputs: Optional[Dict[str, Type]] = None,
        outputs: Optional[Dict[str, Type]] = None,
        api_key: Optional[str] = None,   # name of the secret in Flyte's secret store
        **kwargs,
    ):
        super().__init__(...)
        self.plugin_config = plugin_config
        self.api_key = api_key

    def custom_config(self, sctx: SerializationContext) -> Optional[Dict[str, Any]]:
        config = {"timeout_seconds": self.plugin_config.timeout_seconds}
        if self.api_key is not None:
            config["secrets"] = {"api_key": self.api_key}  # secret name, not value
        return config
```

**2. Receive the secret value as a kwarg in the connector:**

Flyte reads the secret name from `task_template.custom.secrets`, fetches the value from the secrets store, and passes it as a keyword argument to `create` and `get`:

```python
# my_connector/connector.py

class BatchJobConnector(AsyncConnector):
    ...

    async def create(
        self,
        task_template,
        inputs=None,
        api_key: Optional[str] = None,   # value injected by Flyte
        **kwargs,
    ) -> BatchJobMetadata:
        # use api_key to authenticate against the external service
        ...

    async def get(
        self,
        resource_meta: BatchJobMetadata,
        api_key: Optional[str] = None,
        **kwargs,
    ) -> Resource:
        ...
```

**3. Each user specifies their own secret name when defining the task:**

```python
# alice's workflow
batch_job = BatchJobTask(
    name="alice_batch_job",
    plugin_config=BatchJobConfig(timeout_seconds=60),
    inputs={"name": str},
    outputs={"result": str},
    api_key="alice-api-key",   # Alice's secret, stored under this name in Flyte
)

# bob's workflow
batch_job = BatchJobTask(
    name="bob_batch_job",
    plugin_config=BatchJobConfig(timeout_seconds=60),
    inputs={"name": str},
    outputs={"result": str},
    api_key="bob-api-key",     # Bob's own secret
)
```

See [Secrets](https://www.union.ai/docs/v2/union/user-guide/task-configuration/secrets) for how to store secrets in Flyte.

## Related

- [`ConnectorEnvironment` API reference](https://www.union.ai/docs/v2/union/api-reference/flyte-sdk/packages/flyte.app/connectorenvironment)
- [`AsyncConnector` API reference](https://www.union.ai/docs/v2/union/api-reference/flyte-sdk/packages/flyte.connectors/asyncconnector)
- [Task plugins](https://www.union.ai/docs/v2/union/user-guide/task-configuration/task-plugins)

