Skip to content
Snippets Groups Projects
ball_analysis.ipynb 80.3 KiB
Newer Older
gautham's avatar
gautham committed
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# read the pickle file\n",
    "with open('../tracker_stubs/ball_detections.pkl', 'rb') as f:\n",
    "    ball_positions = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "ball_positions = [x.get(1,[]) for x in ball_positions]\n",
    "# convert the list into pandas dataframe\n",
    "df_ball_positions = pd.DataFrame(ball_positions,columns=['x1','y1','x2','y2'])\n",
    "\n",
    "# interpolate the missing values\n",
    "df_ball_positions = df_ball_positions.interpolate()\n",
    "df_ball_positions = df_ball_positions.bfill()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_ball_positions['mid_y'] = (df_ball_positions['y1'] + df_ball_positions['y2'])/2\n",
    "df_ball_positions['mid_y_rolling_mean'] = df_ball_positions['mid_y'].rolling(window=5, min_periods=1, center=False).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x20f1bba1d10>]"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# plot mid_y_rolling_mean\n",
    "plt.plot(df_ball_positions['mid_y_rolling_mean'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_ball_positions['delta_y'] = df_ball_positions['mid_y_rolling_mean'].diff()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x20f444c1390>]"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# plot mid_y_rolling_mean\n",
    "plt.plot(df_ball_positions['delta_y'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_ball_positions['ball_hit']=0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[11, 58, 95, 131, 182]\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "\n",
    "# Ensure 'ball_hit' column exists and initialize it with 0\n",
    "df_ball_positions['ball_hit'] = 0\n",
    "\n",
    "# Define the minimum frames required for a hit\n",
    "minimum_change_frames_for_hit = 25\n",
    "\n",
    "# Loop through each frame to detect hits\n",
    "for i in range(1, len(df_ball_positions) - int(minimum_change_frames_for_hit * 1.2)):\n",
    "    negative_position_change = (\n",
    "        df_ball_positions['delta_y'].iloc[i] > 0 and df_ball_positions['delta_y'].iloc[i + 1] < 0\n",
    "    )\n",
    "    positive_position_change = (\n",
    "        df_ball_positions['delta_y'].iloc[i] < 0 and df_ball_positions['delta_y'].iloc[i + 1] > 0\n",
    "    )\n",
    "\n",
    "    if negative_position_change or positive_position_change:\n",
    "        change_count = 0\n",
    "\n",
    "        # Check subsequent frames for consistent position changes\n",
    "        for change_frame in range(i + 1, i + int(minimum_change_frames_for_hit * 1.2) + 1):\n",
    "            negative_position_change_following_frame = (\n",
    "                df_ball_positions['delta_y'].iloc[i] > 0 and df_ball_positions['delta_y'].iloc[change_frame] < 0\n",
    "            )\n",
    "            positive_position_change_following_frame = (\n",
    "                df_ball_positions['delta_y'].iloc[i] < 0 and df_ball_positions['delta_y'].iloc[change_frame] > 0\n",
    "            )\n",
    "\n",
    "            if negative_position_change and negative_position_change_following_frame:\n",
    "                change_count += 1\n",
    "            elif positive_position_change and positive_position_change_following_frame:\n",
    "                change_count += 1\n",
    "\n",
    "        # Mark the ball hit if enough changes are detected\n",
    "        if change_count > minimum_change_frames_for_hit - 1:\n",
    "            df_ball_positions.at[i, 'ball_hit'] = 1  # Use .at for safer assignment\n",
    "\n",
    "# Get the frame numbers where ball hits were detected\n",
    "frame_nums_with_ball_hits = df_ball_positions[df_ball_positions['ball_hit'] == 1].index.tolist()\n",
    "\n",
    "# Print or return the result if needed\n",
    "print(frame_nums_with_ball_hits)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>x1</th>\n",
       "      <th>y1</th>\n",
       "      <th>x2</th>\n",
       "      <th>y2</th>\n",
       "      <th>mid_y</th>\n",
       "      <th>mid_y_rolling_mean</th>\n",
       "      <th>delta_y</th>\n",
       "      <th>ball_hit</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>776.865967</td>\n",
       "      <td>717.330017</td>\n",
       "      <td>796.806519</td>\n",
       "      <td>738.393188</td>\n",
       "      <td>727.861603</td>\n",
       "      <td>735.918115</td>\n",
       "      <td>6.523407</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>58</th>\n",
       "      <td>925.881409</td>\n",
       "      <td>240.971042</td>\n",
       "      <td>939.039478</td>\n",
       "      <td>253.989072</td>\n",
       "      <td>247.480057</td>\n",
       "      <td>243.406097</td>\n",
       "      <td>-1.957851</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>95</th>\n",
       "      <td>624.777161</td>\n",
       "      <td>748.891968</td>\n",
       "      <td>642.157257</td>\n",
       "      <td>766.698242</td>\n",
       "      <td>757.795105</td>\n",
       "      <td>775.403400</td>\n",
       "      <td>0.871759</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>131</th>\n",
       "      <td>716.963562</td>\n",
       "      <td>229.095024</td>\n",
       "      <td>729.239868</td>\n",
       "      <td>242.786232</td>\n",
       "      <td>235.940628</td>\n",
       "      <td>235.241684</td>\n",
       "      <td>-0.557164</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>182</th>\n",
       "      <td>1294.891235</td>\n",
       "      <td>739.127197</td>\n",
       "      <td>1314.160156</td>\n",
       "      <td>760.564819</td>\n",
       "      <td>749.846008</td>\n",
       "      <td>738.733578</td>\n",
       "      <td>5.602832</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "              x1          y1           x2          y2       mid_y  \\\n",
       "11    776.865967  717.330017   796.806519  738.393188  727.861603   \n",
       "58    925.881409  240.971042   939.039478  253.989072  247.480057   \n",
       "95    624.777161  748.891968   642.157257  766.698242  757.795105   \n",
       "131   716.963562  229.095024   729.239868  242.786232  235.940628   \n",
       "182  1294.891235  739.127197  1314.160156  760.564819  749.846008   \n",
       "\n",
       "     mid_y_rolling_mean   delta_y  ball_hit  \n",
       "11           735.918115  6.523407         1  \n",
       "58           243.406097 -1.957851         1  \n",
       "95           775.403400  0.871759         1  \n",
       "131          235.241684 -0.557164         1  \n",
       "182          738.733578  5.602832         1  "
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_ball_positions[df_ball_positions['ball_hit']==1]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "myenv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}