diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile new file mode 100644 index 00000000..ff261bad --- /dev/null +++ b/.devcontainer/Dockerfile @@ -0,0 +1,9 @@ +ARG VARIANT="3.9" +FROM mcr.microsoft.com/vscode/devcontainers/python:0-${VARIANT} + +USER vscode + +RUN curl -sSf https://rye.astral.sh/get | RYE_VERSION="0.44.0" RYE_INSTALL_OPTION="--yes" bash +ENV PATH=/home/vscode/.rye/shims:$PATH + +RUN echo "[[ -d .venv ]] && source .venv/bin/activate || export PATH=\$PATH" >> /home/vscode/.bashrc diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 00000000..c17fdc16 --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,43 @@ +// For format details, see https://aka.ms/devcontainer.json. For config options, see the +// README at: https://github.com/devcontainers/templates/tree/main/src/debian +{ + "name": "Debian", + "build": { + "dockerfile": "Dockerfile", + "context": ".." + }, + + "postStartCommand": "rye sync --all-features", + + "customizations": { + "vscode": { + "extensions": [ + "ms-python.python" + ], + "settings": { + "terminal.integrated.shell.linux": "/bin/bash", + "python.pythonPath": ".venv/bin/python", + "python.defaultInterpreterPath": ".venv/bin/python", + "python.typeChecking": "basic", + "terminal.integrated.env.linux": { + "PATH": "/home/vscode/.rye/shims:${env:PATH}" + } + } + } + }, + "features": { + "ghcr.io/devcontainers/features/node:1": {} + } + + // Features to add to the dev container. More info: https://containers.dev/features. + // "features": {}, + + // Use 'forwardPorts' to make a list of ports inside the container available locally. + // "forwardPorts": [], + + // Configure tool-specific properties. + // "customizations": {}, + + // Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root. + // "remoteUser": "root" +} diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 00000000..a3ffd9a8 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,104 @@ +name: CI +on: + push: + branches: + - '**' + - '!integrated/**' + - '!stl-preview-head/**' + - '!stl-preview-base/**' + - '!generated' + - '!codegen/**' + - 'codegen/stl/**' + pull_request: + branches-ignore: + - 'stl-preview-head/**' + - 'stl-preview-base/**' + +jobs: + lint: + timeout-minutes: 10 + name: lint + runs-on: ${{ github.repository == 'stainless-sdks/writer-python' && 'depot-ubuntu-24.04' || 'ubuntu-latest' }} + if: (github.event_name == 'push' || github.event.pull_request.head.repo.fork) && (github.event_name != 'push' || github.event.head_commit.message != 'codegen metadata') + steps: + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + + - name: Install Rye + run: | + curl -sSf https://rye.astral.sh/get | bash + echo "$HOME/.rye/shims" >> $GITHUB_PATH + env: + RYE_VERSION: '0.44.0' + RYE_INSTALL_OPTION: '--yes' + + - name: Install dependencies + run: rye sync --all-features + + - name: Run lints + run: ./scripts/lint + + build: + if: (github.event_name == 'push' || github.event.pull_request.head.repo.fork) && (github.event_name != 'push' || github.event.head_commit.message != 'codegen metadata') + timeout-minutes: 10 + name: build + permissions: + contents: read + id-token: write + runs-on: ${{ github.repository == 'stainless-sdks/writer-python' && 'depot-ubuntu-24.04' || 'ubuntu-latest' }} + steps: + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + + - name: Install Rye + run: | + curl -sSf https://rye.astral.sh/get | bash + echo "$HOME/.rye/shims" >> $GITHUB_PATH + env: + RYE_VERSION: '0.44.0' + RYE_INSTALL_OPTION: '--yes' + + - name: Install dependencies + run: rye sync --all-features + + - name: Run build + run: rye build + + - name: Get GitHub OIDC Token + if: |- + github.repository == 'stainless-sdks/writer-python' && + !startsWith(github.ref, 'refs/heads/stl/') + id: github-oidc + uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 + with: + script: core.setOutput('github_token', await core.getIDToken()); + + - name: Upload tarball + if: |- + github.repository == 'stainless-sdks/writer-python' && + !startsWith(github.ref, 'refs/heads/stl/') + env: + URL: https://pkg.stainless.com/s + AUTH: ${{ steps.github-oidc.outputs.github_token }} + SHA: ${{ github.sha }} + run: ./scripts/utils/upload-artifact.sh + + test: + timeout-minutes: 10 + name: test + runs-on: ${{ github.repository == 'stainless-sdks/writer-python' && 'depot-ubuntu-24.04' || 'ubuntu-latest' }} + if: github.event_name == 'push' || github.event.pull_request.head.repo.fork + steps: + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + + - name: Install Rye + run: | + curl -sSf https://rye.astral.sh/get | bash + echo "$HOME/.rye/shims" >> $GITHUB_PATH + env: + RYE_VERSION: '0.44.0' + RYE_INSTALL_OPTION: '--yes' + + - name: Bootstrap + run: ./scripts/bootstrap + + - name: Run tests + run: ./scripts/test diff --git a/.github/workflows/publish-pypi.yml b/.github/workflows/publish-pypi.yml new file mode 100644 index 00000000..8a772966 --- /dev/null +++ b/.github/workflows/publish-pypi.yml @@ -0,0 +1,31 @@ +# This workflow is triggered when a GitHub release is created. +# It can also be run manually to re-publish to PyPI in case it failed for some reason. +# You can run this workflow by navigating to https://www.github.com/writer/writer-python/actions/workflows/publish-pypi.yml +name: Publish PyPI +on: + workflow_dispatch: + + release: + types: [published] + +jobs: + publish: + name: publish + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + + - name: Install Rye + run: | + curl -sSf https://rye.astral.sh/get | bash + echo "$HOME/.rye/shims" >> $GITHUB_PATH + env: + RYE_VERSION: '0.44.0' + RYE_INSTALL_OPTION: '--yes' + + - name: Publish to PyPI + run: | + bash ./bin/publish-pypi + env: + PYPI_TOKEN: ${{ secrets.WRITER_PYPI_TOKEN || secrets.PYPI_TOKEN }} diff --git a/.github/workflows/release-doctor.yml b/.github/workflows/release-doctor.yml new file mode 100644 index 00000000..e797ffe1 --- /dev/null +++ b/.github/workflows/release-doctor.yml @@ -0,0 +1,21 @@ +name: Release Doctor +on: + pull_request: + branches: + - main + workflow_dispatch: + +jobs: + release_doctor: + name: release doctor + runs-on: ubuntu-latest + if: github.repository == 'writer/writer-python' && (github.event_name == 'push' || github.event_name == 'workflow_dispatch' || startsWith(github.head_ref, 'release-please') || github.head_ref == 'next') + + steps: + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + + - name: Check release environment + run: | + bash ./bin/check-release-environment + env: + PYPI_TOKEN: ${{ secrets.WRITER_PYPI_TOKEN || secrets.PYPI_TOKEN }} diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..3824f4c4 --- /dev/null +++ b/.gitignore @@ -0,0 +1,16 @@ +.prism.log +.stdy.log +_dev + +__pycache__ +.mypy_cache + +dist + +.venv +.idea + +.env +.envrc +codegen.log +Brewfile.lock.json diff --git a/.python-version b/.python-version new file mode 100644 index 00000000..43077b24 --- /dev/null +++ b/.python-version @@ -0,0 +1 @@ +3.9.18 diff --git a/.release-please-manifest.json b/.release-please-manifest.json new file mode 100644 index 00000000..4191c889 --- /dev/null +++ b/.release-please-manifest.json @@ -0,0 +1,3 @@ +{ + ".": "3.0.0" +} \ No newline at end of file diff --git a/.stats.yml b/.stats.yml new file mode 100644 index 00000000..45b858e1 --- /dev/null +++ b/.stats.yml @@ -0,0 +1,4 @@ +configured_endpoints: 30 +openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/writerai/writer-275de8f7afa2d37404ebebc082dda35e70ab94437de270b5bc6e2fdc94c9fdae.yml +openapi_spec_hash: 4d4a9ba232d19a6180e6d4a7d5566103 +config_hash: 8701b1a467238f1afdeceeb7feb1adc6 diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 00000000..5b010307 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "python.analysis.importFormat": "relative", +} diff --git a/Brewfile b/Brewfile new file mode 100644 index 00000000..492ca37b --- /dev/null +++ b/Brewfile @@ -0,0 +1,2 @@ +brew "rye" + diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 00000000..508bbc40 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,127 @@ +## Setting up the environment + +### With Rye + +We use [Rye](https://rye.astral.sh/) to manage dependencies because it will automatically provision a Python environment with the expected Python version. To set it up, run: + +```sh +$ ./scripts/bootstrap +``` + +Or [install Rye manually](https://rye.astral.sh/guide/installation/) and run: + +```sh +$ rye sync --all-features +``` + +You can then run scripts using `rye run python script.py` or by activating the virtual environment: + +```sh +# Activate the virtual environment - https://docs.python.org/3/library/venv.html#how-venvs-work +$ source .venv/bin/activate + +# now you can omit the `rye run` prefix +$ python script.py +``` + +### Without Rye + +Alternatively if you don't want to install `Rye`, you can stick with the standard `pip` setup by ensuring you have the Python version specified in `.python-version`, create a virtual environment however you desire and then install dependencies using this command: + +```sh +$ pip install -r requirements-dev.lock +``` + +## Modifying/Adding code + +Most of the SDK is generated code. Modifications to code will be persisted between generations, but may +result in merge conflicts between manual patches and changes from the generator. The generator will never +modify the contents of the `src/writerai/lib/` and `examples/` directories. + +## Adding and running examples + +All files in the `examples/` directory are not modified by the generator and can be freely edited or added to. + +```py +# add an example to examples/.py + +#!/usr/bin/env -S rye run python +… +``` + +```sh +$ chmod +x examples/.py +# run the example against your api +$ ./examples/.py +``` + +## Using the repository from source + +If you’d like to use the repository from source, you can either install from git or link to a cloned repository: + +To install via git: + +```sh +$ pip install git+ssh://git@github.com/writer/writer-python.git +``` + +Alternatively, you can build from source and install the wheel file: + +Building this package will create two files in the `dist/` directory, a `.tar.gz` containing the source files and a `.whl` that can be used to install the package efficiently. + +To create a distributable version of the library, all you have to do is run this command: + +```sh +$ rye build +# or +$ python -m build +``` + +Then to install: + +```sh +$ pip install ./path-to-wheel-file.whl +``` + +## Running tests + +Most tests require you to [set up a mock server](https://github.com/dgellow/steady) against the OpenAPI spec to run the tests. + +```sh +$ ./scripts/mock +``` + +```sh +$ ./scripts/test +``` + +## Linting and formatting + +This repository uses [ruff](https://github.com/astral-sh/ruff) and +[black](https://github.com/psf/black) to format the code in the repository. + +To lint: + +```sh +$ ./scripts/lint +``` + +To format and fix all ruff issues automatically: + +```sh +$ ./scripts/format +``` + +## Publishing and releases + +Changes made to this repository via the automated release PR pipeline should publish to PyPI automatically. If +the changes aren't made through the automated pipeline, you may want to make releases manually. + +### Publish with a GitHub workflow + +You can release to package managers by using [the `Publish PyPI` GitHub action](https://www.github.com/writer/writer-python/actions/workflows/publish-pypi.yml). This requires a setup organization or repository secret to be set up. + +### Publish manually + +If you need to manually release a package, you can run the `bin/publish-pypi` script with a `PYPI_TOKEN` set on +the environment. diff --git a/LICENSE b/LICENSE new file mode 100644 index 00000000..a1e82cd6 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2026 Writer + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md index 9c18f8ee..8be84e69 100644 --- a/README.md +++ b/README.md @@ -1 +1,559 @@ -# writer-python \ No newline at end of file +# Writer Python API library + + +[![PyPI version](https://img.shields.io/pypi/v/writer-sdk.svg?label=pypi%20(stable))](https://pypi.org/project/writer-sdk/) + +The Writer Python library provides convenient access to the Writer REST API from any Python 3.9+ +application. The library includes type definitions for all request params and response fields, +and offers both synchronous and asynchronous clients powered by [httpx](https://github.com/encode/httpx). + +It is generated with [Stainless](https://www.stainless.com/). + +## MCP Server + +Use the Writer MCP Server to enable AI assistants to interact with this API, allowing them to explore endpoints, make test requests, and use documentation to help integrate this SDK into your application. + +[![Add to Cursor](https://cursor.com/deeplink/mcp-install-dark.svg)](https://cursor.com/en-US/install-mcp?name=writer-sdk-mcp&config=eyJjb21tYW5kIjoibnB4IiwiYXJncyI6WyIteSIsIndyaXRlci1zZGstbWNwIl0sImVudiI6eyJXUklURVJfQVBJX0tFWSI6Ik15IEFQSSBLZXkifX0) +[![Install in VS Code](https://img.shields.io/badge/_-Add_to_VS_Code-blue?style=for-the-badge&logo=data:image/svg%2bxml;base64,PHN2ZyB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciIGZpbGw9Im5vbmUiIHZpZXdCb3g9IjAgMCA0MCA0MCI+PHBhdGggZmlsbD0iI0VFRSIgZmlsbC1ydWxlPSJldmVub2RkIiBkPSJNMzAuMjM1IDM5Ljg4NGEyLjQ5MSAyLjQ5MSAwIDAgMS0xLjc4MS0uNzNMMTIuNyAyNC43OGwtMy40NiAyLjYyNC0zLjQwNiAyLjU4MmExLjY2NSAxLjY2NSAwIDAgMS0xLjA4Mi4zMzggMS42NjQgMS42NjQgMCAwIDEtMS4wNDYtLjQzMWwtMi4yLTJhMS42NjYgMS42NjYgMCAwIDEgMC0yLjQ2M0w3LjQ1OCAyMCA0LjY3IDE3LjQ1MyAxLjUwNyAxNC41N2ExLjY2NSAxLjY2NSAwIDAgMSAwLTIuNDYzbDIuMi0yYTEuNjY1IDEuNjY1IDAgMCAxIDIuMTMtLjA5N2w2Ljg2MyA1LjIwOUwyOC40NTIuODQ0YTIuNDg4IDIuNDg4IDAgMCAxIDEuODQxLS43MjljLjM1MS4wMDkuNjk5LjA5MSAxLjAxOS4yNDVsOC4yMzYgMy45NjFhMi41IDIuNSAwIDAgMSAxLjQxNSAyLjI1M3YuMDk5LS4wNDVWMzMuMzd2LS4wNDUuMDk1YTIuNTAxIDIuNTAxIDAgMCAxLTEuNDE2IDIuMjU3bC04LjIzNSAzLjk2MWEyLjQ5MiAyLjQ5MiAwIDAgMS0xLjA3Ny4yNDZabS43MTYtMjguOTQ3LTExLjk0OCA5LjA2MiAxMS45NTIgOS4wNjUtLjAwNC0xOC4xMjdaIi8+PC9zdmc+)](https://vscode.stainless.com/mcp/%7B%22name%22%3A%22writer-sdk-mcp%22%2C%22command%22%3A%22npx%22%2C%22args%22%3A%5B%22-y%22%2C%22writer-sdk-mcp%22%5D%2C%22env%22%3A%7B%22WRITER_API_KEY%22%3A%22My%20API%20Key%22%7D%7D) + +> Note: You may need to set environment variables in your MCP client. + +## Documentation + +The REST API documentation can be found on [dev.writer.com](https://dev.writer.com/api-guides/introduction). The full API of this library can be found in [api.md](api.md). + +## Installation + +```sh +# install from PyPI +pip install writer-sdk +``` + +## Usage + +The full API of this library can be found in [api.md](api.md). + +```python +import os +from writerai import Writer + +client = Writer( + api_key=os.environ.get("WRITER_API_KEY"), # This is the default and can be omitted +) + +chat_completion = client.chat.chat( + messages=[ + { + "content": "Write a haiku about programming", + "role": "user", + } + ], + model="palmyra-x5", +) +print(chat_completion.id) +``` + +While you can provide an `api_key` keyword argument, +we recommend using [python-dotenv](https://pypi.org/project/python-dotenv/) +to add `WRITER_API_KEY="My API Key"` to your `.env` file +so that your API Key is not stored in source control. + +## Async usage + +Simply import `AsyncWriter` instead of `Writer` and use `await` with each API call: + +```python +import os +import asyncio +from writerai import AsyncWriter + +client = AsyncWriter( + api_key=os.environ.get("WRITER_API_KEY"), # This is the default and can be omitted +) + + +async def main() -> None: + chat_completion = await client.chat.chat( + messages=[ + { + "content": "Write a haiku about programming", + "role": "user", + } + ], + model="palmyra-x5", + ) + print(chat_completion.id) + + +asyncio.run(main()) +``` + +Functionality between the synchronous and asynchronous clients is otherwise identical. + +### With aiohttp + +By default, the async client uses `httpx` for HTTP requests. However, for improved concurrency performance you may also use `aiohttp` as the HTTP backend. + +You can enable this by installing `aiohttp`: + +```sh +# install from PyPI +pip install writer-sdk[aiohttp] +``` + +Then you can enable it by instantiating the client with `http_client=DefaultAioHttpClient()`: + +```python +import os +import asyncio +from writerai import DefaultAioHttpClient +from writerai import AsyncWriter + + +async def main() -> None: + async with AsyncWriter( + api_key=os.environ.get("WRITER_API_KEY"), # This is the default and can be omitted + http_client=DefaultAioHttpClient(), + ) as client: + chat_completion = await client.chat.chat( + messages=[ + { + "content": "Write a haiku about programming", + "role": "user", + } + ], + model="palmyra-x5", + ) + print(chat_completion.id) + + +asyncio.run(main()) +``` + +## Streaming responses + +We provide support for streaming responses using Server Side Events (SSE). + +```python +from writerai import Writer + +client = Writer() + +stream = client.chat.chat( + messages=[ + { + "content": "Write a haiku about programming", + "role": "user", + } + ], + model="palmyra-x5", + stream=True, +) +for chat_completion in stream: + print(chat_completion.id) +``` + +The async client uses the exact same interface. + +```python +from writerai import AsyncWriter + +client = AsyncWriter() + +stream = await client.chat.chat( + messages=[ + { + "content": "Write a haiku about programming", + "role": "user", + } + ], + model="palmyra-x5", + stream=True, +) +async for chat_completion in stream: + print(chat_completion.id) +``` + +## Using types + +Nested request parameters are [TypedDicts](https://docs.python.org/3/library/typing.html#typing.TypedDict). Responses are [Pydantic models](https://docs.pydantic.dev) which also provide helper methods for things like: + +- Serializing back into JSON, `model.to_json()` +- Converting to a dictionary, `model.to_dict()` + +Typed requests and responses provide autocomplete and documentation within your editor. If you would like to see type errors in VS Code to help catch bugs earlier, set `python.analysis.typeCheckingMode` to `basic`. + +## Pagination + +List methods in the Writer API are paginated. + +This library provides auto-paginating iterators with each list response, so you do not have to request successive pages manually: + +```python +from writerai import Writer + +client = Writer() + +all_graphs = [] +# Automatically fetches more pages as needed. +for graph in client.graphs.list(): + # Do something with graph here + all_graphs.append(graph) +print(all_graphs) +``` + +Or, asynchronously: + +```python +import asyncio +from writerai import AsyncWriter + +client = AsyncWriter() + + +async def main() -> None: + all_graphs = [] + # Iterate through items across all pages, issuing requests as needed. + async for graph in client.graphs.list(): + all_graphs.append(graph) + print(all_graphs) + + +asyncio.run(main()) +``` + +Alternatively, you can use the `.has_next_page()`, `.next_page_info()`, or `.get_next_page()` methods for more granular control working with pages: + +```python +first_page = await client.graphs.list() +if first_page.has_next_page(): + print(f"will fetch next page using these details: {first_page.next_page_info()}") + next_page = await first_page.get_next_page() + print(f"number of items we just fetched: {len(next_page.data)}") + +# Remove `await` for non-async usage. +``` + +Or just work directly with the returned data: + +```python +first_page = await client.graphs.list() + +print(f"next page cursor: {first_page.after}") # => "next page cursor: ..." +for graph in first_page.data: + print(graph.id) + +# Remove `await` for non-async usage. +``` + +## Nested params + +Nested parameters are dictionaries, typed using `TypedDict`, for example: + +```python +from writerai import Writer + +client = Writer() + +chat_completion = client.chat.chat( + messages=[{"role": "user"}], + model="model", + response_format={"type": "text"}, +) +print(chat_completion.response_format) +``` + +## Handling errors + +When the library is unable to connect to the API (for example, due to network connection problems or a timeout), a subclass of `writerai.APIConnectionError` is raised. + +When the API returns a non-success status code (that is, 4xx or 5xx +response), a subclass of `writerai.APIStatusError` is raised, containing `status_code` and `response` properties. + +All errors inherit from `writerai.APIError`. + +```python +import writerai +from writerai import Writer + +client = Writer() + +try: + client.chat.chat( + messages=[ + { + "content": "Write a haiku about programming", + "role": "user", + } + ], + model="palmyra-x5", + ) +except writerai.APIConnectionError as e: + print("The server could not be reached") + print(e.__cause__) # an underlying Exception, likely raised within httpx. +except writerai.RateLimitError as e: + print("A 429 status code was received; we should back off a bit.") +except writerai.APIStatusError as e: + print("Another non-200-range status code was received") + print(e.status_code) + print(e.response) +``` + +Error codes are as follows: + +| Status Code | Error Type | +| ----------- | -------------------------- | +| 400 | `BadRequestError` | +| 401 | `AuthenticationError` | +| 403 | `PermissionDeniedError` | +| 404 | `NotFoundError` | +| 422 | `UnprocessableEntityError` | +| 429 | `RateLimitError` | +| >=500 | `InternalServerError` | +| N/A | `APIConnectionError` | + +### Retries + +Certain errors are automatically retried 7 times by default, with a short exponential backoff. +Connection errors (for example, due to a network connectivity problem), 408 Request Timeout, 409 Conflict, +429 Rate Limit, and >=500 Internal errors are all retried by default. + +You can use the `max_retries` option to configure or disable retry settings: + +```python +from writerai import Writer + +# Configure the default for all requests: +client = Writer( + # default is 2 + max_retries=0, +) + +# Or, configure per-request: +client.with_options(max_retries=5).chat.chat( + messages=[ + { + "content": "Write a haiku about programming", + "role": "user", + } + ], + model="palmyra-x5", +) +``` + +### Timeouts + +By default requests time out after 3 minutes. You can configure this with a `timeout` option, +which accepts a float or an [`httpx.Timeout`](https://www.python-httpx.org/advanced/timeouts/#fine-tuning-the-configuration) object: + +```python +from writerai import Writer + +# Configure the default for all requests: +client = Writer( + # 20 seconds (default is 3 minutes) + timeout=20.0, +) + +# More granular control: +client = Writer( + timeout=httpx.Timeout(60.0, read=5.0, write=10.0, connect=2.0), +) + +# Override per-request: +client.with_options(timeout=5.0).chat.chat( + messages=[ + { + "content": "Write a haiku about programming", + "role": "user", + } + ], + model="palmyra-x5", +) +``` + +On timeout, an `APITimeoutError` is thrown. + +Note that requests that time out are [retried twice by default](#retries). + +## Advanced + +### Logging + +We use the standard library [`logging`](https://docs.python.org/3/library/logging.html) module. + +You can enable logging by setting the environment variable `WRITER_LOG` to `info`. + +```shell +$ export WRITER_LOG=info +``` + +Or to `debug` for more verbose logging. + +### How to tell whether `None` means `null` or missing + +In an API response, a field may be explicitly `null`, or missing entirely; in either case, its value is `None` in this library. You can differentiate the two cases with `.model_fields_set`: + +```py +if response.my_field is None: + if 'my_field' not in response.model_fields_set: + print('Got json like {}, without a "my_field" key present at all.') + else: + print('Got json like {"my_field": null}.') +``` + +### Accessing raw response data (e.g. headers) + +The "raw" Response object can be accessed by prefixing `.with_raw_response.` to any HTTP method call, e.g., + +```py +from writerai import Writer + +client = Writer() +response = client.chat.with_raw_response.chat( + messages=[{ + "content": "Write a haiku about programming", + "role": "user", + }], + model="palmyra-x5", +) +print(response.headers.get('X-My-Header')) + +chat = response.parse() # get the object that `chat.chat()` would have returned +print(chat.id) +``` + +These methods return an [`APIResponse`](https://github.com/writer/writer-python/tree/main/src/writerai/_response.py) object. + +The async client returns an [`AsyncAPIResponse`](https://github.com/writer/writer-python/tree/main/src/writerai/_response.py) with the same structure, the only difference being `await`able methods for reading the response content. + +#### `.with_streaming_response` + +The above interface eagerly reads the full response body when you make the request, which may not always be what you want. + +To stream the response body, use `.with_streaming_response` instead, which requires a context manager and only reads the response body once you call `.read()`, `.text()`, `.json()`, `.iter_bytes()`, `.iter_text()`, `.iter_lines()` or `.parse()`. In the async client, these are async methods. + +```python +with client.chat.with_streaming_response.chat( + messages=[ + { + "content": "Write a haiku about programming", + "role": "user", + } + ], + model="palmyra-x5", +) as response: + print(response.headers.get("X-My-Header")) + + for line in response.iter_lines(): + print(line) +``` + +The context manager is required so that the response will reliably be closed. + +### Making custom/undocumented requests + +This library is typed for convenient access to the documented API. + +If you need to access undocumented endpoints, params, or response properties, the library can still be used. + +#### Undocumented endpoints + +To make requests to undocumented endpoints, you can make requests using `client.get`, `client.post`, and other +http verbs. Options on the client will be respected (such as retries) when making this request. + +```py +import httpx + +response = client.post( + "/foo", + cast_to=httpx.Response, + body={"my_param": True}, +) + +print(response.headers.get("x-foo")) +``` + +#### Undocumented request params + +If you want to explicitly send an extra param, you can do so with the `extra_query`, `extra_body`, and `extra_headers` request +options. + +#### Undocumented response properties + +To access undocumented response properties, you can access the extra fields like `response.unknown_prop`. You +can also get all the extra fields on the Pydantic model as a dict with +[`response.model_extra`](https://docs.pydantic.dev/latest/api/base_model/#pydantic.BaseModel.model_extra). + +### Configuring the HTTP client + +You can directly override the [httpx client](https://www.python-httpx.org/api/#client) to customize it for your use case, including: + +- Support for [proxies](https://www.python-httpx.org/advanced/proxies/) +- Custom [transports](https://www.python-httpx.org/advanced/transports/) +- Additional [advanced](https://www.python-httpx.org/advanced/clients/) functionality + +```python +import httpx +from writerai import Writer, DefaultHttpxClient + +client = Writer( + # Or use the `WRITER_BASE_URL` env var + base_url="http://my.test.server.example.com:8083", + http_client=DefaultHttpxClient( + proxy="http://my.test.proxy.example.com", + transport=httpx.HTTPTransport(local_address="0.0.0.0"), + ), +) +``` + +You can also customize the client on a per-request basis by using `with_options()`: + +```python +client.with_options(http_client=DefaultHttpxClient(...)) +``` + +### Managing HTTP resources + +By default the library closes underlying HTTP connections whenever the client is [garbage collected](https://docs.python.org/3/reference/datamodel.html#object.__del__). You can manually close the client using the `.close()` method if desired, or with a context manager that closes when exiting. + +```py +from writerai import Writer + +with Writer() as client: + # make requests here + ... + +# HTTP client is now closed +``` + +## Versioning + +This package generally follows [SemVer](https://semver.org/spec/v2.0.0.html) conventions, though certain backwards-incompatible changes may be released as minor versions: + +1. Changes that only affect static types, without breaking runtime behavior. +2. Changes to library internals which are technically public but not intended or documented for external use. _(Please open a GitHub issue to let us know if you are relying on such internals.)_ +3. Changes that we do not expect to impact the vast majority of users in practice. + +We take backwards-compatibility seriously and work hard to ensure you can rely on a smooth upgrade experience. + +We are keen for your feedback; please open an [issue](https://www.github.com/writer/writer-python/issues) with questions, bugs, or suggestions. + +### Determining the installed version + +If you've upgraded to the latest version but aren't seeing any new features you were expecting then your python environment is likely still using an older version. + +You can determine the version that is being used at runtime with: + +```py +import writerai +print(writerai.__version__) +``` + +## Requirements + +Python 3.9 or higher. + +## Contributing + +See [the contributing documentation](./CONTRIBUTING.md). diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 00000000..edf66351 --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,27 @@ +# Security Policy + +## Reporting Security Issues + +This SDK is generated by [Stainless Software Inc](http://stainless.com). Stainless takes security seriously, and encourages you to report any security vulnerability promptly so that appropriate action can be taken. + +To report a security issue, please contact the Stainless team at security@stainless.com. + +## Responsible Disclosure + +We appreciate the efforts of security researchers and individuals who help us maintain the security of +SDKs we generate. If you believe you have found a security vulnerability, please adhere to responsible +disclosure practices by allowing us a reasonable amount of time to investigate and address the issue +before making any information public. + +## Reporting Non-SDK Related Security Issues + +If you encounter security issues that are not directly related to SDKs but pertain to the services +or products provided by Writer, please follow the respective company's security reporting guidelines. + +### Writer Terms and Policies + +Please contact dev-feedback@writer.com for any questions or concerns regarding the security of our services. + +--- + +Thank you for helping us keep the SDKs and systems they interact with secure. diff --git a/api.md b/api.md new file mode 100644 index 00000000..7377b2f1 --- /dev/null +++ b/api.md @@ -0,0 +1,195 @@ +# Shared Types + +```python +from writerai.types import ( + ErrorMessage, + ErrorObject, + FunctionDefinition, + FunctionParams, + GraphData, + Logprobs, + LogprobsToken, + Source, + ToolCall, + ToolCallStreaming, + ToolChoiceJsonObject, + ToolChoiceString, + ToolParam, +) +``` + +# Applications + +Types: + +```python +from writerai.types import ( + ApplicationGenerateContentChunk, + ApplicationGenerateContentResponse, + ApplicationRetrieveResponse, + ApplicationListResponse, +) +``` + +Methods: + +- client.applications.retrieve(application_id) -> ApplicationRetrieveResponse +- client.applications.list(\*\*params) -> SyncCursorPage[ApplicationListResponse] +- client.applications.generate_content(application_id, \*\*params) -> ApplicationGenerateContentResponse + +## Jobs + +Types: + +```python +from writerai.types.applications import ( + ApplicationGenerateAsyncResponse, + ApplicationJobsListResponse, + JobCreateResponse, + JobRetryResponse, +) +``` + +Methods: + +- client.applications.jobs.create(application_id, \*\*params) -> JobCreateResponse +- client.applications.jobs.retrieve(job_id) -> ApplicationGenerateAsyncResponse +- client.applications.jobs.list(application_id, \*\*params) -> SyncApplicationJobsOffset[ApplicationGenerateAsyncResponse] +- client.applications.jobs.retry(job_id) -> JobRetryResponse + +## Graphs + +Types: + +```python +from writerai.types.applications import ApplicationGraphsResponse +``` + +Methods: + +- client.applications.graphs.update(application_id, \*\*params) -> ApplicationGraphsResponse +- client.applications.graphs.list(application_id) -> ApplicationGraphsResponse + +# Chat + +Types: + +```python +from writerai.types import ( + ChatCompletion, + ChatCompletionChoice, + ChatCompletionChunk, + ChatCompletionMessage, + ChatCompletionParams, + ChatCompletionUsage, +) +``` + +Methods: + +- client.chat.chat(\*\*params) -> ChatCompletion + +# Completions + +Types: + +```python +from writerai.types import Completion, CompletionChunk, CompletionParams +``` + +Methods: + +- client.completions.create(\*\*params) -> Completion + +# Models + +Types: + +```python +from writerai.types import ModelListResponse +``` + +Methods: + +- client.models.list() -> ModelListResponse + +# Graphs + +Types: + +```python +from writerai.types import ( + Graph, + Question, + QuestionResponseChunk, + GraphCreateResponse, + GraphUpdateResponse, + GraphDeleteResponse, + GraphRemoveFileFromGraphResponse, +) +``` + +Methods: + +- client.graphs.create(\*\*params) -> GraphCreateResponse +- client.graphs.retrieve(graph_id) -> Graph +- client.graphs.update(graph_id, \*\*params) -> GraphUpdateResponse +- client.graphs.list(\*\*params) -> SyncCursorPage[Graph] +- client.graphs.delete(graph_id) -> GraphDeleteResponse +- client.graphs.add_file_to_graph(graph_id, \*\*params) -> File +- client.graphs.question(\*\*params) -> Question +- client.graphs.remove_file_from_graph(file_id, \*, graph_id) -> GraphRemoveFileFromGraphResponse + +# Files + +Types: + +```python +from writerai.types import File, FileDeleteResponse, FileRetryResponse +``` + +Methods: + +- client.files.retrieve(file_id) -> File +- client.files.list(\*\*params) -> SyncCursorPage[File] +- client.files.delete(file_id) -> FileDeleteResponse +- client.files.download(file_id) -> BinaryAPIResponse +- client.files.retry(\*\*params) -> FileRetryResponse +- client.files.upload(content, \*\*params) -> File + +# Tools + +Types: + +```python +from writerai.types import ToolParsePdfResponse, ToolWebSearchResponse +``` + +Methods: + +- client.tools.parse_pdf(file_id, \*\*params) -> ToolParsePdfResponse +- client.tools.web_search(\*\*params) -> ToolWebSearchResponse + +# Translation + +Types: + +```python +from writerai.types import TranslationRequest, TranslationResponse +``` + +Methods: + +- client.translation.translate(\*\*params) -> TranslationResponse + +# Vision + +Types: + +```python +from writerai.types import VisionRequest, VisionResponse +``` + +Methods: + +- client.vision.analyze(\*\*params) -> VisionResponse diff --git a/bin/check-release-environment b/bin/check-release-environment new file mode 100644 index 00000000..b845b0f4 --- /dev/null +++ b/bin/check-release-environment @@ -0,0 +1,21 @@ +#!/usr/bin/env bash + +errors=() + +if [ -z "${PYPI_TOKEN}" ]; then + errors+=("The PYPI_TOKEN secret has not been set. Please set it in either this repository's secrets or your organization secrets.") +fi + +lenErrors=${#errors[@]} + +if [[ lenErrors -gt 0 ]]; then + echo -e "Found the following errors in the release environment:\n" + + for error in "${errors[@]}"; do + echo -e "- $error\n" + done + + exit 1 +fi + +echo "The environment is ready to push releases!" diff --git a/bin/publish-pypi b/bin/publish-pypi new file mode 100644 index 00000000..826054e9 --- /dev/null +++ b/bin/publish-pypi @@ -0,0 +1,6 @@ +#!/usr/bin/env bash + +set -eux +mkdir -p dist +rye build --clean +rye publish --yes --token=$PYPI_TOKEN diff --git a/examples/.keep b/examples/.keep new file mode 100644 index 00000000..d8c73e93 --- /dev/null +++ b/examples/.keep @@ -0,0 +1,4 @@ +File generated from our OpenAPI spec by Stainless. + +This directory can be used to store example files demonstrating usage of this SDK. +It is ignored by Stainless code generation and its content (other than this keep file) won't be touched. \ No newline at end of file diff --git a/noxfile.py b/noxfile.py new file mode 100644 index 00000000..53bca7ff --- /dev/null +++ b/noxfile.py @@ -0,0 +1,9 @@ +import nox + + +@nox.session(reuse_venv=True, name="test-pydantic-v1") +def test_pydantic_v1(session: nox.Session) -> None: + session.install("-r", "requirements-dev.lock") + session.install("pydantic<2") + + session.run("pytest", "--showlocals", "--ignore=tests/functional", *session.posargs) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..229c1a83 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,269 @@ +[project] +name = "writer-sdk" +version = "3.0.0" +description = "The official Python library for the writer API" +dynamic = ["readme"] +license = "Apache-2.0" +authors = [ +{ name = "Writer", email = "dev-feedback@writer.com" }, +] + +dependencies = [ + "httpx>=0.23.0, <1", + "pydantic>=1.9.0, <3", + "typing-extensions>=4.14, <5", + "anyio>=3.5.0, <5", + "distro>=1.7.0, <2", + "sniffio", +] + +requires-python = ">= 3.9" +classifiers = [ + "Typing :: Typed", + "Intended Audience :: Developers", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", + "Operating System :: OS Independent", + "Operating System :: POSIX", + "Operating System :: MacOS", + "Operating System :: POSIX :: Linux", + "Operating System :: Microsoft :: Windows", + "Topic :: Software Development :: Libraries :: Python Modules", + "License :: OSI Approved :: Apache Software License" +] + +[project.urls] +Homepage = "https://github.com/writer/writer-python" +Repository = "https://github.com/writer/writer-python" + +[project.optional-dependencies] +aiohttp = ["aiohttp", "httpx_aiohttp>=0.1.9"] + +[tool.rye] +managed = true +# version pins are in requirements-dev.lock +dev-dependencies = [ + "pyright==1.1.399", + "mypy==1.17", + "respx", + "pytest", + "pytest-asyncio", + "ruff", + "time-machine", + "nox", + "dirty-equals>=0.6.0", + "importlib-metadata>=6.7.0", + "rich>=13.7.1", + "pytest-xdist>=3.6.1", +] + +[tool.rye.scripts] +format = { chain = [ + "format:ruff", + "format:docs", + "fix:ruff", + # run formatting again to fix any inconsistencies when imports are stripped + "format:ruff", +]} +"format:docs" = "bash -c 'python scripts/utils/ruffen-docs.py README.md $(find . -type f -name api.md)'" +"format:ruff" = "ruff format" + +"lint" = { chain = [ + "check:ruff", + "typecheck", + "check:importable", +]} +"check:ruff" = "ruff check ." +"fix:ruff" = "ruff check --fix ." + +"check:importable" = "python -c 'import writerai'" + +typecheck = { chain = [ + "typecheck:pyright", + "typecheck:mypy" +]} +"typecheck:pyright" = "pyright" +"typecheck:verify-types" = "pyright --verifytypes writerai --ignoreexternal" +"typecheck:mypy" = "mypy ." + +[build-system] +requires = ["hatchling==1.26.3", "hatch-fancy-pypi-readme"] +build-backend = "hatchling.build" + +[tool.hatch.build] +include = [ + "src/*" +] + +[tool.hatch.build.targets.wheel] +packages = ["src/writerai"] + +[tool.hatch.build.targets.sdist] +# Basically everything except hidden files/directories (such as .github, .devcontainers, .python-version, etc) +include = [ + "/*.toml", + "/*.json", + "/*.lock", + "/*.md", + "/mypy.ini", + "/noxfile.py", + "bin/*", + "examples/*", + "src/*", + "tests/*", +] + +[tool.hatch.metadata.hooks.fancy-pypi-readme] +content-type = "text/markdown" + +[[tool.hatch.metadata.hooks.fancy-pypi-readme.fragments]] +path = "README.md" + +[[tool.hatch.metadata.hooks.fancy-pypi-readme.substitutions]] +# replace relative links with absolute links +pattern = '\[(.+?)\]\(((?!https?://)\S+?)\)' +replacement = '[\1](https://github.com/writer/writer-python/tree/main/\g<2>)' + +[tool.pytest.ini_options] +testpaths = ["tests"] +addopts = "--tb=short -n auto" +xfail_strict = true +asyncio_mode = "auto" +asyncio_default_fixture_loop_scope = "session" +filterwarnings = [ + "error" +] + +[tool.pyright] +# this enables practically every flag given by pyright. +# there are a couple of flags that are still disabled by +# default in strict mode as they are experimental and niche. +typeCheckingMode = "strict" +pythonVersion = "3.9" + +exclude = [ + "_dev", + ".venv", + ".nox", + ".git", +] + +reportImplicitOverride = true +reportOverlappingOverload = false + +reportImportCycles = false +reportPrivateUsage = false + +[tool.mypy] +pretty = true +show_error_codes = true + +# Exclude _files.py because mypy isn't smart enough to apply +# the correct type narrowing and as this is an internal module +# it's fine to just use Pyright. +# +# We also exclude our `tests` as mypy doesn't always infer +# types correctly and Pyright will still catch any type errors. +exclude = ["src/writerai/_files.py", "_dev/.*.py", "tests/.*"] + +strict_equality = true +implicit_reexport = true +check_untyped_defs = true +no_implicit_optional = true + +warn_return_any = true +warn_unreachable = true +warn_unused_configs = true + +# Turn these options off as it could cause conflicts +# with the Pyright options. +warn_unused_ignores = false +warn_redundant_casts = false + +disallow_any_generics = true +disallow_untyped_defs = true +disallow_untyped_calls = true +disallow_subclassing_any = true +disallow_incomplete_defs = true +disallow_untyped_decorators = true +cache_fine_grained = true + +# By default, mypy reports an error if you assign a value to the result +# of a function call that doesn't return anything. We do this in our test +# cases: +# ``` +# result = ... +# assert result is None +# ``` +# Changing this codegen to make mypy happy would increase complexity +# and would not be worth it. +disable_error_code = "func-returns-value,overload-cannot-match" + +# https://github.com/python/mypy/issues/12162 +[[tool.mypy.overrides]] +module = "black.files.*" +ignore_errors = true +ignore_missing_imports = true + + +[tool.ruff] +line-length = 120 +output-format = "grouped" +target-version = "py38" + +[tool.ruff.format] +docstring-code-format = true + +[tool.ruff.lint] +select = [ + # isort + "I", + # bugbear rules + "B", + # remove unused imports + "F401", + # check for missing future annotations + "FA102", + # bare except statements + "E722", + # unused arguments + "ARG", + # print statements + "T201", + "T203", + # misuse of typing.TYPE_CHECKING + "TC004", + # import rules + "TID251", +] +ignore = [ + # mutable defaults + "B006", +] +unfixable = [ + # disable auto fix for print statements + "T201", + "T203", +] + +extend-safe-fixes = ["FA102"] + +[tool.ruff.lint.flake8-tidy-imports.banned-api] +"functools.lru_cache".msg = "This function does not retain type information for the wrapped function's arguments; The `lru_cache` function from `_utils` should be used instead" + +[tool.ruff.lint.isort] +length-sort = true +length-sort-straight = true +combine-as-imports = true +extra-standard-library = ["typing_extensions"] +known-first-party = ["writerai", "tests"] + +[tool.ruff.lint.per-file-ignores] +"bin/**.py" = ["T201", "T203"] +"scripts/**.py" = ["T201", "T203"] +"tests/**.py" = ["T201", "T203"] +"examples/**.py" = ["T201", "T203"] diff --git a/release-please-config.json b/release-please-config.json new file mode 100644 index 00000000..6b62e34d --- /dev/null +++ b/release-please-config.json @@ -0,0 +1,66 @@ +{ + "packages": { + ".": {} + }, + "$schema": "https://raw.githubusercontent.com/stainless-api/release-please/main/schemas/config.json", + "include-v-in-tag": true, + "include-component-in-tag": false, + "versioning": "prerelease", + "prerelease": true, + "bump-minor-pre-major": true, + "bump-patch-for-minor-pre-major": false, + "pull-request-header": "Automated Release PR", + "pull-request-title-pattern": "release: ${version}", + "changelog-sections": [ + { + "type": "feat", + "section": "Features" + }, + { + "type": "fix", + "section": "Bug Fixes" + }, + { + "type": "perf", + "section": "Performance Improvements" + }, + { + "type": "revert", + "section": "Reverts" + }, + { + "type": "chore", + "section": "Chores" + }, + { + "type": "docs", + "section": "Documentation" + }, + { + "type": "style", + "section": "Styles" + }, + { + "type": "refactor", + "section": "Refactors" + }, + { + "type": "test", + "section": "Tests", + "hidden": true + }, + { + "type": "build", + "section": "Build System" + }, + { + "type": "ci", + "section": "Continuous Integration", + "hidden": true + } + ], + "release-type": "python", + "extra-files": [ + "src/writerai/_version.py" + ] +} \ No newline at end of file diff --git a/requirements-dev.lock b/requirements-dev.lock new file mode 100644 index 00000000..96b012f6 --- /dev/null +++ b/requirements-dev.lock @@ -0,0 +1,149 @@ +# generated by rye +# use `rye lock` or `rye sync` to update this lockfile +# +# last locked with the following flags: +# pre: false +# features: [] +# all-features: true +# with-sources: false +# generate-hashes: false +# universal: false + +-e file:. +aiohappyeyeballs==2.6.1 + # via aiohttp +aiohttp==3.13.3 + # via httpx-aiohttp + # via writer-sdk +aiosignal==1.4.0 + # via aiohttp +annotated-types==0.7.0 + # via pydantic +anyio==4.12.1 + # via httpx + # via writer-sdk +argcomplete==3.6.3 + # via nox +async-timeout==5.0.1 + # via aiohttp +attrs==25.4.0 + # via aiohttp + # via nox +backports-asyncio-runner==1.2.0 + # via pytest-asyncio +certifi==2026.1.4 + # via httpcore + # via httpx +colorlog==6.10.1 + # via nox +dependency-groups==1.3.1 + # via nox +dirty-equals==0.11 +distlib==0.4.0 + # via virtualenv +distro==1.9.0 + # via writer-sdk +exceptiongroup==1.3.1 + # via anyio + # via pytest +execnet==2.1.2 + # via pytest-xdist +filelock==3.19.1 + # via virtualenv +frozenlist==1.8.0 + # via aiohttp + # via aiosignal +h11==0.16.0 + # via httpcore +httpcore==1.0.9 + # via httpx +httpx==0.28.1 + # via httpx-aiohttp + # via respx + # via writer-sdk +httpx-aiohttp==0.1.12 + # via writer-sdk +humanize==4.13.0 + # via nox +idna==3.11 + # via anyio + # via httpx + # via yarl +importlib-metadata==8.7.1 +iniconfig==2.1.0 + # via pytest +markdown-it-py==3.0.0 + # via rich +mdurl==0.1.2 + # via markdown-it-py +multidict==6.7.0 + # via aiohttp + # via yarl +mypy==1.17.0 +mypy-extensions==1.1.0 + # via mypy +nodeenv==1.10.0 + # via pyright +nox==2025.11.12 +packaging==25.0 + # via dependency-groups + # via nox + # via pytest +pathspec==1.0.3 + # via mypy +platformdirs==4.4.0 + # via virtualenv +pluggy==1.6.0 + # via pytest +propcache==0.4.1 + # via aiohttp + # via yarl +pydantic==2.12.5 + # via writer-sdk +pydantic-core==2.41.5 + # via pydantic +pygments==2.19.2 + # via pytest + # via rich +pyright==1.1.399 +pytest==8.4.2 + # via pytest-asyncio + # via pytest-xdist +pytest-asyncio==1.2.0 +pytest-xdist==3.8.0 +python-dateutil==2.9.0.post0 + # via time-machine +respx==0.22.0 +rich==14.2.0 +ruff==0.14.13 +six==1.17.0 + # via python-dateutil +sniffio==1.3.1 + # via writer-sdk +time-machine==2.19.0 +tomli==2.4.0 + # via dependency-groups + # via mypy + # via nox + # via pytest +typing-extensions==4.15.0 + # via aiosignal + # via anyio + # via exceptiongroup + # via multidict + # via mypy + # via pydantic + # via pydantic-core + # via pyright + # via pytest-asyncio + # via typing-inspection + # via virtualenv + # via writer-sdk +typing-inspection==0.4.2 + # via pydantic +virtualenv==20.36.1 + # via nox +yarl==1.22.0 + # via aiohttp +zipp==3.23.0 + # via importlib-metadata diff --git a/requirements.lock b/requirements.lock new file mode 100644 index 00000000..da1e220e --- /dev/null +++ b/requirements.lock @@ -0,0 +1,76 @@ +# generated by rye +# use `rye lock` or `rye sync` to update this lockfile +# +# last locked with the following flags: +# pre: false +# features: [] +# all-features: true +# with-sources: false +# generate-hashes: false +# universal: false + +-e file:. +aiohappyeyeballs==2.6.1 + # via aiohttp +aiohttp==3.13.3 + # via httpx-aiohttp + # via writer-sdk +aiosignal==1.4.0 + # via aiohttp +annotated-types==0.7.0 + # via pydantic +anyio==4.12.1 + # via httpx + # via writer-sdk +async-timeout==5.0.1 + # via aiohttp +attrs==25.4.0 + # via aiohttp +certifi==2026.1.4 + # via httpcore + # via httpx +distro==1.9.0 + # via writer-sdk +exceptiongroup==1.3.1 + # via anyio +frozenlist==1.8.0 + # via aiohttp + # via aiosignal +h11==0.16.0 + # via httpcore +httpcore==1.0.9 + # via httpx +httpx==0.28.1 + # via httpx-aiohttp + # via writer-sdk +httpx-aiohttp==0.1.12 + # via writer-sdk +idna==3.11 + # via anyio + # via httpx + # via yarl +multidict==6.7.0 + # via aiohttp + # via yarl +propcache==0.4.1 + # via aiohttp + # via yarl +pydantic==2.12.5 + # via writer-sdk +pydantic-core==2.41.5 + # via pydantic +sniffio==1.3.1 + # via writer-sdk +typing-extensions==4.15.0 + # via aiosignal + # via anyio + # via exceptiongroup + # via multidict + # via pydantic + # via pydantic-core + # via typing-inspection + # via writer-sdk +typing-inspection==0.4.2 + # via pydantic +yarl==1.22.0 + # via aiohttp diff --git a/scripts/bootstrap b/scripts/bootstrap new file mode 100755 index 00000000..fe8451e4 --- /dev/null +++ b/scripts/bootstrap @@ -0,0 +1,27 @@ +#!/usr/bin/env bash + +set -e + +cd "$(dirname "$0")/.." + +if [ -f "Brewfile" ] && [ "$(uname -s)" = "Darwin" ] && [ "${SKIP_BREW:-}" != "1" ] && [ -t 0 ]; then + brew bundle check >/dev/null 2>&1 || { + echo -n "==> Install Homebrew dependencies? (y/N): " + read -r response + case "$response" in + [yY][eE][sS]|[yY]) + brew bundle + ;; + *) + ;; + esac + echo + } +fi + +echo "==> Installing Python dependencies…" + +# experimental uv support makes installations significantly faster +rye config --set-bool behavior.use-uv=true + +rye sync --all-features diff --git a/scripts/format b/scripts/format new file mode 100755 index 00000000..667ec2d7 --- /dev/null +++ b/scripts/format @@ -0,0 +1,8 @@ +#!/usr/bin/env bash + +set -e + +cd "$(dirname "$0")/.." + +echo "==> Running formatters" +rye run format diff --git a/scripts/lint b/scripts/lint new file mode 100755 index 00000000..6ac647f4 --- /dev/null +++ b/scripts/lint @@ -0,0 +1,16 @@ +#!/usr/bin/env bash + +set -e + +cd "$(dirname "$0")/.." + +if [ "$1" = "--fix" ]; then + echo "==> Running lints with --fix" + rye run fix:ruff +else + echo "==> Running lints" + rye run lint +fi + +echo "==> Making sure it imports" +rye run python -c 'import writerai' diff --git a/scripts/mock b/scripts/mock new file mode 100755 index 00000000..feebe5ed --- /dev/null +++ b/scripts/mock @@ -0,0 +1,52 @@ +#!/usr/bin/env bash + +set -e + +cd "$(dirname "$0")/.." + +if [[ -n "$1" && "$1" != '--'* ]]; then + URL="$1" + shift +else + URL="$(grep 'openapi_spec_url' .stats.yml | cut -d' ' -f2)" +fi + +# Check if the URL is empty +if [ -z "$URL" ]; then + echo "Error: No OpenAPI spec path/url provided or found in .stats.yml" + exit 1 +fi + +echo "==> Starting mock server with URL ${URL}" + +# Run steady mock on the given spec +if [ "$1" == "--daemon" ]; then + # Pre-install the package so the download doesn't eat into the startup timeout + npm exec --package=@stdy/cli@0.22.1 -- steady --version + + npm exec --package=@stdy/cli@0.22.1 -- steady --host 127.0.0.1 -p 4010 --validator-query-array-format=comma --validator-form-array-format=comma --validator-query-object-format=brackets --validator-form-object-format=brackets "$URL" &> .stdy.log & + + # Wait for server to come online via health endpoint (max 30s) + echo -n "Waiting for server" + attempts=0 + while ! curl --silent --fail "http://127.0.0.1:4010/_x-steady/health" >/dev/null 2>&1; do + if ! kill -0 $! 2>/dev/null; then + echo + cat .stdy.log + exit 1 + fi + attempts=$((attempts + 1)) + if [ "$attempts" -ge 300 ]; then + echo + echo "Timed out waiting for Steady server to start" + cat .stdy.log + exit 1 + fi + echo -n "." + sleep 0.1 + done + + echo +else + npm exec --package=@stdy/cli@0.22.1 -- steady --host 127.0.0.1 -p 4010 --validator-query-array-format=comma --validator-form-array-format=comma --validator-query-object-format=brackets --validator-form-object-format=brackets "$URL" +fi diff --git a/scripts/test b/scripts/test new file mode 100755 index 00000000..19acc916 --- /dev/null +++ b/scripts/test @@ -0,0 +1,61 @@ +#!/usr/bin/env bash + +set -e + +cd "$(dirname "$0")/.." + +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[0;33m' +NC='\033[0m' # No Color + +function steady_is_running() { + curl --silent "http://127.0.0.1:4010/_x-steady/health" >/dev/null 2>&1 +} + +kill_server_on_port() { + pids=$(lsof -t -i tcp:"$1" || echo "") + if [ "$pids" != "" ]; then + kill "$pids" + echo "Stopped $pids." + fi +} + +function is_overriding_api_base_url() { + [ -n "$TEST_API_BASE_URL" ] +} + +if ! is_overriding_api_base_url && ! steady_is_running ; then + # When we exit this script, make sure to kill the background mock server process + trap 'kill_server_on_port 4010' EXIT + + # Start the dev server + ./scripts/mock --daemon +fi + +if is_overriding_api_base_url ; then + echo -e "${GREEN}✔ Running tests against ${TEST_API_BASE_URL}${NC}" + echo +elif ! steady_is_running ; then + echo -e "${RED}ERROR:${NC} The test suite will not run without a mock Steady server" + echo -e "running against your OpenAPI spec." + echo + echo -e "To run the server, pass in the path or url of your OpenAPI" + echo -e "spec to the steady command:" + echo + echo -e " \$ ${YELLOW}npm exec --package=@stdy/cli@0.22.1 -- steady path/to/your.openapi.yml --host 127.0.0.1 -p 4010 --validator-query-array-format=comma --validator-form-array-format=comma --validator-query-object-format=brackets --validator-form-object-format=brackets${NC}" + echo + + exit 1 +else + echo -e "${GREEN}✔ Mock steady server is running with your OpenAPI spec${NC}" + echo +fi + +export DEFER_PYDANTIC_BUILD=false + +echo "==> Running tests" +rye run pytest "$@" + +echo "==> Running Pydantic v1 tests" +rye run nox -s test-pydantic-v1 -- "$@" diff --git a/scripts/utils/ruffen-docs.py b/scripts/utils/ruffen-docs.py new file mode 100644 index 00000000..0cf2bd2f --- /dev/null +++ b/scripts/utils/ruffen-docs.py @@ -0,0 +1,167 @@ +# fork of https://github.com/asottile/blacken-docs adapted for ruff +from __future__ import annotations + +import re +import sys +import argparse +import textwrap +import contextlib +import subprocess +from typing import Match, Optional, Sequence, Generator, NamedTuple, cast + +MD_RE = re.compile( + r"(?P^(?P *)```\s*python\n)" r"(?P.*?)" r"(?P^(?P=indent)```\s*$)", + re.DOTALL | re.MULTILINE, +) +MD_PYCON_RE = re.compile( + r"(?P^(?P *)```\s*pycon\n)" r"(?P.*?)" r"(?P^(?P=indent)```.*$)", + re.DOTALL | re.MULTILINE, +) +PYCON_PREFIX = ">>> " +PYCON_CONTINUATION_PREFIX = "..." +PYCON_CONTINUATION_RE = re.compile( + rf"^{re.escape(PYCON_CONTINUATION_PREFIX)}( |$)", +) +DEFAULT_LINE_LENGTH = 100 + + +class CodeBlockError(NamedTuple): + offset: int + exc: Exception + + +def format_str( + src: str, +) -> tuple[str, Sequence[CodeBlockError]]: + errors: list[CodeBlockError] = [] + + @contextlib.contextmanager + def _collect_error(match: Match[str]) -> Generator[None, None, None]: + try: + yield + except Exception as e: + errors.append(CodeBlockError(match.start(), e)) + + def _md_match(match: Match[str]) -> str: + code = textwrap.dedent(match["code"]) + with _collect_error(match): + code = format_code_block(code) + code = textwrap.indent(code, match["indent"]) + return f"{match['before']}{code}{match['after']}" + + def _pycon_match(match: Match[str]) -> str: + code = "" + fragment = cast(Optional[str], None) + + def finish_fragment() -> None: + nonlocal code + nonlocal fragment + + if fragment is not None: + with _collect_error(match): + fragment = format_code_block(fragment) + fragment_lines = fragment.splitlines() + code += f"{PYCON_PREFIX}{fragment_lines[0]}\n" + for line in fragment_lines[1:]: + # Skip blank lines to handle Black adding a blank above + # functions within blocks. A blank line would end the REPL + # continuation prompt. + # + # >>> if True: + # ... def f(): + # ... pass + # ... + if line: + code += f"{PYCON_CONTINUATION_PREFIX} {line}\n" + if fragment_lines[-1].startswith(" "): + code += f"{PYCON_CONTINUATION_PREFIX}\n" + fragment = None + + indentation = None + for line in match["code"].splitlines(): + orig_line, line = line, line.lstrip() + if indentation is None and line: + indentation = len(orig_line) - len(line) + continuation_match = PYCON_CONTINUATION_RE.match(line) + if continuation_match and fragment is not None: + fragment += line[continuation_match.end() :] + "\n" + else: + finish_fragment() + if line.startswith(PYCON_PREFIX): + fragment = line[len(PYCON_PREFIX) :] + "\n" + else: + code += orig_line[indentation:] + "\n" + finish_fragment() + return code + + def _md_pycon_match(match: Match[str]) -> str: + code = _pycon_match(match) + code = textwrap.indent(code, match["indent"]) + return f"{match['before']}{code}{match['after']}" + + src = MD_RE.sub(_md_match, src) + src = MD_PYCON_RE.sub(_md_pycon_match, src) + return src, errors + + +def format_code_block(code: str) -> str: + return subprocess.check_output( + [ + sys.executable, + "-m", + "ruff", + "format", + "--stdin-filename=script.py", + f"--line-length={DEFAULT_LINE_LENGTH}", + ], + encoding="utf-8", + input=code, + ) + + +def format_file( + filename: str, + skip_errors: bool, +) -> int: + with open(filename, encoding="UTF-8") as f: + contents = f.read() + new_contents, errors = format_str(contents) + for error in errors: + lineno = contents[: error.offset].count("\n") + 1 + print(f"{filename}:{lineno}: code block parse error {error.exc}") + if errors and not skip_errors: + return 1 + if contents != new_contents: + print(f"{filename}: Rewriting...") + with open(filename, "w", encoding="UTF-8") as f: + f.write(new_contents) + return 0 + else: + return 0 + + +def main(argv: Sequence[str] | None = None) -> int: + parser = argparse.ArgumentParser() + parser.add_argument( + "-l", + "--line-length", + type=int, + default=DEFAULT_LINE_LENGTH, + ) + parser.add_argument( + "-S", + "--skip-string-normalization", + action="store_true", + ) + parser.add_argument("-E", "--skip-errors", action="store_true") + parser.add_argument("filenames", nargs="*") + args = parser.parse_args(argv) + + retv = 0 + for filename in args.filenames: + retv |= format_file(filename, skip_errors=args.skip_errors) + return retv + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/utils/upload-artifact.sh b/scripts/utils/upload-artifact.sh new file mode 100755 index 00000000..af535aa5 --- /dev/null +++ b/scripts/utils/upload-artifact.sh @@ -0,0 +1,27 @@ +#!/usr/bin/env bash +set -exuo pipefail + +FILENAME=$(basename dist/*.whl) + +RESPONSE=$(curl -X POST "$URL?filename=$FILENAME" \ + -H "Authorization: Bearer $AUTH" \ + -H "Content-Type: application/json") + +SIGNED_URL=$(echo "$RESPONSE" | jq -r '.url') + +if [[ "$SIGNED_URL" == "null" ]]; then + echo -e "\033[31mFailed to get signed URL.\033[0m" + exit 1 +fi + +UPLOAD_RESPONSE=$(curl -v -X PUT \ + -H "Content-Type: binary/octet-stream" \ + --data-binary "@dist/$FILENAME" "$SIGNED_URL" 2>&1) + +if echo "$UPLOAD_RESPONSE" | grep -q "HTTP/[0-9.]* 200"; then + echo -e "\033[32mUploaded build to Stainless storage.\033[0m" + echo -e "\033[32mInstallation: pip install 'https://pkg.stainless.com/s/writer-python/$SHA/$FILENAME'\033[0m" +else + echo -e "\033[31mFailed to upload artifact.\033[0m" + exit 1 +fi diff --git a/src/writerai/__init__.py b/src/writerai/__init__.py new file mode 100644 index 00000000..cc744dd2 --- /dev/null +++ b/src/writerai/__init__.py @@ -0,0 +1,92 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +import typing as _t + +from . import types +from ._types import NOT_GIVEN, Omit, NoneType, NotGiven, Transport, ProxiesTypes, omit, not_given +from ._utils import file_from_path +from ._client import Client, Stream, Writer, Timeout, Transport, AsyncClient, AsyncStream, AsyncWriter, RequestOptions +from ._models import BaseModel +from ._version import __title__, __version__ +from ._response import APIResponse as APIResponse, AsyncAPIResponse as AsyncAPIResponse +from ._constants import DEFAULT_TIMEOUT, DEFAULT_MAX_RETRIES, DEFAULT_CONNECTION_LIMITS +from ._exceptions import ( + APIError, + WriterError, + ConflictError, + NotFoundError, + APIStatusError, + RateLimitError, + APITimeoutError, + BadRequestError, + APIConnectionError, + AuthenticationError, + InternalServerError, + PermissionDeniedError, + UnprocessableEntityError, + APIResponseValidationError, +) +from ._base_client import DefaultHttpxClient, DefaultAioHttpClient, DefaultAsyncHttpxClient +from ._utils._logs import setup_logging as _setup_logging + +__all__ = [ + "types", + "__version__", + "__title__", + "NoneType", + "Transport", + "ProxiesTypes", + "NotGiven", + "NOT_GIVEN", + "not_given", + "Omit", + "omit", + "WriterError", + "APIError", + "APIStatusError", + "APITimeoutError", + "APIConnectionError", + "APIResponseValidationError", + "BadRequestError", + "AuthenticationError", + "PermissionDeniedError", + "NotFoundError", + "ConflictError", + "UnprocessableEntityError", + "RateLimitError", + "InternalServerError", + "Timeout", + "RequestOptions", + "Client", + "AsyncClient", + "Stream", + "AsyncStream", + "Writer", + "AsyncWriter", + "file_from_path", + "BaseModel", + "DEFAULT_TIMEOUT", + "DEFAULT_MAX_RETRIES", + "DEFAULT_CONNECTION_LIMITS", + "DefaultHttpxClient", + "DefaultAsyncHttpxClient", + "DefaultAioHttpClient", +] + +if not _t.TYPE_CHECKING: + from ._utils._resources_proxy import resources as resources + +_setup_logging() + +# Update the __module__ attribute for exported symbols so that +# error messages point to this module instead of the module +# it was originally defined in, e.g. +# writerai._exceptions.NotFoundError -> writerai.NotFoundError +__locals = locals() +for __name in __all__: + if not __name.startswith("__"): + try: + __locals[__name].__module__ = "writerai" + except (TypeError, AttributeError): + # Some of our exported symbols are builtins which we can't set attributes for. + pass diff --git a/src/writerai/_base_client.py b/src/writerai/_base_client.py new file mode 100644 index 00000000..22c0ce8f --- /dev/null +++ b/src/writerai/_base_client.py @@ -0,0 +1,2131 @@ +from __future__ import annotations + +import sys +import json +import time +import uuid +import email +import asyncio +import inspect +import logging +import platform +import warnings +import email.utils +from types import TracebackType +from random import random +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Type, + Union, + Generic, + Mapping, + TypeVar, + Iterable, + Iterator, + Optional, + Generator, + AsyncIterator, + cast, + overload, +) +from typing_extensions import Literal, override, get_origin + +import anyio +import httpx +import distro +import pydantic +from httpx import URL +from pydantic import PrivateAttr + +from . import _exceptions +from ._qs import Querystring +from ._files import to_httpx_files, async_to_httpx_files +from ._types import ( + Body, + Omit, + Query, + Headers, + Timeout, + NotGiven, + ResponseT, + AnyMapping, + PostParser, + BinaryTypes, + RequestFiles, + HttpxSendArgs, + RequestOptions, + AsyncBinaryTypes, + HttpxRequestFiles, + ModelBuilderProtocol, + not_given, +) +from ._utils import is_dict, is_list, asyncify, is_given, lru_cache, is_mapping +from ._compat import PYDANTIC_V1, model_copy, model_dump +from ._models import GenericModel, FinalRequestOptions, validate_type, construct_type +from ._response import ( + APIResponse, + BaseAPIResponse, + AsyncAPIResponse, + extract_response_type, +) +from ._constants import ( + DEFAULT_TIMEOUT, + MAX_RETRY_DELAY, + DEFAULT_MAX_RETRIES, + INITIAL_RETRY_DELAY, + RAW_RESPONSE_HEADER, + OVERRIDE_CAST_TO_HEADER, + DEFAULT_CONNECTION_LIMITS, +) +from ._streaming import Stream, SSEDecoder, AsyncStream, SSEBytesDecoder +from ._exceptions import ( + APIStatusError, + APITimeoutError, + APIConnectionError, + APIResponseValidationError, +) +from ._utils._json import openapi_dumps + +log: logging.Logger = logging.getLogger(__name__) + +# TODO: make base page type vars covariant +SyncPageT = TypeVar("SyncPageT", bound="BaseSyncPage[Any]") +AsyncPageT = TypeVar("AsyncPageT", bound="BaseAsyncPage[Any]") + + +_T = TypeVar("_T") +_T_co = TypeVar("_T_co", covariant=True) + +_StreamT = TypeVar("_StreamT", bound=Stream[Any]) +_AsyncStreamT = TypeVar("_AsyncStreamT", bound=AsyncStream[Any]) + +if TYPE_CHECKING: + from httpx._config import ( + DEFAULT_TIMEOUT_CONFIG, # pyright: ignore[reportPrivateImportUsage] + ) + + HTTPX_DEFAULT_TIMEOUT = DEFAULT_TIMEOUT_CONFIG +else: + try: + from httpx._config import DEFAULT_TIMEOUT_CONFIG as HTTPX_DEFAULT_TIMEOUT + except ImportError: + # taken from https://github.com/encode/httpx/blob/3ba5fe0d7ac70222590e759c31442b1cab263791/httpx/_config.py#L366 + HTTPX_DEFAULT_TIMEOUT = Timeout(5.0) + + +class PageInfo: + """Stores the necessary information to build the request to retrieve the next page. + + Either `url` or `params` must be set. + """ + + url: URL | NotGiven + params: Query | NotGiven + json: Body | NotGiven + + @overload + def __init__( + self, + *, + url: URL, + ) -> None: ... + + @overload + def __init__( + self, + *, + params: Query, + ) -> None: ... + + @overload + def __init__( + self, + *, + json: Body, + ) -> None: ... + + def __init__( + self, + *, + url: URL | NotGiven = not_given, + json: Body | NotGiven = not_given, + params: Query | NotGiven = not_given, + ) -> None: + self.url = url + self.json = json + self.params = params + + @override + def __repr__(self) -> str: + if self.url: + return f"{self.__class__.__name__}(url={self.url})" + if self.json: + return f"{self.__class__.__name__}(json={self.json})" + return f"{self.__class__.__name__}(params={self.params})" + + +class BasePage(GenericModel, Generic[_T]): + """ + Defines the core interface for pagination. + + Type Args: + ModelT: The pydantic model that represents an item in the response. + + Methods: + has_next_page(): Check if there is another page available + next_page_info(): Get the necessary information to make a request for the next page + """ + + _options: FinalRequestOptions = PrivateAttr() + _model: Type[_T] = PrivateAttr() + + def has_next_page(self) -> bool: + items = self._get_page_items() + if not items: + return False + return self.next_page_info() is not None + + def next_page_info(self) -> Optional[PageInfo]: ... + + def _get_page_items(self) -> Iterable[_T]: # type: ignore[empty-body] + ... + + def _params_from_url(self, url: URL) -> httpx.QueryParams: + # TODO: do we have to preprocess params here? + return httpx.QueryParams(cast(Any, self._options.params)).merge(url.params) + + def _info_to_options(self, info: PageInfo) -> FinalRequestOptions: + options = model_copy(self._options) + options._strip_raw_response_header() + + if not isinstance(info.params, NotGiven): + options.params = {**options.params, **info.params} + return options + + if not isinstance(info.url, NotGiven): + params = self._params_from_url(info.url) + url = info.url.copy_with(params=params) + options.params = dict(url.params) + options.url = str(url) + return options + + if not isinstance(info.json, NotGiven): + if not is_mapping(info.json): + raise TypeError("Pagination is only supported with mappings") + + if not options.json_data: + options.json_data = {**info.json} + else: + if not is_mapping(options.json_data): + raise TypeError("Pagination is only supported with mappings") + + options.json_data = {**options.json_data, **info.json} + return options + + raise ValueError("Unexpected PageInfo state") + + +class BaseSyncPage(BasePage[_T], Generic[_T]): + _client: SyncAPIClient = pydantic.PrivateAttr() + + def _set_private_attributes( + self, + client: SyncAPIClient, + model: Type[_T], + options: FinalRequestOptions, + ) -> None: + if (not PYDANTIC_V1) and getattr(self, "__pydantic_private__", None) is None: + self.__pydantic_private__ = {} + + self._model = model + self._client = client + self._options = options + + # Pydantic uses a custom `__iter__` method to support casting BaseModels + # to dictionaries. e.g. dict(model). + # As we want to support `for item in page`, this is inherently incompatible + # with the default pydantic behaviour. It is not possible to support both + # use cases at once. Fortunately, this is not a big deal as all other pydantic + # methods should continue to work as expected as there is an alternative method + # to cast a model to a dictionary, model.dict(), which is used internally + # by pydantic. + def __iter__(self) -> Iterator[_T]: # type: ignore + for page in self.iter_pages(): + for item in page._get_page_items(): + yield item + + def iter_pages(self: SyncPageT) -> Iterator[SyncPageT]: + page = self + while True: + yield page + if page.has_next_page(): + page = page.get_next_page() + else: + return + + def get_next_page(self: SyncPageT) -> SyncPageT: + info = self.next_page_info() + if not info: + raise RuntimeError( + "No next page expected; please check `.has_next_page()` before calling `.get_next_page()`." + ) + + options = self._info_to_options(info) + return self._client._request_api_list(self._model, page=self.__class__, options=options) + + +class AsyncPaginator(Generic[_T, AsyncPageT]): + def __init__( + self, + client: AsyncAPIClient, + options: FinalRequestOptions, + page_cls: Type[AsyncPageT], + model: Type[_T], + ) -> None: + self._model = model + self._client = client + self._options = options + self._page_cls = page_cls + + def __await__(self) -> Generator[Any, None, AsyncPageT]: + return self._get_page().__await__() + + async def _get_page(self) -> AsyncPageT: + def _parser(resp: AsyncPageT) -> AsyncPageT: + resp._set_private_attributes( + model=self._model, + options=self._options, + client=self._client, + ) + return resp + + self._options.post_parser = _parser + + return await self._client.request(self._page_cls, self._options) + + async def __aiter__(self) -> AsyncIterator[_T]: + # https://github.com/microsoft/pyright/issues/3464 + page = cast( + AsyncPageT, + await self, # type: ignore + ) + async for item in page: + yield item + + +class BaseAsyncPage(BasePage[_T], Generic[_T]): + _client: AsyncAPIClient = pydantic.PrivateAttr() + + def _set_private_attributes( + self, + model: Type[_T], + client: AsyncAPIClient, + options: FinalRequestOptions, + ) -> None: + if (not PYDANTIC_V1) and getattr(self, "__pydantic_private__", None) is None: + self.__pydantic_private__ = {} + + self._model = model + self._client = client + self._options = options + + async def __aiter__(self) -> AsyncIterator[_T]: + async for page in self.iter_pages(): + for item in page._get_page_items(): + yield item + + async def iter_pages(self: AsyncPageT) -> AsyncIterator[AsyncPageT]: + page = self + while True: + yield page + if page.has_next_page(): + page = await page.get_next_page() + else: + return + + async def get_next_page(self: AsyncPageT) -> AsyncPageT: + info = self.next_page_info() + if not info: + raise RuntimeError( + "No next page expected; please check `.has_next_page()` before calling `.get_next_page()`." + ) + + options = self._info_to_options(info) + return await self._client._request_api_list(self._model, page=self.__class__, options=options) + + +_HttpxClientT = TypeVar("_HttpxClientT", bound=Union[httpx.Client, httpx.AsyncClient]) +_DefaultStreamT = TypeVar("_DefaultStreamT", bound=Union[Stream[Any], AsyncStream[Any]]) + + +class BaseClient(Generic[_HttpxClientT, _DefaultStreamT]): + _client: _HttpxClientT + _version: str + _base_url: URL + max_retries: int + timeout: Union[float, Timeout, None] + _strict_response_validation: bool + _idempotency_header: str | None + _default_stream_cls: type[_DefaultStreamT] | None = None + + def __init__( + self, + *, + version: str, + base_url: str | URL, + _strict_response_validation: bool, + max_retries: int = DEFAULT_MAX_RETRIES, + timeout: float | Timeout | None = DEFAULT_TIMEOUT, + custom_headers: Mapping[str, str] | None = None, + custom_query: Mapping[str, object] | None = None, + ) -> None: + self._version = version + self._base_url = self._enforce_trailing_slash(URL(base_url)) + self.max_retries = max_retries + self.timeout = timeout + self._custom_headers = custom_headers or {} + self._custom_query = custom_query or {} + self._strict_response_validation = _strict_response_validation + self._idempotency_header = None + self._platform: Platform | None = None + + if max_retries is None: # pyright: ignore[reportUnnecessaryComparison] + raise TypeError( + "max_retries cannot be None. If you want to disable retries, pass `0`; if you want unlimited retries, pass `math.inf` or a very high number; if you want the default behavior, pass `writerai.DEFAULT_MAX_RETRIES`" + ) + + def _enforce_trailing_slash(self, url: URL) -> URL: + if url.raw_path.endswith(b"/"): + return url + return url.copy_with(raw_path=url.raw_path + b"/") + + def _make_status_error_from_response( + self, + response: httpx.Response, + ) -> APIStatusError: + if response.is_closed and not response.is_stream_consumed: + # We can't read the response body as it has been closed + # before it was read. This can happen if an event hook + # raises a status error. + body = None + err_msg = f"Error code: {response.status_code}" + else: + err_text = response.text.strip() + body = err_text + + try: + body = json.loads(err_text) + err_msg = f"Error code: {response.status_code} - {body}" + except Exception: + err_msg = err_text or f"Error code: {response.status_code}" + + return self._make_status_error(err_msg, body=body, response=response) + + def _make_status_error( + self, + err_msg: str, + *, + body: object, + response: httpx.Response, + ) -> _exceptions.APIStatusError: + raise NotImplementedError() + + def _build_headers(self, options: FinalRequestOptions, *, retries_taken: int = 0) -> httpx.Headers: + custom_headers = options.headers or {} + headers_dict = _merge_mappings(self.default_headers, custom_headers) + self._validate_headers(headers_dict, custom_headers) + + # headers are case-insensitive while dictionaries are not. + headers = httpx.Headers(headers_dict) + + idempotency_header = self._idempotency_header + if idempotency_header and options.idempotency_key and idempotency_header not in headers: + headers[idempotency_header] = options.idempotency_key + + # Don't set these headers if they were already set or removed by the caller. We check + # `custom_headers`, which can contain `Omit()`, instead of `headers` to account for the removal case. + lower_custom_headers = [header.lower() for header in custom_headers] + if "x-stainless-retry-count" not in lower_custom_headers: + headers["x-stainless-retry-count"] = str(retries_taken) + if "x-stainless-read-timeout" not in lower_custom_headers: + timeout = self.timeout if isinstance(options.timeout, NotGiven) else options.timeout + if isinstance(timeout, Timeout): + timeout = timeout.read + if timeout is not None: + headers["x-stainless-read-timeout"] = str(timeout) + + return headers + + def _prepare_url(self, url: str) -> URL: + """ + Merge a URL argument together with any 'base_url' on the client, + to create the URL used for the outgoing request. + """ + # Copied from httpx's `_merge_url` method. + merge_url = URL(url) + if merge_url.is_relative_url: + merge_raw_path = self.base_url.raw_path + merge_url.raw_path.lstrip(b"/") + return self.base_url.copy_with(raw_path=merge_raw_path) + + return merge_url + + def _make_sse_decoder(self) -> SSEDecoder | SSEBytesDecoder: + return SSEDecoder() + + def _build_request( + self, + options: FinalRequestOptions, + *, + retries_taken: int = 0, + ) -> httpx.Request: + if log.isEnabledFor(logging.DEBUG): + log.debug( + "Request options: %s", + model_dump( + options, + exclude_unset=True, + # Pydantic v1 can't dump every type we support in content, so we exclude it for now. + exclude={ + "content", + } + if PYDANTIC_V1 + else {}, + ), + ) + kwargs: dict[str, Any] = {} + + json_data = options.json_data + if options.extra_json is not None: + if json_data is None: + json_data = cast(Body, options.extra_json) + elif is_mapping(json_data): + json_data = _merge_mappings(json_data, options.extra_json) + else: + raise RuntimeError(f"Unexpected JSON data type, {type(json_data)}, cannot merge with `extra_body`") + + headers = self._build_headers(options, retries_taken=retries_taken) + params = _merge_mappings(self.default_query, options.params) + content_type = headers.get("Content-Type") + files = options.files + + # If the given Content-Type header is multipart/form-data then it + # has to be removed so that httpx can generate the header with + # additional information for us as it has to be in this form + # for the server to be able to correctly parse the request: + # multipart/form-data; boundary=---abc-- + if content_type is not None and content_type.startswith("multipart/form-data"): + if "boundary" not in content_type: + # only remove the header if the boundary hasn't been explicitly set + # as the caller doesn't want httpx to come up with their own boundary + headers.pop("Content-Type") + + # As we are now sending multipart/form-data instead of application/json + # we need to tell httpx to use it, https://www.python-httpx.org/advanced/clients/#multipart-file-encoding + if json_data: + if not is_dict(json_data): + raise TypeError( + f"Expected query input to be a dictionary for multipart requests but got {type(json_data)} instead." + ) + kwargs["data"] = self._serialize_multipartform(json_data) + + # httpx determines whether or not to send a "multipart/form-data" + # request based on the truthiness of the "files" argument. + # This gets around that issue by generating a dict value that + # evaluates to true. + # + # https://github.com/encode/httpx/discussions/2399#discussioncomment-3814186 + if not files: + files = cast(HttpxRequestFiles, ForceMultipartDict()) + + prepared_url = self._prepare_url(options.url) + # preserve hard-coded query params from the url + if params and prepared_url.query: + params = {**dict(prepared_url.params.items()), **params} + prepared_url = prepared_url.copy_with(raw_path=prepared_url.raw_path.split(b"?", 1)[0]) + if "_" in prepared_url.host: + # work around https://github.com/encode/httpx/discussions/2880 + kwargs["extensions"] = {"sni_hostname": prepared_url.host.replace("_", "-")} + + is_body_allowed = options.method.lower() != "get" + + if is_body_allowed: + if options.content is not None and json_data is not None: + raise TypeError("Passing both `content` and `json_data` is not supported") + if options.content is not None and files is not None: + raise TypeError("Passing both `content` and `files` is not supported") + if options.content is not None: + kwargs["content"] = options.content + elif isinstance(json_data, bytes): + kwargs["content"] = json_data + elif not files: + # Don't set content when JSON is sent as multipart/form-data, + # since httpx's content param overrides other body arguments + kwargs["content"] = openapi_dumps(json_data) if is_given(json_data) and json_data is not None else None + kwargs["files"] = files + else: + headers.pop("Content-Type", None) + kwargs.pop("data", None) + + # TODO: report this error to httpx + return self._client.build_request( # pyright: ignore[reportUnknownMemberType] + headers=headers, + timeout=self.timeout if isinstance(options.timeout, NotGiven) else options.timeout, + method=options.method, + url=prepared_url, + # the `Query` type that we use is incompatible with qs' + # `Params` type as it needs to be typed as `Mapping[str, object]` + # so that passing a `TypedDict` doesn't cause an error. + # https://github.com/microsoft/pyright/issues/3526#event-6715453066 + params=self.qs.stringify(cast(Mapping[str, Any], params)) if params else None, + **kwargs, + ) + + def _serialize_multipartform(self, data: Mapping[object, object]) -> dict[str, object]: + items = self.qs.stringify_items( + # TODO: type ignore is required as stringify_items is well typed but we can't be + # well typed without heavy validation. + data, # type: ignore + array_format="brackets", + ) + serialized: dict[str, object] = {} + for key, value in items: + existing = serialized.get(key) + + if not existing: + serialized[key] = value + continue + + # If a value has already been set for this key then that + # means we're sending data like `array[]=[1, 2, 3]` and we + # need to tell httpx that we want to send multiple values with + # the same key which is done by using a list or a tuple. + # + # Note: 2d arrays should never result in the same key at both + # levels so it's safe to assume that if the value is a list, + # it was because we changed it to be a list. + if is_list(existing): + existing.append(value) + else: + serialized[key] = [existing, value] + + return serialized + + def _maybe_override_cast_to(self, cast_to: type[ResponseT], options: FinalRequestOptions) -> type[ResponseT]: + if not is_given(options.headers): + return cast_to + + # make a copy of the headers so we don't mutate user-input + headers = dict(options.headers) + + # we internally support defining a temporary header to override the + # default `cast_to` type for use with `.with_raw_response` and `.with_streaming_response` + # see _response.py for implementation details + override_cast_to = headers.pop(OVERRIDE_CAST_TO_HEADER, not_given) + if is_given(override_cast_to): + options.headers = headers + return cast(Type[ResponseT], override_cast_to) + + return cast_to + + def _should_stream_response_body(self, request: httpx.Request) -> bool: + return request.headers.get(RAW_RESPONSE_HEADER) == "stream" # type: ignore[no-any-return] + + def _process_response_data( + self, + *, + data: object, + cast_to: type[ResponseT], + response: httpx.Response, + ) -> ResponseT: + if data is None: + return cast(ResponseT, None) + + if cast_to is object: + return cast(ResponseT, data) + + try: + if inspect.isclass(cast_to) and issubclass(cast_to, ModelBuilderProtocol): + return cast(ResponseT, cast_to.build(response=response, data=data)) + + if self._strict_response_validation: + return cast(ResponseT, validate_type(type_=cast_to, value=data)) + + return cast(ResponseT, construct_type(type_=cast_to, value=data)) + except pydantic.ValidationError as err: + raise APIResponseValidationError(response=response, body=data) from err + + @property + def qs(self) -> Querystring: + return Querystring() + + @property + def custom_auth(self) -> httpx.Auth | None: + return None + + @property + def auth_headers(self) -> dict[str, str]: + return {} + + @property + def default_headers(self) -> dict[str, str | Omit]: + return { + "Accept": "application/json", + "Content-Type": "application/json", + "User-Agent": self.user_agent, + **self.platform_headers(), + **self.auth_headers, + **self._custom_headers, + } + + @property + def default_query(self) -> dict[str, object]: + return { + **self._custom_query, + } + + def _validate_headers( + self, + headers: Headers, # noqa: ARG002 + custom_headers: Headers, # noqa: ARG002 + ) -> None: + """Validate the given default headers and custom headers. + + Does nothing by default. + """ + return + + @property + def user_agent(self) -> str: + return f"{self.__class__.__name__}/Python {self._version}" + + @property + def base_url(self) -> URL: + return self._base_url + + @base_url.setter + def base_url(self, url: URL | str) -> None: + self._base_url = self._enforce_trailing_slash(url if isinstance(url, URL) else URL(url)) + + def platform_headers(self) -> Dict[str, str]: + # the actual implementation is in a separate `lru_cache` decorated + # function because adding `lru_cache` to methods will leak memory + # https://github.com/python/cpython/issues/88476 + return platform_headers(self._version, platform=self._platform) + + def _parse_retry_after_header(self, response_headers: Optional[httpx.Headers] = None) -> float | None: + """Returns a float of the number of seconds (not milliseconds) to wait after retrying, or None if unspecified. + + About the Retry-After header: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After + See also https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After#syntax + """ + if response_headers is None: + return None + + # First, try the non-standard `retry-after-ms` header for milliseconds, + # which is more precise than integer-seconds `retry-after` + try: + retry_ms_header = response_headers.get("retry-after-ms", None) + return float(retry_ms_header) / 1000 + except (TypeError, ValueError): + pass + + # Next, try parsing `retry-after` header as seconds (allowing nonstandard floats). + retry_header = response_headers.get("retry-after") + try: + # note: the spec indicates that this should only ever be an integer + # but if someone sends a float there's no reason for us to not respect it + return float(retry_header) + except (TypeError, ValueError): + pass + + # Last, try parsing `retry-after` as a date. + retry_date_tuple = email.utils.parsedate_tz(retry_header) + if retry_date_tuple is None: + return None + + retry_date = email.utils.mktime_tz(retry_date_tuple) + return float(retry_date - time.time()) + + def _calculate_retry_timeout( + self, + remaining_retries: int, + options: FinalRequestOptions, + response_headers: Optional[httpx.Headers] = None, + ) -> float: + max_retries = options.get_max_retries(self.max_retries) + + # If the API asks us to wait a certain amount of time (and it's a reasonable amount), just do what it says. + retry_after = self._parse_retry_after_header(response_headers) + if retry_after is not None and 0 < retry_after <= 60: + return retry_after + + # Also cap retry count to 1000 to avoid any potential overflows with `pow` + nb_retries = min(max_retries - remaining_retries, 1000) + + # Apply exponential backoff, but not more than the max. + sleep_seconds = min(INITIAL_RETRY_DELAY * pow(2.0, nb_retries), MAX_RETRY_DELAY) + + # Apply some jitter, plus-or-minus half a second. + jitter = 1 - 0.25 * random() + timeout = sleep_seconds * jitter + return timeout if timeout >= 0 else 0 + + def _should_retry(self, response: httpx.Response) -> bool: + # Note: this is not a standard header + should_retry_header = response.headers.get("x-should-retry") + + # If the server explicitly says whether or not to retry, obey. + if should_retry_header == "true": + log.debug("Retrying as header `x-should-retry` is set to `true`") + return True + if should_retry_header == "false": + log.debug("Not retrying as header `x-should-retry` is set to `false`") + return False + + # Retry on request timeouts. + if response.status_code == 408: + log.debug("Retrying due to status code %i", response.status_code) + return True + + # Retry on lock timeouts. + if response.status_code == 409: + log.debug("Retrying due to status code %i", response.status_code) + return True + + # Retry on rate limits. + if response.status_code == 429: + log.debug("Retrying due to status code %i", response.status_code) + return True + + # Retry internal errors. + if response.status_code >= 500: + log.debug("Retrying due to status code %i", response.status_code) + return True + + log.debug("Not retrying") + return False + + def _idempotency_key(self) -> str: + return f"stainless-python-retry-{uuid.uuid4()}" + + +class _DefaultHttpxClient(httpx.Client): + def __init__(self, **kwargs: Any) -> None: + kwargs.setdefault("timeout", DEFAULT_TIMEOUT) + kwargs.setdefault("limits", DEFAULT_CONNECTION_LIMITS) + kwargs.setdefault("follow_redirects", True) + super().__init__(**kwargs) + + +if TYPE_CHECKING: + DefaultHttpxClient = httpx.Client + """An alias to `httpx.Client` that provides the same defaults that this SDK + uses internally. + + This is useful because overriding the `http_client` with your own instance of + `httpx.Client` will result in httpx's defaults being used, not ours. + """ +else: + DefaultHttpxClient = _DefaultHttpxClient + + +class SyncHttpxClientWrapper(DefaultHttpxClient): + def __del__(self) -> None: + if self.is_closed: + return + + try: + self.close() + except Exception: + pass + + +class SyncAPIClient(BaseClient[httpx.Client, Stream[Any]]): + _client: httpx.Client + _default_stream_cls: type[Stream[Any]] | None = None + + def __init__( + self, + *, + version: str, + base_url: str | URL, + max_retries: int = DEFAULT_MAX_RETRIES, + timeout: float | Timeout | None | NotGiven = not_given, + http_client: httpx.Client | None = None, + custom_headers: Mapping[str, str] | None = None, + custom_query: Mapping[str, object] | None = None, + _strict_response_validation: bool, + ) -> None: + if not is_given(timeout): + # if the user passed in a custom http client with a non-default + # timeout set then we use that timeout. + # + # note: there is an edge case here where the user passes in a client + # where they've explicitly set the timeout to match the default timeout + # as this check is structural, meaning that we'll think they didn't + # pass in a timeout and will ignore it + if http_client and http_client.timeout != HTTPX_DEFAULT_TIMEOUT: + timeout = http_client.timeout + else: + timeout = DEFAULT_TIMEOUT + + if http_client is not None and not isinstance(http_client, httpx.Client): # pyright: ignore[reportUnnecessaryIsInstance] + raise TypeError( + f"Invalid `http_client` argument; Expected an instance of `httpx.Client` but got {type(http_client)}" + ) + + super().__init__( + version=version, + # cast to a valid type because mypy doesn't understand our type narrowing + timeout=cast(Timeout, timeout), + base_url=base_url, + max_retries=max_retries, + custom_query=custom_query, + custom_headers=custom_headers, + _strict_response_validation=_strict_response_validation, + ) + self._client = http_client or SyncHttpxClientWrapper( + base_url=base_url, + # cast to a valid type because mypy doesn't understand our type narrowing + timeout=cast(Timeout, timeout), + ) + + def is_closed(self) -> bool: + return self._client.is_closed + + def close(self) -> None: + """Close the underlying HTTPX client. + + The client will *not* be usable after this. + """ + # If an error is thrown while constructing a client, self._client + # may not be present + if hasattr(self, "_client"): + self._client.close() + + def __enter__(self: _T) -> _T: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + self.close() + + def _prepare_options( + self, + options: FinalRequestOptions, # noqa: ARG002 + ) -> FinalRequestOptions: + """Hook for mutating the given options""" + return options + + def _prepare_request( + self, + request: httpx.Request, # noqa: ARG002 + ) -> None: + """This method is used as a callback for mutating the `Request` object + after it has been constructed. + This is useful for cases where you want to add certain headers based off of + the request properties, e.g. `url`, `method` etc. + """ + return None + + @overload + def request( + self, + cast_to: Type[ResponseT], + options: FinalRequestOptions, + *, + stream: Literal[True], + stream_cls: Type[_StreamT], + ) -> _StreamT: ... + + @overload + def request( + self, + cast_to: Type[ResponseT], + options: FinalRequestOptions, + *, + stream: Literal[False] = False, + ) -> ResponseT: ... + + @overload + def request( + self, + cast_to: Type[ResponseT], + options: FinalRequestOptions, + *, + stream: bool = False, + stream_cls: Type[_StreamT] | None = None, + ) -> ResponseT | _StreamT: ... + + def request( + self, + cast_to: Type[ResponseT], + options: FinalRequestOptions, + *, + stream: bool = False, + stream_cls: type[_StreamT] | None = None, + ) -> ResponseT | _StreamT: + cast_to = self._maybe_override_cast_to(cast_to, options) + + # create a copy of the options we were given so that if the + # options are mutated later & we then retry, the retries are + # given the original options + input_options = model_copy(options) + if input_options.idempotency_key is None and input_options.method.lower() != "get": + # ensure the idempotency key is reused between requests + input_options.idempotency_key = self._idempotency_key() + + response: httpx.Response | None = None + max_retries = input_options.get_max_retries(self.max_retries) + + retries_taken = 0 + for retries_taken in range(max_retries + 1): + options = model_copy(input_options) + options = self._prepare_options(options) + + remaining_retries = max_retries - retries_taken + request = self._build_request(options, retries_taken=retries_taken) + self._prepare_request(request) + + kwargs: HttpxSendArgs = {} + if self.custom_auth is not None: + kwargs["auth"] = self.custom_auth + + if options.follow_redirects is not None: + kwargs["follow_redirects"] = options.follow_redirects + + log.debug("Sending HTTP Request: %s %s", request.method, request.url) + + response = None + try: + response = self._client.send( + request, + stream=stream or self._should_stream_response_body(request=request), + **kwargs, + ) + except httpx.TimeoutException as err: + log.debug("Encountered httpx.TimeoutException", exc_info=True) + + if remaining_retries > 0: + self._sleep_for_retry( + retries_taken=retries_taken, + max_retries=max_retries, + options=input_options, + response=None, + ) + continue + + log.debug("Raising timeout error") + raise APITimeoutError(request=request) from err + except Exception as err: + log.debug("Encountered Exception", exc_info=True) + + if remaining_retries > 0: + self._sleep_for_retry( + retries_taken=retries_taken, + max_retries=max_retries, + options=input_options, + response=None, + ) + continue + + log.debug("Raising connection error") + raise APIConnectionError(request=request) from err + + log.debug( + 'HTTP Response: %s %s "%i %s" %s', + request.method, + request.url, + response.status_code, + response.reason_phrase, + response.headers, + ) + + try: + response.raise_for_status() + except httpx.HTTPStatusError as err: # thrown on 4xx and 5xx status code + log.debug("Encountered httpx.HTTPStatusError", exc_info=True) + + if remaining_retries > 0 and self._should_retry(err.response): + err.response.close() + self._sleep_for_retry( + retries_taken=retries_taken, + max_retries=max_retries, + options=input_options, + response=response, + ) + continue + + # If the response is streamed then we need to explicitly read the response + # to completion before attempting to access the response text. + if not err.response.is_closed: + err.response.read() + + log.debug("Re-raising status error") + raise self._make_status_error_from_response(err.response) from None + + break + + assert response is not None, "could not resolve response (should never happen)" + return self._process_response( + cast_to=cast_to, + options=options, + response=response, + stream=stream, + stream_cls=stream_cls, + retries_taken=retries_taken, + ) + + def _sleep_for_retry( + self, *, retries_taken: int, max_retries: int, options: FinalRequestOptions, response: httpx.Response | None + ) -> None: + remaining_retries = max_retries - retries_taken + if remaining_retries == 1: + log.debug("1 retry left") + else: + log.debug("%i retries left", remaining_retries) + + timeout = self._calculate_retry_timeout(remaining_retries, options, response.headers if response else None) + log.info("Retrying request to %s in %f seconds", options.url, timeout) + + time.sleep(timeout) + + def _process_response( + self, + *, + cast_to: Type[ResponseT], + options: FinalRequestOptions, + response: httpx.Response, + stream: bool, + stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None, + retries_taken: int = 0, + ) -> ResponseT: + origin = get_origin(cast_to) or cast_to + + if ( + inspect.isclass(origin) + and issubclass(origin, BaseAPIResponse) + # we only want to actually return the custom BaseAPIResponse class if we're + # returning the raw response, or if we're not streaming SSE, as if we're streaming + # SSE then `cast_to` doesn't actively reflect the type we need to parse into + and (not stream or bool(response.request.headers.get(RAW_RESPONSE_HEADER))) + ): + if not issubclass(origin, APIResponse): + raise TypeError(f"API Response types must subclass {APIResponse}; Received {origin}") + + response_cls = cast("type[BaseAPIResponse[Any]]", cast_to) + return cast( + ResponseT, + response_cls( + raw=response, + client=self, + cast_to=extract_response_type(response_cls), + stream=stream, + stream_cls=stream_cls, + options=options, + retries_taken=retries_taken, + ), + ) + + if cast_to == httpx.Response: + return cast(ResponseT, response) + + api_response = APIResponse( + raw=response, + client=self, + cast_to=cast("type[ResponseT]", cast_to), # pyright: ignore[reportUnnecessaryCast] + stream=stream, + stream_cls=stream_cls, + options=options, + retries_taken=retries_taken, + ) + if bool(response.request.headers.get(RAW_RESPONSE_HEADER)): + return cast(ResponseT, api_response) + + return api_response.parse() + + def _request_api_list( + self, + model: Type[object], + page: Type[SyncPageT], + options: FinalRequestOptions, + ) -> SyncPageT: + def _parser(resp: SyncPageT) -> SyncPageT: + resp._set_private_attributes( + client=self, + model=model, + options=options, + ) + return resp + + options.post_parser = _parser + + return self.request(page, options, stream=False) + + @overload + def get( + self, + path: str, + *, + cast_to: Type[ResponseT], + options: RequestOptions = {}, + stream: Literal[False] = False, + ) -> ResponseT: ... + + @overload + def get( + self, + path: str, + *, + cast_to: Type[ResponseT], + options: RequestOptions = {}, + stream: Literal[True], + stream_cls: type[_StreamT], + ) -> _StreamT: ... + + @overload + def get( + self, + path: str, + *, + cast_to: Type[ResponseT], + options: RequestOptions = {}, + stream: bool, + stream_cls: type[_StreamT] | None = None, + ) -> ResponseT | _StreamT: ... + + def get( + self, + path: str, + *, + cast_to: Type[ResponseT], + options: RequestOptions = {}, + stream: bool = False, + stream_cls: type[_StreamT] | None = None, + ) -> ResponseT | _StreamT: + opts = FinalRequestOptions.construct(method="get", url=path, **options) + # cast is required because mypy complains about returning Any even though + # it understands the type variables + return cast(ResponseT, self.request(cast_to, opts, stream=stream, stream_cls=stream_cls)) + + @overload + def post( + self, + path: str, + *, + cast_to: Type[ResponseT], + body: Body | None = None, + content: BinaryTypes | None = None, + options: RequestOptions = {}, + files: RequestFiles | None = None, + stream: Literal[False] = False, + ) -> ResponseT: ... + + @overload + def post( + self, + path: str, + *, + cast_to: Type[ResponseT], + body: Body | None = None, + content: BinaryTypes | None = None, + options: RequestOptions = {}, + files: RequestFiles | None = None, + stream: Literal[True], + stream_cls: type[_StreamT], + ) -> _StreamT: ... + + @overload + def post( + self, + path: str, + *, + cast_to: Type[ResponseT], + body: Body | None = None, + content: BinaryTypes | None = None, + options: RequestOptions = {}, + files: RequestFiles | None = None, + stream: bool, + stream_cls: type[_StreamT] | None = None, + ) -> ResponseT | _StreamT: ... + + def post( + self, + path: str, + *, + cast_to: Type[ResponseT], + body: Body | None = None, + content: BinaryTypes | None = None, + options: RequestOptions = {}, + files: RequestFiles | None = None, + stream: bool = False, + stream_cls: type[_StreamT] | None = None, + ) -> ResponseT | _StreamT: + if body is not None and content is not None: + raise TypeError("Passing both `body` and `content` is not supported") + if files is not None and content is not None: + raise TypeError("Passing both `files` and `content` is not supported") + if isinstance(body, bytes): + warnings.warn( + "Passing raw bytes as `body` is deprecated and will be removed in a future version. " + "Please pass raw bytes via the `content` parameter instead.", + DeprecationWarning, + stacklevel=2, + ) + opts = FinalRequestOptions.construct( + method="post", url=path, json_data=body, content=content, files=to_httpx_files(files), **options + ) + return cast(ResponseT, self.request(cast_to, opts, stream=stream, stream_cls=stream_cls)) + + def patch( + self, + path: str, + *, + cast_to: Type[ResponseT], + body: Body | None = None, + content: BinaryTypes | None = None, + files: RequestFiles | None = None, + options: RequestOptions = {}, + ) -> ResponseT: + if body is not None and content is not None: + raise TypeError("Passing both `body` and `content` is not supported") + if files is not None and content is not None: + raise TypeError("Passing both `files` and `content` is not supported") + if isinstance(body, bytes): + warnings.warn( + "Passing raw bytes as `body` is deprecated and will be removed in a future version. " + "Please pass raw bytes via the `content` parameter instead.", + DeprecationWarning, + stacklevel=2, + ) + opts = FinalRequestOptions.construct( + method="patch", url=path, json_data=body, content=content, files=to_httpx_files(files), **options + ) + return self.request(cast_to, opts) + + def put( + self, + path: str, + *, + cast_to: Type[ResponseT], + body: Body | None = None, + content: BinaryTypes | None = None, + files: RequestFiles | None = None, + options: RequestOptions = {}, + ) -> ResponseT: + if body is not None and content is not None: + raise TypeError("Passing both `body` and `content` is not supported") + if files is not None and content is not None: + raise TypeError("Passing both `files` and `content` is not supported") + if isinstance(body, bytes): + warnings.warn( + "Passing raw bytes as `body` is deprecated and will be removed in a future version. " + "Please pass raw bytes via the `content` parameter instead.", + DeprecationWarning, + stacklevel=2, + ) + opts = FinalRequestOptions.construct( + method="put", url=path, json_data=body, content=content, files=to_httpx_files(files), **options + ) + return self.request(cast_to, opts) + + def delete( + self, + path: str, + *, + cast_to: Type[ResponseT], + body: Body | None = None, + content: BinaryTypes | None = None, + options: RequestOptions = {}, + ) -> ResponseT: + if body is not None and content is not None: + raise TypeError("Passing both `body` and `content` is not supported") + if isinstance(body, bytes): + warnings.warn( + "Passing raw bytes as `body` is deprecated and will be removed in a future version. " + "Please pass raw bytes via the `content` parameter instead.", + DeprecationWarning, + stacklevel=2, + ) + opts = FinalRequestOptions.construct(method="delete", url=path, json_data=body, content=content, **options) + return self.request(cast_to, opts) + + def get_api_list( + self, + path: str, + *, + model: Type[object], + page: Type[SyncPageT], + body: Body | None = None, + options: RequestOptions = {}, + method: str = "get", + ) -> SyncPageT: + opts = FinalRequestOptions.construct(method=method, url=path, json_data=body, **options) + return self._request_api_list(model, page, opts) + + +class _DefaultAsyncHttpxClient(httpx.AsyncClient): + def __init__(self, **kwargs: Any) -> None: + kwargs.setdefault("timeout", DEFAULT_TIMEOUT) + kwargs.setdefault("limits", DEFAULT_CONNECTION_LIMITS) + kwargs.setdefault("follow_redirects", True) + super().__init__(**kwargs) + + +try: + import httpx_aiohttp +except ImportError: + + class _DefaultAioHttpClient(httpx.AsyncClient): + def __init__(self, **_kwargs: Any) -> None: + raise RuntimeError("To use the aiohttp client you must have installed the package with the `aiohttp` extra") +else: + + class _DefaultAioHttpClient(httpx_aiohttp.HttpxAiohttpClient): # type: ignore + def __init__(self, **kwargs: Any) -> None: + kwargs.setdefault("timeout", DEFAULT_TIMEOUT) + kwargs.setdefault("limits", DEFAULT_CONNECTION_LIMITS) + kwargs.setdefault("follow_redirects", True) + + super().__init__(**kwargs) + + +if TYPE_CHECKING: + DefaultAsyncHttpxClient = httpx.AsyncClient + """An alias to `httpx.AsyncClient` that provides the same defaults that this SDK + uses internally. + + This is useful because overriding the `http_client` with your own instance of + `httpx.AsyncClient` will result in httpx's defaults being used, not ours. + """ + + DefaultAioHttpClient = httpx.AsyncClient + """An alias to `httpx.AsyncClient` that changes the default HTTP transport to `aiohttp`.""" +else: + DefaultAsyncHttpxClient = _DefaultAsyncHttpxClient + DefaultAioHttpClient = _DefaultAioHttpClient + + +class AsyncHttpxClientWrapper(DefaultAsyncHttpxClient): + def __del__(self) -> None: + if self.is_closed: + return + + try: + # TODO(someday): support non asyncio runtimes here + asyncio.get_running_loop().create_task(self.aclose()) + except Exception: + pass + + +class AsyncAPIClient(BaseClient[httpx.AsyncClient, AsyncStream[Any]]): + _client: httpx.AsyncClient + _default_stream_cls: type[AsyncStream[Any]] | None = None + + def __init__( + self, + *, + version: str, + base_url: str | URL, + _strict_response_validation: bool, + max_retries: int = DEFAULT_MAX_RETRIES, + timeout: float | Timeout | None | NotGiven = not_given, + http_client: httpx.AsyncClient | None = None, + custom_headers: Mapping[str, str] | None = None, + custom_query: Mapping[str, object] | None = None, + ) -> None: + if not is_given(timeout): + # if the user passed in a custom http client with a non-default + # timeout set then we use that timeout. + # + # note: there is an edge case here where the user passes in a client + # where they've explicitly set the timeout to match the default timeout + # as this check is structural, meaning that we'll think they didn't + # pass in a timeout and will ignore it + if http_client and http_client.timeout != HTTPX_DEFAULT_TIMEOUT: + timeout = http_client.timeout + else: + timeout = DEFAULT_TIMEOUT + + if http_client is not None and not isinstance(http_client, httpx.AsyncClient): # pyright: ignore[reportUnnecessaryIsInstance] + raise TypeError( + f"Invalid `http_client` argument; Expected an instance of `httpx.AsyncClient` but got {type(http_client)}" + ) + + super().__init__( + version=version, + base_url=base_url, + # cast to a valid type because mypy doesn't understand our type narrowing + timeout=cast(Timeout, timeout), + max_retries=max_retries, + custom_query=custom_query, + custom_headers=custom_headers, + _strict_response_validation=_strict_response_validation, + ) + self._client = http_client or AsyncHttpxClientWrapper( + base_url=base_url, + # cast to a valid type because mypy doesn't understand our type narrowing + timeout=cast(Timeout, timeout), + ) + + def is_closed(self) -> bool: + return self._client.is_closed + + async def close(self) -> None: + """Close the underlying HTTPX client. + + The client will *not* be usable after this. + """ + await self._client.aclose() + + async def __aenter__(self: _T) -> _T: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + await self.close() + + async def _prepare_options( + self, + options: FinalRequestOptions, # noqa: ARG002 + ) -> FinalRequestOptions: + """Hook for mutating the given options""" + return options + + async def _prepare_request( + self, + request: httpx.Request, # noqa: ARG002 + ) -> None: + """This method is used as a callback for mutating the `Request` object + after it has been constructed. + This is useful for cases where you want to add certain headers based off of + the request properties, e.g. `url`, `method` etc. + """ + return None + + @overload + async def request( + self, + cast_to: Type[ResponseT], + options: FinalRequestOptions, + *, + stream: Literal[False] = False, + ) -> ResponseT: ... + + @overload + async def request( + self, + cast_to: Type[ResponseT], + options: FinalRequestOptions, + *, + stream: Literal[True], + stream_cls: type[_AsyncStreamT], + ) -> _AsyncStreamT: ... + + @overload + async def request( + self, + cast_to: Type[ResponseT], + options: FinalRequestOptions, + *, + stream: bool, + stream_cls: type[_AsyncStreamT] | None = None, + ) -> ResponseT | _AsyncStreamT: ... + + async def request( + self, + cast_to: Type[ResponseT], + options: FinalRequestOptions, + *, + stream: bool = False, + stream_cls: type[_AsyncStreamT] | None = None, + ) -> ResponseT | _AsyncStreamT: + if self._platform is None: + # `get_platform` can make blocking IO calls so we + # execute it earlier while we are in an async context + self._platform = await asyncify(get_platform)() + + cast_to = self._maybe_override_cast_to(cast_to, options) + + # create a copy of the options we were given so that if the + # options are mutated later & we then retry, the retries are + # given the original options + input_options = model_copy(options) + if input_options.idempotency_key is None and input_options.method.lower() != "get": + # ensure the idempotency key is reused between requests + input_options.idempotency_key = self._idempotency_key() + + response: httpx.Response | None = None + max_retries = input_options.get_max_retries(self.max_retries) + + retries_taken = 0 + for retries_taken in range(max_retries + 1): + options = model_copy(input_options) + options = await self._prepare_options(options) + + remaining_retries = max_retries - retries_taken + request = self._build_request(options, retries_taken=retries_taken) + await self._prepare_request(request) + + kwargs: HttpxSendArgs = {} + if self.custom_auth is not None: + kwargs["auth"] = self.custom_auth + + if options.follow_redirects is not None: + kwargs["follow_redirects"] = options.follow_redirects + + log.debug("Sending HTTP Request: %s %s", request.method, request.url) + + response = None + try: + response = await self._client.send( + request, + stream=stream or self._should_stream_response_body(request=request), + **kwargs, + ) + except httpx.TimeoutException as err: + log.debug("Encountered httpx.TimeoutException", exc_info=True) + + if remaining_retries > 0: + await self._sleep_for_retry( + retries_taken=retries_taken, + max_retries=max_retries, + options=input_options, + response=None, + ) + continue + + log.debug("Raising timeout error") + raise APITimeoutError(request=request) from err + except Exception as err: + log.debug("Encountered Exception", exc_info=True) + + if remaining_retries > 0: + await self._sleep_for_retry( + retries_taken=retries_taken, + max_retries=max_retries, + options=input_options, + response=None, + ) + continue + + log.debug("Raising connection error") + raise APIConnectionError(request=request) from err + + log.debug( + 'HTTP Response: %s %s "%i %s" %s', + request.method, + request.url, + response.status_code, + response.reason_phrase, + response.headers, + ) + + try: + response.raise_for_status() + except httpx.HTTPStatusError as err: # thrown on 4xx and 5xx status code + log.debug("Encountered httpx.HTTPStatusError", exc_info=True) + + if remaining_retries > 0 and self._should_retry(err.response): + await err.response.aclose() + await self._sleep_for_retry( + retries_taken=retries_taken, + max_retries=max_retries, + options=input_options, + response=response, + ) + continue + + # If the response is streamed then we need to explicitly read the response + # to completion before attempting to access the response text. + if not err.response.is_closed: + await err.response.aread() + + log.debug("Re-raising status error") + raise self._make_status_error_from_response(err.response) from None + + break + + assert response is not None, "could not resolve response (should never happen)" + return await self._process_response( + cast_to=cast_to, + options=options, + response=response, + stream=stream, + stream_cls=stream_cls, + retries_taken=retries_taken, + ) + + async def _sleep_for_retry( + self, *, retries_taken: int, max_retries: int, options: FinalRequestOptions, response: httpx.Response | None + ) -> None: + remaining_retries = max_retries - retries_taken + if remaining_retries == 1: + log.debug("1 retry left") + else: + log.debug("%i retries left", remaining_retries) + + timeout = self._calculate_retry_timeout(remaining_retries, options, response.headers if response else None) + log.info("Retrying request to %s in %f seconds", options.url, timeout) + + await anyio.sleep(timeout) + + async def _process_response( + self, + *, + cast_to: Type[ResponseT], + options: FinalRequestOptions, + response: httpx.Response, + stream: bool, + stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None, + retries_taken: int = 0, + ) -> ResponseT: + origin = get_origin(cast_to) or cast_to + + if ( + inspect.isclass(origin) + and issubclass(origin, BaseAPIResponse) + # we only want to actually return the custom BaseAPIResponse class if we're + # returning the raw response, or if we're not streaming SSE, as if we're streaming + # SSE then `cast_to` doesn't actively reflect the type we need to parse into + and (not stream or bool(response.request.headers.get(RAW_RESPONSE_HEADER))) + ): + if not issubclass(origin, AsyncAPIResponse): + raise TypeError(f"API Response types must subclass {AsyncAPIResponse}; Received {origin}") + + response_cls = cast("type[BaseAPIResponse[Any]]", cast_to) + return cast( + "ResponseT", + response_cls( + raw=response, + client=self, + cast_to=extract_response_type(response_cls), + stream=stream, + stream_cls=stream_cls, + options=options, + retries_taken=retries_taken, + ), + ) + + if cast_to == httpx.Response: + return cast(ResponseT, response) + + api_response = AsyncAPIResponse( + raw=response, + client=self, + cast_to=cast("type[ResponseT]", cast_to), # pyright: ignore[reportUnnecessaryCast] + stream=stream, + stream_cls=stream_cls, + options=options, + retries_taken=retries_taken, + ) + if bool(response.request.headers.get(RAW_RESPONSE_HEADER)): + return cast(ResponseT, api_response) + + return await api_response.parse() + + def _request_api_list( + self, + model: Type[_T], + page: Type[AsyncPageT], + options: FinalRequestOptions, + ) -> AsyncPaginator[_T, AsyncPageT]: + return AsyncPaginator(client=self, options=options, page_cls=page, model=model) + + @overload + async def get( + self, + path: str, + *, + cast_to: Type[ResponseT], + options: RequestOptions = {}, + stream: Literal[False] = False, + ) -> ResponseT: ... + + @overload + async def get( + self, + path: str, + *, + cast_to: Type[ResponseT], + options: RequestOptions = {}, + stream: Literal[True], + stream_cls: type[_AsyncStreamT], + ) -> _AsyncStreamT: ... + + @overload + async def get( + self, + path: str, + *, + cast_to: Type[ResponseT], + options: RequestOptions = {}, + stream: bool, + stream_cls: type[_AsyncStreamT] | None = None, + ) -> ResponseT | _AsyncStreamT: ... + + async def get( + self, + path: str, + *, + cast_to: Type[ResponseT], + options: RequestOptions = {}, + stream: bool = False, + stream_cls: type[_AsyncStreamT] | None = None, + ) -> ResponseT | _AsyncStreamT: + opts = FinalRequestOptions.construct(method="get", url=path, **options) + return await self.request(cast_to, opts, stream=stream, stream_cls=stream_cls) + + @overload + async def post( + self, + path: str, + *, + cast_to: Type[ResponseT], + body: Body | None = None, + content: AsyncBinaryTypes | None = None, + files: RequestFiles | None = None, + options: RequestOptions = {}, + stream: Literal[False] = False, + ) -> ResponseT: ... + + @overload + async def post( + self, + path: str, + *, + cast_to: Type[ResponseT], + body: Body | None = None, + content: AsyncBinaryTypes | None = None, + files: RequestFiles | None = None, + options: RequestOptions = {}, + stream: Literal[True], + stream_cls: type[_AsyncStreamT], + ) -> _AsyncStreamT: ... + + @overload + async def post( + self, + path: str, + *, + cast_to: Type[ResponseT], + body: Body | None = None, + content: AsyncBinaryTypes | None = None, + files: RequestFiles | None = None, + options: RequestOptions = {}, + stream: bool, + stream_cls: type[_AsyncStreamT] | None = None, + ) -> ResponseT | _AsyncStreamT: ... + + async def post( + self, + path: str, + *, + cast_to: Type[ResponseT], + body: Body | None = None, + content: AsyncBinaryTypes | None = None, + files: RequestFiles | None = None, + options: RequestOptions = {}, + stream: bool = False, + stream_cls: type[_AsyncStreamT] | None = None, + ) -> ResponseT | _AsyncStreamT: + if body is not None and content is not None: + raise TypeError("Passing both `body` and `content` is not supported") + if files is not None and content is not None: + raise TypeError("Passing both `files` and `content` is not supported") + if isinstance(body, bytes): + warnings.warn( + "Passing raw bytes as `body` is deprecated and will be removed in a future version. " + "Please pass raw bytes via the `content` parameter instead.", + DeprecationWarning, + stacklevel=2, + ) + opts = FinalRequestOptions.construct( + method="post", url=path, json_data=body, content=content, files=await async_to_httpx_files(files), **options + ) + return await self.request(cast_to, opts, stream=stream, stream_cls=stream_cls) + + async def patch( + self, + path: str, + *, + cast_to: Type[ResponseT], + body: Body | None = None, + content: AsyncBinaryTypes | None = None, + files: RequestFiles | None = None, + options: RequestOptions = {}, + ) -> ResponseT: + if body is not None and content is not None: + raise TypeError("Passing both `body` and `content` is not supported") + if files is not None and content is not None: + raise TypeError("Passing both `files` and `content` is not supported") + if isinstance(body, bytes): + warnings.warn( + "Passing raw bytes as `body` is deprecated and will be removed in a future version. " + "Please pass raw bytes via the `content` parameter instead.", + DeprecationWarning, + stacklevel=2, + ) + opts = FinalRequestOptions.construct( + method="patch", + url=path, + json_data=body, + content=content, + files=await async_to_httpx_files(files), + **options, + ) + return await self.request(cast_to, opts) + + async def put( + self, + path: str, + *, + cast_to: Type[ResponseT], + body: Body | None = None, + content: AsyncBinaryTypes | None = None, + files: RequestFiles | None = None, + options: RequestOptions = {}, + ) -> ResponseT: + if body is not None and content is not None: + raise TypeError("Passing both `body` and `content` is not supported") + if files is not None and content is not None: + raise TypeError("Passing both `files` and `content` is not supported") + if isinstance(body, bytes): + warnings.warn( + "Passing raw bytes as `body` is deprecated and will be removed in a future version. " + "Please pass raw bytes via the `content` parameter instead.", + DeprecationWarning, + stacklevel=2, + ) + opts = FinalRequestOptions.construct( + method="put", url=path, json_data=body, content=content, files=await async_to_httpx_files(files), **options + ) + return await self.request(cast_to, opts) + + async def delete( + self, + path: str, + *, + cast_to: Type[ResponseT], + body: Body | None = None, + content: AsyncBinaryTypes | None = None, + options: RequestOptions = {}, + ) -> ResponseT: + if body is not None and content is not None: + raise TypeError("Passing both `body` and `content` is not supported") + if isinstance(body, bytes): + warnings.warn( + "Passing raw bytes as `body` is deprecated and will be removed in a future version. " + "Please pass raw bytes via the `content` parameter instead.", + DeprecationWarning, + stacklevel=2, + ) + opts = FinalRequestOptions.construct(method="delete", url=path, json_data=body, content=content, **options) + return await self.request(cast_to, opts) + + def get_api_list( + self, + path: str, + *, + model: Type[_T], + page: Type[AsyncPageT], + body: Body | None = None, + options: RequestOptions = {}, + method: str = "get", + ) -> AsyncPaginator[_T, AsyncPageT]: + opts = FinalRequestOptions.construct(method=method, url=path, json_data=body, **options) + return self._request_api_list(model, page, opts) + + +def make_request_options( + *, + query: Query | None = None, + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + idempotency_key: str | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + post_parser: PostParser | NotGiven = not_given, +) -> RequestOptions: + """Create a dict of type RequestOptions without keys of NotGiven values.""" + options: RequestOptions = {} + if extra_headers is not None: + options["headers"] = extra_headers + + if extra_body is not None: + options["extra_json"] = cast(AnyMapping, extra_body) + + if query is not None: + options["params"] = query + + if extra_query is not None: + options["params"] = {**options.get("params", {}), **extra_query} + + if not isinstance(timeout, NotGiven): + options["timeout"] = timeout + + if idempotency_key is not None: + options["idempotency_key"] = idempotency_key + + if is_given(post_parser): + # internal + options["post_parser"] = post_parser # type: ignore + + return options + + +class ForceMultipartDict(Dict[str, None]): + def __bool__(self) -> bool: + return True + + +class OtherPlatform: + def __init__(self, name: str) -> None: + self.name = name + + @override + def __str__(self) -> str: + return f"Other:{self.name}" + + +Platform = Union[ + OtherPlatform, + Literal[ + "MacOS", + "Linux", + "Windows", + "FreeBSD", + "OpenBSD", + "iOS", + "Android", + "Unknown", + ], +] + + +def get_platform() -> Platform: + try: + system = platform.system().lower() + platform_name = platform.platform().lower() + except Exception: + return "Unknown" + + if "iphone" in platform_name or "ipad" in platform_name: + # Tested using Python3IDE on an iPhone 11 and Pythonista on an iPad 7 + # system is Darwin and platform_name is a string like: + # - Darwin-21.6.0-iPhone12,1-64bit + # - Darwin-21.6.0-iPad7,11-64bit + return "iOS" + + if system == "darwin": + return "MacOS" + + if system == "windows": + return "Windows" + + if "android" in platform_name: + # Tested using Pydroid 3 + # system is Linux and platform_name is a string like 'Linux-5.10.81-android12-9-00001-geba40aecb3b7-ab8534902-aarch64-with-libc' + return "Android" + + if system == "linux": + # https://distro.readthedocs.io/en/latest/#distro.id + distro_id = distro.id() + if distro_id == "freebsd": + return "FreeBSD" + + if distro_id == "openbsd": + return "OpenBSD" + + return "Linux" + + if platform_name: + return OtherPlatform(platform_name) + + return "Unknown" + + +@lru_cache(maxsize=None) +def platform_headers(version: str, *, platform: Platform | None) -> Dict[str, str]: + return { + "X-Stainless-Lang": "python", + "X-Stainless-Package-Version": version, + "X-Stainless-OS": str(platform or get_platform()), + "X-Stainless-Arch": str(get_architecture()), + "X-Stainless-Runtime": get_python_runtime(), + "X-Stainless-Runtime-Version": get_python_version(), + } + + +class OtherArch: + def __init__(self, name: str) -> None: + self.name = name + + @override + def __str__(self) -> str: + return f"other:{self.name}" + + +Arch = Union[OtherArch, Literal["x32", "x64", "arm", "arm64", "unknown"]] + + +def get_python_runtime() -> str: + try: + return platform.python_implementation() + except Exception: + return "unknown" + + +def get_python_version() -> str: + try: + return platform.python_version() + except Exception: + return "unknown" + + +def get_architecture() -> Arch: + try: + machine = platform.machine().lower() + except Exception: + return "unknown" + + if machine in ("arm64", "aarch64"): + return "arm64" + + # TODO: untested + if machine == "arm": + return "arm" + + if machine == "x86_64": + return "x64" + + # TODO: untested + if sys.maxsize <= 2**32: + return "x32" + + if machine: + return OtherArch(machine) + + return "unknown" + + +def _merge_mappings( + obj1: Mapping[_T_co, Union[_T, Omit]], + obj2: Mapping[_T_co, Union[_T, Omit]], +) -> Dict[_T_co, _T]: + """Merge two mappings of the same type, removing any values that are instances of `Omit`. + + In cases with duplicate keys the second mapping takes precedence. + """ + merged = {**obj1, **obj2} + return {key: value for key, value in merged.items() if not isinstance(value, Omit)} diff --git a/src/writerai/_client.py b/src/writerai/_client.py new file mode 100644 index 00000000..b6cc406d --- /dev/null +++ b/src/writerai/_client.py @@ -0,0 +1,764 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +import os +from typing import TYPE_CHECKING, Any, Mapping +from typing_extensions import Self, override + +import httpx + +from . import _exceptions +from ._qs import Querystring +from ._types import ( + Omit, + Timeout, + NotGiven, + Transport, + ProxiesTypes, + RequestOptions, + not_given, +) +from ._utils import ( + is_given, + is_mapping_t, + get_async_library, +) +from ._compat import cached_property +from ._version import __version__ +from ._streaming import Stream as Stream, AsyncStream as AsyncStream +from ._exceptions import WriterError, APIStatusError +from ._base_client import ( + DEFAULT_MAX_RETRIES, + SyncAPIClient, + AsyncAPIClient, +) + +if TYPE_CHECKING: + from .resources import chat, files, tools, graphs, models, vision, completions, translation, applications + from .resources.chat import ChatResource, AsyncChatResource + from .resources.files import FilesResource, AsyncFilesResource + from .resources.tools import ToolsResource, AsyncToolsResource + from .resources.graphs import GraphsResource, AsyncGraphsResource + from .resources.models import ModelsResource, AsyncModelsResource + from .resources.vision import VisionResource, AsyncVisionResource + from .resources.completions import CompletionsResource, AsyncCompletionsResource + from .resources.translation import TranslationResource, AsyncTranslationResource + from .resources.applications.applications import ApplicationsResource, AsyncApplicationsResource + +__all__ = ["Timeout", "Transport", "ProxiesTypes", "RequestOptions", "Writer", "AsyncWriter", "Client", "AsyncClient"] + + +class Writer(SyncAPIClient): + # client options + api_key: str + + def __init__( + self, + *, + api_key: str | None = None, + base_url: str | httpx.URL | None = None, + timeout: float | Timeout | None | NotGiven = not_given, + max_retries: int = DEFAULT_MAX_RETRIES, + default_headers: Mapping[str, str] | None = None, + default_query: Mapping[str, object] | None = None, + # Configure a custom httpx client. + # We provide a `DefaultHttpxClient` class that you can pass to retain the default values we use for `limits`, `timeout` & `follow_redirects`. + # See the [httpx documentation](https://www.python-httpx.org/api/#client) for more details. + http_client: httpx.Client | None = None, + # Enable or disable schema validation for data returned by the API. + # When enabled an error APIResponseValidationError is raised + # if the API responds with invalid data for the expected schema. + # + # This parameter may be removed or changed in the future. + # If you rely on this feature, please open a GitHub issue + # outlining your use-case to help us decide if it should be + # part of our public interface in the future. + _strict_response_validation: bool = False, + ) -> None: + """Construct a new synchronous Writer client instance. + + This automatically infers the `api_key` argument from the `WRITER_API_KEY` environment variable if it is not provided. + """ + if api_key is None: + api_key = os.environ.get("WRITER_API_KEY") + if api_key is None: + raise WriterError( + "The api_key client option must be set either by passing api_key to the client or by setting the WRITER_API_KEY environment variable" + ) + self.api_key = api_key + + if base_url is None: + base_url = os.environ.get("WRITER_BASE_URL") + if base_url is None: + base_url = f"https://api.writer.com" + + custom_headers_env = os.environ.get("WRITER_CUSTOM_HEADERS") + if custom_headers_env is not None: + parsed: dict[str, str] = {} + for line in custom_headers_env.split("\n"): + colon = line.find(":") + if colon >= 0: + parsed[line[:colon].strip()] = line[colon + 1 :].strip() + default_headers = {**parsed, **(default_headers if is_mapping_t(default_headers) else {})} + + super().__init__( + version=__version__, + base_url=base_url, + max_retries=max_retries, + timeout=timeout, + http_client=http_client, + custom_headers=default_headers, + custom_query=default_query, + _strict_response_validation=_strict_response_validation, + ) + + self._default_stream_cls = Stream + + @cached_property + def applications(self) -> ApplicationsResource: + from .resources.applications import ApplicationsResource + + return ApplicationsResource(self) + + @cached_property + def chat(self) -> ChatResource: + from .resources.chat import ChatResource + + return ChatResource(self) + + @cached_property + def completions(self) -> CompletionsResource: + from .resources.completions import CompletionsResource + + return CompletionsResource(self) + + @cached_property + def models(self) -> ModelsResource: + from .resources.models import ModelsResource + + return ModelsResource(self) + + @cached_property + def graphs(self) -> GraphsResource: + from .resources.graphs import GraphsResource + + return GraphsResource(self) + + @cached_property + def files(self) -> FilesResource: + from .resources.files import FilesResource + + return FilesResource(self) + + @cached_property + def tools(self) -> ToolsResource: + from .resources.tools import ToolsResource + + return ToolsResource(self) + + @cached_property + def translation(self) -> TranslationResource: + from .resources.translation import TranslationResource + + return TranslationResource(self) + + @cached_property + def vision(self) -> VisionResource: + from .resources.vision import VisionResource + + return VisionResource(self) + + @cached_property + def with_raw_response(self) -> WriterWithRawResponse: + return WriterWithRawResponse(self) + + @cached_property + def with_streaming_response(self) -> WriterWithStreamedResponse: + return WriterWithStreamedResponse(self) + + @property + @override + def qs(self) -> Querystring: + return Querystring(array_format="comma") + + @property + @override + def auth_headers(self) -> dict[str, str]: + api_key = self.api_key + return {"Authorization": f"Bearer {api_key}"} + + @property + @override + def default_headers(self) -> dict[str, str | Omit]: + return { + **super().default_headers, + "X-Stainless-Async": "false", + **self._custom_headers, + } + + def copy( + self, + *, + api_key: str | None = None, + base_url: str | httpx.URL | None = None, + timeout: float | Timeout | None | NotGiven = not_given, + http_client: httpx.Client | None = None, + max_retries: int | NotGiven = not_given, + default_headers: Mapping[str, str] | None = None, + set_default_headers: Mapping[str, str] | None = None, + default_query: Mapping[str, object] | None = None, + set_default_query: Mapping[str, object] | None = None, + _extra_kwargs: Mapping[str, Any] = {}, + ) -> Self: + """ + Create a new client instance re-using the same options given to the current client with optional overriding. + """ + if default_headers is not None and set_default_headers is not None: + raise ValueError("The `default_headers` and `set_default_headers` arguments are mutually exclusive") + + if default_query is not None and set_default_query is not None: + raise ValueError("The `default_query` and `set_default_query` arguments are mutually exclusive") + + headers = self._custom_headers + if default_headers is not None: + headers = {**headers, **default_headers} + elif set_default_headers is not None: + headers = set_default_headers + + params = self._custom_query + if default_query is not None: + params = {**params, **default_query} + elif set_default_query is not None: + params = set_default_query + + http_client = http_client or self._client + return self.__class__( + api_key=api_key or self.api_key, + base_url=base_url or self.base_url, + timeout=self.timeout if isinstance(timeout, NotGiven) else timeout, + http_client=http_client, + max_retries=max_retries if is_given(max_retries) else self.max_retries, + default_headers=headers, + default_query=params, + **_extra_kwargs, + ) + + # Alias for `copy` for nicer inline usage, e.g. + # client.with_options(timeout=10).foo.create(...) + with_options = copy + + @override + def _make_status_error( + self, + err_msg: str, + *, + body: object, + response: httpx.Response, + ) -> APIStatusError: + if response.status_code == 400: + return _exceptions.BadRequestError(err_msg, response=response, body=body) + + if response.status_code == 401: + return _exceptions.AuthenticationError(err_msg, response=response, body=body) + + if response.status_code == 403: + return _exceptions.PermissionDeniedError(err_msg, response=response, body=body) + + if response.status_code == 404: + return _exceptions.NotFoundError(err_msg, response=response, body=body) + + if response.status_code == 409: + return _exceptions.ConflictError(err_msg, response=response, body=body) + + if response.status_code == 422: + return _exceptions.UnprocessableEntityError(err_msg, response=response, body=body) + + if response.status_code == 429: + return _exceptions.RateLimitError(err_msg, response=response, body=body) + + if response.status_code >= 500: + return _exceptions.InternalServerError(err_msg, response=response, body=body) + return APIStatusError(err_msg, response=response, body=body) + + +class AsyncWriter(AsyncAPIClient): + # client options + api_key: str + + def __init__( + self, + *, + api_key: str | None = None, + base_url: str | httpx.URL | None = None, + timeout: float | Timeout | None | NotGiven = not_given, + max_retries: int = DEFAULT_MAX_RETRIES, + default_headers: Mapping[str, str] | None = None, + default_query: Mapping[str, object] | None = None, + # Configure a custom httpx client. + # We provide a `DefaultAsyncHttpxClient` class that you can pass to retain the default values we use for `limits`, `timeout` & `follow_redirects`. + # See the [httpx documentation](https://www.python-httpx.org/api/#asyncclient) for more details. + http_client: httpx.AsyncClient | None = None, + # Enable or disable schema validation for data returned by the API. + # When enabled an error APIResponseValidationError is raised + # if the API responds with invalid data for the expected schema. + # + # This parameter may be removed or changed in the future. + # If you rely on this feature, please open a GitHub issue + # outlining your use-case to help us decide if it should be + # part of our public interface in the future. + _strict_response_validation: bool = False, + ) -> None: + """Construct a new async AsyncWriter client instance. + + This automatically infers the `api_key` argument from the `WRITER_API_KEY` environment variable if it is not provided. + """ + if api_key is None: + api_key = os.environ.get("WRITER_API_KEY") + if api_key is None: + raise WriterError( + "The api_key client option must be set either by passing api_key to the client or by setting the WRITER_API_KEY environment variable" + ) + self.api_key = api_key + + if base_url is None: + base_url = os.environ.get("WRITER_BASE_URL") + if base_url is None: + base_url = f"https://api.writer.com" + + custom_headers_env = os.environ.get("WRITER_CUSTOM_HEADERS") + if custom_headers_env is not None: + parsed: dict[str, str] = {} + for line in custom_headers_env.split("\n"): + colon = line.find(":") + if colon >= 0: + parsed[line[:colon].strip()] = line[colon + 1 :].strip() + default_headers = {**parsed, **(default_headers if is_mapping_t(default_headers) else {})} + + super().__init__( + version=__version__, + base_url=base_url, + max_retries=max_retries, + timeout=timeout, + http_client=http_client, + custom_headers=default_headers, + custom_query=default_query, + _strict_response_validation=_strict_response_validation, + ) + + self._default_stream_cls = AsyncStream + + @cached_property + def applications(self) -> AsyncApplicationsResource: + from .resources.applications import AsyncApplicationsResource + + return AsyncApplicationsResource(self) + + @cached_property + def chat(self) -> AsyncChatResource: + from .resources.chat import AsyncChatResource + + return AsyncChatResource(self) + + @cached_property + def completions(self) -> AsyncCompletionsResource: + from .resources.completions import AsyncCompletionsResource + + return AsyncCompletionsResource(self) + + @cached_property + def models(self) -> AsyncModelsResource: + from .resources.models import AsyncModelsResource + + return AsyncModelsResource(self) + + @cached_property + def graphs(self) -> AsyncGraphsResource: + from .resources.graphs import AsyncGraphsResource + + return AsyncGraphsResource(self) + + @cached_property + def files(self) -> AsyncFilesResource: + from .resources.files import AsyncFilesResource + + return AsyncFilesResource(self) + + @cached_property + def tools(self) -> AsyncToolsResource: + from .resources.tools import AsyncToolsResource + + return AsyncToolsResource(self) + + @cached_property + def translation(self) -> AsyncTranslationResource: + from .resources.translation import AsyncTranslationResource + + return AsyncTranslationResource(self) + + @cached_property + def vision(self) -> AsyncVisionResource: + from .resources.vision import AsyncVisionResource + + return AsyncVisionResource(self) + + @cached_property + def with_raw_response(self) -> AsyncWriterWithRawResponse: + return AsyncWriterWithRawResponse(self) + + @cached_property + def with_streaming_response(self) -> AsyncWriterWithStreamedResponse: + return AsyncWriterWithStreamedResponse(self) + + @property + @override + def qs(self) -> Querystring: + return Querystring(array_format="comma") + + @property + @override + def auth_headers(self) -> dict[str, str]: + api_key = self.api_key + return {"Authorization": f"Bearer {api_key}"} + + @property + @override + def default_headers(self) -> dict[str, str | Omit]: + return { + **super().default_headers, + "X-Stainless-Async": f"async:{get_async_library()}", + **self._custom_headers, + } + + def copy( + self, + *, + api_key: str | None = None, + base_url: str | httpx.URL | None = None, + timeout: float | Timeout | None | NotGiven = not_given, + http_client: httpx.AsyncClient | None = None, + max_retries: int | NotGiven = not_given, + default_headers: Mapping[str, str] | None = None, + set_default_headers: Mapping[str, str] | None = None, + default_query: Mapping[str, object] | None = None, + set_default_query: Mapping[str, object] | None = None, + _extra_kwargs: Mapping[str, Any] = {}, + ) -> Self: + """ + Create a new client instance re-using the same options given to the current client with optional overriding. + """ + if default_headers is not None and set_default_headers is not None: + raise ValueError("The `default_headers` and `set_default_headers` arguments are mutually exclusive") + + if default_query is not None and set_default_query is not None: + raise ValueError("The `default_query` and `set_default_query` arguments are mutually exclusive") + + headers = self._custom_headers + if default_headers is not None: + headers = {**headers, **default_headers} + elif set_default_headers is not None: + headers = set_default_headers + + params = self._custom_query + if default_query is not None: + params = {**params, **default_query} + elif set_default_query is not None: + params = set_default_query + + http_client = http_client or self._client + return self.__class__( + api_key=api_key or self.api_key, + base_url=base_url or self.base_url, + timeout=self.timeout if isinstance(timeout, NotGiven) else timeout, + http_client=http_client, + max_retries=max_retries if is_given(max_retries) else self.max_retries, + default_headers=headers, + default_query=params, + **_extra_kwargs, + ) + + # Alias for `copy` for nicer inline usage, e.g. + # client.with_options(timeout=10).foo.create(...) + with_options = copy + + @override + def _make_status_error( + self, + err_msg: str, + *, + body: object, + response: httpx.Response, + ) -> APIStatusError: + if response.status_code == 400: + return _exceptions.BadRequestError(err_msg, response=response, body=body) + + if response.status_code == 401: + return _exceptions.AuthenticationError(err_msg, response=response, body=body) + + if response.status_code == 403: + return _exceptions.PermissionDeniedError(err_msg, response=response, body=body) + + if response.status_code == 404: + return _exceptions.NotFoundError(err_msg, response=response, body=body) + + if response.status_code == 409: + return _exceptions.ConflictError(err_msg, response=response, body=body) + + if response.status_code == 422: + return _exceptions.UnprocessableEntityError(err_msg, response=response, body=body) + + if response.status_code == 429: + return _exceptions.RateLimitError(err_msg, response=response, body=body) + + if response.status_code >= 500: + return _exceptions.InternalServerError(err_msg, response=response, body=body) + return APIStatusError(err_msg, response=response, body=body) + + +class WriterWithRawResponse: + _client: Writer + + def __init__(self, client: Writer) -> None: + self._client = client + + @cached_property + def applications(self) -> applications.ApplicationsResourceWithRawResponse: + from .resources.applications import ApplicationsResourceWithRawResponse + + return ApplicationsResourceWithRawResponse(self._client.applications) + + @cached_property + def chat(self) -> chat.ChatResourceWithRawResponse: + from .resources.chat import ChatResourceWithRawResponse + + return ChatResourceWithRawResponse(self._client.chat) + + @cached_property + def completions(self) -> completions.CompletionsResourceWithRawResponse: + from .resources.completions import CompletionsResourceWithRawResponse + + return CompletionsResourceWithRawResponse(self._client.completions) + + @cached_property + def models(self) -> models.ModelsResourceWithRawResponse: + from .resources.models import ModelsResourceWithRawResponse + + return ModelsResourceWithRawResponse(self._client.models) + + @cached_property + def graphs(self) -> graphs.GraphsResourceWithRawResponse: + from .resources.graphs import GraphsResourceWithRawResponse + + return GraphsResourceWithRawResponse(self._client.graphs) + + @cached_property + def files(self) -> files.FilesResourceWithRawResponse: + from .resources.files import FilesResourceWithRawResponse + + return FilesResourceWithRawResponse(self._client.files) + + @cached_property + def tools(self) -> tools.ToolsResourceWithRawResponse: + from .resources.tools import ToolsResourceWithRawResponse + + return ToolsResourceWithRawResponse(self._client.tools) + + @cached_property + def translation(self) -> translation.TranslationResourceWithRawResponse: + from .resources.translation import TranslationResourceWithRawResponse + + return TranslationResourceWithRawResponse(self._client.translation) + + @cached_property + def vision(self) -> vision.VisionResourceWithRawResponse: + from .resources.vision import VisionResourceWithRawResponse + + return VisionResourceWithRawResponse(self._client.vision) + + +class AsyncWriterWithRawResponse: + _client: AsyncWriter + + def __init__(self, client: AsyncWriter) -> None: + self._client = client + + @cached_property + def applications(self) -> applications.AsyncApplicationsResourceWithRawResponse: + from .resources.applications import AsyncApplicationsResourceWithRawResponse + + return AsyncApplicationsResourceWithRawResponse(self._client.applications) + + @cached_property + def chat(self) -> chat.AsyncChatResourceWithRawResponse: + from .resources.chat import AsyncChatResourceWithRawResponse + + return AsyncChatResourceWithRawResponse(self._client.chat) + + @cached_property + def completions(self) -> completions.AsyncCompletionsResourceWithRawResponse: + from .resources.completions import AsyncCompletionsResourceWithRawResponse + + return AsyncCompletionsResourceWithRawResponse(self._client.completions) + + @cached_property + def models(self) -> models.AsyncModelsResourceWithRawResponse: + from .resources.models import AsyncModelsResourceWithRawResponse + + return AsyncModelsResourceWithRawResponse(self._client.models) + + @cached_property + def graphs(self) -> graphs.AsyncGraphsResourceWithRawResponse: + from .resources.graphs import AsyncGraphsResourceWithRawResponse + + return AsyncGraphsResourceWithRawResponse(self._client.graphs) + + @cached_property + def files(self) -> files.AsyncFilesResourceWithRawResponse: + from .resources.files import AsyncFilesResourceWithRawResponse + + return AsyncFilesResourceWithRawResponse(self._client.files) + + @cached_property + def tools(self) -> tools.AsyncToolsResourceWithRawResponse: + from .resources.tools import AsyncToolsResourceWithRawResponse + + return AsyncToolsResourceWithRawResponse(self._client.tools) + + @cached_property + def translation(self) -> translation.AsyncTranslationResourceWithRawResponse: + from .resources.translation import AsyncTranslationResourceWithRawResponse + + return AsyncTranslationResourceWithRawResponse(self._client.translation) + + @cached_property + def vision(self) -> vision.AsyncVisionResourceWithRawResponse: + from .resources.vision import AsyncVisionResourceWithRawResponse + + return AsyncVisionResourceWithRawResponse(self._client.vision) + + +class WriterWithStreamedResponse: + _client: Writer + + def __init__(self, client: Writer) -> None: + self._client = client + + @cached_property + def applications(self) -> applications.ApplicationsResourceWithStreamingResponse: + from .resources.applications import ApplicationsResourceWithStreamingResponse + + return ApplicationsResourceWithStreamingResponse(self._client.applications) + + @cached_property + def chat(self) -> chat.ChatResourceWithStreamingResponse: + from .resources.chat import ChatResourceWithStreamingResponse + + return ChatResourceWithStreamingResponse(self._client.chat) + + @cached_property + def completions(self) -> completions.CompletionsResourceWithStreamingResponse: + from .resources.completions import CompletionsResourceWithStreamingResponse + + return CompletionsResourceWithStreamingResponse(self._client.completions) + + @cached_property + def models(self) -> models.ModelsResourceWithStreamingResponse: + from .resources.models import ModelsResourceWithStreamingResponse + + return ModelsResourceWithStreamingResponse(self._client.models) + + @cached_property + def graphs(self) -> graphs.GraphsResourceWithStreamingResponse: + from .resources.graphs import GraphsResourceWithStreamingResponse + + return GraphsResourceWithStreamingResponse(self._client.graphs) + + @cached_property + def files(self) -> files.FilesResourceWithStreamingResponse: + from .resources.files import FilesResourceWithStreamingResponse + + return FilesResourceWithStreamingResponse(self._client.files) + + @cached_property + def tools(self) -> tools.ToolsResourceWithStreamingResponse: + from .resources.tools import ToolsResourceWithStreamingResponse + + return ToolsResourceWithStreamingResponse(self._client.tools) + + @cached_property + def translation(self) -> translation.TranslationResourceWithStreamingResponse: + from .resources.translation import TranslationResourceWithStreamingResponse + + return TranslationResourceWithStreamingResponse(self._client.translation) + + @cached_property + def vision(self) -> vision.VisionResourceWithStreamingResponse: + from .resources.vision import VisionResourceWithStreamingResponse + + return VisionResourceWithStreamingResponse(self._client.vision) + + +class AsyncWriterWithStreamedResponse: + _client: AsyncWriter + + def __init__(self, client: AsyncWriter) -> None: + self._client = client + + @cached_property + def applications(self) -> applications.AsyncApplicationsResourceWithStreamingResponse: + from .resources.applications import AsyncApplicationsResourceWithStreamingResponse + + return AsyncApplicationsResourceWithStreamingResponse(self._client.applications) + + @cached_property + def chat(self) -> chat.AsyncChatResourceWithStreamingResponse: + from .resources.chat import AsyncChatResourceWithStreamingResponse + + return AsyncChatResourceWithStreamingResponse(self._client.chat) + + @cached_property + def completions(self) -> completions.AsyncCompletionsResourceWithStreamingResponse: + from .resources.completions import AsyncCompletionsResourceWithStreamingResponse + + return AsyncCompletionsResourceWithStreamingResponse(self._client.completions) + + @cached_property + def models(self) -> models.AsyncModelsResourceWithStreamingResponse: + from .resources.models import AsyncModelsResourceWithStreamingResponse + + return AsyncModelsResourceWithStreamingResponse(self._client.models) + + @cached_property + def graphs(self) -> graphs.AsyncGraphsResourceWithStreamingResponse: + from .resources.graphs import AsyncGraphsResourceWithStreamingResponse + + return AsyncGraphsResourceWithStreamingResponse(self._client.graphs) + + @cached_property + def files(self) -> files.AsyncFilesResourceWithStreamingResponse: + from .resources.files import AsyncFilesResourceWithStreamingResponse + + return AsyncFilesResourceWithStreamingResponse(self._client.files) + + @cached_property + def tools(self) -> tools.AsyncToolsResourceWithStreamingResponse: + from .resources.tools import AsyncToolsResourceWithStreamingResponse + + return AsyncToolsResourceWithStreamingResponse(self._client.tools) + + @cached_property + def translation(self) -> translation.AsyncTranslationResourceWithStreamingResponse: + from .resources.translation import AsyncTranslationResourceWithStreamingResponse + + return AsyncTranslationResourceWithStreamingResponse(self._client.translation) + + @cached_property + def vision(self) -> vision.AsyncVisionResourceWithStreamingResponse: + from .resources.vision import AsyncVisionResourceWithStreamingResponse + + return AsyncVisionResourceWithStreamingResponse(self._client.vision) + + +Client = Writer + +AsyncClient = AsyncWriter diff --git a/src/writerai/_compat.py b/src/writerai/_compat.py new file mode 100644 index 00000000..e6690a4f --- /dev/null +++ b/src/writerai/_compat.py @@ -0,0 +1,226 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Union, Generic, TypeVar, Callable, cast, overload +from datetime import date, datetime +from typing_extensions import Self, Literal, TypedDict + +import pydantic +from pydantic.fields import FieldInfo + +from ._types import IncEx, StrBytesIntFloat + +_T = TypeVar("_T") +_ModelT = TypeVar("_ModelT", bound=pydantic.BaseModel) + +# --------------- Pydantic v2, v3 compatibility --------------- + +# Pyright incorrectly reports some of our functions as overriding a method when they don't +# pyright: reportIncompatibleMethodOverride=false + +PYDANTIC_V1 = pydantic.VERSION.startswith("1.") + +if TYPE_CHECKING: + + def parse_date(value: date | StrBytesIntFloat) -> date: # noqa: ARG001 + ... + + def parse_datetime(value: Union[datetime, StrBytesIntFloat]) -> datetime: # noqa: ARG001 + ... + + def get_args(t: type[Any]) -> tuple[Any, ...]: # noqa: ARG001 + ... + + def is_union(tp: type[Any] | None) -> bool: # noqa: ARG001 + ... + + def get_origin(t: type[Any]) -> type[Any] | None: # noqa: ARG001 + ... + + def is_literal_type(type_: type[Any]) -> bool: # noqa: ARG001 + ... + + def is_typeddict(type_: type[Any]) -> bool: # noqa: ARG001 + ... + +else: + # v1 re-exports + if PYDANTIC_V1: + from pydantic.typing import ( + get_args as get_args, + is_union as is_union, + get_origin as get_origin, + is_typeddict as is_typeddict, + is_literal_type as is_literal_type, + ) + from pydantic.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime + else: + from ._utils import ( + get_args as get_args, + is_union as is_union, + get_origin as get_origin, + parse_date as parse_date, + is_typeddict as is_typeddict, + parse_datetime as parse_datetime, + is_literal_type as is_literal_type, + ) + + +# refactored config +if TYPE_CHECKING: + from pydantic import ConfigDict as ConfigDict +else: + if PYDANTIC_V1: + # TODO: provide an error message here? + ConfigDict = None + else: + from pydantic import ConfigDict as ConfigDict + + +# renamed methods / properties +def parse_obj(model: type[_ModelT], value: object) -> _ModelT: + if PYDANTIC_V1: + return cast(_ModelT, model.parse_obj(value)) # pyright: ignore[reportDeprecated, reportUnnecessaryCast] + else: + return model.model_validate(value) + + +def field_is_required(field: FieldInfo) -> bool: + if PYDANTIC_V1: + return field.required # type: ignore + return field.is_required() + + +def field_get_default(field: FieldInfo) -> Any: + value = field.get_default() + if PYDANTIC_V1: + return value + from pydantic_core import PydanticUndefined + + if value == PydanticUndefined: + return None + return value + + +def field_outer_type(field: FieldInfo) -> Any: + if PYDANTIC_V1: + return field.outer_type_ # type: ignore + return field.annotation + + +def get_model_config(model: type[pydantic.BaseModel]) -> Any: + if PYDANTIC_V1: + return model.__config__ # type: ignore + return model.model_config + + +def get_model_fields(model: type[pydantic.BaseModel]) -> dict[str, FieldInfo]: + if PYDANTIC_V1: + return model.__fields__ # type: ignore + return model.model_fields + + +def model_copy(model: _ModelT, *, deep: bool = False) -> _ModelT: + if PYDANTIC_V1: + return model.copy(deep=deep) # type: ignore + return model.model_copy(deep=deep) + + +def model_json(model: pydantic.BaseModel, *, indent: int | None = None) -> str: + if PYDANTIC_V1: + return model.json(indent=indent) # type: ignore + return model.model_dump_json(indent=indent) + + +class _ModelDumpKwargs(TypedDict, total=False): + by_alias: bool + + +def model_dump( + model: pydantic.BaseModel, + *, + exclude: IncEx | None = None, + exclude_unset: bool = False, + exclude_defaults: bool = False, + warnings: bool = True, + mode: Literal["json", "python"] = "python", + by_alias: bool | None = None, +) -> dict[str, Any]: + if (not PYDANTIC_V1) or hasattr(model, "model_dump"): + kwargs: _ModelDumpKwargs = {} + if by_alias is not None: + kwargs["by_alias"] = by_alias + return model.model_dump( + mode=mode, + exclude=exclude, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + # warnings are not supported in Pydantic v1 + warnings=True if PYDANTIC_V1 else warnings, + **kwargs, + ) + return cast( + "dict[str, Any]", + model.dict( # pyright: ignore[reportDeprecated, reportUnnecessaryCast] + exclude=exclude, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults, by_alias=bool(by_alias) + ), + ) + + +def model_parse(model: type[_ModelT], data: Any) -> _ModelT: + if PYDANTIC_V1: + return model.parse_obj(data) # pyright: ignore[reportDeprecated] + return model.model_validate(data) + + +# generic models +if TYPE_CHECKING: + + class GenericModel(pydantic.BaseModel): ... + +else: + if PYDANTIC_V1: + import pydantic.generics + + class GenericModel(pydantic.generics.GenericModel, pydantic.BaseModel): ... + else: + # there no longer needs to be a distinction in v2 but + # we still have to create our own subclass to avoid + # inconsistent MRO ordering errors + class GenericModel(pydantic.BaseModel): ... + + +# cached properties +if TYPE_CHECKING: + cached_property = property + + # we define a separate type (copied from typeshed) + # that represents that `cached_property` is `set`able + # at runtime, which differs from `@property`. + # + # this is a separate type as editors likely special case + # `@property` and we don't want to cause issues just to have + # more helpful internal types. + + class typed_cached_property(Generic[_T]): + func: Callable[[Any], _T] + attrname: str | None + + def __init__(self, func: Callable[[Any], _T]) -> None: ... + + @overload + def __get__(self, instance: None, owner: type[Any] | None = None) -> Self: ... + + @overload + def __get__(self, instance: object, owner: type[Any] | None = None) -> _T: ... + + def __get__(self, instance: object, owner: type[Any] | None = None) -> _T | Self: + raise NotImplementedError() + + def __set_name__(self, owner: type[Any], name: str) -> None: ... + + # __set__ is not defined at runtime, but @cached_property is designed to be settable + def __set__(self, instance: object, value: _T) -> None: ... +else: + from functools import cached_property as cached_property + + typed_cached_property = cached_property diff --git a/src/writerai/_constants.py b/src/writerai/_constants.py new file mode 100644 index 00000000..9a4c97ab --- /dev/null +++ b/src/writerai/_constants.py @@ -0,0 +1,14 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +import httpx + +RAW_RESPONSE_HEADER = "X-Stainless-Raw-Response" +OVERRIDE_CAST_TO_HEADER = "____stainless_override_cast_to" + +# default timeout is 3 minutes +DEFAULT_TIMEOUT = httpx.Timeout(timeout=180, connect=5.0) +DEFAULT_MAX_RETRIES = 7 +DEFAULT_CONNECTION_LIMITS = httpx.Limits(max_connections=100, max_keepalive_connections=20) + +INITIAL_RETRY_DELAY = 1.0 +MAX_RETRY_DELAY = 60.0 diff --git a/src/writerai/_exceptions.py b/src/writerai/_exceptions.py new file mode 100644 index 00000000..684af996 --- /dev/null +++ b/src/writerai/_exceptions.py @@ -0,0 +1,108 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing_extensions import Literal + +import httpx + +__all__ = [ + "BadRequestError", + "AuthenticationError", + "PermissionDeniedError", + "NotFoundError", + "ConflictError", + "UnprocessableEntityError", + "RateLimitError", + "InternalServerError", +] + + +class WriterError(Exception): + pass + + +class APIError(WriterError): + message: str + request: httpx.Request + + body: object | None + """The API response body. + + If the API responded with a valid JSON structure then this property will be the + decoded result. + + If it isn't a valid JSON structure then this will be the raw response. + + If there was no response associated with this error then it will be `None`. + """ + + def __init__(self, message: str, request: httpx.Request, *, body: object | None) -> None: # noqa: ARG002 + super().__init__(message) + self.request = request + self.message = message + self.body = body + + +class APIResponseValidationError(APIError): + response: httpx.Response + status_code: int + + def __init__(self, response: httpx.Response, body: object | None, *, message: str | None = None) -> None: + super().__init__(message or "Data returned by API invalid for expected schema.", response.request, body=body) + self.response = response + self.status_code = response.status_code + + +class APIStatusError(APIError): + """Raised when an API response has a status code of 4xx or 5xx.""" + + response: httpx.Response + status_code: int + + def __init__(self, message: str, *, response: httpx.Response, body: object | None) -> None: + super().__init__(message, response.request, body=body) + self.response = response + self.status_code = response.status_code + + +class APIConnectionError(APIError): + def __init__(self, *, message: str = "Connection error.", request: httpx.Request) -> None: + super().__init__(message, request, body=None) + + +class APITimeoutError(APIConnectionError): + def __init__(self, request: httpx.Request) -> None: + super().__init__(message="Request timed out.", request=request) + + +class BadRequestError(APIStatusError): + status_code: Literal[400] = 400 # pyright: ignore[reportIncompatibleVariableOverride] + + +class AuthenticationError(APIStatusError): + status_code: Literal[401] = 401 # pyright: ignore[reportIncompatibleVariableOverride] + + +class PermissionDeniedError(APIStatusError): + status_code: Literal[403] = 403 # pyright: ignore[reportIncompatibleVariableOverride] + + +class NotFoundError(APIStatusError): + status_code: Literal[404] = 404 # pyright: ignore[reportIncompatibleVariableOverride] + + +class ConflictError(APIStatusError): + status_code: Literal[409] = 409 # pyright: ignore[reportIncompatibleVariableOverride] + + +class UnprocessableEntityError(APIStatusError): + status_code: Literal[422] = 422 # pyright: ignore[reportIncompatibleVariableOverride] + + +class RateLimitError(APIStatusError): + status_code: Literal[429] = 429 # pyright: ignore[reportIncompatibleVariableOverride] + + +class InternalServerError(APIStatusError): + pass diff --git a/src/writerai/_files.py b/src/writerai/_files.py new file mode 100644 index 00000000..76da9e08 --- /dev/null +++ b/src/writerai/_files.py @@ -0,0 +1,173 @@ +from __future__ import annotations + +import io +import os +import pathlib +from typing import Sequence, cast, overload +from typing_extensions import TypeVar, TypeGuard + +import anyio + +from ._types import ( + FileTypes, + FileContent, + RequestFiles, + HttpxFileTypes, + Base64FileInput, + HttpxFileContent, + HttpxRequestFiles, +) +from ._utils import is_list, is_mapping, is_tuple_t, is_mapping_t, is_sequence_t + +_T = TypeVar("_T") + + +def is_base64_file_input(obj: object) -> TypeGuard[Base64FileInput]: + return isinstance(obj, io.IOBase) or isinstance(obj, os.PathLike) + + +def is_file_content(obj: object) -> TypeGuard[FileContent]: + return ( + isinstance(obj, bytes) or isinstance(obj, tuple) or isinstance(obj, io.IOBase) or isinstance(obj, os.PathLike) + ) + + +def assert_is_file_content(obj: object, *, key: str | None = None) -> None: + if not is_file_content(obj): + prefix = f"Expected entry at `{key}`" if key is not None else f"Expected file input `{obj!r}`" + raise RuntimeError( + f"{prefix} to be bytes, an io.IOBase instance, PathLike or a tuple but received {type(obj)} instead." + ) from None + + +@overload +def to_httpx_files(files: None) -> None: ... + + +@overload +def to_httpx_files(files: RequestFiles) -> HttpxRequestFiles: ... + + +def to_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles | None: + if files is None: + return None + + if is_mapping_t(files): + files = {key: _transform_file(file) for key, file in files.items()} + elif is_sequence_t(files): + files = [(key, _transform_file(file)) for key, file in files] + else: + raise TypeError(f"Unexpected file type input {type(files)}, expected mapping or sequence") + + return files + + +def _transform_file(file: FileTypes) -> HttpxFileTypes: + if is_file_content(file): + if isinstance(file, os.PathLike): + path = pathlib.Path(file) + return (path.name, path.read_bytes()) + + return file + + if is_tuple_t(file): + return (file[0], read_file_content(file[1]), *file[2:]) + + raise TypeError(f"Expected file types input to be a FileContent type or to be a tuple") + + +def read_file_content(file: FileContent) -> HttpxFileContent: + if isinstance(file, os.PathLike): + return pathlib.Path(file).read_bytes() + return file + + +@overload +async def async_to_httpx_files(files: None) -> None: ... + + +@overload +async def async_to_httpx_files(files: RequestFiles) -> HttpxRequestFiles: ... + + +async def async_to_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles | None: + if files is None: + return None + + if is_mapping_t(files): + files = {key: await _async_transform_file(file) for key, file in files.items()} + elif is_sequence_t(files): + files = [(key, await _async_transform_file(file)) for key, file in files] + else: + raise TypeError(f"Unexpected file type input {type(files)}, expected mapping or sequence") + + return files + + +async def _async_transform_file(file: FileTypes) -> HttpxFileTypes: + if is_file_content(file): + if isinstance(file, os.PathLike): + path = anyio.Path(file) + return (path.name, await path.read_bytes()) + + return file + + if is_tuple_t(file): + return (file[0], await async_read_file_content(file[1]), *file[2:]) + + raise TypeError(f"Expected file types input to be a FileContent type or to be a tuple") + + +async def async_read_file_content(file: FileContent) -> HttpxFileContent: + if isinstance(file, os.PathLike): + return await anyio.Path(file).read_bytes() + + return file + + +def deepcopy_with_paths(item: _T, paths: Sequence[Sequence[str]]) -> _T: + """Copy only the containers along the given paths. + + Used to guard against mutation by extract_files without copying the entire structure. + Only dicts and lists that lie on a path are copied; everything else + is returned by reference. + + For example, given paths=[["foo", "files", "file"]] and the structure: + { + "foo": { + "bar": {"baz": {}}, + "files": {"file": } + } + } + The root dict, "foo", and "files" are copied (they lie on the path). + "bar" and "baz" are returned by reference (off the path). + """ + return _deepcopy_with_paths(item, paths, 0) + + +def _deepcopy_with_paths(item: _T, paths: Sequence[Sequence[str]], index: int) -> _T: + if not paths: + return item + if is_mapping(item): + key_to_paths: dict[str, list[Sequence[str]]] = {} + for path in paths: + if index < len(path): + key_to_paths.setdefault(path[index], []).append(path) + + # if no path continues through this mapping, it won't be mutated and copying it is redundant + if not key_to_paths: + return item + + result = dict(item) + for key, subpaths in key_to_paths.items(): + if key in result: + result[key] = _deepcopy_with_paths(result[key], subpaths, index + 1) + return cast(_T, result) + if is_list(item): + array_paths = [path for path in paths if index < len(path) and path[index] == ""] + + # if no path expects a list here, nothing will be mutated inside it - return by reference + if not array_paths: + return cast(_T, item) + return cast(_T, [_deepcopy_with_paths(entry, array_paths, index + 1) for entry in item]) + return item diff --git a/src/writerai/_models.py b/src/writerai/_models.py new file mode 100644 index 00000000..8c5ab260 --- /dev/null +++ b/src/writerai/_models.py @@ -0,0 +1,952 @@ +from __future__ import annotations + +import os +import inspect +import weakref +from typing import ( + IO, + TYPE_CHECKING, + Any, + Type, + Union, + Generic, + TypeVar, + Callable, + Iterable, + Optional, + AsyncIterable, + cast, +) +from datetime import date, datetime +from typing_extensions import ( + List, + Unpack, + Literal, + ClassVar, + Protocol, + Required, + Annotated, + ParamSpec, + TypeAlias, + TypedDict, + TypeGuard, + final, + override, + runtime_checkable, +) + +import pydantic +from pydantic.fields import FieldInfo + +from ._types import ( + Body, + IncEx, + Query, + ModelT, + Headers, + Timeout, + NotGiven, + AnyMapping, + HttpxRequestFiles, +) +from ._utils import ( + PropertyInfo, + is_list, + is_given, + json_safe, + lru_cache, + is_mapping, + parse_date, + coerce_boolean, + parse_datetime, + strip_not_given, + extract_type_arg, + is_annotated_type, + is_type_alias_type, + strip_annotated_type, +) +from ._compat import ( + PYDANTIC_V1, + ConfigDict, + GenericModel as BaseGenericModel, + get_args, + is_union, + parse_obj, + get_origin, + is_literal_type, + get_model_config, + get_model_fields, + field_get_default, +) +from ._constants import RAW_RESPONSE_HEADER + +if TYPE_CHECKING: + from pydantic import GetCoreSchemaHandler, ValidatorFunctionWrapHandler + from pydantic_core import CoreSchema, core_schema + from pydantic_core.core_schema import ModelField, ModelSchema, LiteralSchema, ModelFieldsSchema +else: + try: + from pydantic_core import CoreSchema, core_schema + except ImportError: + CoreSchema = None + core_schema = None + +__all__ = ["BaseModel", "GenericModel"] + +_T = TypeVar("_T") +_BaseModelT = TypeVar("_BaseModelT", bound="BaseModel") + +P = ParamSpec("P") + + +@runtime_checkable +class _ConfigProtocol(Protocol): + allow_population_by_field_name: bool + + +class BaseModel(pydantic.BaseModel): + if PYDANTIC_V1: + + @property + @override + def model_fields_set(self) -> set[str]: + # a forwards-compat shim for pydantic v2 + return self.__fields_set__ # type: ignore + + class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated] + extra: Any = pydantic.Extra.allow # type: ignore + else: + model_config: ClassVar[ConfigDict] = ConfigDict( + extra="allow", defer_build=coerce_boolean(os.environ.get("DEFER_PYDANTIC_BUILD", "true")) + ) + + def to_dict( + self, + *, + mode: Literal["json", "python"] = "python", + use_api_names: bool = True, + exclude_unset: bool = True, + exclude_defaults: bool = False, + exclude_none: bool = False, + warnings: bool = True, + ) -> dict[str, object]: + """Recursively generate a dictionary representation of the model, optionally specifying which fields to include or exclude. + + By default, fields that were not set by the API will not be included, + and keys will match the API response, *not* the property names from the model. + + For example, if the API responds with `"fooBar": true` but we've defined a `foo_bar: bool` property, + the output will use the `"fooBar"` key (unless `use_api_names=False` is passed). + + Args: + mode: + If mode is 'json', the dictionary will only contain JSON serializable types. e.g. `datetime` will be turned into a string, `"2024-3-22T18:11:19.117000Z"`. + If mode is 'python', the dictionary may contain any Python objects. e.g. `datetime(2024, 3, 22)` + + use_api_names: Whether to use the key that the API responded with or the property name. Defaults to `True`. + exclude_unset: Whether to exclude fields that have not been explicitly set. + exclude_defaults: Whether to exclude fields that are set to their default value from the output. + exclude_none: Whether to exclude fields that have a value of `None` from the output. + warnings: Whether to log warnings when invalid fields are encountered. This is only supported in Pydantic v2. + """ + return self.model_dump( + mode=mode, + by_alias=use_api_names, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + warnings=warnings, + ) + + def to_json( + self, + *, + indent: int | None = 2, + use_api_names: bool = True, + exclude_unset: bool = True, + exclude_defaults: bool = False, + exclude_none: bool = False, + warnings: bool = True, + ) -> str: + """Generates a JSON string representing this model as it would be received from or sent to the API (but with indentation). + + By default, fields that were not set by the API will not be included, + and keys will match the API response, *not* the property names from the model. + + For example, if the API responds with `"fooBar": true` but we've defined a `foo_bar: bool` property, + the output will use the `"fooBar"` key (unless `use_api_names=False` is passed). + + Args: + indent: Indentation to use in the JSON output. If `None` is passed, the output will be compact. Defaults to `2` + use_api_names: Whether to use the key that the API responded with or the property name. Defaults to `True`. + exclude_unset: Whether to exclude fields that have not been explicitly set. + exclude_defaults: Whether to exclude fields that have the default value. + exclude_none: Whether to exclude fields that have a value of `None`. + warnings: Whether to show any warnings that occurred during serialization. This is only supported in Pydantic v2. + """ + return self.model_dump_json( + indent=indent, + by_alias=use_api_names, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + warnings=warnings, + ) + + @override + def __str__(self) -> str: + # mypy complains about an invalid self arg + return f"{self.__repr_name__()}({self.__repr_str__(', ')})" # type: ignore[misc] + + # Override the 'construct' method in a way that supports recursive parsing without validation. + # Based on https://github.com/samuelcolvin/pydantic/issues/1168#issuecomment-817742836. + @classmethod + @override + def construct( # pyright: ignore[reportIncompatibleMethodOverride] + __cls: Type[ModelT], + _fields_set: set[str] | None = None, + **values: object, + ) -> ModelT: + m = __cls.__new__(__cls) + fields_values: dict[str, object] = {} + + config = get_model_config(__cls) + populate_by_name = ( + config.allow_population_by_field_name + if isinstance(config, _ConfigProtocol) + else config.get("populate_by_name") + ) + + if _fields_set is None: + _fields_set = set() + + model_fields = get_model_fields(__cls) + for name, field in model_fields.items(): + key = field.alias + if key is None or (key not in values and populate_by_name): + key = name + + if key in values: + fields_values[name] = _construct_field(value=values[key], field=field, key=key) + _fields_set.add(name) + else: + fields_values[name] = field_get_default(field) + + extra_field_type = _get_extra_fields_type(__cls) + + _extra = {} + for key, value in values.items(): + if key not in model_fields: + parsed = construct_type(value=value, type_=extra_field_type) if extra_field_type is not None else value + + if PYDANTIC_V1: + _fields_set.add(key) + fields_values[key] = parsed + else: + _extra[key] = parsed + + object.__setattr__(m, "__dict__", fields_values) + + if PYDANTIC_V1: + # init_private_attributes() does not exist in v2 + m._init_private_attributes() # type: ignore + + # copied from Pydantic v1's `construct()` method + object.__setattr__(m, "__fields_set__", _fields_set) + else: + # these properties are copied from Pydantic's `model_construct()` method + object.__setattr__(m, "__pydantic_private__", None) + object.__setattr__(m, "__pydantic_extra__", _extra) + object.__setattr__(m, "__pydantic_fields_set__", _fields_set) + + return m + + if not TYPE_CHECKING: + # type checkers incorrectly complain about this assignment + # because the type signatures are technically different + # although not in practice + model_construct = construct + + if PYDANTIC_V1: + # we define aliases for some of the new pydantic v2 methods so + # that we can just document these methods without having to specify + # a specific pydantic version as some users may not know which + # pydantic version they are currently using + + @override + def model_dump( + self, + *, + mode: Literal["json", "python"] | str = "python", + include: IncEx | None = None, + exclude: IncEx | None = None, + context: Any | None = None, + by_alias: bool | None = None, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + exclude_computed_fields: bool = False, + round_trip: bool = False, + warnings: bool | Literal["none", "warn", "error"] = True, + fallback: Callable[[Any], Any] | None = None, + serialize_as_any: bool = False, + ) -> dict[str, Any]: + """Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump + + Generate a dictionary representation of the model, optionally specifying which fields to include or exclude. + + Args: + mode: The mode in which `to_python` should run. + If mode is 'json', the output will only contain JSON serializable types. + If mode is 'python', the output may contain non-JSON-serializable Python objects. + include: A set of fields to include in the output. + exclude: A set of fields to exclude from the output. + context: Additional context to pass to the serializer. + by_alias: Whether to use the field's alias in the dictionary key if defined. + exclude_unset: Whether to exclude fields that have not been explicitly set. + exclude_defaults: Whether to exclude fields that are set to their default value. + exclude_none: Whether to exclude fields that have a value of `None`. + exclude_computed_fields: Whether to exclude computed fields. + While this can be useful for round-tripping, it is usually recommended to use the dedicated + `round_trip` parameter instead. + round_trip: If True, dumped values should be valid as input for non-idempotent types such as Json[T]. + warnings: How to handle serialization errors. False/"none" ignores them, True/"warn" logs errors, + "error" raises a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError]. + fallback: A function to call when an unknown value is encountered. If not provided, + a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError] error is raised. + serialize_as_any: Whether to serialize fields with duck-typing serialization behavior. + + Returns: + A dictionary representation of the model. + """ + if mode not in {"json", "python"}: + raise ValueError("mode must be either 'json' or 'python'") + if round_trip != False: + raise ValueError("round_trip is only supported in Pydantic v2") + if warnings != True: + raise ValueError("warnings is only supported in Pydantic v2") + if context is not None: + raise ValueError("context is only supported in Pydantic v2") + if serialize_as_any != False: + raise ValueError("serialize_as_any is only supported in Pydantic v2") + if fallback is not None: + raise ValueError("fallback is only supported in Pydantic v2") + if exclude_computed_fields != False: + raise ValueError("exclude_computed_fields is only supported in Pydantic v2") + dumped = super().dict( # pyright: ignore[reportDeprecated] + include=include, + exclude=exclude, + by_alias=by_alias if by_alias is not None else False, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) + + return cast("dict[str, Any]", json_safe(dumped)) if mode == "json" else dumped + + @override + def model_dump_json( + self, + *, + indent: int | None = None, + ensure_ascii: bool = False, + include: IncEx | None = None, + exclude: IncEx | None = None, + context: Any | None = None, + by_alias: bool | None = None, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + exclude_computed_fields: bool = False, + round_trip: bool = False, + warnings: bool | Literal["none", "warn", "error"] = True, + fallback: Callable[[Any], Any] | None = None, + serialize_as_any: bool = False, + ) -> str: + """Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump_json + + Generates a JSON representation of the model using Pydantic's `to_json` method. + + Args: + indent: Indentation to use in the JSON output. If None is passed, the output will be compact. + include: Field(s) to include in the JSON output. Can take either a string or set of strings. + exclude: Field(s) to exclude from the JSON output. Can take either a string or set of strings. + by_alias: Whether to serialize using field aliases. + exclude_unset: Whether to exclude fields that have not been explicitly set. + exclude_defaults: Whether to exclude fields that have the default value. + exclude_none: Whether to exclude fields that have a value of `None`. + round_trip: Whether to use serialization/deserialization between JSON and class instance. + warnings: Whether to show any warnings that occurred during serialization. + + Returns: + A JSON string representation of the model. + """ + if round_trip != False: + raise ValueError("round_trip is only supported in Pydantic v2") + if warnings != True: + raise ValueError("warnings is only supported in Pydantic v2") + if context is not None: + raise ValueError("context is only supported in Pydantic v2") + if serialize_as_any != False: + raise ValueError("serialize_as_any is only supported in Pydantic v2") + if fallback is not None: + raise ValueError("fallback is only supported in Pydantic v2") + if ensure_ascii != False: + raise ValueError("ensure_ascii is only supported in Pydantic v2") + if exclude_computed_fields != False: + raise ValueError("exclude_computed_fields is only supported in Pydantic v2") + return super().json( # type: ignore[reportDeprecated] + indent=indent, + include=include, + exclude=exclude, + by_alias=by_alias if by_alias is not None else False, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) + + +class _EagerIterable(list[_T], Generic[_T]): + """ + Accepts any Iterable[T] input (including generators), consumes it + eagerly, and validates all items upfront. + + Validation preserves the original container type where possible + (e.g. a set[T] stays a set[T]). Serialization (model_dump / JSON) + always emits a list — round-tripping through model_dump() will not + restore the original container type. + """ + + @classmethod + def __get_pydantic_core_schema__( + cls, + source_type: Any, + handler: GetCoreSchemaHandler, + ) -> CoreSchema: + (item_type,) = get_args(source_type) or (Any,) + item_schema: CoreSchema = handler.generate_schema(item_type) + list_of_items_schema: CoreSchema = core_schema.list_schema(item_schema) + + return core_schema.no_info_wrap_validator_function( + cls._validate, + list_of_items_schema, + serialization=core_schema.plain_serializer_function_ser_schema( + cls._serialize, + info_arg=False, + ), + ) + + @staticmethod + def _validate(v: Iterable[_T], handler: "ValidatorFunctionWrapHandler") -> Any: + original_type: type[Any] = type(v) + + # Normalize to list so list_schema can validate each item + if isinstance(v, list): + items: list[_T] = v + else: + try: + items = list(v) + except TypeError as e: + raise TypeError("Value is not iterable") from e + + # Validate items against the inner schema + validated: list[_T] = handler(items) + + # Reconstruct original container type + if original_type is list: + return validated + # str(list) produces the list's repr, not a string built from items, + # so skip reconstruction for str and its subclasses. + if issubclass(original_type, str): + return validated + try: + return original_type(validated) + except (TypeError, ValueError): + # If the type cannot be reconstructed, just return the validated list + return validated + + @staticmethod + def _serialize(v: Iterable[_T]) -> list[_T]: + """Always serialize as a list so Pydantic's JSON encoder is happy.""" + if isinstance(v, list): + return v + return list(v) + + +EagerIterable: TypeAlias = Annotated[Iterable[_T], _EagerIterable] + + +def _construct_field(value: object, field: FieldInfo, key: str) -> object: + if value is None: + return field_get_default(field) + + if PYDANTIC_V1: + type_ = cast(type, field.outer_type_) # type: ignore + else: + type_ = field.annotation # type: ignore + + if type_ is None: + raise RuntimeError(f"Unexpected field type is None for {key}") + + return construct_type(value=value, type_=type_, metadata=getattr(field, "metadata", None)) + + +def _get_extra_fields_type(cls: type[pydantic.BaseModel]) -> type | None: + if PYDANTIC_V1: + # TODO + return None + + schema = cls.__pydantic_core_schema__ + if schema["type"] == "model": + fields = schema["schema"] + if fields["type"] == "model-fields": + extras = fields.get("extras_schema") + if extras and "cls" in extras: + # mypy can't narrow the type + return extras["cls"] # type: ignore[no-any-return] + + return None + + +def is_basemodel(type_: type) -> bool: + """Returns whether or not the given type is either a `BaseModel` or a union of `BaseModel`""" + if is_union(type_): + for variant in get_args(type_): + if is_basemodel(variant): + return True + + return False + + return is_basemodel_type(type_) + + +def is_basemodel_type(type_: type) -> TypeGuard[type[BaseModel] | type[GenericModel]]: + origin = get_origin(type_) or type_ + if not inspect.isclass(origin): + return False + return issubclass(origin, BaseModel) or issubclass(origin, GenericModel) + + +def build( + base_model_cls: Callable[P, _BaseModelT], + *args: P.args, + **kwargs: P.kwargs, +) -> _BaseModelT: + """Construct a BaseModel class without validation. + + This is useful for cases where you need to instantiate a `BaseModel` + from an API response as this provides type-safe params which isn't supported + by helpers like `construct_type()`. + + ```py + build(MyModel, my_field_a="foo", my_field_b=123) + ``` + """ + if args: + raise TypeError( + "Received positional arguments which are not supported; Keyword arguments must be used instead", + ) + + return cast(_BaseModelT, construct_type(type_=base_model_cls, value=kwargs)) + + +def construct_type_unchecked(*, value: object, type_: type[_T]) -> _T: + """Loose coercion to the expected type with construction of nested values. + + Note: the returned value from this function is not guaranteed to match the + given type. + """ + return cast(_T, construct_type(value=value, type_=type_)) + + +def construct_type(*, value: object, type_: object, metadata: Optional[List[Any]] = None) -> object: + """Loose coercion to the expected type with construction of nested values. + + If the given value does not match the expected type then it is returned as-is. + """ + + # store a reference to the original type we were given before we extract any inner + # types so that we can properly resolve forward references in `TypeAliasType` annotations + original_type = None + + # we allow `object` as the input type because otherwise, passing things like + # `Literal['value']` will be reported as a type error by type checkers + type_ = cast("type[object]", type_) + if is_type_alias_type(type_): + original_type = type_ # type: ignore[unreachable] + type_ = type_.__value__ # type: ignore[unreachable] + + # unwrap `Annotated[T, ...]` -> `T` + if metadata is not None and len(metadata) > 0: + meta: tuple[Any, ...] = tuple(metadata) + elif is_annotated_type(type_): + meta = get_args(type_)[1:] + type_ = extract_type_arg(type_, 0) + else: + meta = tuple() + + # we need to use the origin class for any types that are subscripted generics + # e.g. Dict[str, object] + origin = get_origin(type_) or type_ + args = get_args(type_) + + if is_union(origin): + try: + return validate_type(type_=cast("type[object]", original_type or type_), value=value) + except Exception: + pass + + # if the type is a discriminated union then we want to construct the right variant + # in the union, even if the data doesn't match exactly, otherwise we'd break code + # that relies on the constructed class types, e.g. + # + # class FooType: + # kind: Literal['foo'] + # value: str + # + # class BarType: + # kind: Literal['bar'] + # value: int + # + # without this block, if the data we get is something like `{'kind': 'bar', 'value': 'foo'}` then + # we'd end up constructing `FooType` when it should be `BarType`. + discriminator = _build_discriminated_union_meta(union=type_, meta_annotations=meta) + if discriminator and is_mapping(value): + variant_value = value.get(discriminator.field_alias_from or discriminator.field_name) + if variant_value and isinstance(variant_value, str): + variant_type = discriminator.mapping.get(variant_value) + if variant_type: + return construct_type(type_=variant_type, value=value) + + # if the data is not valid, use the first variant that doesn't fail while deserializing + for variant in args: + try: + return construct_type(value=value, type_=variant) + except Exception: + continue + + raise RuntimeError(f"Could not convert data into a valid instance of {type_}") + + if origin == dict: + if not is_mapping(value): + return value + + _, items_type = get_args(type_) # Dict[_, items_type] + return {key: construct_type(value=item, type_=items_type) for key, item in value.items()} + + if ( + not is_literal_type(type_) + and inspect.isclass(origin) + and (issubclass(origin, BaseModel) or issubclass(origin, GenericModel)) + ): + if is_list(value): + return [cast(Any, type_).construct(**entry) if is_mapping(entry) else entry for entry in value] + + if is_mapping(value): + if issubclass(type_, BaseModel): + return type_.construct(**value) # type: ignore[arg-type] + + return cast(Any, type_).construct(**value) + + if origin == list: + if not is_list(value): + return value + + inner_type = args[0] # List[inner_type] + return [construct_type(value=entry, type_=inner_type) for entry in value] + + if origin == float: + if isinstance(value, int): + coerced = float(value) + if coerced != value: + return value + return coerced + + return value + + if type_ == datetime: + try: + return parse_datetime(value) # type: ignore + except Exception: + return value + + if type_ == date: + try: + return parse_date(value) # type: ignore + except Exception: + return value + + return value + + +@runtime_checkable +class CachedDiscriminatorType(Protocol): + __discriminator__: DiscriminatorDetails + + +DISCRIMINATOR_CACHE: weakref.WeakKeyDictionary[type, DiscriminatorDetails] = weakref.WeakKeyDictionary() + + +class DiscriminatorDetails: + field_name: str + """The name of the discriminator field in the variant class, e.g. + + ```py + class Foo(BaseModel): + type: Literal['foo'] + ``` + + Will result in field_name='type' + """ + + field_alias_from: str | None + """The name of the discriminator field in the API response, e.g. + + ```py + class Foo(BaseModel): + type: Literal['foo'] = Field(alias='type_from_api') + ``` + + Will result in field_alias_from='type_from_api' + """ + + mapping: dict[str, type] + """Mapping of discriminator value to variant type, e.g. + + {'foo': FooVariant, 'bar': BarVariant} + """ + + def __init__( + self, + *, + mapping: dict[str, type], + discriminator_field: str, + discriminator_alias: str | None, + ) -> None: + self.mapping = mapping + self.field_name = discriminator_field + self.field_alias_from = discriminator_alias + + +def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, ...]) -> DiscriminatorDetails | None: + cached = DISCRIMINATOR_CACHE.get(union) + if cached is not None: + return cached + + discriminator_field_name: str | None = None + + for annotation in meta_annotations: + if isinstance(annotation, PropertyInfo) and annotation.discriminator is not None: + discriminator_field_name = annotation.discriminator + break + + if not discriminator_field_name: + return None + + mapping: dict[str, type] = {} + discriminator_alias: str | None = None + + for variant in get_args(union): + variant = strip_annotated_type(variant) + if is_basemodel_type(variant): + if PYDANTIC_V1: + field_info = cast("dict[str, FieldInfo]", variant.__fields__).get(discriminator_field_name) # pyright: ignore[reportDeprecated, reportUnnecessaryCast] + if not field_info: + continue + + # Note: if one variant defines an alias then they all should + discriminator_alias = field_info.alias + + if (annotation := getattr(field_info, "annotation", None)) and is_literal_type(annotation): + for entry in get_args(annotation): + if isinstance(entry, str): + mapping[entry] = variant + else: + field = _extract_field_schema_pv2(variant, discriminator_field_name) + if not field: + continue + + # Note: if one variant defines an alias then they all should + discriminator_alias = field.get("serialization_alias") + + field_schema = field["schema"] + + if field_schema["type"] == "literal": + for entry in cast("LiteralSchema", field_schema)["expected"]: + if isinstance(entry, str): + mapping[entry] = variant + + if not mapping: + return None + + details = DiscriminatorDetails( + mapping=mapping, + discriminator_field=discriminator_field_name, + discriminator_alias=discriminator_alias, + ) + DISCRIMINATOR_CACHE.setdefault(union, details) + return details + + +def _extract_field_schema_pv2(model: type[BaseModel], field_name: str) -> ModelField | None: + schema = model.__pydantic_core_schema__ + if schema["type"] == "definitions": + schema = schema["schema"] + + if schema["type"] != "model": + return None + + schema = cast("ModelSchema", schema) + fields_schema = schema["schema"] + if fields_schema["type"] != "model-fields": + return None + + fields_schema = cast("ModelFieldsSchema", fields_schema) + field = fields_schema["fields"].get(field_name) + if not field: + return None + + return cast("ModelField", field) # pyright: ignore[reportUnnecessaryCast] + + +def validate_type(*, type_: type[_T], value: object) -> _T: + """Strict validation that the given value matches the expected type""" + if inspect.isclass(type_) and issubclass(type_, pydantic.BaseModel): + return cast(_T, parse_obj(type_, value)) + + return cast(_T, _validate_non_model_type(type_=type_, value=value)) + + +def set_pydantic_config(typ: Any, config: pydantic.ConfigDict) -> None: + """Add a pydantic config for the given type. + + Note: this is a no-op on Pydantic v1. + """ + setattr(typ, "__pydantic_config__", config) # noqa: B010 + + +# our use of subclassing here causes weirdness for type checkers, +# so we just pretend that we don't subclass +if TYPE_CHECKING: + GenericModel = BaseModel +else: + + class GenericModel(BaseGenericModel, BaseModel): + pass + + +if not PYDANTIC_V1: + from pydantic import TypeAdapter as _TypeAdapter + + _CachedTypeAdapter = cast("TypeAdapter[object]", lru_cache(maxsize=None)(_TypeAdapter)) + + if TYPE_CHECKING: + from pydantic import TypeAdapter + else: + TypeAdapter = _CachedTypeAdapter + + def _validate_non_model_type(*, type_: type[_T], value: object) -> _T: + return TypeAdapter(type_).validate_python(value) + +elif not TYPE_CHECKING: # TODO: condition is weird + + class RootModel(GenericModel, Generic[_T]): + """Used as a placeholder to easily convert runtime types to a Pydantic format + to provide validation. + + For example: + ```py + validated = RootModel[int](__root__="5").__root__ + # validated: 5 + ``` + """ + + __root__: _T + + def _validate_non_model_type(*, type_: type[_T], value: object) -> _T: + model = _create_pydantic_model(type_).validate(value) + return cast(_T, model.__root__) + + def _create_pydantic_model(type_: _T) -> Type[RootModel[_T]]: + return RootModel[type_] # type: ignore + + +class FinalRequestOptionsInput(TypedDict, total=False): + method: Required[str] + url: Required[str] + params: Query + headers: Headers + max_retries: int + timeout: float | Timeout | None + files: HttpxRequestFiles | None + idempotency_key: str + content: Union[bytes, bytearray, IO[bytes], Iterable[bytes], AsyncIterable[bytes], None] + json_data: Body + extra_json: AnyMapping + follow_redirects: bool + + +@final +class FinalRequestOptions(pydantic.BaseModel): + method: str + url: str + params: Query = {} + headers: Union[Headers, NotGiven] = NotGiven() + max_retries: Union[int, NotGiven] = NotGiven() + timeout: Union[float, Timeout, None, NotGiven] = NotGiven() + files: Union[HttpxRequestFiles, None] = None + idempotency_key: Union[str, None] = None + post_parser: Union[Callable[[Any], Any], NotGiven] = NotGiven() + follow_redirects: Union[bool, None] = None + + content: Union[bytes, bytearray, IO[bytes], Iterable[bytes], AsyncIterable[bytes], None] = None + # It should be noted that we cannot use `json` here as that would override + # a BaseModel method in an incompatible fashion. + json_data: Union[Body, None] = None + extra_json: Union[AnyMapping, None] = None + + if PYDANTIC_V1: + + class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated] + arbitrary_types_allowed: bool = True + else: + model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True) + + def get_max_retries(self, max_retries: int) -> int: + if isinstance(self.max_retries, NotGiven): + return max_retries + return self.max_retries + + def _strip_raw_response_header(self) -> None: + if not is_given(self.headers): + return + + if self.headers.get(RAW_RESPONSE_HEADER): + self.headers = {**self.headers} + self.headers.pop(RAW_RESPONSE_HEADER) + + # override the `construct` method so that we can run custom transformations. + # this is necessary as we don't want to do any actual runtime type checking + # (which means we can't use validators) but we do want to ensure that `NotGiven` + # values are not present + # + # type ignore required because we're adding explicit types to `**values` + @classmethod + def construct( # type: ignore + cls, + _fields_set: set[str] | None = None, + **values: Unpack[FinalRequestOptionsInput], + ) -> FinalRequestOptions: + kwargs: dict[str, Any] = { + # we unconditionally call `strip_not_given` on any value + # as it will just ignore any non-mapping types + key: strip_not_given(value) + for key, value in values.items() + } + if PYDANTIC_V1: + return cast(FinalRequestOptions, super().construct(_fields_set, **kwargs)) # pyright: ignore[reportDeprecated] + return super().model_construct(_fields_set, **kwargs) + + if not TYPE_CHECKING: + # type checkers incorrectly complain about this assignment + model_construct = construct diff --git a/src/writerai/_qs.py b/src/writerai/_qs.py new file mode 100644 index 00000000..4127c19c --- /dev/null +++ b/src/writerai/_qs.py @@ -0,0 +1,149 @@ +from __future__ import annotations + +from typing import Any, List, Tuple, Union, Mapping, TypeVar +from urllib.parse import parse_qs, urlencode +from typing_extensions import get_args + +from ._types import NotGiven, ArrayFormat, NestedFormat, not_given +from ._utils import flatten + +_T = TypeVar("_T") + +PrimitiveData = Union[str, int, float, bool, None] +# this should be Data = Union[PrimitiveData, "List[Data]", "Tuple[Data]", "Mapping[str, Data]"] +# https://github.com/microsoft/pyright/issues/3555 +Data = Union[PrimitiveData, List[Any], Tuple[Any], "Mapping[str, Any]"] +Params = Mapping[str, Data] + + +class Querystring: + array_format: ArrayFormat + nested_format: NestedFormat + + def __init__( + self, + *, + array_format: ArrayFormat = "repeat", + nested_format: NestedFormat = "brackets", + ) -> None: + self.array_format = array_format + self.nested_format = nested_format + + def parse(self, query: str) -> Mapping[str, object]: + # Note: custom format syntax is not supported yet + return parse_qs(query) + + def stringify( + self, + params: Params, + *, + array_format: ArrayFormat | NotGiven = not_given, + nested_format: NestedFormat | NotGiven = not_given, + ) -> str: + return urlencode( + self.stringify_items( + params, + array_format=array_format, + nested_format=nested_format, + ) + ) + + def stringify_items( + self, + params: Params, + *, + array_format: ArrayFormat | NotGiven = not_given, + nested_format: NestedFormat | NotGiven = not_given, + ) -> list[tuple[str, str]]: + opts = Options( + qs=self, + array_format=array_format, + nested_format=nested_format, + ) + return flatten([self._stringify_item(key, value, opts) for key, value in params.items()]) + + def _stringify_item( + self, + key: str, + value: Data, + opts: Options, + ) -> list[tuple[str, str]]: + if isinstance(value, Mapping): + items: list[tuple[str, str]] = [] + nested_format = opts.nested_format + for subkey, subvalue in value.items(): + items.extend( + self._stringify_item( + # TODO: error if unknown format + f"{key}.{subkey}" if nested_format == "dots" else f"{key}[{subkey}]", + subvalue, + opts, + ) + ) + return items + + if isinstance(value, (list, tuple)): + array_format = opts.array_format + if array_format == "comma": + return [ + ( + key, + ",".join(self._primitive_value_to_str(item) for item in value if item is not None), + ), + ] + elif array_format == "repeat": + items = [] + for item in value: + items.extend(self._stringify_item(key, item, opts)) + return items + elif array_format == "indices": + items = [] + for i, item in enumerate(value): + items.extend(self._stringify_item(f"{key}[{i}]", item, opts)) + return items + elif array_format == "brackets": + items = [] + key = key + "[]" + for item in value: + items.extend(self._stringify_item(key, item, opts)) + return items + else: + raise NotImplementedError( + f"Unknown array_format value: {array_format}, choose from {', '.join(get_args(ArrayFormat))}" + ) + + serialised = self._primitive_value_to_str(value) + if not serialised: + return [] + return [(key, serialised)] + + def _primitive_value_to_str(self, value: PrimitiveData) -> str: + # copied from httpx + if value is True: + return "true" + elif value is False: + return "false" + elif value is None: + return "" + return str(value) + + +_qs = Querystring() +parse = _qs.parse +stringify = _qs.stringify +stringify_items = _qs.stringify_items + + +class Options: + array_format: ArrayFormat + nested_format: NestedFormat + + def __init__( + self, + qs: Querystring = _qs, + *, + array_format: ArrayFormat | NotGiven = not_given, + nested_format: NestedFormat | NotGiven = not_given, + ) -> None: + self.array_format = qs.array_format if isinstance(array_format, NotGiven) else array_format + self.nested_format = qs.nested_format if isinstance(nested_format, NotGiven) else nested_format diff --git a/src/writerai/_resource.py b/src/writerai/_resource.py new file mode 100644 index 00000000..b62a9f28 --- /dev/null +++ b/src/writerai/_resource.py @@ -0,0 +1,43 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +import time +from typing import TYPE_CHECKING + +import anyio + +if TYPE_CHECKING: + from ._client import Writer, AsyncWriter + + +class SyncAPIResource: + _client: Writer + + def __init__(self, client: Writer) -> None: + self._client = client + self._get = client.get + self._post = client.post + self._patch = client.patch + self._put = client.put + self._delete = client.delete + self._get_api_list = client.get_api_list + + def _sleep(self, seconds: float) -> None: + time.sleep(seconds) + + +class AsyncAPIResource: + _client: AsyncWriter + + def __init__(self, client: AsyncWriter) -> None: + self._client = client + self._get = client.get + self._post = client.post + self._patch = client.patch + self._put = client.put + self._delete = client.delete + self._get_api_list = client.get_api_list + + async def _sleep(self, seconds: float) -> None: + await anyio.sleep(seconds) diff --git a/src/writerai/_response.py b/src/writerai/_response.py new file mode 100644 index 00000000..814b0419 --- /dev/null +++ b/src/writerai/_response.py @@ -0,0 +1,833 @@ +from __future__ import annotations + +import os +import inspect +import logging +import datetime +import functools +from types import TracebackType +from typing import ( + TYPE_CHECKING, + Any, + Union, + Generic, + TypeVar, + Callable, + Iterator, + AsyncIterator, + cast, + overload, +) +from typing_extensions import Awaitable, ParamSpec, override, get_origin + +import anyio +import httpx +import pydantic + +from ._types import NoneType +from ._utils import is_given, extract_type_arg, is_annotated_type, is_type_alias_type, extract_type_var_from_base +from ._models import BaseModel, is_basemodel +from ._constants import RAW_RESPONSE_HEADER, OVERRIDE_CAST_TO_HEADER +from ._streaming import Stream, AsyncStream, is_stream_class_type, extract_stream_chunk_type +from ._exceptions import WriterError, APIResponseValidationError + +if TYPE_CHECKING: + from ._models import FinalRequestOptions + from ._base_client import BaseClient + + +P = ParamSpec("P") +R = TypeVar("R") +_T = TypeVar("_T") +_APIResponseT = TypeVar("_APIResponseT", bound="APIResponse[Any]") +_AsyncAPIResponseT = TypeVar("_AsyncAPIResponseT", bound="AsyncAPIResponse[Any]") + +log: logging.Logger = logging.getLogger(__name__) + + +class BaseAPIResponse(Generic[R]): + _cast_to: type[R] + _client: BaseClient[Any, Any] + _parsed_by_type: dict[type[Any], Any] + _is_sse_stream: bool + _stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None + _options: FinalRequestOptions + + http_response: httpx.Response + + retries_taken: int + """The number of retries made. If no retries happened this will be `0`""" + + def __init__( + self, + *, + raw: httpx.Response, + cast_to: type[R], + client: BaseClient[Any, Any], + stream: bool, + stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None, + options: FinalRequestOptions, + retries_taken: int = 0, + ) -> None: + self._cast_to = cast_to + self._client = client + self._parsed_by_type = {} + self._is_sse_stream = stream + self._stream_cls = stream_cls + self._options = options + self.http_response = raw + self.retries_taken = retries_taken + + @property + def headers(self) -> httpx.Headers: + return self.http_response.headers + + @property + def http_request(self) -> httpx.Request: + """Returns the httpx Request instance associated with the current response.""" + return self.http_response.request + + @property + def status_code(self) -> int: + return self.http_response.status_code + + @property + def url(self) -> httpx.URL: + """Returns the URL for which the request was made.""" + return self.http_response.url + + @property + def method(self) -> str: + return self.http_request.method + + @property + def http_version(self) -> str: + return self.http_response.http_version + + @property + def elapsed(self) -> datetime.timedelta: + """The time taken for the complete request/response cycle to complete.""" + return self.http_response.elapsed + + @property + def is_closed(self) -> bool: + """Whether or not the response body has been closed. + + If this is False then there is response data that has not been read yet. + You must either fully consume the response body or call `.close()` + before discarding the response to prevent resource leaks. + """ + return self.http_response.is_closed + + @override + def __repr__(self) -> str: + return ( + f"<{self.__class__.__name__} [{self.status_code} {self.http_response.reason_phrase}] type={self._cast_to}>" + ) + + def _parse(self, *, to: type[_T] | None = None) -> R | _T: + cast_to = to if to is not None else self._cast_to + + # unwrap `TypeAlias('Name', T)` -> `T` + if is_type_alias_type(cast_to): + cast_to = cast_to.__value__ # type: ignore[unreachable] + + # unwrap `Annotated[T, ...]` -> `T` + if cast_to and is_annotated_type(cast_to): + cast_to = extract_type_arg(cast_to, 0) + + origin = get_origin(cast_to) or cast_to + + if self._is_sse_stream: + if to: + if not is_stream_class_type(to): + raise TypeError(f"Expected custom parse type to be a subclass of {Stream} or {AsyncStream}") + + return cast( + _T, + to( + cast_to=extract_stream_chunk_type( + to, + failure_message="Expected custom stream type to be passed with a type argument, e.g. Stream[ChunkType]", + ), + response=self.http_response, + client=cast(Any, self._client), + options=self._options, + ), + ) + + if self._stream_cls: + return cast( + R, + self._stream_cls( + cast_to=extract_stream_chunk_type(self._stream_cls), + response=self.http_response, + client=cast(Any, self._client), + options=self._options, + ), + ) + + stream_cls = cast("type[Stream[Any]] | type[AsyncStream[Any]] | None", self._client._default_stream_cls) + if stream_cls is None: + raise MissingStreamClassError() + + return cast( + R, + stream_cls( + cast_to=cast_to, + response=self.http_response, + client=cast(Any, self._client), + options=self._options, + ), + ) + + if cast_to is NoneType: + return cast(R, None) + + response = self.http_response + if cast_to == str: + return cast(R, response.text) + + if cast_to == bytes: + return cast(R, response.content) + + if cast_to == int: + return cast(R, int(response.text)) + + if cast_to == float: + return cast(R, float(response.text)) + + if cast_to == bool: + return cast(R, response.text.lower() == "true") + + if origin == APIResponse: + raise RuntimeError("Unexpected state - cast_to is `APIResponse`") + + if inspect.isclass(origin) and issubclass(origin, httpx.Response): + # Because of the invariance of our ResponseT TypeVar, users can subclass httpx.Response + # and pass that class to our request functions. We cannot change the variance to be either + # covariant or contravariant as that makes our usage of ResponseT illegal. We could construct + # the response class ourselves but that is something that should be supported directly in httpx + # as it would be easy to incorrectly construct the Response object due to the multitude of arguments. + if cast_to != httpx.Response: + raise ValueError(f"Subclasses of httpx.Response cannot be passed to `cast_to`") + return cast(R, response) + + if ( + inspect.isclass( + origin # pyright: ignore[reportUnknownArgumentType] + ) + and not issubclass(origin, BaseModel) + and issubclass(origin, pydantic.BaseModel) + ): + raise TypeError("Pydantic models must subclass our base model type, e.g. `from writerai import BaseModel`") + + if ( + cast_to is not object + and not origin is list + and not origin is dict + and not origin is Union + and not issubclass(origin, BaseModel) + ): + raise RuntimeError( + f"Unsupported type, expected {cast_to} to be a subclass of {BaseModel}, {dict}, {list}, {Union}, {NoneType}, {str} or {httpx.Response}." + ) + + # split is required to handle cases where additional information is included + # in the response, e.g. application/json; charset=utf-8 + content_type, *_ = response.headers.get("content-type", "*").split(";") + if not content_type.endswith("json"): + if is_basemodel(cast_to): + try: + data = response.json() + except Exception as exc: + log.debug("Could not read JSON from response data due to %s - %s", type(exc), exc) + else: + return self._client._process_response_data( + data=data, + cast_to=cast_to, # type: ignore + response=response, + ) + + if self._client._strict_response_validation: + raise APIResponseValidationError( + response=response, + message=f"Expected Content-Type response header to be `application/json` but received `{content_type}` instead.", + body=response.text, + ) + + # If the API responds with content that isn't JSON then we just return + # the (decoded) text without performing any parsing so that you can still + # handle the response however you need to. + return response.text # type: ignore + + data = response.json() + + return self._client._process_response_data( + data=data, + cast_to=cast_to, # type: ignore + response=response, + ) + + +class APIResponse(BaseAPIResponse[R]): + @overload + def parse(self, *, to: type[_T]) -> _T: ... + + @overload + def parse(self) -> R: ... + + def parse(self, *, to: type[_T] | None = None) -> R | _T: + """Returns the rich python representation of this response's data. + + For lower-level control, see `.read()`, `.json()`, `.iter_bytes()`. + + You can customise the type that the response is parsed into through + the `to` argument, e.g. + + ```py + from writerai import BaseModel + + + class MyModel(BaseModel): + foo: str + + + obj = response.parse(to=MyModel) + print(obj.foo) + ``` + + We support parsing: + - `BaseModel` + - `dict` + - `list` + - `Union` + - `str` + - `int` + - `float` + - `httpx.Response` + """ + cache_key = to if to is not None else self._cast_to + cached = self._parsed_by_type.get(cache_key) + if cached is not None: + return cached # type: ignore[no-any-return] + + if not self._is_sse_stream: + self.read() + + parsed = self._parse(to=to) + if is_given(self._options.post_parser): + parsed = self._options.post_parser(parsed) + + self._parsed_by_type[cache_key] = parsed + return parsed + + def read(self) -> bytes: + """Read and return the binary response content.""" + try: + return self.http_response.read() + except httpx.StreamConsumed as exc: + # The default error raised by httpx isn't very + # helpful in our case so we re-raise it with + # a different error message. + raise StreamAlreadyConsumed() from exc + + def text(self) -> str: + """Read and decode the response content into a string.""" + self.read() + return self.http_response.text + + def json(self) -> object: + """Read and decode the JSON response content.""" + self.read() + return self.http_response.json() + + def close(self) -> None: + """Close the response and release the connection. + + Automatically called if the response body is read to completion. + """ + self.http_response.close() + + def iter_bytes(self, chunk_size: int | None = None) -> Iterator[bytes]: + """ + A byte-iterator over the decoded response content. + + This automatically handles gzip, deflate and brotli encoded responses. + """ + for chunk in self.http_response.iter_bytes(chunk_size): + yield chunk + + def iter_text(self, chunk_size: int | None = None) -> Iterator[str]: + """A str-iterator over the decoded response content + that handles both gzip, deflate, etc but also detects the content's + string encoding. + """ + for chunk in self.http_response.iter_text(chunk_size): + yield chunk + + def iter_lines(self) -> Iterator[str]: + """Like `iter_text()` but will only yield chunks for each line""" + for chunk in self.http_response.iter_lines(): + yield chunk + + +class AsyncAPIResponse(BaseAPIResponse[R]): + @overload + async def parse(self, *, to: type[_T]) -> _T: ... + + @overload + async def parse(self) -> R: ... + + async def parse(self, *, to: type[_T] | None = None) -> R | _T: + """Returns the rich python representation of this response's data. + + For lower-level control, see `.read()`, `.json()`, `.iter_bytes()`. + + You can customise the type that the response is parsed into through + the `to` argument, e.g. + + ```py + from writerai import BaseModel + + + class MyModel(BaseModel): + foo: str + + + obj = response.parse(to=MyModel) + print(obj.foo) + ``` + + We support parsing: + - `BaseModel` + - `dict` + - `list` + - `Union` + - `str` + - `httpx.Response` + """ + cache_key = to if to is not None else self._cast_to + cached = self._parsed_by_type.get(cache_key) + if cached is not None: + return cached # type: ignore[no-any-return] + + if not self._is_sse_stream: + await self.read() + + parsed = self._parse(to=to) + if is_given(self._options.post_parser): + parsed = self._options.post_parser(parsed) + + self._parsed_by_type[cache_key] = parsed + return parsed + + async def read(self) -> bytes: + """Read and return the binary response content.""" + try: + return await self.http_response.aread() + except httpx.StreamConsumed as exc: + # the default error raised by httpx isn't very + # helpful in our case so we re-raise it with + # a different error message + raise StreamAlreadyConsumed() from exc + + async def text(self) -> str: + """Read and decode the response content into a string.""" + await self.read() + return self.http_response.text + + async def json(self) -> object: + """Read and decode the JSON response content.""" + await self.read() + return self.http_response.json() + + async def close(self) -> None: + """Close the response and release the connection. + + Automatically called if the response body is read to completion. + """ + await self.http_response.aclose() + + async def iter_bytes(self, chunk_size: int | None = None) -> AsyncIterator[bytes]: + """ + A byte-iterator over the decoded response content. + + This automatically handles gzip, deflate and brotli encoded responses. + """ + async for chunk in self.http_response.aiter_bytes(chunk_size): + yield chunk + + async def iter_text(self, chunk_size: int | None = None) -> AsyncIterator[str]: + """A str-iterator over the decoded response content + that handles both gzip, deflate, etc but also detects the content's + string encoding. + """ + async for chunk in self.http_response.aiter_text(chunk_size): + yield chunk + + async def iter_lines(self) -> AsyncIterator[str]: + """Like `iter_text()` but will only yield chunks for each line""" + async for chunk in self.http_response.aiter_lines(): + yield chunk + + +class BinaryAPIResponse(APIResponse[bytes]): + """Subclass of APIResponse providing helpers for dealing with binary data. + + Note: If you want to stream the response data instead of eagerly reading it + all at once then you should use `.with_streaming_response` when making + the API request, e.g. `.with_streaming_response.get_binary_response()` + """ + + def write_to_file( + self, + file: str | os.PathLike[str], + ) -> None: + """Write the output to the given file. + + Accepts a filename or any path-like object, e.g. pathlib.Path + + Note: if you want to stream the data to the file instead of writing + all at once then you should use `.with_streaming_response` when making + the API request, e.g. `.with_streaming_response.get_binary_response()` + """ + with open(file, mode="wb") as f: + for data in self.iter_bytes(): + f.write(data) + + +class AsyncBinaryAPIResponse(AsyncAPIResponse[bytes]): + """Subclass of APIResponse providing helpers for dealing with binary data. + + Note: If you want to stream the response data instead of eagerly reading it + all at once then you should use `.with_streaming_response` when making + the API request, e.g. `.with_streaming_response.get_binary_response()` + """ + + async def write_to_file( + self, + file: str | os.PathLike[str], + ) -> None: + """Write the output to the given file. + + Accepts a filename or any path-like object, e.g. pathlib.Path + + Note: if you want to stream the data to the file instead of writing + all at once then you should use `.with_streaming_response` when making + the API request, e.g. `.with_streaming_response.get_binary_response()` + """ + path = anyio.Path(file) + async with await path.open(mode="wb") as f: + async for data in self.iter_bytes(): + await f.write(data) + + +class StreamedBinaryAPIResponse(APIResponse[bytes]): + def stream_to_file( + self, + file: str | os.PathLike[str], + *, + chunk_size: int | None = None, + ) -> None: + """Streams the output to the given file. + + Accepts a filename or any path-like object, e.g. pathlib.Path + """ + with open(file, mode="wb") as f: + for data in self.iter_bytes(chunk_size): + f.write(data) + + +class AsyncStreamedBinaryAPIResponse(AsyncAPIResponse[bytes]): + async def stream_to_file( + self, + file: str | os.PathLike[str], + *, + chunk_size: int | None = None, + ) -> None: + """Streams the output to the given file. + + Accepts a filename or any path-like object, e.g. pathlib.Path + """ + path = anyio.Path(file) + async with await path.open(mode="wb") as f: + async for data in self.iter_bytes(chunk_size): + await f.write(data) + + +class MissingStreamClassError(TypeError): + def __init__(self) -> None: + super().__init__( + "The `stream` argument was set to `True` but the `stream_cls` argument was not given. See `writerai._streaming` for reference", + ) + + +class StreamAlreadyConsumed(WriterError): + """ + Attempted to read or stream content, but the content has already + been streamed. + + This can happen if you use a method like `.iter_lines()` and then attempt + to read th entire response body afterwards, e.g. + + ```py + response = await client.post(...) + async for line in response.iter_lines(): + ... # do something with `line` + + content = await response.read() + # ^ error + ``` + + If you want this behaviour you'll need to either manually accumulate the response + content or call `await response.read()` before iterating over the stream. + """ + + def __init__(self) -> None: + message = ( + "Attempted to read or stream some content, but the content has " + "already been streamed. " + "This could be due to attempting to stream the response " + "content more than once." + "\n\n" + "You can fix this by manually accumulating the response content while streaming " + "or by calling `.read()` before starting to stream." + ) + super().__init__(message) + + +class ResponseContextManager(Generic[_APIResponseT]): + """Context manager for ensuring that a request is not made + until it is entered and that the response will always be closed + when the context manager exits + """ + + def __init__(self, request_func: Callable[[], _APIResponseT]) -> None: + self._request_func = request_func + self.__response: _APIResponseT | None = None + + def __enter__(self) -> _APIResponseT: + self.__response = self._request_func() + return self.__response + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + if self.__response is not None: + self.__response.close() + + +class AsyncResponseContextManager(Generic[_AsyncAPIResponseT]): + """Context manager for ensuring that a request is not made + until it is entered and that the response will always be closed + when the context manager exits + """ + + def __init__(self, api_request: Awaitable[_AsyncAPIResponseT]) -> None: + self._api_request = api_request + self.__response: _AsyncAPIResponseT | None = None + + async def __aenter__(self) -> _AsyncAPIResponseT: + self.__response = await self._api_request + return self.__response + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + if self.__response is not None: + await self.__response.close() + + +def to_streamed_response_wrapper(func: Callable[P, R]) -> Callable[P, ResponseContextManager[APIResponse[R]]]: + """Higher order function that takes one of our bound API methods and wraps it + to support streaming and returning the raw `APIResponse` object directly. + """ + + @functools.wraps(func) + def wrapped(*args: P.args, **kwargs: P.kwargs) -> ResponseContextManager[APIResponse[R]]: + extra_headers: dict[str, str] = {**(cast(Any, kwargs.get("extra_headers")) or {})} + extra_headers[RAW_RESPONSE_HEADER] = "stream" + + kwargs["extra_headers"] = extra_headers + + make_request = functools.partial(func, *args, **kwargs) + + return ResponseContextManager(cast(Callable[[], APIResponse[R]], make_request)) + + return wrapped + + +def async_to_streamed_response_wrapper( + func: Callable[P, Awaitable[R]], +) -> Callable[P, AsyncResponseContextManager[AsyncAPIResponse[R]]]: + """Higher order function that takes one of our bound API methods and wraps it + to support streaming and returning the raw `APIResponse` object directly. + """ + + @functools.wraps(func) + def wrapped(*args: P.args, **kwargs: P.kwargs) -> AsyncResponseContextManager[AsyncAPIResponse[R]]: + extra_headers: dict[str, str] = {**(cast(Any, kwargs.get("extra_headers")) or {})} + extra_headers[RAW_RESPONSE_HEADER] = "stream" + + kwargs["extra_headers"] = extra_headers + + make_request = func(*args, **kwargs) + + return AsyncResponseContextManager(cast(Awaitable[AsyncAPIResponse[R]], make_request)) + + return wrapped + + +def to_custom_streamed_response_wrapper( + func: Callable[P, object], + response_cls: type[_APIResponseT], +) -> Callable[P, ResponseContextManager[_APIResponseT]]: + """Higher order function that takes one of our bound API methods and an `APIResponse` class + and wraps the method to support streaming and returning the given response class directly. + + Note: the given `response_cls` *must* be concrete, e.g. `class BinaryAPIResponse(APIResponse[bytes])` + """ + + @functools.wraps(func) + def wrapped(*args: P.args, **kwargs: P.kwargs) -> ResponseContextManager[_APIResponseT]: + extra_headers: dict[str, Any] = {**(cast(Any, kwargs.get("extra_headers")) or {})} + extra_headers[RAW_RESPONSE_HEADER] = "stream" + extra_headers[OVERRIDE_CAST_TO_HEADER] = response_cls + + kwargs["extra_headers"] = extra_headers + + make_request = functools.partial(func, *args, **kwargs) + + return ResponseContextManager(cast(Callable[[], _APIResponseT], make_request)) + + return wrapped + + +def async_to_custom_streamed_response_wrapper( + func: Callable[P, Awaitable[object]], + response_cls: type[_AsyncAPIResponseT], +) -> Callable[P, AsyncResponseContextManager[_AsyncAPIResponseT]]: + """Higher order function that takes one of our bound API methods and an `APIResponse` class + and wraps the method to support streaming and returning the given response class directly. + + Note: the given `response_cls` *must* be concrete, e.g. `class BinaryAPIResponse(APIResponse[bytes])` + """ + + @functools.wraps(func) + def wrapped(*args: P.args, **kwargs: P.kwargs) -> AsyncResponseContextManager[_AsyncAPIResponseT]: + extra_headers: dict[str, Any] = {**(cast(Any, kwargs.get("extra_headers")) or {})} + extra_headers[RAW_RESPONSE_HEADER] = "stream" + extra_headers[OVERRIDE_CAST_TO_HEADER] = response_cls + + kwargs["extra_headers"] = extra_headers + + make_request = func(*args, **kwargs) + + return AsyncResponseContextManager(cast(Awaitable[_AsyncAPIResponseT], make_request)) + + return wrapped + + +def to_raw_response_wrapper(func: Callable[P, R]) -> Callable[P, APIResponse[R]]: + """Higher order function that takes one of our bound API methods and wraps it + to support returning the raw `APIResponse` object directly. + """ + + @functools.wraps(func) + def wrapped(*args: P.args, **kwargs: P.kwargs) -> APIResponse[R]: + extra_headers: dict[str, str] = {**(cast(Any, kwargs.get("extra_headers")) or {})} + extra_headers[RAW_RESPONSE_HEADER] = "raw" + + kwargs["extra_headers"] = extra_headers + + return cast(APIResponse[R], func(*args, **kwargs)) + + return wrapped + + +def async_to_raw_response_wrapper(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[AsyncAPIResponse[R]]]: + """Higher order function that takes one of our bound API methods and wraps it + to support returning the raw `APIResponse` object directly. + """ + + @functools.wraps(func) + async def wrapped(*args: P.args, **kwargs: P.kwargs) -> AsyncAPIResponse[R]: + extra_headers: dict[str, str] = {**(cast(Any, kwargs.get("extra_headers")) or {})} + extra_headers[RAW_RESPONSE_HEADER] = "raw" + + kwargs["extra_headers"] = extra_headers + + return cast(AsyncAPIResponse[R], await func(*args, **kwargs)) + + return wrapped + + +def to_custom_raw_response_wrapper( + func: Callable[P, object], + response_cls: type[_APIResponseT], +) -> Callable[P, _APIResponseT]: + """Higher order function that takes one of our bound API methods and an `APIResponse` class + and wraps the method to support returning the given response class directly. + + Note: the given `response_cls` *must* be concrete, e.g. `class BinaryAPIResponse(APIResponse[bytes])` + """ + + @functools.wraps(func) + def wrapped(*args: P.args, **kwargs: P.kwargs) -> _APIResponseT: + extra_headers: dict[str, Any] = {**(cast(Any, kwargs.get("extra_headers")) or {})} + extra_headers[RAW_RESPONSE_HEADER] = "raw" + extra_headers[OVERRIDE_CAST_TO_HEADER] = response_cls + + kwargs["extra_headers"] = extra_headers + + return cast(_APIResponseT, func(*args, **kwargs)) + + return wrapped + + +def async_to_custom_raw_response_wrapper( + func: Callable[P, Awaitable[object]], + response_cls: type[_AsyncAPIResponseT], +) -> Callable[P, Awaitable[_AsyncAPIResponseT]]: + """Higher order function that takes one of our bound API methods and an `APIResponse` class + and wraps the method to support returning the given response class directly. + + Note: the given `response_cls` *must* be concrete, e.g. `class BinaryAPIResponse(APIResponse[bytes])` + """ + + @functools.wraps(func) + def wrapped(*args: P.args, **kwargs: P.kwargs) -> Awaitable[_AsyncAPIResponseT]: + extra_headers: dict[str, Any] = {**(cast(Any, kwargs.get("extra_headers")) or {})} + extra_headers[RAW_RESPONSE_HEADER] = "raw" + extra_headers[OVERRIDE_CAST_TO_HEADER] = response_cls + + kwargs["extra_headers"] = extra_headers + + return cast(Awaitable[_AsyncAPIResponseT], func(*args, **kwargs)) + + return wrapped + + +def extract_response_type(typ: type[BaseAPIResponse[Any]]) -> type: + """Given a type like `APIResponse[T]`, returns the generic type variable `T`. + + This also handles the case where a concrete subclass is given, e.g. + ```py + class MyResponse(APIResponse[bytes]): + ... + + extract_response_type(MyResponse) -> bytes + ``` + """ + return extract_type_var_from_base( + typ, + generic_bases=cast("tuple[type, ...]", (BaseAPIResponse, APIResponse, AsyncAPIResponse)), + index=0, + ) diff --git a/src/writerai/_streaming.py b/src/writerai/_streaming.py new file mode 100644 index 00000000..389e0f31 --- /dev/null +++ b/src/writerai/_streaming.py @@ -0,0 +1,376 @@ +# Note: initially copied from https://github.com/florimondmanca/httpx-sse/blob/master/src/httpx_sse/_decoders.py +from __future__ import annotations + +import json +import inspect +from types import TracebackType +from typing import TYPE_CHECKING, Any, Generic, TypeVar, Iterator, Optional, AsyncIterator, cast +from typing_extensions import Self, Protocol, TypeGuard, override, get_origin, runtime_checkable + +import httpx + +from ._utils import extract_type_var_from_base + +if TYPE_CHECKING: + from ._client import Writer, AsyncWriter + from ._models import FinalRequestOptions + + +_T = TypeVar("_T") + + +class Stream(Generic[_T]): + """Provides the core interface to iterate over a synchronous stream response.""" + + response: httpx.Response + _options: Optional[FinalRequestOptions] = None + _decoder: SSEBytesDecoder + + def __init__( + self, + *, + cast_to: type[_T], + response: httpx.Response, + client: Writer, + options: Optional[FinalRequestOptions] = None, + ) -> None: + self.response = response + self._cast_to = cast_to + self._client = client + self._options = options + self._decoder = client._make_sse_decoder() + self._iterator = self.__stream__() + + def __next__(self) -> _T: + return self._iterator.__next__() + + def __iter__(self) -> Iterator[_T]: + for item in self._iterator: + yield item + + def _iter_events(self) -> Iterator[ServerSentEvent]: + yield from self._decoder.iter_bytes(self.response.iter_bytes()) + + def __stream__(self) -> Iterator[_T]: + cast_to = cast(Any, self._cast_to) + response = self.response + process_data = self._client._process_response_data + iterator = self._iter_events() + + try: + for sse in iterator: + if sse.data.startswith("[DONE]"): + break + + if sse.event is None: + yield process_data(data=sse.json(), cast_to=cast_to, response=response) + + if sse.event == "error": + body = sse.data + + try: + body = sse.json() + err_msg = f"{body}" + except Exception: + err_msg = sse.data or f"Error code: {response.status_code}" + + raise self._client._make_status_error( + err_msg, + body=body, + response=self.response, + ) + finally: + # Ensure the response is closed even if the consumer doesn't read all data + response.close() + + def __enter__(self) -> Self: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + self.close() + + def close(self) -> None: + """ + Close the response and release the connection. + + Automatically called if the response body is read to completion. + """ + self.response.close() + + +class AsyncStream(Generic[_T]): + """Provides the core interface to iterate over an asynchronous stream response.""" + + response: httpx.Response + _options: Optional[FinalRequestOptions] = None + _decoder: SSEDecoder | SSEBytesDecoder + + def __init__( + self, + *, + cast_to: type[_T], + response: httpx.Response, + client: AsyncWriter, + options: Optional[FinalRequestOptions] = None, + ) -> None: + self.response = response + self._cast_to = cast_to + self._client = client + self._options = options + self._decoder = client._make_sse_decoder() + self._iterator = self.__stream__() + + async def __anext__(self) -> _T: + return await self._iterator.__anext__() + + async def __aiter__(self) -> AsyncIterator[_T]: + async for item in self._iterator: + yield item + + async def _iter_events(self) -> AsyncIterator[ServerSentEvent]: + async for sse in self._decoder.aiter_bytes(self.response.aiter_bytes()): + yield sse + + async def __stream__(self) -> AsyncIterator[_T]: + cast_to = cast(Any, self._cast_to) + response = self.response + process_data = self._client._process_response_data + iterator = self._iter_events() + + try: + async for sse in iterator: + if sse.data.startswith("[DONE]"): + break + + if sse.event is None: + yield process_data(data=sse.json(), cast_to=cast_to, response=response) + + if sse.event == "error": + body = sse.data + + try: + body = sse.json() + err_msg = f"{body}" + except Exception: + err_msg = sse.data or f"Error code: {response.status_code}" + + raise self._client._make_status_error( + err_msg, + body=body, + response=self.response, + ) + finally: + # Ensure the response is closed even if the consumer doesn't read all data + await response.aclose() + + async def __aenter__(self) -> Self: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + await self.close() + + async def close(self) -> None: + """ + Close the response and release the connection. + + Automatically called if the response body is read to completion. + """ + await self.response.aclose() + + +class ServerSentEvent: + def __init__( + self, + *, + event: str | None = None, + data: str | None = None, + id: str | None = None, + retry: int | None = None, + ) -> None: + if data is None: + data = "" + + self._id = id + self._data = data + self._event = event or None + self._retry = retry + + @property + def event(self) -> str | None: + return self._event + + @property + def id(self) -> str | None: + return self._id + + @property + def retry(self) -> int | None: + return self._retry + + @property + def data(self) -> str: + return self._data + + def json(self) -> Any: + return json.loads(self.data) + + @override + def __repr__(self) -> str: + return f"ServerSentEvent(event={self.event}, data={self.data}, id={self.id}, retry={self.retry})" + + +class SSEDecoder: + _data: list[str] + _event: str | None + _retry: int | None + _last_event_id: str | None + + def __init__(self) -> None: + self._event = None + self._data = [] + self._last_event_id = None + self._retry = None + + def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[ServerSentEvent]: + """Given an iterator that yields raw binary data, iterate over it & yield every event encountered""" + for chunk in self._iter_chunks(iterator): + # Split before decoding so splitlines() only uses \r and \n + for raw_line in chunk.splitlines(): + line = raw_line.decode("utf-8") + sse = self.decode(line) + if sse: + yield sse + + def _iter_chunks(self, iterator: Iterator[bytes]) -> Iterator[bytes]: + """Given an iterator that yields raw binary data, iterate over it and yield individual SSE chunks""" + data = b"" + for chunk in iterator: + for line in chunk.splitlines(keepends=True): + data += line + if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")): + yield data + data = b"" + if data: + yield data + + async def aiter_bytes(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[ServerSentEvent]: + """Given an iterator that yields raw binary data, iterate over it & yield every event encountered""" + async for chunk in self._aiter_chunks(iterator): + # Split before decoding so splitlines() only uses \r and \n + for raw_line in chunk.splitlines(): + line = raw_line.decode("utf-8") + sse = self.decode(line) + if sse: + yield sse + + async def _aiter_chunks(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[bytes]: + """Given an iterator that yields raw binary data, iterate over it and yield individual SSE chunks""" + data = b"" + async for chunk in iterator: + for line in chunk.splitlines(keepends=True): + data += line + if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")): + yield data + data = b"" + if data: + yield data + + def decode(self, line: str) -> ServerSentEvent | None: + # See: https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation # noqa: E501 + + if not line: + if not self._event and not self._data and not self._last_event_id and self._retry is None: + return None + + sse = ServerSentEvent( + event=self._event, + data="\n".join(self._data), + id=self._last_event_id, + retry=self._retry, + ) + + # NOTE: as per the SSE spec, do not reset last_event_id. + self._event = None + self._data = [] + self._retry = None + + return sse + + if line.startswith(":"): + return None + + fieldname, _, value = line.partition(":") + + if value.startswith(" "): + value = value[1:] + + if fieldname == "event": + self._event = value + elif fieldname == "data": + self._data.append(value) + elif fieldname == "id": + if "\0" in value: + pass + else: + self._last_event_id = value + elif fieldname == "retry": + try: + self._retry = int(value) + except (TypeError, ValueError): + pass + else: + pass # Field is ignored. + + return None + + +@runtime_checkable +class SSEBytesDecoder(Protocol): + def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[ServerSentEvent]: + """Given an iterator that yields raw binary data, iterate over it & yield every event encountered""" + ... + + def aiter_bytes(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[ServerSentEvent]: + """Given an async iterator that yields raw binary data, iterate over it & yield every event encountered""" + ... + + +def is_stream_class_type(typ: type) -> TypeGuard[type[Stream[object]] | type[AsyncStream[object]]]: + """TypeGuard for determining whether or not the given type is a subclass of `Stream` / `AsyncStream`""" + origin = get_origin(typ) or typ + return inspect.isclass(origin) and issubclass(origin, (Stream, AsyncStream)) + + +def extract_stream_chunk_type( + stream_cls: type, + *, + failure_message: str | None = None, +) -> type: + """Given a type like `Stream[T]`, returns the generic type variable `T`. + + This also handles the case where a concrete subclass is given, e.g. + ```py + class MyStream(Stream[bytes]): + ... + + extract_stream_chunk_type(MyStream) -> bytes + ``` + """ + from ._base_client import Stream, AsyncStream + + return extract_type_var_from_base( + stream_cls, + index=0, + generic_bases=cast("tuple[type, ...]", (Stream, AsyncStream)), + failure_message=failure_message, + ) diff --git a/src/writerai/_types.py b/src/writerai/_types.py new file mode 100644 index 00000000..2a661e32 --- /dev/null +++ b/src/writerai/_types.py @@ -0,0 +1,273 @@ +from __future__ import annotations + +from os import PathLike +from typing import ( + IO, + TYPE_CHECKING, + Any, + Dict, + List, + Type, + Tuple, + Union, + Mapping, + TypeVar, + Callable, + Iterable, + Iterator, + Optional, + Sequence, + AsyncIterable, +) +from typing_extensions import ( + Set, + Literal, + Protocol, + TypeAlias, + TypedDict, + SupportsIndex, + overload, + override, + runtime_checkable, +) + +import httpx +import pydantic +from httpx import URL, Proxy, Timeout, Response, BaseTransport, AsyncBaseTransport + +if TYPE_CHECKING: + from ._models import BaseModel + from ._response import APIResponse, AsyncAPIResponse + +Transport = BaseTransport +AsyncTransport = AsyncBaseTransport +Query = Mapping[str, object] +Body = object +AnyMapping = Mapping[str, object] +ModelT = TypeVar("ModelT", bound=pydantic.BaseModel) +_T = TypeVar("_T") + +ArrayFormat = Literal["comma", "repeat", "indices", "brackets"] +NestedFormat = Literal["dots", "brackets"] + + +# Approximates httpx internal ProxiesTypes and RequestFiles types +# while adding support for `PathLike` instances +ProxiesDict = Dict["str | URL", Union[None, str, URL, Proxy]] +ProxiesTypes = Union[str, Proxy, ProxiesDict] +if TYPE_CHECKING: + Base64FileInput = Union[IO[bytes], PathLike[str]] + FileContent = Union[IO[bytes], bytes, PathLike[str]] +else: + Base64FileInput = Union[IO[bytes], PathLike] + FileContent = Union[IO[bytes], bytes, PathLike] # PathLike is not subscriptable in Python 3.8. + + +# Used for sending raw binary data / streaming data in request bodies +# e.g. for file uploads without multipart encoding +BinaryTypes = Union[bytes, bytearray, IO[bytes], Iterable[bytes]] +AsyncBinaryTypes = Union[bytes, bytearray, IO[bytes], AsyncIterable[bytes]] + +FileTypes = Union[ + # file (or bytes) + FileContent, + # (filename, file (or bytes)) + Tuple[Optional[str], FileContent], + # (filename, file (or bytes), content_type) + Tuple[Optional[str], FileContent, Optional[str]], + # (filename, file (or bytes), content_type, headers) + Tuple[Optional[str], FileContent, Optional[str], Mapping[str, str]], +] +RequestFiles = Union[Mapping[str, FileTypes], Sequence[Tuple[str, FileTypes]]] + +# duplicate of the above but without our custom file support +HttpxFileContent = Union[IO[bytes], bytes] +HttpxFileTypes = Union[ + # file (or bytes) + HttpxFileContent, + # (filename, file (or bytes)) + Tuple[Optional[str], HttpxFileContent], + # (filename, file (or bytes), content_type) + Tuple[Optional[str], HttpxFileContent, Optional[str]], + # (filename, file (or bytes), content_type, headers) + Tuple[Optional[str], HttpxFileContent, Optional[str], Mapping[str, str]], +] +HttpxRequestFiles = Union[Mapping[str, HttpxFileTypes], Sequence[Tuple[str, HttpxFileTypes]]] + +# Workaround to support (cast_to: Type[ResponseT]) -> ResponseT +# where ResponseT includes `None`. In order to support directly +# passing `None`, overloads would have to be defined for every +# method that uses `ResponseT` which would lead to an unacceptable +# amount of code duplication and make it unreadable. See _base_client.py +# for example usage. +# +# This unfortunately means that you will either have +# to import this type and pass it explicitly: +# +# from writerai import NoneType +# client.get('/foo', cast_to=NoneType) +# +# or build it yourself: +# +# client.get('/foo', cast_to=type(None)) +if TYPE_CHECKING: + NoneType: Type[None] +else: + NoneType = type(None) + + +class RequestOptions(TypedDict, total=False): + headers: Headers + max_retries: int + timeout: float | Timeout | None + params: Query + extra_json: AnyMapping + idempotency_key: str + follow_redirects: bool + + +# Sentinel class used until PEP 0661 is accepted +class NotGiven: + """ + For parameters with a meaningful None value, we need to distinguish between + the user explicitly passing None, and the user not passing the parameter at + all. + + User code shouldn't need to use not_given directly. + + For example: + + ```py + def create(timeout: Timeout | None | NotGiven = not_given): ... + + + create(timeout=1) # 1s timeout + create(timeout=None) # No timeout + create() # Default timeout behavior + ``` + """ + + def __bool__(self) -> Literal[False]: + return False + + @override + def __repr__(self) -> str: + return "NOT_GIVEN" + + +not_given = NotGiven() +# for backwards compatibility: +NOT_GIVEN = NotGiven() + + +class Omit: + """ + To explicitly omit something from being sent in a request, use `omit`. + + ```py + # as the default `Content-Type` header is `application/json` that will be sent + client.post("/upload/files", files={"file": b"my raw file content"}) + + # you can't explicitly override the header as it has to be dynamically generated + # to look something like: 'multipart/form-data; boundary=0d8382fcf5f8c3be01ca2e11002d2983' + client.post(..., headers={"Content-Type": "multipart/form-data"}) + + # instead you can remove the default `application/json` header by passing omit + client.post(..., headers={"Content-Type": omit}) + ``` + """ + + def __bool__(self) -> Literal[False]: + return False + + +omit = Omit() + + +@runtime_checkable +class ModelBuilderProtocol(Protocol): + @classmethod + def build( + cls: type[_T], + *, + response: Response, + data: object, + ) -> _T: ... + + +Headers = Mapping[str, Union[str, Omit]] + + +class HeadersLikeProtocol(Protocol): + def get(self, __key: str) -> str | None: ... + + +HeadersLike = Union[Headers, HeadersLikeProtocol] + +ResponseT = TypeVar( + "ResponseT", + bound=Union[ + object, + str, + None, + "BaseModel", + List[Any], + Dict[str, Any], + Response, + ModelBuilderProtocol, + "APIResponse[Any]", + "AsyncAPIResponse[Any]", + ], +) + +StrBytesIntFloat = Union[str, bytes, int, float] + +# Note: copied from Pydantic +# https://github.com/pydantic/pydantic/blob/6f31f8f68ef011f84357330186f603ff295312fd/pydantic/main.py#L79 +IncEx: TypeAlias = Union[Set[int], Set[str], Mapping[int, Union["IncEx", bool]], Mapping[str, Union["IncEx", bool]]] + +PostParser = Callable[[Any], Any] + + +@runtime_checkable +class InheritsGeneric(Protocol): + """Represents a type that has inherited from `Generic` + + The `__orig_bases__` property can be used to determine the resolved + type variable for a given base class. + """ + + __orig_bases__: tuple[_GenericAlias] + + +class _GenericAlias(Protocol): + __origin__: type[object] + + +class HttpxSendArgs(TypedDict, total=False): + auth: httpx.Auth + follow_redirects: bool + + +_T_co = TypeVar("_T_co", covariant=True) + + +if TYPE_CHECKING: + # This works because str.__contains__ does not accept object (either in typeshed or at runtime) + # https://github.com/hauntsaninja/useful_types/blob/5e9710f3875107d068e7679fd7fec9cfab0eff3b/useful_types/__init__.py#L285 + # + # Note: index() and count() methods are intentionally omitted to allow pyright to properly + # infer TypedDict types when dict literals are used in lists assigned to SequenceNotStr. + class SequenceNotStr(Protocol[_T_co]): + @overload + def __getitem__(self, index: SupportsIndex, /) -> _T_co: ... + @overload + def __getitem__(self, index: slice, /) -> Sequence[_T_co]: ... + def __contains__(self, value: object, /) -> bool: ... + def __len__(self) -> int: ... + def __iter__(self) -> Iterator[_T_co]: ... + def __reversed__(self) -> Iterator[_T_co]: ... +else: + # just point this to a normal `Sequence` at runtime to avoid having to special case + # deserializing our custom sequence type + SequenceNotStr = Sequence diff --git a/src/writerai/_utils/__init__.py b/src/writerai/_utils/__init__.py new file mode 100644 index 00000000..1c090e51 --- /dev/null +++ b/src/writerai/_utils/__init__.py @@ -0,0 +1,64 @@ +from ._path import path_template as path_template +from ._sync import asyncify as asyncify +from ._proxy import LazyProxy as LazyProxy +from ._utils import ( + flatten as flatten, + is_dict as is_dict, + is_list as is_list, + is_given as is_given, + is_tuple as is_tuple, + json_safe as json_safe, + lru_cache as lru_cache, + is_mapping as is_mapping, + is_tuple_t as is_tuple_t, + is_iterable as is_iterable, + is_sequence as is_sequence, + coerce_float as coerce_float, + is_mapping_t as is_mapping_t, + removeprefix as removeprefix, + removesuffix as removesuffix, + extract_files as extract_files, + is_sequence_t as is_sequence_t, + required_args as required_args, + coerce_boolean as coerce_boolean, + coerce_integer as coerce_integer, + file_from_path as file_from_path, + strip_not_given as strip_not_given, + get_async_library as get_async_library, + maybe_coerce_float as maybe_coerce_float, + get_required_header as get_required_header, + maybe_coerce_boolean as maybe_coerce_boolean, + maybe_coerce_integer as maybe_coerce_integer, +) +from ._compat import ( + get_args as get_args, + is_union as is_union, + get_origin as get_origin, + is_typeddict as is_typeddict, + is_literal_type as is_literal_type, +) +from ._typing import ( + is_list_type as is_list_type, + is_union_type as is_union_type, + extract_type_arg as extract_type_arg, + is_iterable_type as is_iterable_type, + is_required_type as is_required_type, + is_sequence_type as is_sequence_type, + is_annotated_type as is_annotated_type, + is_type_alias_type as is_type_alias_type, + strip_annotated_type as strip_annotated_type, + extract_type_var_from_base as extract_type_var_from_base, +) +from ._streams import consume_sync_iterator as consume_sync_iterator, consume_async_iterator as consume_async_iterator +from ._transform import ( + PropertyInfo as PropertyInfo, + transform as transform, + async_transform as async_transform, + maybe_transform as maybe_transform, + async_maybe_transform as async_maybe_transform, +) +from ._reflection import ( + function_has_argument as function_has_argument, + assert_signatures_in_sync as assert_signatures_in_sync, +) +from ._datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime diff --git a/src/writerai/_utils/_compat.py b/src/writerai/_utils/_compat.py new file mode 100644 index 00000000..2c70b299 --- /dev/null +++ b/src/writerai/_utils/_compat.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +import sys +import typing_extensions +from typing import Any, Type, Union, Literal, Optional +from datetime import date, datetime +from typing_extensions import get_args as _get_args, get_origin as _get_origin + +from .._types import StrBytesIntFloat +from ._datetime_parse import parse_date as _parse_date, parse_datetime as _parse_datetime + +_LITERAL_TYPES = {Literal, typing_extensions.Literal} + + +def get_args(tp: type[Any]) -> tuple[Any, ...]: + return _get_args(tp) + + +def get_origin(tp: type[Any]) -> type[Any] | None: + return _get_origin(tp) + + +def is_union(tp: Optional[Type[Any]]) -> bool: + if sys.version_info < (3, 10): + return tp is Union # type: ignore[comparison-overlap] + else: + import types + + return tp is Union or tp is types.UnionType # type: ignore[comparison-overlap] + + +def is_typeddict(tp: Type[Any]) -> bool: + return typing_extensions.is_typeddict(tp) + + +def is_literal_type(tp: Type[Any]) -> bool: + return get_origin(tp) in _LITERAL_TYPES + + +def parse_date(value: Union[date, StrBytesIntFloat]) -> date: + return _parse_date(value) + + +def parse_datetime(value: Union[datetime, StrBytesIntFloat]) -> datetime: + return _parse_datetime(value) diff --git a/src/writerai/_utils/_datetime_parse.py b/src/writerai/_utils/_datetime_parse.py new file mode 100644 index 00000000..7cb9d9e6 --- /dev/null +++ b/src/writerai/_utils/_datetime_parse.py @@ -0,0 +1,136 @@ +""" +This file contains code from https://github.com/pydantic/pydantic/blob/main/pydantic/v1/datetime_parse.py +without the Pydantic v1 specific errors. +""" + +from __future__ import annotations + +import re +from typing import Dict, Union, Optional +from datetime import date, datetime, timezone, timedelta + +from .._types import StrBytesIntFloat + +date_expr = r"(?P\d{4})-(?P\d{1,2})-(?P\d{1,2})" +time_expr = ( + r"(?P\d{1,2}):(?P\d{1,2})" + r"(?::(?P\d{1,2})(?:\.(?P\d{1,6})\d{0,6})?)?" + r"(?PZ|[+-]\d{2}(?::?\d{2})?)?$" +) + +date_re = re.compile(f"{date_expr}$") +datetime_re = re.compile(f"{date_expr}[T ]{time_expr}") + + +EPOCH = datetime(1970, 1, 1) +# if greater than this, the number is in ms, if less than or equal it's in seconds +# (in seconds this is 11th October 2603, in ms it's 20th August 1970) +MS_WATERSHED = int(2e10) +# slightly more than datetime.max in ns - (datetime.max - EPOCH).total_seconds() * 1e9 +MAX_NUMBER = int(3e20) + + +def _get_numeric(value: StrBytesIntFloat, native_expected_type: str) -> Union[None, int, float]: + if isinstance(value, (int, float)): + return value + try: + return float(value) + except ValueError: + return None + except TypeError: + raise TypeError(f"invalid type; expected {native_expected_type}, string, bytes, int or float") from None + + +def _from_unix_seconds(seconds: Union[int, float]) -> datetime: + if seconds > MAX_NUMBER: + return datetime.max + elif seconds < -MAX_NUMBER: + return datetime.min + + while abs(seconds) > MS_WATERSHED: + seconds /= 1000 + dt = EPOCH + timedelta(seconds=seconds) + return dt.replace(tzinfo=timezone.utc) + + +def _parse_timezone(value: Optional[str]) -> Union[None, int, timezone]: + if value == "Z": + return timezone.utc + elif value is not None: + offset_mins = int(value[-2:]) if len(value) > 3 else 0 + offset = 60 * int(value[1:3]) + offset_mins + if value[0] == "-": + offset = -offset + return timezone(timedelta(minutes=offset)) + else: + return None + + +def parse_datetime(value: Union[datetime, StrBytesIntFloat]) -> datetime: + """ + Parse a datetime/int/float/string and return a datetime.datetime. + + This function supports time zone offsets. When the input contains one, + the output uses a timezone with a fixed offset from UTC. + + Raise ValueError if the input is well formatted but not a valid datetime. + Raise ValueError if the input isn't well formatted. + """ + if isinstance(value, datetime): + return value + + number = _get_numeric(value, "datetime") + if number is not None: + return _from_unix_seconds(number) + + if isinstance(value, bytes): + value = value.decode() + + assert not isinstance(value, (float, int)) + + match = datetime_re.match(value) + if match is None: + raise ValueError("invalid datetime format") + + kw = match.groupdict() + if kw["microsecond"]: + kw["microsecond"] = kw["microsecond"].ljust(6, "0") + + tzinfo = _parse_timezone(kw.pop("tzinfo")) + kw_: Dict[str, Union[None, int, timezone]] = {k: int(v) for k, v in kw.items() if v is not None} + kw_["tzinfo"] = tzinfo + + return datetime(**kw_) # type: ignore + + +def parse_date(value: Union[date, StrBytesIntFloat]) -> date: + """ + Parse a date/int/float/string and return a datetime.date. + + Raise ValueError if the input is well formatted but not a valid date. + Raise ValueError if the input isn't well formatted. + """ + if isinstance(value, date): + if isinstance(value, datetime): + return value.date() + else: + return value + + number = _get_numeric(value, "date") + if number is not None: + return _from_unix_seconds(number).date() + + if isinstance(value, bytes): + value = value.decode() + + assert not isinstance(value, (float, int)) + match = date_re.match(value) + if match is None: + raise ValueError("invalid date format") + + kw = {k: int(v) for k, v in match.groupdict().items()} + + try: + return date(**kw) + except ValueError: + raise ValueError("invalid date format") from None diff --git a/src/writerai/_utils/_json.py b/src/writerai/_utils/_json.py new file mode 100644 index 00000000..60584214 --- /dev/null +++ b/src/writerai/_utils/_json.py @@ -0,0 +1,35 @@ +import json +from typing import Any +from datetime import datetime +from typing_extensions import override + +import pydantic + +from .._compat import model_dump + + +def openapi_dumps(obj: Any) -> bytes: + """ + Serialize an object to UTF-8 encoded JSON bytes. + + Extends the standard json.dumps with support for additional types + commonly used in the SDK, such as `datetime`, `pydantic.BaseModel`, etc. + """ + return json.dumps( + obj, + cls=_CustomEncoder, + # Uses the same defaults as httpx's JSON serialization + ensure_ascii=False, + separators=(",", ":"), + allow_nan=False, + ).encode() + + +class _CustomEncoder(json.JSONEncoder): + @override + def default(self, o: Any) -> Any: + if isinstance(o, datetime): + return o.isoformat() + if isinstance(o, pydantic.BaseModel): + return model_dump(o, exclude_unset=True, mode="json", by_alias=True) + return super().default(o) diff --git a/src/writerai/_utils/_logs.py b/src/writerai/_utils/_logs.py new file mode 100644 index 00000000..f9529c8f --- /dev/null +++ b/src/writerai/_utils/_logs.py @@ -0,0 +1,25 @@ +import os +import logging + +logger: logging.Logger = logging.getLogger("writerai") +httpx_logger: logging.Logger = logging.getLogger("httpx") + + +def _basic_config() -> None: + # e.g. [2023-10-05 14:12:26 - writerai._base_client:818 - DEBUG] HTTP Request: POST http://127.0.0.1:4010/foo/bar "200 OK" + logging.basicConfig( + format="[%(asctime)s - %(name)s:%(lineno)d - %(levelname)s] %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + + +def setup_logging() -> None: + env = os.environ.get("WRITER_LOG") + if env == "debug": + _basic_config() + logger.setLevel(logging.DEBUG) + httpx_logger.setLevel(logging.DEBUG) + elif env == "info": + _basic_config() + logger.setLevel(logging.INFO) + httpx_logger.setLevel(logging.INFO) diff --git a/src/writerai/_utils/_path.py b/src/writerai/_utils/_path.py new file mode 100644 index 00000000..4d6e1e4c --- /dev/null +++ b/src/writerai/_utils/_path.py @@ -0,0 +1,127 @@ +from __future__ import annotations + +import re +from typing import ( + Any, + Mapping, + Callable, +) +from urllib.parse import quote + +# Matches '.' or '..' where each dot is either literal or percent-encoded (%2e / %2E). +_DOT_SEGMENT_RE = re.compile(r"^(?:\.|%2[eE]){1,2}$") + +_PLACEHOLDER_RE = re.compile(r"\{(\w+)\}") + + +def _quote_path_segment_part(value: str) -> str: + """Percent-encode `value` for use in a URI path segment. + + Considers characters not in `pchar` set from RFC 3986 §3.3 to be unsafe. + https://datatracker.ietf.org/doc/html/rfc3986#section-3.3 + """ + # quote() already treats unreserved characters (letters, digits, and -._~) + # as safe, so we only need to add sub-delims, ':', and '@'. + # Notably, unlike the default `safe` for quote(), / is unsafe and must be quoted. + return quote(value, safe="!$&'()*+,;=:@") + + +def _quote_query_part(value: str) -> str: + """Percent-encode `value` for use in a URI query string. + + Considers &, = and characters not in `query` set from RFC 3986 §3.4 to be unsafe. + https://datatracker.ietf.org/doc/html/rfc3986#section-3.4 + """ + return quote(value, safe="!$'()*+,;:@/?") + + +def _quote_fragment_part(value: str) -> str: + """Percent-encode `value` for use in a URI fragment. + + Considers characters not in `fragment` set from RFC 3986 §3.5 to be unsafe. + https://datatracker.ietf.org/doc/html/rfc3986#section-3.5 + """ + return quote(value, safe="!$&'()*+,;=:@/?") + + +def _interpolate( + template: str, + values: Mapping[str, Any], + quoter: Callable[[str], str], +) -> str: + """Replace {name} placeholders in `template`, quoting each value with `quoter`. + + Placeholder names are looked up in `values`. + + Raises: + KeyError: If a placeholder is not found in `values`. + """ + # re.split with a capturing group returns alternating + # [text, name, text, name, ..., text] elements. + parts = _PLACEHOLDER_RE.split(template) + + for i in range(1, len(parts), 2): + name = parts[i] + if name not in values: + raise KeyError(f"a value for placeholder {{{name}}} was not provided") + val = values[name] + if val is None: + parts[i] = "null" + elif isinstance(val, bool): + parts[i] = "true" if val else "false" + else: + parts[i] = quoter(str(values[name])) + + return "".join(parts) + + +def path_template(template: str, /, **kwargs: Any) -> str: + """Interpolate {name} placeholders in `template` from keyword arguments. + + Args: + template: The template string containing {name} placeholders. + **kwargs: Keyword arguments to interpolate into the template. + + Returns: + The template with placeholders interpolated and percent-encoded. + + Safe characters for percent-encoding are dependent on the URI component. + Placeholders in path and fragment portions are percent-encoded where the `segment` + and `fragment` sets from RFC 3986 respectively are considered safe. + Placeholders in the query portion are percent-encoded where the `query` set from + RFC 3986 §3.3 is considered safe except for = and & characters. + + Raises: + KeyError: If a placeholder is not found in `kwargs`. + ValueError: If resulting path contains /./ or /../ segments (including percent-encoded dot-segments). + """ + # Split the template into path, query, and fragment portions. + fragment_template: str | None = None + query_template: str | None = None + + rest = template + if "#" in rest: + rest, fragment_template = rest.split("#", 1) + if "?" in rest: + rest, query_template = rest.split("?", 1) + path_template = rest + + # Interpolate each portion with the appropriate quoting rules. + path_result = _interpolate(path_template, kwargs, _quote_path_segment_part) + + # Reject dot-segments (. and ..) in the final assembled path. The check + # runs after interpolation so that adjacent placeholders or a mix of static + # text and placeholders that together form a dot-segment are caught. + # Also reject percent-encoded dot-segments to protect against incorrectly + # implemented normalization in servers/proxies. + for segment in path_result.split("/"): + if _DOT_SEGMENT_RE.match(segment): + raise ValueError(f"Constructed path {path_result!r} contains dot-segment {segment!r} which is not allowed") + + result = path_result + if query_template is not None: + result += "?" + _interpolate(query_template, kwargs, _quote_query_part) + if fragment_template is not None: + result += "#" + _interpolate(fragment_template, kwargs, _quote_fragment_part) + + return result diff --git a/src/writerai/_utils/_proxy.py b/src/writerai/_utils/_proxy.py new file mode 100644 index 00000000..0f239a33 --- /dev/null +++ b/src/writerai/_utils/_proxy.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Generic, TypeVar, Iterable, cast +from typing_extensions import override + +T = TypeVar("T") + + +class LazyProxy(Generic[T], ABC): + """Implements data methods to pretend that an instance is another instance. + + This includes forwarding attribute access and other methods. + """ + + # Note: we have to special case proxies that themselves return proxies + # to support using a proxy as a catch-all for any random access, e.g. `proxy.foo.bar.baz` + + def __getattr__(self, attr: str) -> object: + proxied = self.__get_proxied__() + if isinstance(proxied, LazyProxy): + return proxied # pyright: ignore + return getattr(proxied, attr) + + @override + def __repr__(self) -> str: + proxied = self.__get_proxied__() + if isinstance(proxied, LazyProxy): + return proxied.__class__.__name__ + return repr(self.__get_proxied__()) + + @override + def __str__(self) -> str: + proxied = self.__get_proxied__() + if isinstance(proxied, LazyProxy): + return proxied.__class__.__name__ + return str(proxied) + + @override + def __dir__(self) -> Iterable[str]: + proxied = self.__get_proxied__() + if isinstance(proxied, LazyProxy): + return [] + return proxied.__dir__() + + @property # type: ignore + @override + def __class__(self) -> type: # pyright: ignore + try: + proxied = self.__get_proxied__() + except Exception: + return type(self) + if issubclass(type(proxied), LazyProxy): + return type(proxied) + return proxied.__class__ + + def __get_proxied__(self) -> T: + return self.__load__() + + def __as_proxied__(self) -> T: + """Helper method that returns the current proxy, typed as the loaded object""" + return cast(T, self) + + @abstractmethod + def __load__(self) -> T: ... diff --git a/src/writerai/_utils/_reflection.py b/src/writerai/_utils/_reflection.py new file mode 100644 index 00000000..89aa712a --- /dev/null +++ b/src/writerai/_utils/_reflection.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +import inspect +from typing import Any, Callable + + +def function_has_argument(func: Callable[..., Any], arg_name: str) -> bool: + """Returns whether or not the given function has a specific parameter""" + sig = inspect.signature(func) + return arg_name in sig.parameters + + +def assert_signatures_in_sync( + source_func: Callable[..., Any], + check_func: Callable[..., Any], + *, + exclude_params: set[str] = set(), +) -> None: + """Ensure that the signature of the second function matches the first.""" + + check_sig = inspect.signature(check_func) + source_sig = inspect.signature(source_func) + + errors: list[str] = [] + + for name, source_param in source_sig.parameters.items(): + if name in exclude_params: + continue + + custom_param = check_sig.parameters.get(name) + if not custom_param: + errors.append(f"the `{name}` param is missing") + continue + + if custom_param.annotation != source_param.annotation: + errors.append( + f"types for the `{name}` param are do not match; source={repr(source_param.annotation)} checking={repr(custom_param.annotation)}" + ) + continue + + if errors: + raise AssertionError(f"{len(errors)} errors encountered when comparing signatures:\n\n" + "\n\n".join(errors)) diff --git a/src/writerai/_utils/_resources_proxy.py b/src/writerai/_utils/_resources_proxy.py new file mode 100644 index 00000000..200626ef --- /dev/null +++ b/src/writerai/_utils/_resources_proxy.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +from typing import Any +from typing_extensions import override + +from ._proxy import LazyProxy + + +class ResourcesProxy(LazyProxy[Any]): + """A proxy for the `writerai.resources` module. + + This is used so that we can lazily import `writerai.resources` only when + needed *and* so that users can just import `writerai` and reference `writerai.resources` + """ + + @override + def __load__(self) -> Any: + import importlib + + mod = importlib.import_module("writerai.resources") + return mod + + +resources = ResourcesProxy().__as_proxied__() diff --git a/src/writerai/_utils/_streams.py b/src/writerai/_utils/_streams.py new file mode 100644 index 00000000..f4a0208f --- /dev/null +++ b/src/writerai/_utils/_streams.py @@ -0,0 +1,12 @@ +from typing import Any +from typing_extensions import Iterator, AsyncIterator + + +def consume_sync_iterator(iterator: Iterator[Any]) -> None: + for _ in iterator: + ... + + +async def consume_async_iterator(iterator: AsyncIterator[Any]) -> None: + async for _ in iterator: + ... diff --git a/src/writerai/_utils/_sync.py b/src/writerai/_utils/_sync.py new file mode 100644 index 00000000..f6027c18 --- /dev/null +++ b/src/writerai/_utils/_sync.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +import asyncio +import functools +from typing import TypeVar, Callable, Awaitable +from typing_extensions import ParamSpec + +import anyio +import sniffio +import anyio.to_thread + +T_Retval = TypeVar("T_Retval") +T_ParamSpec = ParamSpec("T_ParamSpec") + + +async def to_thread( + func: Callable[T_ParamSpec, T_Retval], /, *args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs +) -> T_Retval: + if sniffio.current_async_library() == "asyncio": + return await asyncio.to_thread(func, *args, **kwargs) + + return await anyio.to_thread.run_sync( + functools.partial(func, *args, **kwargs), + ) + + +# inspired by `asyncer`, https://github.com/tiangolo/asyncer +def asyncify(function: Callable[T_ParamSpec, T_Retval]) -> Callable[T_ParamSpec, Awaitable[T_Retval]]: + """ + Take a blocking function and create an async one that receives the same + positional and keyword arguments. + + Usage: + + ```python + def blocking_func(arg1, arg2, kwarg1=None): + # blocking code + return result + + + result = asyncify(blocking_function)(arg1, arg2, kwarg1=value1) + ``` + + ## Arguments + + `function`: a blocking regular callable (e.g. a function) + + ## Return + + An async function that takes the same positional and keyword arguments as the + original one, that when called runs the same original function in a thread worker + and returns the result. + """ + + async def wrapper(*args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs) -> T_Retval: + return await to_thread(function, *args, **kwargs) + + return wrapper diff --git a/src/writerai/_utils/_transform.py b/src/writerai/_utils/_transform.py new file mode 100644 index 00000000..52075492 --- /dev/null +++ b/src/writerai/_utils/_transform.py @@ -0,0 +1,457 @@ +from __future__ import annotations + +import io +import base64 +import pathlib +from typing import Any, Mapping, TypeVar, cast +from datetime import date, datetime +from typing_extensions import Literal, get_args, override, get_type_hints as _get_type_hints + +import anyio +import pydantic + +from ._utils import ( + is_list, + is_given, + lru_cache, + is_mapping, + is_iterable, + is_sequence, +) +from .._files import is_base64_file_input +from ._compat import get_origin, is_typeddict +from ._typing import ( + is_list_type, + is_union_type, + extract_type_arg, + is_iterable_type, + is_required_type, + is_sequence_type, + is_annotated_type, + strip_annotated_type, +) + +_T = TypeVar("_T") + + +# TODO: support for drilling globals() and locals() +# TODO: ensure works correctly with forward references in all cases + + +PropertyFormat = Literal["iso8601", "base64", "custom"] + + +class PropertyInfo: + """Metadata class to be used in Annotated types to provide information about a given type. + + For example: + + class MyParams(TypedDict): + account_holder_name: Annotated[str, PropertyInfo(alias='accountHolderName')] + + This means that {'account_holder_name': 'Robert'} will be transformed to {'accountHolderName': 'Robert'} before being sent to the API. + """ + + alias: str | None + format: PropertyFormat | None + format_template: str | None + discriminator: str | None + + def __init__( + self, + *, + alias: str | None = None, + format: PropertyFormat | None = None, + format_template: str | None = None, + discriminator: str | None = None, + ) -> None: + self.alias = alias + self.format = format + self.format_template = format_template + self.discriminator = discriminator + + @override + def __repr__(self) -> str: + return f"{self.__class__.__name__}(alias='{self.alias}', format={self.format}, format_template='{self.format_template}', discriminator='{self.discriminator}')" + + +def maybe_transform( + data: object, + expected_type: object, +) -> Any | None: + """Wrapper over `transform()` that allows `None` to be passed. + + See `transform()` for more details. + """ + if data is None: + return None + return transform(data, expected_type) + + +# Wrapper over _transform_recursive providing fake types +def transform( + data: _T, + expected_type: object, +) -> _T: + """Transform dictionaries based off of type information from the given type, for example: + + ```py + class Params(TypedDict, total=False): + card_id: Required[Annotated[str, PropertyInfo(alias="cardID")]] + + + transformed = transform({"card_id": ""}, Params) + # {'cardID': ''} + ``` + + Any keys / data that does not have type information given will be included as is. + + It should be noted that the transformations that this function does are not represented in the type system. + """ + transformed = _transform_recursive(data, annotation=cast(type, expected_type)) + return cast(_T, transformed) + + +@lru_cache(maxsize=8096) +def _get_annotated_type(type_: type) -> type | None: + """If the given type is an `Annotated` type then it is returned, if not `None` is returned. + + This also unwraps the type when applicable, e.g. `Required[Annotated[T, ...]]` + """ + if is_required_type(type_): + # Unwrap `Required[Annotated[T, ...]]` to `Annotated[T, ...]` + type_ = get_args(type_)[0] + + if is_annotated_type(type_): + return type_ + + return None + + +def _maybe_transform_key(key: str, type_: type) -> str: + """Transform the given `data` based on the annotations provided in `type_`. + + Note: this function only looks at `Annotated` types that contain `PropertyInfo` metadata. + """ + annotated_type = _get_annotated_type(type_) + if annotated_type is None: + # no `Annotated` definition for this type, no transformation needed + return key + + # ignore the first argument as it is the actual type + annotations = get_args(annotated_type)[1:] + for annotation in annotations: + if isinstance(annotation, PropertyInfo) and annotation.alias is not None: + return annotation.alias + + return key + + +def _no_transform_needed(annotation: type) -> bool: + return annotation == float or annotation == int + + +def _transform_recursive( + data: object, + *, + annotation: type, + inner_type: type | None = None, +) -> object: + """Transform the given data against the expected type. + + Args: + annotation: The direct type annotation given to the particular piece of data. + This may or may not be wrapped in metadata types, e.g. `Required[T]`, `Annotated[T, ...]` etc + + inner_type: If applicable, this is the "inside" type. This is useful in certain cases where the outside type + is a container type such as `List[T]`. In that case `inner_type` should be set to `T` so that each entry in + the list can be transformed using the metadata from the container type. + + Defaults to the same value as the `annotation` argument. + """ + from .._compat import model_dump + + if inner_type is None: + inner_type = annotation + + stripped_type = strip_annotated_type(inner_type) + origin = get_origin(stripped_type) or stripped_type + if is_typeddict(stripped_type) and is_mapping(data): + return _transform_typeddict(data, stripped_type) + + if origin == dict and is_mapping(data): + items_type = get_args(stripped_type)[1] + return {key: _transform_recursive(value, annotation=items_type) for key, value in data.items()} + + if ( + # List[T] + (is_list_type(stripped_type) and is_list(data)) + # Iterable[T] + or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str)) + # Sequence[T] + or (is_sequence_type(stripped_type) and is_sequence(data) and not isinstance(data, str)) + ): + # dicts are technically iterable, but it is an iterable on the keys of the dict and is not usually + # intended as an iterable, so we don't transform it. + if isinstance(data, dict): + return cast(object, data) + + inner_type = extract_type_arg(stripped_type, 0) + if _no_transform_needed(inner_type): + # for some types there is no need to transform anything, so we can get a small + # perf boost from skipping that work. + # + # but we still need to convert to a list to ensure the data is json-serializable + if is_list(data): + return data + return list(data) + + return [_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data] + + if is_union_type(stripped_type): + # For union types we run the transformation against all subtypes to ensure that everything is transformed. + # + # TODO: there may be edge cases where the same normalized field name will transform to two different names + # in different subtypes. + for subtype in get_args(stripped_type): + data = _transform_recursive(data, annotation=annotation, inner_type=subtype) + return data + + if isinstance(data, pydantic.BaseModel): + return model_dump(data, exclude_unset=True, mode="json") + + annotated_type = _get_annotated_type(annotation) + if annotated_type is None: + return data + + # ignore the first argument as it is the actual type + annotations = get_args(annotated_type)[1:] + for annotation in annotations: + if isinstance(annotation, PropertyInfo) and annotation.format is not None: + return _format_data(data, annotation.format, annotation.format_template) + + return data + + +def _format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object: + if isinstance(data, (date, datetime)): + if format_ == "iso8601": + return data.isoformat() + + if format_ == "custom" and format_template is not None: + return data.strftime(format_template) + + if format_ == "base64" and is_base64_file_input(data): + binary: str | bytes | None = None + + if isinstance(data, pathlib.Path): + binary = data.read_bytes() + elif isinstance(data, io.IOBase): + binary = data.read() + + if isinstance(binary, str): # type: ignore[unreachable] + binary = binary.encode() + + if not isinstance(binary, bytes): + raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}") + + return base64.b64encode(binary).decode("ascii") + + return data + + +def _transform_typeddict( + data: Mapping[str, object], + expected_type: type, +) -> Mapping[str, object]: + result: dict[str, object] = {} + annotations = get_type_hints(expected_type, include_extras=True) + for key, value in data.items(): + if not is_given(value): + # we don't need to include omitted values here as they'll + # be stripped out before the request is sent anyway + continue + + type_ = annotations.get(key) + if type_ is None: + # we do not have a type annotation for this field, leave it as is + result[key] = value + else: + result[_maybe_transform_key(key, type_)] = _transform_recursive(value, annotation=type_) + return result + + +async def async_maybe_transform( + data: object, + expected_type: object, +) -> Any | None: + """Wrapper over `async_transform()` that allows `None` to be passed. + + See `async_transform()` for more details. + """ + if data is None: + return None + return await async_transform(data, expected_type) + + +async def async_transform( + data: _T, + expected_type: object, +) -> _T: + """Transform dictionaries based off of type information from the given type, for example: + + ```py + class Params(TypedDict, total=False): + card_id: Required[Annotated[str, PropertyInfo(alias="cardID")]] + + + transformed = transform({"card_id": ""}, Params) + # {'cardID': ''} + ``` + + Any keys / data that does not have type information given will be included as is. + + It should be noted that the transformations that this function does are not represented in the type system. + """ + transformed = await _async_transform_recursive(data, annotation=cast(type, expected_type)) + return cast(_T, transformed) + + +async def _async_transform_recursive( + data: object, + *, + annotation: type, + inner_type: type | None = None, +) -> object: + """Transform the given data against the expected type. + + Args: + annotation: The direct type annotation given to the particular piece of data. + This may or may not be wrapped in metadata types, e.g. `Required[T]`, `Annotated[T, ...]` etc + + inner_type: If applicable, this is the "inside" type. This is useful in certain cases where the outside type + is a container type such as `List[T]`. In that case `inner_type` should be set to `T` so that each entry in + the list can be transformed using the metadata from the container type. + + Defaults to the same value as the `annotation` argument. + """ + from .._compat import model_dump + + if inner_type is None: + inner_type = annotation + + stripped_type = strip_annotated_type(inner_type) + origin = get_origin(stripped_type) or stripped_type + if is_typeddict(stripped_type) and is_mapping(data): + return await _async_transform_typeddict(data, stripped_type) + + if origin == dict and is_mapping(data): + items_type = get_args(stripped_type)[1] + return {key: _transform_recursive(value, annotation=items_type) for key, value in data.items()} + + if ( + # List[T] + (is_list_type(stripped_type) and is_list(data)) + # Iterable[T] + or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str)) + # Sequence[T] + or (is_sequence_type(stripped_type) and is_sequence(data) and not isinstance(data, str)) + ): + # dicts are technically iterable, but it is an iterable on the keys of the dict and is not usually + # intended as an iterable, so we don't transform it. + if isinstance(data, dict): + return cast(object, data) + + inner_type = extract_type_arg(stripped_type, 0) + if _no_transform_needed(inner_type): + # for some types there is no need to transform anything, so we can get a small + # perf boost from skipping that work. + # + # but we still need to convert to a list to ensure the data is json-serializable + if is_list(data): + return data + return list(data) + + return [await _async_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data] + + if is_union_type(stripped_type): + # For union types we run the transformation against all subtypes to ensure that everything is transformed. + # + # TODO: there may be edge cases where the same normalized field name will transform to two different names + # in different subtypes. + for subtype in get_args(stripped_type): + data = await _async_transform_recursive(data, annotation=annotation, inner_type=subtype) + return data + + if isinstance(data, pydantic.BaseModel): + return model_dump(data, exclude_unset=True, mode="json") + + annotated_type = _get_annotated_type(annotation) + if annotated_type is None: + return data + + # ignore the first argument as it is the actual type + annotations = get_args(annotated_type)[1:] + for annotation in annotations: + if isinstance(annotation, PropertyInfo) and annotation.format is not None: + return await _async_format_data(data, annotation.format, annotation.format_template) + + return data + + +async def _async_format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object: + if isinstance(data, (date, datetime)): + if format_ == "iso8601": + return data.isoformat() + + if format_ == "custom" and format_template is not None: + return data.strftime(format_template) + + if format_ == "base64" and is_base64_file_input(data): + binary: str | bytes | None = None + + if isinstance(data, pathlib.Path): + binary = await anyio.Path(data).read_bytes() + elif isinstance(data, io.IOBase): + binary = data.read() + + if isinstance(binary, str): # type: ignore[unreachable] + binary = binary.encode() + + if not isinstance(binary, bytes): + raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}") + + return base64.b64encode(binary).decode("ascii") + + return data + + +async def _async_transform_typeddict( + data: Mapping[str, object], + expected_type: type, +) -> Mapping[str, object]: + result: dict[str, object] = {} + annotations = get_type_hints(expected_type, include_extras=True) + for key, value in data.items(): + if not is_given(value): + # we don't need to include omitted values here as they'll + # be stripped out before the request is sent anyway + continue + + type_ = annotations.get(key) + if type_ is None: + # we do not have a type annotation for this field, leave it as is + result[key] = value + else: + result[_maybe_transform_key(key, type_)] = await _async_transform_recursive(value, annotation=type_) + return result + + +@lru_cache(maxsize=8096) +def get_type_hints( + obj: Any, + globalns: dict[str, Any] | None = None, + localns: Mapping[str, Any] | None = None, + include_extras: bool = False, +) -> dict[str, Any]: + return _get_type_hints(obj, globalns=globalns, localns=localns, include_extras=include_extras) diff --git a/src/writerai/_utils/_typing.py b/src/writerai/_utils/_typing.py new file mode 100644 index 00000000..193109f3 --- /dev/null +++ b/src/writerai/_utils/_typing.py @@ -0,0 +1,156 @@ +from __future__ import annotations + +import sys +import typing +import typing_extensions +from typing import Any, TypeVar, Iterable, cast +from collections import abc as _c_abc +from typing_extensions import ( + TypeIs, + Required, + Annotated, + get_args, + get_origin, +) + +from ._utils import lru_cache +from .._types import InheritsGeneric +from ._compat import is_union as _is_union + + +def is_annotated_type(typ: type) -> bool: + return get_origin(typ) == Annotated + + +def is_list_type(typ: type) -> bool: + return (get_origin(typ) or typ) == list + + +def is_sequence_type(typ: type) -> bool: + origin = get_origin(typ) or typ + return origin == typing_extensions.Sequence or origin == typing.Sequence or origin == _c_abc.Sequence + + +def is_iterable_type(typ: type) -> bool: + """If the given type is `typing.Iterable[T]`""" + origin = get_origin(typ) or typ + return origin == Iterable or origin == _c_abc.Iterable + + +def is_union_type(typ: type) -> bool: + return _is_union(get_origin(typ)) + + +def is_required_type(typ: type) -> bool: + return get_origin(typ) == Required + + +def is_typevar(typ: type) -> bool: + # type ignore is required because type checkers + # think this expression will always return False + return type(typ) == TypeVar # type: ignore + + +_TYPE_ALIAS_TYPES: tuple[type[typing_extensions.TypeAliasType], ...] = (typing_extensions.TypeAliasType,) +if sys.version_info >= (3, 12): + _TYPE_ALIAS_TYPES = (*_TYPE_ALIAS_TYPES, typing.TypeAliasType) + + +def is_type_alias_type(tp: Any, /) -> TypeIs[typing_extensions.TypeAliasType]: + """Return whether the provided argument is an instance of `TypeAliasType`. + + ```python + type Int = int + is_type_alias_type(Int) + # > True + Str = TypeAliasType("Str", str) + is_type_alias_type(Str) + # > True + ``` + """ + return isinstance(tp, _TYPE_ALIAS_TYPES) + + +# Extracts T from Annotated[T, ...] or from Required[Annotated[T, ...]] +@lru_cache(maxsize=8096) +def strip_annotated_type(typ: type) -> type: + if is_required_type(typ) or is_annotated_type(typ): + return strip_annotated_type(cast(type, get_args(typ)[0])) + + return typ + + +def extract_type_arg(typ: type, index: int) -> type: + args = get_args(typ) + try: + return cast(type, args[index]) + except IndexError as err: + raise RuntimeError(f"Expected type {typ} to have a type argument at index {index} but it did not") from err + + +def extract_type_var_from_base( + typ: type, + *, + generic_bases: tuple[type, ...], + index: int, + failure_message: str | None = None, +) -> type: + """Given a type like `Foo[T]`, returns the generic type variable `T`. + + This also handles the case where a concrete subclass is given, e.g. + ```py + class MyResponse(Foo[bytes]): + ... + + extract_type_var(MyResponse, bases=(Foo,), index=0) -> bytes + ``` + + And where a generic subclass is given: + ```py + _T = TypeVar('_T') + class MyResponse(Foo[_T]): + ... + + extract_type_var(MyResponse[bytes], bases=(Foo,), index=0) -> bytes + ``` + """ + cls = cast(object, get_origin(typ) or typ) + if cls in generic_bases: # pyright: ignore[reportUnnecessaryContains] + # we're given the class directly + return extract_type_arg(typ, index) + + # if a subclass is given + # --- + # this is needed as __orig_bases__ is not present in the typeshed stubs + # because it is intended to be for internal use only, however there does + # not seem to be a way to resolve generic TypeVars for inherited subclasses + # without using it. + if isinstance(cls, InheritsGeneric): + target_base_class: Any | None = None + for base in cls.__orig_bases__: + if base.__origin__ in generic_bases: + target_base_class = base + break + + if target_base_class is None: + raise RuntimeError( + "Could not find the generic base class;\n" + "This should never happen;\n" + f"Does {cls} inherit from one of {generic_bases} ?" + ) + + extracted = extract_type_arg(target_base_class, index) + if is_typevar(extracted): + # If the extracted type argument is itself a type variable + # then that means the subclass itself is generic, so we have + # to resolve the type argument from the class itself, not + # the base class. + # + # Note: if there is more than 1 type argument, the subclass could + # change the ordering of the type arguments, this is not currently + # supported. + return extract_type_arg(typ, index) + + return extracted + + raise RuntimeError(failure_message or f"Could not resolve inner type variable at index {index} for {typ}") diff --git a/src/writerai/_utils/_utils.py b/src/writerai/_utils/_utils.py new file mode 100644 index 00000000..199cd231 --- /dev/null +++ b/src/writerai/_utils/_utils.py @@ -0,0 +1,433 @@ +from __future__ import annotations + +import os +import re +import inspect +import functools +from typing import ( + Any, + Tuple, + Mapping, + TypeVar, + Callable, + Iterable, + Sequence, + cast, + overload, +) +from pathlib import Path +from datetime import date, datetime +from typing_extensions import TypeGuard, get_args + +import sniffio + +from .._types import Omit, NotGiven, FileTypes, ArrayFormat, HeadersLike + +_T = TypeVar("_T") +_TupleT = TypeVar("_TupleT", bound=Tuple[object, ...]) +_MappingT = TypeVar("_MappingT", bound=Mapping[str, object]) +_SequenceT = TypeVar("_SequenceT", bound=Sequence[object]) +CallableT = TypeVar("CallableT", bound=Callable[..., Any]) + + +def flatten(t: Iterable[Iterable[_T]]) -> list[_T]: + return [item for sublist in t for item in sublist] + + +def extract_files( + # TODO: this needs to take Dict but variance issues..... + # create protocol type ? + query: Mapping[str, object], + *, + paths: Sequence[Sequence[str]], + array_format: ArrayFormat = "brackets", +) -> list[tuple[str, FileTypes]]: + """Recursively extract files from the given dictionary based on specified paths. + + A path may look like this ['foo', 'files', '', 'data']. + + ``array_format`` controls how ```` segments contribute to the emitted + field name. Supported values: ``"brackets"`` (``foo[]``), ``"repeat"`` and + ``"comma"`` (``foo``), ``"indices"`` (``foo[0]``, ``foo[1]``). + + Note: this mutates the given dictionary. + """ + files: list[tuple[str, FileTypes]] = [] + for path in paths: + files.extend(_extract_items(query, path, index=0, flattened_key=None, array_format=array_format)) + return files + + +def _array_suffix(array_format: ArrayFormat, array_index: int) -> str: + if array_format == "brackets": + return "[]" + if array_format == "indices": + return f"[{array_index}]" + if array_format == "repeat" or array_format == "comma": + # Both repeat the bare field name for each file part; there is no + # meaningful way to comma-join binary parts. + return "" + raise NotImplementedError( + f"Unknown array_format value: {array_format}, choose from {', '.join(get_args(ArrayFormat))}" + ) + + +def _extract_items( + obj: object, + path: Sequence[str], + *, + index: int, + flattened_key: str | None, + array_format: ArrayFormat, +) -> list[tuple[str, FileTypes]]: + try: + key = path[index] + except IndexError: + if not is_given(obj): + # no value was provided - we can safely ignore + return [] + + # cyclical import + from .._files import assert_is_file_content + + # We have exhausted the path, return the entry we found. + assert flattened_key is not None + + if is_list(obj): + files: list[tuple[str, FileTypes]] = [] + for array_index, entry in enumerate(obj): + suffix = _array_suffix(array_format, array_index) + emitted_key = (flattened_key + suffix) if flattened_key else suffix + assert_is_file_content(entry, key=emitted_key) + files.append((emitted_key, cast(FileTypes, entry))) + return files + + assert_is_file_content(obj, key=flattened_key) + return [(flattened_key, cast(FileTypes, obj))] + + index += 1 + if is_dict(obj): + try: + # Remove the field if there are no more dict keys in the path, + # only "" traversal markers or end. + if all(p == "" for p in path[index:]): + item = obj.pop(key) + else: + item = obj[key] + except KeyError: + # Key was not present in the dictionary, this is not indicative of an error + # as the given path may not point to a required field. We also do not want + # to enforce required fields as the API may differ from the spec in some cases. + return [] + if flattened_key is None: + flattened_key = key + else: + flattened_key += f"[{key}]" + return _extract_items( + item, + path, + index=index, + flattened_key=flattened_key, + array_format=array_format, + ) + elif is_list(obj): + if key != "": + return [] + + return flatten( + [ + _extract_items( + item, + path, + index=index, + flattened_key=( + (flattened_key if flattened_key is not None else "") + _array_suffix(array_format, array_index) + ), + array_format=array_format, + ) + for array_index, item in enumerate(obj) + ] + ) + + # Something unexpected was passed, just ignore it. + return [] + + +def is_given(obj: _T | NotGiven | Omit) -> TypeGuard[_T]: + return not isinstance(obj, NotGiven) and not isinstance(obj, Omit) + + +# Type safe methods for narrowing types with TypeVars. +# The default narrowing for isinstance(obj, dict) is dict[unknown, unknown], +# however this cause Pyright to rightfully report errors. As we know we don't +# care about the contained types we can safely use `object` in its place. +# +# There are two separate functions defined, `is_*` and `is_*_t` for different use cases. +# `is_*` is for when you're dealing with an unknown input +# `is_*_t` is for when you're narrowing a known union type to a specific subset + + +def is_tuple(obj: object) -> TypeGuard[tuple[object, ...]]: + return isinstance(obj, tuple) + + +def is_tuple_t(obj: _TupleT | object) -> TypeGuard[_TupleT]: + return isinstance(obj, tuple) + + +def is_sequence(obj: object) -> TypeGuard[Sequence[object]]: + return isinstance(obj, Sequence) + + +def is_sequence_t(obj: _SequenceT | object) -> TypeGuard[_SequenceT]: + return isinstance(obj, Sequence) + + +def is_mapping(obj: object) -> TypeGuard[Mapping[str, object]]: + return isinstance(obj, Mapping) + + +def is_mapping_t(obj: _MappingT | object) -> TypeGuard[_MappingT]: + return isinstance(obj, Mapping) + + +def is_dict(obj: object) -> TypeGuard[dict[object, object]]: + return isinstance(obj, dict) + + +def is_list(obj: object) -> TypeGuard[list[object]]: + return isinstance(obj, list) + + +def is_iterable(obj: object) -> TypeGuard[Iterable[object]]: + return isinstance(obj, Iterable) + + +# copied from https://github.com/Rapptz/RoboDanny +def human_join(seq: Sequence[str], *, delim: str = ", ", final: str = "or") -> str: + size = len(seq) + if size == 0: + return "" + + if size == 1: + return seq[0] + + if size == 2: + return f"{seq[0]} {final} {seq[1]}" + + return delim.join(seq[:-1]) + f" {final} {seq[-1]}" + + +def quote(string: str) -> str: + """Add single quotation marks around the given string. Does *not* do any escaping.""" + return f"'{string}'" + + +def required_args(*variants: Sequence[str]) -> Callable[[CallableT], CallableT]: + """Decorator to enforce a given set of arguments or variants of arguments are passed to the decorated function. + + Useful for enforcing runtime validation of overloaded functions. + + Example usage: + ```py + @overload + def foo(*, a: str) -> str: ... + + + @overload + def foo(*, b: bool) -> str: ... + + + # This enforces the same constraints that a static type checker would + # i.e. that either a or b must be passed to the function + @required_args(["a"], ["b"]) + def foo(*, a: str | None = None, b: bool | None = None) -> str: ... + ``` + """ + + def inner(func: CallableT) -> CallableT: + params = inspect.signature(func).parameters + positional = [ + name + for name, param in params.items() + if param.kind + in { + param.POSITIONAL_ONLY, + param.POSITIONAL_OR_KEYWORD, + } + ] + + @functools.wraps(func) + def wrapper(*args: object, **kwargs: object) -> object: + given_params: set[str] = set() + for i, _ in enumerate(args): + try: + given_params.add(positional[i]) + except IndexError: + raise TypeError( + f"{func.__name__}() takes {len(positional)} argument(s) but {len(args)} were given" + ) from None + + for key in kwargs.keys(): + given_params.add(key) + + for variant in variants: + matches = all((param in given_params for param in variant)) + if matches: + break + else: # no break + if len(variants) > 1: + variations = human_join( + ["(" + human_join([quote(arg) for arg in variant], final="and") + ")" for variant in variants] + ) + msg = f"Missing required arguments; Expected either {variations} arguments to be given" + else: + assert len(variants) > 0 + + # TODO: this error message is not deterministic + missing = list(set(variants[0]) - given_params) + if len(missing) > 1: + msg = f"Missing required arguments: {human_join([quote(arg) for arg in missing])}" + else: + msg = f"Missing required argument: {quote(missing[0])}" + raise TypeError(msg) + return func(*args, **kwargs) + + return wrapper # type: ignore + + return inner + + +_K = TypeVar("_K") +_V = TypeVar("_V") + + +@overload +def strip_not_given(obj: None) -> None: ... + + +@overload +def strip_not_given(obj: Mapping[_K, _V | NotGiven]) -> dict[_K, _V]: ... + + +@overload +def strip_not_given(obj: object) -> object: ... + + +def strip_not_given(obj: object | None) -> object: + """Remove all top-level keys where their values are instances of `NotGiven`""" + if obj is None: + return None + + if not is_mapping(obj): + return obj + + return {key: value for key, value in obj.items() if not isinstance(value, NotGiven)} + + +def coerce_integer(val: str) -> int: + return int(val, base=10) + + +def coerce_float(val: str) -> float: + return float(val) + + +def coerce_boolean(val: str) -> bool: + return val == "true" or val == "1" or val == "on" + + +def maybe_coerce_integer(val: str | None) -> int | None: + if val is None: + return None + return coerce_integer(val) + + +def maybe_coerce_float(val: str | None) -> float | None: + if val is None: + return None + return coerce_float(val) + + +def maybe_coerce_boolean(val: str | None) -> bool | None: + if val is None: + return None + return coerce_boolean(val) + + +def removeprefix(string: str, prefix: str) -> str: + """Remove a prefix from a string. + + Backport of `str.removeprefix` for Python < 3.9 + """ + if string.startswith(prefix): + return string[len(prefix) :] + return string + + +def removesuffix(string: str, suffix: str) -> str: + """Remove a suffix from a string. + + Backport of `str.removesuffix` for Python < 3.9 + """ + if string.endswith(suffix): + return string[: -len(suffix)] + return string + + +def file_from_path(path: str) -> FileTypes: + contents = Path(path).read_bytes() + file_name = os.path.basename(path) + return (file_name, contents) + + +def get_required_header(headers: HeadersLike, header: str) -> str: + lower_header = header.lower() + if is_mapping_t(headers): + # mypy doesn't understand the type narrowing here + for k, v in headers.items(): # type: ignore + if k.lower() == lower_header and isinstance(v, str): + return v + + # to deal with the case where the header looks like Stainless-Event-Id + intercaps_header = re.sub(r"([^\w])(\w)", lambda pat: pat.group(1) + pat.group(2).upper(), header.capitalize()) + + for normalized_header in [header, lower_header, header.upper(), intercaps_header]: + value = headers.get(normalized_header) + if value: + return value + + raise ValueError(f"Could not find {header} header") + + +def get_async_library() -> str: + try: + return sniffio.current_async_library() + except Exception: + return "false" + + +def lru_cache(*, maxsize: int | None = 128) -> Callable[[CallableT], CallableT]: + """A version of functools.lru_cache that retains the type signature + for the wrapped function arguments. + """ + wrapper = functools.lru_cache( # noqa: TID251 + maxsize=maxsize, + ) + return cast(Any, wrapper) # type: ignore[no-any-return] + + +def json_safe(data: object) -> object: + """Translates a mapping / sequence recursively in the same fashion + as `pydantic` v2's `model_dump(mode="json")`. + """ + if is_mapping(data): + return {json_safe(key): json_safe(value) for key, value in data.items()} + + if is_iterable(data) and not isinstance(data, (str, bytes, bytearray)): + return [json_safe(item) for item in data] + + if isinstance(data, (datetime, date)): + return data.isoformat() + + return data diff --git a/src/writerai/_version.py b/src/writerai/_version.py new file mode 100644 index 00000000..450bd395 --- /dev/null +++ b/src/writerai/_version.py @@ -0,0 +1,4 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +__title__ = "writerai" +__version__ = "3.0.0" # x-release-please-version diff --git a/src/writerai/lib/.keep b/src/writerai/lib/.keep new file mode 100644 index 00000000..5e2c99fd --- /dev/null +++ b/src/writerai/lib/.keep @@ -0,0 +1,4 @@ +File generated from our OpenAPI spec by Stainless. + +This directory can be used to store custom files to expand the SDK. +It is ignored by Stainless code generation and its content (other than this keep file) won't be touched. \ No newline at end of file diff --git a/src/writerai/pagination.py b/src/writerai/pagination.py new file mode 100644 index 00000000..d8534f51 --- /dev/null +++ b/src/writerai/pagination.py @@ -0,0 +1,178 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import Any, List, Generic, TypeVar, Optional, cast +from typing_extensions import Protocol, override, runtime_checkable + +from pydantic import Field as FieldInfo + +from ._models import BaseModel +from ._base_client import BasePage, PageInfo, BaseSyncPage, BaseAsyncPage + +__all__ = [ + "SyncCursorPage", + "AsyncCursorPage", + "ApplicationJobsOffsetPagination", + "SyncApplicationJobsOffset", + "AsyncApplicationJobsOffset", +] + +_T = TypeVar("_T") + + +@runtime_checkable +class CursorPageItem(Protocol): + id: Optional[str] + + +class SyncCursorPage(BaseSyncPage[_T], BasePage[_T], Generic[_T]): + data: List[_T] + has_more: bool + + @override + def _get_page_items(self) -> List[_T]: + data = self.data + if not data: + return [] + return data + + @override + def has_next_page(self) -> bool: + has_more = self.has_more + return has_more and super().has_next_page() + + @override + def next_page_info(self) -> Optional[PageInfo]: + is_forwards = not self._options.params.get("before", False) + + data = self.data + if not data: + return None + + if is_forwards: + item = cast(Any, data[-1]) + if not isinstance(item, CursorPageItem) or item.id is None: + # TODO emit warning log + return None + + return PageInfo(params={"after": item.id}) + else: + item = cast(Any, self.data[0]) + if not isinstance(item, CursorPageItem) or item.id is None: + # TODO emit warning log + return None + + return PageInfo(params={"before": item.id}) + + +class AsyncCursorPage(BaseAsyncPage[_T], BasePage[_T], Generic[_T]): + data: List[_T] + has_more: bool + + @override + def _get_page_items(self) -> List[_T]: + data = self.data + if not data: + return [] + return data + + @override + def has_next_page(self) -> bool: + has_more = self.has_more + return has_more and super().has_next_page() + + @override + def next_page_info(self) -> Optional[PageInfo]: + is_forwards = not self._options.params.get("before", False) + + data = self.data + if not data: + return None + + if is_forwards: + item = cast(Any, data[-1]) + if not isinstance(item, CursorPageItem) or item.id is None: + # TODO emit warning log + return None + + return PageInfo(params={"after": item.id}) + else: + item = cast(Any, self.data[0]) + if not isinstance(item, CursorPageItem) or item.id is None: + # TODO emit warning log + return None + + return PageInfo(params={"before": item.id}) + + +class ApplicationJobsOffsetPagination(BaseModel): + limit: Optional[int] = None + + offset: Optional[int] = None + + +class SyncApplicationJobsOffset(BaseSyncPage[_T], BasePage[_T], Generic[_T]): + result: List[_T] + total_count: Optional[int] = FieldInfo(alias="totalCount", default=None) + pagination: Optional[ApplicationJobsOffsetPagination] = None + + @override + def _get_page_items(self) -> List[_T]: + result = self.result + if not result: + return [] + return result + + @override + def next_page_info(self) -> Optional[PageInfo]: + offset = None + if self.pagination is not None: + if self.pagination.offset is not None: + offset = self.pagination.offset + if offset is None: + return None # type: ignore[unreachable] + + length = len(self._get_page_items()) + current_count = offset + length + + total_count = self.total_count + if total_count is None: + return None + + if current_count < total_count: + return PageInfo(params={"offset": current_count}) + + return None + + +class AsyncApplicationJobsOffset(BaseAsyncPage[_T], BasePage[_T], Generic[_T]): + result: List[_T] + total_count: Optional[int] = FieldInfo(alias="totalCount", default=None) + pagination: Optional[ApplicationJobsOffsetPagination] = None + + @override + def _get_page_items(self) -> List[_T]: + result = self.result + if not result: + return [] + return result + + @override + def next_page_info(self) -> Optional[PageInfo]: + offset = None + if self.pagination is not None: + if self.pagination.offset is not None: + offset = self.pagination.offset + if offset is None: + return None # type: ignore[unreachable] + + length = len(self._get_page_items()) + current_count = offset + length + + total_count = self.total_count + if total_count is None: + return None + + if current_count < total_count: + return PageInfo(params={"offset": current_count}) + + return None diff --git a/src/writerai/py.typed b/src/writerai/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/src/writerai/resources/__init__.py b/src/writerai/resources/__init__.py new file mode 100644 index 00000000..767c0120 --- /dev/null +++ b/src/writerai/resources/__init__.py @@ -0,0 +1,131 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from .chat import ( + ChatResource, + AsyncChatResource, + ChatResourceWithRawResponse, + AsyncChatResourceWithRawResponse, + ChatResourceWithStreamingResponse, + AsyncChatResourceWithStreamingResponse, +) +from .files import ( + FilesResource, + AsyncFilesResource, + FilesResourceWithRawResponse, + AsyncFilesResourceWithRawResponse, + FilesResourceWithStreamingResponse, + AsyncFilesResourceWithStreamingResponse, +) +from .tools import ( + ToolsResource, + AsyncToolsResource, + ToolsResourceWithRawResponse, + AsyncToolsResourceWithRawResponse, + ToolsResourceWithStreamingResponse, + AsyncToolsResourceWithStreamingResponse, +) +from .graphs import ( + GraphsResource, + AsyncGraphsResource, + GraphsResourceWithRawResponse, + AsyncGraphsResourceWithRawResponse, + GraphsResourceWithStreamingResponse, + AsyncGraphsResourceWithStreamingResponse, +) +from .models import ( + ModelsResource, + AsyncModelsResource, + ModelsResourceWithRawResponse, + AsyncModelsResourceWithRawResponse, + ModelsResourceWithStreamingResponse, + AsyncModelsResourceWithStreamingResponse, +) +from .vision import ( + VisionResource, + AsyncVisionResource, + VisionResourceWithRawResponse, + AsyncVisionResourceWithRawResponse, + VisionResourceWithStreamingResponse, + AsyncVisionResourceWithStreamingResponse, +) +from .completions import ( + CompletionsResource, + AsyncCompletionsResource, + CompletionsResourceWithRawResponse, + AsyncCompletionsResourceWithRawResponse, + CompletionsResourceWithStreamingResponse, + AsyncCompletionsResourceWithStreamingResponse, +) +from .translation import ( + TranslationResource, + AsyncTranslationResource, + TranslationResourceWithRawResponse, + AsyncTranslationResourceWithRawResponse, + TranslationResourceWithStreamingResponse, + AsyncTranslationResourceWithStreamingResponse, +) +from .applications import ( + ApplicationsResource, + AsyncApplicationsResource, + ApplicationsResourceWithRawResponse, + AsyncApplicationsResourceWithRawResponse, + ApplicationsResourceWithStreamingResponse, + AsyncApplicationsResourceWithStreamingResponse, +) + +__all__ = [ + "ApplicationsResource", + "AsyncApplicationsResource", + "ApplicationsResourceWithRawResponse", + "AsyncApplicationsResourceWithRawResponse", + "ApplicationsResourceWithStreamingResponse", + "AsyncApplicationsResourceWithStreamingResponse", + "ChatResource", + "AsyncChatResource", + "ChatResourceWithRawResponse", + "AsyncChatResourceWithRawResponse", + "ChatResourceWithStreamingResponse", + "AsyncChatResourceWithStreamingResponse", + "CompletionsResource", + "AsyncCompletionsResource", + "CompletionsResourceWithRawResponse", + "AsyncCompletionsResourceWithRawResponse", + "CompletionsResourceWithStreamingResponse", + "AsyncCompletionsResourceWithStreamingResponse", + "ModelsResource", + "AsyncModelsResource", + "ModelsResourceWithRawResponse", + "AsyncModelsResourceWithRawResponse", + "ModelsResourceWithStreamingResponse", + "AsyncModelsResourceWithStreamingResponse", + "GraphsResource", + "AsyncGraphsResource", + "GraphsResourceWithRawResponse", + "AsyncGraphsResourceWithRawResponse", + "GraphsResourceWithStreamingResponse", + "AsyncGraphsResourceWithStreamingResponse", + "FilesResource", + "AsyncFilesResource", + "FilesResourceWithRawResponse", + "AsyncFilesResourceWithRawResponse", + "FilesResourceWithStreamingResponse", + "AsyncFilesResourceWithStreamingResponse", + "ToolsResource", + "AsyncToolsResource", + "ToolsResourceWithRawResponse", + "AsyncToolsResourceWithRawResponse", + "ToolsResourceWithStreamingResponse", + "AsyncToolsResourceWithStreamingResponse", + "TranslationResource", + "AsyncTranslationResource", + "TranslationResourceWithRawResponse", + "AsyncTranslationResourceWithRawResponse", + "TranslationResourceWithStreamingResponse", + "AsyncTranslationResourceWithStreamingResponse", + "VisionResource", + "AsyncVisionResource", + "VisionResourceWithRawResponse", + "AsyncVisionResourceWithRawResponse", + "VisionResourceWithStreamingResponse", + "AsyncVisionResourceWithStreamingResponse", +] diff --git a/src/writerai/resources/applications/__init__.py b/src/writerai/resources/applications/__init__.py new file mode 100644 index 00000000..ab99e9c1 --- /dev/null +++ b/src/writerai/resources/applications/__init__.py @@ -0,0 +1,47 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from .jobs import ( + JobsResource, + AsyncJobsResource, + JobsResourceWithRawResponse, + AsyncJobsResourceWithRawResponse, + JobsResourceWithStreamingResponse, + AsyncJobsResourceWithStreamingResponse, +) +from .graphs import ( + GraphsResource, + AsyncGraphsResource, + GraphsResourceWithRawResponse, + AsyncGraphsResourceWithRawResponse, + GraphsResourceWithStreamingResponse, + AsyncGraphsResourceWithStreamingResponse, +) +from .applications import ( + ApplicationsResource, + AsyncApplicationsResource, + ApplicationsResourceWithRawResponse, + AsyncApplicationsResourceWithRawResponse, + ApplicationsResourceWithStreamingResponse, + AsyncApplicationsResourceWithStreamingResponse, +) + +__all__ = [ + "JobsResource", + "AsyncJobsResource", + "JobsResourceWithRawResponse", + "AsyncJobsResourceWithRawResponse", + "JobsResourceWithStreamingResponse", + "AsyncJobsResourceWithStreamingResponse", + "GraphsResource", + "AsyncGraphsResource", + "GraphsResourceWithRawResponse", + "AsyncGraphsResourceWithRawResponse", + "GraphsResourceWithStreamingResponse", + "AsyncGraphsResourceWithStreamingResponse", + "ApplicationsResource", + "AsyncApplicationsResource", + "ApplicationsResourceWithRawResponse", + "AsyncApplicationsResourceWithRawResponse", + "ApplicationsResourceWithStreamingResponse", + "AsyncApplicationsResourceWithStreamingResponse", +] diff --git a/src/writerai/resources/applications/applications.py b/src/writerai/resources/applications/applications.py new file mode 100644 index 00000000..8c13bb40 --- /dev/null +++ b/src/writerai/resources/applications/applications.py @@ -0,0 +1,645 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing import Iterable +from typing_extensions import Literal, overload + +import httpx + +from .jobs import ( + JobsResource, + AsyncJobsResource, + JobsResourceWithRawResponse, + AsyncJobsResourceWithRawResponse, + JobsResourceWithStreamingResponse, + AsyncJobsResourceWithStreamingResponse, +) +from .graphs import ( + GraphsResource, + AsyncGraphsResource, + GraphsResourceWithRawResponse, + AsyncGraphsResourceWithRawResponse, + GraphsResourceWithStreamingResponse, + AsyncGraphsResourceWithStreamingResponse, +) +from ...types import application_list_params, application_generate_content_params +from ..._types import Body, Omit, Query, Headers, NotGiven, omit, not_given +from ..._utils import path_template, required_args, maybe_transform, async_maybe_transform +from ..._compat import cached_property +from ..._resource import SyncAPIResource, AsyncAPIResource +from ..._response import ( + to_raw_response_wrapper, + to_streamed_response_wrapper, + async_to_raw_response_wrapper, + async_to_streamed_response_wrapper, +) +from ..._streaming import Stream, AsyncStream +from ...pagination import SyncCursorPage, AsyncCursorPage +from ..._base_client import AsyncPaginator, make_request_options +from ...types.application_list_response import ApplicationListResponse +from ...types.application_retrieve_response import ApplicationRetrieveResponse +from ...types.application_generate_content_chunk import ApplicationGenerateContentChunk +from ...types.application_generate_content_response import ApplicationGenerateContentResponse + +__all__ = ["ApplicationsResource", "AsyncApplicationsResource"] + + +class ApplicationsResource(SyncAPIResource): + @cached_property + def jobs(self) -> JobsResource: + return JobsResource(self._client) + + @cached_property + def graphs(self) -> GraphsResource: + return GraphsResource(self._client) + + @cached_property + def with_raw_response(self) -> ApplicationsResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/writer/writer-python#accessing-raw-response-data-eg-headers + """ + return ApplicationsResourceWithRawResponse(self) + + @cached_property + def with_streaming_response(self) -> ApplicationsResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/writer/writer-python#with_streaming_response + """ + return ApplicationsResourceWithStreamingResponse(self) + + def retrieve( + self, + application_id: str, + *, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> ApplicationRetrieveResponse: + """ + Retrieves detailed information for a specific no-code agent (formerly called + no-code applications), including its configuration and current status. + + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not application_id: + raise ValueError(f"Expected a non-empty value for `application_id` but received {application_id!r}") + return self._get( + path_template("/v1/applications/{application_id}", application_id=application_id), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=ApplicationRetrieveResponse, + ) + + def list( + self, + *, + after: str | Omit = omit, + before: str | Omit = omit, + limit: int | Omit = omit, + order: Literal["asc", "desc"] | Omit = omit, + type: Literal["generation"] | Omit = omit, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> SyncCursorPage[ApplicationListResponse]: + """ + Retrieves a paginated list of no-code agents (formerly called no-code + applications) with optional filtering and sorting capabilities. + + Args: + after: Return results after this application ID for pagination. + + before: Return results before this application ID for pagination. + + limit: Maximum number of applications to return in the response. + + order: Sort order for the results based on creation time. + + type: Filter applications by their type. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + return self._get_api_list( + "/v1/applications", + page=SyncCursorPage[ApplicationListResponse], + options=make_request_options( + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + query=maybe_transform( + { + "after": after, + "before": before, + "limit": limit, + "order": order, + "type": type, + }, + application_list_params.ApplicationListParams, + ), + ), + model=ApplicationListResponse, + ) + + @overload + def generate_content( + self, + application_id: str, + *, + inputs: Iterable[application_generate_content_params.Input], + stream: Literal[False] | Omit = omit, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> ApplicationGenerateContentResponse: + """ + Generate content from an existing no-code agent (formerly called no-code + applications) with inputs. + + Args: + stream: Indicates whether the response should be streamed. Currently only supported for + research assistant applications. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + ... + + @overload + def generate_content( + self, + application_id: str, + *, + inputs: Iterable[application_generate_content_params.Input], + stream: Literal[True], + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> Stream[ApplicationGenerateContentChunk]: + """ + Generate content from an existing no-code agent (formerly called no-code + applications) with inputs. + + Args: + stream: Indicates whether the response should be streamed. Currently only supported for + research assistant applications. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + ... + + @overload + def generate_content( + self, + application_id: str, + *, + inputs: Iterable[application_generate_content_params.Input], + stream: bool, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> ApplicationGenerateContentResponse | Stream[ApplicationGenerateContentChunk]: + """ + Generate content from an existing no-code agent (formerly called no-code + applications) with inputs. + + Args: + stream: Indicates whether the response should be streamed. Currently only supported for + research assistant applications. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + ... + + @required_args(["inputs"], ["inputs", "stream"]) + def generate_content( + self, + application_id: str, + *, + inputs: Iterable[application_generate_content_params.Input], + stream: Literal[False] | Literal[True] | Omit = omit, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> ApplicationGenerateContentResponse | Stream[ApplicationGenerateContentChunk]: + if not application_id: + raise ValueError(f"Expected a non-empty value for `application_id` but received {application_id!r}") + return self._post( + path_template("/v1/applications/{application_id}", application_id=application_id), + body=maybe_transform( + { + "inputs": inputs, + "stream": stream, + }, + application_generate_content_params.ApplicationGenerateContentParamsStreaming + if stream + else application_generate_content_params.ApplicationGenerateContentParamsNonStreaming, + ), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=ApplicationGenerateContentResponse, + stream=stream or False, + stream_cls=Stream[ApplicationGenerateContentChunk], + ) + + +class AsyncApplicationsResource(AsyncAPIResource): + @cached_property + def jobs(self) -> AsyncJobsResource: + return AsyncJobsResource(self._client) + + @cached_property + def graphs(self) -> AsyncGraphsResource: + return AsyncGraphsResource(self._client) + + @cached_property + def with_raw_response(self) -> AsyncApplicationsResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/writer/writer-python#accessing-raw-response-data-eg-headers + """ + return AsyncApplicationsResourceWithRawResponse(self) + + @cached_property + def with_streaming_response(self) -> AsyncApplicationsResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/writer/writer-python#with_streaming_response + """ + return AsyncApplicationsResourceWithStreamingResponse(self) + + async def retrieve( + self, + application_id: str, + *, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> ApplicationRetrieveResponse: + """ + Retrieves detailed information for a specific no-code agent (formerly called + no-code applications), including its configuration and current status. + + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not application_id: + raise ValueError(f"Expected a non-empty value for `application_id` but received {application_id!r}") + return await self._get( + path_template("/v1/applications/{application_id}", application_id=application_id), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=ApplicationRetrieveResponse, + ) + + def list( + self, + *, + after: str | Omit = omit, + before: str | Omit = omit, + limit: int | Omit = omit, + order: Literal["asc", "desc"] | Omit = omit, + type: Literal["generation"] | Omit = omit, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> AsyncPaginator[ApplicationListResponse, AsyncCursorPage[ApplicationListResponse]]: + """ + Retrieves a paginated list of no-code agents (formerly called no-code + applications) with optional filtering and sorting capabilities. + + Args: + after: Return results after this application ID for pagination. + + before: Return results before this application ID for pagination. + + limit: Maximum number of applications to return in the response. + + order: Sort order for the results based on creation time. + + type: Filter applications by their type. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + return self._get_api_list( + "/v1/applications", + page=AsyncCursorPage[ApplicationListResponse], + options=make_request_options( + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + query=maybe_transform( + { + "after": after, + "before": before, + "limit": limit, + "order": order, + "type": type, + }, + application_list_params.ApplicationListParams, + ), + ), + model=ApplicationListResponse, + ) + + @overload + async def generate_content( + self, + application_id: str, + *, + inputs: Iterable[application_generate_content_params.Input], + stream: Literal[False] | Omit = omit, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> ApplicationGenerateContentResponse: + """ + Generate content from an existing no-code agent (formerly called no-code + applications) with inputs. + + Args: + stream: Indicates whether the response should be streamed. Currently only supported for + research assistant applications. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + ... + + @overload + async def generate_content( + self, + application_id: str, + *, + inputs: Iterable[application_generate_content_params.Input], + stream: Literal[True], + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> AsyncStream[ApplicationGenerateContentChunk]: + """ + Generate content from an existing no-code agent (formerly called no-code + applications) with inputs. + + Args: + stream: Indicates whether the response should be streamed. Currently only supported for + research assistant applications. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + ... + + @overload + async def generate_content( + self, + application_id: str, + *, + inputs: Iterable[application_generate_content_params.Input], + stream: bool, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> ApplicationGenerateContentResponse | AsyncStream[ApplicationGenerateContentChunk]: + """ + Generate content from an existing no-code agent (formerly called no-code + applications) with inputs. + + Args: + stream: Indicates whether the response should be streamed. Currently only supported for + research assistant applications. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + ... + + @required_args(["inputs"], ["inputs", "stream"]) + async def generate_content( + self, + application_id: str, + *, + inputs: Iterable[application_generate_content_params.Input], + stream: Literal[False] | Literal[True] | Omit = omit, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> ApplicationGenerateContentResponse | AsyncStream[ApplicationGenerateContentChunk]: + if not application_id: + raise ValueError(f"Expected a non-empty value for `application_id` but received {application_id!r}") + return await self._post( + path_template("/v1/applications/{application_id}", application_id=application_id), + body=await async_maybe_transform( + { + "inputs": inputs, + "stream": stream, + }, + application_generate_content_params.ApplicationGenerateContentParamsStreaming + if stream + else application_generate_content_params.ApplicationGenerateContentParamsNonStreaming, + ), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=ApplicationGenerateContentResponse, + stream=stream or False, + stream_cls=AsyncStream[ApplicationGenerateContentChunk], + ) + + +class ApplicationsResourceWithRawResponse: + def __init__(self, applications: ApplicationsResource) -> None: + self._applications = applications + + self.retrieve = to_raw_response_wrapper( + applications.retrieve, + ) + self.list = to_raw_response_wrapper( + applications.list, + ) + self.generate_content = to_raw_response_wrapper( + applications.generate_content, + ) + + @cached_property + def jobs(self) -> JobsResourceWithRawResponse: + return JobsResourceWithRawResponse(self._applications.jobs) + + @cached_property + def graphs(self) -> GraphsResourceWithRawResponse: + return GraphsResourceWithRawResponse(self._applications.graphs) + + +class AsyncApplicationsResourceWithRawResponse: + def __init__(self, applications: AsyncApplicationsResource) -> None: + self._applications = applications + + self.retrieve = async_to_raw_response_wrapper( + applications.retrieve, + ) + self.list = async_to_raw_response_wrapper( + applications.list, + ) + self.generate_content = async_to_raw_response_wrapper( + applications.generate_content, + ) + + @cached_property + def jobs(self) -> AsyncJobsResourceWithRawResponse: + return AsyncJobsResourceWithRawResponse(self._applications.jobs) + + @cached_property + def graphs(self) -> AsyncGraphsResourceWithRawResponse: + return AsyncGraphsResourceWithRawResponse(self._applications.graphs) + + +class ApplicationsResourceWithStreamingResponse: + def __init__(self, applications: ApplicationsResource) -> None: + self._applications = applications + + self.retrieve = to_streamed_response_wrapper( + applications.retrieve, + ) + self.list = to_streamed_response_wrapper( + applications.list, + ) + self.generate_content = to_streamed_response_wrapper( + applications.generate_content, + ) + + @cached_property + def jobs(self) -> JobsResourceWithStreamingResponse: + return JobsResourceWithStreamingResponse(self._applications.jobs) + + @cached_property + def graphs(self) -> GraphsResourceWithStreamingResponse: + return GraphsResourceWithStreamingResponse(self._applications.graphs) + + +class AsyncApplicationsResourceWithStreamingResponse: + def __init__(self, applications: AsyncApplicationsResource) -> None: + self._applications = applications + + self.retrieve = async_to_streamed_response_wrapper( + applications.retrieve, + ) + self.list = async_to_streamed_response_wrapper( + applications.list, + ) + self.generate_content = async_to_streamed_response_wrapper( + applications.generate_content, + ) + + @cached_property + def jobs(self) -> AsyncJobsResourceWithStreamingResponse: + return AsyncJobsResourceWithStreamingResponse(self._applications.jobs) + + @cached_property + def graphs(self) -> AsyncGraphsResourceWithStreamingResponse: + return AsyncGraphsResourceWithStreamingResponse(self._applications.graphs) diff --git a/src/writerai/resources/applications/graphs.py b/src/writerai/resources/applications/graphs.py new file mode 100644 index 00000000..212844a9 --- /dev/null +++ b/src/writerai/resources/applications/graphs.py @@ -0,0 +1,257 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +import httpx + +from ..._types import Body, Query, Headers, NotGiven, SequenceNotStr, not_given +from ..._utils import path_template, maybe_transform, async_maybe_transform +from ..._compat import cached_property +from ..._resource import SyncAPIResource, AsyncAPIResource +from ..._response import ( + to_raw_response_wrapper, + to_streamed_response_wrapper, + async_to_raw_response_wrapper, + async_to_streamed_response_wrapper, +) +from ..._base_client import make_request_options +from ...types.applications import graph_update_params +from ...types.applications.application_graphs_response import ApplicationGraphsResponse + +__all__ = ["GraphsResource", "AsyncGraphsResource"] + + +class GraphsResource(SyncAPIResource): + @cached_property + def with_raw_response(self) -> GraphsResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/writer/writer-python#accessing-raw-response-data-eg-headers + """ + return GraphsResourceWithRawResponse(self) + + @cached_property + def with_streaming_response(self) -> GraphsResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/writer/writer-python#with_streaming_response + """ + return GraphsResourceWithStreamingResponse(self) + + def update( + self, + application_id: str, + *, + graph_ids: SequenceNotStr[str], + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> ApplicationGraphsResponse: + """ + Updates the list of Knowledge Graphs associated with a no-code chat agent. + + Args: + graph_ids: A list of Knowledge Graph IDs to associate with the application. Note that this + will replace the existing list of Knowledge Graphs associated with the + application, not add to it. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not application_id: + raise ValueError(f"Expected a non-empty value for `application_id` but received {application_id!r}") + return self._put( + path_template("/v1/applications/{application_id}/graphs", application_id=application_id), + body=maybe_transform({"graph_ids": graph_ids}, graph_update_params.GraphUpdateParams), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=ApplicationGraphsResponse, + ) + + def list( + self, + application_id: str, + *, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> ApplicationGraphsResponse: + """ + Retrieve Knowledge Graphs associated with a no-code agent that has chat + capabilities. + + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not application_id: + raise ValueError(f"Expected a non-empty value for `application_id` but received {application_id!r}") + return self._get( + path_template("/v1/applications/{application_id}/graphs", application_id=application_id), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=ApplicationGraphsResponse, + ) + + +class AsyncGraphsResource(AsyncAPIResource): + @cached_property + def with_raw_response(self) -> AsyncGraphsResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/writer/writer-python#accessing-raw-response-data-eg-headers + """ + return AsyncGraphsResourceWithRawResponse(self) + + @cached_property + def with_streaming_response(self) -> AsyncGraphsResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/writer/writer-python#with_streaming_response + """ + return AsyncGraphsResourceWithStreamingResponse(self) + + async def update( + self, + application_id: str, + *, + graph_ids: SequenceNotStr[str], + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> ApplicationGraphsResponse: + """ + Updates the list of Knowledge Graphs associated with a no-code chat agent. + + Args: + graph_ids: A list of Knowledge Graph IDs to associate with the application. Note that this + will replace the existing list of Knowledge Graphs associated with the + application, not add to it. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not application_id: + raise ValueError(f"Expected a non-empty value for `application_id` but received {application_id!r}") + return await self._put( + path_template("/v1/applications/{application_id}/graphs", application_id=application_id), + body=await async_maybe_transform({"graph_ids": graph_ids}, graph_update_params.GraphUpdateParams), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=ApplicationGraphsResponse, + ) + + async def list( + self, + application_id: str, + *, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> ApplicationGraphsResponse: + """ + Retrieve Knowledge Graphs associated with a no-code agent that has chat + capabilities. + + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not application_id: + raise ValueError(f"Expected a non-empty value for `application_id` but received {application_id!r}") + return await self._get( + path_template("/v1/applications/{application_id}/graphs", application_id=application_id), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=ApplicationGraphsResponse, + ) + + +class GraphsResourceWithRawResponse: + def __init__(self, graphs: GraphsResource) -> None: + self._graphs = graphs + + self.update = to_raw_response_wrapper( + graphs.update, + ) + self.list = to_raw_response_wrapper( + graphs.list, + ) + + +class AsyncGraphsResourceWithRawResponse: + def __init__(self, graphs: AsyncGraphsResource) -> None: + self._graphs = graphs + + self.update = async_to_raw_response_wrapper( + graphs.update, + ) + self.list = async_to_raw_response_wrapper( + graphs.list, + ) + + +class GraphsResourceWithStreamingResponse: + def __init__(self, graphs: GraphsResource) -> None: + self._graphs = graphs + + self.update = to_streamed_response_wrapper( + graphs.update, + ) + self.list = to_streamed_response_wrapper( + graphs.list, + ) + + +class AsyncGraphsResourceWithStreamingResponse: + def __init__(self, graphs: AsyncGraphsResource) -> None: + self._graphs = graphs + + self.update = async_to_streamed_response_wrapper( + graphs.update, + ) + self.list = async_to_streamed_response_wrapper( + graphs.list, + ) diff --git a/src/writerai/resources/applications/jobs.py b/src/writerai/resources/applications/jobs.py new file mode 100644 index 00000000..69891e71 --- /dev/null +++ b/src/writerai/resources/applications/jobs.py @@ -0,0 +1,461 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing import Iterable +from typing_extensions import Literal + +import httpx + +from ..._types import Body, Omit, Query, Headers, NotGiven, omit, not_given +from ..._utils import path_template, maybe_transform, async_maybe_transform +from ..._compat import cached_property +from ..._resource import SyncAPIResource, AsyncAPIResource +from ..._response import ( + to_raw_response_wrapper, + to_streamed_response_wrapper, + async_to_raw_response_wrapper, + async_to_streamed_response_wrapper, +) +from ...pagination import SyncApplicationJobsOffset, AsyncApplicationJobsOffset +from ..._base_client import AsyncPaginator, make_request_options +from ...types.applications import job_list_params, job_create_params +from ...types.applications.job_retry_response import JobRetryResponse +from ...types.applications.job_create_response import JobCreateResponse +from ...types.applications.application_generate_async_response import ApplicationGenerateAsyncResponse + +__all__ = ["JobsResource", "AsyncJobsResource"] + + +class JobsResource(SyncAPIResource): + @cached_property + def with_raw_response(self) -> JobsResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/writer/writer-python#accessing-raw-response-data-eg-headers + """ + return JobsResourceWithRawResponse(self) + + @cached_property + def with_streaming_response(self) -> JobsResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/writer/writer-python#with_streaming_response + """ + return JobsResourceWithStreamingResponse(self) + + def create( + self, + application_id: str, + *, + inputs: Iterable[job_create_params.Input], + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> JobCreateResponse: + """ + Generate content asynchronously from an existing no-code agent (formerly called + no-code applications) with inputs. + + Args: + inputs: A list of input objects to generate content for. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not application_id: + raise ValueError(f"Expected a non-empty value for `application_id` but received {application_id!r}") + return self._post( + path_template("/v1/applications/{application_id}/jobs", application_id=application_id), + body=maybe_transform({"inputs": inputs}, job_create_params.JobCreateParams), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=JobCreateResponse, + ) + + def retrieve( + self, + job_id: str, + *, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> ApplicationGenerateAsyncResponse: + """ + Retrieves a single job created via the Async API. + + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not job_id: + raise ValueError(f"Expected a non-empty value for `job_id` but received {job_id!r}") + return self._get( + path_template("/v1/applications/jobs/{job_id}", job_id=job_id), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=ApplicationGenerateAsyncResponse, + ) + + def list( + self, + application_id: str, + *, + limit: int | Omit = omit, + offset: int | Omit = omit, + status: Literal["in_progress", "failed", "completed"] | Omit = omit, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> SyncApplicationJobsOffset[ApplicationGenerateAsyncResponse]: + """ + Retrieve all jobs created via the async API, linked to the provided application + ID (or alias). + + Args: + limit: The pagination limit for retrieving the jobs. + + offset: The pagination offset for retrieving the jobs. + + status: The status of the job. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not application_id: + raise ValueError(f"Expected a non-empty value for `application_id` but received {application_id!r}") + return self._get_api_list( + path_template("/v1/applications/{application_id}/jobs", application_id=application_id), + page=SyncApplicationJobsOffset[ApplicationGenerateAsyncResponse], + options=make_request_options( + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + query=maybe_transform( + { + "limit": limit, + "offset": offset, + "status": status, + }, + job_list_params.JobListParams, + ), + ), + model=ApplicationGenerateAsyncResponse, + ) + + def retry( + self, + job_id: str, + *, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> JobRetryResponse: + """ + Re-triggers the async execution of a single job previously created via the Async + api and terminated in error. + + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not job_id: + raise ValueError(f"Expected a non-empty value for `job_id` but received {job_id!r}") + return self._post( + path_template("/v1/applications/jobs/{job_id}/retry", job_id=job_id), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=JobRetryResponse, + ) + + +class AsyncJobsResource(AsyncAPIResource): + @cached_property + def with_raw_response(self) -> AsyncJobsResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/writer/writer-python#accessing-raw-response-data-eg-headers + """ + return AsyncJobsResourceWithRawResponse(self) + + @cached_property + def with_streaming_response(self) -> AsyncJobsResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/writer/writer-python#with_streaming_response + """ + return AsyncJobsResourceWithStreamingResponse(self) + + async def create( + self, + application_id: str, + *, + inputs: Iterable[job_create_params.Input], + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> JobCreateResponse: + """ + Generate content asynchronously from an existing no-code agent (formerly called + no-code applications) with inputs. + + Args: + inputs: A list of input objects to generate content for. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not application_id: + raise ValueError(f"Expected a non-empty value for `application_id` but received {application_id!r}") + return await self._post( + path_template("/v1/applications/{application_id}/jobs", application_id=application_id), + body=await async_maybe_transform({"inputs": inputs}, job_create_params.JobCreateParams), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=JobCreateResponse, + ) + + async def retrieve( + self, + job_id: str, + *, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> ApplicationGenerateAsyncResponse: + """ + Retrieves a single job created via the Async API. + + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not job_id: + raise ValueError(f"Expected a non-empty value for `job_id` but received {job_id!r}") + return await self._get( + path_template("/v1/applications/jobs/{job_id}", job_id=job_id), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=ApplicationGenerateAsyncResponse, + ) + + def list( + self, + application_id: str, + *, + limit: int | Omit = omit, + offset: int | Omit = omit, + status: Literal["in_progress", "failed", "completed"] | Omit = omit, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> AsyncPaginator[ApplicationGenerateAsyncResponse, AsyncApplicationJobsOffset[ApplicationGenerateAsyncResponse]]: + """ + Retrieve all jobs created via the async API, linked to the provided application + ID (or alias). + + Args: + limit: The pagination limit for retrieving the jobs. + + offset: The pagination offset for retrieving the jobs. + + status: The status of the job. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not application_id: + raise ValueError(f"Expected a non-empty value for `application_id` but received {application_id!r}") + return self._get_api_list( + path_template("/v1/applications/{application_id}/jobs", application_id=application_id), + page=AsyncApplicationJobsOffset[ApplicationGenerateAsyncResponse], + options=make_request_options( + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + query=maybe_transform( + { + "limit": limit, + "offset": offset, + "status": status, + }, + job_list_params.JobListParams, + ), + ), + model=ApplicationGenerateAsyncResponse, + ) + + async def retry( + self, + job_id: str, + *, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> JobRetryResponse: + """ + Re-triggers the async execution of a single job previously created via the Async + api and terminated in error. + + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not job_id: + raise ValueError(f"Expected a non-empty value for `job_id` but received {job_id!r}") + return await self._post( + path_template("/v1/applications/jobs/{job_id}/retry", job_id=job_id), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=JobRetryResponse, + ) + + +class JobsResourceWithRawResponse: + def __init__(self, jobs: JobsResource) -> None: + self._jobs = jobs + + self.create = to_raw_response_wrapper( + jobs.create, + ) + self.retrieve = to_raw_response_wrapper( + jobs.retrieve, + ) + self.list = to_raw_response_wrapper( + jobs.list, + ) + self.retry = to_raw_response_wrapper( + jobs.retry, + ) + + +class AsyncJobsResourceWithRawResponse: + def __init__(self, jobs: AsyncJobsResource) -> None: + self._jobs = jobs + + self.create = async_to_raw_response_wrapper( + jobs.create, + ) + self.retrieve = async_to_raw_response_wrapper( + jobs.retrieve, + ) + self.list = async_to_raw_response_wrapper( + jobs.list, + ) + self.retry = async_to_raw_response_wrapper( + jobs.retry, + ) + + +class JobsResourceWithStreamingResponse: + def __init__(self, jobs: JobsResource) -> None: + self._jobs = jobs + + self.create = to_streamed_response_wrapper( + jobs.create, + ) + self.retrieve = to_streamed_response_wrapper( + jobs.retrieve, + ) + self.list = to_streamed_response_wrapper( + jobs.list, + ) + self.retry = to_streamed_response_wrapper( + jobs.retry, + ) + + +class AsyncJobsResourceWithStreamingResponse: + def __init__(self, jobs: AsyncJobsResource) -> None: + self._jobs = jobs + + self.create = async_to_streamed_response_wrapper( + jobs.create, + ) + self.retrieve = async_to_streamed_response_wrapper( + jobs.retrieve, + ) + self.list = async_to_streamed_response_wrapper( + jobs.list, + ) + self.retry = async_to_streamed_response_wrapper( + jobs.retry, + ) diff --git a/src/writerai/resources/chat.py b/src/writerai/resources/chat.py new file mode 100644 index 00000000..52835ec2 --- /dev/null +++ b/src/writerai/resources/chat.py @@ -0,0 +1,851 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing import Union, Iterable +from typing_extensions import Literal, overload + +import httpx + +from ..types import chat_chat_params +from .._types import Body, Omit, Query, Headers, NotGiven, SequenceNotStr, omit, not_given +from .._utils import required_args, maybe_transform, async_maybe_transform +from .._compat import cached_property +from .._resource import SyncAPIResource, AsyncAPIResource +from .._response import ( + to_raw_response_wrapper, + to_streamed_response_wrapper, + async_to_raw_response_wrapper, + async_to_streamed_response_wrapper, +) +from .._streaming import Stream, AsyncStream +from .._base_client import make_request_options +from ..types.chat_completion import ChatCompletion +from ..types.chat_completion_chunk import ChatCompletionChunk +from ..types.shared_params.tool_param import ToolParam + +__all__ = ["ChatResource", "AsyncChatResource"] + + +class ChatResource(SyncAPIResource): + @cached_property + def with_raw_response(self) -> ChatResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/writer/writer-python#accessing-raw-response-data-eg-headers + """ + return ChatResourceWithRawResponse(self) + + @cached_property + def with_streaming_response(self) -> ChatResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/writer/writer-python#with_streaming_response + """ + return ChatResourceWithStreamingResponse(self) + + @overload + def chat( + self, + *, + messages: Iterable[chat_chat_params.Message], + model: str, + logprobs: bool | Omit = omit, + max_tokens: int | Omit = omit, + n: int | Omit = omit, + response_format: chat_chat_params.ResponseFormat | Omit = omit, + stop: Union[SequenceNotStr[str], str] | Omit = omit, + stream: Literal[False] | Omit = omit, + stream_options: chat_chat_params.StreamOptions | Omit = omit, + temperature: float | Omit = omit, + tool_choice: chat_chat_params.ToolChoice | Omit = omit, + tools: Iterable[ToolParam] | Omit = omit, + top_p: float | Omit = omit, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> ChatCompletion: + """Generate a chat completion based on the provided messages. + + The response shown + below is for non-streaming. To learn about streaming responses, see the + [chat completion guide](https://dev.writer.com/home/chat-completion). + + Args: + messages: An array of message objects that form the conversation history or context for + the model to respond to. The array must contain at least one message. + + model: The [ID of the model](https://dev.writer.com/home/models) to use for creating + the chat completion. Supports `palmyra-x5`, `palmyra-x4`, `palmyra-fin`, + `palmyra-med`, `palmyra-creative`, and `palmyra-x-003-instruct`. + + logprobs: Specifies whether to return log probabilities of the output tokens. + + max_tokens: Defines the maximum number of tokens (words and characters) that the model can + generate in the response. This can be adjusted to allow for longer or shorter + responses as needed. The maximum value varies by model. See the + [models overview](/home/models) for more information about the maximum number of + tokens for each model. + + n: Specifies the number of completions (responses) to generate from the model in a + single request. This parameter allows for generating multiple responses, + offering a variety of potential replies from which to choose. + + response_format: The response format to use for the chat completion, available with `palmyra-x4` + and `palmyra-x5`. + + `text` is the default response format. [JSON Schema](https://json-schema.org/) + is supported for structured responses. If you specify `json_schema`, you must + also provide a `json_schema` object. + + stop: A token or sequence of tokens that, when generated, will cause the model to stop + producing further content. This can be a single token or an array of tokens, + acting as a signal to end the output. + + stream: Indicates whether the response should be streamed incrementally as it is + generated or only returned once fully complete. Streaming can be useful for + providing real-time feedback in interactive applications. + + stream_options: Additional options for streaming. + + temperature: Controls the randomness or creativity of the model's responses. A higher + temperature results in more varied and less predictable text, while a lower + temperature produces more deterministic and conservative outputs. + + tool_choice: + Configure how the model will call functions: + + - `auto`: allows the model to automatically choose the tool to use, or not call + a tool + - `none`: disables tool calling; the model will instead generate a message + - `required`: requires the model to call one or more tools + + You can also use a JSON object to force the model to call a specific tool. For + example, `{"type": "function", "function": {"name": "get_current_weather"}}` + requires the model to call the `get_current_weather` function, regardless of the + prompt. + + tools: An array containing tool definitions for tools that the model can use to + generate responses. The tool definitions use JSON schema. You can define your + own functions or use one of the built-in `graph`, `llm`, `translation`, or + `vision` tools. Note that you can only use one built-in tool type in the array + (only one of `graph`, `llm`, `translation`, or `vision`). You can pass multiple + [custom tools](https://dev.writer.com/home/tool-calling) of type `function` in + the same request. + + top_p: Sets the threshold for "nucleus sampling," a technique to focus the model's + token generation on the most likely subset of tokens. Only tokens with + cumulative probability above this threshold are considered, controlling the + trade-off between creativity and coherence. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + ... + + @overload + def chat( + self, + *, + messages: Iterable[chat_chat_params.Message], + model: str, + stream: Literal[True], + logprobs: bool | Omit = omit, + max_tokens: int | Omit = omit, + n: int | Omit = omit, + response_format: chat_chat_params.ResponseFormat | Omit = omit, + stop: Union[SequenceNotStr[str], str] | Omit = omit, + stream_options: chat_chat_params.StreamOptions | Omit = omit, + temperature: float | Omit = omit, + tool_choice: chat_chat_params.ToolChoice | Omit = omit, + tools: Iterable[ToolParam] | Omit = omit, + top_p: float | Omit = omit, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> Stream[ChatCompletionChunk]: + """Generate a chat completion based on the provided messages. + + The response shown + below is for non-streaming. To learn about streaming responses, see the + [chat completion guide](https://dev.writer.com/home/chat-completion). + + Args: + messages: An array of message objects that form the conversation history or context for + the model to respond to. The array must contain at least one message. + + model: The [ID of the model](https://dev.writer.com/home/models) to use for creating + the chat completion. Supports `palmyra-x5`, `palmyra-x4`, `palmyra-fin`, + `palmyra-med`, `palmyra-creative`, and `palmyra-x-003-instruct`. + + stream: Indicates whether the response should be streamed incrementally as it is + generated or only returned once fully complete. Streaming can be useful for + providing real-time feedback in interactive applications. + + logprobs: Specifies whether to return log probabilities of the output tokens. + + max_tokens: Defines the maximum number of tokens (words and characters) that the model can + generate in the response. This can be adjusted to allow for longer or shorter + responses as needed. The maximum value varies by model. See the + [models overview](/home/models) for more information about the maximum number of + tokens for each model. + + n: Specifies the number of completions (responses) to generate from the model in a + single request. This parameter allows for generating multiple responses, + offering a variety of potential replies from which to choose. + + response_format: The response format to use for the chat completion, available with `palmyra-x4` + and `palmyra-x5`. + + `text` is the default response format. [JSON Schema](https://json-schema.org/) + is supported for structured responses. If you specify `json_schema`, you must + also provide a `json_schema` object. + + stop: A token or sequence of tokens that, when generated, will cause the model to stop + producing further content. This can be a single token or an array of tokens, + acting as a signal to end the output. + + stream_options: Additional options for streaming. + + temperature: Controls the randomness or creativity of the model's responses. A higher + temperature results in more varied and less predictable text, while a lower + temperature produces more deterministic and conservative outputs. + + tool_choice: + Configure how the model will call functions: + + - `auto`: allows the model to automatically choose the tool to use, or not call + a tool + - `none`: disables tool calling; the model will instead generate a message + - `required`: requires the model to call one or more tools + + You can also use a JSON object to force the model to call a specific tool. For + example, `{"type": "function", "function": {"name": "get_current_weather"}}` + requires the model to call the `get_current_weather` function, regardless of the + prompt. + + tools: An array containing tool definitions for tools that the model can use to + generate responses. The tool definitions use JSON schema. You can define your + own functions or use one of the built-in `graph`, `llm`, `translation`, or + `vision` tools. Note that you can only use one built-in tool type in the array + (only one of `graph`, `llm`, `translation`, or `vision`). You can pass multiple + [custom tools](https://dev.writer.com/home/tool-calling) of type `function` in + the same request. + + top_p: Sets the threshold for "nucleus sampling," a technique to focus the model's + token generation on the most likely subset of tokens. Only tokens with + cumulative probability above this threshold are considered, controlling the + trade-off between creativity and coherence. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + ... + + @overload + def chat( + self, + *, + messages: Iterable[chat_chat_params.Message], + model: str, + stream: bool, + logprobs: bool | Omit = omit, + max_tokens: int | Omit = omit, + n: int | Omit = omit, + response_format: chat_chat_params.ResponseFormat | Omit = omit, + stop: Union[SequenceNotStr[str], str] | Omit = omit, + stream_options: chat_chat_params.StreamOptions | Omit = omit, + temperature: float | Omit = omit, + tool_choice: chat_chat_params.ToolChoice | Omit = omit, + tools: Iterable[ToolParam] | Omit = omit, + top_p: float | Omit = omit, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> ChatCompletion | Stream[ChatCompletionChunk]: + """Generate a chat completion based on the provided messages. + + The response shown + below is for non-streaming. To learn about streaming responses, see the + [chat completion guide](https://dev.writer.com/home/chat-completion). + + Args: + messages: An array of message objects that form the conversation history or context for + the model to respond to. The array must contain at least one message. + + model: The [ID of the model](https://dev.writer.com/home/models) to use for creating + the chat completion. Supports `palmyra-x5`, `palmyra-x4`, `palmyra-fin`, + `palmyra-med`, `palmyra-creative`, and `palmyra-x-003-instruct`. + + stream: Indicates whether the response should be streamed incrementally as it is + generated or only returned once fully complete. Streaming can be useful for + providing real-time feedback in interactive applications. + + logprobs: Specifies whether to return log probabilities of the output tokens. + + max_tokens: Defines the maximum number of tokens (words and characters) that the model can + generate in the response. This can be adjusted to allow for longer or shorter + responses as needed. The maximum value varies by model. See the + [models overview](/home/models) for more information about the maximum number of + tokens for each model. + + n: Specifies the number of completions (responses) to generate from the model in a + single request. This parameter allows for generating multiple responses, + offering a variety of potential replies from which to choose. + + response_format: The response format to use for the chat completion, available with `palmyra-x4` + and `palmyra-x5`. + + `text` is the default response format. [JSON Schema](https://json-schema.org/) + is supported for structured responses. If you specify `json_schema`, you must + also provide a `json_schema` object. + + stop: A token or sequence of tokens that, when generated, will cause the model to stop + producing further content. This can be a single token or an array of tokens, + acting as a signal to end the output. + + stream_options: Additional options for streaming. + + temperature: Controls the randomness or creativity of the model's responses. A higher + temperature results in more varied and less predictable text, while a lower + temperature produces more deterministic and conservative outputs. + + tool_choice: + Configure how the model will call functions: + + - `auto`: allows the model to automatically choose the tool to use, or not call + a tool + - `none`: disables tool calling; the model will instead generate a message + - `required`: requires the model to call one or more tools + + You can also use a JSON object to force the model to call a specific tool. For + example, `{"type": "function", "function": {"name": "get_current_weather"}}` + requires the model to call the `get_current_weather` function, regardless of the + prompt. + + tools: An array containing tool definitions for tools that the model can use to + generate responses. The tool definitions use JSON schema. You can define your + own functions or use one of the built-in `graph`, `llm`, `translation`, or + `vision` tools. Note that you can only use one built-in tool type in the array + (only one of `graph`, `llm`, `translation`, or `vision`). You can pass multiple + [custom tools](https://dev.writer.com/home/tool-calling) of type `function` in + the same request. + + top_p: Sets the threshold for "nucleus sampling," a technique to focus the model's + token generation on the most likely subset of tokens. Only tokens with + cumulative probability above this threshold are considered, controlling the + trade-off between creativity and coherence. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + ... + + @required_args(["messages", "model"], ["messages", "model", "stream"]) + def chat( + self, + *, + messages: Iterable[chat_chat_params.Message], + model: str, + logprobs: bool | Omit = omit, + max_tokens: int | Omit = omit, + n: int | Omit = omit, + response_format: chat_chat_params.ResponseFormat | Omit = omit, + stop: Union[SequenceNotStr[str], str] | Omit = omit, + stream: Literal[False] | Literal[True] | Omit = omit, + stream_options: chat_chat_params.StreamOptions | Omit = omit, + temperature: float | Omit = omit, + tool_choice: chat_chat_params.ToolChoice | Omit = omit, + tools: Iterable[ToolParam] | Omit = omit, + top_p: float | Omit = omit, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> ChatCompletion | Stream[ChatCompletionChunk]: + return self._post( + "/v1/chat", + body=maybe_transform( + { + "messages": messages, + "model": model, + "logprobs": logprobs, + "max_tokens": max_tokens, + "n": n, + "response_format": response_format, + "stop": stop, + "stream": stream, + "stream_options": stream_options, + "temperature": temperature, + "tool_choice": tool_choice, + "tools": tools, + "top_p": top_p, + }, + chat_chat_params.ChatChatParamsStreaming if stream else chat_chat_params.ChatChatParamsNonStreaming, + ), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=ChatCompletion, + stream=stream or False, + stream_cls=Stream[ChatCompletionChunk], + ) + + +class AsyncChatResource(AsyncAPIResource): + @cached_property + def with_raw_response(self) -> AsyncChatResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/writer/writer-python#accessing-raw-response-data-eg-headers + """ + return AsyncChatResourceWithRawResponse(self) + + @cached_property + def with_streaming_response(self) -> AsyncChatResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/writer/writer-python#with_streaming_response + """ + return AsyncChatResourceWithStreamingResponse(self) + + @overload + async def chat( + self, + *, + messages: Iterable[chat_chat_params.Message], + model: str, + logprobs: bool | Omit = omit, + max_tokens: int | Omit = omit, + n: int | Omit = omit, + response_format: chat_chat_params.ResponseFormat | Omit = omit, + stop: Union[SequenceNotStr[str], str] | Omit = omit, + stream: Literal[False] | Omit = omit, + stream_options: chat_chat_params.StreamOptions | Omit = omit, + temperature: float | Omit = omit, + tool_choice: chat_chat_params.ToolChoice | Omit = omit, + tools: Iterable[ToolParam] | Omit = omit, + top_p: float | Omit = omit, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> ChatCompletion: + """Generate a chat completion based on the provided messages. + + The response shown + below is for non-streaming. To learn about streaming responses, see the + [chat completion guide](https://dev.writer.com/home/chat-completion). + + Args: + messages: An array of message objects that form the conversation history or context for + the model to respond to. The array must contain at least one message. + + model: The [ID of the model](https://dev.writer.com/home/models) to use for creating + the chat completion. Supports `palmyra-x5`, `palmyra-x4`, `palmyra-fin`, + `palmyra-med`, `palmyra-creative`, and `palmyra-x-003-instruct`. + + logprobs: Specifies whether to return log probabilities of the output tokens. + + max_tokens: Defines the maximum number of tokens (words and characters) that the model can + generate in the response. This can be adjusted to allow for longer or shorter + responses as needed. The maximum value varies by model. See the + [models overview](/home/models) for more information about the maximum number of + tokens for each model. + + n: Specifies the number of completions (responses) to generate from the model in a + single request. This parameter allows for generating multiple responses, + offering a variety of potential replies from which to choose. + + response_format: The response format to use for the chat completion, available with `palmyra-x4` + and `palmyra-x5`. + + `text` is the default response format. [JSON Schema](https://json-schema.org/) + is supported for structured responses. If you specify `json_schema`, you must + also provide a `json_schema` object. + + stop: A token or sequence of tokens that, when generated, will cause the model to stop + producing further content. This can be a single token or an array of tokens, + acting as a signal to end the output. + + stream: Indicates whether the response should be streamed incrementally as it is + generated or only returned once fully complete. Streaming can be useful for + providing real-time feedback in interactive applications. + + stream_options: Additional options for streaming. + + temperature: Controls the randomness or creativity of the model's responses. A higher + temperature results in more varied and less predictable text, while a lower + temperature produces more deterministic and conservative outputs. + + tool_choice: + Configure how the model will call functions: + + - `auto`: allows the model to automatically choose the tool to use, or not call + a tool + - `none`: disables tool calling; the model will instead generate a message + - `required`: requires the model to call one or more tools + + You can also use a JSON object to force the model to call a specific tool. For + example, `{"type": "function", "function": {"name": "get_current_weather"}}` + requires the model to call the `get_current_weather` function, regardless of the + prompt. + + tools: An array containing tool definitions for tools that the model can use to + generate responses. The tool definitions use JSON schema. You can define your + own functions or use one of the built-in `graph`, `llm`, `translation`, or + `vision` tools. Note that you can only use one built-in tool type in the array + (only one of `graph`, `llm`, `translation`, or `vision`). You can pass multiple + [custom tools](https://dev.writer.com/home/tool-calling) of type `function` in + the same request. + + top_p: Sets the threshold for "nucleus sampling," a technique to focus the model's + token generation on the most likely subset of tokens. Only tokens with + cumulative probability above this threshold are considered, controlling the + trade-off between creativity and coherence. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + ... + + @overload + async def chat( + self, + *, + messages: Iterable[chat_chat_params.Message], + model: str, + stream: Literal[True], + logprobs: bool | Omit = omit, + max_tokens: int | Omit = omit, + n: int | Omit = omit, + response_format: chat_chat_params.ResponseFormat | Omit = omit, + stop: Union[SequenceNotStr[str], str] | Omit = omit, + stream_options: chat_chat_params.StreamOptions | Omit = omit, + temperature: float | Omit = omit, + tool_choice: chat_chat_params.ToolChoice | Omit = omit, + tools: Iterable[ToolParam] | Omit = omit, + top_p: float | Omit = omit, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> AsyncStream[ChatCompletionChunk]: + """Generate a chat completion based on the provided messages. + + The response shown + below is for non-streaming. To learn about streaming responses, see the + [chat completion guide](https://dev.writer.com/home/chat-completion). + + Args: + messages: An array of message objects that form the conversation history or context for + the model to respond to. The array must contain at least one message. + + model: The [ID of the model](https://dev.writer.com/home/models) to use for creating + the chat completion. Supports `palmyra-x5`, `palmyra-x4`, `palmyra-fin`, + `palmyra-med`, `palmyra-creative`, and `palmyra-x-003-instruct`. + + stream: Indicates whether the response should be streamed incrementally as it is + generated or only returned once fully complete. Streaming can be useful for + providing real-time feedback in interactive applications. + + logprobs: Specifies whether to return log probabilities of the output tokens. + + max_tokens: Defines the maximum number of tokens (words and characters) that the model can + generate in the response. This can be adjusted to allow for longer or shorter + responses as needed. The maximum value varies by model. See the + [models overview](/home/models) for more information about the maximum number of + tokens for each model. + + n: Specifies the number of completions (responses) to generate from the model in a + single request. This parameter allows for generating multiple responses, + offering a variety of potential replies from which to choose. + + response_format: The response format to use for the chat completion, available with `palmyra-x4` + and `palmyra-x5`. + + `text` is the default response format. [JSON Schema](https://json-schema.org/) + is supported for structured responses. If you specify `json_schema`, you must + also provide a `json_schema` object. + + stop: A token or sequence of tokens that, when generated, will cause the model to stop + producing further content. This can be a single token or an array of tokens, + acting as a signal to end the output. + + stream_options: Additional options for streaming. + + temperature: Controls the randomness or creativity of the model's responses. A higher + temperature results in more varied and less predictable text, while a lower + temperature produces more deterministic and conservative outputs. + + tool_choice: + Configure how the model will call functions: + + - `auto`: allows the model to automatically choose the tool to use, or not call + a tool + - `none`: disables tool calling; the model will instead generate a message + - `required`: requires the model to call one or more tools + + You can also use a JSON object to force the model to call a specific tool. For + example, `{"type": "function", "function": {"name": "get_current_weather"}}` + requires the model to call the `get_current_weather` function, regardless of the + prompt. + + tools: An array containing tool definitions for tools that the model can use to + generate responses. The tool definitions use JSON schema. You can define your + own functions or use one of the built-in `graph`, `llm`, `translation`, or + `vision` tools. Note that you can only use one built-in tool type in the array + (only one of `graph`, `llm`, `translation`, or `vision`). You can pass multiple + [custom tools](https://dev.writer.com/home/tool-calling) of type `function` in + the same request. + + top_p: Sets the threshold for "nucleus sampling," a technique to focus the model's + token generation on the most likely subset of tokens. Only tokens with + cumulative probability above this threshold are considered, controlling the + trade-off between creativity and coherence. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + ... + + @overload + async def chat( + self, + *, + messages: Iterable[chat_chat_params.Message], + model: str, + stream: bool, + logprobs: bool | Omit = omit, + max_tokens: int | Omit = omit, + n: int | Omit = omit, + response_format: chat_chat_params.ResponseFormat | Omit = omit, + stop: Union[SequenceNotStr[str], str] | Omit = omit, + stream_options: chat_chat_params.StreamOptions | Omit = omit, + temperature: float | Omit = omit, + tool_choice: chat_chat_params.ToolChoice | Omit = omit, + tools: Iterable[ToolParam] | Omit = omit, + top_p: float | Omit = omit, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> ChatCompletion | AsyncStream[ChatCompletionChunk]: + """Generate a chat completion based on the provided messages. + + The response shown + below is for non-streaming. To learn about streaming responses, see the + [chat completion guide](https://dev.writer.com/home/chat-completion). + + Args: + messages: An array of message objects that form the conversation history or context for + the model to respond to. The array must contain at least one message. + + model: The [ID of the model](https://dev.writer.com/home/models) to use for creating + the chat completion. Supports `palmyra-x5`, `palmyra-x4`, `palmyra-fin`, + `palmyra-med`, `palmyra-creative`, and `palmyra-x-003-instruct`. + + stream: Indicates whether the response should be streamed incrementally as it is + generated or only returned once fully complete. Streaming can be useful for + providing real-time feedback in interactive applications. + + logprobs: Specifies whether to return log probabilities of the output tokens. + + max_tokens: Defines the maximum number of tokens (words and characters) that the model can + generate in the response. This can be adjusted to allow for longer or shorter + responses as needed. The maximum value varies by model. See the + [models overview](/home/models) for more information about the maximum number of + tokens for each model. + + n: Specifies the number of completions (responses) to generate from the model in a + single request. This parameter allows for generating multiple responses, + offering a variety of potential replies from which to choose. + + response_format: The response format to use for the chat completion, available with `palmyra-x4` + and `palmyra-x5`. + + `text` is the default response format. [JSON Schema](https://json-schema.org/) + is supported for structured responses. If you specify `json_schema`, you must + also provide a `json_schema` object. + + stop: A token or sequence of tokens that, when generated, will cause the model to stop + producing further content. This can be a single token or an array of tokens, + acting as a signal to end the output. + + stream_options: Additional options for streaming. + + temperature: Controls the randomness or creativity of the model's responses. A higher + temperature results in more varied and less predictable text, while a lower + temperature produces more deterministic and conservative outputs. + + tool_choice: + Configure how the model will call functions: + + - `auto`: allows the model to automatically choose the tool to use, or not call + a tool + - `none`: disables tool calling; the model will instead generate a message + - `required`: requires the model to call one or more tools + + You can also use a JSON object to force the model to call a specific tool. For + example, `{"type": "function", "function": {"name": "get_current_weather"}}` + requires the model to call the `get_current_weather` function, regardless of the + prompt. + + tools: An array containing tool definitions for tools that the model can use to + generate responses. The tool definitions use JSON schema. You can define your + own functions or use one of the built-in `graph`, `llm`, `translation`, or + `vision` tools. Note that you can only use one built-in tool type in the array + (only one of `graph`, `llm`, `translation`, or `vision`). You can pass multiple + [custom tools](https://dev.writer.com/home/tool-calling) of type `function` in + the same request. + + top_p: Sets the threshold for "nucleus sampling," a technique to focus the model's + token generation on the most likely subset of tokens. Only tokens with + cumulative probability above this threshold are considered, controlling the + trade-off between creativity and coherence. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + ... + + @required_args(["messages", "model"], ["messages", "model", "stream"]) + async def chat( + self, + *, + messages: Iterable[chat_chat_params.Message], + model: str, + logprobs: bool | Omit = omit, + max_tokens: int | Omit = omit, + n: int | Omit = omit, + response_format: chat_chat_params.ResponseFormat | Omit = omit, + stop: Union[SequenceNotStr[str], str] | Omit = omit, + stream: Literal[False] | Literal[True] | Omit = omit, + stream_options: chat_chat_params.StreamOptions | Omit = omit, + temperature: float | Omit = omit, + tool_choice: chat_chat_params.ToolChoice | Omit = omit, + tools: Iterable[ToolParam] | Omit = omit, + top_p: float | Omit = omit, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> ChatCompletion | AsyncStream[ChatCompletionChunk]: + return await self._post( + "/v1/chat", + body=await async_maybe_transform( + { + "messages": messages, + "model": model, + "logprobs": logprobs, + "max_tokens": max_tokens, + "n": n, + "response_format": response_format, + "stop": stop, + "stream": stream, + "stream_options": stream_options, + "temperature": temperature, + "tool_choice": tool_choice, + "tools": tools, + "top_p": top_p, + }, + chat_chat_params.ChatChatParamsStreaming if stream else chat_chat_params.ChatChatParamsNonStreaming, + ), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=ChatCompletion, + stream=stream or False, + stream_cls=AsyncStream[ChatCompletionChunk], + ) + + +class ChatResourceWithRawResponse: + def __init__(self, chat: ChatResource) -> None: + self._chat = chat + + self.chat = to_raw_response_wrapper( + chat.chat, + ) + + +class AsyncChatResourceWithRawResponse: + def __init__(self, chat: AsyncChatResource) -> None: + self._chat = chat + + self.chat = async_to_raw_response_wrapper( + chat.chat, + ) + + +class ChatResourceWithStreamingResponse: + def __init__(self, chat: ChatResource) -> None: + self._chat = chat + + self.chat = to_streamed_response_wrapper( + chat.chat, + ) + + +class AsyncChatResourceWithStreamingResponse: + def __init__(self, chat: AsyncChatResource) -> None: + self._chat = chat + + self.chat = async_to_streamed_response_wrapper( + chat.chat, + ) diff --git a/src/writerai/resources/completions.py b/src/writerai/resources/completions.py new file mode 100644 index 00000000..382cb6a0 --- /dev/null +++ b/src/writerai/resources/completions.py @@ -0,0 +1,586 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing import Union +from typing_extensions import Literal, overload + +import httpx + +from ..types import completion_create_params +from .._types import Body, Omit, Query, Headers, NotGiven, SequenceNotStr, omit, not_given +from .._utils import required_args, maybe_transform, async_maybe_transform +from .._compat import cached_property +from .._resource import SyncAPIResource, AsyncAPIResource +from .._response import ( + to_raw_response_wrapper, + to_streamed_response_wrapper, + async_to_raw_response_wrapper, + async_to_streamed_response_wrapper, +) +from .._streaming import Stream, AsyncStream +from .._base_client import make_request_options +from ..types.completion import Completion +from ..types.completion_chunk import CompletionChunk + +__all__ = ["CompletionsResource", "AsyncCompletionsResource"] + + +class CompletionsResource(SyncAPIResource): + @cached_property + def with_raw_response(self) -> CompletionsResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/writer/writer-python#accessing-raw-response-data-eg-headers + """ + return CompletionsResourceWithRawResponse(self) + + @cached_property + def with_streaming_response(self) -> CompletionsResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/writer/writer-python#with_streaming_response + """ + return CompletionsResourceWithStreamingResponse(self) + + @overload + def create( + self, + *, + model: str, + prompt: str, + best_of: int | Omit = omit, + max_tokens: int | Omit = omit, + random_seed: int | Omit = omit, + stop: Union[SequenceNotStr[str], str] | Omit = omit, + stream: Literal[False] | Omit = omit, + temperature: float | Omit = omit, + top_p: float | Omit = omit, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> Completion: + """Generate text completions using the specified model and prompt. + + This endpoint is + useful for text generation tasks that don't require conversational context. + + Args: + model: The [ID of the model](https://dev.writer.com/home/models) to use for generating + text. Supports `palmyra-x5`, `palmyra-x4`, `palmyra-fin`, `palmyra-med`, + `palmyra-creative`, and `palmyra-x-003-instruct`. + + prompt: The input text that the model will process to generate a response. + + best_of: Specifies the number of completions to generate and return the best one. Useful + for generating multiple outputs and choosing the best based on some criteria. + + max_tokens: The maximum number of tokens that the model can generate in the response. + + random_seed: A seed used to initialize the random number generator for the model, ensuring + reproducibility of the output when the same inputs are provided. + + stop: Specifies stopping conditions for the model's output generation. This can be an + array of strings or a single string that the model will look for as a signal to + stop generating further tokens. + + stream: Determines whether the model's output should be streamed. If true, the output is + generated and sent incrementally, which can be useful for real-time + applications. + + temperature: Controls the randomness of the model's outputs. Higher values lead to more + random outputs, while lower values make the model more deterministic. + + top_p: Used to control the nucleus sampling, where only the most probable tokens with a + cumulative probability of top_p are considered for sampling, providing a way to + fine-tune the randomness of predictions. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + ... + + @overload + def create( + self, + *, + model: str, + prompt: str, + stream: Literal[True], + best_of: int | Omit = omit, + max_tokens: int | Omit = omit, + random_seed: int | Omit = omit, + stop: Union[SequenceNotStr[str], str] | Omit = omit, + temperature: float | Omit = omit, + top_p: float | Omit = omit, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> Stream[CompletionChunk]: + """Generate text completions using the specified model and prompt. + + This endpoint is + useful for text generation tasks that don't require conversational context. + + Args: + model: The [ID of the model](https://dev.writer.com/home/models) to use for generating + text. Supports `palmyra-x5`, `palmyra-x4`, `palmyra-fin`, `palmyra-med`, + `palmyra-creative`, and `palmyra-x-003-instruct`. + + prompt: The input text that the model will process to generate a response. + + stream: Determines whether the model's output should be streamed. If true, the output is + generated and sent incrementally, which can be useful for real-time + applications. + + best_of: Specifies the number of completions to generate and return the best one. Useful + for generating multiple outputs and choosing the best based on some criteria. + + max_tokens: The maximum number of tokens that the model can generate in the response. + + random_seed: A seed used to initialize the random number generator for the model, ensuring + reproducibility of the output when the same inputs are provided. + + stop: Specifies stopping conditions for the model's output generation. This can be an + array of strings or a single string that the model will look for as a signal to + stop generating further tokens. + + temperature: Controls the randomness of the model's outputs. Higher values lead to more + random outputs, while lower values make the model more deterministic. + + top_p: Used to control the nucleus sampling, where only the most probable tokens with a + cumulative probability of top_p are considered for sampling, providing a way to + fine-tune the randomness of predictions. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + ... + + @overload + def create( + self, + *, + model: str, + prompt: str, + stream: bool, + best_of: int | Omit = omit, + max_tokens: int | Omit = omit, + random_seed: int | Omit = omit, + stop: Union[SequenceNotStr[str], str] | Omit = omit, + temperature: float | Omit = omit, + top_p: float | Omit = omit, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> Completion | Stream[CompletionChunk]: + """Generate text completions using the specified model and prompt. + + This endpoint is + useful for text generation tasks that don't require conversational context. + + Args: + model: The [ID of the model](https://dev.writer.com/home/models) to use for generating + text. Supports `palmyra-x5`, `palmyra-x4`, `palmyra-fin`, `palmyra-med`, + `palmyra-creative`, and `palmyra-x-003-instruct`. + + prompt: The input text that the model will process to generate a response. + + stream: Determines whether the model's output should be streamed. If true, the output is + generated and sent incrementally, which can be useful for real-time + applications. + + best_of: Specifies the number of completions to generate and return the best one. Useful + for generating multiple outputs and choosing the best based on some criteria. + + max_tokens: The maximum number of tokens that the model can generate in the response. + + random_seed: A seed used to initialize the random number generator for the model, ensuring + reproducibility of the output when the same inputs are provided. + + stop: Specifies stopping conditions for the model's output generation. This can be an + array of strings or a single string that the model will look for as a signal to + stop generating further tokens. + + temperature: Controls the randomness of the model's outputs. Higher values lead to more + random outputs, while lower values make the model more deterministic. + + top_p: Used to control the nucleus sampling, where only the most probable tokens with a + cumulative probability of top_p are considered for sampling, providing a way to + fine-tune the randomness of predictions. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + ... + + @required_args(["model", "prompt"], ["model", "prompt", "stream"]) + def create( + self, + *, + model: str, + prompt: str, + best_of: int | Omit = omit, + max_tokens: int | Omit = omit, + random_seed: int | Omit = omit, + stop: Union[SequenceNotStr[str], str] | Omit = omit, + stream: Literal[False] | Literal[True] | Omit = omit, + temperature: float | Omit = omit, + top_p: float | Omit = omit, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> Completion | Stream[CompletionChunk]: + return self._post( + "/v1/completions", + body=maybe_transform( + { + "model": model, + "prompt": prompt, + "best_of": best_of, + "max_tokens": max_tokens, + "random_seed": random_seed, + "stop": stop, + "stream": stream, + "temperature": temperature, + "top_p": top_p, + }, + completion_create_params.CompletionCreateParamsStreaming + if stream + else completion_create_params.CompletionCreateParamsNonStreaming, + ), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=Completion, + stream=stream or False, + stream_cls=Stream[CompletionChunk], + ) + + +class AsyncCompletionsResource(AsyncAPIResource): + @cached_property + def with_raw_response(self) -> AsyncCompletionsResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/writer/writer-python#accessing-raw-response-data-eg-headers + """ + return AsyncCompletionsResourceWithRawResponse(self) + + @cached_property + def with_streaming_response(self) -> AsyncCompletionsResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/writer/writer-python#with_streaming_response + """ + return AsyncCompletionsResourceWithStreamingResponse(self) + + @overload + async def create( + self, + *, + model: str, + prompt: str, + best_of: int | Omit = omit, + max_tokens: int | Omit = omit, + random_seed: int | Omit = omit, + stop: Union[SequenceNotStr[str], str] | Omit = omit, + stream: Literal[False] | Omit = omit, + temperature: float | Omit = omit, + top_p: float | Omit = omit, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> Completion: + """Generate text completions using the specified model and prompt. + + This endpoint is + useful for text generation tasks that don't require conversational context. + + Args: + model: The [ID of the model](https://dev.writer.com/home/models) to use for generating + text. Supports `palmyra-x5`, `palmyra-x4`, `palmyra-fin`, `palmyra-med`, + `palmyra-creative`, and `palmyra-x-003-instruct`. + + prompt: The input text that the model will process to generate a response. + + best_of: Specifies the number of completions to generate and return the best one. Useful + for generating multiple outputs and choosing the best based on some criteria. + + max_tokens: The maximum number of tokens that the model can generate in the response. + + random_seed: A seed used to initialize the random number generator for the model, ensuring + reproducibility of the output when the same inputs are provided. + + stop: Specifies stopping conditions for the model's output generation. This can be an + array of strings or a single string that the model will look for as a signal to + stop generating further tokens. + + stream: Determines whether the model's output should be streamed. If true, the output is + generated and sent incrementally, which can be useful for real-time + applications. + + temperature: Controls the randomness of the model's outputs. Higher values lead to more + random outputs, while lower values make the model more deterministic. + + top_p: Used to control the nucleus sampling, where only the most probable tokens with a + cumulative probability of top_p are considered for sampling, providing a way to + fine-tune the randomness of predictions. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + ... + + @overload + async def create( + self, + *, + model: str, + prompt: str, + stream: Literal[True], + best_of: int | Omit = omit, + max_tokens: int | Omit = omit, + random_seed: int | Omit = omit, + stop: Union[SequenceNotStr[str], str] | Omit = omit, + temperature: float | Omit = omit, + top_p: float | Omit = omit, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> AsyncStream[CompletionChunk]: + """Generate text completions using the specified model and prompt. + + This endpoint is + useful for text generation tasks that don't require conversational context. + + Args: + model: The [ID of the model](https://dev.writer.com/home/models) to use for generating + text. Supports `palmyra-x5`, `palmyra-x4`, `palmyra-fin`, `palmyra-med`, + `palmyra-creative`, and `palmyra-x-003-instruct`. + + prompt: The input text that the model will process to generate a response. + + stream: Determines whether the model's output should be streamed. If true, the output is + generated and sent incrementally, which can be useful for real-time + applications. + + best_of: Specifies the number of completions to generate and return the best one. Useful + for generating multiple outputs and choosing the best based on some criteria. + + max_tokens: The maximum number of tokens that the model can generate in the response. + + random_seed: A seed used to initialize the random number generator for the model, ensuring + reproducibility of the output when the same inputs are provided. + + stop: Specifies stopping conditions for the model's output generation. This can be an + array of strings or a single string that the model will look for as a signal to + stop generating further tokens. + + temperature: Controls the randomness of the model's outputs. Higher values lead to more + random outputs, while lower values make the model more deterministic. + + top_p: Used to control the nucleus sampling, where only the most probable tokens with a + cumulative probability of top_p are considered for sampling, providing a way to + fine-tune the randomness of predictions. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + ... + + @overload + async def create( + self, + *, + model: str, + prompt: str, + stream: bool, + best_of: int | Omit = omit, + max_tokens: int | Omit = omit, + random_seed: int | Omit = omit, + stop: Union[SequenceNotStr[str], str] | Omit = omit, + temperature: float | Omit = omit, + top_p: float | Omit = omit, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> Completion | AsyncStream[CompletionChunk]: + """Generate text completions using the specified model and prompt. + + This endpoint is + useful for text generation tasks that don't require conversational context. + + Args: + model: The [ID of the model](https://dev.writer.com/home/models) to use for generating + text. Supports `palmyra-x5`, `palmyra-x4`, `palmyra-fin`, `palmyra-med`, + `palmyra-creative`, and `palmyra-x-003-instruct`. + + prompt: The input text that the model will process to generate a response. + + stream: Determines whether the model's output should be streamed. If true, the output is + generated and sent incrementally, which can be useful for real-time + applications. + + best_of: Specifies the number of completions to generate and return the best one. Useful + for generating multiple outputs and choosing the best based on some criteria. + + max_tokens: The maximum number of tokens that the model can generate in the response. + + random_seed: A seed used to initialize the random number generator for the model, ensuring + reproducibility of the output when the same inputs are provided. + + stop: Specifies stopping conditions for the model's output generation. This can be an + array of strings or a single string that the model will look for as a signal to + stop generating further tokens. + + temperature: Controls the randomness of the model's outputs. Higher values lead to more + random outputs, while lower values make the model more deterministic. + + top_p: Used to control the nucleus sampling, where only the most probable tokens with a + cumulative probability of top_p are considered for sampling, providing a way to + fine-tune the randomness of predictions. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + ... + + @required_args(["model", "prompt"], ["model", "prompt", "stream"]) + async def create( + self, + *, + model: str, + prompt: str, + best_of: int | Omit = omit, + max_tokens: int | Omit = omit, + random_seed: int | Omit = omit, + stop: Union[SequenceNotStr[str], str] | Omit = omit, + stream: Literal[False] | Literal[True] | Omit = omit, + temperature: float | Omit = omit, + top_p: float | Omit = omit, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> Completion | AsyncStream[CompletionChunk]: + return await self._post( + "/v1/completions", + body=await async_maybe_transform( + { + "model": model, + "prompt": prompt, + "best_of": best_of, + "max_tokens": max_tokens, + "random_seed": random_seed, + "stop": stop, + "stream": stream, + "temperature": temperature, + "top_p": top_p, + }, + completion_create_params.CompletionCreateParamsStreaming + if stream + else completion_create_params.CompletionCreateParamsNonStreaming, + ), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=Completion, + stream=stream or False, + stream_cls=AsyncStream[CompletionChunk], + ) + + +class CompletionsResourceWithRawResponse: + def __init__(self, completions: CompletionsResource) -> None: + self._completions = completions + + self.create = to_raw_response_wrapper( + completions.create, + ) + + +class AsyncCompletionsResourceWithRawResponse: + def __init__(self, completions: AsyncCompletionsResource) -> None: + self._completions = completions + + self.create = async_to_raw_response_wrapper( + completions.create, + ) + + +class CompletionsResourceWithStreamingResponse: + def __init__(self, completions: CompletionsResource) -> None: + self._completions = completions + + self.create = to_streamed_response_wrapper( + completions.create, + ) + + +class AsyncCompletionsResourceWithStreamingResponse: + def __init__(self, completions: AsyncCompletionsResource) -> None: + self._completions = completions + + self.create = async_to_streamed_response_wrapper( + completions.create, + ) diff --git a/src/writerai/resources/files.py b/src/writerai/resources/files.py new file mode 100644 index 00000000..a7e52e5c --- /dev/null +++ b/src/writerai/resources/files.py @@ -0,0 +1,716 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +import os +from typing_extensions import Literal + +import httpx + +from ..types import file_list_params, file_retry_params, file_upload_params +from .._files import read_file_content, async_read_file_content +from .._types import ( + Body, + Omit, + Query, + Headers, + NotGiven, + BinaryTypes, + FileContent, + SequenceNotStr, + AsyncBinaryTypes, + omit, + not_given, +) +from .._utils import path_template, maybe_transform, async_maybe_transform +from .._compat import cached_property +from .._resource import SyncAPIResource, AsyncAPIResource +from .._response import ( + BinaryAPIResponse, + AsyncBinaryAPIResponse, + StreamedBinaryAPIResponse, + AsyncStreamedBinaryAPIResponse, + to_raw_response_wrapper, + to_streamed_response_wrapper, + async_to_raw_response_wrapper, + to_custom_raw_response_wrapper, + async_to_streamed_response_wrapper, + to_custom_streamed_response_wrapper, + async_to_custom_raw_response_wrapper, + async_to_custom_streamed_response_wrapper, +) +from ..pagination import SyncCursorPage, AsyncCursorPage +from ..types.file import File +from .._base_client import AsyncPaginator, make_request_options +from ..types.file_retry_response import FileRetryResponse +from ..types.file_delete_response import FileDeleteResponse + +__all__ = ["FilesResource", "AsyncFilesResource"] + + +class FilesResource(SyncAPIResource): + @cached_property + def with_raw_response(self) -> FilesResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/writer/writer-python#accessing-raw-response-data-eg-headers + """ + return FilesResourceWithRawResponse(self) + + @cached_property + def with_streaming_response(self) -> FilesResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/writer/writer-python#with_streaming_response + """ + return FilesResourceWithStreamingResponse(self) + + def retrieve( + self, + file_id: str, + *, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> File: + """ + Retrieve detailed information about a specific file, including its metadata, + status, and associated graphs. + + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not file_id: + raise ValueError(f"Expected a non-empty value for `file_id` but received {file_id!r}") + return self._get( + path_template("/v1/files/{file_id}", file_id=file_id), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=File, + ) + + def list( + self, + *, + after: str | Omit = omit, + before: str | Omit = omit, + file_types: str | Omit = omit, + graph_id: str | Omit = omit, + limit: int | Omit = omit, + order: Literal["asc", "desc"] | Omit = omit, + status: Literal["in_progress", "completed", "failed"] | Omit = omit, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> SyncCursorPage[File]: + """ + Retrieve a paginated list of files with optional filtering by status, graph + association, and file type. + + Args: + after: The ID of the last object in the previous page. This parameter instructs the API + to return the next page of results. + + before: The ID of the first object in the previous page. This parameter instructs the + API to return the previous page of results. + + file_types: The extensions of the files to retrieve. Separate multiple extensions with a + comma. For example: `pdf,jpg,docx`. + + graph_id: The unique identifier of the graph to which the files belong. + + limit: Specifies the maximum number of objects returned in a page. The default value + is 50. The minimum value is 1, and the maximum value is 100. + + order: Specifies the order of the results. Valid values are asc for ascending and desc + for descending. + + status: Specifies the status of the files to retrieve. Valid values are in_progress, + completed or failed. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + return self._get_api_list( + "/v1/files", + page=SyncCursorPage[File], + options=make_request_options( + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + query=maybe_transform( + { + "after": after, + "before": before, + "file_types": file_types, + "graph_id": graph_id, + "limit": limit, + "order": order, + "status": status, + }, + file_list_params.FileListParams, + ), + ), + model=File, + ) + + def delete( + self, + file_id: str, + *, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> FileDeleteResponse: + """Permanently delete a file from the system. + + This action cannot be undone. + + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not file_id: + raise ValueError(f"Expected a non-empty value for `file_id` but received {file_id!r}") + return self._delete( + path_template("/v1/files/{file_id}", file_id=file_id), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=FileDeleteResponse, + ) + + def download( + self, + file_id: str, + *, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> BinaryAPIResponse: + """Download the binary content of a file. + + The response will contain the file data + in the appropriate MIME type. + + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not file_id: + raise ValueError(f"Expected a non-empty value for `file_id` but received {file_id!r}") + extra_headers = {"Accept": "application/octet-stream", **(extra_headers or {})} + return self._get( + path_template("/v1/files/{file_id}/download", file_id=file_id), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=BinaryAPIResponse, + ) + + def retry( + self, + *, + file_ids: SequenceNotStr[str], + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> FileRetryResponse: + """Retry processing of files that previously failed to process. + + This will + re-attempt the processing of the specified files. + + Args: + file_ids: The unique identifier of the files to retry. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + return self._post( + "/v1/files/retry", + body=maybe_transform({"file_ids": file_ids}, file_retry_params.FileRetryParams), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=FileRetryResponse, + ) + + def upload( + self, + content: FileContent | BinaryTypes, + *, + content_disposition: str, + graph_id: str | Omit = omit, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> File: + """Upload a new file to the system. + + Supports various file formats including PDF, + DOC, DOCX, PPT, PPTX, JPG, PNG, EML, HTML, SRT, CSV, XLS, and XLSX. + + Args: + graph_id: The unique identifier of the Knowledge Graph to associate the uploaded file + with. + + Note: The response from the upload endpoint does not include the `graphId` + field, but the association will be visible when you retrieve the file using the + file retrieval endpoint. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + extra_headers = {"Content-Disposition": content_disposition, **(extra_headers or {})} + extra_headers["Content-Type"] = "text/plain" + return self._post( + "/v1/files", + content=read_file_content(content) if isinstance(content, os.PathLike) else content, + options=make_request_options( + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + query=maybe_transform({"graph_id": graph_id}, file_upload_params.FileUploadParams), + ), + cast_to=File, + ) + + +class AsyncFilesResource(AsyncAPIResource): + @cached_property + def with_raw_response(self) -> AsyncFilesResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/writer/writer-python#accessing-raw-response-data-eg-headers + """ + return AsyncFilesResourceWithRawResponse(self) + + @cached_property + def with_streaming_response(self) -> AsyncFilesResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/writer/writer-python#with_streaming_response + """ + return AsyncFilesResourceWithStreamingResponse(self) + + async def retrieve( + self, + file_id: str, + *, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> File: + """ + Retrieve detailed information about a specific file, including its metadata, + status, and associated graphs. + + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not file_id: + raise ValueError(f"Expected a non-empty value for `file_id` but received {file_id!r}") + return await self._get( + path_template("/v1/files/{file_id}", file_id=file_id), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=File, + ) + + def list( + self, + *, + after: str | Omit = omit, + before: str | Omit = omit, + file_types: str | Omit = omit, + graph_id: str | Omit = omit, + limit: int | Omit = omit, + order: Literal["asc", "desc"] | Omit = omit, + status: Literal["in_progress", "completed", "failed"] | Omit = omit, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> AsyncPaginator[File, AsyncCursorPage[File]]: + """ + Retrieve a paginated list of files with optional filtering by status, graph + association, and file type. + + Args: + after: The ID of the last object in the previous page. This parameter instructs the API + to return the next page of results. + + before: The ID of the first object in the previous page. This parameter instructs the + API to return the previous page of results. + + file_types: The extensions of the files to retrieve. Separate multiple extensions with a + comma. For example: `pdf,jpg,docx`. + + graph_id: The unique identifier of the graph to which the files belong. + + limit: Specifies the maximum number of objects returned in a page. The default value + is 50. The minimum value is 1, and the maximum value is 100. + + order: Specifies the order of the results. Valid values are asc for ascending and desc + for descending. + + status: Specifies the status of the files to retrieve. Valid values are in_progress, + completed or failed. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + return self._get_api_list( + "/v1/files", + page=AsyncCursorPage[File], + options=make_request_options( + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + query=maybe_transform( + { + "after": after, + "before": before, + "file_types": file_types, + "graph_id": graph_id, + "limit": limit, + "order": order, + "status": status, + }, + file_list_params.FileListParams, + ), + ), + model=File, + ) + + async def delete( + self, + file_id: str, + *, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> FileDeleteResponse: + """Permanently delete a file from the system. + + This action cannot be undone. + + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not file_id: + raise ValueError(f"Expected a non-empty value for `file_id` but received {file_id!r}") + return await self._delete( + path_template("/v1/files/{file_id}", file_id=file_id), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=FileDeleteResponse, + ) + + async def download( + self, + file_id: str, + *, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> AsyncBinaryAPIResponse: + """Download the binary content of a file. + + The response will contain the file data + in the appropriate MIME type. + + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not file_id: + raise ValueError(f"Expected a non-empty value for `file_id` but received {file_id!r}") + extra_headers = {"Accept": "application/octet-stream", **(extra_headers or {})} + return await self._get( + path_template("/v1/files/{file_id}/download", file_id=file_id), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=AsyncBinaryAPIResponse, + ) + + async def retry( + self, + *, + file_ids: SequenceNotStr[str], + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> FileRetryResponse: + """Retry processing of files that previously failed to process. + + This will + re-attempt the processing of the specified files. + + Args: + file_ids: The unique identifier of the files to retry. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + return await self._post( + "/v1/files/retry", + body=await async_maybe_transform({"file_ids": file_ids}, file_retry_params.FileRetryParams), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=FileRetryResponse, + ) + + async def upload( + self, + content: FileContent | AsyncBinaryTypes, + *, + content_disposition: str, + graph_id: str | Omit = omit, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> File: + """Upload a new file to the system. + + Supports various file formats including PDF, + DOC, DOCX, PPT, PPTX, JPG, PNG, EML, HTML, SRT, CSV, XLS, and XLSX. + + Args: + graph_id: The unique identifier of the Knowledge Graph to associate the uploaded file + with. + + Note: The response from the upload endpoint does not include the `graphId` + field, but the association will be visible when you retrieve the file using the + file retrieval endpoint. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + extra_headers = {"Content-Disposition": content_disposition, **(extra_headers or {})} + extra_headers["Content-Type"] = "text/plain" + return await self._post( + "/v1/files", + content=await async_read_file_content(content) if isinstance(content, os.PathLike) else content, + options=make_request_options( + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + query=await async_maybe_transform({"graph_id": graph_id}, file_upload_params.FileUploadParams), + ), + cast_to=File, + ) + + +class FilesResourceWithRawResponse: + def __init__(self, files: FilesResource) -> None: + self._files = files + + self.retrieve = to_raw_response_wrapper( + files.retrieve, + ) + self.list = to_raw_response_wrapper( + files.list, + ) + self.delete = to_raw_response_wrapper( + files.delete, + ) + self.download = to_custom_raw_response_wrapper( + files.download, + BinaryAPIResponse, + ) + self.retry = to_raw_response_wrapper( + files.retry, + ) + self.upload = to_raw_response_wrapper( + files.upload, + ) + + +class AsyncFilesResourceWithRawResponse: + def __init__(self, files: AsyncFilesResource) -> None: + self._files = files + + self.retrieve = async_to_raw_response_wrapper( + files.retrieve, + ) + self.list = async_to_raw_response_wrapper( + files.list, + ) + self.delete = async_to_raw_response_wrapper( + files.delete, + ) + self.download = async_to_custom_raw_response_wrapper( + files.download, + AsyncBinaryAPIResponse, + ) + self.retry = async_to_raw_response_wrapper( + files.retry, + ) + self.upload = async_to_raw_response_wrapper( + files.upload, + ) + + +class FilesResourceWithStreamingResponse: + def __init__(self, files: FilesResource) -> None: + self._files = files + + self.retrieve = to_streamed_response_wrapper( + files.retrieve, + ) + self.list = to_streamed_response_wrapper( + files.list, + ) + self.delete = to_streamed_response_wrapper( + files.delete, + ) + self.download = to_custom_streamed_response_wrapper( + files.download, + StreamedBinaryAPIResponse, + ) + self.retry = to_streamed_response_wrapper( + files.retry, + ) + self.upload = to_streamed_response_wrapper( + files.upload, + ) + + +class AsyncFilesResourceWithStreamingResponse: + def __init__(self, files: AsyncFilesResource) -> None: + self._files = files + + self.retrieve = async_to_streamed_response_wrapper( + files.retrieve, + ) + self.list = async_to_streamed_response_wrapper( + files.list, + ) + self.delete = async_to_streamed_response_wrapper( + files.delete, + ) + self.download = async_to_custom_streamed_response_wrapper( + files.download, + AsyncStreamedBinaryAPIResponse, + ) + self.retry = async_to_streamed_response_wrapper( + files.retry, + ) + self.upload = async_to_streamed_response_wrapper( + files.upload, + ) diff --git a/src/writerai/resources/graphs.py b/src/writerai/resources/graphs.py new file mode 100644 index 00000000..1a78bcb2 --- /dev/null +++ b/src/writerai/resources/graphs.py @@ -0,0 +1,1131 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing import Iterable +from typing_extensions import Literal, overload + +import httpx + +from ..types import ( + graph_list_params, + graph_create_params, + graph_update_params, + graph_question_params, + graph_add_file_to_graph_params, +) +from .._types import Body, Omit, Query, Headers, NotGiven, SequenceNotStr, omit, not_given +from .._utils import path_template, required_args, maybe_transform, async_maybe_transform +from .._compat import cached_property +from .._resource import SyncAPIResource, AsyncAPIResource +from .._response import ( + to_raw_response_wrapper, + to_streamed_response_wrapper, + async_to_raw_response_wrapper, + async_to_streamed_response_wrapper, +) +from .._streaming import Stream, AsyncStream +from ..pagination import SyncCursorPage, AsyncCursorPage +from ..types.file import File +from ..types.graph import Graph +from .._base_client import AsyncPaginator, make_request_options +from ..types.question import Question +from ..types.graph_create_response import GraphCreateResponse +from ..types.graph_delete_response import GraphDeleteResponse +from ..types.graph_update_response import GraphUpdateResponse +from ..types.question_response_chunk import QuestionResponseChunk +from ..types.graph_remove_file_from_graph_response import GraphRemoveFileFromGraphResponse + +__all__ = ["GraphsResource", "AsyncGraphsResource"] + + +class GraphsResource(SyncAPIResource): + @cached_property + def with_raw_response(self) -> GraphsResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/writer/writer-python#accessing-raw-response-data-eg-headers + """ + return GraphsResourceWithRawResponse(self) + + @cached_property + def with_streaming_response(self) -> GraphsResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/writer/writer-python#with_streaming_response + """ + return GraphsResourceWithStreamingResponse(self) + + def create( + self, + *, + description: str | Omit = omit, + name: str | Omit = omit, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> GraphCreateResponse: + """ + Create a new Knowledge Graph. + + Args: + description: A description of the Knowledge Graph (max 255 characters). Omitting this field + leaves the description unchanged. + + name: The name of the Knowledge Graph (max 255 characters). Omitting this field leaves + the name unchanged. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + return self._post( + "/v1/graphs", + body=maybe_transform( + { + "description": description, + "name": name, + }, + graph_create_params.GraphCreateParams, + ), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=GraphCreateResponse, + ) + + def retrieve( + self, + graph_id: str, + *, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> Graph: + """ + Retrieve a Knowledge Graph. + + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not graph_id: + raise ValueError(f"Expected a non-empty value for `graph_id` but received {graph_id!r}") + return self._get( + path_template("/v1/graphs/{graph_id}", graph_id=graph_id), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=Graph, + ) + + def update( + self, + graph_id: str, + *, + description: str | Omit = omit, + name: str | Omit = omit, + urls: Iterable[graph_update_params.URL] | Omit = omit, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> GraphUpdateResponse: + """ + Update the name and description of a Knowledge Graph. + + Args: + description: A description of the Knowledge Graph (max 255 characters). Omitting this field + leaves the description unchanged. + + name: The name of the Knowledge Graph (max 255 characters). Omitting this field leaves + the name unchanged. + + urls: An array of web connector URLs to update for this Knowledge Graph. You can only + connect URLs to Knowledge Graphs with the type `web`. To clear the list of URLs, + set this field to an empty array. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not graph_id: + raise ValueError(f"Expected a non-empty value for `graph_id` but received {graph_id!r}") + return self._put( + path_template("/v1/graphs/{graph_id}", graph_id=graph_id), + body=maybe_transform( + { + "description": description, + "name": name, + "urls": urls, + }, + graph_update_params.GraphUpdateParams, + ), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=GraphUpdateResponse, + ) + + def list( + self, + *, + after: str | Omit = omit, + before: str | Omit = omit, + limit: int | Omit = omit, + order: Literal["asc", "desc"] | Omit = omit, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> SyncCursorPage[Graph]: + """ + Retrieve a list of Knowledge Graphs. + + Args: + after: The ID of the last object in the previous page. This parameter instructs the API + to return the next page of results. + + before: The ID of the first object in the previous page. This parameter instructs the + API to return the previous page of results. + + limit: Specifies the maximum number of objects returned in a page. The default value + is 50. The minimum value is 1, and the maximum value is 100. + + order: Specifies the order of the results. Valid values are asc for ascending and desc + for descending. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + return self._get_api_list( + "/v1/graphs", + page=SyncCursorPage[Graph], + options=make_request_options( + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + query=maybe_transform( + { + "after": after, + "before": before, + "limit": limit, + "order": order, + }, + graph_list_params.GraphListParams, + ), + ), + model=Graph, + ) + + def delete( + self, + graph_id: str, + *, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> GraphDeleteResponse: + """ + Delete a Knowledge Graph. + + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not graph_id: + raise ValueError(f"Expected a non-empty value for `graph_id` but received {graph_id!r}") + return self._delete( + path_template("/v1/graphs/{graph_id}", graph_id=graph_id), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=GraphDeleteResponse, + ) + + def add_file_to_graph( + self, + graph_id: str, + *, + file_id: str, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> File: + """ + Add a file to a Knowledge Graph. + + Args: + file_id: The unique identifier of the file. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not graph_id: + raise ValueError(f"Expected a non-empty value for `graph_id` but received {graph_id!r}") + return self._post( + path_template("/v1/graphs/{graph_id}/file", graph_id=graph_id), + body=maybe_transform({"file_id": file_id}, graph_add_file_to_graph_params.GraphAddFileToGraphParams), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=File, + ) + + @overload + def question( + self, + *, + graph_ids: SequenceNotStr[str], + question: str, + query_config: graph_question_params.QueryConfig | Omit = omit, + stream: Literal[False] | Omit = omit, + subqueries: bool | Omit = omit, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> Question: + """ + Ask a question to specified Knowledge Graphs. + + Args: + graph_ids: The unique identifiers of the Knowledge Graphs to query. + + question: The question to answer using the Knowledge Graph. + + query_config: Configuration options for Knowledge Graph queries, including search parameters + and citation settings. + + stream: Determines whether the model's output should be streamed. If true, the output is + generated and sent incrementally, which can be useful for real-time + applications. + + subqueries: Specify whether to include subqueries. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + ... + + @overload + def question( + self, + *, + graph_ids: SequenceNotStr[str], + question: str, + stream: Literal[True], + query_config: graph_question_params.QueryConfig | Omit = omit, + subqueries: bool | Omit = omit, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> Stream[QuestionResponseChunk]: + """ + Ask a question to specified Knowledge Graphs. + + Args: + graph_ids: The unique identifiers of the Knowledge Graphs to query. + + question: The question to answer using the Knowledge Graph. + + stream: Determines whether the model's output should be streamed. If true, the output is + generated and sent incrementally, which can be useful for real-time + applications. + + query_config: Configuration options for Knowledge Graph queries, including search parameters + and citation settings. + + subqueries: Specify whether to include subqueries. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + ... + + @overload + def question( + self, + *, + graph_ids: SequenceNotStr[str], + question: str, + stream: bool, + query_config: graph_question_params.QueryConfig | Omit = omit, + subqueries: bool | Omit = omit, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> Question | Stream[QuestionResponseChunk]: + """ + Ask a question to specified Knowledge Graphs. + + Args: + graph_ids: The unique identifiers of the Knowledge Graphs to query. + + question: The question to answer using the Knowledge Graph. + + stream: Determines whether the model's output should be streamed. If true, the output is + generated and sent incrementally, which can be useful for real-time + applications. + + query_config: Configuration options for Knowledge Graph queries, including search parameters + and citation settings. + + subqueries: Specify whether to include subqueries. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + ... + + @required_args(["graph_ids", "question"], ["graph_ids", "question", "stream"]) + def question( + self, + *, + graph_ids: SequenceNotStr[str], + question: str, + query_config: graph_question_params.QueryConfig | Omit = omit, + stream: Literal[False] | Literal[True] | Omit = omit, + subqueries: bool | Omit = omit, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> Question | Stream[QuestionResponseChunk]: + return self._post( + "/v1/graphs/question", + body=maybe_transform( + { + "graph_ids": graph_ids, + "question": question, + "query_config": query_config, + "stream": stream, + "subqueries": subqueries, + }, + graph_question_params.GraphQuestionParamsStreaming + if stream + else graph_question_params.GraphQuestionParamsNonStreaming, + ), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=Question, + stream=stream or False, + stream_cls=Stream[QuestionResponseChunk], + ) + + def remove_file_from_graph( + self, + file_id: str, + *, + graph_id: str, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> GraphRemoveFileFromGraphResponse: + """ + Remove a file from a Knowledge Graph. + + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not graph_id: + raise ValueError(f"Expected a non-empty value for `graph_id` but received {graph_id!r}") + if not file_id: + raise ValueError(f"Expected a non-empty value for `file_id` but received {file_id!r}") + return self._delete( + path_template("/v1/graphs/{graph_id}/file/{file_id}", graph_id=graph_id, file_id=file_id), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=GraphRemoveFileFromGraphResponse, + ) + + +class AsyncGraphsResource(AsyncAPIResource): + @cached_property + def with_raw_response(self) -> AsyncGraphsResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/writer/writer-python#accessing-raw-response-data-eg-headers + """ + return AsyncGraphsResourceWithRawResponse(self) + + @cached_property + def with_streaming_response(self) -> AsyncGraphsResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/writer/writer-python#with_streaming_response + """ + return AsyncGraphsResourceWithStreamingResponse(self) + + async def create( + self, + *, + description: str | Omit = omit, + name: str | Omit = omit, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> GraphCreateResponse: + """ + Create a new Knowledge Graph. + + Args: + description: A description of the Knowledge Graph (max 255 characters). Omitting this field + leaves the description unchanged. + + name: The name of the Knowledge Graph (max 255 characters). Omitting this field leaves + the name unchanged. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + return await self._post( + "/v1/graphs", + body=await async_maybe_transform( + { + "description": description, + "name": name, + }, + graph_create_params.GraphCreateParams, + ), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=GraphCreateResponse, + ) + + async def retrieve( + self, + graph_id: str, + *, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> Graph: + """ + Retrieve a Knowledge Graph. + + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not graph_id: + raise ValueError(f"Expected a non-empty value for `graph_id` but received {graph_id!r}") + return await self._get( + path_template("/v1/graphs/{graph_id}", graph_id=graph_id), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=Graph, + ) + + async def update( + self, + graph_id: str, + *, + description: str | Omit = omit, + name: str | Omit = omit, + urls: Iterable[graph_update_params.URL] | Omit = omit, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> GraphUpdateResponse: + """ + Update the name and description of a Knowledge Graph. + + Args: + description: A description of the Knowledge Graph (max 255 characters). Omitting this field + leaves the description unchanged. + + name: The name of the Knowledge Graph (max 255 characters). Omitting this field leaves + the name unchanged. + + urls: An array of web connector URLs to update for this Knowledge Graph. You can only + connect URLs to Knowledge Graphs with the type `web`. To clear the list of URLs, + set this field to an empty array. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not graph_id: + raise ValueError(f"Expected a non-empty value for `graph_id` but received {graph_id!r}") + return await self._put( + path_template("/v1/graphs/{graph_id}", graph_id=graph_id), + body=await async_maybe_transform( + { + "description": description, + "name": name, + "urls": urls, + }, + graph_update_params.GraphUpdateParams, + ), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=GraphUpdateResponse, + ) + + def list( + self, + *, + after: str | Omit = omit, + before: str | Omit = omit, + limit: int | Omit = omit, + order: Literal["asc", "desc"] | Omit = omit, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> AsyncPaginator[Graph, AsyncCursorPage[Graph]]: + """ + Retrieve a list of Knowledge Graphs. + + Args: + after: The ID of the last object in the previous page. This parameter instructs the API + to return the next page of results. + + before: The ID of the first object in the previous page. This parameter instructs the + API to return the previous page of results. + + limit: Specifies the maximum number of objects returned in a page. The default value + is 50. The minimum value is 1, and the maximum value is 100. + + order: Specifies the order of the results. Valid values are asc for ascending and desc + for descending. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + return self._get_api_list( + "/v1/graphs", + page=AsyncCursorPage[Graph], + options=make_request_options( + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + query=maybe_transform( + { + "after": after, + "before": before, + "limit": limit, + "order": order, + }, + graph_list_params.GraphListParams, + ), + ), + model=Graph, + ) + + async def delete( + self, + graph_id: str, + *, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> GraphDeleteResponse: + """ + Delete a Knowledge Graph. + + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not graph_id: + raise ValueError(f"Expected a non-empty value for `graph_id` but received {graph_id!r}") + return await self._delete( + path_template("/v1/graphs/{graph_id}", graph_id=graph_id), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=GraphDeleteResponse, + ) + + async def add_file_to_graph( + self, + graph_id: str, + *, + file_id: str, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> File: + """ + Add a file to a Knowledge Graph. + + Args: + file_id: The unique identifier of the file. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not graph_id: + raise ValueError(f"Expected a non-empty value for `graph_id` but received {graph_id!r}") + return await self._post( + path_template("/v1/graphs/{graph_id}/file", graph_id=graph_id), + body=await async_maybe_transform( + {"file_id": file_id}, graph_add_file_to_graph_params.GraphAddFileToGraphParams + ), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=File, + ) + + @overload + async def question( + self, + *, + graph_ids: SequenceNotStr[str], + question: str, + query_config: graph_question_params.QueryConfig | Omit = omit, + stream: Literal[False] | Omit = omit, + subqueries: bool | Omit = omit, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> Question: + """ + Ask a question to specified Knowledge Graphs. + + Args: + graph_ids: The unique identifiers of the Knowledge Graphs to query. + + question: The question to answer using the Knowledge Graph. + + query_config: Configuration options for Knowledge Graph queries, including search parameters + and citation settings. + + stream: Determines whether the model's output should be streamed. If true, the output is + generated and sent incrementally, which can be useful for real-time + applications. + + subqueries: Specify whether to include subqueries. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + ... + + @overload + async def question( + self, + *, + graph_ids: SequenceNotStr[str], + question: str, + stream: Literal[True], + query_config: graph_question_params.QueryConfig | Omit = omit, + subqueries: bool | Omit = omit, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> AsyncStream[QuestionResponseChunk]: + """ + Ask a question to specified Knowledge Graphs. + + Args: + graph_ids: The unique identifiers of the Knowledge Graphs to query. + + question: The question to answer using the Knowledge Graph. + + stream: Determines whether the model's output should be streamed. If true, the output is + generated and sent incrementally, which can be useful for real-time + applications. + + query_config: Configuration options for Knowledge Graph queries, including search parameters + and citation settings. + + subqueries: Specify whether to include subqueries. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + ... + + @overload + async def question( + self, + *, + graph_ids: SequenceNotStr[str], + question: str, + stream: bool, + query_config: graph_question_params.QueryConfig | Omit = omit, + subqueries: bool | Omit = omit, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> Question | AsyncStream[QuestionResponseChunk]: + """ + Ask a question to specified Knowledge Graphs. + + Args: + graph_ids: The unique identifiers of the Knowledge Graphs to query. + + question: The question to answer using the Knowledge Graph. + + stream: Determines whether the model's output should be streamed. If true, the output is + generated and sent incrementally, which can be useful for real-time + applications. + + query_config: Configuration options for Knowledge Graph queries, including search parameters + and citation settings. + + subqueries: Specify whether to include subqueries. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + ... + + @required_args(["graph_ids", "question"], ["graph_ids", "question", "stream"]) + async def question( + self, + *, + graph_ids: SequenceNotStr[str], + question: str, + query_config: graph_question_params.QueryConfig | Omit = omit, + stream: Literal[False] | Literal[True] | Omit = omit, + subqueries: bool | Omit = omit, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> Question | AsyncStream[QuestionResponseChunk]: + return await self._post( + "/v1/graphs/question", + body=await async_maybe_transform( + { + "graph_ids": graph_ids, + "question": question, + "query_config": query_config, + "stream": stream, + "subqueries": subqueries, + }, + graph_question_params.GraphQuestionParamsStreaming + if stream + else graph_question_params.GraphQuestionParamsNonStreaming, + ), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=Question, + stream=stream or False, + stream_cls=AsyncStream[QuestionResponseChunk], + ) + + async def remove_file_from_graph( + self, + file_id: str, + *, + graph_id: str, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> GraphRemoveFileFromGraphResponse: + """ + Remove a file from a Knowledge Graph. + + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not graph_id: + raise ValueError(f"Expected a non-empty value for `graph_id` but received {graph_id!r}") + if not file_id: + raise ValueError(f"Expected a non-empty value for `file_id` but received {file_id!r}") + return await self._delete( + path_template("/v1/graphs/{graph_id}/file/{file_id}", graph_id=graph_id, file_id=file_id), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=GraphRemoveFileFromGraphResponse, + ) + + +class GraphsResourceWithRawResponse: + def __init__(self, graphs: GraphsResource) -> None: + self._graphs = graphs + + self.create = to_raw_response_wrapper( + graphs.create, + ) + self.retrieve = to_raw_response_wrapper( + graphs.retrieve, + ) + self.update = to_raw_response_wrapper( + graphs.update, + ) + self.list = to_raw_response_wrapper( + graphs.list, + ) + self.delete = to_raw_response_wrapper( + graphs.delete, + ) + self.add_file_to_graph = to_raw_response_wrapper( + graphs.add_file_to_graph, + ) + self.question = to_raw_response_wrapper( + graphs.question, + ) + self.remove_file_from_graph = to_raw_response_wrapper( + graphs.remove_file_from_graph, + ) + + +class AsyncGraphsResourceWithRawResponse: + def __init__(self, graphs: AsyncGraphsResource) -> None: + self._graphs = graphs + + self.create = async_to_raw_response_wrapper( + graphs.create, + ) + self.retrieve = async_to_raw_response_wrapper( + graphs.retrieve, + ) + self.update = async_to_raw_response_wrapper( + graphs.update, + ) + self.list = async_to_raw_response_wrapper( + graphs.list, + ) + self.delete = async_to_raw_response_wrapper( + graphs.delete, + ) + self.add_file_to_graph = async_to_raw_response_wrapper( + graphs.add_file_to_graph, + ) + self.question = async_to_raw_response_wrapper( + graphs.question, + ) + self.remove_file_from_graph = async_to_raw_response_wrapper( + graphs.remove_file_from_graph, + ) + + +class GraphsResourceWithStreamingResponse: + def __init__(self, graphs: GraphsResource) -> None: + self._graphs = graphs + + self.create = to_streamed_response_wrapper( + graphs.create, + ) + self.retrieve = to_streamed_response_wrapper( + graphs.retrieve, + ) + self.update = to_streamed_response_wrapper( + graphs.update, + ) + self.list = to_streamed_response_wrapper( + graphs.list, + ) + self.delete = to_streamed_response_wrapper( + graphs.delete, + ) + self.add_file_to_graph = to_streamed_response_wrapper( + graphs.add_file_to_graph, + ) + self.question = to_streamed_response_wrapper( + graphs.question, + ) + self.remove_file_from_graph = to_streamed_response_wrapper( + graphs.remove_file_from_graph, + ) + + +class AsyncGraphsResourceWithStreamingResponse: + def __init__(self, graphs: AsyncGraphsResource) -> None: + self._graphs = graphs + + self.create = async_to_streamed_response_wrapper( + graphs.create, + ) + self.retrieve = async_to_streamed_response_wrapper( + graphs.retrieve, + ) + self.update = async_to_streamed_response_wrapper( + graphs.update, + ) + self.list = async_to_streamed_response_wrapper( + graphs.list, + ) + self.delete = async_to_streamed_response_wrapper( + graphs.delete, + ) + self.add_file_to_graph = async_to_streamed_response_wrapper( + graphs.add_file_to_graph, + ) + self.question = async_to_streamed_response_wrapper( + graphs.question, + ) + self.remove_file_from_graph = async_to_streamed_response_wrapper( + graphs.remove_file_from_graph, + ) diff --git a/src/writerai/resources/models.py b/src/writerai/resources/models.py new file mode 100644 index 00000000..d7bc28a6 --- /dev/null +++ b/src/writerai/resources/models.py @@ -0,0 +1,141 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +import httpx + +from .._types import Body, Query, Headers, NotGiven, not_given +from .._compat import cached_property +from .._resource import SyncAPIResource, AsyncAPIResource +from .._response import ( + to_raw_response_wrapper, + to_streamed_response_wrapper, + async_to_raw_response_wrapper, + async_to_streamed_response_wrapper, +) +from .._base_client import make_request_options +from ..types.model_list_response import ModelListResponse + +__all__ = ["ModelsResource", "AsyncModelsResource"] + + +class ModelsResource(SyncAPIResource): + @cached_property + def with_raw_response(self) -> ModelsResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/writer/writer-python#accessing-raw-response-data-eg-headers + """ + return ModelsResourceWithRawResponse(self) + + @cached_property + def with_streaming_response(self) -> ModelsResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/writer/writer-python#with_streaming_response + """ + return ModelsResourceWithStreamingResponse(self) + + def list( + self, + *, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> ModelListResponse: + """ + Retrieve a list of available models that can be used for text generation, chat + completions, and other AI tasks. + """ + return self._get( + "/v1/models", + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=ModelListResponse, + ) + + +class AsyncModelsResource(AsyncAPIResource): + @cached_property + def with_raw_response(self) -> AsyncModelsResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/writer/writer-python#accessing-raw-response-data-eg-headers + """ + return AsyncModelsResourceWithRawResponse(self) + + @cached_property + def with_streaming_response(self) -> AsyncModelsResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/writer/writer-python#with_streaming_response + """ + return AsyncModelsResourceWithStreamingResponse(self) + + async def list( + self, + *, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> ModelListResponse: + """ + Retrieve a list of available models that can be used for text generation, chat + completions, and other AI tasks. + """ + return await self._get( + "/v1/models", + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=ModelListResponse, + ) + + +class ModelsResourceWithRawResponse: + def __init__(self, models: ModelsResource) -> None: + self._models = models + + self.list = to_raw_response_wrapper( + models.list, + ) + + +class AsyncModelsResourceWithRawResponse: + def __init__(self, models: AsyncModelsResource) -> None: + self._models = models + + self.list = async_to_raw_response_wrapper( + models.list, + ) + + +class ModelsResourceWithStreamingResponse: + def __init__(self, models: ModelsResource) -> None: + self._models = models + + self.list = to_streamed_response_wrapper( + models.list, + ) + + +class AsyncModelsResourceWithStreamingResponse: + def __init__(self, models: AsyncModelsResource) -> None: + self._models = models + + self.list = async_to_streamed_response_wrapper( + models.list, + ) diff --git a/src/writerai/resources/tools.py b/src/writerai/resources/tools.py new file mode 100644 index 00000000..8fd00630 --- /dev/null +++ b/src/writerai/resources/tools.py @@ -0,0 +1,764 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +import typing_extensions +from typing import Union +from typing_extensions import Literal + +import httpx + +from ..types import tool_parse_pdf_params, tool_web_search_params +from .._types import Body, Omit, Query, Headers, NotGiven, SequenceNotStr, omit, not_given +from .._utils import path_template, maybe_transform, async_maybe_transform +from .._compat import cached_property +from .._resource import SyncAPIResource, AsyncAPIResource +from .._response import ( + to_raw_response_wrapper, + to_streamed_response_wrapper, + async_to_raw_response_wrapper, + async_to_streamed_response_wrapper, +) +from .._base_client import make_request_options +from ..types.tool_parse_pdf_response import ToolParsePdfResponse +from ..types.tool_web_search_response import ToolWebSearchResponse + +__all__ = ["ToolsResource", "AsyncToolsResource"] + + +class ToolsResource(SyncAPIResource): + @cached_property + def with_raw_response(self) -> ToolsResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/writer/writer-python#accessing-raw-response-data-eg-headers + """ + return ToolsResourceWithRawResponse(self) + + @cached_property + def with_streaming_response(self) -> ToolsResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/writer/writer-python#with_streaming_response + """ + return ToolsResourceWithStreamingResponse(self) + + @typing_extensions.deprecated( + "Will be removed in a future release. A replacement PDF parsing tool for chat completions is planned; see documentation at dev.writer.com for more information." + ) + def parse_pdf( + self, + file_id: str, + *, + format: Literal["text", "markdown"], + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> ToolParsePdfResponse: + """ + Parse PDF to other formats. + + Args: + format: The format into which the PDF content should be converted. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not file_id: + raise ValueError(f"Expected a non-empty value for `file_id` but received {file_id!r}") + return self._post( + path_template("/v1/tools/pdf-parser/{file_id}", file_id=file_id), + body=maybe_transform({"format": format}, tool_parse_pdf_params.ToolParsePdfParams), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=ToolParsePdfResponse, + ) + + @typing_extensions.deprecated( + "Will be removed in a future release. Migrate to `chat.chat` with the web search tool for web search capabilities. See documentation at dev.writer.com for more information." + ) + def web_search( + self, + *, + chunks_per_source: int | Omit = omit, + country: Literal[ + "afghanistan", + "albania", + "algeria", + "andorra", + "angola", + "argentina", + "armenia", + "australia", + "austria", + "azerbaijan", + "bahamas", + "bahrain", + "bangladesh", + "barbados", + "belarus", + "belgium", + "belize", + "benin", + "bhutan", + "bolivia", + "bosnia and herzegovina", + "botswana", + "brazil", + "brunei", + "bulgaria", + "burkina faso", + "burundi", + "cambodia", + "cameroon", + "canada", + "cape verde", + "central african republic", + "chad", + "chile", + "china", + "colombia", + "comoros", + "congo", + "costa rica", + "croatia", + "cuba", + "cyprus", + "czech republic", + "denmark", + "djibouti", + "dominican republic", + "ecuador", + "egypt", + "el salvador", + "equatorial guinea", + "eritrea", + "estonia", + "ethiopia", + "fiji", + "finland", + "france", + "gabon", + "gambia", + "georgia", + "germany", + "ghana", + "greece", + "guatemala", + "guinea", + "haiti", + "honduras", + "hungary", + "iceland", + "india", + "indonesia", + "iran", + "iraq", + "ireland", + "israel", + "italy", + "jamaica", + "japan", + "jordan", + "kazakhstan", + "kenya", + "kuwait", + "kyrgyzstan", + "latvia", + "lebanon", + "lesotho", + "liberia", + "libya", + "liechtenstein", + "lithuania", + "luxembourg", + "madagascar", + "malawi", + "malaysia", + "maldives", + "mali", + "malta", + "mauritania", + "mauritius", + "mexico", + "moldova", + "monaco", + "mongolia", + "montenegro", + "morocco", + "mozambique", + "myanmar", + "namibia", + "nepal", + "netherlands", + "new zealand", + "nicaragua", + "niger", + "nigeria", + "north korea", + "north macedonia", + "norway", + "oman", + "pakistan", + "panama", + "papua new guinea", + "paraguay", + "peru", + "philippines", + "poland", + "portugal", + "qatar", + "romania", + "russia", + "rwanda", + "saudi arabia", + "senegal", + "serbia", + "singapore", + "slovakia", + "slovenia", + "somalia", + "south africa", + "south korea", + "south sudan", + "spain", + "sri lanka", + "sudan", + "sweden", + "switzerland", + "syria", + "taiwan", + "tajikistan", + "tanzania", + "thailand", + "togo", + "trinidad and tobago", + "tunisia", + "turkey", + "turkmenistan", + "uganda", + "ukraine", + "united arab emirates", + "united kingdom", + "united states", + "uruguay", + "uzbekistan", + "venezuela", + "vietnam", + "yemen", + "zambia", + "zimbabwe", + ] + | Omit = omit, + days: int | Omit = omit, + exclude_domains: SequenceNotStr[str] | Omit = omit, + include_answer: bool | Omit = omit, + include_domains: SequenceNotStr[str] | Omit = omit, + include_raw_content: Union[Literal["text", "markdown"], bool] | Omit = omit, + max_results: int | Omit = omit, + query: str | Omit = omit, + search_depth: Literal["basic", "advanced"] | Omit = omit, + stream: bool | Omit = omit, + time_range: Literal["day", "week", "month", "year", "d", "w", "m", "y"] | Omit = omit, + topic: Literal["general", "news"] | Omit = omit, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> ToolWebSearchResponse: + """ + Search the web for information about a given query and return relevant results + with source URLs. + + Args: + chunks_per_source: Only applies when `search_depth` is `advanced`. Specifies how many text segments + to extract from each source. Limited to 3 chunks maximum. + + country: Localizes search results to a specific country. Only applies to general topic + searches. + + days: For news topic searches, specifies how many days of news coverage to include. + + exclude_domains: Domains to exclude from the search. If unset, the search includes all domains. + + include_answer: Whether to include a generated answer to the query in the response. If `false`, + only search results are returned. + + include_domains: Domains to include in the search. If unset, the search includes all domains. + + include_raw_content: + Controls how raw content is included in search results: + + - `text`: Returns plain text without formatting markup + - `markdown`: Returns structured content with markdown formatting (headers, + links, bold text) + - `true`: Same as `markdown` + - `false`: Raw content is not included (default if unset) + + max_results: Limits the number of search results returned. Cannot exceed 20 sources. + + query: The search query. + + search_depth: + Controls search comprehensiveness: + + - `basic`: Returns fewer but highly relevant results + - `advanced`: Performs a deeper search with more results + + stream: Enables streaming of search results as they become available. + + time_range: Filters results to content published within the specified time range back from + the current date. For example, `week` or `w` returns results from the past 7 + days. + + topic: The search topic category. Use `news` for current events and news articles, or + `general` for broader web search. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + return self._post( + "/v1/tools/web-search", + body=maybe_transform( + { + "chunks_per_source": chunks_per_source, + "country": country, + "days": days, + "exclude_domains": exclude_domains, + "include_answer": include_answer, + "include_domains": include_domains, + "include_raw_content": include_raw_content, + "max_results": max_results, + "query": query, + "search_depth": search_depth, + "stream": stream, + "time_range": time_range, + "topic": topic, + }, + tool_web_search_params.ToolWebSearchParams, + ), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=ToolWebSearchResponse, + ) + + +class AsyncToolsResource(AsyncAPIResource): + @cached_property + def with_raw_response(self) -> AsyncToolsResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/writer/writer-python#accessing-raw-response-data-eg-headers + """ + return AsyncToolsResourceWithRawResponse(self) + + @cached_property + def with_streaming_response(self) -> AsyncToolsResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/writer/writer-python#with_streaming_response + """ + return AsyncToolsResourceWithStreamingResponse(self) + + @typing_extensions.deprecated( + "Will be removed in a future release. A replacement PDF parsing tool for chat completions is planned; see documentation at dev.writer.com for more information." + ) + async def parse_pdf( + self, + file_id: str, + *, + format: Literal["text", "markdown"], + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> ToolParsePdfResponse: + """ + Parse PDF to other formats. + + Args: + format: The format into which the PDF content should be converted. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not file_id: + raise ValueError(f"Expected a non-empty value for `file_id` but received {file_id!r}") + return await self._post( + path_template("/v1/tools/pdf-parser/{file_id}", file_id=file_id), + body=await async_maybe_transform({"format": format}, tool_parse_pdf_params.ToolParsePdfParams), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=ToolParsePdfResponse, + ) + + @typing_extensions.deprecated( + "Will be removed in a future release. Migrate to `chat.chat` with the web search tool for web search capabilities. See documentation at dev.writer.com for more information." + ) + async def web_search( + self, + *, + chunks_per_source: int | Omit = omit, + country: Literal[ + "afghanistan", + "albania", + "algeria", + "andorra", + "angola", + "argentina", + "armenia", + "australia", + "austria", + "azerbaijan", + "bahamas", + "bahrain", + "bangladesh", + "barbados", + "belarus", + "belgium", + "belize", + "benin", + "bhutan", + "bolivia", + "bosnia and herzegovina", + "botswana", + "brazil", + "brunei", + "bulgaria", + "burkina faso", + "burundi", + "cambodia", + "cameroon", + "canada", + "cape verde", + "central african republic", + "chad", + "chile", + "china", + "colombia", + "comoros", + "congo", + "costa rica", + "croatia", + "cuba", + "cyprus", + "czech republic", + "denmark", + "djibouti", + "dominican republic", + "ecuador", + "egypt", + "el salvador", + "equatorial guinea", + "eritrea", + "estonia", + "ethiopia", + "fiji", + "finland", + "france", + "gabon", + "gambia", + "georgia", + "germany", + "ghana", + "greece", + "guatemala", + "guinea", + "haiti", + "honduras", + "hungary", + "iceland", + "india", + "indonesia", + "iran", + "iraq", + "ireland", + "israel", + "italy", + "jamaica", + "japan", + "jordan", + "kazakhstan", + "kenya", + "kuwait", + "kyrgyzstan", + "latvia", + "lebanon", + "lesotho", + "liberia", + "libya", + "liechtenstein", + "lithuania", + "luxembourg", + "madagascar", + "malawi", + "malaysia", + "maldives", + "mali", + "malta", + "mauritania", + "mauritius", + "mexico", + "moldova", + "monaco", + "mongolia", + "montenegro", + "morocco", + "mozambique", + "myanmar", + "namibia", + "nepal", + "netherlands", + "new zealand", + "nicaragua", + "niger", + "nigeria", + "north korea", + "north macedonia", + "norway", + "oman", + "pakistan", + "panama", + "papua new guinea", + "paraguay", + "peru", + "philippines", + "poland", + "portugal", + "qatar", + "romania", + "russia", + "rwanda", + "saudi arabia", + "senegal", + "serbia", + "singapore", + "slovakia", + "slovenia", + "somalia", + "south africa", + "south korea", + "south sudan", + "spain", + "sri lanka", + "sudan", + "sweden", + "switzerland", + "syria", + "taiwan", + "tajikistan", + "tanzania", + "thailand", + "togo", + "trinidad and tobago", + "tunisia", + "turkey", + "turkmenistan", + "uganda", + "ukraine", + "united arab emirates", + "united kingdom", + "united states", + "uruguay", + "uzbekistan", + "venezuela", + "vietnam", + "yemen", + "zambia", + "zimbabwe", + ] + | Omit = omit, + days: int | Omit = omit, + exclude_domains: SequenceNotStr[str] | Omit = omit, + include_answer: bool | Omit = omit, + include_domains: SequenceNotStr[str] | Omit = omit, + include_raw_content: Union[Literal["text", "markdown"], bool] | Omit = omit, + max_results: int | Omit = omit, + query: str | Omit = omit, + search_depth: Literal["basic", "advanced"] | Omit = omit, + stream: bool | Omit = omit, + time_range: Literal["day", "week", "month", "year", "d", "w", "m", "y"] | Omit = omit, + topic: Literal["general", "news"] | Omit = omit, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> ToolWebSearchResponse: + """ + Search the web for information about a given query and return relevant results + with source URLs. + + Args: + chunks_per_source: Only applies when `search_depth` is `advanced`. Specifies how many text segments + to extract from each source. Limited to 3 chunks maximum. + + country: Localizes search results to a specific country. Only applies to general topic + searches. + + days: For news topic searches, specifies how many days of news coverage to include. + + exclude_domains: Domains to exclude from the search. If unset, the search includes all domains. + + include_answer: Whether to include a generated answer to the query in the response. If `false`, + only search results are returned. + + include_domains: Domains to include in the search. If unset, the search includes all domains. + + include_raw_content: + Controls how raw content is included in search results: + + - `text`: Returns plain text without formatting markup + - `markdown`: Returns structured content with markdown formatting (headers, + links, bold text) + - `true`: Same as `markdown` + - `false`: Raw content is not included (default if unset) + + max_results: Limits the number of search results returned. Cannot exceed 20 sources. + + query: The search query. + + search_depth: + Controls search comprehensiveness: + + - `basic`: Returns fewer but highly relevant results + - `advanced`: Performs a deeper search with more results + + stream: Enables streaming of search results as they become available. + + time_range: Filters results to content published within the specified time range back from + the current date. For example, `week` or `w` returns results from the past 7 + days. + + topic: The search topic category. Use `news` for current events and news articles, or + `general` for broader web search. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + return await self._post( + "/v1/tools/web-search", + body=await async_maybe_transform( + { + "chunks_per_source": chunks_per_source, + "country": country, + "days": days, + "exclude_domains": exclude_domains, + "include_answer": include_answer, + "include_domains": include_domains, + "include_raw_content": include_raw_content, + "max_results": max_results, + "query": query, + "search_depth": search_depth, + "stream": stream, + "time_range": time_range, + "topic": topic, + }, + tool_web_search_params.ToolWebSearchParams, + ), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=ToolWebSearchResponse, + ) + + +class ToolsResourceWithRawResponse: + def __init__(self, tools: ToolsResource) -> None: + self._tools = tools + + self.parse_pdf = ( # pyright: ignore[reportDeprecated] + to_raw_response_wrapper( + tools.parse_pdf, # pyright: ignore[reportDeprecated], + ) + ) + self.web_search = ( # pyright: ignore[reportDeprecated] + to_raw_response_wrapper( + tools.web_search, # pyright: ignore[reportDeprecated], + ) + ) + + +class AsyncToolsResourceWithRawResponse: + def __init__(self, tools: AsyncToolsResource) -> None: + self._tools = tools + + self.parse_pdf = ( # pyright: ignore[reportDeprecated] + async_to_raw_response_wrapper( + tools.parse_pdf, # pyright: ignore[reportDeprecated], + ) + ) + self.web_search = ( # pyright: ignore[reportDeprecated] + async_to_raw_response_wrapper( + tools.web_search, # pyright: ignore[reportDeprecated], + ) + ) + + +class ToolsResourceWithStreamingResponse: + def __init__(self, tools: ToolsResource) -> None: + self._tools = tools + + self.parse_pdf = ( # pyright: ignore[reportDeprecated] + to_streamed_response_wrapper( + tools.parse_pdf, # pyright: ignore[reportDeprecated], + ) + ) + self.web_search = ( # pyright: ignore[reportDeprecated] + to_streamed_response_wrapper( + tools.web_search, # pyright: ignore[reportDeprecated], + ) + ) + + +class AsyncToolsResourceWithStreamingResponse: + def __init__(self, tools: AsyncToolsResource) -> None: + self._tools = tools + + self.parse_pdf = ( # pyright: ignore[reportDeprecated] + async_to_streamed_response_wrapper( + tools.parse_pdf, # pyright: ignore[reportDeprecated], + ) + ) + self.web_search = ( # pyright: ignore[reportDeprecated] + async_to_streamed_response_wrapper( + tools.web_search, # pyright: ignore[reportDeprecated], + ) + ) diff --git a/src/writerai/resources/translation.py b/src/writerai/resources/translation.py new file mode 100644 index 00000000..36166407 --- /dev/null +++ b/src/writerai/resources/translation.py @@ -0,0 +1,278 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +import typing_extensions +from typing_extensions import Literal + +import httpx + +from ..types import translation_translate_params +from .._types import Body, Query, Headers, NotGiven, not_given +from .._utils import maybe_transform, async_maybe_transform +from .._compat import cached_property +from .._resource import SyncAPIResource, AsyncAPIResource +from .._response import ( + to_raw_response_wrapper, + to_streamed_response_wrapper, + async_to_raw_response_wrapper, + async_to_streamed_response_wrapper, +) +from .._base_client import make_request_options +from ..types.translation_response import TranslationResponse + +__all__ = ["TranslationResource", "AsyncTranslationResource"] + + +class TranslationResource(SyncAPIResource): + @cached_property + def with_raw_response(self) -> TranslationResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/writer/writer-python#accessing-raw-response-data-eg-headers + """ + return TranslationResourceWithRawResponse(self) + + @cached_property + def with_streaming_response(self) -> TranslationResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/writer/writer-python#with_streaming_response + """ + return TranslationResourceWithStreamingResponse(self) + + @typing_extensions.deprecated( + "Will be removed in a future release. Migrate to `chat.chat` with the translate tool for translation capabilities. See documentation at dev.writer.com for more information." + ) + def translate( + self, + *, + formality: bool, + length_control: bool, + mask_profanity: bool, + model: Literal["palmyra-translate"], + source_language_code: str, + target_language_code: str, + text: str, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> TranslationResponse: + """ + Translate text from one language to another. + + Args: + formality: Whether to use formal or informal language in the translation. See the + [list of languages that support formality](https://dev.writer.com/api-reference/translation-api/language-support#formality). + If the language does not support formality, this parameter is ignored. + + length_control: Whether to control the length of the translated text. See the + [list of languages that support length control](https://dev.writer.com/api-reference/translation-api/language-support#length-control). + If the language does not support length control, this parameter is ignored. + + mask_profanity: Whether to mask profane words in the translated text. See the + [list of languages that do not support profanity masking](https://dev.writer.com/api-reference/translation-api/language-support#profanity-masking). + If the language does not support profanity masking, this parameter is ignored. + + model: The model to use for translation. + + source_language_code: The [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639_language_codes) + language code of the original text to translate. For example, `en` for English, + `zh` for Chinese, `fr` for French, `es` for Spanish. If the language has a + variant, the code appends the two-digit + [ISO-3166 country code](https://en.wikipedia.org/wiki/List_of_ISO_3166_country_codes). + For example, Mexican Spanish is `es-MX`. See the + [list of supported languages and language codes](https://dev.writer.com/api-reference/translation-api/language-support). + + target_language_code: The [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639_language_codes) + language code of the target language for the translation. For example, `en` for + English, `zh` for Chinese, `fr` for French, `es` for Spanish. If the language + has a variant, the code appends the two-digit + [ISO-3166 country code](https://en.wikipedia.org/wiki/List_of_ISO_3166_country_codes). + For example, Mexican Spanish is `es-MX`. See the + [list of supported languages and language codes](https://dev.writer.com/api-reference/translation-api/language-support). + + text: The text to translate. Maximum of 100,000 words. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + return self._post( + "/v1/translation", + body=maybe_transform( + { + "formality": formality, + "length_control": length_control, + "mask_profanity": mask_profanity, + "model": model, + "source_language_code": source_language_code, + "target_language_code": target_language_code, + "text": text, + }, + translation_translate_params.TranslationTranslateParams, + ), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=TranslationResponse, + ) + + +class AsyncTranslationResource(AsyncAPIResource): + @cached_property + def with_raw_response(self) -> AsyncTranslationResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/writer/writer-python#accessing-raw-response-data-eg-headers + """ + return AsyncTranslationResourceWithRawResponse(self) + + @cached_property + def with_streaming_response(self) -> AsyncTranslationResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/writer/writer-python#with_streaming_response + """ + return AsyncTranslationResourceWithStreamingResponse(self) + + @typing_extensions.deprecated( + "Will be removed in a future release. Migrate to `chat.chat` with the translate tool for translation capabilities. See documentation at dev.writer.com for more information." + ) + async def translate( + self, + *, + formality: bool, + length_control: bool, + mask_profanity: bool, + model: Literal["palmyra-translate"], + source_language_code: str, + target_language_code: str, + text: str, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> TranslationResponse: + """ + Translate text from one language to another. + + Args: + formality: Whether to use formal or informal language in the translation. See the + [list of languages that support formality](https://dev.writer.com/api-reference/translation-api/language-support#formality). + If the language does not support formality, this parameter is ignored. + + length_control: Whether to control the length of the translated text. See the + [list of languages that support length control](https://dev.writer.com/api-reference/translation-api/language-support#length-control). + If the language does not support length control, this parameter is ignored. + + mask_profanity: Whether to mask profane words in the translated text. See the + [list of languages that do not support profanity masking](https://dev.writer.com/api-reference/translation-api/language-support#profanity-masking). + If the language does not support profanity masking, this parameter is ignored. + + model: The model to use for translation. + + source_language_code: The [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639_language_codes) + language code of the original text to translate. For example, `en` for English, + `zh` for Chinese, `fr` for French, `es` for Spanish. If the language has a + variant, the code appends the two-digit + [ISO-3166 country code](https://en.wikipedia.org/wiki/List_of_ISO_3166_country_codes). + For example, Mexican Spanish is `es-MX`. See the + [list of supported languages and language codes](https://dev.writer.com/api-reference/translation-api/language-support). + + target_language_code: The [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639_language_codes) + language code of the target language for the translation. For example, `en` for + English, `zh` for Chinese, `fr` for French, `es` for Spanish. If the language + has a variant, the code appends the two-digit + [ISO-3166 country code](https://en.wikipedia.org/wiki/List_of_ISO_3166_country_codes). + For example, Mexican Spanish is `es-MX`. See the + [list of supported languages and language codes](https://dev.writer.com/api-reference/translation-api/language-support). + + text: The text to translate. Maximum of 100,000 words. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + return await self._post( + "/v1/translation", + body=await async_maybe_transform( + { + "formality": formality, + "length_control": length_control, + "mask_profanity": mask_profanity, + "model": model, + "source_language_code": source_language_code, + "target_language_code": target_language_code, + "text": text, + }, + translation_translate_params.TranslationTranslateParams, + ), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=TranslationResponse, + ) + + +class TranslationResourceWithRawResponse: + def __init__(self, translation: TranslationResource) -> None: + self._translation = translation + + self.translate = ( # pyright: ignore[reportDeprecated] + to_raw_response_wrapper( + translation.translate, # pyright: ignore[reportDeprecated], + ) + ) + + +class AsyncTranslationResourceWithRawResponse: + def __init__(self, translation: AsyncTranslationResource) -> None: + self._translation = translation + + self.translate = ( # pyright: ignore[reportDeprecated] + async_to_raw_response_wrapper( + translation.translate, # pyright: ignore[reportDeprecated], + ) + ) + + +class TranslationResourceWithStreamingResponse: + def __init__(self, translation: TranslationResource) -> None: + self._translation = translation + + self.translate = ( # pyright: ignore[reportDeprecated] + to_streamed_response_wrapper( + translation.translate, # pyright: ignore[reportDeprecated], + ) + ) + + +class AsyncTranslationResourceWithStreamingResponse: + def __init__(self, translation: AsyncTranslationResource) -> None: + self._translation = translation + + self.translate = ( # pyright: ignore[reportDeprecated] + async_to_streamed_response_wrapper( + translation.translate, # pyright: ignore[reportDeprecated], + ) + ) diff --git a/src/writerai/resources/vision.py b/src/writerai/resources/vision.py new file mode 100644 index 00000000..446b31d2 --- /dev/null +++ b/src/writerai/resources/vision.py @@ -0,0 +1,200 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing import Iterable +from typing_extensions import Literal + +import httpx + +from ..types import vision_analyze_params +from .._types import Body, Query, Headers, NotGiven, not_given +from .._utils import maybe_transform, async_maybe_transform +from .._compat import cached_property +from .._resource import SyncAPIResource, AsyncAPIResource +from .._response import ( + to_raw_response_wrapper, + to_streamed_response_wrapper, + async_to_raw_response_wrapper, + async_to_streamed_response_wrapper, +) +from .._base_client import make_request_options +from ..types.vision_response import VisionResponse + +__all__ = ["VisionResource", "AsyncVisionResource"] + + +class VisionResource(SyncAPIResource): + @cached_property + def with_raw_response(self) -> VisionResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/writer/writer-python#accessing-raw-response-data-eg-headers + """ + return VisionResourceWithRawResponse(self) + + @cached_property + def with_streaming_response(self) -> VisionResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/writer/writer-python#with_streaming_response + """ + return VisionResourceWithStreamingResponse(self) + + def analyze( + self, + *, + model: Literal["palmyra-vision"], + prompt: str, + variables: Iterable[vision_analyze_params.Variable], + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> VisionResponse: + """Submit images and documents with a prompt to generate an analysis. + + Supports JPG, + PNG, PDF, and TXT files up to 7MB each. + + Args: + model: The model to use for image analysis. + + prompt: The prompt to use for the image analysis. The prompt must include the name of + each image variable, surrounded by double curly braces (`{{}}`). For example, + `Describe the difference between the image {{image_1}} and the image {{image_2}}`. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + return self._post( + "/v1/vision", + body=maybe_transform( + { + "model": model, + "prompt": prompt, + "variables": variables, + }, + vision_analyze_params.VisionAnalyzeParams, + ), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=VisionResponse, + ) + + +class AsyncVisionResource(AsyncAPIResource): + @cached_property + def with_raw_response(self) -> AsyncVisionResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/writer/writer-python#accessing-raw-response-data-eg-headers + """ + return AsyncVisionResourceWithRawResponse(self) + + @cached_property + def with_streaming_response(self) -> AsyncVisionResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/writer/writer-python#with_streaming_response + """ + return AsyncVisionResourceWithStreamingResponse(self) + + async def analyze( + self, + *, + model: Literal["palmyra-vision"], + prompt: str, + variables: Iterable[vision_analyze_params.Variable], + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = not_given, + ) -> VisionResponse: + """Submit images and documents with a prompt to generate an analysis. + + Supports JPG, + PNG, PDF, and TXT files up to 7MB each. + + Args: + model: The model to use for image analysis. + + prompt: The prompt to use for the image analysis. The prompt must include the name of + each image variable, surrounded by double curly braces (`{{}}`). For example, + `Describe the difference between the image {{image_1}} and the image {{image_2}}`. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + return await self._post( + "/v1/vision", + body=await async_maybe_transform( + { + "model": model, + "prompt": prompt, + "variables": variables, + }, + vision_analyze_params.VisionAnalyzeParams, + ), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=VisionResponse, + ) + + +class VisionResourceWithRawResponse: + def __init__(self, vision: VisionResource) -> None: + self._vision = vision + + self.analyze = to_raw_response_wrapper( + vision.analyze, + ) + + +class AsyncVisionResourceWithRawResponse: + def __init__(self, vision: AsyncVisionResource) -> None: + self._vision = vision + + self.analyze = async_to_raw_response_wrapper( + vision.analyze, + ) + + +class VisionResourceWithStreamingResponse: + def __init__(self, vision: VisionResource) -> None: + self._vision = vision + + self.analyze = to_streamed_response_wrapper( + vision.analyze, + ) + + +class AsyncVisionResourceWithStreamingResponse: + def __init__(self, vision: AsyncVisionResource) -> None: + self._vision = vision + + self.analyze = async_to_streamed_response_wrapper( + vision.analyze, + ) diff --git a/src/writerai/types/__init__.py b/src/writerai/types/__init__.py new file mode 100644 index 00000000..195d442a --- /dev/null +++ b/src/writerai/types/__init__.py @@ -0,0 +1,63 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from .file import File as File +from .graph import Graph as Graph +from .shared import ( + Source as Source, + Logprobs as Logprobs, + ToolCall as ToolCall, + GraphData as GraphData, + ToolParam as ToolParam, + ErrorObject as ErrorObject, + ErrorMessage as ErrorMessage, + LogprobsToken as LogprobsToken, + FunctionParams as FunctionParams, + ToolChoiceString as ToolChoiceString, + ToolCallStreaming as ToolCallStreaming, + FunctionDefinition as FunctionDefinition, + ToolChoiceJsonObject as ToolChoiceJsonObject, +) +from .question import Question as Question +from .completion import Completion as Completion +from .chat_completion import ChatCompletion as ChatCompletion +from .vision_response import VisionResponse as VisionResponse +from .chat_chat_params import ChatChatParams as ChatChatParams +from .completion_chunk import CompletionChunk as CompletionChunk +from .file_list_params import FileListParams as FileListParams +from .file_retry_params import FileRetryParams as FileRetryParams +from .graph_list_params import GraphListParams as GraphListParams +from .file_upload_params import FileUploadParams as FileUploadParams +from .file_retry_response import FileRetryResponse as FileRetryResponse +from .graph_create_params import GraphCreateParams as GraphCreateParams +from .graph_update_params import GraphUpdateParams as GraphUpdateParams +from .model_list_response import ModelListResponse as ModelListResponse +from .file_delete_response import FileDeleteResponse as FileDeleteResponse +from .translation_response import TranslationResponse as TranslationResponse +from .chat_completion_chunk import ChatCompletionChunk as ChatCompletionChunk +from .chat_completion_usage import ChatCompletionUsage as ChatCompletionUsage +from .graph_create_response import GraphCreateResponse as GraphCreateResponse +from .graph_delete_response import GraphDeleteResponse as GraphDeleteResponse +from .graph_question_params import GraphQuestionParams as GraphQuestionParams +from .graph_update_response import GraphUpdateResponse as GraphUpdateResponse +from .tool_parse_pdf_params import ToolParsePdfParams as ToolParsePdfParams +from .vision_analyze_params import VisionAnalyzeParams as VisionAnalyzeParams +from .chat_completion_choice import ChatCompletionChoice as ChatCompletionChoice +from .tool_web_search_params import ToolWebSearchParams as ToolWebSearchParams +from .application_list_params import ApplicationListParams as ApplicationListParams +from .chat_completion_message import ChatCompletionMessage as ChatCompletionMessage +from .question_response_chunk import QuestionResponseChunk as QuestionResponseChunk +from .tool_parse_pdf_response import ToolParsePdfResponse as ToolParsePdfResponse +from .completion_create_params import CompletionCreateParams as CompletionCreateParams +from .tool_web_search_response import ToolWebSearchResponse as ToolWebSearchResponse +from .application_list_response import ApplicationListResponse as ApplicationListResponse +from .translation_translate_params import TranslationTranslateParams as TranslationTranslateParams +from .application_retrieve_response import ApplicationRetrieveResponse as ApplicationRetrieveResponse +from .graph_add_file_to_graph_params import GraphAddFileToGraphParams as GraphAddFileToGraphParams +from .application_generate_content_chunk import ApplicationGenerateContentChunk as ApplicationGenerateContentChunk +from .application_generate_content_params import ApplicationGenerateContentParams as ApplicationGenerateContentParams +from .application_generate_content_response import ( + ApplicationGenerateContentResponse as ApplicationGenerateContentResponse, +) +from .graph_remove_file_from_graph_response import GraphRemoveFileFromGraphResponse as GraphRemoveFileFromGraphResponse diff --git a/src/writerai/types/application_generate_content_chunk.py b/src/writerai/types/application_generate_content_chunk.py new file mode 100644 index 00000000..70e96939 --- /dev/null +++ b/src/writerai/types/application_generate_content_chunk.py @@ -0,0 +1,33 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import List, Optional + +from .._models import BaseModel + +__all__ = ["ApplicationGenerateContentChunk", "Delta", "DeltaStage"] + + +class DeltaStage(BaseModel): + id: str + """The unique identifier for the stage.""" + + content: str + """The text content of the stage.""" + + sources: Optional[List[str]] = None + """A list of sources (URLs) that that stage used to process that particular step.""" + + +class Delta(BaseModel): + content: Optional[str] = None + """The main text output.""" + + stages: Optional[List[DeltaStage]] = None + """A list of stages that show the 'thinking process'.""" + + title: Optional[str] = None + """The name of the output.""" + + +class ApplicationGenerateContentChunk(BaseModel): + delta: Delta diff --git a/src/writerai/types/application_generate_content_params.py b/src/writerai/types/application_generate_content_params.py new file mode 100644 index 00000000..401d52d4 --- /dev/null +++ b/src/writerai/types/application_generate_content_params.py @@ -0,0 +1,61 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing import Union, Iterable +from typing_extensions import Literal, Required, TypedDict + +from .._types import SequenceNotStr + +__all__ = [ + "ApplicationGenerateContentParamsBase", + "Input", + "ApplicationGenerateContentParamsNonStreaming", + "ApplicationGenerateContentParamsStreaming", +] + + +class ApplicationGenerateContentParamsBase(TypedDict, total=False): + inputs: Required[Iterable[Input]] + + +class Input(TypedDict, total=False): + id: Required[str] + """The unique identifier for the input field from the application. + + All input types from the No-code application are supported (i.e. Text input, + Dropdown, File upload, Image input). The identifier should be the name of the + input type. + """ + + value: Required[SequenceNotStr[str]] + """The value for the input field. + + If the input type is "File upload", you must pass the `file_id` of an uploaded + file. You cannot pass a file object directly. See the + [file upload endpoint](https://dev.writer.com/api-reference/file-api/upload-files) + for instructions on uploading files or the + [list files endpoint](https://dev.writer.com/api-reference/file-api/get-all-files) + for how to see a list of uploaded files and their IDs. + """ + + +class ApplicationGenerateContentParamsNonStreaming(ApplicationGenerateContentParamsBase, total=False): + stream: Literal[False] + """Indicates whether the response should be streamed. + + Currently only supported for research assistant applications. + """ + + +class ApplicationGenerateContentParamsStreaming(ApplicationGenerateContentParamsBase): + stream: Required[Literal[True]] + """Indicates whether the response should be streamed. + + Currently only supported for research assistant applications. + """ + + +ApplicationGenerateContentParams = Union[ + ApplicationGenerateContentParamsNonStreaming, ApplicationGenerateContentParamsStreaming +] diff --git a/src/writerai/types/application_generate_content_response.py b/src/writerai/types/application_generate_content_response.py new file mode 100644 index 00000000..240a6a82 --- /dev/null +++ b/src/writerai/types/application_generate_content_response.py @@ -0,0 +1,15 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import Optional + +from .._models import BaseModel + +__all__ = ["ApplicationGenerateContentResponse"] + + +class ApplicationGenerateContentResponse(BaseModel): + suggestion: str + """The response from the model specified in the application.""" + + title: Optional[str] = None + """The name of the output field.""" diff --git a/src/writerai/types/application_list_params.py b/src/writerai/types/application_list_params.py new file mode 100644 index 00000000..1753e83b --- /dev/null +++ b/src/writerai/types/application_list_params.py @@ -0,0 +1,24 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing_extensions import Literal, TypedDict + +__all__ = ["ApplicationListParams"] + + +class ApplicationListParams(TypedDict, total=False): + after: str + """Return results after this application ID for pagination.""" + + before: str + """Return results before this application ID for pagination.""" + + limit: int + """Maximum number of applications to return in the response.""" + + order: Literal["asc", "desc"] + """Sort order for the results based on creation time.""" + + type: Literal["generation"] + """Filter applications by their type.""" diff --git a/src/writerai/types/application_list_response.py b/src/writerai/types/application_list_response.py new file mode 100644 index 00000000..d6eb0734 --- /dev/null +++ b/src/writerai/types/application_list_response.py @@ -0,0 +1,121 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import List, Union, Optional +from datetime import datetime +from typing_extensions import Literal, TypeAlias + +from .._models import BaseModel + +__all__ = [ + "ApplicationListResponse", + "Input", + "InputOptions", + "InputOptionsApplicationInputDropdownOptions", + "InputOptionsApplicationInputFileOptions", + "InputOptionsApplicationInputMediaOptions", + "InputOptionsApplicationInputTextOptions", +] + + +class InputOptionsApplicationInputDropdownOptions(BaseModel): + """Configuration options specific to dropdown-type input fields.""" + + list: List[str] + """List of available options in the dropdown menu.""" + + +class InputOptionsApplicationInputFileOptions(BaseModel): + """Configuration options specific to file upload input fields.""" + + file_types: List[str] + """List of allowed file extensions.""" + + max_file_size_mb: int + """Maximum file size allowed in megabytes.""" + + max_files: int + """Maximum number of files that can be uploaded.""" + + max_word_count: int + """Maximum number of words allowed in text files.""" + + upload_types: List[Literal["url", "file_id"]] + """List of allowed upload types for file inputs.""" + + +class InputOptionsApplicationInputMediaOptions(BaseModel): + """Configuration options specific to media upload input fields.""" + + file_types: List[str] + """List of allowed media file types.""" + + max_image_size_mb: int + """Maximum media file size allowed in megabytes.""" + + +class InputOptionsApplicationInputTextOptions(BaseModel): + """Configuration options specific to text input fields.""" + + max_fields: int + """Maximum number of text fields allowed.""" + + min_fields: int + """Minimum number of text fields required.""" + + +InputOptions: TypeAlias = Union[ + InputOptionsApplicationInputDropdownOptions, + InputOptionsApplicationInputFileOptions, + InputOptionsApplicationInputMediaOptions, + InputOptionsApplicationInputTextOptions, +] + + +class Input(BaseModel): + """Configuration for an individual input field in the application.""" + + input_type: Literal["text", "dropdown", "file", "media"] + """Type of input field determining its behavior and validation rules.""" + + name: str + """Identifier for the input field.""" + + required: bool + """Indicates if this input field is mandatory.""" + + description: Optional[str] = None + """Human-readable description of the input field's purpose.""" + + options: Optional[InputOptions] = None + """Type-specific configuration options for input fields.""" + + +class ApplicationListResponse(BaseModel): + """Detailed application object including its input configuration.""" + + id: str + """Unique identifier for the application.""" + + created_at: datetime + """Timestamp when the application was created.""" + + inputs: List[Input] + """List of input configurations for the application.""" + + name: str + """Display name of the application.""" + + status: Literal["deployed", "draft"] + """Current deployment status of the application. + + Note: currently only `deployed` applications are returned. + """ + + type: Literal["generation"] + """The type of no-code application.""" + + updated_at: datetime + """Timestamp when the application was last updated.""" + + last_deployed_at: Optional[datetime] = None + """Timestamp when the application was last deployed.""" diff --git a/src/writerai/types/application_retrieve_response.py b/src/writerai/types/application_retrieve_response.py new file mode 100644 index 00000000..6fec026a --- /dev/null +++ b/src/writerai/types/application_retrieve_response.py @@ -0,0 +1,121 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import List, Union, Optional +from datetime import datetime +from typing_extensions import Literal, TypeAlias + +from .._models import BaseModel + +__all__ = [ + "ApplicationRetrieveResponse", + "Input", + "InputOptions", + "InputOptionsApplicationInputDropdownOptions", + "InputOptionsApplicationInputFileOptions", + "InputOptionsApplicationInputMediaOptions", + "InputOptionsApplicationInputTextOptions", +] + + +class InputOptionsApplicationInputDropdownOptions(BaseModel): + """Configuration options specific to dropdown-type input fields.""" + + list: List[str] + """List of available options in the dropdown menu.""" + + +class InputOptionsApplicationInputFileOptions(BaseModel): + """Configuration options specific to file upload input fields.""" + + file_types: List[str] + """List of allowed file extensions.""" + + max_file_size_mb: int + """Maximum file size allowed in megabytes.""" + + max_files: int + """Maximum number of files that can be uploaded.""" + + max_word_count: int + """Maximum number of words allowed in text files.""" + + upload_types: List[Literal["url", "file_id"]] + """List of allowed upload types for file inputs.""" + + +class InputOptionsApplicationInputMediaOptions(BaseModel): + """Configuration options specific to media upload input fields.""" + + file_types: List[str] + """List of allowed media file types.""" + + max_image_size_mb: int + """Maximum media file size allowed in megabytes.""" + + +class InputOptionsApplicationInputTextOptions(BaseModel): + """Configuration options specific to text input fields.""" + + max_fields: int + """Maximum number of text fields allowed.""" + + min_fields: int + """Minimum number of text fields required.""" + + +InputOptions: TypeAlias = Union[ + InputOptionsApplicationInputDropdownOptions, + InputOptionsApplicationInputFileOptions, + InputOptionsApplicationInputMediaOptions, + InputOptionsApplicationInputTextOptions, +] + + +class Input(BaseModel): + """Configuration for an individual input field in the application.""" + + input_type: Literal["text", "dropdown", "file", "media"] + """Type of input field determining its behavior and validation rules.""" + + name: str + """Identifier for the input field.""" + + required: bool + """Indicates if this input field is mandatory.""" + + description: Optional[str] = None + """Human-readable description of the input field's purpose.""" + + options: Optional[InputOptions] = None + """Type-specific configuration options for input fields.""" + + +class ApplicationRetrieveResponse(BaseModel): + """Detailed application object including its input configuration.""" + + id: str + """Unique identifier for the application.""" + + created_at: datetime + """Timestamp when the application was created.""" + + inputs: List[Input] + """List of input configurations for the application.""" + + name: str + """Display name of the application.""" + + status: Literal["deployed", "draft"] + """Current deployment status of the application. + + Note: currently only `deployed` applications are returned. + """ + + type: Literal["generation"] + """The type of no-code application.""" + + updated_at: datetime + """Timestamp when the application was last updated.""" + + last_deployed_at: Optional[datetime] = None + """Timestamp when the application was last deployed.""" diff --git a/src/writerai/types/applications/__init__.py b/src/writerai/types/applications/__init__.py new file mode 100644 index 00000000..92c8a570 --- /dev/null +++ b/src/writerai/types/applications/__init__.py @@ -0,0 +1,12 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from .job_list_params import JobListParams as JobListParams +from .job_create_params import JobCreateParams as JobCreateParams +from .job_retry_response import JobRetryResponse as JobRetryResponse +from .graph_update_params import GraphUpdateParams as GraphUpdateParams +from .job_create_response import JobCreateResponse as JobCreateResponse +from .application_graphs_response import ApplicationGraphsResponse as ApplicationGraphsResponse +from .application_jobs_list_response import ApplicationJobsListResponse as ApplicationJobsListResponse +from .application_generate_async_response import ApplicationGenerateAsyncResponse as ApplicationGenerateAsyncResponse diff --git a/src/writerai/types/applications/application_generate_async_response.py b/src/writerai/types/applications/application_generate_async_response.py new file mode 100644 index 00000000..08afe291 --- /dev/null +++ b/src/writerai/types/applications/application_generate_async_response.py @@ -0,0 +1,36 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import Optional +from datetime import datetime +from typing_extensions import Literal + +from ..._models import BaseModel +from ..application_generate_content_response import ApplicationGenerateContentResponse + +__all__ = ["ApplicationGenerateAsyncResponse"] + + +class ApplicationGenerateAsyncResponse(BaseModel): + id: str + """The unique identifier for the job.""" + + application_id: str + """The ID of the application associated with this job.""" + + created_at: datetime + """The timestamp when the job was created.""" + + status: Literal["in_progress", "failed", "completed"] + """The status of the job.""" + + completed_at: Optional[datetime] = None + """The timestamp when the job was completed.""" + + data: Optional[ApplicationGenerateContentResponse] = None + """The result of the completed job, if applicable.""" + + error: Optional[str] = None + """The error message if the job failed.""" + + updated_at: Optional[datetime] = None + """The timestamp when the job was last updated.""" diff --git a/src/writerai/types/applications/application_graphs_response.py b/src/writerai/types/applications/application_graphs_response.py new file mode 100644 index 00000000..f979a58a --- /dev/null +++ b/src/writerai/types/applications/application_graphs_response.py @@ -0,0 +1,12 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import List + +from ..._models import BaseModel + +__all__ = ["ApplicationGraphsResponse"] + + +class ApplicationGraphsResponse(BaseModel): + graph_ids: List[str] + """A list of Knowledge Graphs associated with the application.""" diff --git a/src/writerai/types/applications/application_jobs_list_response.py b/src/writerai/types/applications/application_jobs_list_response.py new file mode 100644 index 00000000..f6e92126 --- /dev/null +++ b/src/writerai/types/applications/application_jobs_list_response.py @@ -0,0 +1,27 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import List, Optional + +from pydantic import Field as FieldInfo + +from ..._models import BaseModel +from .application_generate_async_response import ApplicationGenerateAsyncResponse + +__all__ = ["ApplicationJobsListResponse", "Pagination"] + + +class Pagination(BaseModel): + limit: Optional[int] = None + """The pagination limit for retrieving the jobs.""" + + offset: Optional[int] = None + """The pagination offset for retrieving the jobs.""" + + +class ApplicationJobsListResponse(BaseModel): + result: List[ApplicationGenerateAsyncResponse] + + pagination: Optional[Pagination] = None + + total_count: Optional[int] = FieldInfo(alias="totalCount", default=None) + """The total number of jobs associated with the application.""" diff --git a/src/writerai/types/applications/graph_update_params.py b/src/writerai/types/applications/graph_update_params.py new file mode 100644 index 00000000..11fcf3bf --- /dev/null +++ b/src/writerai/types/applications/graph_update_params.py @@ -0,0 +1,18 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing_extensions import Required, TypedDict + +from ..._types import SequenceNotStr + +__all__ = ["GraphUpdateParams"] + + +class GraphUpdateParams(TypedDict, total=False): + graph_ids: Required[SequenceNotStr[str]] + """A list of Knowledge Graph IDs to associate with the application. + + Note that this will replace the existing list of Knowledge Graphs associated + with the application, not add to it. + """ diff --git a/src/writerai/types/applications/job_create_params.py b/src/writerai/types/applications/job_create_params.py new file mode 100644 index 00000000..f3453a9f --- /dev/null +++ b/src/writerai/types/applications/job_create_params.py @@ -0,0 +1,36 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing import Iterable +from typing_extensions import Required, TypedDict + +from ..._types import SequenceNotStr + +__all__ = ["JobCreateParams", "Input"] + + +class JobCreateParams(TypedDict, total=False): + inputs: Required[Iterable[Input]] + """A list of input objects to generate content for.""" + + +class Input(TypedDict, total=False): + id: Required[str] + """The unique identifier for the input field from the application. + + All input types from the No-code application are supported (i.e. Text input, + Dropdown, File upload, Image input). The identifier should be the name of the + input type. + """ + + value: Required[SequenceNotStr[str]] + """The value for the input field. + + If the input type is "File upload", you must pass the `file_id` of an uploaded + file. You cannot pass a file object directly. See the + [file upload endpoint](https://dev.writer.com/api-reference/file-api/upload-files) + for instructions on uploading files or the + [list files endpoint](https://dev.writer.com/api-reference/file-api/get-all-files) + for how to see a list of uploaded files and their IDs. + """ diff --git a/src/writerai/types/applications/job_create_response.py b/src/writerai/types/applications/job_create_response.py new file mode 100644 index 00000000..f83502c2 --- /dev/null +++ b/src/writerai/types/applications/job_create_response.py @@ -0,0 +1,19 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from datetime import datetime +from typing_extensions import Literal + +from ..._models import BaseModel + +__all__ = ["JobCreateResponse"] + + +class JobCreateResponse(BaseModel): + id: str + """The unique identifier for the async job created.""" + + created_at: datetime + """The timestamp when the job was created.""" + + status: Literal["in_progress", "failed", "completed"] + """The status of the job.""" diff --git a/src/writerai/types/applications/job_list_params.py b/src/writerai/types/applications/job_list_params.py new file mode 100644 index 00000000..dfb85b65 --- /dev/null +++ b/src/writerai/types/applications/job_list_params.py @@ -0,0 +1,18 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing_extensions import Literal, TypedDict + +__all__ = ["JobListParams"] + + +class JobListParams(TypedDict, total=False): + limit: int + """The pagination limit for retrieving the jobs.""" + + offset: int + """The pagination offset for retrieving the jobs.""" + + status: Literal["in_progress", "failed", "completed"] + """The status of the job.""" diff --git a/src/writerai/types/applications/job_retry_response.py b/src/writerai/types/applications/job_retry_response.py new file mode 100644 index 00000000..9555ce6d --- /dev/null +++ b/src/writerai/types/applications/job_retry_response.py @@ -0,0 +1,19 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from datetime import datetime +from typing_extensions import Literal + +from ..._models import BaseModel + +__all__ = ["JobRetryResponse"] + + +class JobRetryResponse(BaseModel): + id: str + """The unique identifier for the async job created.""" + + created_at: datetime + """The timestamp when the job was created.""" + + status: Literal["in_progress", "failed", "completed"] + """The status of the job.""" diff --git a/src/writerai/types/chat_chat_params.py b/src/writerai/types/chat_chat_params.py new file mode 100644 index 00000000..b7b3d830 --- /dev/null +++ b/src/writerai/types/chat_chat_params.py @@ -0,0 +1,237 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing import Union, Iterable, Optional +from typing_extensions import Literal, Required, TypeAlias, TypedDict + +from .._types import SequenceNotStr +from .shared_params.tool_call import ToolCall +from .shared_params.graph_data import GraphData +from .shared_params.tool_param import ToolParam +from .shared_params.tool_choice_string import ToolChoiceString +from .shared_params.tool_choice_json_object import ToolChoiceJsonObject + +__all__ = [ + "ChatChatParamsBase", + "Message", + "MessageContentMixedContent", + "MessageContentMixedContentTextFragment", + "MessageContentMixedContentImageFragment", + "MessageContentMixedContentImageFragmentImageURL", + "ResponseFormat", + "StreamOptions", + "ToolChoice", + "ChatChatParamsNonStreaming", + "ChatChatParamsStreaming", +] + + +class ChatChatParamsBase(TypedDict, total=False): + messages: Required[Iterable[Message]] + """ + An array of message objects that form the conversation history or context for + the model to respond to. The array must contain at least one message. + """ + + model: Required[str] + """ + The [ID of the model](https://dev.writer.com/home/models) to use for creating + the chat completion. Supports `palmyra-x5`, `palmyra-x4`, `palmyra-fin`, + `palmyra-med`, `palmyra-creative`, and `palmyra-x-003-instruct`. + """ + + logprobs: bool + """Specifies whether to return log probabilities of the output tokens.""" + + max_tokens: int + """ + Defines the maximum number of tokens (words and characters) that the model can + generate in the response. This can be adjusted to allow for longer or shorter + responses as needed. The maximum value varies by model. See the + [models overview](/home/models) for more information about the maximum number of + tokens for each model. + """ + + n: int + """ + Specifies the number of completions (responses) to generate from the model in a + single request. This parameter allows for generating multiple responses, + offering a variety of potential replies from which to choose. + """ + + response_format: ResponseFormat + """ + The response format to use for the chat completion, available with `palmyra-x4` + and `palmyra-x5`. + + `text` is the default response format. [JSON Schema](https://json-schema.org/) + is supported for structured responses. If you specify `json_schema`, you must + also provide a `json_schema` object. + """ + + stop: Union[SequenceNotStr[str], str] + """ + A token or sequence of tokens that, when generated, will cause the model to stop + producing further content. This can be a single token or an array of tokens, + acting as a signal to end the output. + """ + + stream_options: StreamOptions + """Additional options for streaming.""" + + temperature: float + """Controls the randomness or creativity of the model's responses. + + A higher temperature results in more varied and less predictable text, while a + lower temperature produces more deterministic and conservative outputs. + """ + + tool_choice: ToolChoice + """Configure how the model will call functions: + + - `auto`: allows the model to automatically choose the tool to use, or not call + a tool + - `none`: disables tool calling; the model will instead generate a message + - `required`: requires the model to call one or more tools + + You can also use a JSON object to force the model to call a specific tool. For + example, `{"type": "function", "function": {"name": "get_current_weather"}}` + requires the model to call the `get_current_weather` function, regardless of the + prompt. + """ + + tools: Iterable[ToolParam] + """ + An array containing tool definitions for tools that the model can use to + generate responses. The tool definitions use JSON schema. You can define your + own functions or use one of the built-in `graph`, `llm`, `translation`, or + `vision` tools. Note that you can only use one built-in tool type in the array + (only one of `graph`, `llm`, `translation`, or `vision`). You can pass multiple + [custom tools](https://dev.writer.com/home/tool-calling) of type `function` in + the same request. + """ + + top_p: float + """ + Sets the threshold for "nucleus sampling," a technique to focus the model's + token generation on the most likely subset of tokens. Only tokens with + cumulative probability above this threshold are considered, controlling the + trade-off between creativity and coherence. + """ + + +class MessageContentMixedContentTextFragment(TypedDict, total=False): + """Represents a text content fragment within a chat message.""" + + text: Required[str] + """The actual text content of the message fragment.""" + + type: Required[Literal["text"]] + """The type of content fragment. Must be `text` for text fragments.""" + + +class MessageContentMixedContentImageFragmentImageURL(TypedDict, total=False): + """The image URL object containing the location of the image.""" + + url: Required[str] + """The URL pointing to the image file. + + Supports common image formats like JPEG, PNG, GIF, etc. + """ + + +class MessageContentMixedContentImageFragment(TypedDict, total=False): + """Represents an image content fragment within a chat message. + + Note: This content type is only supported with the Palmyra X5 model. + """ + + image_url: Required[MessageContentMixedContentImageFragmentImageURL] + """The image URL object containing the location of the image.""" + + type: Required[Literal["image_url"]] + """The type of content fragment. Must be `image_url` for image fragments.""" + + +MessageContentMixedContent: TypeAlias = Union[ + MessageContentMixedContentTextFragment, MessageContentMixedContentImageFragment +] + + +class Message(TypedDict, total=False): + role: Required[Literal["user", "assistant", "system", "tool"]] + """The role of the chat message. + + You can provide a system prompt by setting the role to `system`, or specify that + a message is the result of a + [tool call](https://dev.writer.com/home/tool-calling) by setting the role to + `tool`. + """ + + content: Union[str, Iterable[MessageContentMixedContent], None] + """The content of the message. + + Can be either a string (for text-only messages) or an array of content fragments + (for mixed text and image messages). + """ + + graph_data: Optional[GraphData] + + name: Optional[str] + """An optional name for the message sender. + + Useful for identifying different users, personas, or tools in multi-participant + conversations. + """ + + refusal: Optional[str] + + tool_call_id: Optional[str] + + tool_calls: Optional[Iterable[ToolCall]] + + +class ResponseFormat(TypedDict, total=False): + """ + The response format to use for the chat completion, available with `palmyra-x4` and `palmyra-x5`. + + `text` is the default response format. [JSON Schema](https://json-schema.org/) is supported for structured responses. If you specify `json_schema`, you must also provide a `json_schema` object. + """ + + type: Required[Literal["text", "json_schema"]] + """The type of response format to use.""" + + json_schema: object + """The JSON schema to use for the response format.""" + + +class StreamOptions(TypedDict, total=False): + """Additional options for streaming.""" + + include_usage: Required[bool] + """Indicate whether to include usage information.""" + + +ToolChoice: TypeAlias = Union[ToolChoiceString, ToolChoiceJsonObject] + + +class ChatChatParamsNonStreaming(ChatChatParamsBase, total=False): + stream: Literal[False] + """ + Indicates whether the response should be streamed incrementally as it is + generated or only returned once fully complete. Streaming can be useful for + providing real-time feedback in interactive applications. + """ + + +class ChatChatParamsStreaming(ChatChatParamsBase): + stream: Required[Literal[True]] + """ + Indicates whether the response should be streamed incrementally as it is + generated or only returned once fully complete. Streaming can be useful for + providing real-time feedback in interactive applications. + """ + + +ChatChatParams = Union[ChatChatParamsNonStreaming, ChatChatParamsStreaming] diff --git a/src/writerai/types/chat_completion.py b/src/writerai/types/chat_completion.py new file mode 100644 index 00000000..a7521fff --- /dev/null +++ b/src/writerai/types/chat_completion.py @@ -0,0 +1,54 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import List, Optional +from typing_extensions import Literal + +from .._models import BaseModel +from .chat_completion_usage import ChatCompletionUsage +from .chat_completion_choice import ChatCompletionChoice + +__all__ = ["ChatCompletion"] + + +class ChatCompletion(BaseModel): + id: str + """A globally unique identifier (UUID) for the response generated by the API. + + This ID can be used to reference the specific operation or transaction within + the system for tracking or debugging purposes. + """ + + choices: List[ChatCompletionChoice] + """ + An array of objects representing the different outcomes or results produced by + the model based on the input provided. + """ + + created: int + """The Unix timestamp (in seconds) when the response was created. + + This timestamp can be used to verify the timing of the response relative to + other events or operations. + """ + + model: str + """Identifies the specific model used to generate the response.""" + + object: Literal["chat.completion"] + """ + The type of object returned, which is always `chat.completion` for chat + responses. + """ + + service_tier: Optional[str] = None + """The service tier used for processing the request.""" + + system_fingerprint: Optional[str] = None + """A string representing the backend configuration that the model runs with.""" + + usage: Optional[ChatCompletionUsage] = None + """Usage information for the chat completion response. + + Please note that at this time Knowledge Graph tool usage is not included in this + object. + """ diff --git a/src/writerai/types/chat_completion_choice.py b/src/writerai/types/chat_completion_choice.py new file mode 100644 index 00000000..f96b5bee --- /dev/null +++ b/src/writerai/types/chat_completion_choice.py @@ -0,0 +1,32 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import Optional +from typing_extensions import Literal + +from .._models import BaseModel +from .shared.logprobs import Logprobs +from .chat_completion_message import ChatCompletionMessage + +__all__ = ["ChatCompletionChoice"] + + +class ChatCompletionChoice(BaseModel): + finish_reason: Literal["stop", "length", "content_filter", "tool_calls"] + """Describes the condition under which the model ceased generating content. + + Common reasons include 'length' (reached the maximum output size), 'stop' + (encountered a stop sequence), 'content_filter' (harmful content filtered out), + or 'tool_calls' (encountered tool calls). + """ + + index: int + """The index of the choice in the list of completions generated by the model.""" + + message: ChatCompletionMessage + """The chat completion message from the model. + + Note: this field is deprecated for streaming. Use `delta` instead. + """ + + logprobs: Optional[Logprobs] = None + """Log probability information for the choice.""" diff --git a/src/writerai/types/chat_completion_chunk.py b/src/writerai/types/chat_completion_chunk.py new file mode 100644 index 00000000..bc3bf85b --- /dev/null +++ b/src/writerai/types/chat_completion_chunk.py @@ -0,0 +1,127 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import List, Optional +from typing_extensions import Literal + +from .._models import BaseModel +from .shared.logprobs import Logprobs +from .shared.graph_data import GraphData +from .chat_completion_usage import ChatCompletionUsage +from .chat_completion_message import ChatCompletionMessage +from .shared.tool_call_streaming import ToolCallStreaming + +__all__ = ["ChatCompletionChunk", "Choice", "ChoiceDelta", "ChoiceDeltaLlmData", "ChoiceDeltaTranslationData"] + + +class ChoiceDeltaLlmData(BaseModel): + model: str + """The model used by the tool.""" + + prompt: str + """The prompt processed by the model.""" + + +class ChoiceDeltaTranslationData(BaseModel): + source_language_code: str + """The language code of the source text.""" + + source_text: str + """The text the tool translated.""" + + target_language_code: str + """The language code of the target text.""" + + +class ChoiceDelta(BaseModel): + """A chat completion delta generated by streamed model responses.""" + + content: Optional[str] = None + """The text content produced by the model. + + This field contains the actual output generated, reflecting the model's response + to the input query or command. + """ + + graph_data: Optional[GraphData] = None + + llm_data: Optional[ChoiceDeltaLlmData] = None + + refusal: Optional[str] = None + + role: Optional[Literal["user", "assistant", "system"]] = None + """ + Specifies the role associated with the content, indicating whether the message + is from the 'assistant' or another defined role, helping to contextualize the + output within the interaction flow. + """ + + tool_calls: Optional[List[ToolCallStreaming]] = None + + translation_data: Optional[ChoiceDeltaTranslationData] = None + + +class Choice(BaseModel): + delta: ChoiceDelta + """A chat completion delta generated by streamed model responses.""" + + finish_reason: Optional[Literal["stop", "length", "content_filter", "tool_calls"]] = None + """Describes the condition under which the model ceased generating content. + + Common reasons include 'length' (reached the maximum output size), 'stop' + (encountered a stop sequence), 'content_filter' (harmful content filtered out), + or 'tool_calls' (encountered tool calls). + """ + + index: int + """The index of the choice in the list of completions generated by the model.""" + + logprobs: Optional[Logprobs] = None + """Log probability information for the choice.""" + + message: Optional[ChatCompletionMessage] = None + """The chat completion message from the model. + + Note: this field is deprecated for streaming. Use `delta` instead. + """ + + +class ChatCompletionChunk(BaseModel): + id: str + """A globally unique identifier (UUID) for the response generated by the API. + + This ID can be used to reference the specific operation or transaction within + the system for tracking or debugging purposes. + """ + + choices: List[Choice] + """ + An array of objects representing the different outcomes or results produced by + the model based on the input provided. + """ + + created: int + """The Unix timestamp (in seconds) when the response was created. + + This timestamp can be used to verify the timing of the response relative to + other events or operations. + """ + + model: str + """Identifies the specific model used to generate the response.""" + + object: Literal["chat.completion.chunk"] + """ + The type of object returned, which is always `chat.completion.chunk` for + streaming chat responses. + """ + + service_tier: Optional[str] = None + + system_fingerprint: Optional[str] = None + + usage: Optional[ChatCompletionUsage] = None + """Usage information for the chat completion response. + + Please note that at this time Knowledge Graph tool usage is not included in this + object. + """ diff --git a/src/writerai/types/chat_completion_message.py b/src/writerai/types/chat_completion_message.py new file mode 100644 index 00000000..97045c61 --- /dev/null +++ b/src/writerai/types/chat_completion_message.py @@ -0,0 +1,68 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import List, Optional +from typing_extensions import Literal + +from .._models import BaseModel +from .shared.tool_call import ToolCall +from .shared.graph_data import GraphData + +__all__ = ["ChatCompletionMessage", "LlmData", "TranslationData", "WebSearchData", "WebSearchDataSource"] + + +class LlmData(BaseModel): + model: str + """The model used by the tool.""" + + prompt: str + """The prompt processed by the model.""" + + +class TranslationData(BaseModel): + source_language_code: str + """The language code of the source text.""" + + source_text: str + """The text the tool translated.""" + + target_language_code: str + """The language code of the target text.""" + + +class WebSearchDataSource(BaseModel): + raw_content: Optional[str] = None + + url: Optional[str] = None + + +class WebSearchData(BaseModel): + sources: List[WebSearchDataSource] + + +class ChatCompletionMessage(BaseModel): + """The chat completion message from the model. + + Note: this field is deprecated for streaming. Use `delta` instead. + """ + + content: str + """The text content produced by the model. + + This field contains the actual output generated, reflecting the model's response + to the input query or command. + """ + + refusal: Optional[str] = None + + role: Literal["assistant"] + """Specifies the role associated with the content.""" + + graph_data: Optional[GraphData] = None + + llm_data: Optional[LlmData] = None + + tool_calls: Optional[List[ToolCall]] = None + + translation_data: Optional[TranslationData] = None + + web_search_data: Optional[WebSearchData] = None diff --git a/src/writerai/types/chat_completion_usage.py b/src/writerai/types/chat_completion_usage.py new file mode 100644 index 00000000..ec7b7991 --- /dev/null +++ b/src/writerai/types/chat_completion_usage.py @@ -0,0 +1,32 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import Optional + +from .._models import BaseModel + +__all__ = ["ChatCompletionUsage", "CompletionTokensDetails", "PromptTokenDetails"] + + +class CompletionTokensDetails(BaseModel): + reasoning_tokens: int + + +class PromptTokenDetails(BaseModel): + cached_tokens: int + + +class ChatCompletionUsage(BaseModel): + """Usage information for the chat completion response. + + Please note that at this time Knowledge Graph tool usage is not included in this object. + """ + + completion_tokens: int + + prompt_tokens: int + + total_tokens: int + + completion_tokens_details: Optional[CompletionTokensDetails] = None + + prompt_token_details: Optional[PromptTokenDetails] = None diff --git a/src/writerai/types/completion.py b/src/writerai/types/completion.py new file mode 100644 index 00000000..01fe40e5 --- /dev/null +++ b/src/writerai/types/completion.py @@ -0,0 +1,32 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import List, Optional + +from .._models import BaseModel +from .shared.logprobs import Logprobs + +__all__ = ["Completion", "Choice"] + + +class Choice(BaseModel): + text: str + """ + The generated text output from the model, which forms the main content of the + response. + """ + + log_probs: Optional[Logprobs] = None + + +class Completion(BaseModel): + choices: List[Choice] + """ + A list of choices generated by the model, each containing the text of the + completion and associated metadata such as log probabilities. + """ + + model: Optional[str] = None + """ + The identifier of the model that was used to generate the responses in the + 'choices' array. + """ diff --git a/src/writerai/types/completion_chunk.py b/src/writerai/types/completion_chunk.py new file mode 100644 index 00000000..b92b0a88 --- /dev/null +++ b/src/writerai/types/completion_chunk.py @@ -0,0 +1,9 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from .._models import BaseModel + +__all__ = ["CompletionChunk"] + + +class CompletionChunk(BaseModel): + value: str diff --git a/src/writerai/types/completion_create_params.py b/src/writerai/types/completion_create_params.py new file mode 100644 index 00000000..017e18bf --- /dev/null +++ b/src/writerai/types/completion_create_params.py @@ -0,0 +1,80 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing import Union +from typing_extensions import Literal, Required, TypedDict + +from .._types import SequenceNotStr + +__all__ = ["CompletionCreateParamsBase", "CompletionCreateParamsNonStreaming", "CompletionCreateParamsStreaming"] + + +class CompletionCreateParamsBase(TypedDict, total=False): + model: Required[str] + """ + The [ID of the model](https://dev.writer.com/home/models) to use for generating + text. Supports `palmyra-x5`, `palmyra-x4`, `palmyra-fin`, `palmyra-med`, + `palmyra-creative`, and `palmyra-x-003-instruct`. + """ + + prompt: Required[str] + """The input text that the model will process to generate a response.""" + + best_of: int + """Specifies the number of completions to generate and return the best one. + + Useful for generating multiple outputs and choosing the best based on some + criteria. + """ + + max_tokens: int + """The maximum number of tokens that the model can generate in the response.""" + + random_seed: int + """ + A seed used to initialize the random number generator for the model, ensuring + reproducibility of the output when the same inputs are provided. + """ + + stop: Union[SequenceNotStr[str], str] + """Specifies stopping conditions for the model's output generation. + + This can be an array of strings or a single string that the model will look for + as a signal to stop generating further tokens. + """ + + temperature: float + """Controls the randomness of the model's outputs. + + Higher values lead to more random outputs, while lower values make the model + more deterministic. + """ + + top_p: float + """ + Used to control the nucleus sampling, where only the most probable tokens with a + cumulative probability of top_p are considered for sampling, providing a way to + fine-tune the randomness of predictions. + """ + + +class CompletionCreateParamsNonStreaming(CompletionCreateParamsBase, total=False): + stream: Literal[False] + """Determines whether the model's output should be streamed. + + If true, the output is generated and sent incrementally, which can be useful for + real-time applications. + """ + + +class CompletionCreateParamsStreaming(CompletionCreateParamsBase): + stream: Required[Literal[True]] + """Determines whether the model's output should be streamed. + + If true, the output is generated and sent incrementally, which can be useful for + real-time applications. + """ + + +CompletionCreateParams = Union[CompletionCreateParamsNonStreaming, CompletionCreateParamsStreaming] diff --git a/src/writerai/types/file.py b/src/writerai/types/file.py new file mode 100644 index 00000000..8b129976 --- /dev/null +++ b/src/writerai/types/file.py @@ -0,0 +1,31 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import List +from datetime import datetime + +from .._models import BaseModel + +__all__ = ["File"] + + +class File(BaseModel): + id: str + """A unique identifier of the file.""" + + created_at: datetime + """The timestamp when the file was uploaded.""" + + graph_ids: List[str] + """A list of Knowledge Graph IDs that the file is associated with. + + If you provided a `graphId` during upload, the file is associated with that + Knowledge Graph. However, the `graph_ids` field in the upload response is an + empty list. The association will be visible in the `graph_ids` list when you + retrieve the file using the file retrieval endpoint. + """ + + name: str + """The name of the file.""" + + status: str + """The processing status of the file.""" diff --git a/src/writerai/types/file_delete_response.py b/src/writerai/types/file_delete_response.py new file mode 100644 index 00000000..d4735d06 --- /dev/null +++ b/src/writerai/types/file_delete_response.py @@ -0,0 +1,13 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from .._models import BaseModel + +__all__ = ["FileDeleteResponse"] + + +class FileDeleteResponse(BaseModel): + id: str + """A unique identifier of the deleted file.""" + + deleted: bool + """Indicates whether the file was successfully deleted.""" diff --git a/src/writerai/types/file_list_params.py b/src/writerai/types/file_list_params.py new file mode 100644 index 00000000..8bc101e1 --- /dev/null +++ b/src/writerai/types/file_list_params.py @@ -0,0 +1,48 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing_extensions import Literal, TypedDict + +__all__ = ["FileListParams"] + + +class FileListParams(TypedDict, total=False): + after: str + """The ID of the last object in the previous page. + + This parameter instructs the API to return the next page of results. + """ + + before: str + """The ID of the first object in the previous page. + + This parameter instructs the API to return the previous page of results. + """ + + file_types: str + """The extensions of the files to retrieve. + + Separate multiple extensions with a comma. For example: `pdf,jpg,docx`. + """ + + graph_id: str + """The unique identifier of the graph to which the files belong.""" + + limit: int + """Specifies the maximum number of objects returned in a page. + + The default value is 50. The minimum value is 1, and the maximum value is 100. + """ + + order: Literal["asc", "desc"] + """Specifies the order of the results. + + Valid values are asc for ascending and desc for descending. + """ + + status: Literal["in_progress", "completed", "failed"] + """Specifies the status of the files to retrieve. + + Valid values are in_progress, completed or failed. + """ diff --git a/src/writerai/types/file_retry_params.py b/src/writerai/types/file_retry_params.py new file mode 100644 index 00000000..8882af6a --- /dev/null +++ b/src/writerai/types/file_retry_params.py @@ -0,0 +1,14 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing_extensions import Required, TypedDict + +from .._types import SequenceNotStr + +__all__ = ["FileRetryParams"] + + +class FileRetryParams(TypedDict, total=False): + file_ids: Required[SequenceNotStr[str]] + """The unique identifier of the files to retry.""" diff --git a/src/writerai/types/file_retry_response.py b/src/writerai/types/file_retry_response.py new file mode 100644 index 00000000..4414586c --- /dev/null +++ b/src/writerai/types/file_retry_response.py @@ -0,0 +1,12 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import Optional + +from .._models import BaseModel + +__all__ = ["FileRetryResponse"] + + +class FileRetryResponse(BaseModel): + success: Optional[bool] = None + """Indicates whether the retry operation was successful.""" diff --git a/src/writerai/types/file_upload_params.py b/src/writerai/types/file_upload_params.py new file mode 100644 index 00000000..0bd8f75e --- /dev/null +++ b/src/writerai/types/file_upload_params.py @@ -0,0 +1,23 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing_extensions import Required, Annotated, TypedDict + +from .._utils import PropertyInfo + +__all__ = ["FileUploadParams"] + + +class FileUploadParams(TypedDict, total=False): + content_disposition: Required[Annotated[str, PropertyInfo(alias="Content-Disposition")]] + + graph_id: Annotated[str, PropertyInfo(alias="graphId")] + """ + The unique identifier of the Knowledge Graph to associate the uploaded file + with. + + Note: The response from the upload endpoint does not include the `graphId` + field, but the association will be visible when you retrieve the file using the + file retrieval endpoint. + """ diff --git a/src/writerai/types/graph.py b/src/writerai/types/graph.py new file mode 100644 index 00000000..7721c972 --- /dev/null +++ b/src/writerai/types/graph.py @@ -0,0 +1,80 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import List, Optional +from datetime import datetime +from typing_extensions import Literal + +from .._models import BaseModel + +__all__ = ["Graph", "FileStatus", "URL", "URLStatus"] + + +class FileStatus(BaseModel): + """The processing status of files in the Knowledge Graph.""" + + completed: int + """The number of files that have been successfully processed.""" + + failed: int + """The number of files that failed to process.""" + + in_progress: int + """The number of files currently being processed.""" + + total: int + """The total number of files associated with the Knowledge Graph.""" + + +class URLStatus(BaseModel): + """The current status of the URL processing.""" + + status: Literal["validating", "success", "error"] + """The current status of the URL processing.""" + + error_type: Optional[ + Literal["invalid_url", "not_searchable", "not_found", "paywall_or_login_page", "unexpected_error"] + ] = None + """The type of error that occurred during processing, if any.""" + + +class URL(BaseModel): + status: URLStatus + """The current status of the URL processing.""" + + type: Literal["single_page", "sub_pages"] + """The type of web connector processing for this URL.""" + + url: str + """The URL to be processed by the web connector.""" + + exclude_urls: Optional[List[str]] = None + """An array of URLs to exclude from processing within this web connector.""" + + +class Graph(BaseModel): + id: str + """The unique identifier of the Knowledge Graph.""" + + created_at: datetime + """The timestamp when the Knowledge Graph was created.""" + + file_status: FileStatus + """The processing status of files in the Knowledge Graph.""" + + name: str + """The name of the Knowledge Graph.""" + + type: Literal["manual", "connector", "web"] + """The type of Knowledge Graph. + + - `manual`: files are uploaded via UI or API + - `connector`: files are uploaded via a data connector such as Google Drive or + Confluence + - `web`: URLs are connected to the Knowledge Graph + """ + + description: Optional[str] = None + """A description of the Knowledge Graph.""" + + urls: Optional[List[URL]] = None + """An array of web connector URLs associated with this Knowledge Graph.""" diff --git a/src/writerai/types/graph_add_file_to_graph_params.py b/src/writerai/types/graph_add_file_to_graph_params.py new file mode 100644 index 00000000..b0c471d0 --- /dev/null +++ b/src/writerai/types/graph_add_file_to_graph_params.py @@ -0,0 +1,12 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing_extensions import Required, TypedDict + +__all__ = ["GraphAddFileToGraphParams"] + + +class GraphAddFileToGraphParams(TypedDict, total=False): + file_id: Required[str] + """The unique identifier of the file.""" diff --git a/src/writerai/types/graph_create_params.py b/src/writerai/types/graph_create_params.py new file mode 100644 index 00000000..30f2b213 --- /dev/null +++ b/src/writerai/types/graph_create_params.py @@ -0,0 +1,21 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing_extensions import TypedDict + +__all__ = ["GraphCreateParams"] + + +class GraphCreateParams(TypedDict, total=False): + description: str + """A description of the Knowledge Graph (max 255 characters). + + Omitting this field leaves the description unchanged. + """ + + name: str + """The name of the Knowledge Graph (max 255 characters). + + Omitting this field leaves the name unchanged. + """ diff --git a/src/writerai/types/graph_create_response.py b/src/writerai/types/graph_create_response.py new file mode 100644 index 00000000..87b471c0 --- /dev/null +++ b/src/writerai/types/graph_create_response.py @@ -0,0 +1,52 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import List, Optional +from datetime import datetime +from typing_extensions import Literal + +from .._models import BaseModel + +__all__ = ["GraphCreateResponse", "URL", "URLStatus"] + + +class URLStatus(BaseModel): + """The current status of the URL processing.""" + + status: Literal["validating", "success", "error"] + """The current status of the URL processing.""" + + error_type: Optional[ + Literal["invalid_url", "not_searchable", "not_found", "paywall_or_login_page", "unexpected_error"] + ] = None + """The type of error that occurred during processing, if any.""" + + +class URL(BaseModel): + status: URLStatus + """The current status of the URL processing.""" + + type: Literal["single_page", "sub_pages"] + """The type of web connector processing for this URL.""" + + url: str + """The URL to be processed by the web connector.""" + + exclude_urls: Optional[List[str]] = None + """An array of URLs to exclude from processing within this web connector.""" + + +class GraphCreateResponse(BaseModel): + id: str + """A unique identifier of the Knowledge Graph.""" + + created_at: datetime + """The timestamp when the Knowledge Graph was created.""" + + name: str + """The name of the Knowledge Graph (max 255 characters).""" + + description: Optional[str] = None + """A description of the Knowledge Graph (max 255 characters).""" + + urls: Optional[List[URL]] = None + """An array of web connector URLs associated with this Knowledge Graph.""" diff --git a/src/writerai/types/graph_delete_response.py b/src/writerai/types/graph_delete_response.py new file mode 100644 index 00000000..75b3b640 --- /dev/null +++ b/src/writerai/types/graph_delete_response.py @@ -0,0 +1,13 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from .._models import BaseModel + +__all__ = ["GraphDeleteResponse"] + + +class GraphDeleteResponse(BaseModel): + id: str + """A unique identifier of the deleted Knowledge Graph.""" + + deleted: bool + """Indicates whether the Knowledge Graph was successfully deleted.""" diff --git a/src/writerai/types/graph_list_params.py b/src/writerai/types/graph_list_params.py new file mode 100644 index 00000000..d83fcc33 --- /dev/null +++ b/src/writerai/types/graph_list_params.py @@ -0,0 +1,33 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing_extensions import Literal, TypedDict + +__all__ = ["GraphListParams"] + + +class GraphListParams(TypedDict, total=False): + after: str + """The ID of the last object in the previous page. + + This parameter instructs the API to return the next page of results. + """ + + before: str + """The ID of the first object in the previous page. + + This parameter instructs the API to return the previous page of results. + """ + + limit: int + """Specifies the maximum number of objects returned in a page. + + The default value is 50. The minimum value is 1, and the maximum value is 100. + """ + + order: Literal["asc", "desc"] + """Specifies the order of the results. + + Valid values are asc for ascending and desc for descending. + """ diff --git a/src/writerai/types/graph_question_params.py b/src/writerai/types/graph_question_params.py new file mode 100644 index 00000000..dc61ae1d --- /dev/null +++ b/src/writerai/types/graph_question_params.py @@ -0,0 +1,119 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing import Union +from typing_extensions import Literal, Required, TypedDict + +from .._types import SequenceNotStr + +__all__ = ["GraphQuestionParamsBase", "QueryConfig", "GraphQuestionParamsNonStreaming", "GraphQuestionParamsStreaming"] + + +class GraphQuestionParamsBase(TypedDict, total=False): + graph_ids: Required[SequenceNotStr[str]] + """The unique identifiers of the Knowledge Graphs to query.""" + + question: Required[str] + """The question to answer using the Knowledge Graph.""" + + query_config: QueryConfig + """ + Configuration options for Knowledge Graph queries, including search parameters + and citation settings. + """ + + subqueries: bool + """Specify whether to include subqueries.""" + + +class QueryConfig(TypedDict, total=False): + """ + Configuration options for Knowledge Graph queries, including search parameters and citation settings. + """ + + grounding_level: float + """ + Level of grounding required for responses, controlling how closely answers must + be tied to source material. Set lower for grounded outputs, higher for + creativity. Higher values (closer to 1.0) allow more creative interpretation, + while lower values (closer to 0.0) stick more closely to source material. Range: + 0.0-1.0, Default: 0.0. + """ + + inline_citations: bool + """ + Whether to include inline citations in the response, showing which Knowledge + Graph sources were used. Default: false. + """ + + keyword_threshold: float + """Threshold for keyword-based matching when searching Knowledge Graph content. + + Set higher for stricter relevance, lower for broader range. Higher values + (closer to 1.0) require stronger keyword matches, while lower values (closer to + 0.0) allow more lenient matching. Range: 0.0-1.0, Default: 0.7. + """ + + max_snippets: int + """Maximum number of text snippets to retrieve from the Knowledge Graph for + context. + + Works in concert with `search_weight` to control best matches vs broader + coverage. While technically supports 1-60, values below 5 may return no results + due to RAG implementation. Recommended range: 5-25. Due to RAG system behavior, + you may see more snippets than requested. Range: 1-60, Default: 30. + """ + + max_subquestions: int + """Maximum number of subquestions to generate when processing complex queries. + + Set higher to improve detail, set lower to reduce response time. Range: 1-10, + Default: 6. + """ + + max_tokens: int + """Maximum number of tokens the model can generate in the response. + + This controls the length of the AI's answer. Set higher for longer answers, set + lower for shorter, faster answers. Range: 100-8000, Default: 4000. + """ + + search_weight: int + """Weight given to search results when ranking and selecting relevant information. + + Higher values (closer to 100) prioritize keyword-based matching, while lower + values (closer to 0) prioritize semantic similarity matching. Use higher values + for exact keyword searches, lower values for conceptual similarity searches. + Range: 0-100, Default: 50. + """ + + semantic_threshold: float + """ + Threshold for semantic similarity matching when searching Knowledge Graph + content. Set higher for stricter relevance, lower for broader range. Higher + values (closer to 1.0) require stronger semantic similarity, while lower values + (closer to 0.0) allow more lenient semantic matching. Range: 0.0-1.0, Default: + 0.7. + """ + + +class GraphQuestionParamsNonStreaming(GraphQuestionParamsBase, total=False): + stream: Literal[False] + """Determines whether the model's output should be streamed. + + If true, the output is generated and sent incrementally, which can be useful for + real-time applications. + """ + + +class GraphQuestionParamsStreaming(GraphQuestionParamsBase): + stream: Required[Literal[True]] + """Determines whether the model's output should be streamed. + + If true, the output is generated and sent incrementally, which can be useful for + real-time applications. + """ + + +GraphQuestionParams = Union[GraphQuestionParamsNonStreaming, GraphQuestionParamsStreaming] diff --git a/src/writerai/types/graph_remove_file_from_graph_response.py b/src/writerai/types/graph_remove_file_from_graph_response.py new file mode 100644 index 00000000..bde8c696 --- /dev/null +++ b/src/writerai/types/graph_remove_file_from_graph_response.py @@ -0,0 +1,13 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from .._models import BaseModel + +__all__ = ["GraphRemoveFileFromGraphResponse"] + + +class GraphRemoveFileFromGraphResponse(BaseModel): + id: str + """A unique identifier of the deleted file.""" + + deleted: bool + """Indicates whether the file was successfully deleted.""" diff --git a/src/writerai/types/graph_update_params.py b/src/writerai/types/graph_update_params.py new file mode 100644 index 00000000..f400f902 --- /dev/null +++ b/src/writerai/types/graph_update_params.py @@ -0,0 +1,42 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing import Iterable +from typing_extensions import Literal, Required, TypedDict + +from .._types import SequenceNotStr + +__all__ = ["GraphUpdateParams", "URL"] + + +class GraphUpdateParams(TypedDict, total=False): + description: str + """A description of the Knowledge Graph (max 255 characters). + + Omitting this field leaves the description unchanged. + """ + + name: str + """The name of the Knowledge Graph (max 255 characters). + + Omitting this field leaves the name unchanged. + """ + + urls: Iterable[URL] + """An array of web connector URLs to update for this Knowledge Graph. + + You can only connect URLs to Knowledge Graphs with the type `web`. To clear the + list of URLs, set this field to an empty array. + """ + + +class URL(TypedDict, total=False): + type: Required[Literal["single_page", "sub_pages"]] + """The type of web connector processing for this URL.""" + + url: Required[str] + """The URL to be processed by the web connector.""" + + exclude_urls: SequenceNotStr[str] + """An array of URLs to exclude from processing within this web connector.""" diff --git a/src/writerai/types/graph_update_response.py b/src/writerai/types/graph_update_response.py new file mode 100644 index 00000000..6910f9d4 --- /dev/null +++ b/src/writerai/types/graph_update_response.py @@ -0,0 +1,52 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import List, Optional +from datetime import datetime +from typing_extensions import Literal + +from .._models import BaseModel + +__all__ = ["GraphUpdateResponse", "URL", "URLStatus"] + + +class URLStatus(BaseModel): + """The current status of the URL processing.""" + + status: Literal["validating", "success", "error"] + """The current status of the URL processing.""" + + error_type: Optional[ + Literal["invalid_url", "not_searchable", "not_found", "paywall_or_login_page", "unexpected_error"] + ] = None + """The type of error that occurred during processing, if any.""" + + +class URL(BaseModel): + status: URLStatus + """The current status of the URL processing.""" + + type: Literal["single_page", "sub_pages"] + """The type of web connector processing for this URL.""" + + url: str + """The URL to be processed by the web connector.""" + + exclude_urls: Optional[List[str]] = None + """An array of URLs to exclude from processing within this web connector.""" + + +class GraphUpdateResponse(BaseModel): + id: str + """A unique identifier of the Knowledge Graph.""" + + created_at: datetime + """The timestamp when the Knowledge Graph was created.""" + + name: str + """The name of the Knowledge Graph (max 255 characters).""" + + description: Optional[str] = None + """A description of the Knowledge Graph (max 255 characters).""" + + urls: Optional[List[URL]] = None + """An array of web connector URLs associated with this Knowledge Graph.""" diff --git a/src/writerai/types/model_list_response.py b/src/writerai/types/model_list_response.py new file mode 100644 index 00000000..7393a6b2 --- /dev/null +++ b/src/writerai/types/model_list_response.py @@ -0,0 +1,23 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import List + +from .._models import BaseModel + +__all__ = ["ModelListResponse", "Model"] + + +class Model(BaseModel): + id: str + """The ID of the particular LLM that you want to use""" + + name: str + """The name of the particular LLM that you want to use.""" + + +class ModelListResponse(BaseModel): + models: List[Model] + """ + The [ID of the model](https://dev.writer.com/home/models) to use for processing + the request. + """ diff --git a/src/writerai/types/question.py b/src/writerai/types/question.py new file mode 100644 index 00000000..58a656cb --- /dev/null +++ b/src/writerai/types/question.py @@ -0,0 +1,109 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import List, Optional + +from pydantic import Field as FieldInfo + +from .._models import BaseModel +from .shared.source import Source + +__all__ = ["Question", "References", "ReferencesFile", "ReferencesWeb", "Subquery"] + + +class ReferencesFile(BaseModel): + """ + A file-based reference containing text snippets from uploaded documents in the Knowledge Graph. + """ + + file_id: str = FieldInfo(alias="fileId") + """The unique identifier of the file in your Writer account.""" + + score: float + """ + Internal score used during the retrieval process for ranking and selecting + relevant snippets. + """ + + text: str + """ + The exact text snippet from the source document that was used to support the + response. + """ + + cite: Optional[str] = None + """ + Unique citation ID that appears in inline citations within the response text + (null if not cited). + """ + + page: Optional[int] = None + """Page number where this snippet was found in the source document.""" + + +class ReferencesWeb(BaseModel): + """ + A web-based reference containing text snippets from online sources accessed during the query. + """ + + score: float + """ + Internal score used during the retrieval process for ranking and selecting + relevant snippets. + """ + + text: str + """ + The exact text snippet from the web source that was used to support the + response. + """ + + title: str + """The title of the web page where this content was found.""" + + url: str + """The URL of the web page where this content was found.""" + + +class References(BaseModel): + """ + Detailed source information organized by reference type, providing comprehensive metadata about the sources used to generate the response. + """ + + files: Optional[List[ReferencesFile]] = None + """Array of file-based references from uploaded documents in the Knowledge Graph.""" + + web: Optional[List[ReferencesWeb]] = None + """Array of web-based references from online sources accessed during the query.""" + + +class Subquery(BaseModel): + """ + A sub-question generated to break down complex queries into more manageable parts, along with its answer and supporting sources. + """ + + answer: str + """The answer to the subquery based on Knowledge Graph content.""" + + query: str + """The subquery that was generated to help answer the main question.""" + + sources: List[Optional[Source]] + """Array of source snippets that were used to answer this subquery.""" + + +class Question(BaseModel): + answer: str + """The answer to the question.""" + + question: str + """The question that was asked.""" + + sources: List[Optional[Source]] + + references: Optional[References] = None + """ + Detailed source information organized by reference type, providing comprehensive + metadata about the sources used to generate the response. + """ + + subqueries: Optional[List[Optional[Subquery]]] = None diff --git a/src/writerai/types/question_response_chunk.py b/src/writerai/types/question_response_chunk.py new file mode 100644 index 00000000..d140acc3 --- /dev/null +++ b/src/writerai/types/question_response_chunk.py @@ -0,0 +1,10 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from .._models import BaseModel +from .question import Question + +__all__ = ["QuestionResponseChunk"] + + +class QuestionResponseChunk(BaseModel): + data: Question diff --git a/src/writerai/types/shared/__init__.py b/src/writerai/types/shared/__init__.py new file mode 100644 index 00000000..d62055d0 --- /dev/null +++ b/src/writerai/types/shared/__init__.py @@ -0,0 +1,15 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from .source import Source as Source +from .logprobs import Logprobs as Logprobs +from .tool_call import ToolCall as ToolCall +from .graph_data import GraphData as GraphData +from .tool_param import ToolParam as ToolParam +from .error_object import ErrorObject as ErrorObject +from .error_message import ErrorMessage as ErrorMessage +from .logprobs_token import LogprobsToken as LogprobsToken +from .function_params import FunctionParams as FunctionParams +from .tool_choice_string import ToolChoiceString as ToolChoiceString +from .function_definition import FunctionDefinition as FunctionDefinition +from .tool_call_streaming import ToolCallStreaming as ToolCallStreaming +from .tool_choice_json_object import ToolChoiceJsonObject as ToolChoiceJsonObject diff --git a/src/writerai/types/shared/error_message.py b/src/writerai/types/shared/error_message.py new file mode 100644 index 00000000..2b5b1599 --- /dev/null +++ b/src/writerai/types/shared/error_message.py @@ -0,0 +1,15 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import Dict + +from ..._models import BaseModel + +__all__ = ["ErrorMessage"] + + +class ErrorMessage(BaseModel): + description: str + + extras: Dict[str, object] + + key: str diff --git a/src/writerai/types/shared/error_object.py b/src/writerai/types/shared/error_object.py new file mode 100644 index 00000000..e98f7a5e --- /dev/null +++ b/src/writerai/types/shared/error_object.py @@ -0,0 +1,16 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import Dict, List + +from ..._models import BaseModel +from .error_message import ErrorMessage + +__all__ = ["ErrorObject"] + + +class ErrorObject(BaseModel): + errors: List[ErrorMessage] + + extras: Dict[str, object] + + tpe: str diff --git a/src/writerai/types/shared/function_definition.py b/src/writerai/types/shared/function_definition.py new file mode 100644 index 00000000..e71dd5fc --- /dev/null +++ b/src/writerai/types/shared/function_definition.py @@ -0,0 +1,21 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import Optional + +from ..._models import BaseModel +from .function_params import FunctionParams + +__all__ = ["FunctionDefinition"] + + +class FunctionDefinition(BaseModel): + """A tool that uses a custom function.""" + + name: str + """Name of the function.""" + + description: Optional[str] = None + """Description of the function.""" + + parameters: Optional[FunctionParams] = None + """The parameters of the function.""" diff --git a/src/writerai/types/shared/function_params.py b/src/writerai/types/shared/function_params.py new file mode 100644 index 00000000..cb00a506 --- /dev/null +++ b/src/writerai/types/shared/function_params.py @@ -0,0 +1,8 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import Dict +from typing_extensions import TypeAlias + +__all__ = ["FunctionParams"] + +FunctionParams: TypeAlias = Dict[str, object] diff --git a/src/writerai/types/shared/graph_data.py b/src/writerai/types/shared/graph_data.py new file mode 100644 index 00000000..bac0293d --- /dev/null +++ b/src/writerai/types/shared/graph_data.py @@ -0,0 +1,106 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import List, Optional +from typing_extensions import Literal + +from pydantic import Field as FieldInfo + +from .source import Source +from ..._models import BaseModel + +__all__ = ["GraphData", "References", "ReferencesFile", "ReferencesWeb", "Subquery"] + + +class ReferencesFile(BaseModel): + """ + A file-based reference containing text snippets from uploaded documents in the Knowledge Graph. + """ + + file_id: str = FieldInfo(alias="fileId") + """The unique identifier of the file in your Writer account.""" + + score: float + """ + Internal score used during the retrieval process for ranking and selecting + relevant snippets. + """ + + text: str + """ + The exact text snippet from the source document that was used to support the + response. + """ + + cite: Optional[str] = None + """ + Unique citation ID that appears in inline citations within the response text + (null if not cited). + """ + + page: Optional[int] = None + """Page number where this snippet was found in the source document.""" + + +class ReferencesWeb(BaseModel): + """ + A web-based reference containing text snippets from online sources accessed during the query. + """ + + score: float + """ + Internal score used during the retrieval process for ranking and selecting + relevant snippets. + """ + + text: str + """ + The exact text snippet from the web source that was used to support the + response. + """ + + title: str + """The title of the web page where this content was found.""" + + url: str + """The URL of the web page where this content was found.""" + + +class References(BaseModel): + """ + Detailed source information organized by reference type, providing comprehensive metadata about the sources used to generate the response. + """ + + files: Optional[List[ReferencesFile]] = None + """Array of file-based references from uploaded documents in the Knowledge Graph.""" + + web: Optional[List[ReferencesWeb]] = None + """Array of web-based references from online sources accessed during the query.""" + + +class Subquery(BaseModel): + """ + A sub-question generated to break down complex queries into more manageable parts, along with its answer and supporting sources. + """ + + answer: str + """The answer to the subquery based on Knowledge Graph content.""" + + query: str + """The subquery that was generated to help answer the main question.""" + + sources: List[Optional[Source]] + """Array of source snippets that were used to answer this subquery.""" + + +class GraphData(BaseModel): + references: Optional[References] = None + """ + Detailed source information organized by reference type, providing comprehensive + metadata about the sources used to generate the response. + """ + + sources: Optional[List[Optional[Source]]] = None + + status: Optional[Literal["processing", "finished"]] = None + + subqueries: Optional[List[Optional[Subquery]]] = None diff --git a/src/writerai/types/shared/logprobs.py b/src/writerai/types/shared/logprobs.py new file mode 100644 index 00000000..33dbf7ed --- /dev/null +++ b/src/writerai/types/shared/logprobs.py @@ -0,0 +1,14 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import List, Optional + +from ..._models import BaseModel +from .logprobs_token import LogprobsToken + +__all__ = ["Logprobs"] + + +class Logprobs(BaseModel): + content: Optional[List[LogprobsToken]] = None + + refusal: Optional[List[LogprobsToken]] = None diff --git a/src/writerai/types/shared/logprobs_token.py b/src/writerai/types/shared/logprobs_token.py new file mode 100644 index 00000000..cf122336 --- /dev/null +++ b/src/writerai/types/shared/logprobs_token.py @@ -0,0 +1,29 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import List, Optional + +from ..._models import BaseModel + +__all__ = ["LogprobsToken", "TopLogprob"] + + +class TopLogprob(BaseModel): + """ + An array of mappings for each token to its top log probabilities, showing detailed prediction probabilities. + """ + + token: str + + logprob: float + + bytes: Optional[List[int]] = None + + +class LogprobsToken(BaseModel): + token: str + + logprob: float + + top_logprobs: List[TopLogprob] + + bytes: Optional[List[int]] = None diff --git a/src/writerai/types/shared/source.py b/src/writerai/types/shared/source.py new file mode 100644 index 00000000..65debe9c --- /dev/null +++ b/src/writerai/types/shared/source.py @@ -0,0 +1,18 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from ..._models import BaseModel + +__all__ = ["Source"] + + +class Source(BaseModel): + """A source snippet containing text and fileId from Knowledge Graph content.""" + + file_id: str + """The unique identifier of the file in your Writer account.""" + + snippet: str + """ + The exact text snippet from the source document that was used to support the + response. + """ diff --git a/src/writerai/types/shared/tool_call.py b/src/writerai/types/shared/tool_call.py new file mode 100644 index 00000000..33d5d929 --- /dev/null +++ b/src/writerai/types/shared/tool_call.py @@ -0,0 +1,24 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import Optional +from typing_extensions import Literal + +from ..._models import BaseModel + +__all__ = ["ToolCall", "Function"] + + +class Function(BaseModel): + arguments: str + + name: Optional[str] = None + + +class ToolCall(BaseModel): + id: str + + function: Function + + type: Literal["function"] + + index: Optional[int] = None diff --git a/src/writerai/types/shared/tool_call_streaming.py b/src/writerai/types/shared/tool_call_streaming.py new file mode 100644 index 00000000..ad4619e6 --- /dev/null +++ b/src/writerai/types/shared/tool_call_streaming.py @@ -0,0 +1,24 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import Optional +from typing_extensions import Literal + +from ..._models import BaseModel + +__all__ = ["ToolCallStreaming", "Function"] + + +class Function(BaseModel): + arguments: str + + name: Optional[str] = None + + +class ToolCallStreaming(BaseModel): + index: int + + id: Optional[str] = None + + function: Optional[Function] = None + + type: Optional[Literal["function"]] = None diff --git a/src/writerai/types/shared/tool_choice_json_object.py b/src/writerai/types/shared/tool_choice_json_object.py new file mode 100644 index 00000000..499b6d24 --- /dev/null +++ b/src/writerai/types/shared/tool_choice_json_object.py @@ -0,0 +1,15 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import Dict + +from ..._models import BaseModel + +__all__ = ["ToolChoiceJsonObject"] + + +class ToolChoiceJsonObject(BaseModel): + value: Dict[str, object] + """A JSON object that specifies the tool to call. + + For example, `{"type": "function", "function": {"name": "get_current_weather"}}` + """ diff --git a/src/writerai/types/shared/tool_choice_string.py b/src/writerai/types/shared/tool_choice_string.py new file mode 100644 index 00000000..653664da --- /dev/null +++ b/src/writerai/types/shared/tool_choice_string.py @@ -0,0 +1,11 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing_extensions import Literal + +from ..._models import BaseModel + +__all__ = ["ToolChoiceString"] + + +class ToolChoiceString(BaseModel): + value: Literal["none", "auto", "required"] diff --git a/src/writerai/types/shared/tool_param.py b/src/writerai/types/shared/tool_param.py new file mode 100644 index 00000000..1391010a --- /dev/null +++ b/src/writerai/types/shared/tool_param.py @@ -0,0 +1,282 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import List, Union, Optional +from typing_extensions import Literal, Annotated, TypeAlias + +from ..._utils import PropertyInfo +from ..._models import BaseModel +from .function_definition import FunctionDefinition + +__all__ = [ + "ToolParam", + "FunctionTool", + "GraphTool", + "GraphToolFunction", + "GraphToolFunctionQueryConfig", + "LlmTool", + "LlmToolFunction", + "TranslationTool", + "TranslationToolFunction", + "VisionTool", + "VisionToolFunction", + "VisionToolFunctionVariable", + "WebSearchTool", + "WebSearchToolFunction", +] + + +class FunctionTool(BaseModel): + function: FunctionDefinition + """A tool that uses a custom function.""" + + type: Literal["function"] + """The type of tool.""" + + +class GraphToolFunctionQueryConfig(BaseModel): + """ + Configuration options for Knowledge Graph queries, including search parameters and citation settings. + """ + + grounding_level: Optional[float] = None + """ + Level of grounding required for responses, controlling how closely answers must + be tied to source material. Set lower for grounded outputs, higher for + creativity. Higher values (closer to 1.0) allow more creative interpretation, + while lower values (closer to 0.0) stick more closely to source material. Range: + 0.0-1.0, Default: 0.0. + """ + + inline_citations: Optional[bool] = None + """ + Whether to include inline citations in the response, showing which Knowledge + Graph sources were used. Default: false. + """ + + keyword_threshold: Optional[float] = None + """Threshold for keyword-based matching when searching Knowledge Graph content. + + Set higher for stricter relevance, lower for broader range. Higher values + (closer to 1.0) require stronger keyword matches, while lower values (closer to + 0.0) allow more lenient matching. Range: 0.0-1.0, Default: 0.7. + """ + + max_snippets: Optional[int] = None + """Maximum number of text snippets to retrieve from the Knowledge Graph for + context. + + Works in concert with `search_weight` to control best matches vs broader + coverage. While technically supports 1-60, values below 5 may return no results + due to RAG implementation. Recommended range: 5-25. Due to RAG system behavior, + you may see more snippets than requested. Range: 1-60, Default: 30. + """ + + max_subquestions: Optional[int] = None + """Maximum number of subquestions to generate when processing complex queries. + + Set higher to improve detail, set lower to reduce response time. Range: 1-10, + Default: 6. + """ + + max_tokens: Optional[int] = None + """Maximum number of tokens the model can generate in the response. + + This controls the length of the AI's answer. Set higher for longer answers, set + lower for shorter, faster answers. Range: 100-8000, Default: 4000. + """ + + search_weight: Optional[int] = None + """Weight given to search results when ranking and selecting relevant information. + + Higher values (closer to 100) prioritize keyword-based matching, while lower + values (closer to 0) prioritize semantic similarity matching. Use higher values + for exact keyword searches, lower values for conceptual similarity searches. + Range: 0-100, Default: 50. + """ + + semantic_threshold: Optional[float] = None + """ + Threshold for semantic similarity matching when searching Knowledge Graph + content. Set higher for stricter relevance, lower for broader range. Higher + values (closer to 1.0) require stronger semantic similarity, while lower values + (closer to 0.0) allow more lenient semantic matching. Range: 0.0-1.0, Default: + 0.7. + """ + + +class GraphToolFunction(BaseModel): + """A tool that uses Knowledge Graphs as context for responses.""" + + graph_ids: List[str] + """An array of graph IDs to use in the tool.""" + + subqueries: bool + """Boolean to indicate whether to include subqueries in the response.""" + + description: Optional[str] = None + """A description of the graph content.""" + + query_config: Optional[GraphToolFunctionQueryConfig] = None + """ + Configuration options for Knowledge Graph queries, including search parameters + and citation settings. + """ + + +class GraphTool(BaseModel): + function: GraphToolFunction + """A tool that uses Knowledge Graphs as context for responses.""" + + type: Literal["graph"] + """The type of tool.""" + + +class LlmToolFunction(BaseModel): + """A tool that uses another Writer model to generate a response.""" + + description: str + """A description of the model to use.""" + + model: str + """The model to use.""" + + +class LlmTool(BaseModel): + function: LlmToolFunction + """A tool that uses another Writer model to generate a response.""" + + type: Literal["llm"] + """The type of tool.""" + + +class TranslationToolFunction(BaseModel): + """A tool that uses Palmyra Translate to translate text.""" + + formality: bool + """Whether to use formal or informal language in the translation. + + See the + [list of languages that support formality](https://dev.writer.com/api-reference/translation-api/language-support#formality). + If the language does not support formality, this parameter is ignored. + """ + + length_control: bool + """Whether to control the length of the translated text. + + See the + [list of languages that support length control](https://dev.writer.com/api-reference/translation-api/language-support#length-control). + If the language does not support length control, this parameter is ignored. + """ + + mask_profanity: bool + """Whether to mask profane words in the translated text. + + See the + [list of languages that do not support profanity masking](https://dev.writer.com/api-reference/translation-api/language-support#profanity-masking). + If the language does not support profanity masking, this parameter is ignored. + """ + + model: Literal["palmyra-translate"] + """The model to use for translation.""" + + source_language_code: Optional[str] = None + """Optional. + + The [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639_language_codes) + language code of the original text to translate. For example, `en` for English, + `zh` for Chinese, `fr` for French, `es` for Spanish. If the language has a + variant, the code appends the two-digit + [ISO-3166 country code](https://en.wikipedia.org/wiki/List_of_ISO_3166_country_codes). + If you do not provide a language code, the LLM detects the language of the text. + """ + + target_language_code: Optional[str] = None + """Optional. + + The [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639_language_codes) + language code of the target language for the translation. For example, `en` for + English, `zh` for Chinese, `fr` for French, `es` for Spanish. If the language + has a variant, the code appends the two-digit + [ISO-3166 country code](https://en.wikipedia.org/wiki/List_of_ISO_3166_country_codes). + If you do not provide a language code, the LLM uses the content of the chat + message to determine the target language. + """ + + +class TranslationTool(BaseModel): + """A tool that uses Palmyra Translate to translate text. + + Note that this tool does not stream results. The response is returned after the translation is complete. + """ + + function: TranslationToolFunction + """A tool that uses Palmyra Translate to translate text.""" + + type: Literal["translation"] + """The type of tool.""" + + +class VisionToolFunctionVariable(BaseModel): + file_id: str + """The File ID of the file to analyze. + + The file must be uploaded to the Writer platform before you use it with the + Vision tool. Supported file types: JPG, PNG, PDF, TXT. The maximum allowed file + size is 7MB. + """ + + name: str + """The name of the file variable. + + You must reference this name in the `message.content` field of the request to + the chat completions endpoint. Use double curly braces (`{{}}`) to reference the + file. For example, + `Describe the difference between the image {{image_1}} and the image {{image_2}}`. + """ + + +class VisionToolFunction(BaseModel): + """A tool that uses Palmyra Vision to analyze images and documents. + + Supports JPG, PNG, PDF, and TXT files up to 7MB each. + """ + + model: Literal["palmyra-vision"] + """The model to use for image analysis.""" + + variables: List[VisionToolFunctionVariable] + + +class VisionTool(BaseModel): + function: VisionToolFunction + """A tool that uses Palmyra Vision to analyze images and documents. + + Supports JPG, PNG, PDF, and TXT files up to 7MB each. + """ + + type: Literal["vision"] + """The type of tool.""" + + +class WebSearchToolFunction(BaseModel): + """A tool that uses web search to find information.""" + + exclude_domains: List[str] + """An array of domains to exclude from the search results.""" + + include_domains: List[str] + """An array of domains to include in the search results.""" + + +class WebSearchTool(BaseModel): + function: WebSearchToolFunction + """A tool that uses web search to find information.""" + + type: Literal["web_search"] + """The type of tool.""" + + +ToolParam: TypeAlias = Annotated[ + Union[FunctionTool, GraphTool, LlmTool, TranslationTool, VisionTool, WebSearchTool], + PropertyInfo(discriminator="type"), +] diff --git a/src/writerai/types/shared_params/__init__.py b/src/writerai/types/shared_params/__init__.py new file mode 100644 index 00000000..1bd920cf --- /dev/null +++ b/src/writerai/types/shared_params/__init__.py @@ -0,0 +1,10 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from .source import Source as Source +from .tool_call import ToolCall as ToolCall +from .graph_data import GraphData as GraphData +from .tool_param import ToolParam as ToolParam +from .function_params import FunctionParams as FunctionParams +from .tool_choice_string import ToolChoiceString as ToolChoiceString +from .function_definition import FunctionDefinition as FunctionDefinition +from .tool_choice_json_object import ToolChoiceJsonObject as ToolChoiceJsonObject diff --git a/src/writerai/types/shared_params/function_definition.py b/src/writerai/types/shared_params/function_definition.py new file mode 100644 index 00000000..44dcd0c7 --- /dev/null +++ b/src/writerai/types/shared_params/function_definition.py @@ -0,0 +1,22 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing_extensions import Required, TypedDict + +from .function_params import FunctionParams + +__all__ = ["FunctionDefinition"] + + +class FunctionDefinition(TypedDict, total=False): + """A tool that uses a custom function.""" + + name: Required[str] + """Name of the function.""" + + description: str + """Description of the function.""" + + parameters: FunctionParams + """The parameters of the function.""" diff --git a/src/writerai/types/shared_params/function_params.py b/src/writerai/types/shared_params/function_params.py new file mode 100644 index 00000000..dd1f1c2a --- /dev/null +++ b/src/writerai/types/shared_params/function_params.py @@ -0,0 +1,10 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing import Dict +from typing_extensions import TypeAlias + +__all__ = ["FunctionParams"] + +FunctionParams: TypeAlias = Dict[str, object] diff --git a/src/writerai/types/shared_params/graph_data.py b/src/writerai/types/shared_params/graph_data.py new file mode 100644 index 00000000..70f0caad --- /dev/null +++ b/src/writerai/types/shared_params/graph_data.py @@ -0,0 +1,106 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing import Iterable, Optional +from typing_extensions import Literal, Required, Annotated, TypedDict + +from .source import Source +from ..._utils import PropertyInfo + +__all__ = ["GraphData", "References", "ReferencesFile", "ReferencesWeb", "Subquery"] + + +class ReferencesFile(TypedDict, total=False): + """ + A file-based reference containing text snippets from uploaded documents in the Knowledge Graph. + """ + + file_id: Required[Annotated[str, PropertyInfo(alias="fileId")]] + """The unique identifier of the file in your Writer account.""" + + score: Required[float] + """ + Internal score used during the retrieval process for ranking and selecting + relevant snippets. + """ + + text: Required[str] + """ + The exact text snippet from the source document that was used to support the + response. + """ + + cite: str + """ + Unique citation ID that appears in inline citations within the response text + (null if not cited). + """ + + page: int + """Page number where this snippet was found in the source document.""" + + +class ReferencesWeb(TypedDict, total=False): + """ + A web-based reference containing text snippets from online sources accessed during the query. + """ + + score: Required[float] + """ + Internal score used during the retrieval process for ranking and selecting + relevant snippets. + """ + + text: Required[str] + """ + The exact text snippet from the web source that was used to support the + response. + """ + + title: Required[str] + """The title of the web page where this content was found.""" + + url: Required[str] + """The URL of the web page where this content was found.""" + + +class References(TypedDict, total=False): + """ + Detailed source information organized by reference type, providing comprehensive metadata about the sources used to generate the response. + """ + + files: Iterable[ReferencesFile] + """Array of file-based references from uploaded documents in the Knowledge Graph.""" + + web: Iterable[ReferencesWeb] + """Array of web-based references from online sources accessed during the query.""" + + +class Subquery(TypedDict, total=False): + """ + A sub-question generated to break down complex queries into more manageable parts, along with its answer and supporting sources. + """ + + answer: Required[str] + """The answer to the subquery based on Knowledge Graph content.""" + + query: Required[str] + """The subquery that was generated to help answer the main question.""" + + sources: Required[Iterable[Optional[Source]]] + """Array of source snippets that were used to answer this subquery.""" + + +class GraphData(TypedDict, total=False): + references: References + """ + Detailed source information organized by reference type, providing comprehensive + metadata about the sources used to generate the response. + """ + + sources: Iterable[Optional[Source]] + + status: Optional[Literal["processing", "finished"]] + + subqueries: Iterable[Optional[Subquery]] diff --git a/src/writerai/types/shared_params/source.py b/src/writerai/types/shared_params/source.py new file mode 100644 index 00000000..a1397fb5 --- /dev/null +++ b/src/writerai/types/shared_params/source.py @@ -0,0 +1,20 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing_extensions import Required, TypedDict + +__all__ = ["Source"] + + +class Source(TypedDict, total=False): + """A source snippet containing text and fileId from Knowledge Graph content.""" + + file_id: Required[str] + """The unique identifier of the file in your Writer account.""" + + snippet: Required[str] + """ + The exact text snippet from the source document that was used to support the + response. + """ diff --git a/src/writerai/types/shared_params/tool_call.py b/src/writerai/types/shared_params/tool_call.py new file mode 100644 index 00000000..cb736487 --- /dev/null +++ b/src/writerai/types/shared_params/tool_call.py @@ -0,0 +1,23 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing_extensions import Literal, Required, TypedDict + +__all__ = ["ToolCall", "Function"] + + +class Function(TypedDict, total=False): + arguments: Required[str] + + name: str + + +class ToolCall(TypedDict, total=False): + id: Required[str] + + function: Required[Function] + + type: Required[Literal["function"]] + + index: int diff --git a/src/writerai/types/shared_params/tool_choice_json_object.py b/src/writerai/types/shared_params/tool_choice_json_object.py new file mode 100644 index 00000000..30d0f7f6 --- /dev/null +++ b/src/writerai/types/shared_params/tool_choice_json_object.py @@ -0,0 +1,16 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing import Dict +from typing_extensions import Required, TypedDict + +__all__ = ["ToolChoiceJsonObject"] + + +class ToolChoiceJsonObject(TypedDict, total=False): + value: Required[Dict[str, object]] + """A JSON object that specifies the tool to call. + + For example, `{"type": "function", "function": {"name": "get_current_weather"}}` + """ diff --git a/src/writerai/types/shared_params/tool_choice_string.py b/src/writerai/types/shared_params/tool_choice_string.py new file mode 100644 index 00000000..5852b152 --- /dev/null +++ b/src/writerai/types/shared_params/tool_choice_string.py @@ -0,0 +1,11 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing_extensions import Literal, Required, TypedDict + +__all__ = ["ToolChoiceString"] + + +class ToolChoiceString(TypedDict, total=False): + value: Required[Literal["none", "auto", "required"]] diff --git a/src/writerai/types/shared_params/tool_param.py b/src/writerai/types/shared_params/tool_param.py new file mode 100644 index 00000000..1a9c4dd7 --- /dev/null +++ b/src/writerai/types/shared_params/tool_param.py @@ -0,0 +1,280 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing import Union, Iterable +from typing_extensions import Literal, Required, TypeAlias, TypedDict + +from ..._types import SequenceNotStr +from .function_definition import FunctionDefinition + +__all__ = [ + "ToolParam", + "FunctionTool", + "GraphTool", + "GraphToolFunction", + "GraphToolFunctionQueryConfig", + "LlmTool", + "LlmToolFunction", + "TranslationTool", + "TranslationToolFunction", + "VisionTool", + "VisionToolFunction", + "VisionToolFunctionVariable", + "WebSearchTool", + "WebSearchToolFunction", +] + + +class FunctionTool(TypedDict, total=False): + function: Required[FunctionDefinition] + """A tool that uses a custom function.""" + + type: Required[Literal["function"]] + """The type of tool.""" + + +class GraphToolFunctionQueryConfig(TypedDict, total=False): + """ + Configuration options for Knowledge Graph queries, including search parameters and citation settings. + """ + + grounding_level: float + """ + Level of grounding required for responses, controlling how closely answers must + be tied to source material. Set lower for grounded outputs, higher for + creativity. Higher values (closer to 1.0) allow more creative interpretation, + while lower values (closer to 0.0) stick more closely to source material. Range: + 0.0-1.0, Default: 0.0. + """ + + inline_citations: bool + """ + Whether to include inline citations in the response, showing which Knowledge + Graph sources were used. Default: false. + """ + + keyword_threshold: float + """Threshold for keyword-based matching when searching Knowledge Graph content. + + Set higher for stricter relevance, lower for broader range. Higher values + (closer to 1.0) require stronger keyword matches, while lower values (closer to + 0.0) allow more lenient matching. Range: 0.0-1.0, Default: 0.7. + """ + + max_snippets: int + """Maximum number of text snippets to retrieve from the Knowledge Graph for + context. + + Works in concert with `search_weight` to control best matches vs broader + coverage. While technically supports 1-60, values below 5 may return no results + due to RAG implementation. Recommended range: 5-25. Due to RAG system behavior, + you may see more snippets than requested. Range: 1-60, Default: 30. + """ + + max_subquestions: int + """Maximum number of subquestions to generate when processing complex queries. + + Set higher to improve detail, set lower to reduce response time. Range: 1-10, + Default: 6. + """ + + max_tokens: int + """Maximum number of tokens the model can generate in the response. + + This controls the length of the AI's answer. Set higher for longer answers, set + lower for shorter, faster answers. Range: 100-8000, Default: 4000. + """ + + search_weight: int + """Weight given to search results when ranking and selecting relevant information. + + Higher values (closer to 100) prioritize keyword-based matching, while lower + values (closer to 0) prioritize semantic similarity matching. Use higher values + for exact keyword searches, lower values for conceptual similarity searches. + Range: 0-100, Default: 50. + """ + + semantic_threshold: float + """ + Threshold for semantic similarity matching when searching Knowledge Graph + content. Set higher for stricter relevance, lower for broader range. Higher + values (closer to 1.0) require stronger semantic similarity, while lower values + (closer to 0.0) allow more lenient semantic matching. Range: 0.0-1.0, Default: + 0.7. + """ + + +class GraphToolFunction(TypedDict, total=False): + """A tool that uses Knowledge Graphs as context for responses.""" + + graph_ids: Required[SequenceNotStr[str]] + """An array of graph IDs to use in the tool.""" + + subqueries: Required[bool] + """Boolean to indicate whether to include subqueries in the response.""" + + description: str + """A description of the graph content.""" + + query_config: GraphToolFunctionQueryConfig + """ + Configuration options for Knowledge Graph queries, including search parameters + and citation settings. + """ + + +class GraphTool(TypedDict, total=False): + function: Required[GraphToolFunction] + """A tool that uses Knowledge Graphs as context for responses.""" + + type: Required[Literal["graph"]] + """The type of tool.""" + + +class LlmToolFunction(TypedDict, total=False): + """A tool that uses another Writer model to generate a response.""" + + description: Required[str] + """A description of the model to use.""" + + model: Required[str] + """The model to use.""" + + +class LlmTool(TypedDict, total=False): + function: Required[LlmToolFunction] + """A tool that uses another Writer model to generate a response.""" + + type: Required[Literal["llm"]] + """The type of tool.""" + + +class TranslationToolFunction(TypedDict, total=False): + """A tool that uses Palmyra Translate to translate text.""" + + formality: Required[bool] + """Whether to use formal or informal language in the translation. + + See the + [list of languages that support formality](https://dev.writer.com/api-reference/translation-api/language-support#formality). + If the language does not support formality, this parameter is ignored. + """ + + length_control: Required[bool] + """Whether to control the length of the translated text. + + See the + [list of languages that support length control](https://dev.writer.com/api-reference/translation-api/language-support#length-control). + If the language does not support length control, this parameter is ignored. + """ + + mask_profanity: Required[bool] + """Whether to mask profane words in the translated text. + + See the + [list of languages that do not support profanity masking](https://dev.writer.com/api-reference/translation-api/language-support#profanity-masking). + If the language does not support profanity masking, this parameter is ignored. + """ + + model: Required[Literal["palmyra-translate"]] + """The model to use for translation.""" + + source_language_code: str + """Optional. + + The [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639_language_codes) + language code of the original text to translate. For example, `en` for English, + `zh` for Chinese, `fr` for French, `es` for Spanish. If the language has a + variant, the code appends the two-digit + [ISO-3166 country code](https://en.wikipedia.org/wiki/List_of_ISO_3166_country_codes). + If you do not provide a language code, the LLM detects the language of the text. + """ + + target_language_code: str + """Optional. + + The [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639_language_codes) + language code of the target language for the translation. For example, `en` for + English, `zh` for Chinese, `fr` for French, `es` for Spanish. If the language + has a variant, the code appends the two-digit + [ISO-3166 country code](https://en.wikipedia.org/wiki/List_of_ISO_3166_country_codes). + If you do not provide a language code, the LLM uses the content of the chat + message to determine the target language. + """ + + +class TranslationTool(TypedDict, total=False): + """A tool that uses Palmyra Translate to translate text. + + Note that this tool does not stream results. The response is returned after the translation is complete. + """ + + function: Required[TranslationToolFunction] + """A tool that uses Palmyra Translate to translate text.""" + + type: Required[Literal["translation"]] + """The type of tool.""" + + +class VisionToolFunctionVariable(TypedDict, total=False): + file_id: Required[str] + """The File ID of the file to analyze. + + The file must be uploaded to the Writer platform before you use it with the + Vision tool. Supported file types: JPG, PNG, PDF, TXT. The maximum allowed file + size is 7MB. + """ + + name: Required[str] + """The name of the file variable. + + You must reference this name in the `message.content` field of the request to + the chat completions endpoint. Use double curly braces (`{{}}`) to reference the + file. For example, + `Describe the difference between the image {{image_1}} and the image {{image_2}}`. + """ + + +class VisionToolFunction(TypedDict, total=False): + """A tool that uses Palmyra Vision to analyze images and documents. + + Supports JPG, PNG, PDF, and TXT files up to 7MB each. + """ + + model: Required[Literal["palmyra-vision"]] + """The model to use for image analysis.""" + + variables: Required[Iterable[VisionToolFunctionVariable]] + + +class VisionTool(TypedDict, total=False): + function: Required[VisionToolFunction] + """A tool that uses Palmyra Vision to analyze images and documents. + + Supports JPG, PNG, PDF, and TXT files up to 7MB each. + """ + + type: Required[Literal["vision"]] + """The type of tool.""" + + +class WebSearchToolFunction(TypedDict, total=False): + """A tool that uses web search to find information.""" + + exclude_domains: Required[SequenceNotStr[str]] + """An array of domains to exclude from the search results.""" + + include_domains: Required[SequenceNotStr[str]] + """An array of domains to include in the search results.""" + + +class WebSearchTool(TypedDict, total=False): + function: Required[WebSearchToolFunction] + """A tool that uses web search to find information.""" + + type: Required[Literal["web_search"]] + """The type of tool.""" + + +ToolParam: TypeAlias = Union[FunctionTool, GraphTool, LlmTool, TranslationTool, VisionTool, WebSearchTool] diff --git a/src/writerai/types/tool_parse_pdf_params.py b/src/writerai/types/tool_parse_pdf_params.py new file mode 100644 index 00000000..52c535c2 --- /dev/null +++ b/src/writerai/types/tool_parse_pdf_params.py @@ -0,0 +1,12 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing_extensions import Literal, Required, TypedDict + +__all__ = ["ToolParsePdfParams"] + + +class ToolParsePdfParams(TypedDict, total=False): + format: Required[Literal["text", "markdown"]] + """The format into which the PDF content should be converted.""" diff --git a/src/writerai/types/tool_parse_pdf_response.py b/src/writerai/types/tool_parse_pdf_response.py new file mode 100644 index 00000000..0d601ec8 --- /dev/null +++ b/src/writerai/types/tool_parse_pdf_response.py @@ -0,0 +1,10 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from .._models import BaseModel + +__all__ = ["ToolParsePdfResponse"] + + +class ToolParsePdfResponse(BaseModel): + content: str + """The extracted content from the PDF file, converted to the specified format.""" diff --git a/src/writerai/types/tool_web_search_params.py b/src/writerai/types/tool_web_search_params.py new file mode 100644 index 00000000..6f639e02 --- /dev/null +++ b/src/writerai/types/tool_web_search_params.py @@ -0,0 +1,247 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing import Union +from typing_extensions import Literal, TypedDict + +from .._types import SequenceNotStr + +__all__ = ["ToolWebSearchParams"] + + +class ToolWebSearchParams(TypedDict, total=False): + chunks_per_source: int + """Only applies when `search_depth` is `advanced`. + + Specifies how many text segments to extract from each source. Limited to 3 + chunks maximum. + """ + + country: Literal[ + "afghanistan", + "albania", + "algeria", + "andorra", + "angola", + "argentina", + "armenia", + "australia", + "austria", + "azerbaijan", + "bahamas", + "bahrain", + "bangladesh", + "barbados", + "belarus", + "belgium", + "belize", + "benin", + "bhutan", + "bolivia", + "bosnia and herzegovina", + "botswana", + "brazil", + "brunei", + "bulgaria", + "burkina faso", + "burundi", + "cambodia", + "cameroon", + "canada", + "cape verde", + "central african republic", + "chad", + "chile", + "china", + "colombia", + "comoros", + "congo", + "costa rica", + "croatia", + "cuba", + "cyprus", + "czech republic", + "denmark", + "djibouti", + "dominican republic", + "ecuador", + "egypt", + "el salvador", + "equatorial guinea", + "eritrea", + "estonia", + "ethiopia", + "fiji", + "finland", + "france", + "gabon", + "gambia", + "georgia", + "germany", + "ghana", + "greece", + "guatemala", + "guinea", + "haiti", + "honduras", + "hungary", + "iceland", + "india", + "indonesia", + "iran", + "iraq", + "ireland", + "israel", + "italy", + "jamaica", + "japan", + "jordan", + "kazakhstan", + "kenya", + "kuwait", + "kyrgyzstan", + "latvia", + "lebanon", + "lesotho", + "liberia", + "libya", + "liechtenstein", + "lithuania", + "luxembourg", + "madagascar", + "malawi", + "malaysia", + "maldives", + "mali", + "malta", + "mauritania", + "mauritius", + "mexico", + "moldova", + "monaco", + "mongolia", + "montenegro", + "morocco", + "mozambique", + "myanmar", + "namibia", + "nepal", + "netherlands", + "new zealand", + "nicaragua", + "niger", + "nigeria", + "north korea", + "north macedonia", + "norway", + "oman", + "pakistan", + "panama", + "papua new guinea", + "paraguay", + "peru", + "philippines", + "poland", + "portugal", + "qatar", + "romania", + "russia", + "rwanda", + "saudi arabia", + "senegal", + "serbia", + "singapore", + "slovakia", + "slovenia", + "somalia", + "south africa", + "south korea", + "south sudan", + "spain", + "sri lanka", + "sudan", + "sweden", + "switzerland", + "syria", + "taiwan", + "tajikistan", + "tanzania", + "thailand", + "togo", + "trinidad and tobago", + "tunisia", + "turkey", + "turkmenistan", + "uganda", + "ukraine", + "united arab emirates", + "united kingdom", + "united states", + "uruguay", + "uzbekistan", + "venezuela", + "vietnam", + "yemen", + "zambia", + "zimbabwe", + ] + """Localizes search results to a specific country. + + Only applies to general topic searches. + """ + + days: int + """For news topic searches, specifies how many days of news coverage to include.""" + + exclude_domains: SequenceNotStr[str] + """Domains to exclude from the search. If unset, the search includes all domains.""" + + include_answer: bool + """Whether to include a generated answer to the query in the response. + + If `false`, only search results are returned. + """ + + include_domains: SequenceNotStr[str] + """Domains to include in the search. If unset, the search includes all domains.""" + + include_raw_content: Union[Literal["text", "markdown"], bool] + """Controls how raw content is included in search results: + + - `text`: Returns plain text without formatting markup + - `markdown`: Returns structured content with markdown formatting (headers, + links, bold text) + - `true`: Same as `markdown` + - `false`: Raw content is not included (default if unset) + """ + + max_results: int + """Limits the number of search results returned. Cannot exceed 20 sources.""" + + query: str + """The search query.""" + + search_depth: Literal["basic", "advanced"] + """Controls search comprehensiveness: + + - `basic`: Returns fewer but highly relevant results + - `advanced`: Performs a deeper search with more results + """ + + stream: bool + """Enables streaming of search results as they become available.""" + + time_range: Literal["day", "week", "month", "year", "d", "w", "m", "y"] + """ + Filters results to content published within the specified time range back from + the current date. For example, `week` or `w` returns results from the past 7 + days. + """ + + topic: Literal["general", "news"] + """The search topic category. + + Use `news` for current events and news articles, or `general` for broader web + search. + """ diff --git a/src/writerai/types/tool_web_search_response.py b/src/writerai/types/tool_web_search_response.py new file mode 100644 index 00000000..2ed9ed71 --- /dev/null +++ b/src/writerai/types/tool_web_search_response.py @@ -0,0 +1,32 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import List, Optional + +from .._models import BaseModel + +__all__ = ["ToolWebSearchResponse", "Source"] + + +class Source(BaseModel): + raw_content: Optional[str] = None + """Raw content from the source URL. + + Not included if `include_raw_content` is `false`. + """ + + url: Optional[str] = None + """URL of the search result.""" + + +class ToolWebSearchResponse(BaseModel): + query: str + """The search query that was submitted.""" + + sources: List[Source] + """The search results found.""" + + answer: Optional[str] = None + """Generated answer based on the search results. + + Not included if `include_answer` is `false`. + """ diff --git a/src/writerai/types/translation_response.py b/src/writerai/types/translation_response.py new file mode 100644 index 00000000..9d947d8b --- /dev/null +++ b/src/writerai/types/translation_response.py @@ -0,0 +1,10 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from .._models import BaseModel + +__all__ = ["TranslationResponse"] + + +class TranslationResponse(BaseModel): + data: str + """The result of the translation.""" diff --git a/src/writerai/types/translation_translate_params.py b/src/writerai/types/translation_translate_params.py new file mode 100644 index 00000000..26de74d7 --- /dev/null +++ b/src/writerai/types/translation_translate_params.py @@ -0,0 +1,61 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing_extensions import Literal, Required, TypedDict + +__all__ = ["TranslationTranslateParams"] + + +class TranslationTranslateParams(TypedDict, total=False): + formality: Required[bool] + """Whether to use formal or informal language in the translation. + + See the + [list of languages that support formality](https://dev.writer.com/api-reference/translation-api/language-support#formality). + If the language does not support formality, this parameter is ignored. + """ + + length_control: Required[bool] + """Whether to control the length of the translated text. + + See the + [list of languages that support length control](https://dev.writer.com/api-reference/translation-api/language-support#length-control). + If the language does not support length control, this parameter is ignored. + """ + + mask_profanity: Required[bool] + """Whether to mask profane words in the translated text. + + See the + [list of languages that do not support profanity masking](https://dev.writer.com/api-reference/translation-api/language-support#profanity-masking). + If the language does not support profanity masking, this parameter is ignored. + """ + + model: Required[Literal["palmyra-translate"]] + """The model to use for translation.""" + + source_language_code: Required[str] + """ + The [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639_language_codes) + language code of the original text to translate. For example, `en` for English, + `zh` for Chinese, `fr` for French, `es` for Spanish. If the language has a + variant, the code appends the two-digit + [ISO-3166 country code](https://en.wikipedia.org/wiki/List_of_ISO_3166_country_codes). + For example, Mexican Spanish is `es-MX`. See the + [list of supported languages and language codes](https://dev.writer.com/api-reference/translation-api/language-support). + """ + + target_language_code: Required[str] + """ + The [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639_language_codes) + language code of the target language for the translation. For example, `en` for + English, `zh` for Chinese, `fr` for French, `es` for Spanish. If the language + has a variant, the code appends the two-digit + [ISO-3166 country code](https://en.wikipedia.org/wiki/List_of_ISO_3166_country_codes). + For example, Mexican Spanish is `es-MX`. See the + [list of supported languages and language codes](https://dev.writer.com/api-reference/translation-api/language-support). + """ + + text: Required[str] + """The text to translate. Maximum of 100,000 words.""" diff --git a/src/writerai/types/vision_analyze_params.py b/src/writerai/types/vision_analyze_params.py new file mode 100644 index 00000000..2e0adf2d --- /dev/null +++ b/src/writerai/types/vision_analyze_params.py @@ -0,0 +1,47 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing import Iterable +from typing_extensions import Literal, Required, TypedDict + +__all__ = ["VisionAnalyzeParams", "Variable"] + + +class VisionAnalyzeParams(TypedDict, total=False): + model: Required[Literal["palmyra-vision"]] + """The model to use for image analysis.""" + + prompt: Required[str] + """The prompt to use for the image analysis. + + The prompt must include the name of each image variable, surrounded by double + curly braces (`{{}}`). For example, + `Describe the difference between the image {{image_1}} and the image {{image_2}}`. + """ + + variables: Required[Iterable[Variable]] + + +class Variable(TypedDict, total=False): + """An array of file variables required for the analysis. + + The files must be uploaded to the Writer platform before they can be used in a vision request. Learn how to upload files using the [Files API](https://dev.writer.com/api-reference/file-api/upload-files). + + Supported file types: JPG, PNG, PDF, TXT. The maximum allowed file size for each file is 7MB. + """ + + file_id: Required[str] + """The File ID of the file to analyze. + + The file must be uploaded to the Writer platform before it can be used in a + vision request. Supported file types: JPG, PNG, PDF, TXT (max 7MB each). + """ + + name: Required[str] + """The name of the file variable. + + You must reference this name in the prompt with double curly braces (`{{}}`). + For example, + `Describe the difference between the image {{image_1}} and the image {{image_2}}`. + """ diff --git a/src/writerai/types/vision_response.py b/src/writerai/types/vision_response.py new file mode 100644 index 00000000..56b65a3e --- /dev/null +++ b/src/writerai/types/vision_response.py @@ -0,0 +1,10 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from .._models import BaseModel + +__all__ = ["VisionResponse"] + + +class VisionResponse(BaseModel): + data: str + """The result of the image analysis.""" diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..fd8019a9 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. diff --git a/tests/api_resources/__init__.py b/tests/api_resources/__init__.py new file mode 100644 index 00000000..fd8019a9 --- /dev/null +++ b/tests/api_resources/__init__.py @@ -0,0 +1 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. diff --git a/tests/api_resources/applications/__init__.py b/tests/api_resources/applications/__init__.py new file mode 100644 index 00000000..fd8019a9 --- /dev/null +++ b/tests/api_resources/applications/__init__.py @@ -0,0 +1 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. diff --git a/tests/api_resources/applications/test_graphs.py b/tests/api_resources/applications/test_graphs.py new file mode 100644 index 00000000..e7c2c30e --- /dev/null +++ b/tests/api_resources/applications/test_graphs.py @@ -0,0 +1,184 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +import os +from typing import Any, cast + +import pytest + +from writerai import Writer, AsyncWriter +from tests.utils import assert_matches_type +from writerai.types.applications import ApplicationGraphsResponse + +base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") + + +class TestGraphs: + parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) + + @parametrize + def test_method_update(self, client: Writer) -> None: + graph = client.applications.graphs.update( + application_id="application_id", + graph_ids=["182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e"], + ) + assert_matches_type(ApplicationGraphsResponse, graph, path=["response"]) + + @parametrize + def test_raw_response_update(self, client: Writer) -> None: + response = client.applications.graphs.with_raw_response.update( + application_id="application_id", + graph_ids=["182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e"], + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + graph = response.parse() + assert_matches_type(ApplicationGraphsResponse, graph, path=["response"]) + + @parametrize + def test_streaming_response_update(self, client: Writer) -> None: + with client.applications.graphs.with_streaming_response.update( + application_id="application_id", + graph_ids=["182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e"], + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + graph = response.parse() + assert_matches_type(ApplicationGraphsResponse, graph, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_path_params_update(self, client: Writer) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `application_id` but received ''"): + client.applications.graphs.with_raw_response.update( + application_id="", + graph_ids=["182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e"], + ) + + @parametrize + def test_method_list(self, client: Writer) -> None: + graph = client.applications.graphs.list( + "application_id", + ) + assert_matches_type(ApplicationGraphsResponse, graph, path=["response"]) + + @parametrize + def test_raw_response_list(self, client: Writer) -> None: + response = client.applications.graphs.with_raw_response.list( + "application_id", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + graph = response.parse() + assert_matches_type(ApplicationGraphsResponse, graph, path=["response"]) + + @parametrize + def test_streaming_response_list(self, client: Writer) -> None: + with client.applications.graphs.with_streaming_response.list( + "application_id", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + graph = response.parse() + assert_matches_type(ApplicationGraphsResponse, graph, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_path_params_list(self, client: Writer) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `application_id` but received ''"): + client.applications.graphs.with_raw_response.list( + "", + ) + + +class TestAsyncGraphs: + parametrize = pytest.mark.parametrize( + "async_client", [False, True, {"http_client": "aiohttp"}], indirect=True, ids=["loose", "strict", "aiohttp"] + ) + + @parametrize + async def test_method_update(self, async_client: AsyncWriter) -> None: + graph = await async_client.applications.graphs.update( + application_id="application_id", + graph_ids=["182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e"], + ) + assert_matches_type(ApplicationGraphsResponse, graph, path=["response"]) + + @parametrize + async def test_raw_response_update(self, async_client: AsyncWriter) -> None: + response = await async_client.applications.graphs.with_raw_response.update( + application_id="application_id", + graph_ids=["182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e"], + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + graph = await response.parse() + assert_matches_type(ApplicationGraphsResponse, graph, path=["response"]) + + @parametrize + async def test_streaming_response_update(self, async_client: AsyncWriter) -> None: + async with async_client.applications.graphs.with_streaming_response.update( + application_id="application_id", + graph_ids=["182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e"], + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + graph = await response.parse() + assert_matches_type(ApplicationGraphsResponse, graph, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_path_params_update(self, async_client: AsyncWriter) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `application_id` but received ''"): + await async_client.applications.graphs.with_raw_response.update( + application_id="", + graph_ids=["182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e"], + ) + + @parametrize + async def test_method_list(self, async_client: AsyncWriter) -> None: + graph = await async_client.applications.graphs.list( + "application_id", + ) + assert_matches_type(ApplicationGraphsResponse, graph, path=["response"]) + + @parametrize + async def test_raw_response_list(self, async_client: AsyncWriter) -> None: + response = await async_client.applications.graphs.with_raw_response.list( + "application_id", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + graph = await response.parse() + assert_matches_type(ApplicationGraphsResponse, graph, path=["response"]) + + @parametrize + async def test_streaming_response_list(self, async_client: AsyncWriter) -> None: + async with async_client.applications.graphs.with_streaming_response.list( + "application_id", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + graph = await response.parse() + assert_matches_type(ApplicationGraphsResponse, graph, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_path_params_list(self, async_client: AsyncWriter) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `application_id` but received ''"): + await async_client.applications.graphs.with_raw_response.list( + "", + ) diff --git a/tests/api_resources/applications/test_jobs.py b/tests/api_resources/applications/test_jobs.py new file mode 100644 index 00000000..ff13b328 --- /dev/null +++ b/tests/api_resources/applications/test_jobs.py @@ -0,0 +1,401 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +import os +from typing import Any, cast + +import pytest + +from writerai import Writer, AsyncWriter +from tests.utils import assert_matches_type +from writerai.pagination import SyncApplicationJobsOffset, AsyncApplicationJobsOffset +from writerai.types.applications import ( + JobRetryResponse, + JobCreateResponse, + ApplicationGenerateAsyncResponse, +) + +base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") + + +class TestJobs: + parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) + + @parametrize + def test_method_create(self, client: Writer) -> None: + job = client.applications.jobs.create( + application_id="application_id", + inputs=[ + { + "id": "id", + "value": ["string"], + } + ], + ) + assert_matches_type(JobCreateResponse, job, path=["response"]) + + @parametrize + def test_raw_response_create(self, client: Writer) -> None: + response = client.applications.jobs.with_raw_response.create( + application_id="application_id", + inputs=[ + { + "id": "id", + "value": ["string"], + } + ], + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + job = response.parse() + assert_matches_type(JobCreateResponse, job, path=["response"]) + + @parametrize + def test_streaming_response_create(self, client: Writer) -> None: + with client.applications.jobs.with_streaming_response.create( + application_id="application_id", + inputs=[ + { + "id": "id", + "value": ["string"], + } + ], + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + job = response.parse() + assert_matches_type(JobCreateResponse, job, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_path_params_create(self, client: Writer) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `application_id` but received ''"): + client.applications.jobs.with_raw_response.create( + application_id="", + inputs=[ + { + "id": "id", + "value": ["string"], + } + ], + ) + + @parametrize + def test_method_retrieve(self, client: Writer) -> None: + job = client.applications.jobs.retrieve( + "182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + ) + assert_matches_type(ApplicationGenerateAsyncResponse, job, path=["response"]) + + @parametrize + def test_raw_response_retrieve(self, client: Writer) -> None: + response = client.applications.jobs.with_raw_response.retrieve( + "182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + job = response.parse() + assert_matches_type(ApplicationGenerateAsyncResponse, job, path=["response"]) + + @parametrize + def test_streaming_response_retrieve(self, client: Writer) -> None: + with client.applications.jobs.with_streaming_response.retrieve( + "182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + job = response.parse() + assert_matches_type(ApplicationGenerateAsyncResponse, job, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_path_params_retrieve(self, client: Writer) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `job_id` but received ''"): + client.applications.jobs.with_raw_response.retrieve( + "", + ) + + @parametrize + def test_method_list(self, client: Writer) -> None: + job = client.applications.jobs.list( + application_id="application_id", + ) + assert_matches_type(SyncApplicationJobsOffset[ApplicationGenerateAsyncResponse], job, path=["response"]) + + @parametrize + def test_method_list_with_all_params(self, client: Writer) -> None: + job = client.applications.jobs.list( + application_id="application_id", + limit=0, + offset=0, + status="in_progress", + ) + assert_matches_type(SyncApplicationJobsOffset[ApplicationGenerateAsyncResponse], job, path=["response"]) + + @parametrize + def test_raw_response_list(self, client: Writer) -> None: + response = client.applications.jobs.with_raw_response.list( + application_id="application_id", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + job = response.parse() + assert_matches_type(SyncApplicationJobsOffset[ApplicationGenerateAsyncResponse], job, path=["response"]) + + @parametrize + def test_streaming_response_list(self, client: Writer) -> None: + with client.applications.jobs.with_streaming_response.list( + application_id="application_id", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + job = response.parse() + assert_matches_type(SyncApplicationJobsOffset[ApplicationGenerateAsyncResponse], job, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_path_params_list(self, client: Writer) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `application_id` but received ''"): + client.applications.jobs.with_raw_response.list( + application_id="", + ) + + @parametrize + def test_method_retry(self, client: Writer) -> None: + job = client.applications.jobs.retry( + "182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + ) + assert_matches_type(JobRetryResponse, job, path=["response"]) + + @parametrize + def test_raw_response_retry(self, client: Writer) -> None: + response = client.applications.jobs.with_raw_response.retry( + "182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + job = response.parse() + assert_matches_type(JobRetryResponse, job, path=["response"]) + + @parametrize + def test_streaming_response_retry(self, client: Writer) -> None: + with client.applications.jobs.with_streaming_response.retry( + "182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + job = response.parse() + assert_matches_type(JobRetryResponse, job, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_path_params_retry(self, client: Writer) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `job_id` but received ''"): + client.applications.jobs.with_raw_response.retry( + "", + ) + + +class TestAsyncJobs: + parametrize = pytest.mark.parametrize( + "async_client", [False, True, {"http_client": "aiohttp"}], indirect=True, ids=["loose", "strict", "aiohttp"] + ) + + @parametrize + async def test_method_create(self, async_client: AsyncWriter) -> None: + job = await async_client.applications.jobs.create( + application_id="application_id", + inputs=[ + { + "id": "id", + "value": ["string"], + } + ], + ) + assert_matches_type(JobCreateResponse, job, path=["response"]) + + @parametrize + async def test_raw_response_create(self, async_client: AsyncWriter) -> None: + response = await async_client.applications.jobs.with_raw_response.create( + application_id="application_id", + inputs=[ + { + "id": "id", + "value": ["string"], + } + ], + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + job = await response.parse() + assert_matches_type(JobCreateResponse, job, path=["response"]) + + @parametrize + async def test_streaming_response_create(self, async_client: AsyncWriter) -> None: + async with async_client.applications.jobs.with_streaming_response.create( + application_id="application_id", + inputs=[ + { + "id": "id", + "value": ["string"], + } + ], + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + job = await response.parse() + assert_matches_type(JobCreateResponse, job, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_path_params_create(self, async_client: AsyncWriter) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `application_id` but received ''"): + await async_client.applications.jobs.with_raw_response.create( + application_id="", + inputs=[ + { + "id": "id", + "value": ["string"], + } + ], + ) + + @parametrize + async def test_method_retrieve(self, async_client: AsyncWriter) -> None: + job = await async_client.applications.jobs.retrieve( + "182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + ) + assert_matches_type(ApplicationGenerateAsyncResponse, job, path=["response"]) + + @parametrize + async def test_raw_response_retrieve(self, async_client: AsyncWriter) -> None: + response = await async_client.applications.jobs.with_raw_response.retrieve( + "182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + job = await response.parse() + assert_matches_type(ApplicationGenerateAsyncResponse, job, path=["response"]) + + @parametrize + async def test_streaming_response_retrieve(self, async_client: AsyncWriter) -> None: + async with async_client.applications.jobs.with_streaming_response.retrieve( + "182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + job = await response.parse() + assert_matches_type(ApplicationGenerateAsyncResponse, job, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_path_params_retrieve(self, async_client: AsyncWriter) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `job_id` but received ''"): + await async_client.applications.jobs.with_raw_response.retrieve( + "", + ) + + @parametrize + async def test_method_list(self, async_client: AsyncWriter) -> None: + job = await async_client.applications.jobs.list( + application_id="application_id", + ) + assert_matches_type(AsyncApplicationJobsOffset[ApplicationGenerateAsyncResponse], job, path=["response"]) + + @parametrize + async def test_method_list_with_all_params(self, async_client: AsyncWriter) -> None: + job = await async_client.applications.jobs.list( + application_id="application_id", + limit=0, + offset=0, + status="in_progress", + ) + assert_matches_type(AsyncApplicationJobsOffset[ApplicationGenerateAsyncResponse], job, path=["response"]) + + @parametrize + async def test_raw_response_list(self, async_client: AsyncWriter) -> None: + response = await async_client.applications.jobs.with_raw_response.list( + application_id="application_id", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + job = await response.parse() + assert_matches_type(AsyncApplicationJobsOffset[ApplicationGenerateAsyncResponse], job, path=["response"]) + + @parametrize + async def test_streaming_response_list(self, async_client: AsyncWriter) -> None: + async with async_client.applications.jobs.with_streaming_response.list( + application_id="application_id", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + job = await response.parse() + assert_matches_type(AsyncApplicationJobsOffset[ApplicationGenerateAsyncResponse], job, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_path_params_list(self, async_client: AsyncWriter) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `application_id` but received ''"): + await async_client.applications.jobs.with_raw_response.list( + application_id="", + ) + + @parametrize + async def test_method_retry(self, async_client: AsyncWriter) -> None: + job = await async_client.applications.jobs.retry( + "182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + ) + assert_matches_type(JobRetryResponse, job, path=["response"]) + + @parametrize + async def test_raw_response_retry(self, async_client: AsyncWriter) -> None: + response = await async_client.applications.jobs.with_raw_response.retry( + "182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + job = await response.parse() + assert_matches_type(JobRetryResponse, job, path=["response"]) + + @parametrize + async def test_streaming_response_retry(self, async_client: AsyncWriter) -> None: + async with async_client.applications.jobs.with_streaming_response.retry( + "182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + job = await response.parse() + assert_matches_type(JobRetryResponse, job, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_path_params_retry(self, async_client: AsyncWriter) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `job_id` but received ''"): + await async_client.applications.jobs.with_raw_response.retry( + "", + ) diff --git a/tests/api_resources/test_applications.py b/tests/api_resources/test_applications.py new file mode 100644 index 00000000..b4813d58 --- /dev/null +++ b/tests/api_resources/test_applications.py @@ -0,0 +1,459 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +import os +from typing import Any, cast + +import pytest + +from writerai import Writer, AsyncWriter +from tests.utils import assert_matches_type +from writerai.types import ( + ApplicationListResponse, + ApplicationRetrieveResponse, + ApplicationGenerateContentResponse, +) +from writerai.pagination import SyncCursorPage, AsyncCursorPage + +base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") + + +class TestApplications: + parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) + + @parametrize + def test_method_retrieve(self, client: Writer) -> None: + application = client.applications.retrieve( + "application_id", + ) + assert_matches_type(ApplicationRetrieveResponse, application, path=["response"]) + + @parametrize + def test_raw_response_retrieve(self, client: Writer) -> None: + response = client.applications.with_raw_response.retrieve( + "application_id", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + application = response.parse() + assert_matches_type(ApplicationRetrieveResponse, application, path=["response"]) + + @parametrize + def test_streaming_response_retrieve(self, client: Writer) -> None: + with client.applications.with_streaming_response.retrieve( + "application_id", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + application = response.parse() + assert_matches_type(ApplicationRetrieveResponse, application, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_path_params_retrieve(self, client: Writer) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `application_id` but received ''"): + client.applications.with_raw_response.retrieve( + "", + ) + + @parametrize + def test_method_list(self, client: Writer) -> None: + application = client.applications.list() + assert_matches_type(SyncCursorPage[ApplicationListResponse], application, path=["response"]) + + @parametrize + def test_method_list_with_all_params(self, client: Writer) -> None: + application = client.applications.list( + after="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + before="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + limit=0, + order="asc", + type="generation", + ) + assert_matches_type(SyncCursorPage[ApplicationListResponse], application, path=["response"]) + + @parametrize + def test_raw_response_list(self, client: Writer) -> None: + response = client.applications.with_raw_response.list() + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + application = response.parse() + assert_matches_type(SyncCursorPage[ApplicationListResponse], application, path=["response"]) + + @parametrize + def test_streaming_response_list(self, client: Writer) -> None: + with client.applications.with_streaming_response.list() as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + application = response.parse() + assert_matches_type(SyncCursorPage[ApplicationListResponse], application, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_method_generate_content_overload_1(self, client: Writer) -> None: + application = client.applications.generate_content( + application_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + inputs=[ + { + "id": "id", + "value": ["string"], + } + ], + ) + assert_matches_type(ApplicationGenerateContentResponse, application, path=["response"]) + + @parametrize + def test_method_generate_content_with_all_params_overload_1(self, client: Writer) -> None: + application = client.applications.generate_content( + application_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + inputs=[ + { + "id": "id", + "value": ["string"], + } + ], + stream=False, + ) + assert_matches_type(ApplicationGenerateContentResponse, application, path=["response"]) + + @parametrize + def test_raw_response_generate_content_overload_1(self, client: Writer) -> None: + response = client.applications.with_raw_response.generate_content( + application_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + inputs=[ + { + "id": "id", + "value": ["string"], + } + ], + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + application = response.parse() + assert_matches_type(ApplicationGenerateContentResponse, application, path=["response"]) + + @parametrize + def test_streaming_response_generate_content_overload_1(self, client: Writer) -> None: + with client.applications.with_streaming_response.generate_content( + application_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + inputs=[ + { + "id": "id", + "value": ["string"], + } + ], + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + application = response.parse() + assert_matches_type(ApplicationGenerateContentResponse, application, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_path_params_generate_content_overload_1(self, client: Writer) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `application_id` but received ''"): + client.applications.with_raw_response.generate_content( + application_id="", + inputs=[ + { + "id": "id", + "value": ["string"], + } + ], + ) + + @parametrize + def test_method_generate_content_overload_2(self, client: Writer) -> None: + application_stream = client.applications.generate_content( + application_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + inputs=[ + { + "id": "id", + "value": ["string"], + } + ], + stream=True, + ) + application_stream.response.close() + + @parametrize + def test_raw_response_generate_content_overload_2(self, client: Writer) -> None: + response = client.applications.with_raw_response.generate_content( + application_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + inputs=[ + { + "id": "id", + "value": ["string"], + } + ], + stream=True, + ) + + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + stream = response.parse() + stream.close() + + @parametrize + def test_streaming_response_generate_content_overload_2(self, client: Writer) -> None: + with client.applications.with_streaming_response.generate_content( + application_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + inputs=[ + { + "id": "id", + "value": ["string"], + } + ], + stream=True, + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + stream = response.parse() + stream.close() + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_path_params_generate_content_overload_2(self, client: Writer) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `application_id` but received ''"): + client.applications.with_raw_response.generate_content( + application_id="", + inputs=[ + { + "id": "id", + "value": ["string"], + } + ], + stream=True, + ) + + +class TestAsyncApplications: + parametrize = pytest.mark.parametrize( + "async_client", [False, True, {"http_client": "aiohttp"}], indirect=True, ids=["loose", "strict", "aiohttp"] + ) + + @parametrize + async def test_method_retrieve(self, async_client: AsyncWriter) -> None: + application = await async_client.applications.retrieve( + "application_id", + ) + assert_matches_type(ApplicationRetrieveResponse, application, path=["response"]) + + @parametrize + async def test_raw_response_retrieve(self, async_client: AsyncWriter) -> None: + response = await async_client.applications.with_raw_response.retrieve( + "application_id", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + application = await response.parse() + assert_matches_type(ApplicationRetrieveResponse, application, path=["response"]) + + @parametrize + async def test_streaming_response_retrieve(self, async_client: AsyncWriter) -> None: + async with async_client.applications.with_streaming_response.retrieve( + "application_id", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + application = await response.parse() + assert_matches_type(ApplicationRetrieveResponse, application, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_path_params_retrieve(self, async_client: AsyncWriter) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `application_id` but received ''"): + await async_client.applications.with_raw_response.retrieve( + "", + ) + + @parametrize + async def test_method_list(self, async_client: AsyncWriter) -> None: + application = await async_client.applications.list() + assert_matches_type(AsyncCursorPage[ApplicationListResponse], application, path=["response"]) + + @parametrize + async def test_method_list_with_all_params(self, async_client: AsyncWriter) -> None: + application = await async_client.applications.list( + after="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + before="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + limit=0, + order="asc", + type="generation", + ) + assert_matches_type(AsyncCursorPage[ApplicationListResponse], application, path=["response"]) + + @parametrize + async def test_raw_response_list(self, async_client: AsyncWriter) -> None: + response = await async_client.applications.with_raw_response.list() + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + application = await response.parse() + assert_matches_type(AsyncCursorPage[ApplicationListResponse], application, path=["response"]) + + @parametrize + async def test_streaming_response_list(self, async_client: AsyncWriter) -> None: + async with async_client.applications.with_streaming_response.list() as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + application = await response.parse() + assert_matches_type(AsyncCursorPage[ApplicationListResponse], application, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_method_generate_content_overload_1(self, async_client: AsyncWriter) -> None: + application = await async_client.applications.generate_content( + application_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + inputs=[ + { + "id": "id", + "value": ["string"], + } + ], + ) + assert_matches_type(ApplicationGenerateContentResponse, application, path=["response"]) + + @parametrize + async def test_method_generate_content_with_all_params_overload_1(self, async_client: AsyncWriter) -> None: + application = await async_client.applications.generate_content( + application_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + inputs=[ + { + "id": "id", + "value": ["string"], + } + ], + stream=False, + ) + assert_matches_type(ApplicationGenerateContentResponse, application, path=["response"]) + + @parametrize + async def test_raw_response_generate_content_overload_1(self, async_client: AsyncWriter) -> None: + response = await async_client.applications.with_raw_response.generate_content( + application_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + inputs=[ + { + "id": "id", + "value": ["string"], + } + ], + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + application = await response.parse() + assert_matches_type(ApplicationGenerateContentResponse, application, path=["response"]) + + @parametrize + async def test_streaming_response_generate_content_overload_1(self, async_client: AsyncWriter) -> None: + async with async_client.applications.with_streaming_response.generate_content( + application_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + inputs=[ + { + "id": "id", + "value": ["string"], + } + ], + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + application = await response.parse() + assert_matches_type(ApplicationGenerateContentResponse, application, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_path_params_generate_content_overload_1(self, async_client: AsyncWriter) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `application_id` but received ''"): + await async_client.applications.with_raw_response.generate_content( + application_id="", + inputs=[ + { + "id": "id", + "value": ["string"], + } + ], + ) + + @parametrize + async def test_method_generate_content_overload_2(self, async_client: AsyncWriter) -> None: + application_stream = await async_client.applications.generate_content( + application_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + inputs=[ + { + "id": "id", + "value": ["string"], + } + ], + stream=True, + ) + await application_stream.response.aclose() + + @parametrize + async def test_raw_response_generate_content_overload_2(self, async_client: AsyncWriter) -> None: + response = await async_client.applications.with_raw_response.generate_content( + application_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + inputs=[ + { + "id": "id", + "value": ["string"], + } + ], + stream=True, + ) + + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + stream = await response.parse() + await stream.close() + + @parametrize + async def test_streaming_response_generate_content_overload_2(self, async_client: AsyncWriter) -> None: + async with async_client.applications.with_streaming_response.generate_content( + application_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + inputs=[ + { + "id": "id", + "value": ["string"], + } + ], + stream=True, + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + stream = await response.parse() + await stream.close() + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_path_params_generate_content_overload_2(self, async_client: AsyncWriter) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `application_id` but received ''"): + await async_client.applications.with_raw_response.generate_content( + application_id="", + inputs=[ + { + "id": "id", + "value": ["string"], + } + ], + stream=True, + ) diff --git a/tests/api_resources/test_chat.py b/tests/api_resources/test_chat.py new file mode 100644 index 00000000..ab61287b --- /dev/null +++ b/tests/api_resources/test_chat.py @@ -0,0 +1,524 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +import os +from typing import Any, cast + +import pytest + +from writerai import Writer, AsyncWriter +from tests.utils import assert_matches_type +from writerai.types import ChatCompletion + +base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") + + +class TestChat: + parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) + + @parametrize + def test_method_chat_overload_1(self, client: Writer) -> None: + chat = client.chat.chat( + messages=[{"role": "user"}], + model="model", + ) + assert_matches_type(ChatCompletion, chat, path=["response"]) + + @parametrize + def test_method_chat_with_all_params_overload_1(self, client: Writer) -> None: + chat = client.chat.chat( + messages=[ + { + "role": "user", + "content": "string", + "graph_data": { + "references": { + "files": [ + { + "file_id": "fileId", + "score": 0, + "text": "text", + "cite": "cite", + "page": 0, + } + ], + "web": [ + { + "score": 0, + "text": "text", + "title": "title", + "url": "https://example.com", + } + ], + }, + "sources": [ + { + "file_id": "file_id", + "snippet": "snippet", + } + ], + "status": "processing", + "subqueries": [ + { + "answer": "answer", + "query": "query", + "sources": [ + { + "file_id": "file_id", + "snippet": "snippet", + } + ], + } + ], + }, + "name": "name", + "refusal": "refusal", + "tool_call_id": "tool_call_id", + "tool_calls": [ + { + "id": "id", + "function": { + "arguments": "arguments", + "name": "name", + }, + "type": "function", + "index": 0, + } + ], + } + ], + model="model", + logprobs=True, + max_tokens=0, + n=0, + response_format={ + "type": "text", + "json_schema": {}, + }, + stop=["string"], + stream=False, + stream_options={"include_usage": True}, + temperature=0, + tool_choice={"value": "none"}, + tools=[ + { + "function": { + "name": "name", + "description": "description", + "parameters": {"foo": "bar"}, + }, + "type": "function", + } + ], + top_p=0, + ) + assert_matches_type(ChatCompletion, chat, path=["response"]) + + @parametrize + def test_raw_response_chat_overload_1(self, client: Writer) -> None: + response = client.chat.with_raw_response.chat( + messages=[{"role": "user"}], + model="model", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + chat = response.parse() + assert_matches_type(ChatCompletion, chat, path=["response"]) + + @parametrize + def test_streaming_response_chat_overload_1(self, client: Writer) -> None: + with client.chat.with_streaming_response.chat( + messages=[{"role": "user"}], + model="model", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + chat = response.parse() + assert_matches_type(ChatCompletion, chat, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_method_chat_overload_2(self, client: Writer) -> None: + chat_stream = client.chat.chat( + messages=[{"role": "user"}], + model="model", + stream=True, + ) + chat_stream.response.close() + + @parametrize + def test_method_chat_with_all_params_overload_2(self, client: Writer) -> None: + chat_stream = client.chat.chat( + messages=[ + { + "role": "user", + "content": "string", + "graph_data": { + "references": { + "files": [ + { + "file_id": "fileId", + "score": 0, + "text": "text", + "cite": "cite", + "page": 0, + } + ], + "web": [ + { + "score": 0, + "text": "text", + "title": "title", + "url": "https://example.com", + } + ], + }, + "sources": [ + { + "file_id": "file_id", + "snippet": "snippet", + } + ], + "status": "processing", + "subqueries": [ + { + "answer": "answer", + "query": "query", + "sources": [ + { + "file_id": "file_id", + "snippet": "snippet", + } + ], + } + ], + }, + "name": "name", + "refusal": "refusal", + "tool_call_id": "tool_call_id", + "tool_calls": [ + { + "id": "id", + "function": { + "arguments": "arguments", + "name": "name", + }, + "type": "function", + "index": 0, + } + ], + } + ], + model="model", + stream=True, + logprobs=True, + max_tokens=0, + n=0, + response_format={ + "type": "text", + "json_schema": {}, + }, + stop=["string"], + stream_options={"include_usage": True}, + temperature=0, + tool_choice={"value": "none"}, + tools=[ + { + "function": { + "name": "name", + "description": "description", + "parameters": {"foo": "bar"}, + }, + "type": "function", + } + ], + top_p=0, + ) + chat_stream.response.close() + + @parametrize + def test_raw_response_chat_overload_2(self, client: Writer) -> None: + response = client.chat.with_raw_response.chat( + messages=[{"role": "user"}], + model="model", + stream=True, + ) + + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + stream = response.parse() + stream.close() + + @parametrize + def test_streaming_response_chat_overload_2(self, client: Writer) -> None: + with client.chat.with_streaming_response.chat( + messages=[{"role": "user"}], + model="model", + stream=True, + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + stream = response.parse() + stream.close() + + assert cast(Any, response.is_closed) is True + + +class TestAsyncChat: + parametrize = pytest.mark.parametrize( + "async_client", [False, True, {"http_client": "aiohttp"}], indirect=True, ids=["loose", "strict", "aiohttp"] + ) + + @parametrize + async def test_method_chat_overload_1(self, async_client: AsyncWriter) -> None: + chat = await async_client.chat.chat( + messages=[{"role": "user"}], + model="model", + ) + assert_matches_type(ChatCompletion, chat, path=["response"]) + + @parametrize + async def test_method_chat_with_all_params_overload_1(self, async_client: AsyncWriter) -> None: + chat = await async_client.chat.chat( + messages=[ + { + "role": "user", + "content": "string", + "graph_data": { + "references": { + "files": [ + { + "file_id": "fileId", + "score": 0, + "text": "text", + "cite": "cite", + "page": 0, + } + ], + "web": [ + { + "score": 0, + "text": "text", + "title": "title", + "url": "https://example.com", + } + ], + }, + "sources": [ + { + "file_id": "file_id", + "snippet": "snippet", + } + ], + "status": "processing", + "subqueries": [ + { + "answer": "answer", + "query": "query", + "sources": [ + { + "file_id": "file_id", + "snippet": "snippet", + } + ], + } + ], + }, + "name": "name", + "refusal": "refusal", + "tool_call_id": "tool_call_id", + "tool_calls": [ + { + "id": "id", + "function": { + "arguments": "arguments", + "name": "name", + }, + "type": "function", + "index": 0, + } + ], + } + ], + model="model", + logprobs=True, + max_tokens=0, + n=0, + response_format={ + "type": "text", + "json_schema": {}, + }, + stop=["string"], + stream=False, + stream_options={"include_usage": True}, + temperature=0, + tool_choice={"value": "none"}, + tools=[ + { + "function": { + "name": "name", + "description": "description", + "parameters": {"foo": "bar"}, + }, + "type": "function", + } + ], + top_p=0, + ) + assert_matches_type(ChatCompletion, chat, path=["response"]) + + @parametrize + async def test_raw_response_chat_overload_1(self, async_client: AsyncWriter) -> None: + response = await async_client.chat.with_raw_response.chat( + messages=[{"role": "user"}], + model="model", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + chat = await response.parse() + assert_matches_type(ChatCompletion, chat, path=["response"]) + + @parametrize + async def test_streaming_response_chat_overload_1(self, async_client: AsyncWriter) -> None: + async with async_client.chat.with_streaming_response.chat( + messages=[{"role": "user"}], + model="model", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + chat = await response.parse() + assert_matches_type(ChatCompletion, chat, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_method_chat_overload_2(self, async_client: AsyncWriter) -> None: + chat_stream = await async_client.chat.chat( + messages=[{"role": "user"}], + model="model", + stream=True, + ) + await chat_stream.response.aclose() + + @parametrize + async def test_method_chat_with_all_params_overload_2(self, async_client: AsyncWriter) -> None: + chat_stream = await async_client.chat.chat( + messages=[ + { + "role": "user", + "content": "string", + "graph_data": { + "references": { + "files": [ + { + "file_id": "fileId", + "score": 0, + "text": "text", + "cite": "cite", + "page": 0, + } + ], + "web": [ + { + "score": 0, + "text": "text", + "title": "title", + "url": "https://example.com", + } + ], + }, + "sources": [ + { + "file_id": "file_id", + "snippet": "snippet", + } + ], + "status": "processing", + "subqueries": [ + { + "answer": "answer", + "query": "query", + "sources": [ + { + "file_id": "file_id", + "snippet": "snippet", + } + ], + } + ], + }, + "name": "name", + "refusal": "refusal", + "tool_call_id": "tool_call_id", + "tool_calls": [ + { + "id": "id", + "function": { + "arguments": "arguments", + "name": "name", + }, + "type": "function", + "index": 0, + } + ], + } + ], + model="model", + stream=True, + logprobs=True, + max_tokens=0, + n=0, + response_format={ + "type": "text", + "json_schema": {}, + }, + stop=["string"], + stream_options={"include_usage": True}, + temperature=0, + tool_choice={"value": "none"}, + tools=[ + { + "function": { + "name": "name", + "description": "description", + "parameters": {"foo": "bar"}, + }, + "type": "function", + } + ], + top_p=0, + ) + await chat_stream.response.aclose() + + @parametrize + async def test_raw_response_chat_overload_2(self, async_client: AsyncWriter) -> None: + response = await async_client.chat.with_raw_response.chat( + messages=[{"role": "user"}], + model="model", + stream=True, + ) + + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + stream = await response.parse() + await stream.close() + + @parametrize + async def test_streaming_response_chat_overload_2(self, async_client: AsyncWriter) -> None: + async with async_client.chat.with_streaming_response.chat( + messages=[{"role": "user"}], + model="model", + stream=True, + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + stream = await response.parse() + await stream.close() + + assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/test_completions.py b/tests/api_resources/test_completions.py new file mode 100644 index 00000000..246bdfc9 --- /dev/null +++ b/tests/api_resources/test_completions.py @@ -0,0 +1,224 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +import os +from typing import Any, cast + +import pytest + +from writerai import Writer, AsyncWriter +from tests.utils import assert_matches_type +from writerai.types import Completion + +base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") + + +class TestCompletions: + parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) + + @parametrize + def test_method_create_overload_1(self, client: Writer) -> None: + completion = client.completions.create( + model="palmyra-x-003-instruct", + prompt="Write me an SEO article about...", + ) + assert_matches_type(Completion, completion, path=["response"]) + + @parametrize + def test_method_create_with_all_params_overload_1(self, client: Writer) -> None: + completion = client.completions.create( + model="palmyra-x-003-instruct", + prompt="Write me an SEO article about...", + best_of=1, + max_tokens=150, + random_seed=42, + stop=["."], + stream=False, + temperature=0.7, + top_p=0.9, + ) + assert_matches_type(Completion, completion, path=["response"]) + + @parametrize + def test_raw_response_create_overload_1(self, client: Writer) -> None: + response = client.completions.with_raw_response.create( + model="palmyra-x-003-instruct", + prompt="Write me an SEO article about...", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + completion = response.parse() + assert_matches_type(Completion, completion, path=["response"]) + + @parametrize + def test_streaming_response_create_overload_1(self, client: Writer) -> None: + with client.completions.with_streaming_response.create( + model="palmyra-x-003-instruct", + prompt="Write me an SEO article about...", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + completion = response.parse() + assert_matches_type(Completion, completion, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_method_create_overload_2(self, client: Writer) -> None: + completion_stream = client.completions.create( + model="palmyra-x-003-instruct", + prompt="Write me an SEO article about...", + stream=True, + ) + completion_stream.response.close() + + @parametrize + def test_method_create_with_all_params_overload_2(self, client: Writer) -> None: + completion_stream = client.completions.create( + model="palmyra-x-003-instruct", + prompt="Write me an SEO article about...", + stream=True, + best_of=1, + max_tokens=150, + random_seed=42, + stop=["."], + temperature=0.7, + top_p=0.9, + ) + completion_stream.response.close() + + @parametrize + def test_raw_response_create_overload_2(self, client: Writer) -> None: + response = client.completions.with_raw_response.create( + model="palmyra-x-003-instruct", + prompt="Write me an SEO article about...", + stream=True, + ) + + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + stream = response.parse() + stream.close() + + @parametrize + def test_streaming_response_create_overload_2(self, client: Writer) -> None: + with client.completions.with_streaming_response.create( + model="palmyra-x-003-instruct", + prompt="Write me an SEO article about...", + stream=True, + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + stream = response.parse() + stream.close() + + assert cast(Any, response.is_closed) is True + + +class TestAsyncCompletions: + parametrize = pytest.mark.parametrize( + "async_client", [False, True, {"http_client": "aiohttp"}], indirect=True, ids=["loose", "strict", "aiohttp"] + ) + + @parametrize + async def test_method_create_overload_1(self, async_client: AsyncWriter) -> None: + completion = await async_client.completions.create( + model="palmyra-x-003-instruct", + prompt="Write me an SEO article about...", + ) + assert_matches_type(Completion, completion, path=["response"]) + + @parametrize + async def test_method_create_with_all_params_overload_1(self, async_client: AsyncWriter) -> None: + completion = await async_client.completions.create( + model="palmyra-x-003-instruct", + prompt="Write me an SEO article about...", + best_of=1, + max_tokens=150, + random_seed=42, + stop=["."], + stream=False, + temperature=0.7, + top_p=0.9, + ) + assert_matches_type(Completion, completion, path=["response"]) + + @parametrize + async def test_raw_response_create_overload_1(self, async_client: AsyncWriter) -> None: + response = await async_client.completions.with_raw_response.create( + model="palmyra-x-003-instruct", + prompt="Write me an SEO article about...", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + completion = await response.parse() + assert_matches_type(Completion, completion, path=["response"]) + + @parametrize + async def test_streaming_response_create_overload_1(self, async_client: AsyncWriter) -> None: + async with async_client.completions.with_streaming_response.create( + model="palmyra-x-003-instruct", + prompt="Write me an SEO article about...", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + completion = await response.parse() + assert_matches_type(Completion, completion, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_method_create_overload_2(self, async_client: AsyncWriter) -> None: + completion_stream = await async_client.completions.create( + model="palmyra-x-003-instruct", + prompt="Write me an SEO article about...", + stream=True, + ) + await completion_stream.response.aclose() + + @parametrize + async def test_method_create_with_all_params_overload_2(self, async_client: AsyncWriter) -> None: + completion_stream = await async_client.completions.create( + model="palmyra-x-003-instruct", + prompt="Write me an SEO article about...", + stream=True, + best_of=1, + max_tokens=150, + random_seed=42, + stop=["."], + temperature=0.7, + top_p=0.9, + ) + await completion_stream.response.aclose() + + @parametrize + async def test_raw_response_create_overload_2(self, async_client: AsyncWriter) -> None: + response = await async_client.completions.with_raw_response.create( + model="palmyra-x-003-instruct", + prompt="Write me an SEO article about...", + stream=True, + ) + + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + stream = await response.parse() + await stream.close() + + @parametrize + async def test_streaming_response_create_overload_2(self, async_client: AsyncWriter) -> None: + async with async_client.completions.with_streaming_response.create( + model="palmyra-x-003-instruct", + prompt="Write me an SEO article about...", + stream=True, + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + stream = await response.parse() + await stream.close() + + assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/test_files.py b/tests/api_resources/test_files.py new file mode 100644 index 00000000..52532517 --- /dev/null +++ b/tests/api_resources/test_files.py @@ -0,0 +1,529 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +import os +from typing import Any, cast + +import httpx +import pytest +from respx import MockRouter + +from writerai import Writer, AsyncWriter +from tests.utils import assert_matches_type +from writerai.types import ( + File, + FileRetryResponse, + FileDeleteResponse, +) +from writerai._response import ( + BinaryAPIResponse, + AsyncBinaryAPIResponse, + StreamedBinaryAPIResponse, + AsyncStreamedBinaryAPIResponse, +) +from writerai.pagination import SyncCursorPage, AsyncCursorPage + +base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") + + +class TestFiles: + parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) + + @parametrize + def test_method_retrieve(self, client: Writer) -> None: + file = client.files.retrieve( + "file_id", + ) + assert_matches_type(File, file, path=["response"]) + + @parametrize + def test_raw_response_retrieve(self, client: Writer) -> None: + response = client.files.with_raw_response.retrieve( + "file_id", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + file = response.parse() + assert_matches_type(File, file, path=["response"]) + + @parametrize + def test_streaming_response_retrieve(self, client: Writer) -> None: + with client.files.with_streaming_response.retrieve( + "file_id", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + file = response.parse() + assert_matches_type(File, file, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_path_params_retrieve(self, client: Writer) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `file_id` but received ''"): + client.files.with_raw_response.retrieve( + "", + ) + + @parametrize + def test_method_list(self, client: Writer) -> None: + file = client.files.list() + assert_matches_type(SyncCursorPage[File], file, path=["response"]) + + @parametrize + def test_method_list_with_all_params(self, client: Writer) -> None: + file = client.files.list( + after="after", + before="before", + file_types="file_types", + graph_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + limit=0, + order="asc", + status="in_progress", + ) + assert_matches_type(SyncCursorPage[File], file, path=["response"]) + + @parametrize + def test_raw_response_list(self, client: Writer) -> None: + response = client.files.with_raw_response.list() + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + file = response.parse() + assert_matches_type(SyncCursorPage[File], file, path=["response"]) + + @parametrize + def test_streaming_response_list(self, client: Writer) -> None: + with client.files.with_streaming_response.list() as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + file = response.parse() + assert_matches_type(SyncCursorPage[File], file, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_method_delete(self, client: Writer) -> None: + file = client.files.delete( + "file_id", + ) + assert_matches_type(FileDeleteResponse, file, path=["response"]) + + @parametrize + def test_raw_response_delete(self, client: Writer) -> None: + response = client.files.with_raw_response.delete( + "file_id", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + file = response.parse() + assert_matches_type(FileDeleteResponse, file, path=["response"]) + + @parametrize + def test_streaming_response_delete(self, client: Writer) -> None: + with client.files.with_streaming_response.delete( + "file_id", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + file = response.parse() + assert_matches_type(FileDeleteResponse, file, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_path_params_delete(self, client: Writer) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `file_id` but received ''"): + client.files.with_raw_response.delete( + "", + ) + + @pytest.mark.skip(reason="requests with binary data not yet supported in test environment") + @parametrize + @pytest.mark.respx(base_url=base_url) + def test_method_download(self, client: Writer, respx_mock: MockRouter) -> None: + respx_mock.get("/v1/files/file_id/download").mock(return_value=httpx.Response(200, json={"foo": "bar"})) + file = client.files.download( + "file_id", + ) + assert file.is_closed + assert file.json() == {"foo": "bar"} + assert cast(Any, file.is_closed) is True + assert isinstance(file, BinaryAPIResponse) + + @pytest.mark.skip(reason="requests with binary data not yet supported in test environment") + @parametrize + @pytest.mark.respx(base_url=base_url) + def test_raw_response_download(self, client: Writer, respx_mock: MockRouter) -> None: + respx_mock.get("/v1/files/file_id/download").mock(return_value=httpx.Response(200, json={"foo": "bar"})) + + file = client.files.with_raw_response.download( + "file_id", + ) + + assert file.is_closed is True + assert file.http_request.headers.get("X-Stainless-Lang") == "python" + assert file.json() == {"foo": "bar"} + assert isinstance(file, BinaryAPIResponse) + + @pytest.mark.skip(reason="requests with binary data not yet supported in test environment") + @parametrize + @pytest.mark.respx(base_url=base_url) + def test_streaming_response_download(self, client: Writer, respx_mock: MockRouter) -> None: + respx_mock.get("/v1/files/file_id/download").mock(return_value=httpx.Response(200, json={"foo": "bar"})) + with client.files.with_streaming_response.download( + "file_id", + ) as file: + assert not file.is_closed + assert file.http_request.headers.get("X-Stainless-Lang") == "python" + + assert file.json() == {"foo": "bar"} + assert cast(Any, file.is_closed) is True + assert isinstance(file, StreamedBinaryAPIResponse) + + assert cast(Any, file.is_closed) is True + + @pytest.mark.skip(reason="requests with binary data not yet supported in test environment") + @parametrize + @pytest.mark.respx(base_url=base_url) + def test_path_params_download(self, client: Writer) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `file_id` but received ''"): + client.files.with_raw_response.download( + "", + ) + + @parametrize + def test_method_retry(self, client: Writer) -> None: + file = client.files.retry( + file_ids=["182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e"], + ) + assert_matches_type(FileRetryResponse, file, path=["response"]) + + @parametrize + def test_raw_response_retry(self, client: Writer) -> None: + response = client.files.with_raw_response.retry( + file_ids=["182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e"], + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + file = response.parse() + assert_matches_type(FileRetryResponse, file, path=["response"]) + + @parametrize + def test_streaming_response_retry(self, client: Writer) -> None: + with client.files.with_streaming_response.retry( + file_ids=["182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e"], + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + file = response.parse() + assert_matches_type(FileRetryResponse, file, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @pytest.mark.skip(reason="requests with binary data not yet supported in test environment") + @parametrize + def test_method_upload(self, client: Writer) -> None: + file = client.files.upload( + content=b"Example data", + content_disposition="Content-Disposition", + ) + assert_matches_type(File, file, path=["response"]) + + @pytest.mark.skip(reason="requests with binary data not yet supported in test environment") + @parametrize + def test_method_upload_with_all_params(self, client: Writer) -> None: + file = client.files.upload( + content=b"Example data", + content_disposition="Content-Disposition", + graph_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + ) + assert_matches_type(File, file, path=["response"]) + + @pytest.mark.skip(reason="requests with binary data not yet supported in test environment") + @parametrize + def test_raw_response_upload(self, client: Writer) -> None: + response = client.files.with_raw_response.upload( + content=b"Example data", + content_disposition="Content-Disposition", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + file = response.parse() + assert_matches_type(File, file, path=["response"]) + + @pytest.mark.skip(reason="requests with binary data not yet supported in test environment") + @parametrize + def test_streaming_response_upload(self, client: Writer) -> None: + with client.files.with_streaming_response.upload( + content=b"Example data", + content_disposition="Content-Disposition", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + file = response.parse() + assert_matches_type(File, file, path=["response"]) + + assert cast(Any, response.is_closed) is True + + +class TestAsyncFiles: + parametrize = pytest.mark.parametrize( + "async_client", [False, True, {"http_client": "aiohttp"}], indirect=True, ids=["loose", "strict", "aiohttp"] + ) + + @parametrize + async def test_method_retrieve(self, async_client: AsyncWriter) -> None: + file = await async_client.files.retrieve( + "file_id", + ) + assert_matches_type(File, file, path=["response"]) + + @parametrize + async def test_raw_response_retrieve(self, async_client: AsyncWriter) -> None: + response = await async_client.files.with_raw_response.retrieve( + "file_id", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + file = await response.parse() + assert_matches_type(File, file, path=["response"]) + + @parametrize + async def test_streaming_response_retrieve(self, async_client: AsyncWriter) -> None: + async with async_client.files.with_streaming_response.retrieve( + "file_id", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + file = await response.parse() + assert_matches_type(File, file, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_path_params_retrieve(self, async_client: AsyncWriter) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `file_id` but received ''"): + await async_client.files.with_raw_response.retrieve( + "", + ) + + @parametrize + async def test_method_list(self, async_client: AsyncWriter) -> None: + file = await async_client.files.list() + assert_matches_type(AsyncCursorPage[File], file, path=["response"]) + + @parametrize + async def test_method_list_with_all_params(self, async_client: AsyncWriter) -> None: + file = await async_client.files.list( + after="after", + before="before", + file_types="file_types", + graph_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + limit=0, + order="asc", + status="in_progress", + ) + assert_matches_type(AsyncCursorPage[File], file, path=["response"]) + + @parametrize + async def test_raw_response_list(self, async_client: AsyncWriter) -> None: + response = await async_client.files.with_raw_response.list() + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + file = await response.parse() + assert_matches_type(AsyncCursorPage[File], file, path=["response"]) + + @parametrize + async def test_streaming_response_list(self, async_client: AsyncWriter) -> None: + async with async_client.files.with_streaming_response.list() as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + file = await response.parse() + assert_matches_type(AsyncCursorPage[File], file, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_method_delete(self, async_client: AsyncWriter) -> None: + file = await async_client.files.delete( + "file_id", + ) + assert_matches_type(FileDeleteResponse, file, path=["response"]) + + @parametrize + async def test_raw_response_delete(self, async_client: AsyncWriter) -> None: + response = await async_client.files.with_raw_response.delete( + "file_id", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + file = await response.parse() + assert_matches_type(FileDeleteResponse, file, path=["response"]) + + @parametrize + async def test_streaming_response_delete(self, async_client: AsyncWriter) -> None: + async with async_client.files.with_streaming_response.delete( + "file_id", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + file = await response.parse() + assert_matches_type(FileDeleteResponse, file, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_path_params_delete(self, async_client: AsyncWriter) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `file_id` but received ''"): + await async_client.files.with_raw_response.delete( + "", + ) + + @pytest.mark.skip(reason="requests with binary data not yet supported in test environment") + @parametrize + @pytest.mark.respx(base_url=base_url) + async def test_method_download(self, async_client: AsyncWriter, respx_mock: MockRouter) -> None: + respx_mock.get("/v1/files/file_id/download").mock(return_value=httpx.Response(200, json={"foo": "bar"})) + file = await async_client.files.download( + "file_id", + ) + assert file.is_closed + assert await file.json() == {"foo": "bar"} + assert cast(Any, file.is_closed) is True + assert isinstance(file, AsyncBinaryAPIResponse) + + @pytest.mark.skip(reason="requests with binary data not yet supported in test environment") + @parametrize + @pytest.mark.respx(base_url=base_url) + async def test_raw_response_download(self, async_client: AsyncWriter, respx_mock: MockRouter) -> None: + respx_mock.get("/v1/files/file_id/download").mock(return_value=httpx.Response(200, json={"foo": "bar"})) + + file = await async_client.files.with_raw_response.download( + "file_id", + ) + + assert file.is_closed is True + assert file.http_request.headers.get("X-Stainless-Lang") == "python" + assert await file.json() == {"foo": "bar"} + assert isinstance(file, AsyncBinaryAPIResponse) + + @pytest.mark.skip(reason="requests with binary data not yet supported in test environment") + @parametrize + @pytest.mark.respx(base_url=base_url) + async def test_streaming_response_download(self, async_client: AsyncWriter, respx_mock: MockRouter) -> None: + respx_mock.get("/v1/files/file_id/download").mock(return_value=httpx.Response(200, json={"foo": "bar"})) + async with async_client.files.with_streaming_response.download( + "file_id", + ) as file: + assert not file.is_closed + assert file.http_request.headers.get("X-Stainless-Lang") == "python" + + assert await file.json() == {"foo": "bar"} + assert cast(Any, file.is_closed) is True + assert isinstance(file, AsyncStreamedBinaryAPIResponse) + + assert cast(Any, file.is_closed) is True + + @pytest.mark.skip(reason="requests with binary data not yet supported in test environment") + @parametrize + @pytest.mark.respx(base_url=base_url) + async def test_path_params_download(self, async_client: AsyncWriter) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `file_id` but received ''"): + await async_client.files.with_raw_response.download( + "", + ) + + @parametrize + async def test_method_retry(self, async_client: AsyncWriter) -> None: + file = await async_client.files.retry( + file_ids=["182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e"], + ) + assert_matches_type(FileRetryResponse, file, path=["response"]) + + @parametrize + async def test_raw_response_retry(self, async_client: AsyncWriter) -> None: + response = await async_client.files.with_raw_response.retry( + file_ids=["182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e"], + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + file = await response.parse() + assert_matches_type(FileRetryResponse, file, path=["response"]) + + @parametrize + async def test_streaming_response_retry(self, async_client: AsyncWriter) -> None: + async with async_client.files.with_streaming_response.retry( + file_ids=["182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e"], + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + file = await response.parse() + assert_matches_type(FileRetryResponse, file, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @pytest.mark.skip(reason="requests with binary data not yet supported in test environment") + @parametrize + async def test_method_upload(self, async_client: AsyncWriter) -> None: + file = await async_client.files.upload( + content=b"Example data", + content_disposition="Content-Disposition", + ) + assert_matches_type(File, file, path=["response"]) + + @pytest.mark.skip(reason="requests with binary data not yet supported in test environment") + @parametrize + async def test_method_upload_with_all_params(self, async_client: AsyncWriter) -> None: + file = await async_client.files.upload( + content=b"Example data", + content_disposition="Content-Disposition", + graph_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + ) + assert_matches_type(File, file, path=["response"]) + + @pytest.mark.skip(reason="requests with binary data not yet supported in test environment") + @parametrize + async def test_raw_response_upload(self, async_client: AsyncWriter) -> None: + response = await async_client.files.with_raw_response.upload( + content=b"Example data", + content_disposition="Content-Disposition", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + file = await response.parse() + assert_matches_type(File, file, path=["response"]) + + @pytest.mark.skip(reason="requests with binary data not yet supported in test environment") + @parametrize + async def test_streaming_response_upload(self, async_client: AsyncWriter) -> None: + async with async_client.files.with_streaming_response.upload( + content=b"Example data", + content_disposition="Content-Disposition", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + file = await response.parse() + assert_matches_type(File, file, path=["response"]) + + assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/test_graphs.py b/tests/api_resources/test_graphs.py new file mode 100644 index 00000000..d4859225 --- /dev/null +++ b/tests/api_resources/test_graphs.py @@ -0,0 +1,829 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +import os +from typing import Any, cast + +import pytest + +from writerai import Writer, AsyncWriter +from tests.utils import assert_matches_type +from writerai.types import ( + File, + Graph, + Question, + GraphCreateResponse, + GraphDeleteResponse, + GraphUpdateResponse, + GraphRemoveFileFromGraphResponse, +) +from writerai.pagination import SyncCursorPage, AsyncCursorPage + +base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") + + +class TestGraphs: + parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) + + @parametrize + def test_method_create(self, client: Writer) -> None: + graph = client.graphs.create() + assert_matches_type(GraphCreateResponse, graph, path=["response"]) + + @parametrize + def test_method_create_with_all_params(self, client: Writer) -> None: + graph = client.graphs.create( + description="description", + name="name", + ) + assert_matches_type(GraphCreateResponse, graph, path=["response"]) + + @parametrize + def test_raw_response_create(self, client: Writer) -> None: + response = client.graphs.with_raw_response.create() + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + graph = response.parse() + assert_matches_type(GraphCreateResponse, graph, path=["response"]) + + @parametrize + def test_streaming_response_create(self, client: Writer) -> None: + with client.graphs.with_streaming_response.create() as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + graph = response.parse() + assert_matches_type(GraphCreateResponse, graph, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_method_retrieve(self, client: Writer) -> None: + graph = client.graphs.retrieve( + "182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + ) + assert_matches_type(Graph, graph, path=["response"]) + + @parametrize + def test_raw_response_retrieve(self, client: Writer) -> None: + response = client.graphs.with_raw_response.retrieve( + "182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + graph = response.parse() + assert_matches_type(Graph, graph, path=["response"]) + + @parametrize + def test_streaming_response_retrieve(self, client: Writer) -> None: + with client.graphs.with_streaming_response.retrieve( + "182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + graph = response.parse() + assert_matches_type(Graph, graph, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_path_params_retrieve(self, client: Writer) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `graph_id` but received ''"): + client.graphs.with_raw_response.retrieve( + "", + ) + + @parametrize + def test_method_update(self, client: Writer) -> None: + graph = client.graphs.update( + graph_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + ) + assert_matches_type(GraphUpdateResponse, graph, path=["response"]) + + @parametrize + def test_method_update_with_all_params(self, client: Writer) -> None: + graph = client.graphs.update( + graph_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + description="description", + name="name", + urls=[ + { + "type": "single_page", + "url": "url", + "exclude_urls": ["string"], + } + ], + ) + assert_matches_type(GraphUpdateResponse, graph, path=["response"]) + + @parametrize + def test_raw_response_update(self, client: Writer) -> None: + response = client.graphs.with_raw_response.update( + graph_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + graph = response.parse() + assert_matches_type(GraphUpdateResponse, graph, path=["response"]) + + @parametrize + def test_streaming_response_update(self, client: Writer) -> None: + with client.graphs.with_streaming_response.update( + graph_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + graph = response.parse() + assert_matches_type(GraphUpdateResponse, graph, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_path_params_update(self, client: Writer) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `graph_id` but received ''"): + client.graphs.with_raw_response.update( + graph_id="", + ) + + @parametrize + def test_method_list(self, client: Writer) -> None: + graph = client.graphs.list() + assert_matches_type(SyncCursorPage[Graph], graph, path=["response"]) + + @parametrize + def test_method_list_with_all_params(self, client: Writer) -> None: + graph = client.graphs.list( + after="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + before="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + limit=0, + order="asc", + ) + assert_matches_type(SyncCursorPage[Graph], graph, path=["response"]) + + @parametrize + def test_raw_response_list(self, client: Writer) -> None: + response = client.graphs.with_raw_response.list() + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + graph = response.parse() + assert_matches_type(SyncCursorPage[Graph], graph, path=["response"]) + + @parametrize + def test_streaming_response_list(self, client: Writer) -> None: + with client.graphs.with_streaming_response.list() as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + graph = response.parse() + assert_matches_type(SyncCursorPage[Graph], graph, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_method_delete(self, client: Writer) -> None: + graph = client.graphs.delete( + "182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + ) + assert_matches_type(GraphDeleteResponse, graph, path=["response"]) + + @parametrize + def test_raw_response_delete(self, client: Writer) -> None: + response = client.graphs.with_raw_response.delete( + "182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + graph = response.parse() + assert_matches_type(GraphDeleteResponse, graph, path=["response"]) + + @parametrize + def test_streaming_response_delete(self, client: Writer) -> None: + with client.graphs.with_streaming_response.delete( + "182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + graph = response.parse() + assert_matches_type(GraphDeleteResponse, graph, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_path_params_delete(self, client: Writer) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `graph_id` but received ''"): + client.graphs.with_raw_response.delete( + "", + ) + + @parametrize + def test_method_add_file_to_graph(self, client: Writer) -> None: + graph = client.graphs.add_file_to_graph( + graph_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + file_id="file_id", + ) + assert_matches_type(File, graph, path=["response"]) + + @parametrize + def test_raw_response_add_file_to_graph(self, client: Writer) -> None: + response = client.graphs.with_raw_response.add_file_to_graph( + graph_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + file_id="file_id", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + graph = response.parse() + assert_matches_type(File, graph, path=["response"]) + + @parametrize + def test_streaming_response_add_file_to_graph(self, client: Writer) -> None: + with client.graphs.with_streaming_response.add_file_to_graph( + graph_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + file_id="file_id", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + graph = response.parse() + assert_matches_type(File, graph, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_path_params_add_file_to_graph(self, client: Writer) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `graph_id` but received ''"): + client.graphs.with_raw_response.add_file_to_graph( + graph_id="", + file_id="file_id", + ) + + @parametrize + def test_method_question_overload_1(self, client: Writer) -> None: + graph = client.graphs.question( + graph_ids=["182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e"], + question="question", + ) + assert_matches_type(Question, graph, path=["response"]) + + @parametrize + def test_method_question_with_all_params_overload_1(self, client: Writer) -> None: + graph = client.graphs.question( + graph_ids=["182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e"], + question="question", + query_config={ + "grounding_level": 0, + "inline_citations": True, + "keyword_threshold": 0, + "max_snippets": 1, + "max_subquestions": 1, + "max_tokens": 100, + "search_weight": 0, + "semantic_threshold": 0, + }, + stream=False, + subqueries=True, + ) + assert_matches_type(Question, graph, path=["response"]) + + @parametrize + def test_raw_response_question_overload_1(self, client: Writer) -> None: + response = client.graphs.with_raw_response.question( + graph_ids=["182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e"], + question="question", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + graph = response.parse() + assert_matches_type(Question, graph, path=["response"]) + + @parametrize + def test_streaming_response_question_overload_1(self, client: Writer) -> None: + with client.graphs.with_streaming_response.question( + graph_ids=["182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e"], + question="question", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + graph = response.parse() + assert_matches_type(Question, graph, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_method_question_overload_2(self, client: Writer) -> None: + graph_stream = client.graphs.question( + graph_ids=["182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e"], + question="question", + stream=True, + ) + graph_stream.response.close() + + @parametrize + def test_method_question_with_all_params_overload_2(self, client: Writer) -> None: + graph_stream = client.graphs.question( + graph_ids=["182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e"], + question="question", + stream=True, + query_config={ + "grounding_level": 0, + "inline_citations": True, + "keyword_threshold": 0, + "max_snippets": 1, + "max_subquestions": 1, + "max_tokens": 100, + "search_weight": 0, + "semantic_threshold": 0, + }, + subqueries=True, + ) + graph_stream.response.close() + + @parametrize + def test_raw_response_question_overload_2(self, client: Writer) -> None: + response = client.graphs.with_raw_response.question( + graph_ids=["182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e"], + question="question", + stream=True, + ) + + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + stream = response.parse() + stream.close() + + @parametrize + def test_streaming_response_question_overload_2(self, client: Writer) -> None: + with client.graphs.with_streaming_response.question( + graph_ids=["182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e"], + question="question", + stream=True, + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + stream = response.parse() + stream.close() + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_method_remove_file_from_graph(self, client: Writer) -> None: + graph = client.graphs.remove_file_from_graph( + file_id="file_id", + graph_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + ) + assert_matches_type(GraphRemoveFileFromGraphResponse, graph, path=["response"]) + + @parametrize + def test_raw_response_remove_file_from_graph(self, client: Writer) -> None: + response = client.graphs.with_raw_response.remove_file_from_graph( + file_id="file_id", + graph_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + graph = response.parse() + assert_matches_type(GraphRemoveFileFromGraphResponse, graph, path=["response"]) + + @parametrize + def test_streaming_response_remove_file_from_graph(self, client: Writer) -> None: + with client.graphs.with_streaming_response.remove_file_from_graph( + file_id="file_id", + graph_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + graph = response.parse() + assert_matches_type(GraphRemoveFileFromGraphResponse, graph, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_path_params_remove_file_from_graph(self, client: Writer) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `graph_id` but received ''"): + client.graphs.with_raw_response.remove_file_from_graph( + file_id="file_id", + graph_id="", + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `file_id` but received ''"): + client.graphs.with_raw_response.remove_file_from_graph( + file_id="", + graph_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + ) + + +class TestAsyncGraphs: + parametrize = pytest.mark.parametrize( + "async_client", [False, True, {"http_client": "aiohttp"}], indirect=True, ids=["loose", "strict", "aiohttp"] + ) + + @parametrize + async def test_method_create(self, async_client: AsyncWriter) -> None: + graph = await async_client.graphs.create() + assert_matches_type(GraphCreateResponse, graph, path=["response"]) + + @parametrize + async def test_method_create_with_all_params(self, async_client: AsyncWriter) -> None: + graph = await async_client.graphs.create( + description="description", + name="name", + ) + assert_matches_type(GraphCreateResponse, graph, path=["response"]) + + @parametrize + async def test_raw_response_create(self, async_client: AsyncWriter) -> None: + response = await async_client.graphs.with_raw_response.create() + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + graph = await response.parse() + assert_matches_type(GraphCreateResponse, graph, path=["response"]) + + @parametrize + async def test_streaming_response_create(self, async_client: AsyncWriter) -> None: + async with async_client.graphs.with_streaming_response.create() as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + graph = await response.parse() + assert_matches_type(GraphCreateResponse, graph, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_method_retrieve(self, async_client: AsyncWriter) -> None: + graph = await async_client.graphs.retrieve( + "182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + ) + assert_matches_type(Graph, graph, path=["response"]) + + @parametrize + async def test_raw_response_retrieve(self, async_client: AsyncWriter) -> None: + response = await async_client.graphs.with_raw_response.retrieve( + "182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + graph = await response.parse() + assert_matches_type(Graph, graph, path=["response"]) + + @parametrize + async def test_streaming_response_retrieve(self, async_client: AsyncWriter) -> None: + async with async_client.graphs.with_streaming_response.retrieve( + "182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + graph = await response.parse() + assert_matches_type(Graph, graph, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_path_params_retrieve(self, async_client: AsyncWriter) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `graph_id` but received ''"): + await async_client.graphs.with_raw_response.retrieve( + "", + ) + + @parametrize + async def test_method_update(self, async_client: AsyncWriter) -> None: + graph = await async_client.graphs.update( + graph_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + ) + assert_matches_type(GraphUpdateResponse, graph, path=["response"]) + + @parametrize + async def test_method_update_with_all_params(self, async_client: AsyncWriter) -> None: + graph = await async_client.graphs.update( + graph_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + description="description", + name="name", + urls=[ + { + "type": "single_page", + "url": "url", + "exclude_urls": ["string"], + } + ], + ) + assert_matches_type(GraphUpdateResponse, graph, path=["response"]) + + @parametrize + async def test_raw_response_update(self, async_client: AsyncWriter) -> None: + response = await async_client.graphs.with_raw_response.update( + graph_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + graph = await response.parse() + assert_matches_type(GraphUpdateResponse, graph, path=["response"]) + + @parametrize + async def test_streaming_response_update(self, async_client: AsyncWriter) -> None: + async with async_client.graphs.with_streaming_response.update( + graph_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + graph = await response.parse() + assert_matches_type(GraphUpdateResponse, graph, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_path_params_update(self, async_client: AsyncWriter) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `graph_id` but received ''"): + await async_client.graphs.with_raw_response.update( + graph_id="", + ) + + @parametrize + async def test_method_list(self, async_client: AsyncWriter) -> None: + graph = await async_client.graphs.list() + assert_matches_type(AsyncCursorPage[Graph], graph, path=["response"]) + + @parametrize + async def test_method_list_with_all_params(self, async_client: AsyncWriter) -> None: + graph = await async_client.graphs.list( + after="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + before="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + limit=0, + order="asc", + ) + assert_matches_type(AsyncCursorPage[Graph], graph, path=["response"]) + + @parametrize + async def test_raw_response_list(self, async_client: AsyncWriter) -> None: + response = await async_client.graphs.with_raw_response.list() + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + graph = await response.parse() + assert_matches_type(AsyncCursorPage[Graph], graph, path=["response"]) + + @parametrize + async def test_streaming_response_list(self, async_client: AsyncWriter) -> None: + async with async_client.graphs.with_streaming_response.list() as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + graph = await response.parse() + assert_matches_type(AsyncCursorPage[Graph], graph, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_method_delete(self, async_client: AsyncWriter) -> None: + graph = await async_client.graphs.delete( + "182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + ) + assert_matches_type(GraphDeleteResponse, graph, path=["response"]) + + @parametrize + async def test_raw_response_delete(self, async_client: AsyncWriter) -> None: + response = await async_client.graphs.with_raw_response.delete( + "182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + graph = await response.parse() + assert_matches_type(GraphDeleteResponse, graph, path=["response"]) + + @parametrize + async def test_streaming_response_delete(self, async_client: AsyncWriter) -> None: + async with async_client.graphs.with_streaming_response.delete( + "182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + graph = await response.parse() + assert_matches_type(GraphDeleteResponse, graph, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_path_params_delete(self, async_client: AsyncWriter) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `graph_id` but received ''"): + await async_client.graphs.with_raw_response.delete( + "", + ) + + @parametrize + async def test_method_add_file_to_graph(self, async_client: AsyncWriter) -> None: + graph = await async_client.graphs.add_file_to_graph( + graph_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + file_id="file_id", + ) + assert_matches_type(File, graph, path=["response"]) + + @parametrize + async def test_raw_response_add_file_to_graph(self, async_client: AsyncWriter) -> None: + response = await async_client.graphs.with_raw_response.add_file_to_graph( + graph_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + file_id="file_id", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + graph = await response.parse() + assert_matches_type(File, graph, path=["response"]) + + @parametrize + async def test_streaming_response_add_file_to_graph(self, async_client: AsyncWriter) -> None: + async with async_client.graphs.with_streaming_response.add_file_to_graph( + graph_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + file_id="file_id", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + graph = await response.parse() + assert_matches_type(File, graph, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_path_params_add_file_to_graph(self, async_client: AsyncWriter) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `graph_id` but received ''"): + await async_client.graphs.with_raw_response.add_file_to_graph( + graph_id="", + file_id="file_id", + ) + + @parametrize + async def test_method_question_overload_1(self, async_client: AsyncWriter) -> None: + graph = await async_client.graphs.question( + graph_ids=["182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e"], + question="question", + ) + assert_matches_type(Question, graph, path=["response"]) + + @parametrize + async def test_method_question_with_all_params_overload_1(self, async_client: AsyncWriter) -> None: + graph = await async_client.graphs.question( + graph_ids=["182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e"], + question="question", + query_config={ + "grounding_level": 0, + "inline_citations": True, + "keyword_threshold": 0, + "max_snippets": 1, + "max_subquestions": 1, + "max_tokens": 100, + "search_weight": 0, + "semantic_threshold": 0, + }, + stream=False, + subqueries=True, + ) + assert_matches_type(Question, graph, path=["response"]) + + @parametrize + async def test_raw_response_question_overload_1(self, async_client: AsyncWriter) -> None: + response = await async_client.graphs.with_raw_response.question( + graph_ids=["182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e"], + question="question", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + graph = await response.parse() + assert_matches_type(Question, graph, path=["response"]) + + @parametrize + async def test_streaming_response_question_overload_1(self, async_client: AsyncWriter) -> None: + async with async_client.graphs.with_streaming_response.question( + graph_ids=["182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e"], + question="question", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + graph = await response.parse() + assert_matches_type(Question, graph, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_method_question_overload_2(self, async_client: AsyncWriter) -> None: + graph_stream = await async_client.graphs.question( + graph_ids=["182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e"], + question="question", + stream=True, + ) + await graph_stream.response.aclose() + + @parametrize + async def test_method_question_with_all_params_overload_2(self, async_client: AsyncWriter) -> None: + graph_stream = await async_client.graphs.question( + graph_ids=["182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e"], + question="question", + stream=True, + query_config={ + "grounding_level": 0, + "inline_citations": True, + "keyword_threshold": 0, + "max_snippets": 1, + "max_subquestions": 1, + "max_tokens": 100, + "search_weight": 0, + "semantic_threshold": 0, + }, + subqueries=True, + ) + await graph_stream.response.aclose() + + @parametrize + async def test_raw_response_question_overload_2(self, async_client: AsyncWriter) -> None: + response = await async_client.graphs.with_raw_response.question( + graph_ids=["182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e"], + question="question", + stream=True, + ) + + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + stream = await response.parse() + await stream.close() + + @parametrize + async def test_streaming_response_question_overload_2(self, async_client: AsyncWriter) -> None: + async with async_client.graphs.with_streaming_response.question( + graph_ids=["182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e"], + question="question", + stream=True, + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + stream = await response.parse() + await stream.close() + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_method_remove_file_from_graph(self, async_client: AsyncWriter) -> None: + graph = await async_client.graphs.remove_file_from_graph( + file_id="file_id", + graph_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + ) + assert_matches_type(GraphRemoveFileFromGraphResponse, graph, path=["response"]) + + @parametrize + async def test_raw_response_remove_file_from_graph(self, async_client: AsyncWriter) -> None: + response = await async_client.graphs.with_raw_response.remove_file_from_graph( + file_id="file_id", + graph_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + graph = await response.parse() + assert_matches_type(GraphRemoveFileFromGraphResponse, graph, path=["response"]) + + @parametrize + async def test_streaming_response_remove_file_from_graph(self, async_client: AsyncWriter) -> None: + async with async_client.graphs.with_streaming_response.remove_file_from_graph( + file_id="file_id", + graph_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + graph = await response.parse() + assert_matches_type(GraphRemoveFileFromGraphResponse, graph, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_path_params_remove_file_from_graph(self, async_client: AsyncWriter) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `graph_id` but received ''"): + await async_client.graphs.with_raw_response.remove_file_from_graph( + file_id="file_id", + graph_id="", + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `file_id` but received ''"): + await async_client.graphs.with_raw_response.remove_file_from_graph( + file_id="", + graph_id="182bd5e5-6e1a-4fe4-a799-aa6d9a6ab26e", + ) diff --git a/tests/api_resources/test_models.py b/tests/api_resources/test_models.py new file mode 100644 index 00000000..3ad1f08f --- /dev/null +++ b/tests/api_resources/test_models.py @@ -0,0 +1,74 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +import os +from typing import Any, cast + +import pytest + +from writerai import Writer, AsyncWriter +from tests.utils import assert_matches_type +from writerai.types import ModelListResponse + +base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") + + +class TestModels: + parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) + + @parametrize + def test_method_list(self, client: Writer) -> None: + model = client.models.list() + assert_matches_type(ModelListResponse, model, path=["response"]) + + @parametrize + def test_raw_response_list(self, client: Writer) -> None: + response = client.models.with_raw_response.list() + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + model = response.parse() + assert_matches_type(ModelListResponse, model, path=["response"]) + + @parametrize + def test_streaming_response_list(self, client: Writer) -> None: + with client.models.with_streaming_response.list() as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + model = response.parse() + assert_matches_type(ModelListResponse, model, path=["response"]) + + assert cast(Any, response.is_closed) is True + + +class TestAsyncModels: + parametrize = pytest.mark.parametrize( + "async_client", [False, True, {"http_client": "aiohttp"}], indirect=True, ids=["loose", "strict", "aiohttp"] + ) + + @parametrize + async def test_method_list(self, async_client: AsyncWriter) -> None: + model = await async_client.models.list() + assert_matches_type(ModelListResponse, model, path=["response"]) + + @parametrize + async def test_raw_response_list(self, async_client: AsyncWriter) -> None: + response = await async_client.models.with_raw_response.list() + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + model = await response.parse() + assert_matches_type(ModelListResponse, model, path=["response"]) + + @parametrize + async def test_streaming_response_list(self, async_client: AsyncWriter) -> None: + async with async_client.models.with_streaming_response.list() as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + model = await response.parse() + assert_matches_type(ModelListResponse, model, path=["response"]) + + assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/test_tools.py b/tests/api_resources/test_tools.py new file mode 100644 index 00000000..8e0bd7a7 --- /dev/null +++ b/tests/api_resources/test_tools.py @@ -0,0 +1,220 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +import os +from typing import Any, cast + +import pytest + +from writerai import Writer, AsyncWriter +from tests.utils import assert_matches_type +from writerai.types import ToolParsePdfResponse, ToolWebSearchResponse + +# pyright: reportDeprecated=false + +base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") + + +class TestTools: + parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) + + @parametrize + def test_method_parse_pdf(self, client: Writer) -> None: + with pytest.warns(DeprecationWarning): + tool = client.tools.parse_pdf( + file_id="file_id", + format="text", + ) + + assert_matches_type(ToolParsePdfResponse, tool, path=["response"]) + + @parametrize + def test_raw_response_parse_pdf(self, client: Writer) -> None: + with pytest.warns(DeprecationWarning): + response = client.tools.with_raw_response.parse_pdf( + file_id="file_id", + format="text", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + tool = response.parse() + assert_matches_type(ToolParsePdfResponse, tool, path=["response"]) + + @parametrize + def test_streaming_response_parse_pdf(self, client: Writer) -> None: + with pytest.warns(DeprecationWarning): + with client.tools.with_streaming_response.parse_pdf( + file_id="file_id", + format="text", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + tool = response.parse() + assert_matches_type(ToolParsePdfResponse, tool, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_path_params_parse_pdf(self, client: Writer) -> None: + with pytest.warns(DeprecationWarning): + with pytest.raises(ValueError, match=r"Expected a non-empty value for `file_id` but received ''"): + client.tools.with_raw_response.parse_pdf( + file_id="", + format="text", + ) + + @parametrize + def test_method_web_search(self, client: Writer) -> None: + with pytest.warns(DeprecationWarning): + tool = client.tools.web_search() + + assert_matches_type(ToolWebSearchResponse, tool, path=["response"]) + + @parametrize + def test_method_web_search_with_all_params(self, client: Writer) -> None: + with pytest.warns(DeprecationWarning): + tool = client.tools.web_search( + chunks_per_source=0, + country="afghanistan", + days=0, + exclude_domains=["string"], + include_answer=True, + include_domains=["dev.writer.com"], + include_raw_content="text", + max_results=0, + query="How do I get an API key for the Writer API?", + search_depth="basic", + stream=True, + time_range="day", + topic="general", + ) + + assert_matches_type(ToolWebSearchResponse, tool, path=["response"]) + + @parametrize + def test_raw_response_web_search(self, client: Writer) -> None: + with pytest.warns(DeprecationWarning): + response = client.tools.with_raw_response.web_search() + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + tool = response.parse() + assert_matches_type(ToolWebSearchResponse, tool, path=["response"]) + + @parametrize + def test_streaming_response_web_search(self, client: Writer) -> None: + with pytest.warns(DeprecationWarning): + with client.tools.with_streaming_response.web_search() as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + tool = response.parse() + assert_matches_type(ToolWebSearchResponse, tool, path=["response"]) + + assert cast(Any, response.is_closed) is True + + +class TestAsyncTools: + parametrize = pytest.mark.parametrize( + "async_client", [False, True, {"http_client": "aiohttp"}], indirect=True, ids=["loose", "strict", "aiohttp"] + ) + + @parametrize + async def test_method_parse_pdf(self, async_client: AsyncWriter) -> None: + with pytest.warns(DeprecationWarning): + tool = await async_client.tools.parse_pdf( + file_id="file_id", + format="text", + ) + + assert_matches_type(ToolParsePdfResponse, tool, path=["response"]) + + @parametrize + async def test_raw_response_parse_pdf(self, async_client: AsyncWriter) -> None: + with pytest.warns(DeprecationWarning): + response = await async_client.tools.with_raw_response.parse_pdf( + file_id="file_id", + format="text", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + tool = await response.parse() + assert_matches_type(ToolParsePdfResponse, tool, path=["response"]) + + @parametrize + async def test_streaming_response_parse_pdf(self, async_client: AsyncWriter) -> None: + with pytest.warns(DeprecationWarning): + async with async_client.tools.with_streaming_response.parse_pdf( + file_id="file_id", + format="text", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + tool = await response.parse() + assert_matches_type(ToolParsePdfResponse, tool, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_path_params_parse_pdf(self, async_client: AsyncWriter) -> None: + with pytest.warns(DeprecationWarning): + with pytest.raises(ValueError, match=r"Expected a non-empty value for `file_id` but received ''"): + await async_client.tools.with_raw_response.parse_pdf( + file_id="", + format="text", + ) + + @parametrize + async def test_method_web_search(self, async_client: AsyncWriter) -> None: + with pytest.warns(DeprecationWarning): + tool = await async_client.tools.web_search() + + assert_matches_type(ToolWebSearchResponse, tool, path=["response"]) + + @parametrize + async def test_method_web_search_with_all_params(self, async_client: AsyncWriter) -> None: + with pytest.warns(DeprecationWarning): + tool = await async_client.tools.web_search( + chunks_per_source=0, + country="afghanistan", + days=0, + exclude_domains=["string"], + include_answer=True, + include_domains=["dev.writer.com"], + include_raw_content="text", + max_results=0, + query="How do I get an API key for the Writer API?", + search_depth="basic", + stream=True, + time_range="day", + topic="general", + ) + + assert_matches_type(ToolWebSearchResponse, tool, path=["response"]) + + @parametrize + async def test_raw_response_web_search(self, async_client: AsyncWriter) -> None: + with pytest.warns(DeprecationWarning): + response = await async_client.tools.with_raw_response.web_search() + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + tool = await response.parse() + assert_matches_type(ToolWebSearchResponse, tool, path=["response"]) + + @parametrize + async def test_streaming_response_web_search(self, async_client: AsyncWriter) -> None: + with pytest.warns(DeprecationWarning): + async with async_client.tools.with_streaming_response.web_search() as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + tool = await response.parse() + assert_matches_type(ToolWebSearchResponse, tool, path=["response"]) + + assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/test_translation.py b/tests/api_resources/test_translation.py new file mode 100644 index 00000000..12497e7a --- /dev/null +++ b/tests/api_resources/test_translation.py @@ -0,0 +1,132 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +import os +from typing import Any, cast + +import pytest + +from writerai import Writer, AsyncWriter +from tests.utils import assert_matches_type +from writerai.types import TranslationResponse + +# pyright: reportDeprecated=false + +base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") + + +class TestTranslation: + parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) + + @parametrize + def test_method_translate(self, client: Writer) -> None: + with pytest.warns(DeprecationWarning): + translation = client.translation.translate( + formality=True, + length_control=True, + mask_profanity=True, + model="palmyra-translate", + source_language_code="en", + target_language_code="es", + text="Hello, world!", + ) + + assert_matches_type(TranslationResponse, translation, path=["response"]) + + @parametrize + def test_raw_response_translate(self, client: Writer) -> None: + with pytest.warns(DeprecationWarning): + response = client.translation.with_raw_response.translate( + formality=True, + length_control=True, + mask_profanity=True, + model="palmyra-translate", + source_language_code="en", + target_language_code="es", + text="Hello, world!", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + translation = response.parse() + assert_matches_type(TranslationResponse, translation, path=["response"]) + + @parametrize + def test_streaming_response_translate(self, client: Writer) -> None: + with pytest.warns(DeprecationWarning): + with client.translation.with_streaming_response.translate( + formality=True, + length_control=True, + mask_profanity=True, + model="palmyra-translate", + source_language_code="en", + target_language_code="es", + text="Hello, world!", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + translation = response.parse() + assert_matches_type(TranslationResponse, translation, path=["response"]) + + assert cast(Any, response.is_closed) is True + + +class TestAsyncTranslation: + parametrize = pytest.mark.parametrize( + "async_client", [False, True, {"http_client": "aiohttp"}], indirect=True, ids=["loose", "strict", "aiohttp"] + ) + + @parametrize + async def test_method_translate(self, async_client: AsyncWriter) -> None: + with pytest.warns(DeprecationWarning): + translation = await async_client.translation.translate( + formality=True, + length_control=True, + mask_profanity=True, + model="palmyra-translate", + source_language_code="en", + target_language_code="es", + text="Hello, world!", + ) + + assert_matches_type(TranslationResponse, translation, path=["response"]) + + @parametrize + async def test_raw_response_translate(self, async_client: AsyncWriter) -> None: + with pytest.warns(DeprecationWarning): + response = await async_client.translation.with_raw_response.translate( + formality=True, + length_control=True, + mask_profanity=True, + model="palmyra-translate", + source_language_code="en", + target_language_code="es", + text="Hello, world!", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + translation = await response.parse() + assert_matches_type(TranslationResponse, translation, path=["response"]) + + @parametrize + async def test_streaming_response_translate(self, async_client: AsyncWriter) -> None: + with pytest.warns(DeprecationWarning): + async with async_client.translation.with_streaming_response.translate( + formality=True, + length_control=True, + mask_profanity=True, + model="palmyra-translate", + source_language_code="en", + target_language_code="es", + text="Hello, world!", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + translation = await response.parse() + assert_matches_type(TranslationResponse, translation, path=["response"]) + + assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/test_vision.py b/tests/api_resources/test_vision.py new file mode 100644 index 00000000..dfca8433 --- /dev/null +++ b/tests/api_resources/test_vision.py @@ -0,0 +1,152 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +import os +from typing import Any, cast + +import pytest + +from writerai import Writer, AsyncWriter +from tests.utils import assert_matches_type +from writerai.types import VisionResponse + +base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") + + +class TestVision: + parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) + + @parametrize + def test_method_analyze(self, client: Writer) -> None: + vision = client.vision.analyze( + model="palmyra-vision", + prompt="Describe the difference between the image {{image_1}} and the image {{image_2}}.", + variables=[ + { + "file_id": "f1234", + "name": "image_1", + }, + { + "file_id": "f9876", + "name": "image_2", + }, + ], + ) + assert_matches_type(VisionResponse, vision, path=["response"]) + + @parametrize + def test_raw_response_analyze(self, client: Writer) -> None: + response = client.vision.with_raw_response.analyze( + model="palmyra-vision", + prompt="Describe the difference between the image {{image_1}} and the image {{image_2}}.", + variables=[ + { + "file_id": "f1234", + "name": "image_1", + }, + { + "file_id": "f9876", + "name": "image_2", + }, + ], + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + vision = response.parse() + assert_matches_type(VisionResponse, vision, path=["response"]) + + @parametrize + def test_streaming_response_analyze(self, client: Writer) -> None: + with client.vision.with_streaming_response.analyze( + model="palmyra-vision", + prompt="Describe the difference between the image {{image_1}} and the image {{image_2}}.", + variables=[ + { + "file_id": "f1234", + "name": "image_1", + }, + { + "file_id": "f9876", + "name": "image_2", + }, + ], + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + vision = response.parse() + assert_matches_type(VisionResponse, vision, path=["response"]) + + assert cast(Any, response.is_closed) is True + + +class TestAsyncVision: + parametrize = pytest.mark.parametrize( + "async_client", [False, True, {"http_client": "aiohttp"}], indirect=True, ids=["loose", "strict", "aiohttp"] + ) + + @parametrize + async def test_method_analyze(self, async_client: AsyncWriter) -> None: + vision = await async_client.vision.analyze( + model="palmyra-vision", + prompt="Describe the difference between the image {{image_1}} and the image {{image_2}}.", + variables=[ + { + "file_id": "f1234", + "name": "image_1", + }, + { + "file_id": "f9876", + "name": "image_2", + }, + ], + ) + assert_matches_type(VisionResponse, vision, path=["response"]) + + @parametrize + async def test_raw_response_analyze(self, async_client: AsyncWriter) -> None: + response = await async_client.vision.with_raw_response.analyze( + model="palmyra-vision", + prompt="Describe the difference between the image {{image_1}} and the image {{image_2}}.", + variables=[ + { + "file_id": "f1234", + "name": "image_1", + }, + { + "file_id": "f9876", + "name": "image_2", + }, + ], + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + vision = await response.parse() + assert_matches_type(VisionResponse, vision, path=["response"]) + + @parametrize + async def test_streaming_response_analyze(self, async_client: AsyncWriter) -> None: + async with async_client.vision.with_streaming_response.analyze( + model="palmyra-vision", + prompt="Describe the difference between the image {{image_1}} and the image {{image_2}}.", + variables=[ + { + "file_id": "f1234", + "name": "image_1", + }, + { + "file_id": "f9876", + "name": "image_2", + }, + ], + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + vision = await response.parse() + assert_matches_type(VisionResponse, vision, path=["response"]) + + assert cast(Any, response.is_closed) is True diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..ec6f42e1 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,84 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +import os +import logging +from typing import TYPE_CHECKING, Iterator, AsyncIterator + +import httpx +import pytest +from pytest_asyncio import is_async_test + +from writerai import Writer, AsyncWriter, DefaultAioHttpClient +from writerai._utils import is_dict + +if TYPE_CHECKING: + from _pytest.fixtures import FixtureRequest # pyright: ignore[reportPrivateImportUsage] + +pytest.register_assert_rewrite("tests.utils") + +logging.getLogger("writerai").setLevel(logging.DEBUG) + + +# automatically add `pytest.mark.asyncio()` to all of our async tests +# so we don't have to add that boilerplate everywhere +def pytest_collection_modifyitems(items: list[pytest.Function]) -> None: + pytest_asyncio_tests = (item for item in items if is_async_test(item)) + session_scope_marker = pytest.mark.asyncio(loop_scope="session") + for async_test in pytest_asyncio_tests: + async_test.add_marker(session_scope_marker, append=False) + + # We skip tests that use both the aiohttp client and respx_mock as respx_mock + # doesn't support custom transports. + for item in items: + if "async_client" not in item.fixturenames or "respx_mock" not in item.fixturenames: + continue + + if not hasattr(item, "callspec"): + continue + + async_client_param = item.callspec.params.get("async_client") + if is_dict(async_client_param) and async_client_param.get("http_client") == "aiohttp": + item.add_marker(pytest.mark.skip(reason="aiohttp client is not compatible with respx_mock")) + + +base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") + +api_key = "My API Key" + + +@pytest.fixture(scope="session") +def client(request: FixtureRequest) -> Iterator[Writer]: + strict = getattr(request, "param", True) + if not isinstance(strict, bool): + raise TypeError(f"Unexpected fixture parameter type {type(strict)}, expected {bool}") + + with Writer(base_url=base_url, api_key=api_key, _strict_response_validation=strict) as client: + yield client + + +@pytest.fixture(scope="session") +async def async_client(request: FixtureRequest) -> AsyncIterator[AsyncWriter]: + param = getattr(request, "param", True) + + # defaults + strict = True + http_client: None | httpx.AsyncClient = None + + if isinstance(param, bool): + strict = param + elif is_dict(param): + strict = param.get("strict", True) + assert isinstance(strict, bool) + + http_client_type = param.get("http_client", "httpx") + if http_client_type == "aiohttp": + http_client = DefaultAioHttpClient() + else: + raise TypeError(f"Unexpected fixture parameter type {type(param)}, expected bool or dict") + + async with AsyncWriter( + base_url=base_url, api_key=api_key, _strict_response_validation=strict, http_client=http_client + ) as client: + yield client diff --git a/tests/sample_file.txt b/tests/sample_file.txt new file mode 100644 index 00000000..af5626b4 --- /dev/null +++ b/tests/sample_file.txt @@ -0,0 +1 @@ +Hello, world! diff --git a/tests/test_client.py b/tests/test_client.py new file mode 100644 index 00000000..2b13f897 --- /dev/null +++ b/tests/test_client.py @@ -0,0 +1,1993 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +import gc +import os +import sys +import json +import asyncio +import inspect +import dataclasses +import tracemalloc +from typing import Any, Union, TypeVar, Callable, Iterable, Iterator, Optional, Coroutine, cast +from unittest import mock +from typing_extensions import Literal, AsyncIterator, override + +import httpx +import pytest +from respx import MockRouter +from pydantic import ValidationError + +from writerai import Writer, AsyncWriter, APIResponseValidationError +from writerai._types import Omit +from writerai._utils import asyncify +from writerai._models import BaseModel, FinalRequestOptions +from writerai._streaming import Stream, AsyncStream +from writerai._exceptions import WriterError, APIStatusError, APITimeoutError, APIResponseValidationError +from writerai._base_client import ( + DEFAULT_TIMEOUT, + HTTPX_DEFAULT_TIMEOUT, + BaseClient, + OtherPlatform, + DefaultHttpxClient, + DefaultAsyncHttpxClient, + get_platform, + make_request_options, +) + +from .utils import update_env + +T = TypeVar("T") +base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") +api_key = "My API Key" + + +def _get_params(client: BaseClient[Any, Any]) -> dict[str, str]: + request = client._build_request(FinalRequestOptions(method="get", url="/foo")) + url = httpx.URL(request.url) + return dict(url.params) + + +def _low_retry_timeout(*_args: Any, **_kwargs: Any) -> float: + return 0.1 + + +def mirror_request_content(request: httpx.Request) -> httpx.Response: + return httpx.Response(200, content=request.content) + + +# note: we can't use the httpx.MockTransport class as it consumes the request +# body itself, which means we can't test that the body is read lazily +class MockTransport(httpx.BaseTransport, httpx.AsyncBaseTransport): + def __init__( + self, + handler: Callable[[httpx.Request], httpx.Response] + | Callable[[httpx.Request], Coroutine[Any, Any, httpx.Response]], + ) -> None: + self.handler = handler + + @override + def handle_request( + self, + request: httpx.Request, + ) -> httpx.Response: + assert not inspect.iscoroutinefunction(self.handler), "handler must not be a coroutine function" + assert inspect.isfunction(self.handler), "handler must be a function" + return self.handler(request) + + @override + async def handle_async_request( + self, + request: httpx.Request, + ) -> httpx.Response: + assert inspect.iscoroutinefunction(self.handler), "handler must be a coroutine function" + return await self.handler(request) + + +@dataclasses.dataclass +class Counter: + value: int = 0 + + +def _make_sync_iterator(iterable: Iterable[T], counter: Optional[Counter] = None) -> Iterator[T]: + for item in iterable: + if counter: + counter.value += 1 + yield item + + +async def _make_async_iterator(iterable: Iterable[T], counter: Optional[Counter] = None) -> AsyncIterator[T]: + for item in iterable: + if counter: + counter.value += 1 + yield item + + +def _get_open_connections(client: Writer | AsyncWriter) -> int: + transport = client._client._transport + assert isinstance(transport, httpx.HTTPTransport) or isinstance(transport, httpx.AsyncHTTPTransport) + + pool = transport._pool + return len(pool._requests) + + +class TestWriter: + @pytest.mark.respx(base_url=base_url) + def test_raw_response(self, respx_mock: MockRouter, client: Writer) -> None: + respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) + + response = client.post("/foo", cast_to=httpx.Response) + assert response.status_code == 200 + assert isinstance(response, httpx.Response) + assert response.json() == {"foo": "bar"} + + @pytest.mark.respx(base_url=base_url) + def test_raw_response_for_binary(self, respx_mock: MockRouter, client: Writer) -> None: + respx_mock.post("/foo").mock( + return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}') + ) + + response = client.post("/foo", cast_to=httpx.Response) + assert response.status_code == 200 + assert isinstance(response, httpx.Response) + assert response.json() == {"foo": "bar"} + + def test_copy(self, client: Writer) -> None: + copied = client.copy() + assert id(copied) != id(client) + + copied = client.copy(api_key="another My API Key") + assert copied.api_key == "another My API Key" + assert client.api_key == "My API Key" + + def test_copy_default_options(self, client: Writer) -> None: + # options that have a default are overridden correctly + copied = client.copy(max_retries=7) + assert copied.max_retries == 7 + assert client.max_retries == 7 + + copied2 = copied.copy(max_retries=6) + assert copied2.max_retries == 6 + assert copied.max_retries == 7 + + # timeout + assert isinstance(client.timeout, httpx.Timeout) + copied = client.copy(timeout=None) + assert copied.timeout is None + assert isinstance(client.timeout, httpx.Timeout) + + def test_copy_default_headers(self) -> None: + client = Writer( + base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"} + ) + assert client.default_headers["X-Foo"] == "bar" + + # does not override the already given value when not specified + copied = client.copy() + assert copied.default_headers["X-Foo"] == "bar" + + # merges already given headers + copied = client.copy(default_headers={"X-Bar": "stainless"}) + assert copied.default_headers["X-Foo"] == "bar" + assert copied.default_headers["X-Bar"] == "stainless" + + # uses new values for any already given headers + copied = client.copy(default_headers={"X-Foo": "stainless"}) + assert copied.default_headers["X-Foo"] == "stainless" + + # set_default_headers + + # completely overrides already set values + copied = client.copy(set_default_headers={}) + assert copied.default_headers.get("X-Foo") is None + + copied = client.copy(set_default_headers={"X-Bar": "Robert"}) + assert copied.default_headers["X-Bar"] == "Robert" + + with pytest.raises( + ValueError, + match="`default_headers` and `set_default_headers` arguments are mutually exclusive", + ): + client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"}) + client.close() + + def test_copy_default_query(self) -> None: + client = Writer( + base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"} + ) + assert _get_params(client)["foo"] == "bar" + + # does not override the already given value when not specified + copied = client.copy() + assert _get_params(copied)["foo"] == "bar" + + # merges already given params + copied = client.copy(default_query={"bar": "stainless"}) + params = _get_params(copied) + assert params["foo"] == "bar" + assert params["bar"] == "stainless" + + # uses new values for any already given headers + copied = client.copy(default_query={"foo": "stainless"}) + assert _get_params(copied)["foo"] == "stainless" + + # set_default_query + + # completely overrides already set values + copied = client.copy(set_default_query={}) + assert _get_params(copied) == {} + + copied = client.copy(set_default_query={"bar": "Robert"}) + assert _get_params(copied)["bar"] == "Robert" + + with pytest.raises( + ValueError, + # TODO: update + match="`default_query` and `set_default_query` arguments are mutually exclusive", + ): + client.copy(set_default_query={}, default_query={"foo": "Bar"}) + + client.close() + + def test_copy_signature(self, client: Writer) -> None: + # ensure the same parameters that can be passed to the client are defined in the `.copy()` method + init_signature = inspect.signature( + # mypy doesn't like that we access the `__init__` property. + client.__init__, # type: ignore[misc] + ) + copy_signature = inspect.signature(client.copy) + exclude_params = {"transport", "proxies", "_strict_response_validation"} + + for name in init_signature.parameters.keys(): + if name in exclude_params: + continue + + copy_param = copy_signature.parameters.get(name) + assert copy_param is not None, f"copy() signature is missing the {name} param" + + @pytest.mark.skipif(sys.version_info >= (3, 10), reason="fails because of a memory leak that started from 3.12") + def test_copy_build_request(self, client: Writer) -> None: + options = FinalRequestOptions(method="get", url="/foo") + + def build_request(options: FinalRequestOptions) -> None: + client_copy = client.copy() + client_copy._build_request(options) + + # ensure that the machinery is warmed up before tracing starts. + build_request(options) + gc.collect() + + tracemalloc.start(1000) + + snapshot_before = tracemalloc.take_snapshot() + + ITERATIONS = 10 + for _ in range(ITERATIONS): + build_request(options) + + gc.collect() + snapshot_after = tracemalloc.take_snapshot() + + tracemalloc.stop() + + def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None: + if diff.count == 0: + # Avoid false positives by considering only leaks (i.e. allocations that persist). + return + + if diff.count % ITERATIONS != 0: + # Avoid false positives by considering only leaks that appear per iteration. + return + + for frame in diff.traceback: + if any( + frame.filename.endswith(fragment) + for fragment in [ + # to_raw_response_wrapper leaks through the @functools.wraps() decorator. + # + # removing the decorator fixes the leak for reasons we don't understand. + "writerai/_legacy_response.py", + "writerai/_response.py", + # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason. + "writerai/_compat.py", + # Standard library leaks we don't care about. + "/logging/__init__.py", + ] + ): + return + + leaks.append(diff) + + leaks: list[tracemalloc.StatisticDiff] = [] + for diff in snapshot_after.compare_to(snapshot_before, "traceback"): + add_leak(leaks, diff) + if leaks: + for leak in leaks: + print("MEMORY LEAK:", leak) + for frame in leak.traceback: + print(frame) + raise AssertionError() + + def test_request_timeout(self, client: Writer) -> None: + request = client._build_request(FinalRequestOptions(method="get", url="/foo")) + timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore + assert timeout == DEFAULT_TIMEOUT + + request = client._build_request(FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))) + timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore + assert timeout == httpx.Timeout(100.0) + + def test_client_timeout_option(self) -> None: + client = Writer(base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0)) + + request = client._build_request(FinalRequestOptions(method="get", url="/foo")) + timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore + assert timeout == httpx.Timeout(0) + + client.close() + + def test_http_client_timeout_option(self) -> None: + # custom timeout given to the httpx client should be used + with httpx.Client(timeout=None) as http_client: + client = Writer( + base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client + ) + + request = client._build_request(FinalRequestOptions(method="get", url="/foo")) + timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore + assert timeout == httpx.Timeout(None) + + client.close() + + # no timeout given to the httpx client should not use the httpx default + with httpx.Client() as http_client: + client = Writer( + base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client + ) + + request = client._build_request(FinalRequestOptions(method="get", url="/foo")) + timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore + assert timeout == DEFAULT_TIMEOUT + + client.close() + + # explicitly passing the default timeout currently results in it being ignored + with httpx.Client(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client: + client = Writer( + base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client + ) + + request = client._build_request(FinalRequestOptions(method="get", url="/foo")) + timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore + assert timeout == DEFAULT_TIMEOUT # our default + + client.close() + + async def test_invalid_http_client(self) -> None: + with pytest.raises(TypeError, match="Invalid `http_client` arg"): + async with httpx.AsyncClient() as http_client: + Writer( + base_url=base_url, + api_key=api_key, + _strict_response_validation=True, + http_client=cast(Any, http_client), + ) + + def test_default_headers_option(self) -> None: + test_client = Writer( + base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"} + ) + request = test_client._build_request(FinalRequestOptions(method="get", url="/foo")) + assert request.headers.get("x-foo") == "bar" + assert request.headers.get("x-stainless-lang") == "python" + + test_client2 = Writer( + base_url=base_url, + api_key=api_key, + _strict_response_validation=True, + default_headers={ + "X-Foo": "stainless", + "X-Stainless-Lang": "my-overriding-header", + }, + ) + request = test_client2._build_request(FinalRequestOptions(method="get", url="/foo")) + assert request.headers.get("x-foo") == "stainless" + assert request.headers.get("x-stainless-lang") == "my-overriding-header" + + test_client.close() + test_client2.close() + + def test_validate_headers(self) -> None: + client = Writer(base_url=base_url, api_key=api_key, _strict_response_validation=True) + request = client._build_request(FinalRequestOptions(method="get", url="/foo")) + assert request.headers.get("Authorization") == f"Bearer {api_key}" + + with pytest.raises(WriterError): + with update_env(**{"WRITER_API_KEY": Omit()}): + client2 = Writer(base_url=base_url, api_key=None, _strict_response_validation=True) + _ = client2 + + def test_default_query_option(self) -> None: + client = Writer( + base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"} + ) + request = client._build_request(FinalRequestOptions(method="get", url="/foo")) + url = httpx.URL(request.url) + assert dict(url.params) == {"query_param": "bar"} + + request = client._build_request( + FinalRequestOptions( + method="get", + url="/foo", + params={"foo": "baz", "query_param": "overridden"}, + ) + ) + url = httpx.URL(request.url) + assert dict(url.params) == {"foo": "baz", "query_param": "overridden"} + + client.close() + + def test_hardcoded_query_params_in_url(self, client: Writer) -> None: + request = client._build_request(FinalRequestOptions(method="get", url="/foo?beta=true")) + url = httpx.URL(request.url) + assert dict(url.params) == {"beta": "true"} + + request = client._build_request( + FinalRequestOptions( + method="get", + url="/foo?beta=true", + params={"limit": "10", "page": "abc"}, + ) + ) + url = httpx.URL(request.url) + assert dict(url.params) == {"beta": "true", "limit": "10", "page": "abc"} + + request = client._build_request( + FinalRequestOptions( + method="get", + url="/files/a%2Fb?beta=true", + params={"limit": "10"}, + ) + ) + assert request.url.raw_path == b"/files/a%2Fb?beta=true&limit=10" + + def test_request_extra_json(self, client: Writer) -> None: + request = client._build_request( + FinalRequestOptions( + method="post", + url="/foo", + json_data={"foo": "bar"}, + extra_json={"baz": False}, + ), + ) + data = json.loads(request.content.decode("utf-8")) + assert data == {"foo": "bar", "baz": False} + + request = client._build_request( + FinalRequestOptions( + method="post", + url="/foo", + extra_json={"baz": False}, + ), + ) + data = json.loads(request.content.decode("utf-8")) + assert data == {"baz": False} + + # `extra_json` takes priority over `json_data` when keys clash + request = client._build_request( + FinalRequestOptions( + method="post", + url="/foo", + json_data={"foo": "bar", "baz": True}, + extra_json={"baz": None}, + ), + ) + data = json.loads(request.content.decode("utf-8")) + assert data == {"foo": "bar", "baz": None} + + def test_request_extra_headers(self, client: Writer) -> None: + request = client._build_request( + FinalRequestOptions( + method="post", + url="/foo", + **make_request_options(extra_headers={"X-Foo": "Foo"}), + ), + ) + assert request.headers.get("X-Foo") == "Foo" + + # `extra_headers` takes priority over `default_headers` when keys clash + request = client.with_options(default_headers={"X-Bar": "true"})._build_request( + FinalRequestOptions( + method="post", + url="/foo", + **make_request_options( + extra_headers={"X-Bar": "false"}, + ), + ), + ) + assert request.headers.get("X-Bar") == "false" + + def test_request_extra_query(self, client: Writer) -> None: + request = client._build_request( + FinalRequestOptions( + method="post", + url="/foo", + **make_request_options( + extra_query={"my_query_param": "Foo"}, + ), + ), + ) + params = dict(request.url.params) + assert params == {"my_query_param": "Foo"} + + # if both `query` and `extra_query` are given, they are merged + request = client._build_request( + FinalRequestOptions( + method="post", + url="/foo", + **make_request_options( + query={"bar": "1"}, + extra_query={"foo": "2"}, + ), + ), + ) + params = dict(request.url.params) + assert params == {"bar": "1", "foo": "2"} + + # `extra_query` takes priority over `query` when keys clash + request = client._build_request( + FinalRequestOptions( + method="post", + url="/foo", + **make_request_options( + query={"foo": "1"}, + extra_query={"foo": "2"}, + ), + ), + ) + params = dict(request.url.params) + assert params == {"foo": "2"} + + def test_multipart_repeating_array(self, client: Writer) -> None: + request = client._build_request( + FinalRequestOptions.construct( + method="post", + url="/foo", + headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"}, + json_data={"array": ["foo", "bar"]}, + files=[("foo.txt", b"hello world")], + ) + ) + + assert request.read().split(b"\r\n") == [ + b"--6b7ba517decee4a450543ea6ae821c82", + b'Content-Disposition: form-data; name="array[]"', + b"", + b"foo", + b"--6b7ba517decee4a450543ea6ae821c82", + b'Content-Disposition: form-data; name="array[]"', + b"", + b"bar", + b"--6b7ba517decee4a450543ea6ae821c82", + b'Content-Disposition: form-data; name="foo.txt"; filename="upload"', + b"Content-Type: application/octet-stream", + b"", + b"hello world", + b"--6b7ba517decee4a450543ea6ae821c82--", + b"", + ] + + @pytest.mark.respx(base_url=base_url) + def test_binary_content_upload(self, respx_mock: MockRouter, client: Writer) -> None: + respx_mock.post("/upload").mock(side_effect=mirror_request_content) + + file_content = b"Hello, this is a test file." + + response = client.post( + "/upload", + content=file_content, + cast_to=httpx.Response, + options={"headers": {"Content-Type": "application/octet-stream"}}, + ) + + assert response.status_code == 200 + assert response.request.headers["Content-Type"] == "application/octet-stream" + assert response.content == file_content + + def test_binary_content_upload_with_iterator(self) -> None: + file_content = b"Hello, this is a test file." + counter = Counter() + iterator = _make_sync_iterator([file_content], counter=counter) + + def mock_handler(request: httpx.Request) -> httpx.Response: + assert counter.value == 0, "the request body should not have been read" + return httpx.Response(200, content=request.read()) + + with Writer( + base_url=base_url, + api_key=api_key, + _strict_response_validation=True, + http_client=httpx.Client(transport=MockTransport(handler=mock_handler)), + ) as client: + response = client.post( + "/upload", + content=iterator, + cast_to=httpx.Response, + options={"headers": {"Content-Type": "application/octet-stream"}}, + ) + + assert response.status_code == 200 + assert response.request.headers["Content-Type"] == "application/octet-stream" + assert response.content == file_content + assert counter.value == 1 + + @pytest.mark.respx(base_url=base_url) + def test_binary_content_upload_with_body_is_deprecated(self, respx_mock: MockRouter, client: Writer) -> None: + respx_mock.post("/upload").mock(side_effect=mirror_request_content) + + file_content = b"Hello, this is a test file." + + with pytest.deprecated_call( + match="Passing raw bytes as `body` is deprecated and will be removed in a future version. Please pass raw bytes via the `content` parameter instead." + ): + response = client.post( + "/upload", + body=file_content, + cast_to=httpx.Response, + options={"headers": {"Content-Type": "application/octet-stream"}}, + ) + + assert response.status_code == 200 + assert response.request.headers["Content-Type"] == "application/octet-stream" + assert response.content == file_content + + @pytest.mark.respx(base_url=base_url) + def test_basic_union_response(self, respx_mock: MockRouter, client: Writer) -> None: + class Model1(BaseModel): + name: str + + class Model2(BaseModel): + foo: str + + respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) + + response = client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) + assert isinstance(response, Model2) + assert response.foo == "bar" + + @pytest.mark.respx(base_url=base_url) + def test_union_response_different_types(self, respx_mock: MockRouter, client: Writer) -> None: + """Union of objects with the same field name using a different type""" + + class Model1(BaseModel): + foo: int + + class Model2(BaseModel): + foo: str + + respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) + + response = client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) + assert isinstance(response, Model2) + assert response.foo == "bar" + + respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1})) + + response = client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) + assert isinstance(response, Model1) + assert response.foo == 1 + + @pytest.mark.respx(base_url=base_url) + def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter, client: Writer) -> None: + """ + Response that sets Content-Type to something other than application/json but returns json data + """ + + class Model(BaseModel): + foo: int + + respx_mock.get("/foo").mock( + return_value=httpx.Response( + 200, + content=json.dumps({"foo": 2}), + headers={"Content-Type": "application/text"}, + ) + ) + + response = client.get("/foo", cast_to=Model) + assert isinstance(response, Model) + assert response.foo == 2 + + def test_base_url_setter(self) -> None: + client = Writer(base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True) + assert client.base_url == "https://example.com/from_init/" + + client.base_url = "https://example.com/from_setter" # type: ignore[assignment] + + assert client.base_url == "https://example.com/from_setter/" + + client.close() + + def test_base_url_env(self) -> None: + with update_env(WRITER_BASE_URL="http://localhost:5000/from/env"): + client = Writer(api_key=api_key, _strict_response_validation=True) + assert client.base_url == "http://localhost:5000/from/env/" + + @pytest.mark.parametrize( + "client", + [ + Writer(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True), + Writer( + base_url="http://localhost:5000/custom/path/", + api_key=api_key, + _strict_response_validation=True, + http_client=httpx.Client(), + ), + ], + ids=["standard", "custom http client"], + ) + def test_base_url_trailing_slash(self, client: Writer) -> None: + request = client._build_request( + FinalRequestOptions( + method="post", + url="/foo", + json_data={"foo": "bar"}, + ), + ) + assert request.url == "http://localhost:5000/custom/path/foo" + client.close() + + @pytest.mark.parametrize( + "client", + [ + Writer(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True), + Writer( + base_url="http://localhost:5000/custom/path/", + api_key=api_key, + _strict_response_validation=True, + http_client=httpx.Client(), + ), + ], + ids=["standard", "custom http client"], + ) + def test_base_url_no_trailing_slash(self, client: Writer) -> None: + request = client._build_request( + FinalRequestOptions( + method="post", + url="/foo", + json_data={"foo": "bar"}, + ), + ) + assert request.url == "http://localhost:5000/custom/path/foo" + client.close() + + @pytest.mark.parametrize( + "client", + [ + Writer(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True), + Writer( + base_url="http://localhost:5000/custom/path/", + api_key=api_key, + _strict_response_validation=True, + http_client=httpx.Client(), + ), + ], + ids=["standard", "custom http client"], + ) + def test_absolute_request_url(self, client: Writer) -> None: + request = client._build_request( + FinalRequestOptions( + method="post", + url="https://myapi.com/foo", + json_data={"foo": "bar"}, + ), + ) + assert request.url == "https://myapi.com/foo" + client.close() + + def test_copied_client_does_not_close_http(self) -> None: + test_client = Writer(base_url=base_url, api_key=api_key, _strict_response_validation=True) + assert not test_client.is_closed() + + copied = test_client.copy() + assert copied is not test_client + + del copied + + assert not test_client.is_closed() + + def test_client_context_manager(self) -> None: + test_client = Writer(base_url=base_url, api_key=api_key, _strict_response_validation=True) + with test_client as c2: + assert c2 is test_client + assert not c2.is_closed() + assert not test_client.is_closed() + assert test_client.is_closed() + + @pytest.mark.respx(base_url=base_url) + def test_client_response_validation_error(self, respx_mock: MockRouter, client: Writer) -> None: + class Model(BaseModel): + foo: str + + respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}})) + + with pytest.raises(APIResponseValidationError) as exc: + client.get("/foo", cast_to=Model) + + assert isinstance(exc.value.__cause__, ValidationError) + + def test_client_max_retries_validation(self) -> None: + with pytest.raises(TypeError, match=r"max_retries cannot be None"): + Writer(base_url=base_url, api_key=api_key, _strict_response_validation=True, max_retries=cast(Any, None)) + + @pytest.mark.respx(base_url=base_url) + def test_default_stream_cls(self, respx_mock: MockRouter, client: Writer) -> None: + class Model(BaseModel): + name: str + + respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) + + stream = client.post("/foo", cast_to=Model, stream=True, stream_cls=Stream[Model]) + assert isinstance(stream, Stream) + stream.response.close() + + @pytest.mark.respx(base_url=base_url) + def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None: + class Model(BaseModel): + name: str + + respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format")) + + strict_client = Writer(base_url=base_url, api_key=api_key, _strict_response_validation=True) + + with pytest.raises(APIResponseValidationError): + strict_client.get("/foo", cast_to=Model) + + non_strict_client = Writer(base_url=base_url, api_key=api_key, _strict_response_validation=False) + + response = non_strict_client.get("/foo", cast_to=Model) + assert isinstance(response, str) # type: ignore[unreachable] + + strict_client.close() + non_strict_client.close() + + @pytest.mark.parametrize( + "remaining_retries,retry_after,timeout", + [ + [3, "20", 20], + [3, "0", 1], + [3, "-10", 1], + [3, "60", 60], + [3, "61", 1], + [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20], + [3, "Fri, 29 Sep 2023 16:26:37 GMT", 1], + [3, "Fri, 29 Sep 2023 16:26:27 GMT", 1], + [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60], + [3, "Fri, 29 Sep 2023 16:27:38 GMT", 1], + [3, "99999999999999999999999999999999999", 1], + [3, "Zun, 29 Sep 2023 16:26:27 GMT", 1], + [3, "", 1], + [2, "", 1 * 2.0], + [1, "", 1 * 4.0], + [-1100, "", 60], # test large number potentially overflowing + ], + ) + @mock.patch("time.time", mock.MagicMock(return_value=1696004797)) + def test_parse_retry_after_header( + self, remaining_retries: int, retry_after: str, timeout: float, client: Writer + ) -> None: + headers = httpx.Headers({"retry-after": retry_after}) + options = FinalRequestOptions(method="get", url="/foo", max_retries=3) + calculated = client._calculate_retry_timeout(remaining_retries, options, headers) + assert calculated == pytest.approx(timeout, 1 * 0.875) # pyright: ignore[reportUnknownMemberType] + + @mock.patch("writerai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) + @pytest.mark.respx(base_url=base_url) + def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter, client: Writer) -> None: + respx_mock.post("/v1/chat").mock(side_effect=httpx.TimeoutException("Test timeout error")) + + with pytest.raises(APITimeoutError): + client.chat.with_streaming_response.chat(messages=[{"role": "user"}], model="model").__enter__() + + assert _get_open_connections(client) == 0 + + @mock.patch("writerai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) + @pytest.mark.respx(base_url=base_url) + def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter, client: Writer) -> None: + respx_mock.post("/v1/chat").mock(return_value=httpx.Response(500)) + + with pytest.raises(APIStatusError): + client.chat.with_streaming_response.chat(messages=[{"role": "user"}], model="model").__enter__() + assert _get_open_connections(client) == 0 + + @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) + @mock.patch("writerai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) + @pytest.mark.respx(base_url=base_url) + @pytest.mark.parametrize("failure_mode", ["status", "exception"]) + def test_retries_taken( + self, + client: Writer, + failures_before_success: int, + failure_mode: Literal["status", "exception"], + respx_mock: MockRouter, + ) -> None: + client = client.with_options(max_retries=4) + + nb_retries = 0 + + def retry_handler(_request: httpx.Request) -> httpx.Response: + nonlocal nb_retries + if nb_retries < failures_before_success: + nb_retries += 1 + if failure_mode == "exception": + raise RuntimeError("oops") + return httpx.Response(500) + return httpx.Response(200) + + respx_mock.post("/v1/chat").mock(side_effect=retry_handler) + + response = client.chat.with_raw_response.chat(messages=[{"role": "user"}], model="model") + + assert response.retries_taken == failures_before_success + assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success + + @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) + @mock.patch("writerai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) + @pytest.mark.respx(base_url=base_url) + def test_omit_retry_count_header( + self, client: Writer, failures_before_success: int, respx_mock: MockRouter + ) -> None: + client = client.with_options(max_retries=4) + + nb_retries = 0 + + def retry_handler(_request: httpx.Request) -> httpx.Response: + nonlocal nb_retries + if nb_retries < failures_before_success: + nb_retries += 1 + return httpx.Response(500) + return httpx.Response(200) + + respx_mock.post("/v1/chat").mock(side_effect=retry_handler) + + response = client.chat.with_raw_response.chat( + messages=[{"role": "user"}], model="model", extra_headers={"x-stainless-retry-count": Omit()} + ) + + assert len(response.http_request.headers.get_list("x-stainless-retry-count")) == 0 + + @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) + @mock.patch("writerai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) + @pytest.mark.respx(base_url=base_url) + def test_overwrite_retry_count_header( + self, client: Writer, failures_before_success: int, respx_mock: MockRouter + ) -> None: + client = client.with_options(max_retries=4) + + nb_retries = 0 + + def retry_handler(_request: httpx.Request) -> httpx.Response: + nonlocal nb_retries + if nb_retries < failures_before_success: + nb_retries += 1 + return httpx.Response(500) + return httpx.Response(200) + + respx_mock.post("/v1/chat").mock(side_effect=retry_handler) + + response = client.chat.with_raw_response.chat( + messages=[{"role": "user"}], model="model", extra_headers={"x-stainless-retry-count": "42"} + ) + + assert response.http_request.headers.get("x-stainless-retry-count") == "42" + + def test_proxy_environment_variables(self, monkeypatch: pytest.MonkeyPatch) -> None: + # Test that the proxy environment variables are set correctly + monkeypatch.setenv("HTTPS_PROXY", "https://example.org") + # Delete in case our environment has any proxy env vars set + monkeypatch.delenv("HTTP_PROXY", raising=False) + monkeypatch.delenv("ALL_PROXY", raising=False) + monkeypatch.delenv("NO_PROXY", raising=False) + monkeypatch.delenv("http_proxy", raising=False) + monkeypatch.delenv("https_proxy", raising=False) + monkeypatch.delenv("all_proxy", raising=False) + monkeypatch.delenv("no_proxy", raising=False) + + client = DefaultHttpxClient() + + mounts = tuple(client._mounts.items()) + assert len(mounts) == 1 + assert mounts[0][0].pattern == "https://" + + @pytest.mark.filterwarnings("ignore:.*deprecated.*:DeprecationWarning") + def test_default_client_creation(self) -> None: + # Ensure that the client can be initialized without any exceptions + DefaultHttpxClient( + verify=True, + cert=None, + trust_env=True, + http1=True, + http2=False, + limits=httpx.Limits(max_connections=100, max_keepalive_connections=20), + ) + + @pytest.mark.respx(base_url=base_url) + def test_follow_redirects(self, respx_mock: MockRouter, client: Writer) -> None: + # Test that the default follow_redirects=True allows following redirects + respx_mock.post("/redirect").mock( + return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"}) + ) + respx_mock.get("/redirected").mock(return_value=httpx.Response(200, json={"status": "ok"})) + + response = client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response) + assert response.status_code == 200 + assert response.json() == {"status": "ok"} + + @pytest.mark.respx(base_url=base_url) + def test_follow_redirects_disabled(self, respx_mock: MockRouter, client: Writer) -> None: + # Test that follow_redirects=False prevents following redirects + respx_mock.post("/redirect").mock( + return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"}) + ) + + with pytest.raises(APIStatusError) as exc_info: + client.post("/redirect", body={"key": "value"}, options={"follow_redirects": False}, cast_to=httpx.Response) + + assert exc_info.value.response.status_code == 302 + assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected" + + +class TestAsyncWriter: + @pytest.mark.respx(base_url=base_url) + async def test_raw_response(self, respx_mock: MockRouter, async_client: AsyncWriter) -> None: + respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) + + response = await async_client.post("/foo", cast_to=httpx.Response) + assert response.status_code == 200 + assert isinstance(response, httpx.Response) + assert response.json() == {"foo": "bar"} + + @pytest.mark.respx(base_url=base_url) + async def test_raw_response_for_binary(self, respx_mock: MockRouter, async_client: AsyncWriter) -> None: + respx_mock.post("/foo").mock( + return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}') + ) + + response = await async_client.post("/foo", cast_to=httpx.Response) + assert response.status_code == 200 + assert isinstance(response, httpx.Response) + assert response.json() == {"foo": "bar"} + + def test_copy(self, async_client: AsyncWriter) -> None: + copied = async_client.copy() + assert id(copied) != id(async_client) + + copied = async_client.copy(api_key="another My API Key") + assert copied.api_key == "another My API Key" + assert async_client.api_key == "My API Key" + + def test_copy_default_options(self, async_client: AsyncWriter) -> None: + # options that have a default are overridden correctly + copied = async_client.copy(max_retries=7) + assert copied.max_retries == 7 + assert async_client.max_retries == 7 + + copied2 = copied.copy(max_retries=6) + assert copied2.max_retries == 6 + assert copied.max_retries == 7 + + # timeout + assert isinstance(async_client.timeout, httpx.Timeout) + copied = async_client.copy(timeout=None) + assert copied.timeout is None + assert isinstance(async_client.timeout, httpx.Timeout) + + async def test_copy_default_headers(self) -> None: + client = AsyncWriter( + base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"} + ) + assert client.default_headers["X-Foo"] == "bar" + + # does not override the already given value when not specified + copied = client.copy() + assert copied.default_headers["X-Foo"] == "bar" + + # merges already given headers + copied = client.copy(default_headers={"X-Bar": "stainless"}) + assert copied.default_headers["X-Foo"] == "bar" + assert copied.default_headers["X-Bar"] == "stainless" + + # uses new values for any already given headers + copied = client.copy(default_headers={"X-Foo": "stainless"}) + assert copied.default_headers["X-Foo"] == "stainless" + + # set_default_headers + + # completely overrides already set values + copied = client.copy(set_default_headers={}) + assert copied.default_headers.get("X-Foo") is None + + copied = client.copy(set_default_headers={"X-Bar": "Robert"}) + assert copied.default_headers["X-Bar"] == "Robert" + + with pytest.raises( + ValueError, + match="`default_headers` and `set_default_headers` arguments are mutually exclusive", + ): + client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"}) + await client.close() + + async def test_copy_default_query(self) -> None: + client = AsyncWriter( + base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"} + ) + assert _get_params(client)["foo"] == "bar" + + # does not override the already given value when not specified + copied = client.copy() + assert _get_params(copied)["foo"] == "bar" + + # merges already given params + copied = client.copy(default_query={"bar": "stainless"}) + params = _get_params(copied) + assert params["foo"] == "bar" + assert params["bar"] == "stainless" + + # uses new values for any already given headers + copied = client.copy(default_query={"foo": "stainless"}) + assert _get_params(copied)["foo"] == "stainless" + + # set_default_query + + # completely overrides already set values + copied = client.copy(set_default_query={}) + assert _get_params(copied) == {} + + copied = client.copy(set_default_query={"bar": "Robert"}) + assert _get_params(copied)["bar"] == "Robert" + + with pytest.raises( + ValueError, + # TODO: update + match="`default_query` and `set_default_query` arguments are mutually exclusive", + ): + client.copy(set_default_query={}, default_query={"foo": "Bar"}) + + await client.close() + + def test_copy_signature(self, async_client: AsyncWriter) -> None: + # ensure the same parameters that can be passed to the client are defined in the `.copy()` method + init_signature = inspect.signature( + # mypy doesn't like that we access the `__init__` property. + async_client.__init__, # type: ignore[misc] + ) + copy_signature = inspect.signature(async_client.copy) + exclude_params = {"transport", "proxies", "_strict_response_validation"} + + for name in init_signature.parameters.keys(): + if name in exclude_params: + continue + + copy_param = copy_signature.parameters.get(name) + assert copy_param is not None, f"copy() signature is missing the {name} param" + + @pytest.mark.skipif(sys.version_info >= (3, 10), reason="fails because of a memory leak that started from 3.12") + def test_copy_build_request(self, async_client: AsyncWriter) -> None: + options = FinalRequestOptions(method="get", url="/foo") + + def build_request(options: FinalRequestOptions) -> None: + client_copy = async_client.copy() + client_copy._build_request(options) + + # ensure that the machinery is warmed up before tracing starts. + build_request(options) + gc.collect() + + tracemalloc.start(1000) + + snapshot_before = tracemalloc.take_snapshot() + + ITERATIONS = 10 + for _ in range(ITERATIONS): + build_request(options) + + gc.collect() + snapshot_after = tracemalloc.take_snapshot() + + tracemalloc.stop() + + def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None: + if diff.count == 0: + # Avoid false positives by considering only leaks (i.e. allocations that persist). + return + + if diff.count % ITERATIONS != 0: + # Avoid false positives by considering only leaks that appear per iteration. + return + + for frame in diff.traceback: + if any( + frame.filename.endswith(fragment) + for fragment in [ + # to_raw_response_wrapper leaks through the @functools.wraps() decorator. + # + # removing the decorator fixes the leak for reasons we don't understand. + "writerai/_legacy_response.py", + "writerai/_response.py", + # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason. + "writerai/_compat.py", + # Standard library leaks we don't care about. + "/logging/__init__.py", + ] + ): + return + + leaks.append(diff) + + leaks: list[tracemalloc.StatisticDiff] = [] + for diff in snapshot_after.compare_to(snapshot_before, "traceback"): + add_leak(leaks, diff) + if leaks: + for leak in leaks: + print("MEMORY LEAK:", leak) + for frame in leak.traceback: + print(frame) + raise AssertionError() + + async def test_request_timeout(self, async_client: AsyncWriter) -> None: + request = async_client._build_request(FinalRequestOptions(method="get", url="/foo")) + timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore + assert timeout == DEFAULT_TIMEOUT + + request = async_client._build_request( + FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0)) + ) + timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore + assert timeout == httpx.Timeout(100.0) + + async def test_client_timeout_option(self) -> None: + client = AsyncWriter( + base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0) + ) + + request = client._build_request(FinalRequestOptions(method="get", url="/foo")) + timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore + assert timeout == httpx.Timeout(0) + + await client.close() + + async def test_http_client_timeout_option(self) -> None: + # custom timeout given to the httpx client should be used + async with httpx.AsyncClient(timeout=None) as http_client: + client = AsyncWriter( + base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client + ) + + request = client._build_request(FinalRequestOptions(method="get", url="/foo")) + timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore + assert timeout == httpx.Timeout(None) + + await client.close() + + # no timeout given to the httpx client should not use the httpx default + async with httpx.AsyncClient() as http_client: + client = AsyncWriter( + base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client + ) + + request = client._build_request(FinalRequestOptions(method="get", url="/foo")) + timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore + assert timeout == DEFAULT_TIMEOUT + + await client.close() + + # explicitly passing the default timeout currently results in it being ignored + async with httpx.AsyncClient(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client: + client = AsyncWriter( + base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client + ) + + request = client._build_request(FinalRequestOptions(method="get", url="/foo")) + timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore + assert timeout == DEFAULT_TIMEOUT # our default + + await client.close() + + def test_invalid_http_client(self) -> None: + with pytest.raises(TypeError, match="Invalid `http_client` arg"): + with httpx.Client() as http_client: + AsyncWriter( + base_url=base_url, + api_key=api_key, + _strict_response_validation=True, + http_client=cast(Any, http_client), + ) + + async def test_default_headers_option(self) -> None: + test_client = AsyncWriter( + base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"} + ) + request = test_client._build_request(FinalRequestOptions(method="get", url="/foo")) + assert request.headers.get("x-foo") == "bar" + assert request.headers.get("x-stainless-lang") == "python" + + test_client2 = AsyncWriter( + base_url=base_url, + api_key=api_key, + _strict_response_validation=True, + default_headers={ + "X-Foo": "stainless", + "X-Stainless-Lang": "my-overriding-header", + }, + ) + request = test_client2._build_request(FinalRequestOptions(method="get", url="/foo")) + assert request.headers.get("x-foo") == "stainless" + assert request.headers.get("x-stainless-lang") == "my-overriding-header" + + await test_client.close() + await test_client2.close() + + def test_validate_headers(self) -> None: + client = AsyncWriter(base_url=base_url, api_key=api_key, _strict_response_validation=True) + request = client._build_request(FinalRequestOptions(method="get", url="/foo")) + assert request.headers.get("Authorization") == f"Bearer {api_key}" + + with pytest.raises(WriterError): + with update_env(**{"WRITER_API_KEY": Omit()}): + client2 = AsyncWriter(base_url=base_url, api_key=None, _strict_response_validation=True) + _ = client2 + + async def test_default_query_option(self) -> None: + client = AsyncWriter( + base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"} + ) + request = client._build_request(FinalRequestOptions(method="get", url="/foo")) + url = httpx.URL(request.url) + assert dict(url.params) == {"query_param": "bar"} + + request = client._build_request( + FinalRequestOptions( + method="get", + url="/foo", + params={"foo": "baz", "query_param": "overridden"}, + ) + ) + url = httpx.URL(request.url) + assert dict(url.params) == {"foo": "baz", "query_param": "overridden"} + + await client.close() + + async def test_hardcoded_query_params_in_url(self, async_client: AsyncWriter) -> None: + request = async_client._build_request(FinalRequestOptions(method="get", url="/foo?beta=true")) + url = httpx.URL(request.url) + assert dict(url.params) == {"beta": "true"} + + request = async_client._build_request( + FinalRequestOptions( + method="get", + url="/foo?beta=true", + params={"limit": "10", "page": "abc"}, + ) + ) + url = httpx.URL(request.url) + assert dict(url.params) == {"beta": "true", "limit": "10", "page": "abc"} + + request = async_client._build_request( + FinalRequestOptions( + method="get", + url="/files/a%2Fb?beta=true", + params={"limit": "10"}, + ) + ) + assert request.url.raw_path == b"/files/a%2Fb?beta=true&limit=10" + + def test_request_extra_json(self, client: Writer) -> None: + request = client._build_request( + FinalRequestOptions( + method="post", + url="/foo", + json_data={"foo": "bar"}, + extra_json={"baz": False}, + ), + ) + data = json.loads(request.content.decode("utf-8")) + assert data == {"foo": "bar", "baz": False} + + request = client._build_request( + FinalRequestOptions( + method="post", + url="/foo", + extra_json={"baz": False}, + ), + ) + data = json.loads(request.content.decode("utf-8")) + assert data == {"baz": False} + + # `extra_json` takes priority over `json_data` when keys clash + request = client._build_request( + FinalRequestOptions( + method="post", + url="/foo", + json_data={"foo": "bar", "baz": True}, + extra_json={"baz": None}, + ), + ) + data = json.loads(request.content.decode("utf-8")) + assert data == {"foo": "bar", "baz": None} + + def test_request_extra_headers(self, client: Writer) -> None: + request = client._build_request( + FinalRequestOptions( + method="post", + url="/foo", + **make_request_options(extra_headers={"X-Foo": "Foo"}), + ), + ) + assert request.headers.get("X-Foo") == "Foo" + + # `extra_headers` takes priority over `default_headers` when keys clash + request = client.with_options(default_headers={"X-Bar": "true"})._build_request( + FinalRequestOptions( + method="post", + url="/foo", + **make_request_options( + extra_headers={"X-Bar": "false"}, + ), + ), + ) + assert request.headers.get("X-Bar") == "false" + + def test_request_extra_query(self, client: Writer) -> None: + request = client._build_request( + FinalRequestOptions( + method="post", + url="/foo", + **make_request_options( + extra_query={"my_query_param": "Foo"}, + ), + ), + ) + params = dict(request.url.params) + assert params == {"my_query_param": "Foo"} + + # if both `query` and `extra_query` are given, they are merged + request = client._build_request( + FinalRequestOptions( + method="post", + url="/foo", + **make_request_options( + query={"bar": "1"}, + extra_query={"foo": "2"}, + ), + ), + ) + params = dict(request.url.params) + assert params == {"bar": "1", "foo": "2"} + + # `extra_query` takes priority over `query` when keys clash + request = client._build_request( + FinalRequestOptions( + method="post", + url="/foo", + **make_request_options( + query={"foo": "1"}, + extra_query={"foo": "2"}, + ), + ), + ) + params = dict(request.url.params) + assert params == {"foo": "2"} + + def test_multipart_repeating_array(self, async_client: AsyncWriter) -> None: + request = async_client._build_request( + FinalRequestOptions.construct( + method="post", + url="/foo", + headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"}, + json_data={"array": ["foo", "bar"]}, + files=[("foo.txt", b"hello world")], + ) + ) + + assert request.read().split(b"\r\n") == [ + b"--6b7ba517decee4a450543ea6ae821c82", + b'Content-Disposition: form-data; name="array[]"', + b"", + b"foo", + b"--6b7ba517decee4a450543ea6ae821c82", + b'Content-Disposition: form-data; name="array[]"', + b"", + b"bar", + b"--6b7ba517decee4a450543ea6ae821c82", + b'Content-Disposition: form-data; name="foo.txt"; filename="upload"', + b"Content-Type: application/octet-stream", + b"", + b"hello world", + b"--6b7ba517decee4a450543ea6ae821c82--", + b"", + ] + + @pytest.mark.respx(base_url=base_url) + async def test_binary_content_upload(self, respx_mock: MockRouter, async_client: AsyncWriter) -> None: + respx_mock.post("/upload").mock(side_effect=mirror_request_content) + + file_content = b"Hello, this is a test file." + + response = await async_client.post( + "/upload", + content=file_content, + cast_to=httpx.Response, + options={"headers": {"Content-Type": "application/octet-stream"}}, + ) + + assert response.status_code == 200 + assert response.request.headers["Content-Type"] == "application/octet-stream" + assert response.content == file_content + + async def test_binary_content_upload_with_asynciterator(self) -> None: + file_content = b"Hello, this is a test file." + counter = Counter() + iterator = _make_async_iterator([file_content], counter=counter) + + async def mock_handler(request: httpx.Request) -> httpx.Response: + assert counter.value == 0, "the request body should not have been read" + return httpx.Response(200, content=await request.aread()) + + async with AsyncWriter( + base_url=base_url, + api_key=api_key, + _strict_response_validation=True, + http_client=httpx.AsyncClient(transport=MockTransport(handler=mock_handler)), + ) as client: + response = await client.post( + "/upload", + content=iterator, + cast_to=httpx.Response, + options={"headers": {"Content-Type": "application/octet-stream"}}, + ) + + assert response.status_code == 200 + assert response.request.headers["Content-Type"] == "application/octet-stream" + assert response.content == file_content + assert counter.value == 1 + + @pytest.mark.respx(base_url=base_url) + async def test_binary_content_upload_with_body_is_deprecated( + self, respx_mock: MockRouter, async_client: AsyncWriter + ) -> None: + respx_mock.post("/upload").mock(side_effect=mirror_request_content) + + file_content = b"Hello, this is a test file." + + with pytest.deprecated_call( + match="Passing raw bytes as `body` is deprecated and will be removed in a future version. Please pass raw bytes via the `content` parameter instead." + ): + response = await async_client.post( + "/upload", + body=file_content, + cast_to=httpx.Response, + options={"headers": {"Content-Type": "application/octet-stream"}}, + ) + + assert response.status_code == 200 + assert response.request.headers["Content-Type"] == "application/octet-stream" + assert response.content == file_content + + @pytest.mark.respx(base_url=base_url) + async def test_basic_union_response(self, respx_mock: MockRouter, async_client: AsyncWriter) -> None: + class Model1(BaseModel): + name: str + + class Model2(BaseModel): + foo: str + + respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) + + response = await async_client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) + assert isinstance(response, Model2) + assert response.foo == "bar" + + @pytest.mark.respx(base_url=base_url) + async def test_union_response_different_types(self, respx_mock: MockRouter, async_client: AsyncWriter) -> None: + """Union of objects with the same field name using a different type""" + + class Model1(BaseModel): + foo: int + + class Model2(BaseModel): + foo: str + + respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) + + response = await async_client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) + assert isinstance(response, Model2) + assert response.foo == "bar" + + respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1})) + + response = await async_client.get("/foo", cast_to=cast(Any, Union[Model1, Model2])) + assert isinstance(response, Model1) + assert response.foo == 1 + + @pytest.mark.respx(base_url=base_url) + async def test_non_application_json_content_type_for_json_data( + self, respx_mock: MockRouter, async_client: AsyncWriter + ) -> None: + """ + Response that sets Content-Type to something other than application/json but returns json data + """ + + class Model(BaseModel): + foo: int + + respx_mock.get("/foo").mock( + return_value=httpx.Response( + 200, + content=json.dumps({"foo": 2}), + headers={"Content-Type": "application/text"}, + ) + ) + + response = await async_client.get("/foo", cast_to=Model) + assert isinstance(response, Model) + assert response.foo == 2 + + async def test_base_url_setter(self) -> None: + client = AsyncWriter( + base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True + ) + assert client.base_url == "https://example.com/from_init/" + + client.base_url = "https://example.com/from_setter" # type: ignore[assignment] + + assert client.base_url == "https://example.com/from_setter/" + + await client.close() + + async def test_base_url_env(self) -> None: + with update_env(WRITER_BASE_URL="http://localhost:5000/from/env"): + client = AsyncWriter(api_key=api_key, _strict_response_validation=True) + assert client.base_url == "http://localhost:5000/from/env/" + + @pytest.mark.parametrize( + "client", + [ + AsyncWriter( + base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True + ), + AsyncWriter( + base_url="http://localhost:5000/custom/path/", + api_key=api_key, + _strict_response_validation=True, + http_client=httpx.AsyncClient(), + ), + ], + ids=["standard", "custom http client"], + ) + async def test_base_url_trailing_slash(self, client: AsyncWriter) -> None: + request = client._build_request( + FinalRequestOptions( + method="post", + url="/foo", + json_data={"foo": "bar"}, + ), + ) + assert request.url == "http://localhost:5000/custom/path/foo" + await client.close() + + @pytest.mark.parametrize( + "client", + [ + AsyncWriter( + base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True + ), + AsyncWriter( + base_url="http://localhost:5000/custom/path/", + api_key=api_key, + _strict_response_validation=True, + http_client=httpx.AsyncClient(), + ), + ], + ids=["standard", "custom http client"], + ) + async def test_base_url_no_trailing_slash(self, client: AsyncWriter) -> None: + request = client._build_request( + FinalRequestOptions( + method="post", + url="/foo", + json_data={"foo": "bar"}, + ), + ) + assert request.url == "http://localhost:5000/custom/path/foo" + await client.close() + + @pytest.mark.parametrize( + "client", + [ + AsyncWriter( + base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True + ), + AsyncWriter( + base_url="http://localhost:5000/custom/path/", + api_key=api_key, + _strict_response_validation=True, + http_client=httpx.AsyncClient(), + ), + ], + ids=["standard", "custom http client"], + ) + async def test_absolute_request_url(self, client: AsyncWriter) -> None: + request = client._build_request( + FinalRequestOptions( + method="post", + url="https://myapi.com/foo", + json_data={"foo": "bar"}, + ), + ) + assert request.url == "https://myapi.com/foo" + await client.close() + + async def test_copied_client_does_not_close_http(self) -> None: + test_client = AsyncWriter(base_url=base_url, api_key=api_key, _strict_response_validation=True) + assert not test_client.is_closed() + + copied = test_client.copy() + assert copied is not test_client + + del copied + + await asyncio.sleep(0.2) + assert not test_client.is_closed() + + async def test_client_context_manager(self) -> None: + test_client = AsyncWriter(base_url=base_url, api_key=api_key, _strict_response_validation=True) + async with test_client as c2: + assert c2 is test_client + assert not c2.is_closed() + assert not test_client.is_closed() + assert test_client.is_closed() + + @pytest.mark.respx(base_url=base_url) + async def test_client_response_validation_error(self, respx_mock: MockRouter, async_client: AsyncWriter) -> None: + class Model(BaseModel): + foo: str + + respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}})) + + with pytest.raises(APIResponseValidationError) as exc: + await async_client.get("/foo", cast_to=Model) + + assert isinstance(exc.value.__cause__, ValidationError) + + async def test_client_max_retries_validation(self) -> None: + with pytest.raises(TypeError, match=r"max_retries cannot be None"): + AsyncWriter( + base_url=base_url, api_key=api_key, _strict_response_validation=True, max_retries=cast(Any, None) + ) + + @pytest.mark.respx(base_url=base_url) + async def test_default_stream_cls(self, respx_mock: MockRouter, async_client: AsyncWriter) -> None: + class Model(BaseModel): + name: str + + respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) + + stream = await async_client.post("/foo", cast_to=Model, stream=True, stream_cls=AsyncStream[Model]) + assert isinstance(stream, AsyncStream) + await stream.response.aclose() + + @pytest.mark.respx(base_url=base_url) + async def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None: + class Model(BaseModel): + name: str + + respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format")) + + strict_client = AsyncWriter(base_url=base_url, api_key=api_key, _strict_response_validation=True) + + with pytest.raises(APIResponseValidationError): + await strict_client.get("/foo", cast_to=Model) + + non_strict_client = AsyncWriter(base_url=base_url, api_key=api_key, _strict_response_validation=False) + + response = await non_strict_client.get("/foo", cast_to=Model) + assert isinstance(response, str) # type: ignore[unreachable] + + await strict_client.close() + await non_strict_client.close() + + @pytest.mark.parametrize( + "remaining_retries,retry_after,timeout", + [ + [3, "20", 20], + [3, "0", 1], + [3, "-10", 1], + [3, "60", 60], + [3, "61", 1], + [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20], + [3, "Fri, 29 Sep 2023 16:26:37 GMT", 1], + [3, "Fri, 29 Sep 2023 16:26:27 GMT", 1], + [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60], + [3, "Fri, 29 Sep 2023 16:27:38 GMT", 1], + [3, "99999999999999999999999999999999999", 1], + [3, "Zun, 29 Sep 2023 16:26:27 GMT", 1], + [3, "", 1], + [2, "", 1 * 2.0], + [1, "", 1 * 4.0], + [-1100, "", 60], # test large number potentially overflowing + ], + ) + @mock.patch("time.time", mock.MagicMock(return_value=1696004797)) + async def test_parse_retry_after_header( + self, remaining_retries: int, retry_after: str, timeout: float, async_client: AsyncWriter + ) -> None: + headers = httpx.Headers({"retry-after": retry_after}) + options = FinalRequestOptions(method="get", url="/foo", max_retries=3) + calculated = async_client._calculate_retry_timeout(remaining_retries, options, headers) + assert calculated == pytest.approx(timeout, 1 * 0.875) # pyright: ignore[reportUnknownMemberType] + + @mock.patch("writerai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) + @pytest.mark.respx(base_url=base_url) + async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter, async_client: AsyncWriter) -> None: + respx_mock.post("/v1/chat").mock(side_effect=httpx.TimeoutException("Test timeout error")) + + with pytest.raises(APITimeoutError): + await async_client.chat.with_streaming_response.chat( + messages=[{"role": "user"}], model="model" + ).__aenter__() + + assert _get_open_connections(async_client) == 0 + + @mock.patch("writerai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) + @pytest.mark.respx(base_url=base_url) + async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter, async_client: AsyncWriter) -> None: + respx_mock.post("/v1/chat").mock(return_value=httpx.Response(500)) + + with pytest.raises(APIStatusError): + await async_client.chat.with_streaming_response.chat( + messages=[{"role": "user"}], model="model" + ).__aenter__() + assert _get_open_connections(async_client) == 0 + + @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) + @mock.patch("writerai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) + @pytest.mark.respx(base_url=base_url) + @pytest.mark.parametrize("failure_mode", ["status", "exception"]) + async def test_retries_taken( + self, + async_client: AsyncWriter, + failures_before_success: int, + failure_mode: Literal["status", "exception"], + respx_mock: MockRouter, + ) -> None: + client = async_client.with_options(max_retries=4) + + nb_retries = 0 + + def retry_handler(_request: httpx.Request) -> httpx.Response: + nonlocal nb_retries + if nb_retries < failures_before_success: + nb_retries += 1 + if failure_mode == "exception": + raise RuntimeError("oops") + return httpx.Response(500) + return httpx.Response(200) + + respx_mock.post("/v1/chat").mock(side_effect=retry_handler) + + response = await client.chat.with_raw_response.chat(messages=[{"role": "user"}], model="model") + + assert response.retries_taken == failures_before_success + assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success + + @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) + @mock.patch("writerai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) + @pytest.mark.respx(base_url=base_url) + async def test_omit_retry_count_header( + self, async_client: AsyncWriter, failures_before_success: int, respx_mock: MockRouter + ) -> None: + client = async_client.with_options(max_retries=4) + + nb_retries = 0 + + def retry_handler(_request: httpx.Request) -> httpx.Response: + nonlocal nb_retries + if nb_retries < failures_before_success: + nb_retries += 1 + return httpx.Response(500) + return httpx.Response(200) + + respx_mock.post("/v1/chat").mock(side_effect=retry_handler) + + response = await client.chat.with_raw_response.chat( + messages=[{"role": "user"}], model="model", extra_headers={"x-stainless-retry-count": Omit()} + ) + + assert len(response.http_request.headers.get_list("x-stainless-retry-count")) == 0 + + @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) + @mock.patch("writerai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) + @pytest.mark.respx(base_url=base_url) + async def test_overwrite_retry_count_header( + self, async_client: AsyncWriter, failures_before_success: int, respx_mock: MockRouter + ) -> None: + client = async_client.with_options(max_retries=4) + + nb_retries = 0 + + def retry_handler(_request: httpx.Request) -> httpx.Response: + nonlocal nb_retries + if nb_retries < failures_before_success: + nb_retries += 1 + return httpx.Response(500) + return httpx.Response(200) + + respx_mock.post("/v1/chat").mock(side_effect=retry_handler) + + response = await client.chat.with_raw_response.chat( + messages=[{"role": "user"}], model="model", extra_headers={"x-stainless-retry-count": "42"} + ) + + assert response.http_request.headers.get("x-stainless-retry-count") == "42" + + async def test_get_platform(self) -> None: + platform = await asyncify(get_platform)() + assert isinstance(platform, (str, OtherPlatform)) + + async def test_proxy_environment_variables(self, monkeypatch: pytest.MonkeyPatch) -> None: + # Test that the proxy environment variables are set correctly + monkeypatch.setenv("HTTPS_PROXY", "https://example.org") + # Delete in case our environment has any proxy env vars set + monkeypatch.delenv("HTTP_PROXY", raising=False) + monkeypatch.delenv("ALL_PROXY", raising=False) + monkeypatch.delenv("NO_PROXY", raising=False) + monkeypatch.delenv("http_proxy", raising=False) + monkeypatch.delenv("https_proxy", raising=False) + monkeypatch.delenv("all_proxy", raising=False) + monkeypatch.delenv("no_proxy", raising=False) + + client = DefaultAsyncHttpxClient() + + mounts = tuple(client._mounts.items()) + assert len(mounts) == 1 + assert mounts[0][0].pattern == "https://" + + @pytest.mark.filterwarnings("ignore:.*deprecated.*:DeprecationWarning") + async def test_default_client_creation(self) -> None: + # Ensure that the client can be initialized without any exceptions + DefaultAsyncHttpxClient( + verify=True, + cert=None, + trust_env=True, + http1=True, + http2=False, + limits=httpx.Limits(max_connections=100, max_keepalive_connections=20), + ) + + @pytest.mark.respx(base_url=base_url) + async def test_follow_redirects(self, respx_mock: MockRouter, async_client: AsyncWriter) -> None: + # Test that the default follow_redirects=True allows following redirects + respx_mock.post("/redirect").mock( + return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"}) + ) + respx_mock.get("/redirected").mock(return_value=httpx.Response(200, json={"status": "ok"})) + + response = await async_client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response) + assert response.status_code == 200 + assert response.json() == {"status": "ok"} + + @pytest.mark.respx(base_url=base_url) + async def test_follow_redirects_disabled(self, respx_mock: MockRouter, async_client: AsyncWriter) -> None: + # Test that follow_redirects=False prevents following redirects + respx_mock.post("/redirect").mock( + return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"}) + ) + + with pytest.raises(APIStatusError) as exc_info: + await async_client.post( + "/redirect", body={"key": "value"}, options={"follow_redirects": False}, cast_to=httpx.Response + ) + + assert exc_info.value.response.status_code == 302 + assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected" diff --git a/tests/test_extract_files.py b/tests/test_extract_files.py new file mode 100644 index 00000000..0ad3886c --- /dev/null +++ b/tests/test_extract_files.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +from typing import Sequence + +import pytest + +from writerai._types import FileTypes, ArrayFormat +from writerai._utils import extract_files + + +def test_removes_files_from_input() -> None: + query = {"foo": "bar"} + assert extract_files(query, paths=[]) == [] + assert query == {"foo": "bar"} + + query2 = {"foo": b"Bar", "hello": "world"} + assert extract_files(query2, paths=[["foo"]]) == [("foo", b"Bar")] + assert query2 == {"hello": "world"} + + query3 = {"foo": {"foo": {"bar": b"Bar"}}, "hello": "world"} + assert extract_files(query3, paths=[["foo", "foo", "bar"]]) == [("foo[foo][bar]", b"Bar")] + assert query3 == {"foo": {"foo": {}}, "hello": "world"} + + query4 = {"foo": {"bar": b"Bar", "baz": "foo"}, "hello": "world"} + assert extract_files(query4, paths=[["foo", "bar"]]) == [("foo[bar]", b"Bar")] + assert query4 == {"hello": "world", "foo": {"baz": "foo"}} + + +def test_multiple_files() -> None: + query = {"documents": [{"file": b"My first file"}, {"file": b"My second file"}]} + assert extract_files(query, paths=[["documents", "", "file"]]) == [ + ("documents[][file]", b"My first file"), + ("documents[][file]", b"My second file"), + ] + assert query == {"documents": [{}, {}]} + + +def test_top_level_file_array() -> None: + query = {"files": [b"file one", b"file two"], "title": "hello"} + assert extract_files(query, paths=[["files", ""]]) == [("files[]", b"file one"), ("files[]", b"file two")] + assert query == {"title": "hello"} + + +@pytest.mark.parametrize( + "query,paths,expected", + [ + [ + {"foo": {"bar": "baz"}}, + [["foo", "", "bar"]], + [], + ], + [ + {"foo": ["bar", "baz"]}, + [["foo", "bar"]], + [], + ], + [ + {"foo": {"bar": "baz"}}, + [["foo", "foo"]], + [], + ], + ], + ids=["dict expecting array", "array expecting dict", "unknown keys"], +) +def test_ignores_incorrect_paths( + query: dict[str, object], + paths: Sequence[Sequence[str]], + expected: list[tuple[str, FileTypes]], +) -> None: + assert extract_files(query, paths=paths) == expected + + +@pytest.mark.parametrize( + "array_format,expected_top_level,expected_nested", + [ + ("brackets", [("files[]", b"a"), ("files[]", b"b")], [("items[][file]", b"a"), ("items[][file]", b"b")]), + ("repeat", [("files", b"a"), ("files", b"b")], [("items[file]", b"a"), ("items[file]", b"b")]), + ("comma", [("files", b"a"), ("files", b"b")], [("items[file]", b"a"), ("items[file]", b"b")]), + ("indices", [("files[0]", b"a"), ("files[1]", b"b")], [("items[0][file]", b"a"), ("items[1][file]", b"b")]), + ], +) +def test_array_format_controls_file_field_names( + array_format: ArrayFormat, + expected_top_level: list[tuple[str, FileTypes]], + expected_nested: list[tuple[str, FileTypes]], +) -> None: + top_level = {"files": [b"a", b"b"]} + assert extract_files(top_level, paths=[["files", ""]], array_format=array_format) == expected_top_level + + nested = {"items": [{"file": b"a"}, {"file": b"b"}]} + assert extract_files(nested, paths=[["items", "", "file"]], array_format=array_format) == expected_nested diff --git a/tests/test_files.py b/tests/test_files.py new file mode 100644 index 00000000..e37bcded --- /dev/null +++ b/tests/test_files.py @@ -0,0 +1,148 @@ +from pathlib import Path + +import anyio +import pytest +from dirty_equals import IsDict, IsList, IsBytes, IsTuple + +from writerai._files import to_httpx_files, deepcopy_with_paths, async_to_httpx_files +from writerai._utils import extract_files + +readme_path = Path(__file__).parent.parent.joinpath("README.md") + + +def test_pathlib_includes_file_name() -> None: + result = to_httpx_files({"file": readme_path}) + print(result) + assert result == IsDict({"file": IsTuple("README.md", IsBytes())}) + + +def test_tuple_input() -> None: + result = to_httpx_files([("file", readme_path)]) + print(result) + assert result == IsList(IsTuple("file", IsTuple("README.md", IsBytes()))) + + +@pytest.mark.asyncio +async def test_async_pathlib_includes_file_name() -> None: + result = await async_to_httpx_files({"file": readme_path}) + print(result) + assert result == IsDict({"file": IsTuple("README.md", IsBytes())}) + + +@pytest.mark.asyncio +async def test_async_supports_anyio_path() -> None: + result = await async_to_httpx_files({"file": anyio.Path(readme_path)}) + print(result) + assert result == IsDict({"file": IsTuple("README.md", IsBytes())}) + + +@pytest.mark.asyncio +async def test_async_tuple_input() -> None: + result = await async_to_httpx_files([("file", readme_path)]) + print(result) + assert result == IsList(IsTuple("file", IsTuple("README.md", IsBytes()))) + + +def test_string_not_allowed() -> None: + with pytest.raises(TypeError, match="Expected file types input to be a FileContent type or to be a tuple"): + to_httpx_files( + { + "file": "foo", # type: ignore + } + ) + + +def assert_different_identities(obj1: object, obj2: object) -> None: + assert obj1 == obj2 + assert obj1 is not obj2 + + +class TestDeepcopyWithPaths: + def test_copies_top_level_dict(self) -> None: + original = {"file": b"data", "other": "value"} + result = deepcopy_with_paths(original, [["file"]]) + assert_different_identities(result, original) + + def test_file_value_is_same_reference(self) -> None: + file_bytes = b"contents" + original = {"file": file_bytes} + result = deepcopy_with_paths(original, [["file"]]) + assert_different_identities(result, original) + assert result["file"] is file_bytes + + def test_list_popped_wholesale(self) -> None: + files = [b"f1", b"f2"] + original = {"files": files, "title": "t"} + result = deepcopy_with_paths(original, [["files", ""]]) + assert_different_identities(result, original) + result_files = result["files"] + assert isinstance(result_files, list) + assert_different_identities(result_files, files) + + def test_nested_array_path_copies_list_and_elements(self) -> None: + elem1 = {"file": b"f1", "extra": 1} + elem2 = {"file": b"f2", "extra": 2} + original = {"items": [elem1, elem2]} + result = deepcopy_with_paths(original, [["items", "", "file"]]) + assert_different_identities(result, original) + result_items = result["items"] + assert isinstance(result_items, list) + assert_different_identities(result_items, original["items"]) + assert_different_identities(result_items[0], elem1) + assert_different_identities(result_items[1], elem2) + + def test_empty_paths_returns_same_object(self) -> None: + original = {"foo": "bar"} + result = deepcopy_with_paths(original, []) + assert result is original + + def test_multiple_paths(self) -> None: + f1 = b"file1" + f2 = b"file2" + original = {"a": f1, "b": f2, "c": "unchanged"} + result = deepcopy_with_paths(original, [["a"], ["b"]]) + assert_different_identities(result, original) + assert result["a"] is f1 + assert result["b"] is f2 + assert result["c"] is original["c"] + + def test_extract_files_does_not_mutate_original_top_level(self) -> None: + file_bytes = b"contents" + original = {"file": file_bytes, "other": "value"} + + copied = deepcopy_with_paths(original, [["file"]]) + extracted = extract_files(copied, paths=[["file"]]) + + assert extracted == [("file", file_bytes)] + assert original == {"file": file_bytes, "other": "value"} + assert copied == {"other": "value"} + + def test_extract_files_does_not_mutate_original_nested_array_path(self) -> None: + file1 = b"f1" + file2 = b"f2" + original = { + "items": [ + {"file": file1, "extra": 1}, + {"file": file2, "extra": 2}, + ], + "title": "example", + } + + copied = deepcopy_with_paths(original, [["items", "", "file"]]) + extracted = extract_files(copied, paths=[["items", "", "file"]]) + + assert [entry for _, entry in extracted] == [file1, file2] + assert original == { + "items": [ + {"file": file1, "extra": 1}, + {"file": file2, "extra": 2}, + ], + "title": "example", + } + assert copied == { + "items": [ + {"extra": 1}, + {"extra": 2}, + ], + "title": "example", + } diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 00000000..5be5d2fc --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,1017 @@ +import json +from typing import TYPE_CHECKING, Any, Dict, List, Union, Iterable, Optional, cast +from datetime import datetime, timezone +from collections import deque +from typing_extensions import Literal, Annotated, TypedDict, TypeAliasType + +import pytest +import pydantic +from pydantic import Field + +from writerai._utils import PropertyInfo +from writerai._compat import PYDANTIC_V1, parse_obj, model_dump, model_json +from writerai._models import DISCRIMINATOR_CACHE, BaseModel, EagerIterable, construct_type + + +class BasicModel(BaseModel): + foo: str + + +@pytest.mark.parametrize("value", ["hello", 1], ids=["correct type", "mismatched"]) +def test_basic(value: object) -> None: + m = BasicModel.construct(foo=value) + assert m.foo == value + + +def test_directly_nested_model() -> None: + class NestedModel(BaseModel): + nested: BasicModel + + m = NestedModel.construct(nested={"foo": "Foo!"}) + assert m.nested.foo == "Foo!" + + # mismatched types + m = NestedModel.construct(nested="hello!") + assert cast(Any, m.nested) == "hello!" + + +def test_optional_nested_model() -> None: + class NestedModel(BaseModel): + nested: Optional[BasicModel] + + m1 = NestedModel.construct(nested=None) + assert m1.nested is None + + m2 = NestedModel.construct(nested={"foo": "bar"}) + assert m2.nested is not None + assert m2.nested.foo == "bar" + + # mismatched types + m3 = NestedModel.construct(nested={"foo"}) + assert isinstance(cast(Any, m3.nested), set) + assert cast(Any, m3.nested) == {"foo"} + + +def test_list_nested_model() -> None: + class NestedModel(BaseModel): + nested: List[BasicModel] + + m = NestedModel.construct(nested=[{"foo": "bar"}, {"foo": "2"}]) + assert m.nested is not None + assert isinstance(m.nested, list) + assert len(m.nested) == 2 + assert m.nested[0].foo == "bar" + assert m.nested[1].foo == "2" + + # mismatched types + m = NestedModel.construct(nested=True) + assert cast(Any, m.nested) is True + + m = NestedModel.construct(nested=[False]) + assert cast(Any, m.nested) == [False] + + +def test_optional_list_nested_model() -> None: + class NestedModel(BaseModel): + nested: Optional[List[BasicModel]] + + m1 = NestedModel.construct(nested=[{"foo": "bar"}, {"foo": "2"}]) + assert m1.nested is not None + assert isinstance(m1.nested, list) + assert len(m1.nested) == 2 + assert m1.nested[0].foo == "bar" + assert m1.nested[1].foo == "2" + + m2 = NestedModel.construct(nested=None) + assert m2.nested is None + + # mismatched types + m3 = NestedModel.construct(nested={1}) + assert cast(Any, m3.nested) == {1} + + m4 = NestedModel.construct(nested=[False]) + assert cast(Any, m4.nested) == [False] + + +def test_list_optional_items_nested_model() -> None: + class NestedModel(BaseModel): + nested: List[Optional[BasicModel]] + + m = NestedModel.construct(nested=[None, {"foo": "bar"}]) + assert m.nested is not None + assert isinstance(m.nested, list) + assert len(m.nested) == 2 + assert m.nested[0] is None + assert m.nested[1] is not None + assert m.nested[1].foo == "bar" + + # mismatched types + m3 = NestedModel.construct(nested="foo") + assert cast(Any, m3.nested) == "foo" + + m4 = NestedModel.construct(nested=[False]) + assert cast(Any, m4.nested) == [False] + + +def test_list_mismatched_type() -> None: + class NestedModel(BaseModel): + nested: List[str] + + m = NestedModel.construct(nested=False) + assert cast(Any, m.nested) is False + + +def test_raw_dictionary() -> None: + class NestedModel(BaseModel): + nested: Dict[str, str] + + m = NestedModel.construct(nested={"hello": "world"}) + assert m.nested == {"hello": "world"} + + # mismatched types + m = NestedModel.construct(nested=False) + assert cast(Any, m.nested) is False + + +def test_nested_dictionary_model() -> None: + class NestedModel(BaseModel): + nested: Dict[str, BasicModel] + + m = NestedModel.construct(nested={"hello": {"foo": "bar"}}) + assert isinstance(m.nested, dict) + assert m.nested["hello"].foo == "bar" + + # mismatched types + m = NestedModel.construct(nested={"hello": False}) + assert cast(Any, m.nested["hello"]) is False + + +def test_unknown_fields() -> None: + m1 = BasicModel.construct(foo="foo", unknown=1) + assert m1.foo == "foo" + assert cast(Any, m1).unknown == 1 + + m2 = BasicModel.construct(foo="foo", unknown={"foo_bar": True}) + assert m2.foo == "foo" + assert cast(Any, m2).unknown == {"foo_bar": True} + + assert model_dump(m2) == {"foo": "foo", "unknown": {"foo_bar": True}} + + +def test_strict_validation_unknown_fields() -> None: + class Model(BaseModel): + foo: str + + model = parse_obj(Model, dict(foo="hello!", user="Robert")) + assert model.foo == "hello!" + assert cast(Any, model).user == "Robert" + + assert model_dump(model) == {"foo": "hello!", "user": "Robert"} + + +def test_aliases() -> None: + class Model(BaseModel): + my_field: int = Field(alias="myField") + + m = Model.construct(myField=1) + assert m.my_field == 1 + + # mismatched types + m = Model.construct(myField={"hello": False}) + assert cast(Any, m.my_field) == {"hello": False} + + +def test_repr() -> None: + model = BasicModel(foo="bar") + assert str(model) == "BasicModel(foo='bar')" + assert repr(model) == "BasicModel(foo='bar')" + + +def test_repr_nested_model() -> None: + class Child(BaseModel): + name: str + age: int + + class Parent(BaseModel): + name: str + child: Child + + model = Parent(name="Robert", child=Child(name="Foo", age=5)) + assert str(model) == "Parent(name='Robert', child=Child(name='Foo', age=5))" + assert repr(model) == "Parent(name='Robert', child=Child(name='Foo', age=5))" + + +def test_optional_list() -> None: + class Submodel(BaseModel): + name: str + + class Model(BaseModel): + items: Optional[List[Submodel]] + + m = Model.construct(items=None) + assert m.items is None + + m = Model.construct(items=[]) + assert m.items == [] + + m = Model.construct(items=[{"name": "Robert"}]) + assert m.items is not None + assert len(m.items) == 1 + assert m.items[0].name == "Robert" + + +def test_nested_union_of_models() -> None: + class Submodel1(BaseModel): + bar: bool + + class Submodel2(BaseModel): + thing: str + + class Model(BaseModel): + foo: Union[Submodel1, Submodel2] + + m = Model.construct(foo={"thing": "hello"}) + assert isinstance(m.foo, Submodel2) + assert m.foo.thing == "hello" + + +def test_nested_union_of_mixed_types() -> None: + class Submodel1(BaseModel): + bar: bool + + class Model(BaseModel): + foo: Union[Submodel1, Literal[True], Literal["CARD_HOLDER"]] + + m = Model.construct(foo=True) + assert m.foo is True + + m = Model.construct(foo="CARD_HOLDER") + assert m.foo == "CARD_HOLDER" + + m = Model.construct(foo={"bar": False}) + assert isinstance(m.foo, Submodel1) + assert m.foo.bar is False + + +def test_nested_union_multiple_variants() -> None: + class Submodel1(BaseModel): + bar: bool + + class Submodel2(BaseModel): + thing: str + + class Submodel3(BaseModel): + foo: int + + class Model(BaseModel): + foo: Union[Submodel1, Submodel2, None, Submodel3] + + m = Model.construct(foo={"thing": "hello"}) + assert isinstance(m.foo, Submodel2) + assert m.foo.thing == "hello" + + m = Model.construct(foo=None) + assert m.foo is None + + m = Model.construct() + assert m.foo is None + + m = Model.construct(foo={"foo": "1"}) + assert isinstance(m.foo, Submodel3) + assert m.foo.foo == 1 + + +def test_nested_union_invalid_data() -> None: + class Submodel1(BaseModel): + level: int + + class Submodel2(BaseModel): + name: str + + class Model(BaseModel): + foo: Union[Submodel1, Submodel2] + + m = Model.construct(foo=True) + assert cast(bool, m.foo) is True + + m = Model.construct(foo={"name": 3}) + if PYDANTIC_V1: + assert isinstance(m.foo, Submodel2) + assert m.foo.name == "3" + else: + assert isinstance(m.foo, Submodel1) + assert m.foo.name == 3 # type: ignore + + +def test_list_of_unions() -> None: + class Submodel1(BaseModel): + level: int + + class Submodel2(BaseModel): + name: str + + class Model(BaseModel): + items: List[Union[Submodel1, Submodel2]] + + m = Model.construct(items=[{"level": 1}, {"name": "Robert"}]) + assert len(m.items) == 2 + assert isinstance(m.items[0], Submodel1) + assert m.items[0].level == 1 + assert isinstance(m.items[1], Submodel2) + assert m.items[1].name == "Robert" + + m = Model.construct(items=[{"level": -1}, 156]) + assert len(m.items) == 2 + assert isinstance(m.items[0], Submodel1) + assert m.items[0].level == -1 + assert cast(Any, m.items[1]) == 156 + + +def test_union_of_lists() -> None: + class SubModel1(BaseModel): + level: int + + class SubModel2(BaseModel): + name: str + + class Model(BaseModel): + items: Union[List[SubModel1], List[SubModel2]] + + # with one valid entry + m = Model.construct(items=[{"name": "Robert"}]) + assert len(m.items) == 1 + assert isinstance(m.items[0], SubModel2) + assert m.items[0].name == "Robert" + + # with two entries pointing to different types + m = Model.construct(items=[{"level": 1}, {"name": "Robert"}]) + assert len(m.items) == 2 + assert isinstance(m.items[0], SubModel1) + assert m.items[0].level == 1 + assert isinstance(m.items[1], SubModel1) + assert cast(Any, m.items[1]).name == "Robert" + + # with two entries pointing to *completely* different types + m = Model.construct(items=[{"level": -1}, 156]) + assert len(m.items) == 2 + assert isinstance(m.items[0], SubModel1) + assert m.items[0].level == -1 + assert cast(Any, m.items[1]) == 156 + + +def test_dict_of_union() -> None: + class SubModel1(BaseModel): + name: str + + class SubModel2(BaseModel): + foo: str + + class Model(BaseModel): + data: Dict[str, Union[SubModel1, SubModel2]] + + m = Model.construct(data={"hello": {"name": "there"}, "foo": {"foo": "bar"}}) + assert len(list(m.data.keys())) == 2 + assert isinstance(m.data["hello"], SubModel1) + assert m.data["hello"].name == "there" + assert isinstance(m.data["foo"], SubModel2) + assert m.data["foo"].foo == "bar" + + # TODO: test mismatched type + + +def test_double_nested_union() -> None: + class SubModel1(BaseModel): + name: str + + class SubModel2(BaseModel): + bar: str + + class Model(BaseModel): + data: Dict[str, List[Union[SubModel1, SubModel2]]] + + m = Model.construct(data={"foo": [{"bar": "baz"}, {"name": "Robert"}]}) + assert len(m.data["foo"]) == 2 + + entry1 = m.data["foo"][0] + assert isinstance(entry1, SubModel2) + assert entry1.bar == "baz" + + entry2 = m.data["foo"][1] + assert isinstance(entry2, SubModel1) + assert entry2.name == "Robert" + + # TODO: test mismatched type + + +def test_union_of_dict() -> None: + class SubModel1(BaseModel): + name: str + + class SubModel2(BaseModel): + foo: str + + class Model(BaseModel): + data: Union[Dict[str, SubModel1], Dict[str, SubModel2]] + + m = Model.construct(data={"hello": {"name": "there"}, "foo": {"foo": "bar"}}) + assert len(list(m.data.keys())) == 2 + assert isinstance(m.data["hello"], SubModel1) + assert m.data["hello"].name == "there" + assert isinstance(m.data["foo"], SubModel1) + assert cast(Any, m.data["foo"]).foo == "bar" + + +def test_iso8601_datetime() -> None: + class Model(BaseModel): + created_at: datetime + + expected = datetime(2019, 12, 27, 18, 11, 19, 117000, tzinfo=timezone.utc) + + if PYDANTIC_V1: + expected_json = '{"created_at": "2019-12-27T18:11:19.117000+00:00"}' + else: + expected_json = '{"created_at":"2019-12-27T18:11:19.117000Z"}' + + model = Model.construct(created_at="2019-12-27T18:11:19.117Z") + assert model.created_at == expected + assert model_json(model) == expected_json + + model = parse_obj(Model, dict(created_at="2019-12-27T18:11:19.117Z")) + assert model.created_at == expected + assert model_json(model) == expected_json + + +def test_does_not_coerce_int() -> None: + class Model(BaseModel): + bar: int + + assert Model.construct(bar=1).bar == 1 + assert Model.construct(bar=10.9).bar == 10.9 + assert Model.construct(bar="19").bar == "19" # type: ignore[comparison-overlap] + assert Model.construct(bar=False).bar is False + + +def test_int_to_float_safe_conversion() -> None: + class Model(BaseModel): + float_field: float + + m = Model.construct(float_field=10) + assert m.float_field == 10.0 + assert isinstance(m.float_field, float) + + m = Model.construct(float_field=10.12) + assert m.float_field == 10.12 + assert isinstance(m.float_field, float) + + # number too big + m = Model.construct(float_field=2**53 + 1) + assert m.float_field == 2**53 + 1 + assert isinstance(m.float_field, int) + + +def test_deprecated_alias() -> None: + class Model(BaseModel): + resource_id: str = Field(alias="model_id") + + @property + def model_id(self) -> str: + return self.resource_id + + m = Model.construct(model_id="id") + assert m.model_id == "id" + assert m.resource_id == "id" + assert m.resource_id is m.model_id + + m = parse_obj(Model, {"model_id": "id"}) + assert m.model_id == "id" + assert m.resource_id == "id" + assert m.resource_id is m.model_id + + +def test_omitted_fields() -> None: + class Model(BaseModel): + resource_id: Optional[str] = None + + m = Model.construct() + assert m.resource_id is None + assert "resource_id" not in m.model_fields_set + + m = Model.construct(resource_id=None) + assert m.resource_id is None + assert "resource_id" in m.model_fields_set + + m = Model.construct(resource_id="foo") + assert m.resource_id == "foo" + assert "resource_id" in m.model_fields_set + + +def test_to_dict() -> None: + class Model(BaseModel): + foo: Optional[str] = Field(alias="FOO", default=None) + + m = Model(FOO="hello") + assert m.to_dict() == {"FOO": "hello"} + assert m.to_dict(use_api_names=False) == {"foo": "hello"} + + m2 = Model() + assert m2.to_dict() == {} + assert m2.to_dict(exclude_unset=False) == {"FOO": None} + assert m2.to_dict(exclude_unset=False, exclude_none=True) == {} + assert m2.to_dict(exclude_unset=False, exclude_defaults=True) == {} + + m3 = Model(FOO=None) + assert m3.to_dict() == {"FOO": None} + assert m3.to_dict(exclude_none=True) == {} + assert m3.to_dict(exclude_defaults=True) == {} + + class Model2(BaseModel): + created_at: datetime + + time_str = "2024-03-21T11:39:01.275859" + m4 = Model2.construct(created_at=time_str) + assert m4.to_dict(mode="python") == {"created_at": datetime.fromisoformat(time_str)} + assert m4.to_dict(mode="json") == {"created_at": time_str} + + if PYDANTIC_V1: + with pytest.raises(ValueError, match="warnings is only supported in Pydantic v2"): + m.to_dict(warnings=False) + + +def test_forwards_compat_model_dump_method() -> None: + class Model(BaseModel): + foo: Optional[str] = Field(alias="FOO", default=None) + + m = Model(FOO="hello") + assert m.model_dump() == {"foo": "hello"} + assert m.model_dump(include={"bar"}) == {} + assert m.model_dump(exclude={"foo"}) == {} + assert m.model_dump(by_alias=True) == {"FOO": "hello"} + + m2 = Model() + assert m2.model_dump() == {"foo": None} + assert m2.model_dump(exclude_unset=True) == {} + assert m2.model_dump(exclude_none=True) == {} + assert m2.model_dump(exclude_defaults=True) == {} + + m3 = Model(FOO=None) + assert m3.model_dump() == {"foo": None} + assert m3.model_dump(exclude_none=True) == {} + + if PYDANTIC_V1: + with pytest.raises(ValueError, match="round_trip is only supported in Pydantic v2"): + m.model_dump(round_trip=True) + + with pytest.raises(ValueError, match="warnings is only supported in Pydantic v2"): + m.model_dump(warnings=False) + + +def test_compat_method_no_error_for_warnings() -> None: + class Model(BaseModel): + foo: Optional[str] + + m = Model(foo="hello") + assert isinstance(model_dump(m, warnings=False), dict) + + +def test_to_json() -> None: + class Model(BaseModel): + foo: Optional[str] = Field(alias="FOO", default=None) + + m = Model(FOO="hello") + assert json.loads(m.to_json()) == {"FOO": "hello"} + assert json.loads(m.to_json(use_api_names=False)) == {"foo": "hello"} + + if PYDANTIC_V1: + assert m.to_json(indent=None) == '{"FOO": "hello"}' + else: + assert m.to_json(indent=None) == '{"FOO":"hello"}' + + m2 = Model() + assert json.loads(m2.to_json()) == {} + assert json.loads(m2.to_json(exclude_unset=False)) == {"FOO": None} + assert json.loads(m2.to_json(exclude_unset=False, exclude_none=True)) == {} + assert json.loads(m2.to_json(exclude_unset=False, exclude_defaults=True)) == {} + + m3 = Model(FOO=None) + assert json.loads(m3.to_json()) == {"FOO": None} + assert json.loads(m3.to_json(exclude_none=True)) == {} + + if PYDANTIC_V1: + with pytest.raises(ValueError, match="warnings is only supported in Pydantic v2"): + m.to_json(warnings=False) + + +def test_forwards_compat_model_dump_json_method() -> None: + class Model(BaseModel): + foo: Optional[str] = Field(alias="FOO", default=None) + + m = Model(FOO="hello") + assert json.loads(m.model_dump_json()) == {"foo": "hello"} + assert json.loads(m.model_dump_json(include={"bar"})) == {} + assert json.loads(m.model_dump_json(include={"foo"})) == {"foo": "hello"} + assert json.loads(m.model_dump_json(by_alias=True)) == {"FOO": "hello"} + + assert m.model_dump_json(indent=2) == '{\n "foo": "hello"\n}' + + m2 = Model() + assert json.loads(m2.model_dump_json()) == {"foo": None} + assert json.loads(m2.model_dump_json(exclude_unset=True)) == {} + assert json.loads(m2.model_dump_json(exclude_none=True)) == {} + assert json.loads(m2.model_dump_json(exclude_defaults=True)) == {} + + m3 = Model(FOO=None) + assert json.loads(m3.model_dump_json()) == {"foo": None} + assert json.loads(m3.model_dump_json(exclude_none=True)) == {} + + if PYDANTIC_V1: + with pytest.raises(ValueError, match="round_trip is only supported in Pydantic v2"): + m.model_dump_json(round_trip=True) + + with pytest.raises(ValueError, match="warnings is only supported in Pydantic v2"): + m.model_dump_json(warnings=False) + + +def test_type_compat() -> None: + # our model type can be assigned to Pydantic's model type + + def takes_pydantic(model: pydantic.BaseModel) -> None: # noqa: ARG001 + ... + + class OurModel(BaseModel): + foo: Optional[str] = None + + takes_pydantic(OurModel()) + + +def test_annotated_types() -> None: + class Model(BaseModel): + value: str + + m = construct_type( + value={"value": "foo"}, + type_=cast(Any, Annotated[Model, "random metadata"]), + ) + assert isinstance(m, Model) + assert m.value == "foo" + + +def test_discriminated_unions_invalid_data() -> None: + class A(BaseModel): + type: Literal["a"] + + data: str + + class B(BaseModel): + type: Literal["b"] + + data: int + + m = construct_type( + value={"type": "b", "data": "foo"}, + type_=cast(Any, Annotated[Union[A, B], PropertyInfo(discriminator="type")]), + ) + assert isinstance(m, B) + assert m.type == "b" + assert m.data == "foo" # type: ignore[comparison-overlap] + + m = construct_type( + value={"type": "a", "data": 100}, + type_=cast(Any, Annotated[Union[A, B], PropertyInfo(discriminator="type")]), + ) + assert isinstance(m, A) + assert m.type == "a" + if PYDANTIC_V1: + # pydantic v1 automatically converts inputs to strings + # if the expected type is a str + assert m.data == "100" + else: + assert m.data == 100 # type: ignore[comparison-overlap] + + +def test_discriminated_unions_unknown_variant() -> None: + class A(BaseModel): + type: Literal["a"] + + data: str + + class B(BaseModel): + type: Literal["b"] + + data: int + + m = construct_type( + value={"type": "c", "data": None, "new_thing": "bar"}, + type_=cast(Any, Annotated[Union[A, B], PropertyInfo(discriminator="type")]), + ) + + # just chooses the first variant + assert isinstance(m, A) + assert m.type == "c" # type: ignore[comparison-overlap] + assert m.data == None # type: ignore[unreachable] + assert m.new_thing == "bar" + + +def test_discriminated_unions_invalid_data_nested_unions() -> None: + class A(BaseModel): + type: Literal["a"] + + data: str + + class B(BaseModel): + type: Literal["b"] + + data: int + + class C(BaseModel): + type: Literal["c"] + + data: bool + + m = construct_type( + value={"type": "b", "data": "foo"}, + type_=cast(Any, Annotated[Union[Union[A, B], C], PropertyInfo(discriminator="type")]), + ) + assert isinstance(m, B) + assert m.type == "b" + assert m.data == "foo" # type: ignore[comparison-overlap] + + m = construct_type( + value={"type": "c", "data": "foo"}, + type_=cast(Any, Annotated[Union[Union[A, B], C], PropertyInfo(discriminator="type")]), + ) + assert isinstance(m, C) + assert m.type == "c" + assert m.data == "foo" # type: ignore[comparison-overlap] + + +def test_discriminated_unions_with_aliases_invalid_data() -> None: + class A(BaseModel): + foo_type: Literal["a"] = Field(alias="type") + + data: str + + class B(BaseModel): + foo_type: Literal["b"] = Field(alias="type") + + data: int + + m = construct_type( + value={"type": "b", "data": "foo"}, + type_=cast(Any, Annotated[Union[A, B], PropertyInfo(discriminator="foo_type")]), + ) + assert isinstance(m, B) + assert m.foo_type == "b" + assert m.data == "foo" # type: ignore[comparison-overlap] + + m = construct_type( + value={"type": "a", "data": 100}, + type_=cast(Any, Annotated[Union[A, B], PropertyInfo(discriminator="foo_type")]), + ) + assert isinstance(m, A) + assert m.foo_type == "a" + if PYDANTIC_V1: + # pydantic v1 automatically converts inputs to strings + # if the expected type is a str + assert m.data == "100" + else: + assert m.data == 100 # type: ignore[comparison-overlap] + + +def test_discriminated_unions_overlapping_discriminators_invalid_data() -> None: + class A(BaseModel): + type: Literal["a"] + + data: bool + + class B(BaseModel): + type: Literal["a"] + + data: int + + m = construct_type( + value={"type": "a", "data": "foo"}, + type_=cast(Any, Annotated[Union[A, B], PropertyInfo(discriminator="type")]), + ) + assert isinstance(m, B) + assert m.type == "a" + assert m.data == "foo" # type: ignore[comparison-overlap] + + +def test_discriminated_unions_invalid_data_uses_cache() -> None: + class A(BaseModel): + type: Literal["a"] + + data: str + + class B(BaseModel): + type: Literal["b"] + + data: int + + UnionType = cast(Any, Union[A, B]) + + assert not DISCRIMINATOR_CACHE.get(UnionType) + + m = construct_type( + value={"type": "b", "data": "foo"}, type_=cast(Any, Annotated[UnionType, PropertyInfo(discriminator="type")]) + ) + assert isinstance(m, B) + assert m.type == "b" + assert m.data == "foo" # type: ignore[comparison-overlap] + + discriminator = DISCRIMINATOR_CACHE.get(UnionType) + assert discriminator is not None + + m = construct_type( + value={"type": "b", "data": "foo"}, type_=cast(Any, Annotated[UnionType, PropertyInfo(discriminator="type")]) + ) + assert isinstance(m, B) + assert m.type == "b" + assert m.data == "foo" # type: ignore[comparison-overlap] + + # if the discriminator details object stays the same between invocations then + # we hit the cache + assert DISCRIMINATOR_CACHE.get(UnionType) is discriminator + + +@pytest.mark.skipif(PYDANTIC_V1, reason="TypeAliasType is not supported in Pydantic v1") +def test_type_alias_type() -> None: + Alias = TypeAliasType("Alias", str) # pyright: ignore + + class Model(BaseModel): + alias: Alias + union: Union[int, Alias] + + m = construct_type(value={"alias": "foo", "union": "bar"}, type_=Model) + assert isinstance(m, Model) + assert isinstance(m.alias, str) + assert m.alias == "foo" + assert isinstance(m.union, str) + assert m.union == "bar" + + +@pytest.mark.skipif(PYDANTIC_V1, reason="TypeAliasType is not supported in Pydantic v1") +def test_field_named_cls() -> None: + class Model(BaseModel): + cls: str + + m = construct_type(value={"cls": "foo"}, type_=Model) + assert isinstance(m, Model) + assert isinstance(m.cls, str) + + +def test_discriminated_union_case() -> None: + class A(BaseModel): + type: Literal["a"] + + data: bool + + class B(BaseModel): + type: Literal["b"] + + data: List[Union[A, object]] + + class ModelA(BaseModel): + type: Literal["modelA"] + + data: int + + class ModelB(BaseModel): + type: Literal["modelB"] + + required: str + + data: Union[A, B] + + # when constructing ModelA | ModelB, value data doesn't match ModelB exactly - missing `required` + m = construct_type( + value={"type": "modelB", "data": {"type": "a", "data": True}}, + type_=cast(Any, Annotated[Union[ModelA, ModelB], PropertyInfo(discriminator="type")]), + ) + + assert isinstance(m, ModelB) + + +def test_nested_discriminated_union() -> None: + class InnerType1(BaseModel): + type: Literal["type_1"] + + class InnerModel(BaseModel): + inner_value: str + + class InnerType2(BaseModel): + type: Literal["type_2"] + some_inner_model: InnerModel + + class Type1(BaseModel): + base_type: Literal["base_type_1"] + value: Annotated[ + Union[ + InnerType1, + InnerType2, + ], + PropertyInfo(discriminator="type"), + ] + + class Type2(BaseModel): + base_type: Literal["base_type_2"] + + T = Annotated[ + Union[ + Type1, + Type2, + ], + PropertyInfo(discriminator="base_type"), + ] + + model = construct_type( + type_=T, + value={ + "base_type": "base_type_1", + "value": { + "type": "type_2", + }, + }, + ) + assert isinstance(model, Type1) + assert isinstance(model.value, InnerType2) + + +@pytest.mark.skipif(PYDANTIC_V1, reason="this is only supported in pydantic v2 for now") +def test_extra_properties() -> None: + class Item(BaseModel): + prop: int + + class Model(BaseModel): + __pydantic_extra__: Dict[str, Item] = Field(init=False) # pyright: ignore[reportIncompatibleVariableOverride] + + other: str + + if TYPE_CHECKING: + + def __getattr__(self, attr: str) -> Item: ... + + model = construct_type( + type_=Model, + value={ + "a": {"prop": 1}, + "other": "foo", + }, + ) + assert isinstance(model, Model) + assert model.a.prop == 1 + assert isinstance(model.a, Item) + assert model.other == "foo" + + +# NOTE: Workaround for Pydantic Iterable behavior. +# Iterable fields are replaced with a ValidatorIterator and may be consumed +# during serialization, which can cause subsequent dumps to return empty data. +# See: https://github.com/pydantic/pydantic/issues/9541 +@pytest.mark.parametrize( + "data, expected_validated", + [ + ([1, 2, 3], [1, 2, 3]), + ((1, 2, 3), (1, 2, 3)), + (set([1, 2, 3]), set([1, 2, 3])), + (iter([1, 2, 3]), [1, 2, 3]), + ([], []), + ((x for x in [1, 2, 3]), [1, 2, 3]), + (map(lambda x: x, [1, 2, 3]), [1, 2, 3]), + (frozenset([1, 2, 3]), frozenset([1, 2, 3])), + (deque([1, 2, 3]), deque([1, 2, 3])), + ], + ids=["list", "tuple", "set", "iterator", "empty", "generator", "map", "frozenset", "deque"], +) +@pytest.mark.skipif(PYDANTIC_V1, reason="this is only supported in pydantic v2") +def test_iterable_construction(data: Iterable[int], expected_validated: Iterable[int]) -> None: + class TypeWithIterable(TypedDict): + items: EagerIterable[int] + + class Model(BaseModel): + data: TypeWithIterable + + m = Model.model_validate({"data": {"items": data}}) + assert m.data["items"] == expected_validated + + # Verify repeated dumps don't lose data (the original bug) + assert m.model_dump()["data"]["items"] == list(expected_validated) + assert m.model_dump()["data"]["items"] == list(expected_validated) + + +@pytest.mark.skipif(PYDANTIC_V1, reason="this is only supported in pydantic v2") +def test_iterable_construction_str_falls_back_to_list() -> None: + # str is iterable (over chars), but str(list_of_chars) produces the list's repr + # rather than reconstructing a string from items. We special-case str to fall + # back to list instead of attempting reconstruction. + class TypeWithIterable(TypedDict): + items: EagerIterable[str] + + class Model(BaseModel): + data: TypeWithIterable + + m = Model.model_validate({"data": {"items": "hello"}}) + + # falls back to list of chars rather than calling str(["h", "e", "l", "l", "o"]) + assert m.data["items"] == ["h", "e", "l", "l", "o"] + assert m.model_dump()["data"]["items"] == ["h", "e", "l", "l", "o"] diff --git a/tests/test_qs.py b/tests/test_qs.py new file mode 100644 index 00000000..3dbda131 --- /dev/null +++ b/tests/test_qs.py @@ -0,0 +1,78 @@ +from typing import Any, cast +from functools import partial +from urllib.parse import unquote + +import pytest + +from writerai._qs import Querystring, stringify + + +def test_empty() -> None: + assert stringify({}) == "" + assert stringify({"a": {}}) == "" + assert stringify({"a": {"b": {"c": {}}}}) == "" + + +def test_basic() -> None: + assert stringify({"a": 1}) == "a=1" + assert stringify({"a": "b"}) == "a=b" + assert stringify({"a": True}) == "a=true" + assert stringify({"a": False}) == "a=false" + assert stringify({"a": 1.23456}) == "a=1.23456" + assert stringify({"a": None}) == "" + + +@pytest.mark.parametrize("method", ["class", "function"]) +def test_nested_dotted(method: str) -> None: + if method == "class": + serialise = Querystring(nested_format="dots").stringify + else: + serialise = partial(stringify, nested_format="dots") + + assert unquote(serialise({"a": {"b": "c"}})) == "a.b=c" + assert unquote(serialise({"a": {"b": "c", "d": "e", "f": "g"}})) == "a.b=c&a.d=e&a.f=g" + assert unquote(serialise({"a": {"b": {"c": {"d": "e"}}}})) == "a.b.c.d=e" + assert unquote(serialise({"a": {"b": True}})) == "a.b=true" + + +def test_nested_brackets() -> None: + assert unquote(stringify({"a": {"b": "c"}})) == "a[b]=c" + assert unquote(stringify({"a": {"b": "c", "d": "e", "f": "g"}})) == "a[b]=c&a[d]=e&a[f]=g" + assert unquote(stringify({"a": {"b": {"c": {"d": "e"}}}})) == "a[b][c][d]=e" + assert unquote(stringify({"a": {"b": True}})) == "a[b]=true" + + +@pytest.mark.parametrize("method", ["class", "function"]) +def test_array_comma(method: str) -> None: + if method == "class": + serialise = Querystring(array_format="comma").stringify + else: + serialise = partial(stringify, array_format="comma") + + assert unquote(serialise({"in": ["foo", "bar"]})) == "in=foo,bar" + assert unquote(serialise({"a": {"b": [True, False]}})) == "a[b]=true,false" + assert unquote(serialise({"a": {"b": [True, False, None, True]}})) == "a[b]=true,false,true" + + +def test_array_repeat() -> None: + assert unquote(stringify({"in": ["foo", "bar"]})) == "in=foo&in=bar" + assert unquote(stringify({"a": {"b": [True, False]}})) == "a[b]=true&a[b]=false" + assert unquote(stringify({"a": {"b": [True, False, None, True]}})) == "a[b]=true&a[b]=false&a[b]=true" + assert unquote(stringify({"in": ["foo", {"b": {"c": ["d", "e"]}}]})) == "in=foo&in[b][c]=d&in[b][c]=e" + + +@pytest.mark.parametrize("method", ["class", "function"]) +def test_array_brackets(method: str) -> None: + if method == "class": + serialise = Querystring(array_format="brackets").stringify + else: + serialise = partial(stringify, array_format="brackets") + + assert unquote(serialise({"in": ["foo", "bar"]})) == "in[]=foo&in[]=bar" + assert unquote(serialise({"a": {"b": [True, False]}})) == "a[b][]=true&a[b][]=false" + assert unquote(serialise({"a": {"b": [True, False, None, True]}})) == "a[b][]=true&a[b][]=false&a[b][]=true" + + +def test_unknown_array_format() -> None: + with pytest.raises(NotImplementedError, match="Unknown array_format value: foo, choose from comma, repeat"): + stringify({"a": ["foo", "bar"]}, array_format=cast(Any, "foo")) diff --git a/tests/test_required_args.py b/tests/test_required_args.py new file mode 100644 index 00000000..db5f699c --- /dev/null +++ b/tests/test_required_args.py @@ -0,0 +1,111 @@ +from __future__ import annotations + +import pytest + +from writerai._utils import required_args + + +def test_too_many_positional_params() -> None: + @required_args(["a"]) + def foo(a: str | None = None) -> str | None: + return a + + with pytest.raises(TypeError, match=r"foo\(\) takes 1 argument\(s\) but 2 were given"): + foo("a", "b") # type: ignore + + +def test_positional_param() -> None: + @required_args(["a"]) + def foo(a: str | None = None) -> str | None: + return a + + assert foo("a") == "a" + assert foo(None) is None + assert foo(a="b") == "b" + + with pytest.raises(TypeError, match="Missing required argument: 'a'"): + foo() + + +def test_keyword_only_param() -> None: + @required_args(["a"]) + def foo(*, a: str | None = None) -> str | None: + return a + + assert foo(a="a") == "a" + assert foo(a=None) is None + assert foo(a="b") == "b" + + with pytest.raises(TypeError, match="Missing required argument: 'a'"): + foo() + + +def test_multiple_params() -> None: + @required_args(["a", "b", "c"]) + def foo(a: str = "", *, b: str = "", c: str = "") -> str | None: + return f"{a} {b} {c}" + + assert foo(a="a", b="b", c="c") == "a b c" + + error_message = r"Missing required arguments.*" + + with pytest.raises(TypeError, match=error_message): + foo() + + with pytest.raises(TypeError, match=error_message): + foo(a="a") + + with pytest.raises(TypeError, match=error_message): + foo(b="b") + + with pytest.raises(TypeError, match=error_message): + foo(c="c") + + with pytest.raises(TypeError, match=r"Missing required argument: 'a'"): + foo(b="a", c="c") + + with pytest.raises(TypeError, match=r"Missing required argument: 'b'"): + foo("a", c="c") + + +def test_multiple_variants() -> None: + @required_args(["a"], ["b"]) + def foo(*, a: str | None = None, b: str | None = None) -> str | None: + return a if a is not None else b + + assert foo(a="foo") == "foo" + assert foo(b="bar") == "bar" + assert foo(a=None) is None + assert foo(b=None) is None + + # TODO: this error message could probably be improved + with pytest.raises( + TypeError, + match=r"Missing required arguments; Expected either \('a'\) or \('b'\) arguments to be given", + ): + foo() + + +def test_multiple_params_multiple_variants() -> None: + @required_args(["a", "b"], ["c"]) + def foo(*, a: str | None = None, b: str | None = None, c: str | None = None) -> str | None: + if a is not None: + return a + if b is not None: + return b + return c + + error_message = r"Missing required arguments; Expected either \('a' and 'b'\) or \('c'\) arguments to be given" + + with pytest.raises(TypeError, match=error_message): + foo(a="foo") + + with pytest.raises(TypeError, match=error_message): + foo(b="bar") + + with pytest.raises(TypeError, match=error_message): + foo() + + assert foo(a=None, b="bar") == "bar" + assert foo(c=None) is None + assert foo(c="foo") == "foo" diff --git a/tests/test_response.py b/tests/test_response.py new file mode 100644 index 00000000..dcbaece2 --- /dev/null +++ b/tests/test_response.py @@ -0,0 +1,277 @@ +import json +from typing import Any, List, Union, cast +from typing_extensions import Annotated + +import httpx +import pytest +import pydantic + +from writerai import Writer, BaseModel, AsyncWriter +from writerai._response import ( + APIResponse, + BaseAPIResponse, + AsyncAPIResponse, + BinaryAPIResponse, + AsyncBinaryAPIResponse, + extract_response_type, +) +from writerai._streaming import Stream +from writerai._base_client import FinalRequestOptions + + +class ConcreteBaseAPIResponse(APIResponse[bytes]): ... + + +class ConcreteAPIResponse(APIResponse[List[str]]): ... + + +class ConcreteAsyncAPIResponse(APIResponse[httpx.Response]): ... + + +def test_extract_response_type_direct_classes() -> None: + assert extract_response_type(BaseAPIResponse[str]) == str + assert extract_response_type(APIResponse[str]) == str + assert extract_response_type(AsyncAPIResponse[str]) == str + + +def test_extract_response_type_direct_class_missing_type_arg() -> None: + with pytest.raises( + RuntimeError, + match="Expected type to have a type argument at index 0 but it did not", + ): + extract_response_type(AsyncAPIResponse) + + +def test_extract_response_type_concrete_subclasses() -> None: + assert extract_response_type(ConcreteBaseAPIResponse) == bytes + assert extract_response_type(ConcreteAPIResponse) == List[str] + assert extract_response_type(ConcreteAsyncAPIResponse) == httpx.Response + + +def test_extract_response_type_binary_response() -> None: + assert extract_response_type(BinaryAPIResponse) == bytes + assert extract_response_type(AsyncBinaryAPIResponse) == bytes + + +class PydanticModel(pydantic.BaseModel): ... + + +def test_response_parse_mismatched_basemodel(client: Writer) -> None: + response = APIResponse( + raw=httpx.Response(200, content=b"foo"), + client=client, + stream=False, + stream_cls=None, + cast_to=str, + options=FinalRequestOptions.construct(method="get", url="/foo"), + ) + + with pytest.raises( + TypeError, + match="Pydantic models must subclass our base model type, e.g. `from writerai import BaseModel`", + ): + response.parse(to=PydanticModel) + + +@pytest.mark.asyncio +async def test_async_response_parse_mismatched_basemodel(async_client: AsyncWriter) -> None: + response = AsyncAPIResponse( + raw=httpx.Response(200, content=b"foo"), + client=async_client, + stream=False, + stream_cls=None, + cast_to=str, + options=FinalRequestOptions.construct(method="get", url="/foo"), + ) + + with pytest.raises( + TypeError, + match="Pydantic models must subclass our base model type, e.g. `from writerai import BaseModel`", + ): + await response.parse(to=PydanticModel) + + +def test_response_parse_custom_stream(client: Writer) -> None: + response = APIResponse( + raw=httpx.Response(200, content=b"foo"), + client=client, + stream=True, + stream_cls=None, + cast_to=str, + options=FinalRequestOptions.construct(method="get", url="/foo"), + ) + + stream = response.parse(to=Stream[int]) + assert stream._cast_to == int + + +@pytest.mark.asyncio +async def test_async_response_parse_custom_stream(async_client: AsyncWriter) -> None: + response = AsyncAPIResponse( + raw=httpx.Response(200, content=b"foo"), + client=async_client, + stream=True, + stream_cls=None, + cast_to=str, + options=FinalRequestOptions.construct(method="get", url="/foo"), + ) + + stream = await response.parse(to=Stream[int]) + assert stream._cast_to == int + + +class CustomModel(BaseModel): + foo: str + bar: int + + +def test_response_parse_custom_model(client: Writer) -> None: + response = APIResponse( + raw=httpx.Response(200, content=json.dumps({"foo": "hello!", "bar": 2})), + client=client, + stream=False, + stream_cls=None, + cast_to=str, + options=FinalRequestOptions.construct(method="get", url="/foo"), + ) + + obj = response.parse(to=CustomModel) + assert obj.foo == "hello!" + assert obj.bar == 2 + + +@pytest.mark.asyncio +async def test_async_response_parse_custom_model(async_client: AsyncWriter) -> None: + response = AsyncAPIResponse( + raw=httpx.Response(200, content=json.dumps({"foo": "hello!", "bar": 2})), + client=async_client, + stream=False, + stream_cls=None, + cast_to=str, + options=FinalRequestOptions.construct(method="get", url="/foo"), + ) + + obj = await response.parse(to=CustomModel) + assert obj.foo == "hello!" + assert obj.bar == 2 + + +def test_response_parse_annotated_type(client: Writer) -> None: + response = APIResponse( + raw=httpx.Response(200, content=json.dumps({"foo": "hello!", "bar": 2})), + client=client, + stream=False, + stream_cls=None, + cast_to=str, + options=FinalRequestOptions.construct(method="get", url="/foo"), + ) + + obj = response.parse( + to=cast("type[CustomModel]", Annotated[CustomModel, "random metadata"]), + ) + assert obj.foo == "hello!" + assert obj.bar == 2 + + +async def test_async_response_parse_annotated_type(async_client: AsyncWriter) -> None: + response = AsyncAPIResponse( + raw=httpx.Response(200, content=json.dumps({"foo": "hello!", "bar": 2})), + client=async_client, + stream=False, + stream_cls=None, + cast_to=str, + options=FinalRequestOptions.construct(method="get", url="/foo"), + ) + + obj = await response.parse( + to=cast("type[CustomModel]", Annotated[CustomModel, "random metadata"]), + ) + assert obj.foo == "hello!" + assert obj.bar == 2 + + +@pytest.mark.parametrize( + "content, expected", + [ + ("false", False), + ("true", True), + ("False", False), + ("True", True), + ("TrUe", True), + ("FalSe", False), + ], +) +def test_response_parse_bool(client: Writer, content: str, expected: bool) -> None: + response = APIResponse( + raw=httpx.Response(200, content=content), + client=client, + stream=False, + stream_cls=None, + cast_to=str, + options=FinalRequestOptions.construct(method="get", url="/foo"), + ) + + result = response.parse(to=bool) + assert result is expected + + +@pytest.mark.parametrize( + "content, expected", + [ + ("false", False), + ("true", True), + ("False", False), + ("True", True), + ("TrUe", True), + ("FalSe", False), + ], +) +async def test_async_response_parse_bool(client: AsyncWriter, content: str, expected: bool) -> None: + response = AsyncAPIResponse( + raw=httpx.Response(200, content=content), + client=client, + stream=False, + stream_cls=None, + cast_to=str, + options=FinalRequestOptions.construct(method="get", url="/foo"), + ) + + result = await response.parse(to=bool) + assert result is expected + + +class OtherModel(BaseModel): + a: str + + +@pytest.mark.parametrize("client", [False], indirect=True) # loose validation +def test_response_parse_expect_model_union_non_json_content(client: Writer) -> None: + response = APIResponse( + raw=httpx.Response(200, content=b"foo", headers={"Content-Type": "application/text"}), + client=client, + stream=False, + stream_cls=None, + cast_to=str, + options=FinalRequestOptions.construct(method="get", url="/foo"), + ) + + obj = response.parse(to=cast(Any, Union[CustomModel, OtherModel])) + assert isinstance(obj, str) + assert obj == "foo" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("async_client", [False], indirect=True) # loose validation +async def test_async_response_parse_expect_model_union_non_json_content(async_client: AsyncWriter) -> None: + response = AsyncAPIResponse( + raw=httpx.Response(200, content=b"foo", headers={"Content-Type": "application/text"}), + client=async_client, + stream=False, + stream_cls=None, + cast_to=str, + options=FinalRequestOptions.construct(method="get", url="/foo"), + ) + + obj = await response.parse(to=cast(Any, Union[CustomModel, OtherModel])) + assert isinstance(obj, str) + assert obj == "foo" diff --git a/tests/test_streaming.py b/tests/test_streaming.py new file mode 100644 index 00000000..1c1f6d23 --- /dev/null +++ b/tests/test_streaming.py @@ -0,0 +1,248 @@ +from __future__ import annotations + +from typing import Iterator, AsyncIterator + +import httpx +import pytest + +from writerai import Writer, AsyncWriter +from writerai._streaming import Stream, AsyncStream, ServerSentEvent + + +@pytest.mark.asyncio +@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) +async def test_basic(sync: bool, client: Writer, async_client: AsyncWriter) -> None: + def body() -> Iterator[bytes]: + yield b"event: completion\n" + yield b'data: {"foo":true}\n' + yield b"\n" + + iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client) + + sse = await iter_next(iterator) + assert sse.event == "completion" + assert sse.json() == {"foo": True} + + await assert_empty_iter(iterator) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) +async def test_data_missing_event(sync: bool, client: Writer, async_client: AsyncWriter) -> None: + def body() -> Iterator[bytes]: + yield b'data: {"foo":true}\n' + yield b"\n" + + iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client) + + sse = await iter_next(iterator) + assert sse.event is None + assert sse.json() == {"foo": True} + + await assert_empty_iter(iterator) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) +async def test_event_missing_data(sync: bool, client: Writer, async_client: AsyncWriter) -> None: + def body() -> Iterator[bytes]: + yield b"event: ping\n" + yield b"\n" + + iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client) + + sse = await iter_next(iterator) + assert sse.event == "ping" + assert sse.data == "" + + await assert_empty_iter(iterator) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) +async def test_multiple_events(sync: bool, client: Writer, async_client: AsyncWriter) -> None: + def body() -> Iterator[bytes]: + yield b"event: ping\n" + yield b"\n" + yield b"event: completion\n" + yield b"\n" + + iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client) + + sse = await iter_next(iterator) + assert sse.event == "ping" + assert sse.data == "" + + sse = await iter_next(iterator) + assert sse.event == "completion" + assert sse.data == "" + + await assert_empty_iter(iterator) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) +async def test_multiple_events_with_data(sync: bool, client: Writer, async_client: AsyncWriter) -> None: + def body() -> Iterator[bytes]: + yield b"event: ping\n" + yield b'data: {"foo":true}\n' + yield b"\n" + yield b"event: completion\n" + yield b'data: {"bar":false}\n' + yield b"\n" + + iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client) + + sse = await iter_next(iterator) + assert sse.event == "ping" + assert sse.json() == {"foo": True} + + sse = await iter_next(iterator) + assert sse.event == "completion" + assert sse.json() == {"bar": False} + + await assert_empty_iter(iterator) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) +async def test_multiple_data_lines_with_empty_line(sync: bool, client: Writer, async_client: AsyncWriter) -> None: + def body() -> Iterator[bytes]: + yield b"event: ping\n" + yield b"data: {\n" + yield b'data: "foo":\n' + yield b"data: \n" + yield b"data:\n" + yield b"data: true}\n" + yield b"\n\n" + + iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client) + + sse = await iter_next(iterator) + assert sse.event == "ping" + assert sse.json() == {"foo": True} + assert sse.data == '{\n"foo":\n\n\ntrue}' + + await assert_empty_iter(iterator) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) +async def test_data_json_escaped_double_new_line(sync: bool, client: Writer, async_client: AsyncWriter) -> None: + def body() -> Iterator[bytes]: + yield b"event: ping\n" + yield b'data: {"foo": "my long\\n\\ncontent"}' + yield b"\n\n" + + iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client) + + sse = await iter_next(iterator) + assert sse.event == "ping" + assert sse.json() == {"foo": "my long\n\ncontent"} + + await assert_empty_iter(iterator) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) +async def test_multiple_data_lines(sync: bool, client: Writer, async_client: AsyncWriter) -> None: + def body() -> Iterator[bytes]: + yield b"event: ping\n" + yield b"data: {\n" + yield b'data: "foo":\n' + yield b"data: true}\n" + yield b"\n\n" + + iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client) + + sse = await iter_next(iterator) + assert sse.event == "ping" + assert sse.json() == {"foo": True} + + await assert_empty_iter(iterator) + + +@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) +async def test_special_new_line_character( + sync: bool, + client: Writer, + async_client: AsyncWriter, +) -> None: + def body() -> Iterator[bytes]: + yield b'data: {"content":" culpa"}\n' + yield b"\n" + yield b'data: {"content":" \xe2\x80\xa8"}\n' + yield b"\n" + yield b'data: {"content":"foo"}\n' + yield b"\n" + + iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client) + + sse = await iter_next(iterator) + assert sse.event is None + assert sse.json() == {"content": " culpa"} + + sse = await iter_next(iterator) + assert sse.event is None + assert sse.json() == {"content": " 
"} + + sse = await iter_next(iterator) + assert sse.event is None + assert sse.json() == {"content": "foo"} + + await assert_empty_iter(iterator) + + +@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) +async def test_multi_byte_character_multiple_chunks( + sync: bool, + client: Writer, + async_client: AsyncWriter, +) -> None: + def body() -> Iterator[bytes]: + yield b'data: {"content":"' + # bytes taken from the string 'известни' and arbitrarily split + # so that some multi-byte characters span multiple chunks + yield b"\xd0" + yield b"\xb8\xd0\xb7\xd0" + yield b"\xb2\xd0\xb5\xd1\x81\xd1\x82\xd0\xbd\xd0\xb8" + yield b'"}\n' + yield b"\n" + + iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client) + + sse = await iter_next(iterator) + assert sse.event is None + assert sse.json() == {"content": "известни"} + + +async def to_aiter(iter: Iterator[bytes]) -> AsyncIterator[bytes]: + for chunk in iter: + yield chunk + + +async def iter_next(iter: Iterator[ServerSentEvent] | AsyncIterator[ServerSentEvent]) -> ServerSentEvent: + if isinstance(iter, AsyncIterator): + return await iter.__anext__() + + return next(iter) + + +async def assert_empty_iter(iter: Iterator[ServerSentEvent] | AsyncIterator[ServerSentEvent]) -> None: + with pytest.raises((StopAsyncIteration, RuntimeError)): + await iter_next(iter) + + +def make_event_iterator( + content: Iterator[bytes], + *, + sync: bool, + client: Writer, + async_client: AsyncWriter, +) -> Iterator[ServerSentEvent] | AsyncIterator[ServerSentEvent]: + if sync: + return Stream(cast_to=object, client=client, response=httpx.Response(200, content=content))._iter_events() + + return AsyncStream( + cast_to=object, client=async_client, response=httpx.Response(200, content=to_aiter(content)) + )._iter_events() diff --git a/tests/test_transform.py b/tests/test_transform.py new file mode 100644 index 00000000..8fb16570 --- /dev/null +++ b/tests/test_transform.py @@ -0,0 +1,460 @@ +from __future__ import annotations + +import io +import pathlib +from typing import Any, Dict, List, Union, TypeVar, Iterable, Optional, cast +from datetime import date, datetime +from typing_extensions import Required, Annotated, TypedDict + +import pytest + +from writerai._types import Base64FileInput, omit, not_given +from writerai._utils import ( + PropertyInfo, + transform as _transform, + parse_datetime, + async_transform as _async_transform, +) +from writerai._compat import PYDANTIC_V1 +from writerai._models import BaseModel + +_T = TypeVar("_T") + +SAMPLE_FILE_PATH = pathlib.Path(__file__).parent.joinpath("sample_file.txt") + + +async def transform( + data: _T, + expected_type: object, + use_async: bool, +) -> _T: + if use_async: + return await _async_transform(data, expected_type=expected_type) + + return _transform(data, expected_type=expected_type) + + +parametrize = pytest.mark.parametrize("use_async", [False, True], ids=["sync", "async"]) + + +class Foo1(TypedDict): + foo_bar: Annotated[str, PropertyInfo(alias="fooBar")] + + +@parametrize +@pytest.mark.asyncio +async def test_top_level_alias(use_async: bool) -> None: + assert await transform({"foo_bar": "hello"}, expected_type=Foo1, use_async=use_async) == {"fooBar": "hello"} + + +class Foo2(TypedDict): + bar: Bar2 + + +class Bar2(TypedDict): + this_thing: Annotated[int, PropertyInfo(alias="this__thing")] + baz: Annotated[Baz2, PropertyInfo(alias="Baz")] + + +class Baz2(TypedDict): + my_baz: Annotated[str, PropertyInfo(alias="myBaz")] + + +@parametrize +@pytest.mark.asyncio +async def test_recursive_typeddict(use_async: bool) -> None: + assert await transform({"bar": {"this_thing": 1}}, Foo2, use_async) == {"bar": {"this__thing": 1}} + assert await transform({"bar": {"baz": {"my_baz": "foo"}}}, Foo2, use_async) == {"bar": {"Baz": {"myBaz": "foo"}}} + + +class Foo3(TypedDict): + things: List[Bar3] + + +class Bar3(TypedDict): + my_field: Annotated[str, PropertyInfo(alias="myField")] + + +@parametrize +@pytest.mark.asyncio +async def test_list_of_typeddict(use_async: bool) -> None: + result = await transform({"things": [{"my_field": "foo"}, {"my_field": "foo2"}]}, Foo3, use_async) + assert result == {"things": [{"myField": "foo"}, {"myField": "foo2"}]} + + +class Foo4(TypedDict): + foo: Union[Bar4, Baz4] + + +class Bar4(TypedDict): + foo_bar: Annotated[str, PropertyInfo(alias="fooBar")] + + +class Baz4(TypedDict): + foo_baz: Annotated[str, PropertyInfo(alias="fooBaz")] + + +@parametrize +@pytest.mark.asyncio +async def test_union_of_typeddict(use_async: bool) -> None: + assert await transform({"foo": {"foo_bar": "bar"}}, Foo4, use_async) == {"foo": {"fooBar": "bar"}} + assert await transform({"foo": {"foo_baz": "baz"}}, Foo4, use_async) == {"foo": {"fooBaz": "baz"}} + assert await transform({"foo": {"foo_baz": "baz", "foo_bar": "bar"}}, Foo4, use_async) == { + "foo": {"fooBaz": "baz", "fooBar": "bar"} + } + + +class Foo5(TypedDict): + foo: Annotated[Union[Bar4, List[Baz4]], PropertyInfo(alias="FOO")] + + +class Bar5(TypedDict): + foo_bar: Annotated[str, PropertyInfo(alias="fooBar")] + + +class Baz5(TypedDict): + foo_baz: Annotated[str, PropertyInfo(alias="fooBaz")] + + +@parametrize +@pytest.mark.asyncio +async def test_union_of_list(use_async: bool) -> None: + assert await transform({"foo": {"foo_bar": "bar"}}, Foo5, use_async) == {"FOO": {"fooBar": "bar"}} + assert await transform( + { + "foo": [ + {"foo_baz": "baz"}, + {"foo_baz": "baz"}, + ] + }, + Foo5, + use_async, + ) == {"FOO": [{"fooBaz": "baz"}, {"fooBaz": "baz"}]} + + +class Foo6(TypedDict): + bar: Annotated[str, PropertyInfo(alias="Bar")] + + +@parametrize +@pytest.mark.asyncio +async def test_includes_unknown_keys(use_async: bool) -> None: + assert await transform({"bar": "bar", "baz_": {"FOO": 1}}, Foo6, use_async) == { + "Bar": "bar", + "baz_": {"FOO": 1}, + } + + +class Foo7(TypedDict): + bar: Annotated[List[Bar7], PropertyInfo(alias="bAr")] + foo: Bar7 + + +class Bar7(TypedDict): + foo: str + + +@parametrize +@pytest.mark.asyncio +async def test_ignores_invalid_input(use_async: bool) -> None: + assert await transform({"bar": ""}, Foo7, use_async) == {"bAr": ""} + assert await transform({"foo": ""}, Foo7, use_async) == {"foo": ""} + + +class DatetimeDict(TypedDict, total=False): + foo: Annotated[datetime, PropertyInfo(format="iso8601")] + + bar: Annotated[Optional[datetime], PropertyInfo(format="iso8601")] + + required: Required[Annotated[Optional[datetime], PropertyInfo(format="iso8601")]] + + list_: Required[Annotated[Optional[List[datetime]], PropertyInfo(format="iso8601")]] + + union: Annotated[Union[int, datetime], PropertyInfo(format="iso8601")] + + +class DateDict(TypedDict, total=False): + foo: Annotated[date, PropertyInfo(format="iso8601")] + + +class DatetimeModel(BaseModel): + foo: datetime + + +class DateModel(BaseModel): + foo: Optional[date] + + +@parametrize +@pytest.mark.asyncio +async def test_iso8601_format(use_async: bool) -> None: + dt = datetime.fromisoformat("2023-02-23T14:16:36.337692+00:00") + tz = "+00:00" if PYDANTIC_V1 else "Z" + assert await transform({"foo": dt}, DatetimeDict, use_async) == {"foo": "2023-02-23T14:16:36.337692+00:00"} # type: ignore[comparison-overlap] + assert await transform(DatetimeModel(foo=dt), Any, use_async) == {"foo": "2023-02-23T14:16:36.337692" + tz} # type: ignore[comparison-overlap] + + dt = dt.replace(tzinfo=None) + assert await transform({"foo": dt}, DatetimeDict, use_async) == {"foo": "2023-02-23T14:16:36.337692"} # type: ignore[comparison-overlap] + assert await transform(DatetimeModel(foo=dt), Any, use_async) == {"foo": "2023-02-23T14:16:36.337692"} # type: ignore[comparison-overlap] + + assert await transform({"foo": None}, DateDict, use_async) == {"foo": None} # type: ignore[comparison-overlap] + assert await transform(DateModel(foo=None), Any, use_async) == {"foo": None} # type: ignore + assert await transform({"foo": date.fromisoformat("2023-02-23")}, DateDict, use_async) == {"foo": "2023-02-23"} # type: ignore[comparison-overlap] + assert await transform(DateModel(foo=date.fromisoformat("2023-02-23")), DateDict, use_async) == { + "foo": "2023-02-23" + } # type: ignore[comparison-overlap] + + +@parametrize +@pytest.mark.asyncio +async def test_optional_iso8601_format(use_async: bool) -> None: + dt = datetime.fromisoformat("2023-02-23T14:16:36.337692+00:00") + assert await transform({"bar": dt}, DatetimeDict, use_async) == {"bar": "2023-02-23T14:16:36.337692+00:00"} # type: ignore[comparison-overlap] + + assert await transform({"bar": None}, DatetimeDict, use_async) == {"bar": None} + + +@parametrize +@pytest.mark.asyncio +async def test_required_iso8601_format(use_async: bool) -> None: + dt = datetime.fromisoformat("2023-02-23T14:16:36.337692+00:00") + assert await transform({"required": dt}, DatetimeDict, use_async) == { + "required": "2023-02-23T14:16:36.337692+00:00" + } # type: ignore[comparison-overlap] + + assert await transform({"required": None}, DatetimeDict, use_async) == {"required": None} + + +@parametrize +@pytest.mark.asyncio +async def test_union_datetime(use_async: bool) -> None: + dt = datetime.fromisoformat("2023-02-23T14:16:36.337692+00:00") + assert await transform({"union": dt}, DatetimeDict, use_async) == { # type: ignore[comparison-overlap] + "union": "2023-02-23T14:16:36.337692+00:00" + } + + assert await transform({"union": "foo"}, DatetimeDict, use_async) == {"union": "foo"} + + +@parametrize +@pytest.mark.asyncio +async def test_nested_list_iso6801_format(use_async: bool) -> None: + dt1 = datetime.fromisoformat("2023-02-23T14:16:36.337692+00:00") + dt2 = parse_datetime("2022-01-15T06:34:23Z") + assert await transform({"list_": [dt1, dt2]}, DatetimeDict, use_async) == { # type: ignore[comparison-overlap] + "list_": ["2023-02-23T14:16:36.337692+00:00", "2022-01-15T06:34:23+00:00"] + } + + +@parametrize +@pytest.mark.asyncio +async def test_datetime_custom_format(use_async: bool) -> None: + dt = parse_datetime("2022-01-15T06:34:23Z") + + result = await transform(dt, Annotated[datetime, PropertyInfo(format="custom", format_template="%H")], use_async) + assert result == "06" # type: ignore[comparison-overlap] + + +class DateDictWithRequiredAlias(TypedDict, total=False): + required_prop: Required[Annotated[date, PropertyInfo(format="iso8601", alias="prop")]] + + +@parametrize +@pytest.mark.asyncio +async def test_datetime_with_alias(use_async: bool) -> None: + assert await transform({"required_prop": None}, DateDictWithRequiredAlias, use_async) == {"prop": None} # type: ignore[comparison-overlap] + assert await transform( + {"required_prop": date.fromisoformat("2023-02-23")}, DateDictWithRequiredAlias, use_async + ) == {"prop": "2023-02-23"} # type: ignore[comparison-overlap] + + +class MyModel(BaseModel): + foo: str + + +@parametrize +@pytest.mark.asyncio +async def test_pydantic_model_to_dictionary(use_async: bool) -> None: + assert cast(Any, await transform(MyModel(foo="hi!"), Any, use_async)) == {"foo": "hi!"} + assert cast(Any, await transform(MyModel.construct(foo="hi!"), Any, use_async)) == {"foo": "hi!"} + + +@parametrize +@pytest.mark.asyncio +async def test_pydantic_empty_model(use_async: bool) -> None: + assert cast(Any, await transform(MyModel.construct(), Any, use_async)) == {} + + +@parametrize +@pytest.mark.asyncio +async def test_pydantic_unknown_field(use_async: bool) -> None: + assert cast(Any, await transform(MyModel.construct(my_untyped_field=True), Any, use_async)) == { + "my_untyped_field": True + } + + +@parametrize +@pytest.mark.asyncio +async def test_pydantic_mismatched_types(use_async: bool) -> None: + model = MyModel.construct(foo=True) + if PYDANTIC_V1: + params = await transform(model, Any, use_async) + else: + with pytest.warns(UserWarning): + params = await transform(model, Any, use_async) + assert cast(Any, params) == {"foo": True} + + +@parametrize +@pytest.mark.asyncio +async def test_pydantic_mismatched_object_type(use_async: bool) -> None: + model = MyModel.construct(foo=MyModel.construct(hello="world")) + if PYDANTIC_V1: + params = await transform(model, Any, use_async) + else: + with pytest.warns(UserWarning): + params = await transform(model, Any, use_async) + assert cast(Any, params) == {"foo": {"hello": "world"}} + + +class ModelNestedObjects(BaseModel): + nested: MyModel + + +@parametrize +@pytest.mark.asyncio +async def test_pydantic_nested_objects(use_async: bool) -> None: + model = ModelNestedObjects.construct(nested={"foo": "stainless"}) + assert isinstance(model.nested, MyModel) + assert cast(Any, await transform(model, Any, use_async)) == {"nested": {"foo": "stainless"}} + + +class ModelWithDefaultField(BaseModel): + foo: str + with_none_default: Union[str, None] = None + with_str_default: str = "foo" + + +@parametrize +@pytest.mark.asyncio +async def test_pydantic_default_field(use_async: bool) -> None: + # should be excluded when defaults are used + model = ModelWithDefaultField.construct() + assert model.with_none_default is None + assert model.with_str_default == "foo" + assert cast(Any, await transform(model, Any, use_async)) == {} + + # should be included when the default value is explicitly given + model = ModelWithDefaultField.construct(with_none_default=None, with_str_default="foo") + assert model.with_none_default is None + assert model.with_str_default == "foo" + assert cast(Any, await transform(model, Any, use_async)) == {"with_none_default": None, "with_str_default": "foo"} + + # should be included when a non-default value is explicitly given + model = ModelWithDefaultField.construct(with_none_default="bar", with_str_default="baz") + assert model.with_none_default == "bar" + assert model.with_str_default == "baz" + assert cast(Any, await transform(model, Any, use_async)) == {"with_none_default": "bar", "with_str_default": "baz"} + + +class TypedDictIterableUnion(TypedDict): + foo: Annotated[Union[Bar8, Iterable[Baz8]], PropertyInfo(alias="FOO")] + + +class Bar8(TypedDict): + foo_bar: Annotated[str, PropertyInfo(alias="fooBar")] + + +class Baz8(TypedDict): + foo_baz: Annotated[str, PropertyInfo(alias="fooBaz")] + + +@parametrize +@pytest.mark.asyncio +async def test_iterable_of_dictionaries(use_async: bool) -> None: + assert await transform({"foo": [{"foo_baz": "bar"}]}, TypedDictIterableUnion, use_async) == { + "FOO": [{"fooBaz": "bar"}] + } + assert cast(Any, await transform({"foo": ({"foo_baz": "bar"},)}, TypedDictIterableUnion, use_async)) == { + "FOO": [{"fooBaz": "bar"}] + } + + def my_iter() -> Iterable[Baz8]: + yield {"foo_baz": "hello"} + yield {"foo_baz": "world"} + + assert await transform({"foo": my_iter()}, TypedDictIterableUnion, use_async) == { + "FOO": [{"fooBaz": "hello"}, {"fooBaz": "world"}] + } + + +@parametrize +@pytest.mark.asyncio +async def test_dictionary_items(use_async: bool) -> None: + class DictItems(TypedDict): + foo_baz: Annotated[str, PropertyInfo(alias="fooBaz")] + + assert await transform({"foo": {"foo_baz": "bar"}}, Dict[str, DictItems], use_async) == {"foo": {"fooBaz": "bar"}} + + +class TypedDictIterableUnionStr(TypedDict): + foo: Annotated[Union[str, Iterable[Baz8]], PropertyInfo(alias="FOO")] + + +@parametrize +@pytest.mark.asyncio +async def test_iterable_union_str(use_async: bool) -> None: + assert await transform({"foo": "bar"}, TypedDictIterableUnionStr, use_async) == {"FOO": "bar"} + assert cast(Any, await transform(iter([{"foo_baz": "bar"}]), Union[str, Iterable[Baz8]], use_async)) == [ + {"fooBaz": "bar"} + ] + + +class TypedDictBase64Input(TypedDict): + foo: Annotated[Union[str, Base64FileInput], PropertyInfo(format="base64")] + + +@parametrize +@pytest.mark.asyncio +async def test_base64_file_input(use_async: bool) -> None: + # strings are left as-is + assert await transform({"foo": "bar"}, TypedDictBase64Input, use_async) == {"foo": "bar"} + + # pathlib.Path is automatically converted to base64 + assert await transform({"foo": SAMPLE_FILE_PATH}, TypedDictBase64Input, use_async) == { + "foo": "SGVsbG8sIHdvcmxkIQo=" + } # type: ignore[comparison-overlap] + + # io instances are automatically converted to base64 + assert await transform({"foo": io.StringIO("Hello, world!")}, TypedDictBase64Input, use_async) == { + "foo": "SGVsbG8sIHdvcmxkIQ==" + } # type: ignore[comparison-overlap] + assert await transform({"foo": io.BytesIO(b"Hello, world!")}, TypedDictBase64Input, use_async) == { + "foo": "SGVsbG8sIHdvcmxkIQ==" + } # type: ignore[comparison-overlap] + + +@parametrize +@pytest.mark.asyncio +async def test_transform_skipping(use_async: bool) -> None: + # lists of ints are left as-is + data = [1, 2, 3] + assert await transform(data, List[int], use_async) is data + + # iterables of ints are converted to a list + data = iter([1, 2, 3]) + assert await transform(data, Iterable[int], use_async) == [1, 2, 3] + + +@parametrize +@pytest.mark.asyncio +async def test_strips_notgiven(use_async: bool) -> None: + assert await transform({"foo_bar": "bar"}, Foo1, use_async) == {"fooBar": "bar"} + assert await transform({"foo_bar": not_given}, Foo1, use_async) == {} + + +@parametrize +@pytest.mark.asyncio +async def test_strips_omit(use_async: bool) -> None: + assert await transform({"foo_bar": "bar"}, Foo1, use_async) == {"fooBar": "bar"} + assert await transform({"foo_bar": omit}, Foo1, use_async) == {} diff --git a/tests/test_utils/test_datetime_parse.py b/tests/test_utils/test_datetime_parse.py new file mode 100644 index 00000000..ba09afb5 --- /dev/null +++ b/tests/test_utils/test_datetime_parse.py @@ -0,0 +1,110 @@ +""" +Copied from https://github.com/pydantic/pydantic/blob/v1.10.22/tests/test_datetime_parse.py +with modifications so it works without pydantic v1 imports. +""" + +from typing import Type, Union +from datetime import date, datetime, timezone, timedelta + +import pytest + +from writerai._utils import parse_date, parse_datetime + + +def create_tz(minutes: int) -> timezone: + return timezone(timedelta(minutes=minutes)) + + +@pytest.mark.parametrize( + "value,result", + [ + # Valid inputs + ("1494012444.883309", date(2017, 5, 5)), + (b"1494012444.883309", date(2017, 5, 5)), + (1_494_012_444.883_309, date(2017, 5, 5)), + ("1494012444", date(2017, 5, 5)), + (1_494_012_444, date(2017, 5, 5)), + (0, date(1970, 1, 1)), + ("2012-04-23", date(2012, 4, 23)), + (b"2012-04-23", date(2012, 4, 23)), + ("2012-4-9", date(2012, 4, 9)), + (date(2012, 4, 9), date(2012, 4, 9)), + (datetime(2012, 4, 9, 12, 15), date(2012, 4, 9)), + # Invalid inputs + ("x20120423", ValueError), + ("2012-04-56", ValueError), + (19_999_999_999, date(2603, 10, 11)), # just before watershed + (20_000_000_001, date(1970, 8, 20)), # just after watershed + (1_549_316_052, date(2019, 2, 4)), # nowish in s + (1_549_316_052_104, date(2019, 2, 4)), # nowish in ms + (1_549_316_052_104_324, date(2019, 2, 4)), # nowish in μs + (1_549_316_052_104_324_096, date(2019, 2, 4)), # nowish in ns + ("infinity", date(9999, 12, 31)), + ("inf", date(9999, 12, 31)), + (float("inf"), date(9999, 12, 31)), + ("infinity ", date(9999, 12, 31)), + (int("1" + "0" * 100), date(9999, 12, 31)), + (1e1000, date(9999, 12, 31)), + ("-infinity", date(1, 1, 1)), + ("-inf", date(1, 1, 1)), + ("nan", ValueError), + ], +) +def test_date_parsing(value: Union[str, bytes, int, float], result: Union[date, Type[Exception]]) -> None: + if type(result) == type and issubclass(result, Exception): # pyright: ignore[reportUnnecessaryIsInstance] + with pytest.raises(result): + parse_date(value) + else: + assert parse_date(value) == result + + +@pytest.mark.parametrize( + "value,result", + [ + # Valid inputs + # values in seconds + ("1494012444.883309", datetime(2017, 5, 5, 19, 27, 24, 883_309, tzinfo=timezone.utc)), + (1_494_012_444.883_309, datetime(2017, 5, 5, 19, 27, 24, 883_309, tzinfo=timezone.utc)), + ("1494012444", datetime(2017, 5, 5, 19, 27, 24, tzinfo=timezone.utc)), + (b"1494012444", datetime(2017, 5, 5, 19, 27, 24, tzinfo=timezone.utc)), + (1_494_012_444, datetime(2017, 5, 5, 19, 27, 24, tzinfo=timezone.utc)), + # values in ms + ("1494012444000.883309", datetime(2017, 5, 5, 19, 27, 24, 883, tzinfo=timezone.utc)), + ("-1494012444000.883309", datetime(1922, 8, 29, 4, 32, 35, 999117, tzinfo=timezone.utc)), + (1_494_012_444_000, datetime(2017, 5, 5, 19, 27, 24, tzinfo=timezone.utc)), + ("2012-04-23T09:15:00", datetime(2012, 4, 23, 9, 15)), + ("2012-4-9 4:8:16", datetime(2012, 4, 9, 4, 8, 16)), + ("2012-04-23T09:15:00Z", datetime(2012, 4, 23, 9, 15, 0, 0, timezone.utc)), + ("2012-4-9 4:8:16-0320", datetime(2012, 4, 9, 4, 8, 16, 0, create_tz(-200))), + ("2012-04-23T10:20:30.400+02:30", datetime(2012, 4, 23, 10, 20, 30, 400_000, create_tz(150))), + ("2012-04-23T10:20:30.400+02", datetime(2012, 4, 23, 10, 20, 30, 400_000, create_tz(120))), + ("2012-04-23T10:20:30.400-02", datetime(2012, 4, 23, 10, 20, 30, 400_000, create_tz(-120))), + (b"2012-04-23T10:20:30.400-02", datetime(2012, 4, 23, 10, 20, 30, 400_000, create_tz(-120))), + (datetime(2017, 5, 5), datetime(2017, 5, 5)), + (0, datetime(1970, 1, 1, 0, 0, 0, tzinfo=timezone.utc)), + # Invalid inputs + ("x20120423091500", ValueError), + ("2012-04-56T09:15:90", ValueError), + ("2012-04-23T11:05:00-25:00", ValueError), + (19_999_999_999, datetime(2603, 10, 11, 11, 33, 19, tzinfo=timezone.utc)), # just before watershed + (20_000_000_001, datetime(1970, 8, 20, 11, 33, 20, 1000, tzinfo=timezone.utc)), # just after watershed + (1_549_316_052, datetime(2019, 2, 4, 21, 34, 12, 0, tzinfo=timezone.utc)), # nowish in s + (1_549_316_052_104, datetime(2019, 2, 4, 21, 34, 12, 104_000, tzinfo=timezone.utc)), # nowish in ms + (1_549_316_052_104_324, datetime(2019, 2, 4, 21, 34, 12, 104_324, tzinfo=timezone.utc)), # nowish in μs + (1_549_316_052_104_324_096, datetime(2019, 2, 4, 21, 34, 12, 104_324, tzinfo=timezone.utc)), # nowish in ns + ("infinity", datetime(9999, 12, 31, 23, 59, 59, 999999)), + ("inf", datetime(9999, 12, 31, 23, 59, 59, 999999)), + ("inf ", datetime(9999, 12, 31, 23, 59, 59, 999999)), + (1e50, datetime(9999, 12, 31, 23, 59, 59, 999999)), + (float("inf"), datetime(9999, 12, 31, 23, 59, 59, 999999)), + ("-infinity", datetime(1, 1, 1, 0, 0)), + ("-inf", datetime(1, 1, 1, 0, 0)), + ("nan", ValueError), + ], +) +def test_datetime_parsing(value: Union[str, bytes, int, float], result: Union[datetime, Type[Exception]]) -> None: + if type(result) == type and issubclass(result, Exception): # pyright: ignore[reportUnnecessaryIsInstance] + with pytest.raises(result): + parse_datetime(value) + else: + assert parse_datetime(value) == result diff --git a/tests/test_utils/test_json.py b/tests/test_utils/test_json.py new file mode 100644 index 00000000..9d63c901 --- /dev/null +++ b/tests/test_utils/test_json.py @@ -0,0 +1,126 @@ +from __future__ import annotations + +import datetime +from typing import Union + +import pydantic + +from writerai import _compat +from writerai._utils._json import openapi_dumps + + +class TestOpenapiDumps: + def test_basic(self) -> None: + data = {"key": "value", "number": 42} + json_bytes = openapi_dumps(data) + assert json_bytes == b'{"key":"value","number":42}' + + def test_datetime_serialization(self) -> None: + dt = datetime.datetime(2023, 1, 1, 12, 0, 0) + data = {"datetime": dt} + json_bytes = openapi_dumps(data) + assert json_bytes == b'{"datetime":"2023-01-01T12:00:00"}' + + def test_pydantic_model_serialization(self) -> None: + class User(pydantic.BaseModel): + first_name: str + last_name: str + age: int + + model_instance = User(first_name="John", last_name="Kramer", age=83) + data = {"model": model_instance} + json_bytes = openapi_dumps(data) + assert json_bytes == b'{"model":{"first_name":"John","last_name":"Kramer","age":83}}' + + def test_pydantic_model_with_default_values(self) -> None: + class User(pydantic.BaseModel): + name: str + role: str = "user" + active: bool = True + score: int = 0 + + model_instance = User(name="Alice") + data = {"model": model_instance} + json_bytes = openapi_dumps(data) + assert json_bytes == b'{"model":{"name":"Alice"}}' + + def test_pydantic_model_with_default_values_overridden(self) -> None: + class User(pydantic.BaseModel): + name: str + role: str = "user" + active: bool = True + + model_instance = User(name="Bob", role="admin", active=False) + data = {"model": model_instance} + json_bytes = openapi_dumps(data) + assert json_bytes == b'{"model":{"name":"Bob","role":"admin","active":false}}' + + def test_pydantic_model_with_alias(self) -> None: + class User(pydantic.BaseModel): + first_name: str = pydantic.Field(alias="firstName") + last_name: str = pydantic.Field(alias="lastName") + + model_instance = User(firstName="John", lastName="Doe") + data = {"model": model_instance} + json_bytes = openapi_dumps(data) + assert json_bytes == b'{"model":{"firstName":"John","lastName":"Doe"}}' + + def test_pydantic_model_with_alias_and_default(self) -> None: + class User(pydantic.BaseModel): + user_name: str = pydantic.Field(alias="userName") + user_role: str = pydantic.Field(default="member", alias="userRole") + is_active: bool = pydantic.Field(default=True, alias="isActive") + + model_instance = User(userName="charlie") + data = {"model": model_instance} + json_bytes = openapi_dumps(data) + assert json_bytes == b'{"model":{"userName":"charlie"}}' + + model_with_overrides = User(userName="diana", userRole="admin", isActive=False) + data = {"model": model_with_overrides} + json_bytes = openapi_dumps(data) + assert json_bytes == b'{"model":{"userName":"diana","userRole":"admin","isActive":false}}' + + def test_pydantic_model_with_nested_models_and_defaults(self) -> None: + class Address(pydantic.BaseModel): + street: str + city: str = "Unknown" + + class User(pydantic.BaseModel): + name: str + address: Address + verified: bool = False + + if _compat.PYDANTIC_V1: + # to handle forward references in Pydantic v1 + User.update_forward_refs(**locals()) # type: ignore[reportDeprecated] + + address = Address(street="123 Main St") + user = User(name="Diana", address=address) + data = {"user": user} + json_bytes = openapi_dumps(data) + assert json_bytes == b'{"user":{"name":"Diana","address":{"street":"123 Main St"}}}' + + address_with_city = Address(street="456 Oak Ave", city="Boston") + user_verified = User(name="Eve", address=address_with_city, verified=True) + data = {"user": user_verified} + json_bytes = openapi_dumps(data) + assert ( + json_bytes == b'{"user":{"name":"Eve","address":{"street":"456 Oak Ave","city":"Boston"},"verified":true}}' + ) + + def test_pydantic_model_with_optional_fields(self) -> None: + class User(pydantic.BaseModel): + name: str + email: Union[str, None] + phone: Union[str, None] + + model_with_none = User(name="Eve", email=None, phone=None) + data = {"model": model_with_none} + json_bytes = openapi_dumps(data) + assert json_bytes == b'{"model":{"name":"Eve","email":null,"phone":null}}' + + model_with_values = User(name="Frank", email="frank@example.com", phone=None) + data = {"model": model_with_values} + json_bytes = openapi_dumps(data) + assert json_bytes == b'{"model":{"name":"Frank","email":"frank@example.com","phone":null}}' diff --git a/tests/test_utils/test_path.py b/tests/test_utils/test_path.py new file mode 100644 index 00000000..b42e3d87 --- /dev/null +++ b/tests/test_utils/test_path.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +from typing import Any + +import pytest + +from writerai._utils._path import path_template + + +@pytest.mark.parametrize( + "template, kwargs, expected", + [ + ("/v1/{id}", dict(id="abc"), "/v1/abc"), + ("/v1/{a}/{b}", dict(a="x", b="y"), "/v1/x/y"), + ("/v1/{a}{b}/path/{c}?val={d}#{e}", dict(a="x", b="y", c="z", d="u", e="v"), "/v1/xy/path/z?val=u#v"), + ("/{w}/{w}", dict(w="echo"), "/echo/echo"), + ("/v1/static", {}, "/v1/static"), + ("", {}, ""), + ("/v1/?q={n}&count=10", dict(n=42), "/v1/?q=42&count=10"), + ("/v1/{v}", dict(v=None), "/v1/null"), + ("/v1/{v}", dict(v=True), "/v1/true"), + ("/v1/{v}", dict(v=False), "/v1/false"), + ("/v1/{v}", dict(v=".hidden"), "/v1/.hidden"), # dot prefix ok + ("/v1/{v}", dict(v="file.txt"), "/v1/file.txt"), # dot in middle ok + ("/v1/{v}", dict(v="..."), "/v1/..."), # triple dot ok + ("/v1/{a}{b}", dict(a=".", b="txt"), "/v1/.txt"), # dot var combining with adjacent to be ok + ("/items?q={v}#{f}", dict(v=".", f=".."), "/items?q=.#.."), # dots in query/fragment are fine + ( + "/v1/{a}?query={b}", + dict(a="../../other/endpoint", b="a&bad=true"), + "/v1/..%2F..%2Fother%2Fendpoint?query=a%26bad%3Dtrue", + ), + ("/v1/{val}", dict(val="a/b/c"), "/v1/a%2Fb%2Fc"), + ("/v1/{val}", dict(val="a/b/c?query=value"), "/v1/a%2Fb%2Fc%3Fquery=value"), + ("/v1/{val}", dict(val="a/b/c?query=value&bad=true"), "/v1/a%2Fb%2Fc%3Fquery=value&bad=true"), + ("/v1/{val}", dict(val="%20"), "/v1/%2520"), # escapes escape sequences in input + # Query: slash and ? are safe, # is not + ("/items?q={v}", dict(v="a/b"), "/items?q=a/b"), + ("/items?q={v}", dict(v="a?b"), "/items?q=a?b"), + ("/items?q={v}", dict(v="a#b"), "/items?q=a%23b"), + ("/items?q={v}", dict(v="a b"), "/items?q=a%20b"), + # Fragment: slash and ? are safe + ("/docs#{v}", dict(v="a/b"), "/docs#a/b"), + ("/docs#{v}", dict(v="a?b"), "/docs#a?b"), + # Path: slash, ? and # are all encoded + ("/v1/{v}", dict(v="a/b"), "/v1/a%2Fb"), + ("/v1/{v}", dict(v="a?b"), "/v1/a%3Fb"), + ("/v1/{v}", dict(v="a#b"), "/v1/a%23b"), + # same var encoded differently by component + ( + "/v1/{v}?q={v}#{v}", + dict(v="a/b?c#d"), + "/v1/a%2Fb%3Fc%23d?q=a/b?c%23d#a/b?c%23d", + ), + ("/v1/{val}", dict(val="x?admin=true"), "/v1/x%3Fadmin=true"), # query injection + ("/v1/{val}", dict(val="x#admin"), "/v1/x%23admin"), # fragment injection + ], +) +def test_interpolation(template: str, kwargs: dict[str, Any], expected: str) -> None: + assert path_template(template, **kwargs) == expected + + +def test_missing_kwarg_raises_key_error() -> None: + with pytest.raises(KeyError, match="org_id"): + path_template("/v1/{org_id}") + + +@pytest.mark.parametrize( + "template, kwargs", + [ + ("{a}/path", dict(a=".")), + ("{a}/path", dict(a="..")), + ("/v1/{a}", dict(a=".")), + ("/v1/{a}", dict(a="..")), + ("/v1/{a}/path", dict(a=".")), + ("/v1/{a}/path", dict(a="..")), + ("/v1/{a}{b}", dict(a=".", b=".")), # adjacent vars → ".." + ("/v1/{a}.", dict(a=".")), # var + static → ".." + ("/v1/{a}{b}", dict(a="", b=".")), # empty + dot → "." + ("/v1/%2e/{x}", dict(x="ok")), # encoded dot in static text + ("/v1/%2e./{x}", dict(x="ok")), # mixed encoded ".." in static + ("/v1/.%2E/{x}", dict(x="ok")), # mixed encoded ".." in static + ("/v1/{v}?q=1", dict(v="..")), + ("/v1/{v}#frag", dict(v="..")), + ], +) +def test_dot_segment_rejected(template: str, kwargs: dict[str, Any]) -> None: + with pytest.raises(ValueError, match="dot-segment"): + path_template(template, **kwargs) diff --git a/tests/test_utils/test_proxy.py b/tests/test_utils/test_proxy.py new file mode 100644 index 00000000..8d75d83d --- /dev/null +++ b/tests/test_utils/test_proxy.py @@ -0,0 +1,34 @@ +import operator +from typing import Any +from typing_extensions import override + +from writerai._utils import LazyProxy + + +class RecursiveLazyProxy(LazyProxy[Any]): + @override + def __load__(self) -> Any: + return self + + def __call__(self, *_args: Any, **_kwds: Any) -> Any: + raise RuntimeError("This should never be called!") + + +def test_recursive_proxy() -> None: + proxy = RecursiveLazyProxy() + assert repr(proxy) == "RecursiveLazyProxy" + assert str(proxy) == "RecursiveLazyProxy" + assert dir(proxy) == [] + assert type(proxy).__name__ == "RecursiveLazyProxy" + assert type(operator.attrgetter("name.foo.bar.baz")(proxy)).__name__ == "RecursiveLazyProxy" + + +def test_isinstance_does_not_error() -> None: + class AlwaysErrorProxy(LazyProxy[Any]): + @override + def __load__(self) -> Any: + raise RuntimeError("Mocking missing dependency") + + proxy = AlwaysErrorProxy() + assert not isinstance(proxy, dict) + assert isinstance(proxy, LazyProxy) diff --git a/tests/test_utils/test_typing.py b/tests/test_utils/test_typing.py new file mode 100644 index 00000000..8237ec21 --- /dev/null +++ b/tests/test_utils/test_typing.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +from typing import Generic, TypeVar, cast + +from writerai._utils import extract_type_var_from_base + +_T = TypeVar("_T") +_T2 = TypeVar("_T2") +_T3 = TypeVar("_T3") + + +class BaseGeneric(Generic[_T]): ... + + +class SubclassGeneric(BaseGeneric[_T]): ... + + +class BaseGenericMultipleTypeArgs(Generic[_T, _T2, _T3]): ... + + +class SubclassGenericMultipleTypeArgs(BaseGenericMultipleTypeArgs[_T, _T2, _T3]): ... + + +class SubclassDifferentOrderGenericMultipleTypeArgs(BaseGenericMultipleTypeArgs[_T2, _T, _T3]): ... + + +def test_extract_type_var() -> None: + assert ( + extract_type_var_from_base( + BaseGeneric[int], + index=0, + generic_bases=cast("tuple[type, ...]", (BaseGeneric,)), + ) + == int + ) + + +def test_extract_type_var_generic_subclass() -> None: + assert ( + extract_type_var_from_base( + SubclassGeneric[int], + index=0, + generic_bases=cast("tuple[type, ...]", (BaseGeneric,)), + ) + == int + ) + + +def test_extract_type_var_multiple() -> None: + typ = BaseGenericMultipleTypeArgs[int, str, None] + + generic_bases = cast("tuple[type, ...]", (BaseGenericMultipleTypeArgs,)) + assert extract_type_var_from_base(typ, index=0, generic_bases=generic_bases) == int + assert extract_type_var_from_base(typ, index=1, generic_bases=generic_bases) == str + assert extract_type_var_from_base(typ, index=2, generic_bases=generic_bases) == type(None) + + +def test_extract_type_var_generic_subclass_multiple() -> None: + typ = SubclassGenericMultipleTypeArgs[int, str, None] + + generic_bases = cast("tuple[type, ...]", (BaseGenericMultipleTypeArgs,)) + assert extract_type_var_from_base(typ, index=0, generic_bases=generic_bases) == int + assert extract_type_var_from_base(typ, index=1, generic_bases=generic_bases) == str + assert extract_type_var_from_base(typ, index=2, generic_bases=generic_bases) == type(None) + + +def test_extract_type_var_generic_subclass_different_ordering_multiple() -> None: + typ = SubclassDifferentOrderGenericMultipleTypeArgs[int, str, None] + + generic_bases = cast("tuple[type, ...]", (BaseGenericMultipleTypeArgs,)) + assert extract_type_var_from_base(typ, index=0, generic_bases=generic_bases) == int + assert extract_type_var_from_base(typ, index=1, generic_bases=generic_bases) == str + assert extract_type_var_from_base(typ, index=2, generic_bases=generic_bases) == type(None) diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 00000000..7e3ba26a --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,167 @@ +from __future__ import annotations + +import os +import inspect +import traceback +import contextlib +from typing import Any, TypeVar, Iterator, Sequence, cast +from datetime import date, datetime +from typing_extensions import Literal, get_args, get_origin, assert_type + +from writerai._types import Omit, NoneType +from writerai._utils import ( + is_dict, + is_list, + is_list_type, + is_union_type, + extract_type_arg, + is_sequence_type, + is_annotated_type, + is_type_alias_type, +) +from writerai._compat import PYDANTIC_V1, field_outer_type, get_model_fields +from writerai._models import BaseModel + +BaseModelT = TypeVar("BaseModelT", bound=BaseModel) + + +def assert_matches_model(model: type[BaseModelT], value: BaseModelT, *, path: list[str]) -> bool: + for name, field in get_model_fields(model).items(): + field_value = getattr(value, name) + if PYDANTIC_V1: + # in v1 nullability was structured differently + # https://docs.pydantic.dev/2.0/migration/#required-optional-and-nullable-fields + allow_none = getattr(field, "allow_none", False) + else: + allow_none = False + + assert_matches_type( + field_outer_type(field), + field_value, + path=[*path, name], + allow_none=allow_none, + ) + + return True + + +# Note: the `path` argument is only used to improve error messages when `--showlocals` is used +def assert_matches_type( + type_: Any, + value: object, + *, + path: list[str], + allow_none: bool = False, +) -> None: + if is_type_alias_type(type_): + type_ = type_.__value__ + + # unwrap `Annotated[T, ...]` -> `T` + if is_annotated_type(type_): + type_ = extract_type_arg(type_, 0) + + if allow_none and value is None: + return + + if type_ is None or type_ is NoneType: + assert value is None + return + + origin = get_origin(type_) or type_ + + if is_list_type(type_): + return _assert_list_type(type_, value) + + if is_sequence_type(type_): + assert isinstance(value, Sequence) + inner_type = get_args(type_)[0] + for entry in value: # type: ignore + assert_type(inner_type, entry) # type: ignore + return + + if origin == str: + assert isinstance(value, str) + elif origin == int: + assert isinstance(value, int) + elif origin == bool: + assert isinstance(value, bool) + elif origin == float: + assert isinstance(value, float) + elif origin == bytes: + assert isinstance(value, bytes) + elif origin == datetime: + assert isinstance(value, datetime) + elif origin == date: + assert isinstance(value, date) + elif origin == object: + # nothing to do here, the expected type is unknown + pass + elif origin == Literal: + assert value in get_args(type_) + elif origin == dict: + assert is_dict(value) + + args = get_args(type_) + key_type = args[0] + items_type = args[1] + + for key, item in value.items(): + assert_matches_type(key_type, key, path=[*path, ""]) + assert_matches_type(items_type, item, path=[*path, ""]) + elif is_union_type(type_): + variants = get_args(type_) + + try: + none_index = variants.index(type(None)) + except ValueError: + pass + else: + # special case Optional[T] for better error messages + if len(variants) == 2: + if value is None: + # valid + return + + return assert_matches_type(type_=variants[not none_index], value=value, path=path) + + for i, variant in enumerate(variants): + try: + assert_matches_type(variant, value, path=[*path, f"variant {i}"]) + return + except AssertionError: + traceback.print_exc() + continue + + raise AssertionError("Did not match any variants") + elif issubclass(origin, BaseModel): + assert isinstance(value, type_) + assert assert_matches_model(type_, cast(Any, value), path=path) + elif inspect.isclass(origin) and origin.__name__ == "HttpxBinaryResponseContent": + assert value.__class__.__name__ == "HttpxBinaryResponseContent" + else: + assert None, f"Unhandled field type: {type_}" + + +def _assert_list_type(type_: type[object], value: object) -> None: + assert is_list(value) + + inner_type = get_args(type_)[0] + for entry in value: + assert_type(inner_type, entry) # type: ignore + + +@contextlib.contextmanager +def update_env(**new_env: str | Omit) -> Iterator[None]: + old = os.environ.copy() + + try: + for name, value in new_env.items(): + if isinstance(value, Omit): + os.environ.pop(name, None) + else: + os.environ[name] = value + + yield None + finally: + os.environ.clear() + os.environ.update(old)