Skip to content

Commit

Permalink
Short examples and ux fixes (#1508)
Browse files Browse the repository at this point in the history
This PR adds `examples/short_examples.ipynb` that runs through quick
examples on using agent0 to:
- Analyze positions
- Close all mature positions
- Write a quick policy

This involves the following changes:
- Agents now can be passed in a public address instead of a private
address for analysis functions.
- Adding a `get_hyperdrive_pools_from_registry` that returns a list of
`Hyperdrive` objects given a registry address.
- `agent.get_positions` now allows a list of `pool_filters`.
  • Loading branch information
Sheng Lundquist authored May 31, 2024
1 parent 8f11205 commit a3cbc79
Show file tree
Hide file tree
Showing 8 changed files with 347 additions and 45 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ from agent0 import LocalHyperdrive, LocalChain
# Initialize
chain = LocalChain()
hyperdrive = LocalHyperdrive(chain)
hyperdrive_agent0 = hyperdrive.init_agent(base=FixedPoint(100_000), eth=FixedPoint(10))
hyperdrive_agent0 = chain.init_agent(base=FixedPoint(100_000), eth=FixedPoint(10), pool=hyperdrive)

# Run trades
chain.advance_time(datetime.timedelta(weeks=1))
Expand All @@ -48,11 +48,11 @@ close_event = hyperdrive_agent0.close_long(
)

# Analyze
pool_state = hyperdrive.get_pool_state(coerce_float=True)
pool_state.plot(x="block_number", y="longs_outstanding", kind="line")
pool_info = hyperdrive.get_pool_info(coerce_float=True)
pool_info.plot(x="block_number", y="longs_outstanding", kind="line")
```

See our [tutorial notebook](examples/tutorial.ipynb) for more information, including details on executing trades on remote chains.
See our [tutorial notebook](examples/tutorial.ipynb) and [examples notebook](examples/short_examples.ipynb) for more information, including details on executing trades on remote chains.

## Install

Expand Down
165 changes: 165 additions & 0 deletions examples/short_examples.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from dotenv import load_dotenv\n",
"\n",
"load_dotenv(\".env\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Query all your positions\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from agent0 import Chain, Hyperdrive\n",
"import os\n",
"\n",
"# PUBLIC_ADDRESS = \"0xPUBLIC_ADDRESS\"\n",
"# RPC_URI = \"https://sepolia.rpc.url\"\n",
"PUBLIC_ADDRESS = os.getenv(\"PUBLIC_ADDRESS\")\n",
"RPC_URI = os.getenv(\"RPC_URI\")\n",
"\n",
"# Address of Hyperdrive Sepolia registry\n",
"REGISTRY_ADDRESS = \"0x4ba58147e50e57e71177cfedb1fac0303f216104\"\n",
"\n",
"## View open and closed positions in all pools\n",
"chain = Chain(RPC_URI)\n",
"agent = chain.init_agent(public_address=PUBLIC_ADDRESS)\n",
"registered_pools = Hyperdrive.get_hyperdrive_pools_from_registry(\n",
" chain,\n",
" registry_address=REGISTRY_ADDRESS,\n",
")\n",
"agent.get_positions(pool_filter=registered_pools, show_closed_positions=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Automate withdrawing funds from matured positions\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# PRIVATE_KEY = \"0xPRIVATE_KEY\"\n",
"PRIVATE_KEY = os.getenv(\"PRIVATE_KEY\")\n",
"\n",
"# Initialize agent with private key for transactions\n",
"agent = chain.init_agent(private_key=PRIVATE_KEY)\n",
"for pool in registered_pools:\n",
" # Close all mature longs\n",
" for long in agent.get_longs(pool=pool):\n",
" if long.maturity_time <= chain.block_time():\n",
" print(f\"Close long-{long.maturity_time} on {pool.name}\")\n",
" agent.close_long(maturity_time=long.maturity_time, bonds=long.balance, pool=pool)\n",
" # Close all mature shorts\n",
" for short in agent.get_shorts(pool=pool):\n",
" if short.maturity_time <= chain.block_time():\n",
" print(f\"Close short-{short.maturity_time} on {pool.name}\")\n",
" agent.close_short(maturity_time=short.maturity_time, bonds=short.balance, pool=pool)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Write policies in Python\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from agent0 import HyperdriveBasePolicy, open_long_trade\n",
"from fixedpointmath import FixedPoint\n",
"from dataclasses import dataclass"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class OpenLongPolicy(HyperdriveBasePolicy):\n",
" @dataclass(kw_only=True)\n",
" class Config(HyperdriveBasePolicy.Config):\n",
" fixed_rate_threshold: FixedPoint\n",
" open_long_amount: FixedPoint\n",
"\n",
" def action(self, interface, wallet):\n",
" \"\"\"Get agent actions for teh current block\n",
"\n",
" Action fn returns the trades to be executed\n",
" at a given moment in time.\n",
" \"\"\"\n",
" done_trading = False # Never done trading\n",
"\n",
" # If no longs in wallet, we check our fixed rate\n",
" # threshold and open the long if threshold reached.\n",
" if len(wallet.longs) == 0:\n",
" if interface.calc_spot_rate() > self.config.fixed_rate_threshold:\n",
" return [open_long_trade(self.config.open_long_amount)], done_trading\n",
"\n",
" # We don't do any trades otherwise\n",
" return [], done_trading"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"agent.set_active(\n",
" policy=OpenLongPolicy,\n",
" policy_config=OpenLongPolicy.Config(\n",
" fixed_rate_threshold=FixedPoint(0.06),\n",
" open_long_amount=FixedPoint(\"100_000\"),\n",
" ),\n",
")\n",
"agent.execute_policy_action(pool=registered_pools[0])"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
10 changes: 6 additions & 4 deletions src/agent0/chainsync/db/hyperdrive/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,7 @@ def get_all_traders(session: Session, hyperdrive_address: str | None = None) ->
# pylint: disable=too-many-arguments
def get_position_snapshot(
session: Session,
hyperdrive_address: str | None = None,
hyperdrive_address: list[str] | str | None = None,
start_block: int | None = None,
end_block: int | None = None,
wallet_address: list[str] | str | None = None,
Expand All @@ -607,8 +607,8 @@ def get_position_snapshot(
---------
session: Session
The initialized session object.
hyperdrive_address: str | None, optional
The hyperdrive pool address to filter the query on. Defaults to returning all position snapshots.
hyperdrive_address: list[str] | str | None, optional
The hyperdrive pool address(es) to filter the query on. Defaults to returning all position snapshots.
start_block: int | None, optional
The starting block to filter the query on. start_block integers
matches python slicing notation, e.g., list[:3], list[:-3].
Expand All @@ -630,7 +630,9 @@ def get_position_snapshot(
"""
query = session.query(PositionSnapshot)

if hyperdrive_address is not None:
if isinstance(hyperdrive_address, list):
query = query.filter(PositionSnapshot.hyperdrive_address.in_(hyperdrive_address))
elif hyperdrive_address is not None:
query = query.filter(PositionSnapshot.hyperdrive_address == hyperdrive_address)

latest_block = get_latest_block_number_from_table(PositionSnapshot, session)
Expand Down
16 changes: 14 additions & 2 deletions src/agent0/core/hyperdrive/interactive/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,11 @@ def block_time(self) -> Timestamp:
raise AssertionError("The provided block has no timestamp")
return block_timestamp

@property
def is_local_chain(self) -> bool:
"""Returns if this object is a local chain."""
return False

################
# Agent functions
################
Expand Down Expand Up @@ -366,7 +371,8 @@ def _handle_policy_config(

def init_agent(
self,
private_key: str,
private_key: str | None = None,
public_address: str | None = None,
pool: Hyperdrive | None = None,
policy: Type[HyperdriveBasePolicy] | None = None,
policy_config: HyperdriveBasePolicy.Config | None = None,
Expand All @@ -379,8 +385,13 @@ def init_agent(
Arguments
---------
private_key: str
private_key: str, optional
The private key of the associated account.
Must be supplied to allow this agent to do any transactions.
public_address: str | None, optional
The public address of the associated account. This allows this agent
to be used for analyzing data.
Can't be used in conjunction with private_key.
pool: LocalHyperdrive, optional
An optional pool to set as the active pool.
policy: HyperdrivePolicy, optional
Expand All @@ -406,6 +417,7 @@ def init_agent(
policy=policy,
policy_config=policy_config,
private_key=private_key,
public_address=public_address,
)
return out_agent

Expand Down
46 changes: 38 additions & 8 deletions src/agent0/core/hyperdrive/interactive/hyperdrive.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING

import nest_asyncio
import pandas as pd
Expand All @@ -15,7 +16,8 @@
get_hyperdrive_addresses_from_registry,
)

from .chain import Chain
if TYPE_CHECKING:
from .chain import Chain

# In order to support both scripts and jupyter notebooks with underlying async functions,
# we use the nest_asyncio package so that we can execute asyncio.run within a running event loop.
Expand All @@ -38,15 +40,15 @@ class Config:
def get_hyperdrive_addresses_from_registry(
cls,
chain: Chain,
registry_contract_addr: str,
registry_address: str,
) -> dict[str, ChecksumAddress]:
"""Gather deployed Hyperdrive pool addresses.
Arguments
---------
chain: Chain
The Chain object connected to a chain.
registry_contract_addr: str
registry_address: str
The address of the Hyperdrive factory contract.
Returns
Expand All @@ -55,7 +57,37 @@ def get_hyperdrive_addresses_from_registry(
A dictionary keyed by the pool's name, valued by the pool's address
"""
# pylint: disable=protected-access
return get_hyperdrive_addresses_from_registry(registry_contract_addr, chain._web3)
return get_hyperdrive_addresses_from_registry(registry_address, chain._web3)

@classmethod
def get_hyperdrive_pools_from_registry(
cls,
chain: Chain,
registry_address: str,
) -> list[Hyperdrive]:
"""Gather deployed Hyperdrive pool addresses.
Arguments
---------
chain: Chain
The Chain object connected to a chain.
registry_address: str
The address of the Hyperdrive registry contract.
Returns
-------
list[Hyperdrive]
The hyperdrive objects for all registered pools
"""
hyperdrive_addresses = cls.get_hyperdrive_addresses_from_registry(chain, registry_address)
if len(hyperdrive_addresses) == 0:
raise ValueError("Registry does not have any hyperdrive pools registered.")
# Generate hyperdrive pool objects here
registered_pools = []
for hyperdrive_name, hyperdrive_address in hyperdrive_addresses.items():
registered_pools.append(Hyperdrive(chain, hyperdrive_address, name=hyperdrive_name))

return registered_pools

def _initialize(self, chain: Chain, hyperdrive_address: ChecksumAddress, name: str | None):
self.chain = chain
Expand All @@ -76,6 +108,7 @@ def _initialize(self, chain: Chain, hyperdrive_address: ChecksumAddress, name: s
)

add_hyperdrive_addr_to_name(name, self.hyperdrive_address, self.chain.db_session)
self.name = name

def __init__(
self,
Expand Down Expand Up @@ -106,10 +139,7 @@ def __init__(
# held by the chain object, we want to ensure that we dont mix and match
# local vs non-local hyperdrive objects. Hence, we ensure that any hyperdrive
# objects must come from a base Chain object and not a LocalChain.
# We use `type` instead of `isinstance` to explicitly check for
# the base Chain type instead of any subclass.
# pylint: disable=unidiomatic-typecheck
if type(chain) != Chain:
if chain.is_local_chain:
raise TypeError("The chain parameter must be a Chain object, not a LocalChain.")

self._initialize(chain, hyperdrive_address, name)
Expand Down
Loading

0 comments on commit a3cbc79

Please sign in to comment.