Skip to content

feat: handle serialization for objects with __dict__ magic method#619

Closed
AgustinRamiroDiaz wants to merge 1 commit intomainfrom
agustin/improve-serialization
Closed

feat: handle serialization for objects with __dict__ magic method#619
AgustinRamiroDiaz wants to merge 1 commit intomainfrom
agustin/improve-serialization

Conversation

@AgustinRamiroDiaz
Copy link
Copy Markdown
Contributor

@AgustinRamiroDiaz AgustinRamiroDiaz commented Nov 19, 2024

What

handle serialization for objects with __dict__ magic method, by simply calling the __dict__ method on the object

Why

To allow for more flexible Intelligent Contracts, like using @DataClass es

Testing done

Tested using this IC

"""
TODO:
- should we allow to submit old tweets?
"""

from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
import json
import os
import uuid
import re
from itertools import groupby
import urllib.request
import urllib.parse

from backend.node.genvm.icontract import IContract
from backend.node.genvm.equivalence_principle import (
    EquivalencePrinciple,
)


@dataclass
class Submission:
    global datetime
    mission_uuid: str
    user: str
    tweet_url: str
    checks_accomplished: list[bool]
    timestamp: datetime


@dataclass
class MissionStaticCheck:
    description: str
    points: int


@dataclass
class Mission:
    global MissionStaticCheck
    title: str
    description: str
    static_checks: list[MissionStaticCheck]
    day_index: int
    # TODO: add dynamic points


def dict_to_mission(data: dict) -> Mission:
    # Use list comprehension to transform nested dictionaries to MissionStaticCheck
    data["static_checks"] = [
        MissionStaticCheck(**check) for check in data["static_checks"]
    ]
    # Directly unpack into the Mission dataclass
    return Mission(**data)


