MaxQuadPool2d#

class torch_geopooling.nn.MaxQuadPool2d(feature_dim: int, polygon: Polygon, exterior: Exterior | Tuple[float, float, float, float], max_terminal_nodes: int | None = None, max_depth: int = 17, precision: int | None = 7)[source]#

Maximum pooling over quadtree decomposition of input 2D coordinates.

This module constructs an internal lookup tree to organize closely situated 2D points using a specified polygon and exterior, where polygon is treated as a boundary of terminal nodes of a quadtree.

Each terminal node in the resulting quadtree is paired with a weight. Thus, when providing an input coordinate, the module retrieves a terminal group of nodes and calculates the maximum value for each feature_dim.

Parameters:
  • feature_dim – Size of each feature vector.

  • polygon – Polygon that resembles boundary for the terminal nodes of a quadtree.

  • exterior – Geometrical boundary of the learning space in (X, Y, W, H) format.

  • max_terminal_nodes – Optional maximum number of terminal nodes in a quadtree. Once a maximum is reached, internal nodes are no longer sub-divided and tree stops growing.

  • max_depth – Maximum depth of the quadtree. Default: 17.

  • precision – Optional rounding of the input coordinates. Default: 7.

Shape:
  • Input: \((*, 2)\), where 2 comprises longitude and latitude coordinates.

  • Output: \((*, H)\), where * is the input shape and \(H = \text{feature_dim}\).

Note

Input coordinates must be within a specified exterior geometry (including boundaries). For input coordinates outsize of the specified exterior, module throws an exception.

Note

A terminal group refers to a collection of terminal nodes within the quadtree that share the same parent tile.

Note

All terminal nodes that have an intersection with the specified polygon boundary are included into the quadtree.

Examples:

>>> from shapely.geometry import Polygon
>>> poly = Polygon([(0, 0), (10, 0), (10, 10), (0, 10)])
>>> pool = nn.MaxQuadPool2d(3, poly, exterior=(0, 0, 100, 100))
>>> input = torch.rand((2048, 2), dtype=torch.float64)
>>> output = pool(input)