1- # noqa: D100 
21import  warnings 
32from  collections .abc  import  Callable 
43from  dataclasses  import  fields 
@@ -75,6 +74,7 @@ def collect_agent_data(
7574            "stroke" : [],  # Stroke color 
7675            "strokeWidth" : [],
7776            "filled" : [],
77+             "tooltip" : [],
7878        }
7979
8080        # Import here to avoid circular import issues 
@@ -129,6 +129,7 @@ def collect_agent_data(
129129                    linewidths = dict_data .pop (
130130                        "linewidths" , style_fields .get ("linewidths" )
131131                    ),
132+                     tooltip = dict_data .pop ("tooltip" , None ),
132133                )
133134                if  dict_data :
134135                    ignored_keys  =  list (dict_data .keys ())
@@ -184,6 +185,7 @@ def collect_agent_data(
184185            # FIXME: Make filled user-controllable 
185186            filled_value  =  True 
186187            arguments ["filled" ].append (filled_value )
188+             arguments ["tooltip" ].append (aps .tooltip )
187189
188190        final_data  =  {}
189191        for  k , v  in  arguments .items ():
@@ -199,87 +201,84 @@ def collect_agent_data(
199201
200202        return  final_data 
201203
204+ 
205+ 
202206    def  draw_agents (
203207        self , arguments , chart_width : int  =  450 , chart_height : int  =  350 , ** kwargs 
204208    ):
205-         """Draw agents using Altair backend. 
206- 
207-         Args: 
208-             arguments: Dictionary containing agent data arrays. 
209-             chart_width: Width of the chart. 
210-             chart_height: Height of the chart. 
211-             **kwargs: Additional keyword arguments for customization. 
212-             Checkout respective `SpaceDrawer` class on details how to pass **kwargs. 
213- 
214-         Returns: 
215-             alt.Chart: The Altair chart representing the agents, or None if no agents. 
216-         """ 
209+         """Draw agents using Altair backend.""" 
217210        if  arguments ["loc" ].size  ==  0 :
218211            return  None 
219212
220-         # To get a continuous scale for color the domain should be between [0, 1] 
221-         # that's why changing the the domain of strokeWidth beforehand. 
222-         stroke_width  =  [data  /  10  for  data  in  arguments ["strokeWidth" ]]
223- 
224-         # Agent data preparation 
225-         df_data  =  {
226-             "x" : arguments ["loc" ][:, 0 ],
227-             "y" : arguments ["loc" ][:, 1 ],
228-             "size" : arguments ["size" ],
229-             "shape" : arguments ["shape" ],
230-             "opacity" : arguments ["opacity" ],
231-             "strokeWidth" : stroke_width ,
232-             "original_color" : arguments ["color" ],
233-             "is_filled" : arguments ["filled" ],
234-             "original_stroke" : arguments ["stroke" ],
235-         }
236-         df  =  pd .DataFrame (df_data )
237- 
238-         # To ensure distinct shapes according to agent portrayal 
239-         unique_shape_names_in_data  =  df ["shape" ].unique ().tolist ()
240- 
241-         fill_colors  =  []
242-         stroke_colors  =  []
243-         for  i  in  range (len (df )):
244-             filled  =  df ["is_filled" ][i ]
245-             main_color  =  df ["original_color" ][i ]
246-             stroke_spec  =  (
247-                 df ["original_stroke" ][i ]
248-                 if  isinstance (df ["original_stroke" ][i ], str )
249-                 else  None 
250-             )
251-             if  filled :
252-                 fill_colors .append (main_color )
253-                 stroke_colors .append (stroke_spec )
213+         # Prepare a list of dictionaries, which is a robust way to create a DataFrame 
214+         records  =  []
215+         for  i  in  range (len (arguments ["loc" ])):
216+             record  =  {
217+                 "x" : arguments ["loc" ][i ][0 ],
218+                 "y" : arguments ["loc" ][i ][1 ],
219+                 "size" : arguments ["size" ][i ],
220+                 "shape" : arguments ["shape" ][i ],
221+                 "opacity" : arguments ["opacity" ][i ],
222+                 "strokeWidth" : arguments ["strokeWidth" ][i ] /  10 , # Scale for continuous domain 
223+                 "original_color" : arguments ["color" ][i ],
224+             }
225+             # Add tooltip data if available 
226+             tooltip  =  arguments ["tooltip" ][i ]
227+             if  tooltip :
228+                 record .update (tooltip )
229+ 
230+             # Determine fill and stroke colors 
231+             if  arguments ["filled" ][i ]:
232+                 record ["viz_fill_color" ] =  arguments ["color" ][i ]
233+                 record ["viz_stroke_color" ] =  arguments ["stroke" ][i ] if  isinstance (arguments ["stroke" ][i ], str ) else  None 
254234            else :
255-                 fill_colors .append (None )
256-                 stroke_colors .append (main_color )
257-         df ["viz_fill_color" ] =  fill_colors 
258-         df ["viz_stroke_color" ] =  stroke_colors 
235+                 record ["viz_fill_color" ] =  None 
236+                 record ["viz_stroke_color" ] =  arguments ["color" ][i ]
237+ 
238+             records .append (record )
239+ 
240+         df  =  pd .DataFrame (records )
241+ 
242+         # Ensure all columns that should be numeric are, handling potential Nones 
243+         numeric_cols  =  ['x' , 'y' , 'size' , 'opacity' , 'strokeWidth' , 'original_color' ]
244+         for  col  in  numeric_cols :
245+             if  col  in  df .columns :
246+                 df [col ] =  pd .to_numeric (df [col ], errors = 'coerce' )
247+ 
248+ 
249+         # Get tooltip keys from the first valid record 
250+         tooltip_list  =  ["x" , "y" ]
251+         # This is the corrected line: 
252+         if  any (t  is  not None  for  t  in  arguments ["tooltip" ]):
253+              first_valid_tooltip  =  next ((t  for  t  in  arguments ["tooltip" ] if  t ), None )
254+              if  first_valid_tooltip :
255+                  tooltip_list .extend (first_valid_tooltip .keys ())
259256
260257        # Extract additional parameters from kwargs 
261-         # FIXME: Add more parameters to kwargs 
262258        title  =  kwargs .pop ("title" , "" )
263259        xlabel  =  kwargs .pop ("xlabel" , "" )
264260        ylabel  =  kwargs .pop ("ylabel" , "" )
265- 
266-         # Tooltip list for interactivity 
267-         # FIXME: Add more fields to tooltip (preferably from agent_portrayal) 
268-         tooltip_list  =  ["x" , "y" ]
261+         legend_title  =  kwargs .pop ("legend_title" , "Color" )
269262
270263        # Handle custom colormapping 
271264        cmap  =  kwargs .pop ("cmap" , "viridis" )
272265        vmin  =  kwargs .pop ("vmin" , None )
273266        vmax  =  kwargs .pop ("vmax" , None )
274267
275-         color_is_numeric  =  np . issubdtype (df ["original_color" ]. dtype ,  np . number )
268+         color_is_numeric  =  pd . api . types . is_numeric_dtype (df ["original_color" ])
276269        if  color_is_numeric :
277270            color_min  =  vmin  if  vmin  is  not None  else  df ["original_color" ].min ()
278271            color_max  =  vmax  if  vmax  is  not None  else  df ["original_color" ].max ()
279272
280273            fill_encoding  =  alt .Fill (
281274                "original_color:Q" ,
282275                scale = alt .Scale (scheme = cmap , domain = [color_min , color_max ]),
276+                 legend = alt .Legend (
277+                     title = legend_title ,
278+                     orient = "right" ,
279+                     type = "gradient" ,
280+                     gradientLength = 200 ,
281+                 ),
283282            )
284283        else :
285284            fill_encoding  =  alt .Fill (
@@ -290,6 +289,7 @@ def draw_agents(
290289
291290        # Determine space dimensions 
292291        xmin , xmax , ymin , ymax  =  self .space_drawer .get_viz_limits ()
292+         unique_shape_names_in_data  =  df ["shape" ].dropna ().unique ().tolist ()
293293
294294        chart  =  (
295295            alt .Chart (df )
@@ -316,16 +316,10 @@ def draw_agents(
316316                    ),
317317                    title = "Shape" ,
318318                ),
319-                 opacity = alt .Opacity (
320-                     "opacity:Q" ,
321-                     title = "Opacity" ,
322-                     scale = alt .Scale (domain = [0 , 1 ], range = [0 , 1 ]),
323-                 ),
319+                 opacity = alt .Opacity ("opacity:Q" , title = "Opacity" , scale = alt .Scale (domain = [0 , 1 ], range = [0 , 1 ])),
324320                fill = fill_encoding ,
325321                stroke = alt .Stroke ("viz_stroke_color:N" , scale = None ),
326-                 strokeWidth = alt .StrokeWidth (
327-                     "strokeWidth:Q" , scale = alt .Scale (domain = [0 , 1 ])
328-                 ),
322+                 strokeWidth = alt .StrokeWidth ("strokeWidth:Q" , scale = alt .Scale (domain = [0 , 1 ])),
329323                tooltip = tooltip_list ,
330324            )
331325            .properties (title = title , width = chart_width , height = chart_height )
@@ -437,4 +431,4 @@ def draw_propertylayer(
437431                main_charts .append (current_chart )
438432
439433        base  =  alt .layer (* main_charts ).resolve_scale (color = "independent" )
440-         return  base 
434+         return  base 
0 commit comments