class Campaign(IContract):
    # needed due to limitation in the simulator imports
    global Mission
    global Submission
    global datetime
    global MissionStaticCheck

    campaign_id: str
    title: str
    creator_address: str
    pot_size: int
    pot_coin: (
        str  # TODO: we should store the original ERC20 contract to know the real coin
    )
    # TODO: we should probably also store the token decimals
    duration_days: int
    missions: dict[str, Mission]
    faq: str
    start_date: datetime
    end_date: datetime
    x_id_contract: str
    submissions: list[Submission]
    daily_scores: dict[str, dict[str, int]]
    total_scores: dict[str, int]

    def __init__(
        self,
        campaign_id: str,
        title: str,
        start_date: str,
        pot_coin: str,
        pot_size: int,
        duration_days: int,
        missions: dict[str, dict],
        faq: str,
        x_id_contract: str,
    ):
        self.campaign_id = campaign_id
        self.title = title
        self.creator_address = contract_runner.from_address
        self.pot_size = pot_size
        self.pot_coin = pot_coin
        self.duration_days = duration_days
        self.start_date = datetime.fromisoformat(start_date)
        self.end_date = self.start_date + timedelta(days=duration_days)
        self.missions = {
            id: dict_to_mission(mission) for id, mission in missions.items()
        }
        self.faq = faq
        self.submissions: list[Submission] = []
        self.daily_scores = {}  # date: {twitter_username: points}
        self.total_scores = {}  # twitter_username: total_points
        self.x_id_contract = x_id_contract

    def _get_day_index(self, date: datetime) -> int:
        return (date - self.start_date).days

    def _require_now_in_valid_range(self):
        now = datetime.now(timezone.utc)
        now_day_index = self._get_day_index(now)
        if now_day_index >= self.duration_days:
            raise ValueError("Campaign has already ended.")

    async def submit_mission(
        self,
        mission_day_index: int,  # 0 to duration_days - 1
        mission_uuid: str,
        mission_title: str,
        mission_description: str,
        mission_static_checks: list[str],
        mission_static_checks_points: list[int],
    ):
        if contract_runner.from_address != self.creator_address:
            raise ValueError("Only the creator can submit missions.")

        now_day_index = self._get_day_index(datetime.now(timezone.utc))

        if mission_day_index <= now_day_index:
            raise ValueError("Mission can only be submitted from the next day onward.")

        self._require_now_in_valid_range()

        if not is_valid_uuid(mission_uuid):
            raise ValueError("Invalid mission UUID.")

        if mission_uuid in self.missions:
            raise ValueError("Mission id already exists.")

        if len(mission_static_checks) != len(mission_static_checks_points):
            raise ValueError(
                "mission_static_checks and mission_static_checks_points must have the same length."
            )

        # Create mission static checks
        static_checks = [
            MissionStaticCheck(description=check, points=points, id=id)
            for id, (check, points) in enumerate(
                zip(mission_static_checks, mission_static_checks_points)
            )
        ]

        # Create and store the mission
        self.missions[mission_uuid] = Mission(
            title=mission_title,
            description=mission_description,
            static_checks=static_checks,
            day_index=mission_day_index,
        )

    async def submit_tweet(self, mission_uuid: str, tweet_url: str):
        now = datetime.now(timezone.utc)
        now_day_index = self._get_day_index(now)

        self._require_now_in_valid_range()

        # Validate mission exists
        if mission_uuid not in self.missions:
            raise ValueError("Mission not found")

        mission = self.missions[mission_uuid]
        if mission.day_index != now_day_index:
            raise ValueError("Mission can only be submitted for the current day.")

        # Get checks accomplished from tweet
        user, checks_accomplished = await self._get_mission_checks_for_tweet(
            mission, tweet_url
        )

        # TODO: evaluate using dicts
        if any(
            submission.mission_uuid == mission_uuid and submission.user == user
            for submission in self.submissions
        ):
            raise ValueError(
                "User already submitted for this mission."
            )  # TODO: should we handle retries?

        # Store submission data
        self.submissions.append(
            Submission(
                mission_uuid=mission_uuid,
                user=user,
                tweet_url=tweet_url,
                checks_accomplished=checks_accomplished,
                timestamp=now,
            )
        )

    def get_scoreboard(self) -> dict[str, dict[str, int]]:
        """
        Returns a dictionary with the day index as the key and a dictionary with the user and their total points for that day as the value.
        """
        return {
            day_index: {
                user: sum(
                    self._get_points_for_accomplished_checks(
                        self.missions[submission.mission_uuid],
                        submission.checks_accomplished,
                    )
                    for submission in user_submissions
                )
                for user, user_submissions in groupby(day_submissions, lambda s: s.user)
            }
            for day_index, day_submissions in groupby(
                self.submissions, key=lambda s: self._get_day_index(s.timestamp)
            )
        }

    # TODO: review if rounding causes trouble. Maybe we should do the rounding in Base?
    def get_daily_distribution(self, day_index: int) -> dict[str, int]:
        """
        Distribution for a given day.
        The sum of all values is 1.
        The distribution for each user is the sum of the points they have for that day divided by the total points for that day.
        """
        if day_index < 0 or day_index >= self.duration_days:
            raise ValueError("Invalid day index.")
        total_points_for_day = sum(
            self._get_points_for_accomplished_checks(mission, checks)
            for mission, checks in self.missions.values()
        )
        return {
            user: int(points / total_points_for_day)
            for user, points in self.get_scoreboard()[
                day_index
            ].items()  # Proably can be optimized if we only get the scoreboard for this date
        }

    def _get_points_for_accomplished_checks(
        self, mission: Mission, checks_accomplished: list[bool]
    ) -> int:
        assert len(checks_accomplished) == len(mission.static_checks)
        return sum(
            check.points
            for check, accomplished in zip(mission.static_checks, checks_accomplished)
            if accomplished
        )

    async def _get_mission_checks_for_tweet(
        self, mission: Mission, tweet_url: str
    ) -> tuple[str, list[bool]]:
        print(f"Getting checks for tweet {tweet_url} and mission {mission}")

        tweet_text, tweet_created_at = get_tweet_content(tweet_url)
        user = get_tweet_user_handle(tweet_url)

        if self._get_day_index(tweet_created_at) != mission.day_index:
            raise ValueError("Tweet was not tweeted on the mission's day.")

        final_result = {}
        async with EquivalencePrinciple(
            result=final_result,
            principle="The JSONs have to be exactly the same",
            comparative=True,
        ) as eq:
            task = f"""
            Given the following mission

            - Title: {mission.title}
            - Description: {mission.description}
            - Checks: {[check.description for check in mission.static_checks]}

            Analyze the following tweet and for each check, evaluate if the tweet satisfies the check.
            If it does, assign a `true` value. If it does not, assign a `false` value.
            
            <<<Tweet>>>
            {tweet_text}
            <<<End of Tweet>>>

            Your response should be only a JSON array of boolean values, one for each check in the mission, respecting the order of the checks. Here's an example:
            [
                true,
                false,
            ]
            """
            checks = await eq.call_llm(task)
            print(f"{checks=}")
            eq.set(checks)

        checks_accomplished = json.loads(
            final_result["output"].strip("```json").strip("```")
        )
        print(f"{checks_accomplished=}")

        assert len(checks_accomplished) == len(mission.static_checks)
        assert isinstance(checks_accomplished, list)
        assert all(isinstance(check, bool) for check in checks_accomplished)

        return user, checks_accomplished

    def get_campaign_info(self):
        return {
            "campaign_id": self.campaign_id,
            "pot_coin": self.pot_coin,
            "title": self.title,
            "creator_address": self.creator_address,
            "pot_size": self.pot_size,
            "duration_days": self.duration_days,
            "start_date": self.start_date.isoformat(),
            "end_date": self.end_date.isoformat(),
            "missions": self.missions,
            "submissions": self.submissions,
            "daily_scores": self.daily_scores,
            "total_scores": self.total_scores,
            "faq": self.faq,
        }

    async def link_wallet(self, twitter_username: str, wallet_address: str):
        Contract(self.x_id_contract).link_wallet(twitter_username, wallet_address)


def is_valid_uuid(uuid_str: str) -> bool:
    try:
        uuid_obj = uuid.UUID(uuid_str)
        return str(uuid_obj) == uuid_str
    except ValueError:
        return False


def create_headers(bearer_token: str) -> dict[str, str]:
    headers = {
        "Authorization": f"Bearer {bearer_token}",
        "User-Agent": "v2TweetLookupPython",
    }
    return headers


# TODO: improve these methods for more robustness
def get_tweet_id(tweet_url: str) -> str:
    # Extract tweet ID using regex
    match = re.search(r"status/(\d+)", tweet_url)
    if match:
        return match.group(1)
    else:
        raise ValueError("Invalid Tweet URL")


def get_tweet_user_handle(tweet_url: str) -> str:
    # Extract tweet user handle using regex
    match = re.search(r"x\.com/([^/]+)/status/(\d+)", tweet_url)
    if match:
        return match.group(1)
    else:
        raise ValueError("Invalid Tweet URL")


# TODO: should we use equivalence principle for this?
def get_tweet_content(tweet_url: str) -> tuple[str, datetime]:
    tweet_id = get_tweet_id(tweet_url)
    headers = create_headers(os.getenv("X_API_BEARER_TOKEN"))
    url = f"https://api.twitter.com/2/tweets/{tweet_id}"
    params = {"tweet.fields": "created_at,text"}
    url_with_params = f"{url}?{urllib.parse.urlencode(params)}"
    req = urllib.request.Request(url_with_params, headers=headers)

    try:
        with urllib.request.urlopen(req) as response:
            if response.status == 200:
                tweet_data = json.loads(response.read().decode())
                return tweet_data["data"]["text"], datetime.fromisoformat(
                    tweet_data["data"]["created_at"]
                )
            else:
                raise Exception(
                    f"Request returned an error: {response.status} {response.reason}"
                )
    except urllib.error.HTTPError as e:
        raise Exception(f"Request returned an error: {e.code} {e.reason}")

Checks

  • I have tested this code
  • I have reviewed my own PR
  • I have created an issue for this PR
  • I have set a descriptive PR title compliant with conventional commits

User facing release notes

Now your Intelligent Contracts can return Python objects that implement __dict__ magic methods

Signed-off-by: Agustín Ramiro Díaz <agustin.ramiro.diaz@gmail.com>
@codecov
Copy link
Copy Markdown

codecov Bot commented Nov 19, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 18.61%. Comparing base (af6baf7) to head (603a176).

Additional details and impacted files
@@           Coverage Diff           @@
##             main     #619   +/-   ##
=======================================
  Coverage   18.61%   18.61%           
=======================================
  Files         123      123           
  Lines        9618     9618           
  Branches      299      299           
=======================================
  Hits         1790     1790           
  Misses       7744     7744           
  Partials       84       84           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.


🚨 Try these New Features:

@sonarqubecloud
Copy link
Copy Markdown

@kp2pml30
Copy link
Copy Markdown
Member

I don't know global plans regarding merging new genvm (#599), but as soon as it is done calldata encoding (from python code) will happen there and this change will do nothing

@kp2pml30
Copy link
Copy Markdown
Member

kp2pml30 commented Nov 20, 2024

I understand the reasoning, but I think more research/design is needed, and I am also not sure about the best approach now...

  1. It doesn't support __slots__
  2. I am more of a fan of json.dumps approach (which receives a custom function), and is more "pythonic" in that sense
  3. maybe new decorator is needed?
  4. maybe we should have special __to_calldata__ method? (ruby approach)

second would be the least evil right now

@AgustinRamiroDiaz
Copy link
Copy Markdown
Contributor Author

Closing until further definition is made

@kp2pml30
Copy link
Copy Markdown
Member

kp2pml30 commented Nov 21, 2024

FYI: for now I am replacing checks for exact types with checks for Sequence and Mapping + I am adding a check for dataclasses.is_dataclass and then dataclasses.asdict

related genvm issue: genlayerlabs/genvm#103

@cristiam86 cristiam86 deleted the agustin/improve-serialization branch December 9, 2024 13:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants